follow_up.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. import os
  2. import datetime
  3. import pandas as pd
  4. from config import mongodb_config
  5. def follow_up_handle():
  6. '''后续处理'''
  7. object_dir = "./predictions_0"
  8. output_dir = "./keep_0"
  9. # 创建输出目录
  10. os.makedirs(output_dir, exist_ok=True)
  11. # 检查目录是否存在
  12. if not os.path.exists(object_dir):
  13. print(f"目录不存在: {object_dir}")
  14. return
  15. # 获取所有以 future_predictions_ 开头的 CSV 文件
  16. csv_files = []
  17. for file in os.listdir(object_dir):
  18. if file.startswith("future_predictions_") and file.endswith(".csv"):
  19. csv_files.append(file)
  20. if not csv_files:
  21. print(f"在 {object_dir} 中没有找到 future_predictions_ 开头的 CSV 文件")
  22. return
  23. csv_files.sort()
  24. # 调试分支
  25. # target_time = "202603011600"
  26. # matching_files = [f for f in csv_files if target_time in f]
  27. # if matching_files:
  28. # last_csv_file = matching_files[0]
  29. # print(f"指定时间的文件: {last_csv_file}")
  30. # else:
  31. # print(f"未找到时间 {target_time} 的预测文件")
  32. # return
  33. # 正式分支
  34. last_csv_file = csv_files[-1] # 只看最新预测的文件
  35. print(f"最新预测文件: {last_csv_file}")
  36. if last_csv_file.startswith("future_predictions_") and last_csv_file.endswith(".csv"):
  37. target_time = last_csv_file.replace("future_predictions_", "").replace(".csv", "")
  38. else:
  39. target_time = datetime.datetime.now().strftime("%Y%m%d%H%M")
  40. # 读取最新预测文件
  41. last_csv_path = os.path.join(object_dir, last_csv_file)
  42. df_last_predict = pd.read_csv(last_csv_path)
  43. df_last_predict_will_drop = df_last_predict[df_last_predict["will_price_drop"] == 1].reset_index(drop=True)
  44. df_last_predict_not_drop = df_last_predict[df_last_predict["will_price_drop"] == 0].reset_index(drop=True)
  45. print(f"最新预测文件中,预测降价的航班有 {len(df_last_predict_will_drop)} 条,预测不降价的航班有 {len(df_last_predict_not_drop)} 条")
  46. # 建一张 维护表 keep_info.csv 附加一个维护表快照 keep_info_{target_time}.csv
  47. keep_info_path = os.path.join(output_dir, "keep_info.csv")
  48. keep_info_snapshot_path = os.path.join(output_dir, f"keep_info_{target_time}.csv")
  49. key_cols = ["city_pair", "flight_day", "flight_number_1", "flight_number_2"]
  50. df_last_predict_will_drop = df_last_predict_will_drop.drop_duplicates(
  51. subset=key_cols, keep="last"
  52. ).reset_index(drop=True)
  53. df_last_predict_not_drop = df_last_predict_not_drop.drop_duplicates(
  54. subset=key_cols, keep="last"
  55. ).reset_index(drop=True)
  56. # 读取维护表
  57. if os.path.exists(keep_info_path):
  58. try:
  59. df_keep_info = pd.read_csv(keep_info_path)
  60. except Exception as e:
  61. print(f"读取维护表失败: {keep_info_path}, error: {str(e)}")
  62. df_keep_info = pd.DataFrame()
  63. else:
  64. df_keep_info = pd.DataFrame()
  65. # 初始化维护表
  66. if df_keep_info.empty:
  67. df_keep_info = df_last_predict_will_drop.copy()
  68. df_keep_info["keep_flag"] = 1
  69. df_keep_info.to_csv(keep_info_path, index=False, encoding="utf-8-sig")
  70. print(f"维护表已初始化: {keep_info_path} (rows={len(df_keep_info)})")
  71. df_keep_info.to_csv(keep_info_snapshot_path, index=False, encoding="utf-8-sig")
  72. print(f"维护表快照已保存: {keep_info_snapshot_path} (rows={len(df_keep_info)})")
  73. # 已存在维护表
  74. else:
  75. if "keep_flag" not in df_keep_info.columns:
  76. df_keep_info["keep_flag"] = 0
  77. df_keep_info["keep_flag"] = (
  78. pd.to_numeric(df_keep_info["keep_flag"], errors="coerce")
  79. .fillna(0)
  80. .astype(int)
  81. )
  82. missing_cols = [c for c in key_cols if c not in df_keep_info.columns]
  83. if missing_cols:
  84. print(f"维护表缺少字段: {missing_cols}, path={keep_info_path}")
  85. return
  86. for c in key_cols:
  87. df_last_predict_will_drop[c] = df_last_predict_will_drop[c].astype(str)
  88. df_last_predict_not_drop[c] = df_last_predict_not_drop[c].astype(str)
  89. df_keep_info[c] = df_keep_info[c].astype(str)
  90. df_keep_info = df_keep_info.drop_duplicates(subset=key_cols, keep="last").reset_index(drop=True)
  91. # 提取两者的标志位
  92. df_last_keys = df_last_predict_will_drop[key_cols].drop_duplicates().reset_index(drop=True)
  93. df_keep_keys = df_keep_info[key_cols].drop_duplicates().reset_index(drop=True)
  94. df_last_with_merge = df_last_predict_will_drop.merge(
  95. df_keep_keys, on=key_cols, how="left", indicator=True
  96. )
  97. # 场景一: 如果某一行数据在 df_last_predict_will_drop 出现,没有在 df_keep_info 里
  98. df_to_add = (
  99. df_last_with_merge.loc[df_last_with_merge["_merge"] == "left_only"]
  100. .drop(columns=["_merge"])
  101. .copy()
  102. )
  103. # keep_flag 设为 1
  104. if not df_to_add.empty:
  105. df_to_add["keep_flag"] = 1
  106. df_keep_with_merge = df_keep_info.reset_index().merge(
  107. df_last_keys, on=key_cols, how="left", indicator=True
  108. )
  109. # 场景二: 如果某一行数据在 df_last_predict_will_drop 和 df_keep_info 里都出现
  110. matched_idx = df_keep_with_merge.loc[df_keep_with_merge["_merge"] == "both", "index"].tolist()
  111. # 场景三: 如果某一行数据在 df_last_predict_will_drop 没有出现,却在 df_keep_info 里都出现
  112. keep_only_idx = df_keep_with_merge.loc[df_keep_with_merge["_merge"] == "left_only", "index"].tolist()
  113. # 符合场景二的索引 (在 df_keep_with_merge 中)
  114. if matched_idx:
  115. df_matched_keys = df_keep_info.loc[matched_idx, key_cols]
  116. df_latest_matched = df_matched_keys.merge(
  117. df_last_predict_will_drop, on=key_cols, how="left"
  118. )
  119. # 将 df_keep_info 的 df_matched_keys 的内容更新为 df_last_predict_will_drop 里对应的内容
  120. update_cols = [c for c in df_last_predict_will_drop.columns if c not in key_cols]
  121. for c in update_cols:
  122. if c == "keep_flag":
  123. continue
  124. if c not in df_keep_info.columns:
  125. df_keep_info[c] = pd.NA
  126. df_keep_info.loc[matched_idx, c] = df_latest_matched[c].values
  127. # 重新标记 原来是1 -> 0 原来是0 -> 0 原来是-1 -> 1
  128. old_flags = df_keep_info.loc[matched_idx, "keep_flag"]
  129. df_keep_info.loc[matched_idx, "keep_flag"] = old_flags.apply(
  130. lambda x: 0 if x in (0, 1) else (1 if x == -1 else 1)
  131. )
  132. # 符合场景三的索引 (在 df_keep_with_merge 中)
  133. if keep_only_idx:
  134. mask_keep_only = df_keep_info.index.isin(keep_only_idx) # 布尔索引序列
  135. # 如果 df_keep_info 的 keep_flag 为-1,此时标记为-2
  136. # mask_to_remove = mask_keep_only & (df_keep_info["keep_flag"] == -1)
  137. # df_keep_info.loc[mask_to_remove, "keep_flag"] = -2
  138. # 如果 df_keep_info 的 keep_flag 大于等于0
  139. mask_need_observe = mask_keep_only & (df_keep_info["keep_flag"] >= 0) # 布尔索引序列
  140. if mask_need_observe.any():
  141. if "hours_until_departure" not in df_keep_info.columns:
  142. df_keep_info.loc[mask_need_observe, "keep_flag"] = -1
  143. else:
  144. hud = pd.to_numeric(
  145. df_keep_info.loc[mask_need_observe, "hours_until_departure"],
  146. errors="coerce",
  147. )
  148. # hours_until_departure自动减1
  149. new_hud = hud - 1
  150. df_keep_info.loc[mask_need_observe, "hours_until_departure"] = new_hud
  151. df_keep_only_keys = df_keep_info.loc[mask_keep_only, key_cols].copy()
  152. df_keep_only_keys["_row_idx"] = df_keep_only_keys.index
  153. # 检查 df_keep_only_keys 是否在 df_last_predict_not_drop 中
  154. df_keep_only_keys = df_keep_only_keys.merge(
  155. df_last_predict_not_drop[key_cols].drop_duplicates(),
  156. on=key_cols,
  157. how="left",
  158. indicator=True,
  159. )
  160. idx_in_not_drop = df_keep_only_keys.loc[
  161. df_keep_only_keys["_merge"] == "both", "_row_idx"
  162. ].tolist()
  163. mask_in_not_drop = df_keep_info.index.isin(idx_in_not_drop) # 在 df_last_predict_not_drop 中出现 只是will_price_drop为0 未达边界
  164. mask_not_drop_observe = mask_need_observe & mask_in_not_drop # 判断为不降价的布尔索引数组
  165. mask_boundary_observe = mask_need_observe & ~mask_in_not_drop # 判断为到达边界的布尔索引数组
  166. df_keep_info.loc[mask_not_drop_observe, "keep_flag"] = -1 # 删除标志
  167. if mask_boundary_observe.any():
  168. new_hud_full = pd.to_numeric(
  169. df_keep_info["hours_until_departure"], errors="coerce"
  170. )
  171. df_keep_info.loc[mask_boundary_observe, "keep_flag"] = -1 # 默认删除标志
  172. df_keep_info.loc[
  173. mask_boundary_observe & new_hud_full.gt(4), "keep_flag" # 如果达到边界且hours_until_departure大于4 则给保留标志
  174. ] = 0
  175. pass
  176. # idx_eq13 = mask_need_observe.copy()
  177. # idx_eq13.loc[idx_eq13] = hud.eq(13) # 原hours_until_departure等于13
  178. # idx_gt13 = mask_need_observe.copy()
  179. # idx_gt13.loc[idx_gt13] = hud.gt(13) # 原hours_until_departure大于13
  180. # idx_other = mask_need_observe & ~(idx_eq13 | idx_gt13) # 原hours_until_departure小于13
  181. # idx_eq13_gt4 = idx_eq13 & new_hud.gt(4)
  182. # idx_eq13_eq4 = idx_eq13 & new_hud.eq(4)
  183. # # idx_eq13_lt4 = idx_eq13 & new_hud.lt(4)
  184. # df_keep_info.loc[idx_eq13_gt4, "keep_flag"] = 0
  185. # df_keep_info.loc[idx_eq13_eq4, "keep_flag"] = -1
  186. # # df_keep_info.loc[idx_eq13_lt4, "keep_flag"] = -2
  187. # df_keep_info.loc[idx_gt13, "keep_flag"] = -1
  188. # idx_other_gt4 = idx_other & new_hud.gt(4)
  189. # idx_other_eq4 = idx_other & new_hud.eq(4)
  190. # # idx_other_lt4 = idx_other & new_hud.lt(4)
  191. # df_keep_info.loc[idx_other_gt4, "keep_flag"] = 0
  192. # df_keep_info.loc[idx_other_eq4, "keep_flag"] = -1
  193. # # df_keep_info.loc[idx_other_lt4, "keep_flag"] = -2
  194. # 将 df_to_add 添加到 df_keep_info 之后
  195. add_rows = len(df_to_add) if "df_to_add" in locals() else 0
  196. if add_rows:
  197. df_keep_info = pd.concat([df_keep_info, df_to_add], ignore_index=True)
  198. df_keep_info_snapshot = df_keep_info.copy()
  199. df_keep_info_snapshot.to_csv(keep_info_snapshot_path, index=False, encoding="utf-8-sig")
  200. print(
  201. f"维护表快照已保存: {keep_info_snapshot_path} (rows={len(df_keep_info_snapshot)})"
  202. )
  203. # 移除 keep_flag 为 -1 的行
  204. before_rm = len(df_keep_info)
  205. df_keep_info = df_keep_info.loc[df_keep_info["keep_flag"] != -1].reset_index(drop=True)
  206. rm_rows = before_rm - len(df_keep_info)
  207. # 保存更新后的 df_keep_info 到维护表csv文件
  208. df_keep_info.to_csv(keep_info_path, index=False, encoding="utf-8-sig")
  209. print(
  210. f"维护表已更新: {keep_info_path} (rows={len(df_keep_info)} add={add_rows} rm={rm_rows})"
  211. )
  212. # ================================================================
  213. # for idx, row in df_last_predict_will_drop.iterrows():
  214. # city_pair = row['city_pair']
  215. # flight_day = row['flight_day']
  216. # flight_number_1 = row['flight_number_1']
  217. # flight_number_2 = row['flight_number_2']
  218. # baggage = row['baggage']
  219. # from_city_code = city_pair.split('-')[0]
  220. # to_city_code = city_pair.split('-')[1]
  221. # from_day = datetime.datetime.strptime(flight_day, '%Y-%m-%d').strftime('%Y%m%d')
  222. # baggage_str = f"1-{baggage}"
  223. # pass
  224. # adult_total_price = row['adult_total_price']
  225. # hours_until_departure = row['hours_until_departure']
  226. pass
  227. if __name__ == "__main__":
  228. follow_up_handle()