result_validate_0.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451
  1. import argparse
  2. import datetime
  3. import os
  4. import pandas as pd
  5. from data_loader import mongo_con_parse, validate_one_line, fill_hourly_crawl_date
  6. def _validate_predict_df(df_predict):
  7. client, db = mongo_con_parse()
  8. count = 0
  9. for idx, row in df_predict.iterrows():
  10. city_pair = row['city_pair']
  11. flight_day = row['flight_day']
  12. flight_number_1 = row['flight_number_1']
  13. flight_number_2 = row['flight_number_2']
  14. baggage = row['baggage']
  15. valid_begin_hour = row['valid_begin_hour']
  16. valid_begin_dt = pd.to_datetime(valid_begin_hour, format='%Y-%m-%d %H:%M:%S')
  17. # valid_end_hour = row['valid_end_hour']
  18. # valid_end_dt = pd.to_datetime(valid_end_hour, format='%Y-%m-%d %H:%M:%S')
  19. update_hour = row['update_hour']
  20. update_dt = pd.to_datetime(update_hour, format='%Y-%m-%d %H:%M:%S')
  21. valid_begin_hour_modify = max(
  22. valid_begin_dt,
  23. update_dt
  24. ).strftime('%Y-%m-%d %H:%M:%S')
  25. df_val= validate_one_line(db, city_pair, flight_day, flight_number_1, flight_number_2, baggage, valid_begin_hour_modify)
  26. # 有可能在当前验证时刻,数据库里没有在valid_begin_hour之后的数据
  27. if not df_val.empty:
  28. df_val_f = fill_hourly_crawl_date(df_val, rear_fill=2)
  29. df_val_f = df_val_f[df_val_f['is_filled']==0] # 只要原始数据,不要补齐的
  30. # df_val_f = df_val_f[df_val_f['update_hour'] <= valid_end_dt]
  31. if df_val_f.empty:
  32. drop_flag = 0
  33. # first_drop_amount = pd.NA
  34. first_drop_price = pd.NA
  35. first_drop_hours_until_departure = pd.NA
  36. first_drop_update_hour = pd.NA
  37. last_hours_util = pd.NA
  38. last_update_hour = pd.NA
  39. list_change_price = []
  40. list_change_hours = []
  41. else:
  42. # 有效数据的最后一行
  43. last_row = df_val_f.iloc[-1]
  44. last_hours_util = last_row['hours_until_departure']
  45. last_update_hour = last_row['update_hour']
  46. # 价格变化过滤
  47. df_price_changes = df_val_f.loc[
  48. df_val_f["adult_total_price"].shift() != df_val_f["adult_total_price"]
  49. ].copy()
  50. # 价格变化幅度
  51. df_price_changes['change_amount'] = df_price_changes['adult_total_price'].diff().fillna(0)
  52. # 找到第一个 change_amount 小于 -5 的行
  53. first_negative_change = df_price_changes[df_price_changes['change_amount'] < -5].head(1)
  54. # 提取所需的值
  55. if not first_negative_change.empty:
  56. drop_flag = 1
  57. # first_drop_amount = first_negative_change['change_amount'].iloc[0].round(2)
  58. first_drop_price = first_negative_change['adult_total_price'].iloc[0].round(2)
  59. first_drop_hours_until_departure = first_negative_change['hours_until_departure'].iloc[0]
  60. first_drop_update_hour = first_negative_change['update_hour'].iloc[0]
  61. else:
  62. drop_flag = 0
  63. # first_drop_amount = pd.NA
  64. first_drop_price = pd.NA
  65. first_drop_hours_until_departure = pd.NA
  66. first_drop_update_hour = pd.NA
  67. list_change_price = df_price_changes['adult_total_price'].tolist()
  68. list_change_hours = df_price_changes['hours_until_departure'].tolist()
  69. else:
  70. drop_flag = 0
  71. # first_drop_amount = pd.NA
  72. first_drop_price = pd.NA
  73. first_drop_hours_until_departure = pd.NA
  74. first_drop_update_hour = pd.NA
  75. last_hours_util = pd.NA
  76. last_update_hour = pd.NA
  77. list_change_price = []
  78. list_change_hours = []
  79. safe_sep = "; "
  80. df_predict.at[idx, 'change_prices'] = safe_sep.join(map(str, list_change_price))
  81. df_predict.at[idx, 'change_hours'] = safe_sep.join(map(str, list_change_hours))
  82. df_predict.at[idx, 'last_hours_util'] = last_hours_util
  83. df_predict.at[idx, 'last_update_hour'] = last_update_hour
  84. # df_predict.at[idx, 'first_drop_amount'] = first_drop_amount * -1 # 负数转正数
  85. df_predict.at[idx, 'first_drop_price'] = first_drop_price
  86. df_predict.at[idx, 'first_drop_hours_until_departure'] = first_drop_hours_until_departure
  87. df_predict.at[idx, 'first_drop_update_hour'] = first_drop_update_hour
  88. df_predict.at[idx, 'drop_flag'] = drop_flag
  89. count += 1
  90. if count % 5 == 0:
  91. print(f"cal count: {count}")
  92. print(f"计算结束")
  93. client.close()
  94. return df_predict
  95. def validate_process(node, interval_hours, pred_time_str):
  96. '''手动验证脚本'''
  97. date = pred_time_str[4:8]
  98. output_dir = f"./validate/{node}_{date}"
  99. os.makedirs(output_dir, exist_ok=True)
  100. object_dir = "./predictions_0"
  101. csv_file = f'future_predictions_{pred_time_str}.csv'
  102. csv_path = os.path.join(object_dir, csv_file)
  103. try:
  104. df_predict = pd.read_csv(csv_path)
  105. except Exception as e:
  106. print(f"read {csv_path} error: {str(e)}")
  107. df_predict = pd.DataFrame()
  108. if df_predict.empty:
  109. print(f"预测数据为空")
  110. return
  111. df_predict = _validate_predict_df(df_predict)
  112. timestamp_str = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
  113. save_scv = f"result_validate_{node}_{pred_time_str}_{timestamp_str}.csv"
  114. output_path = os.path.join(output_dir, save_scv)
  115. df_predict.to_csv(output_path, index=False, encoding="utf-8-sig")
  116. print(f"保存完成: {output_path}")
  117. def validate_process_auto(node, interval_hours):
  118. '''自动验证脚本'''
  119. # 当前时间,取整时
  120. current_time = datetime.datetime.now()
  121. current_time_str = current_time.strftime("%Y%m%d%H%M")
  122. hourly_time = current_time.replace(minute=0, second=0, microsecond=0)
  123. hourly_time_str = hourly_time.strftime("%Y%m%d%H%M")
  124. print(f"验证时间:{current_time_str}, (取整): {hourly_time_str}")
  125. output_dir = f"./validate/{node}"
  126. os.makedirs(output_dir, exist_ok=True)
  127. object_dir = "./predictions_0"
  128. # 检查目录是否存在
  129. if not os.path.exists(object_dir):
  130. print(f"目录不存在: {object_dir}")
  131. return
  132. # 获取所有以 future_predictions_ 开头的 CSV 文件
  133. csv_files = []
  134. for file in os.listdir(object_dir):
  135. if file.startswith("future_predictions_") and file.endswith(".csv"):
  136. csv_files.append(file)
  137. if not csv_files:
  138. print(f"在 {object_dir} 中没有找到 future_predictions_ 开头的 CSV 文件")
  139. return
  140. # 提取时间戳并转换为 datetime 对象
  141. file_times = []
  142. for file in csv_files:
  143. # 提取时间戳部分:future_predictions_202601151600.csv -> 202601151600
  144. timestamp_str = file.replace("future_predictions_", "").replace(".csv", "")
  145. try:
  146. # 将时间戳转换为 datetime 对象
  147. file_time = datetime.datetime.strptime(timestamp_str, "%Y%m%d%H%M")
  148. file_times.append((file, file_time))
  149. except ValueError as e:
  150. print(f"文件 {file} 的时间戳格式错误: {e}")
  151. continue
  152. if not file_times:
  153. print("没有找到有效的时间戳文件")
  154. return
  155. # 目标验证文件(当前整点减56小时: 48 + (12 - 4) = 56)
  156. target_time = hourly_time - datetime.timedelta(hours=56)
  157. target_time_str = target_time.strftime("%Y%m%d%H%M")
  158. print(f"目标验证时间: {target_time_str}")
  159. valid_files = [(f, t) for f, t in file_times if t == target_time]
  160. if not valid_files:
  161. print(f"没有找到目标对应时间 {target_time.strftime('%Y%m%d%H%M')} 的文件")
  162. return
  163. valid_file, valid_time = valid_files[0]
  164. valid_time_str = valid_time.strftime("%Y%m%d%H%M")
  165. print(f"找到符合条件的文件: {valid_file} (时间: {valid_time_str})")
  166. csv_path = os.path.join(object_dir, valid_file)
  167. # 开始验证
  168. try:
  169. df_predict = pd.read_csv(csv_path)
  170. except Exception as e:
  171. print(f"read {csv_path} error: {str(e)}")
  172. df_predict = pd.DataFrame()
  173. if df_predict.empty:
  174. print(f"预测数据为空")
  175. return
  176. df_predict = _validate_predict_df(df_predict)
  177. timestamp_str = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
  178. save_scv = f"result_validate_{node}_{valid_time_str}_{timestamp_str}.csv"
  179. output_path = os.path.join(output_dir, save_scv)
  180. df_predict.to_csv(output_path, index=False, encoding="utf-8-sig")
  181. print(f"保存完成: {output_path}")
  182. print(f"验证完成: {node} {valid_time_str}")
  183. print()
  184. def validate_process_zong(node, enable_min_batch_flag=False, min_batch_time_str=None):
  185. object_dir = "./predictions_0"
  186. output_dir = f"./validate/{node}_zong"
  187. os.makedirs(output_dir, exist_ok=True)
  188. # 检查目录是否存在
  189. if not os.path.exists(object_dir):
  190. print(f"目录不存在: {object_dir}")
  191. return
  192. # 获取所有以 future_predictions_ 开头的 CSV 文件
  193. csv_files = []
  194. for file in os.listdir(object_dir):
  195. if file.startswith("future_predictions_") and file.endswith(".csv"):
  196. csv_files.append(file)
  197. if not csv_files:
  198. print(f"在 {object_dir} 中没有找到 future_predictions_ 开头的 CSV 文件")
  199. return
  200. csv_files.sort()
  201. list_df_will_drop = []
  202. min_batch_dt = None
  203. if enable_min_batch_flag:
  204. if not min_batch_time_str:
  205. print("enable_min_batch_flag=True 但未提供 min_batch_time_str,退出")
  206. return
  207. min_batch_dt = datetime.datetime.strptime(min_batch_time_str, "%Y%m%d%H%M")
  208. min_batch_dt = min_batch_dt.replace(minute=0, second=0, microsecond=0)
  209. # 从所有预测的文件中
  210. for csv_file in csv_files:
  211. batch_time_str = (
  212. csv_file.replace("future_predictions_", "").replace(".csv", "")
  213. )
  214. batch_dt = datetime.datetime.strptime(batch_time_str, "%Y%m%d%H%M")
  215. batch_hour_dt = batch_dt.replace(minute=0, second=0, microsecond=0)
  216. # 跳过早于 min_batch_dt 的批次
  217. if min_batch_dt is not None and batch_hour_dt < min_batch_dt:
  218. continue
  219. csv_path = os.path.join(object_dir, csv_file)
  220. try:
  221. df_predict = pd.read_csv(csv_path)
  222. except Exception as e:
  223. print(f"read {csv_path} error: {str(e)}")
  224. df_predict = pd.DataFrame()
  225. if df_predict.empty:
  226. print(f"预测数据为空: {csv_file}")
  227. continue
  228. if "will_price_drop" not in df_predict.columns:
  229. print(f"缺少 will_price_drop 字段,跳过: {csv_file}")
  230. continue
  231. df_predict_will_drop = df_predict[df_predict["will_price_drop"] == 1].copy()
  232. if df_predict_will_drop.empty:
  233. continue
  234. # df_predict_will_drop["batch_file"] = csv_file
  235. df_predict_will_drop["batch_time"] = batch_time_str
  236. list_df_will_drop.append(df_predict_will_drop) # 保存每个批次的 will_drop 数据
  237. del df_predict
  238. if not list_df_will_drop:
  239. print("所有批次的 will_drop 都为空")
  240. return
  241. # === 1. 合并所有 will_drop 结果 ===
  242. df_predict_will_drop_all = pd.concat(list_df_will_drop, ignore_index=True)
  243. # 释放临时列表内存(大列表时很有必要)
  244. del list_df_will_drop
  245. before_rows = len(df_predict_will_drop_all)
  246. # 定义“航班唯一标识”的分组键
  247. group_keys = ["city_pair", "flight_number_1", "flight_number_2", "flight_day"]
  248. # === 2. batch_time 转为 datetime,用于时间间隔判断 ===
  249. df_predict_will_drop_all["batch_dt"] = pd.to_datetime(
  250. df_predict_will_drop_all["batch_time"],
  251. format="%Y%m%d%H%M",
  252. errors="coerce", # 非法时间直接置为 NaT
  253. )
  254. # === 3. 自动推断 batch_time 的“正常时间步长”(分钟) ===
  255. diff_minutes = (
  256. df_predict_will_drop_all["batch_dt"].dropna().sort_values().drop_duplicates().diff()
  257. .dt.total_seconds()
  258. .div(60)
  259. .dropna()
  260. )
  261. # - 取出现频率最高的时间差作为“期望步长” 默认 60 分钟
  262. expected_step_minutes = (
  263. int(diff_minutes.value_counts().idxmax()) if not diff_minutes.empty else 60
  264. )
  265. # === 4. 按航班 + 批次时间排序,为后续连续性判断做准备 ===
  266. df_predict_will_drop_all.sort_values(
  267. by=group_keys + ["batch_dt"],
  268. inplace=True,
  269. ignore_index=True,
  270. na_position="last",
  271. )
  272. # === 5. 计算组内相邻 batch_dt 的时间间隔 ===
  273. df_predict_will_drop_all["prev_batch_dt"] = df_predict_will_drop_all.groupby(group_keys)[
  274. "batch_dt"
  275. ].shift(1)
  276. df_predict_will_drop_all["gap_minutes"] = (
  277. (df_predict_will_drop_all["batch_dt"] - df_predict_will_drop_all["prev_batch_dt"])
  278. .dt.total_seconds()
  279. .div(60)
  280. )
  281. # === 6. 标记“是否是一个新的连续段” ===
  282. # 新段的条件:
  283. # 1) prev_batch_dt 缺失(当前是组内第一条)
  284. # 2) batch_dt 缺失 (不常见)
  285. # 3) 与上一条的时间间隔 != 期望步长
  286. df_predict_will_drop_all["is_new_segment"] = (
  287. df_predict_will_drop_all["prev_batch_dt"].isna()
  288. | df_predict_will_drop_all["batch_dt"].isna()
  289. | (df_predict_will_drop_all["gap_minutes"] != expected_step_minutes)
  290. )
  291. # === 7. 生成段号(segment_id)===
  292. # 同一航班内,每遇到一个新段就 +1
  293. df_predict_will_drop_all["segment_id"] = df_predict_will_drop_all.groupby(group_keys)[
  294. "is_new_segment"
  295. ].cumsum()
  296. # === 8. 计算每个连续段的“段尾 hours_until_departure” ===
  297. df_segment_last = df_predict_will_drop_all.groupby(
  298. group_keys + ["segment_id"], as_index=False
  299. ).agg(last_hours_until_departure=("hours_until_departure", "last"))
  300. # === 9. 每个连续段只保留“第一条记录”,并补上段尾信息 ===
  301. df_predict_will_drop_filter = df_predict_will_drop_all.drop_duplicates(
  302. subset=group_keys + ["segment_id"], keep="first"
  303. ).merge(
  304. df_segment_last,
  305. on=group_keys + ["segment_id"],
  306. how="left",
  307. )
  308. # === 10. 清理中间附加字段 ===
  309. df_predict_will_drop_filter = (
  310. df_predict_will_drop_filter.drop(
  311. columns=[
  312. "batch_dt",
  313. "prev_batch_dt",
  314. "gap_minutes",
  315. "is_new_segment",
  316. "segment_id",
  317. ]
  318. )
  319. .reset_index(drop=True)
  320. )
  321. # === 11. 调整字段顺序(last_hours_until_departure 紧跟 price_change_percent)===
  322. if "price_change_percent" in df_predict_will_drop_filter.columns:
  323. cols = df_predict_will_drop_filter.columns.tolist()
  324. if "last_hours_until_departure" in cols:
  325. cols.remove("last_hours_until_departure")
  326. cols.insert(cols.index("price_change_percent"), "last_hours_until_departure")
  327. df_predict_will_drop_filter = df_predict_will_drop_filter[cols]
  328. after_rows = len(df_predict_will_drop_filter)
  329. print(
  330. f"will_drop 连续段过滤完成(step={expected_step_minutes}min): {before_rows} -> {after_rows}"
  331. )
  332. # 当前时间,取整时
  333. current_time = datetime.datetime.now()
  334. current_time_str = current_time.strftime("%Y%m%d%H%M")
  335. hourly_time = current_time.replace(minute=0, second=0, microsecond=0)
  336. hourly_time_str = hourly_time.strftime("%Y%m%d%H%M")
  337. before_rows = len(df_predict_will_drop_filter)
  338. df_predict_will_drop_filter["valid_end_dt"] = pd.to_datetime(
  339. df_predict_will_drop_filter["valid_end_hour"],
  340. errors="coerce",
  341. )
  342. df_predict_will_drop_filter_1 = df_predict_will_drop_filter[
  343. (df_predict_will_drop_filter["valid_end_dt"] + pd.Timedelta(hours=8))
  344. <= hourly_time
  345. ].copy()
  346. df_predict_will_drop_filter_1.drop(columns=["valid_end_dt"], inplace=True)
  347. after_rows = len(df_predict_will_drop_filter_1)
  348. print(
  349. f"valid_end_hour(+8h)过滤完成: {before_rows} -> {after_rows} (hourly_time={hourly_time_str})"
  350. )
  351. # 开始验证
  352. df_predict_will_drop_validate = _validate_predict_df(df_predict_will_drop_filter_1)
  353. timestamp_str = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
  354. save_scv = f"result_validate_{node}_zong_{timestamp_str}.csv"
  355. output_path = os.path.join(output_dir, save_scv)
  356. df_predict_will_drop_validate.to_csv(output_path, index=False, encoding="utf-8-sig")
  357. print(f"保存完成: {output_path}")
  358. print(f"验证完成: {node} zong")
  359. print()
  360. if __name__ == "__main__":
  361. parser = argparse.ArgumentParser(description='验证脚本')
  362. parser.add_argument('--interval', type=int, choices=[1],
  363. default=0, help='间隔小时数(1,)')
  364. args = parser.parse_args()
  365. interval_hours = args.interval
  366. # 0 手动验证
  367. if interval_hours == 0:
  368. # node, pred_time_str = "node0127", "202601301500"
  369. # validate_process(node, interval_hours, pred_time_str)
  370. node = "node0127"
  371. validate_process_zong(node) # 无条件汇总
  372. # node = "node0203"
  373. # validate_process_zong(node, True, "202602031100") # 有条件汇总
  374. # 1 自动验证
  375. else:
  376. node = "node0127"
  377. validate_process_auto(node, interval_hours)