main_pe_0.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. import os
  2. import time
  3. from datetime import datetime, timedelta
  4. from config import mongodb_config, vj_flight_route_list_hot, vj_flight_route_list_nothot, CLEAN_VJ_HOT_NEAR_INFO_TAB, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
  5. from data_loader import load_train_data
  6. from data_preprocess import preprocess_data_simple, predict_data_simple
  7. from utils import chunk_list_with_index
  8. def start_predict():
  9. print(f"开始预测")
  10. output_dir = "./data_shards_0"
  11. # photo_dir = "./photo_0"
  12. predict_dir = "./predictions_0"
  13. # 确保目录存在
  14. os.makedirs(output_dir, exist_ok=True)
  15. # os.makedirs(photo_dir, exist_ok=True)
  16. os.makedirs(predict_dir, exist_ok=True)
  17. cpu_cores = os.cpu_count() # 你的系统是72
  18. max_workers = min(4, cpu_cores) # 最大不超过4个进程
  19. # 当前时间,取整时
  20. current_time = datetime.now()
  21. current_time_str = current_time.strftime("%Y%m%d%H%M")
  22. hourly_time = current_time.replace(minute=0, second=0, microsecond=0)
  23. hourly_time_str = hourly_time.strftime("%Y%m%d%H%M")
  24. print(f"预测时间:{current_time_str}, (取整): {hourly_time_str}")
  25. # 清空上一次(同小时内)预测结果
  26. csv_file_list = [f'future_predictions_{hourly_time_str}.csv']
  27. for csv_file in csv_file_list:
  28. try:
  29. csv_path = os.path.join(predict_dir, csv_file)
  30. os.remove(csv_path)
  31. except Exception as e:
  32. print(f"remove {csv_path} info: {str(e)}")
  33. # 预测时间范围,满足起飞时间 在18小时后到54小时后
  34. pred_hour_begin = hourly_time + timedelta(hours=18)
  35. pred_hour_end = hourly_time + timedelta(hours=54)
  36. pred_date_end = pred_hour_end.strftime("%Y-%m-%d")
  37. pred_date_begin = pred_hour_begin.strftime("%Y-%m-%d")
  38. print(f"预测起飞时间范围: {pred_date_begin} 到 {pred_date_end}")
  39. # 主干代码 (排除冷门航线)
  40. flight_route_list = vj_flight_route_list_hot + vj_flight_route_list_nothot[:0]
  41. flight_route_list_len = len(flight_route_list)
  42. route_len_hot = len(vj_flight_route_list_hot)
  43. route_len_nothot = len(vj_flight_route_list_nothot[:0])
  44. group_size = 1 # 每几组作为一个批次
  45. chunks = chunk_list_with_index(flight_route_list, group_size)
  46. # 如果从中途某个批次预测, 修改起始索引
  47. resume_chunk_idx = 0
  48. chunks = chunks[resume_chunk_idx:]
  49. batch_starts = [start_idx for start_idx, _ in chunks]
  50. print(f"预测阶段起始索引顺序:{batch_starts}")
  51. # 预测阶段
  52. for i, (_, group_route_list) in enumerate(chunks, start=resume_chunk_idx):
  53. # 特殊处理,跳过不好的批次
  54. # client, db = mongo_con_parse()
  55. print(f"第 {i} 组 :", group_route_list)
  56. # batch_flight_routes = group_route_list
  57. group_route_str = ','.join(group_route_list)
  58. # 根据索引位置决定是 热门 还是 冷门
  59. if 0 <= i < route_len_hot:
  60. is_hot = 1
  61. table_name = CLEAN_VJ_HOT_NEAR_INFO_TAB
  62. elif route_len_hot <= i < route_len_hot + route_len_nothot:
  63. is_hot = 0
  64. table_name = CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
  65. else:
  66. print(f"无法确定热门还是冷门, 跳过此批次。")
  67. continue
  68. # 加载测试数据 (仅仅是时间段取到后天)
  69. start_time = time.time()
  70. df_test = load_train_data(mongodb_config, group_route_list, table_name, pred_date_begin, pred_date_end, output_dir, is_hot,
  71. use_multiprocess=True, max_workers=max_workers)
  72. end_time = time.time()
  73. run_time = round(end_time - start_time, 3)
  74. print(f"用时: {run_time} 秒")
  75. if df_test.empty:
  76. print(f"测试数据为空,跳过此批次。")
  77. continue
  78. # 按起飞时间过滤
  79. # 创建临时字段:seg1_dep_time 的整点时间
  80. df_test['seg1_dep_hour'] = df_test['seg1_dep_time'].dt.floor('h')
  81. # 使用整点时间进行比较过滤
  82. mask = (df_test['seg1_dep_hour'] >= pred_hour_begin) & (df_test['seg1_dep_hour'] < pred_hour_end)
  83. original_count = len(df_test)
  84. df_test = df_test[mask].reset_index(drop=True)
  85. filtered_count = len(df_test)
  86. # 删除临时字段
  87. df_test = df_test.drop(columns=['seg1_dep_hour'])
  88. print(f"按起飞时间过滤:过滤前 {original_count} 条,过滤后 {filtered_count} 条")
  89. if filtered_count == 0:
  90. print(f"起飞时间在 {pred_hour_begin} 到 {pred_hour_end} 之间没有航班,跳过此批次。")
  91. continue
  92. df_test_inputs, _, _, = preprocess_data_simple(df_test)
  93. df_predict = predict_data_simple(df_test_inputs, group_route_str, output_dir, predict_dir, hourly_time_str)
  94. del df_test_inputs
  95. del df_predict
  96. time.sleep(1)
  97. print("所有批次的预测结束")
  98. print()
  99. if __name__ == "__main__":
  100. start_predict()