| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150 |
- import argparse
- import datetime
- import os
- import pandas as pd
- def build_predict_board(
- node,
- enable_min_max_batch_flag=False,
- min_batch_time_str=None,
- max_batch_time_str=None,
- object_dir="./predictions_0",
- output_dir="./predictions",
- ):
- os.makedirs(output_dir, exist_ok=True)
- if not os.path.exists(object_dir):
- print(f"目录不存在: {object_dir}")
- return
- csv_files = []
- for file in os.listdir(object_dir):
- if file.startswith("future_predictions_") and file.endswith(".csv"):
- csv_files.append(file)
-
- if not csv_files:
- print(f"在 {object_dir} 中没有找到 future_predictions_ 开头的 CSV 文件")
- return
- csv_files.sort()
- min_batch_dt = None
- max_batch_dt = None
- if enable_min_max_batch_flag:
- if not min_batch_time_str and not max_batch_time_str:
- print("enable_min_max_batch_flag=True 但未提供 min_batch_time_str/max_batch_time_str,退出")
- return
- if min_batch_time_str:
- min_batch_dt = datetime.datetime.strptime(min_batch_time_str, "%Y%m%d%H%M")
- min_batch_dt = min_batch_dt.replace(minute=0, second=0, microsecond=0)
- if max_batch_time_str:
- max_batch_dt = datetime.datetime.strptime(max_batch_time_str, "%Y%m%d%H%M")
- max_batch_dt = max_batch_dt.replace(minute=0, second=0, microsecond=0)
- if min_batch_dt is not None and max_batch_dt is not None and min_batch_dt > max_batch_dt:
- print(f"时间范围非法: min_batch_time_str({min_batch_time_str}) > max_batch_time_str({max_batch_time_str}),退出")
- return
- group_keys = ["city_pair", "flight_day", "flight_number_1", "flight_number_2"]
- list_df_all = []
- count = 0
- for csv_file in csv_files:
- batch_time_str = csv_file.replace("future_predictions_", "").replace(".csv", "")
- try:
- batch_dt = datetime.datetime.strptime(batch_time_str, "%Y%m%d%H%M")
- except Exception:
- print(f"批次时间解析失败,跳过: {csv_file}")
- continue
-
- batch_hour_dt = batch_dt.replace(minute=0, second=0, microsecond=0)
- if min_batch_dt is not None and batch_hour_dt < min_batch_dt:
- continue
- if max_batch_dt is not None and batch_hour_dt > max_batch_dt:
- continue
-
- csv_path = os.path.join(object_dir, csv_file)
- try:
- df_predict = pd.read_csv(csv_path)
- except Exception as e:
- print(f"read {csv_path} error: {str(e)}")
- continue
- if df_predict.empty:
- continue
- missing = [c for c in group_keys if c not in df_predict.columns]
- if missing:
- print(f"缺少关键字段 {missing},跳过: {csv_file}")
- continue
- if "will_price_drop" not in df_predict.columns:
- print(f"缺少 will_price_drop 字段,跳过: {csv_file}")
- continue
- # df_predict = df_predict.copy()
- df_predict["batch_time"] = batch_time_str
- # 指定 update_hour 列前插入 batch_time 列
- insert_idx = df_predict.columns.get_loc("update_hour")
- col_data = df_predict.pop("batch_time")
- df_predict.insert(insert_idx, "batch_time", col_data)
- # df_predict["batch_dt"] = pd.to_datetime(batch_time_str, format="%Y%m%d%H%M", errors="coerce")
- # if "hours_until_departure" in df_predict.columns:
- # hud = pd.to_numeric(df_predict["hours_until_departure"], errors="coerce")
- # df_predict = df_predict.loc[hud.between(13, 60)].copy()
- list_df_all.append(df_predict)
- count += 1
- if count % 10 == 0:
- print(f"已处理 {count} 个批次文件")
- if not list_df_all:
- print("汇总为空:没有符合条件的批次文件")
- return
- print(f"汇总 {len(list_df_all)} 个批次文件")
- df_all = pd.concat(list_df_all, ignore_index=True)
- del list_df_all
- df_all = df_all.drop_duplicates(subset=group_keys + ["batch_time"], keep="last").reset_index(drop=True)
- df_all.sort_values(
- by=group_keys + ["batch_time"],
- inplace=True,
- ignore_index=True,
- na_position="last",
- )
- timestamp_str = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
- save_csv = f"prediction_board_{node}_{timestamp_str}.csv"
- output_path = os.path.join(output_dir, save_csv)
- df_all.to_csv(output_path, index=False, encoding="utf-8-sig")
- print(f"保存完成: {output_path} (rows={len(df_all)})")
- if __name__ == "__main__":
- # parser = argparse.ArgumentParser(description="预测看板汇总脚本")
- # parser.add_argument("--node", type=str, default="node", help="节点名/标识,用于输出文件名")
- # parser.add_argument("--enable_min_max_batch", action="store_true", help="启用批次时间范围过滤")
- # parser.add_argument("--min_batch_time", type=str, default=None, help="最小批次时间,如 202602061000")
- # parser.add_argument("--max_batch_time", type=str, default=None, help="最大批次时间,如 202602091000")
- # args = parser.parse_args()
- # build_predict_board(
- # node=args.node,
- # enable_min_max_batch_flag=args.enable_min_max_batch,
- # min_batch_time_str=args.min_batch_time,
- # max_batch_time_str=args.max_batch_time,
- # )
- node = "node0205"
- enable_min_max_batch_flag = True
- min_batch_time_str = "202602061000"
- max_batch_time_str = "202602091000"
- build_predict_board(node, enable_min_max_batch_flag, min_batch_time_str, max_batch_time_str)
|