main_pe.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. import os
  2. import time
  3. from datetime import datetime, timedelta
  4. from config import mongo_config, uo_city_pairs_new
  5. from data_loader import load_data
  6. from data_process import preprocess_data_simple, predict_data_simple
  7. def start_predict():
  8. print(f"开始预测")
  9. output_dir = "./data_shards"
  10. predict_dir = "./predictions"
  11. os.makedirs(predict_dir, exist_ok=True)
  12. cpu_cores = os.cpu_count() # 你的系统是72
  13. max_workers = min(4, cpu_cores) # 最大不超过4个进程
  14. # 当前时间,取整时
  15. current_time = datetime.now()
  16. current_time_str = current_time.strftime("%Y%m%d%H%M")
  17. hourly_time = current_time.replace(minute=0, second=0, microsecond=0)
  18. hourly_time_str = hourly_time.strftime("%Y%m%d%H%M")
  19. print(f"预测时间:{current_time_str}, (取整): {hourly_time_str}")
  20. # 清空上一次(同小时内)预测结果
  21. csv_file_list = [f'future_predictions_{hourly_time_str}.csv']
  22. for csv_file in csv_file_list:
  23. try:
  24. csv_path = os.path.join(predict_dir, csv_file)
  25. os.remove(csv_path)
  26. except Exception as e:
  27. print(f"remove {csv_path} info: {str(e)}")
  28. # 预测时间范围,满足起飞时间 在24小时后到360小时后
  29. pred_hour_begin = hourly_time + timedelta(hours=24)
  30. pred_hour_end = hourly_time + timedelta(hours=360)
  31. pred_date_end = pred_hour_end.strftime("%Y-%m-%d")
  32. pred_date_begin = pred_hour_begin.strftime("%Y-%m-%d")
  33. print(f"预测起飞时间范围: {pred_date_begin} 到 {pred_date_end}")
  34. uo_city_pairs = uo_city_pairs_new.copy()
  35. uo_city_pair_list = [f"{pair[:3]}-{pair[3:]}" for pair in uo_city_pairs]
  36. # 如果临时处理中断,从日志里找到 中断的索引 修改它
  37. resume_idx = 0
  38. uo_city_pair_list = uo_city_pair_list[resume_idx:]
  39. # 打印预测阶段起始索引顺序
  40. max_len = len(uo_city_pair_list) + resume_idx
  41. print(f"预测阶段起始索引顺序:{resume_idx} ~ {max_len - 1}")
  42. for idx, uo_city_pair in enumerate(uo_city_pair_list, start=resume_idx):
  43. print(f"第 {idx} 组 :", uo_city_pair)
  44. # 加载预测数据 (仅仅是天数取到以后)
  45. start_time = time.time()
  46. df_test = load_data(mongo_config, uo_city_pair, pred_date_begin, pred_date_end, is_train=False,
  47. use_multiprocess=True, max_workers=max_workers)
  48. end_time = time.time()
  49. run_time = round(end_time - start_time, 3)
  50. print(f"用时: {run_time} 秒")
  51. if df_test.empty:
  52. print(f"预测数据为空,跳过此批次。")
  53. continue
  54. # 按起飞时间过滤
  55. df_test['from_hour'] = df_test['from_time'].dt.floor('h')
  56. # 使用整点时间进行比较过滤
  57. mask = (df_test['from_hour'] >= pred_hour_begin) & (df_test['from_hour'] < pred_hour_end)
  58. original_count = len(df_test)
  59. df_test = df_test[mask].reset_index(drop=True)
  60. filtered_count = len(df_test)
  61. # 删除临时字段
  62. df_test = df_test.drop(columns=['from_hour'])
  63. print(f"按起飞时间过滤:过滤前 {original_count} 条,过滤后 {filtered_count} 条")
  64. if filtered_count == 0:
  65. print(f"起飞时间在 {pred_hour_begin} 到 {pred_hour_end} 之间没有航班,跳过此批次。")
  66. continue
  67. df_test_inputs, _, _, _, = preprocess_data_simple(df_test, is_train=False, hourly_time=hourly_time)
  68. df_predict = predict_data_simple(df_test_inputs, uo_city_pair, output_dir, predict_dir, hourly_time_str)
  69. del df_test_inputs
  70. del df_predict
  71. print(f"第 {idx} 组 预测完成")
  72. print()
  73. time.sleep(1)
  74. print("所有批次的预测结束")
  75. print()
  76. if __name__ == "__main__":
  77. start_predict()