Browse Source

提交预测结果汇总脚本(非验证)

node04 1 day ago
parent
commit
fce924884c
1 changed files with 150 additions and 0 deletions
  1. 150 0
      result_gather.py

+ 150 - 0
result_gather.py

@@ -0,0 +1,150 @@
+import argparse
+import datetime
+import os
+import pandas as pd
+
+
+def build_predict_board(
+    node,
+    enable_min_max_batch_flag=False,
+    min_batch_time_str=None,
+    max_batch_time_str=None,
+    object_dir="./predictions_0",
+    output_dir="./predictions",
+):
+    os.makedirs(output_dir, exist_ok=True)
+
+    if not os.path.exists(object_dir):
+        print(f"目录不存在: {object_dir}")
+        return
+
+    csv_files = []
+    for file in os.listdir(object_dir):
+        if file.startswith("future_predictions_") and file.endswith(".csv"):
+            csv_files.append(file)
+    
+    if not csv_files:
+        print(f"在 {object_dir} 中没有找到 future_predictions_ 开头的 CSV 文件")
+        return
+
+    csv_files.sort()
+
+    min_batch_dt = None
+    max_batch_dt = None
+    if enable_min_max_batch_flag:
+        if not min_batch_time_str and not max_batch_time_str:
+            print("enable_min_max_batch_flag=True 但未提供 min_batch_time_str/max_batch_time_str,退出")
+            return
+        if min_batch_time_str:
+            min_batch_dt = datetime.datetime.strptime(min_batch_time_str, "%Y%m%d%H%M")
+            min_batch_dt = min_batch_dt.replace(minute=0, second=0, microsecond=0)
+        if max_batch_time_str:
+            max_batch_dt = datetime.datetime.strptime(max_batch_time_str, "%Y%m%d%H%M")
+            max_batch_dt = max_batch_dt.replace(minute=0, second=0, microsecond=0)
+        if min_batch_dt is not None and max_batch_dt is not None and min_batch_dt > max_batch_dt:
+            print(f"时间范围非法: min_batch_time_str({min_batch_time_str}) > max_batch_time_str({max_batch_time_str}),退出")
+            return
+
+    group_keys = ["city_pair", "flight_day", "flight_number_1", "flight_number_2"]
+
+    list_df_all = []
+    count = 0
+    for csv_file in csv_files:
+        batch_time_str = csv_file.replace("future_predictions_", "").replace(".csv", "")
+        try:
+            batch_dt = datetime.datetime.strptime(batch_time_str, "%Y%m%d%H%M")
+        except Exception:
+            print(f"批次时间解析失败,跳过: {csv_file}")
+            continue
+        
+        batch_hour_dt = batch_dt.replace(minute=0, second=0, microsecond=0)
+        if min_batch_dt is not None and batch_hour_dt < min_batch_dt:
+            continue
+        if max_batch_dt is not None and batch_hour_dt > max_batch_dt:
+            continue
+        
+        csv_path = os.path.join(object_dir, csv_file)
+        try:
+            df_predict = pd.read_csv(csv_path)
+        except Exception as e:
+            print(f"read {csv_path} error: {str(e)}")
+            continue
+
+        if df_predict.empty:
+            continue
+
+        missing = [c for c in group_keys if c not in df_predict.columns]
+        if missing:
+            print(f"缺少关键字段 {missing},跳过: {csv_file}")
+            continue
+
+        if "will_price_drop" not in df_predict.columns:
+            print(f"缺少 will_price_drop 字段,跳过: {csv_file}")
+            continue
+
+        # df_predict = df_predict.copy()
+        df_predict["batch_time"] = batch_time_str
+
+        # 指定 update_hour 列前插入 batch_time 列
+        insert_idx = df_predict.columns.get_loc("update_hour")
+        col_data = df_predict.pop("batch_time")
+        df_predict.insert(insert_idx, "batch_time", col_data)
+
+        # df_predict["batch_dt"] = pd.to_datetime(batch_time_str, format="%Y%m%d%H%M", errors="coerce")
+
+        # if "hours_until_departure" in df_predict.columns:
+        #     hud = pd.to_numeric(df_predict["hours_until_departure"], errors="coerce")
+        #     df_predict = df_predict.loc[hud.between(13, 60)].copy()
+
+        list_df_all.append(df_predict)
+
+        count += 1
+        if count % 10 == 0:
+            print(f"已处理 {count} 个批次文件")
+
+    if not list_df_all:
+        print("汇总为空:没有符合条件的批次文件")
+        return
+
+    print(f"汇总 {len(list_df_all)} 个批次文件")
+
+    df_all = pd.concat(list_df_all, ignore_index=True)
+    del list_df_all
+
+    df_all = df_all.drop_duplicates(subset=group_keys + ["batch_time"], keep="last").reset_index(drop=True)
+
+    df_all.sort_values(
+        by=group_keys + ["batch_time"],
+        inplace=True,
+        ignore_index=True,
+        na_position="last",
+    )
+
+    timestamp_str = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
+    save_csv = f"prediction_board_{node}_{timestamp_str}.csv"
+    output_path = os.path.join(output_dir, save_csv)
+    df_all.to_csv(output_path, index=False, encoding="utf-8-sig")
+    print(f"保存完成: {output_path} (rows={len(df_all)})")
+
+
+if __name__ == "__main__":
+    # parser = argparse.ArgumentParser(description="预测看板汇总脚本")
+    # parser.add_argument("--node", type=str, default="node", help="节点名/标识,用于输出文件名")
+    # parser.add_argument("--enable_min_max_batch", action="store_true", help="启用批次时间范围过滤")
+    # parser.add_argument("--min_batch_time", type=str, default=None, help="最小批次时间,如 202602061000")
+    # parser.add_argument("--max_batch_time", type=str, default=None, help="最大批次时间,如 202602091000")
+    # args = parser.parse_args()
+
+    # build_predict_board(
+    #     node=args.node,
+    #     enable_min_max_batch_flag=args.enable_min_max_batch,
+    #     min_batch_time_str=args.min_batch_time,
+    #     max_batch_time_str=args.max_batch_time,
+    # )
+
+    node = "node0205"
+    enable_min_max_batch_flag = True
+    min_batch_time_str = "202602061000"
+    max_batch_time_str = "202602091000"
+
+    build_predict_board(node, enable_min_max_batch_flag, min_batch_time_str, max_batch_time_str)