result_gather.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. import argparse
  2. import datetime
  3. import os
  4. import pandas as pd
  5. def build_predict_board(
  6. node,
  7. enable_min_max_batch_flag=False,
  8. min_batch_time_str=None,
  9. max_batch_time_str=None,
  10. object_dir="./predictions_0",
  11. output_dir="./predictions",
  12. ):
  13. os.makedirs(output_dir, exist_ok=True)
  14. if not os.path.exists(object_dir):
  15. print(f"目录不存在: {object_dir}")
  16. return
  17. csv_files = []
  18. for file in os.listdir(object_dir):
  19. if file.startswith("future_predictions_") and file.endswith(".csv"):
  20. csv_files.append(file)
  21. if not csv_files:
  22. print(f"在 {object_dir} 中没有找到 future_predictions_ 开头的 CSV 文件")
  23. return
  24. csv_files.sort()
  25. min_batch_dt = None
  26. max_batch_dt = None
  27. if enable_min_max_batch_flag:
  28. if not min_batch_time_str and not max_batch_time_str:
  29. print("enable_min_max_batch_flag=True 但未提供 min_batch_time_str/max_batch_time_str,退出")
  30. return
  31. if min_batch_time_str:
  32. min_batch_dt = datetime.datetime.strptime(min_batch_time_str, "%Y%m%d%H%M")
  33. min_batch_dt = min_batch_dt.replace(minute=0, second=0, microsecond=0)
  34. if max_batch_time_str:
  35. max_batch_dt = datetime.datetime.strptime(max_batch_time_str, "%Y%m%d%H%M")
  36. max_batch_dt = max_batch_dt.replace(minute=0, second=0, microsecond=0)
  37. if min_batch_dt is not None and max_batch_dt is not None and min_batch_dt > max_batch_dt:
  38. print(f"时间范围非法: min_batch_time_str({min_batch_time_str}) > max_batch_time_str({max_batch_time_str}),退出")
  39. return
  40. group_keys = ["city_pair", "flight_day", "flight_number_1", "flight_number_2"]
  41. list_df_all = []
  42. count = 0
  43. for csv_file in csv_files:
  44. batch_time_str = csv_file.replace("future_predictions_", "").replace(".csv", "")
  45. try:
  46. batch_dt = datetime.datetime.strptime(batch_time_str, "%Y%m%d%H%M")
  47. except Exception:
  48. print(f"批次时间解析失败,跳过: {csv_file}")
  49. continue
  50. batch_hour_dt = batch_dt.replace(minute=0, second=0, microsecond=0)
  51. if min_batch_dt is not None and batch_hour_dt < min_batch_dt:
  52. continue
  53. if max_batch_dt is not None and batch_hour_dt > max_batch_dt:
  54. continue
  55. csv_path = os.path.join(object_dir, csv_file)
  56. try:
  57. df_predict = pd.read_csv(csv_path)
  58. except Exception as e:
  59. print(f"read {csv_path} error: {str(e)}")
  60. continue
  61. if df_predict.empty:
  62. continue
  63. missing = [c for c in group_keys if c not in df_predict.columns]
  64. if missing:
  65. print(f"缺少关键字段 {missing},跳过: {csv_file}")
  66. continue
  67. if "will_price_drop" not in df_predict.columns:
  68. print(f"缺少 will_price_drop 字段,跳过: {csv_file}")
  69. continue
  70. # df_predict = df_predict.copy()
  71. df_predict["batch_time"] = batch_time_str
  72. # 指定 update_hour 列前插入 batch_time 列
  73. insert_idx = df_predict.columns.get_loc("update_hour")
  74. col_data = df_predict.pop("batch_time")
  75. df_predict.insert(insert_idx, "batch_time", col_data)
  76. # df_predict["batch_dt"] = pd.to_datetime(batch_time_str, format="%Y%m%d%H%M", errors="coerce")
  77. # if "hours_until_departure" in df_predict.columns:
  78. # hud = pd.to_numeric(df_predict["hours_until_departure"], errors="coerce")
  79. # df_predict = df_predict.loc[hud.between(13, 60)].copy()
  80. list_df_all.append(df_predict)
  81. count += 1
  82. if count % 10 == 0:
  83. print(f"已处理 {count} 个批次文件")
  84. if not list_df_all:
  85. print("汇总为空:没有符合条件的批次文件")
  86. return
  87. print(f"汇总 {len(list_df_all)} 个批次文件")
  88. df_all = pd.concat(list_df_all, ignore_index=True)
  89. del list_df_all
  90. df_all = df_all.drop_duplicates(subset=group_keys + ["batch_time"], keep="last").reset_index(drop=True)
  91. df_all.sort_values(
  92. by=group_keys + ["batch_time"],
  93. inplace=True,
  94. ignore_index=True,
  95. na_position="last",
  96. )
  97. timestamp_str = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
  98. save_csv = f"prediction_board_{node}_{timestamp_str}.csv"
  99. output_path = os.path.join(output_dir, save_csv)
  100. df_all.to_csv(output_path, index=False, encoding="utf-8-sig")
  101. print(f"保存完成: {output_path} (rows={len(df_all)})")
  102. if __name__ == "__main__":
  103. # parser = argparse.ArgumentParser(description="预测看板汇总脚本")
  104. # parser.add_argument("--node", type=str, default="node", help="节点名/标识,用于输出文件名")
  105. # parser.add_argument("--enable_min_max_batch", action="store_true", help="启用批次时间范围过滤")
  106. # parser.add_argument("--min_batch_time", type=str, default=None, help="最小批次时间,如 202602061000")
  107. # parser.add_argument("--max_batch_time", type=str, default=None, help="最大批次时间,如 202602091000")
  108. # args = parser.parse_args()
  109. # build_predict_board(
  110. # node=args.node,
  111. # enable_min_max_batch_flag=args.enable_min_max_batch,
  112. # min_batch_time_str=args.min_batch_time,
  113. # max_batch_time_str=args.max_batch_time,
  114. # )
  115. node = "node0205"
  116. enable_min_max_batch_flag = True
  117. min_batch_time_str = "202602061000"
  118. max_batch_time_str = "202602091000"
  119. build_predict_board(node, enable_min_max_batch_flag, min_batch_time_str, max_batch_time_str)