| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101 |
- import os
- import time
- from datetime import datetime, timedelta
- from config import mongo_config, uo_city_pairs_new
- from data_loader import load_data
- from data_process import preprocess_data_simple, predict_data_simple
- def start_predict():
- print(f"开始预测")
- output_dir = "./data_shards"
- predict_dir = "./predictions"
- os.makedirs(predict_dir, exist_ok=True)
- cpu_cores = os.cpu_count() # 你的系统是72
- max_workers = min(4, cpu_cores) # 最大不超过4个进程
- # 当前时间,取整时
- current_time = datetime.now()
- current_time_str = current_time.strftime("%Y%m%d%H%M")
- hourly_time = current_time.replace(minute=0, second=0, microsecond=0)
- hourly_time_str = hourly_time.strftime("%Y%m%d%H%M")
- print(f"预测时间:{current_time_str}, (取整): {hourly_time_str}")
- # 清空上一次(同小时内)预测结果
- csv_file_list = [f'future_predictions_{hourly_time_str}.csv']
- for csv_file in csv_file_list:
- try:
- csv_path = os.path.join(predict_dir, csv_file)
- os.remove(csv_path)
- except Exception as e:
- print(f"remove {csv_path} info: {str(e)}")
- # 预测时间范围,满足起飞时间 在24小时后到360小时后
- pred_hour_begin = hourly_time + timedelta(hours=24)
- pred_hour_end = hourly_time + timedelta(hours=360)
- pred_date_end = pred_hour_end.strftime("%Y-%m-%d")
- pred_date_begin = pred_hour_begin.strftime("%Y-%m-%d")
- print(f"预测起飞时间范围: {pred_date_begin} 到 {pred_date_end}")
- uo_city_pairs = uo_city_pairs_new.copy()
- uo_city_pair_list = [f"{pair[:3]}-{pair[3:]}" for pair in uo_city_pairs]
- # 如果临时处理中断,从日志里找到 中断的索引 修改它
- resume_idx = 0
- uo_city_pair_list = uo_city_pair_list[resume_idx:]
- # 打印预测阶段起始索引顺序
- max_len = len(uo_city_pair_list) + resume_idx
- print(f"预测阶段起始索引顺序:{resume_idx} ~ {max_len - 1}")
- for idx, uo_city_pair in enumerate(uo_city_pair_list, start=resume_idx):
- print(f"第 {idx} 组 :", uo_city_pair)
- # 加载预测数据 (仅仅是天数取到以后)
- start_time = time.time()
- df_test = load_data(mongo_config, uo_city_pair, pred_date_begin, pred_date_end, is_train=False,
- use_multiprocess=True, max_workers=max_workers)
- end_time = time.time()
- run_time = round(end_time - start_time, 3)
- print(f"用时: {run_time} 秒")
- if df_test.empty:
- print(f"预测数据为空,跳过此批次。")
- continue
- # 按起飞时间过滤
- df_test['from_hour'] = df_test['from_time'].dt.floor('h')
- # 使用整点时间进行比较过滤
- mask = (df_test['from_hour'] >= pred_hour_begin) & (df_test['from_hour'] < pred_hour_end)
- original_count = len(df_test)
- df_test = df_test[mask].reset_index(drop=True)
- filtered_count = len(df_test)
- # 删除临时字段
- df_test = df_test.drop(columns=['from_hour'])
- print(f"按起飞时间过滤:过滤前 {original_count} 条,过滤后 {filtered_count} 条")
- if filtered_count == 0:
- print(f"起飞时间在 {pred_hour_begin} 到 {pred_hour_end} 之间没有航班,跳过此批次。")
- continue
- df_test_inputs, _, _, _, = preprocess_data_simple(df_test, is_train=False, hourly_time=hourly_time)
- df_predict = predict_data_simple(df_test_inputs, uo_city_pair, output_dir, predict_dir, hourly_time_str)
-
- del df_test_inputs
- del df_predict
- print(f"第 {idx} 组 预测完成")
- print()
- time.sleep(1)
- print("所有批次的预测结束")
- print()
- if __name__ == "__main__":
- start_predict()
|