main_pe.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. import os
  2. import torch
  3. import joblib
  4. # import pandas as pd
  5. # import numpy as np
  6. import pickle
  7. import time
  8. import argparse
  9. from datetime import datetime, timedelta
  10. from config import mongodb_config, vj_flight_route_list_hot, vj_flight_route_list_nothot, CLEAN_VJ_HOT_NEAR_INFO_TAB, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
  11. from data_loader import mongo_con_parse, load_train_data
  12. from data_preprocess import preprocess_data, standardization
  13. from utils import chunk_list_with_index, create_fixed_length_sequences
  14. from model import PriceDropClassifiTransModel
  15. from predict import predict_future_distribute
  16. from main_tr import features, categorical_features, target_vars
  17. def initialize_model():
  18. input_size = len(features)
  19. model = PriceDropClassifiTransModel(input_size, num_periods=2, hidden_size=64, num_layers=3, output_size=1, dropout=0.2)
  20. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  21. model.to(device)
  22. print(f"模型已初始化,输入尺寸:{input_size}")
  23. return model, device
  24. def convert_date_format(date_str):
  25. """将 '2025-09-19 19:35:00' 转换为 '20250919193500' 格式"""
  26. dt = datetime.strptime(date_str, '%Y-%m-%d %H:%M:%S')
  27. return dt
  28. # return dt.strftime('%Y%m%d%H%M00')
  29. def start_predict(interval_hours):
  30. print(f"开始预测,间隔小时数: {interval_hours}")
  31. output_dir = "./data_shards"
  32. photo_dir = "./photo"
  33. predict_dir = "./predictions"
  34. if interval_hours == 4:
  35. output_dir = "./data_shards_4"
  36. photo_dir = "./photo_4"
  37. predict_dir = "./predictions_4"
  38. elif interval_hours == 2:
  39. output_dir = "./data_shards_2"
  40. photo_dir = "./photo_2"
  41. predict_dir = "./predictions_2"
  42. # 确保目录存在
  43. os.makedirs(output_dir, exist_ok=True)
  44. os.makedirs(photo_dir, exist_ok=True)
  45. os.makedirs(predict_dir, exist_ok=True)
  46. # 清空上一次预测结果
  47. # csv_file_list = ['future_predictions.csv']
  48. # for csv_file in csv_file_list:
  49. # try:
  50. # csv_path = os.path.join(output_dir, csv_file)
  51. # os.remove(csv_path)
  52. # except Exception as e:
  53. # print(f"remove {csv_path} error: {str(e)}")
  54. cpu_cores = os.cpu_count() # 你的系统是72
  55. max_workers = min(4, cpu_cores) # 最大不超过4个进程
  56. model, _ = initialize_model()
  57. # 当前时间,取整时
  58. current_time = datetime.now()
  59. current_time_str = current_time.strftime("%Y%m%d%H%M")
  60. hourly_time = current_time.replace(minute=0, second=0, microsecond=0)
  61. pred_time_str = hourly_time.strftime("%Y%m%d%H%M")
  62. print(f"预测时间:{current_time_str}, (取整): {pred_time_str}")
  63. current_n_hours = 36
  64. if interval_hours == 4:
  65. current_n_hours = 32
  66. elif interval_hours == 2:
  67. current_n_hours = 30
  68. # 预测时间范围,满足起飞时间 在28小时后到36/32/30小时后
  69. pred_hour_begin = hourly_time + timedelta(hours=28)
  70. pred_hour_end = hourly_time + timedelta(hours=current_n_hours)
  71. pred_date_end = pred_hour_end.strftime("%Y-%m-%d")
  72. pred_date_begin = pred_hour_begin.strftime("%Y-%m-%d")
  73. # date_end = (datetime.today() + timedelta(days=2)).strftime("%Y-%m-%d")
  74. # date_begin = (datetime.today() + timedelta(days=2)).strftime("%Y-%m-%d")
  75. # 加载 scaler 列表
  76. feature_scaler_path = os.path.join(output_dir, 'feature_scalers.joblib')
  77. # target_scaler_path = os.path.join(output_dir, 'target_scalers.joblib')
  78. feature_scaler_list = joblib.load(feature_scaler_path)
  79. # target_scaler_list = joblib.load(target_scaler_path)
  80. # 加载训练时保存的航班列表顺序
  81. with open(os.path.join(output_dir, f'order.pkl'), "rb") as f:
  82. flight_route_list = pickle.load(f)
  83. flight_route_list_len = len(flight_route_list)
  84. route_len_hot = len(vj_flight_route_list_hot)
  85. route_len_nothot = len(vj_flight_route_list_nothot)
  86. assemble_size = 1 # 几个batch作为一个集群assemble
  87. current_assembled = -1 # 当前已加载的assemble索引
  88. group_size = 1 # 每几组作为一个批次
  89. chunks = chunk_list_with_index(flight_route_list, group_size)
  90. # 如果从中途某个批次预测, 修改起始索引
  91. resume_chunk_idx = 0
  92. chunks = chunks[resume_chunk_idx:]
  93. batch_starts = [start_idx for start_idx, _ in chunks]
  94. print(f"预测阶段起始索引顺序:{batch_starts}")
  95. # 测试阶段
  96. for i, (_, group_route_list) in enumerate(chunks, start=resume_chunk_idx):
  97. # 特殊处理,跳过不好的批次
  98. # client, db = mongo_con_parse()
  99. print(f"第 {i} 组 :", group_route_list)
  100. # batch_flight_routes = group_route_list
  101. # 根据索引位置决定是 热门 还是 冷门
  102. if 0 <= i < route_len_hot:
  103. is_hot = 1
  104. table_name = CLEAN_VJ_HOT_NEAR_INFO_TAB
  105. elif route_len_hot <= i < route_len_hot + route_len_nothot:
  106. is_hot = 0
  107. table_name = CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
  108. else:
  109. print(f"无法确定热门还是冷门, 跳过此批次。")
  110. continue
  111. # 加载测试数据 (仅仅是时间段取到后天)
  112. start_time = time.time()
  113. df_test = load_train_data(mongodb_config, group_route_list, table_name, pred_date_begin, pred_date_end, output_dir, is_hot,
  114. use_multiprocess=True, max_workers=max_workers)
  115. end_time = time.time()
  116. run_time = round(end_time - start_time, 3)
  117. print(f"用时: {run_time} 秒")
  118. # client.close()
  119. if df_test.empty:
  120. print(f"测试数据为空,跳过此批次。")
  121. continue
  122. # 按起飞时间过滤
  123. # 创建临时字段:seg1_dep_time 的整点时间
  124. df_test['seg1_dep_hour'] = df_test['seg1_dep_time'].dt.floor('h')
  125. # 使用整点时间进行比较过滤
  126. mask = (df_test['seg1_dep_hour'] >= pred_hour_begin) & (df_test['seg1_dep_hour'] < pred_hour_end)
  127. original_count = len(df_test)
  128. df_test = df_test[mask].reset_index(drop=True)
  129. filtered_count = len(df_test)
  130. # 删除临时字段
  131. df_test = df_test.drop(columns=['seg1_dep_hour'])
  132. print(f"按起飞时间过滤:过滤前 {original_count} 条,过滤后 {filtered_count} 条")
  133. if filtered_count == 0:
  134. print(f"起飞时间在 {pred_hour_begin} 到 {pred_hour_end} 之间没有航班,跳过此批次。")
  135. continue
  136. # 数据预处理
  137. df_test_inputs = preprocess_data(df_test, features, categorical_features, is_training=False, current_n_hours=current_n_hours)
  138. total_rows = df_test_inputs.shape[0]
  139. print(f"行数: {total_rows}")
  140. if total_rows == 0:
  141. print(f"预处理后的测试数据为空,跳过此批次。")
  142. continue
  143. # 找对应的特征缩放文件
  144. batch_idx = i
  145. print("batch_idx:", batch_idx)
  146. feature_scaler = feature_scaler_list[batch_idx]
  147. if feature_scaler is None:
  148. print(f"批次{batch_idx}没有找到特征标准化缩放文件")
  149. continue
  150. # 标准化与归一化处理
  151. df_test_inputs, feature_scaler, _ = standardization(df_test_inputs, feature_scaler, is_training=False)
  152. print("标准化后数据样本:\n", df_test_inputs.head())
  153. threshold = current_n_hours
  154. input_length = 444
  155. # 确保 threshold 与 input_length 之合为 480
  156. if threshold == 36:
  157. input_length = 444
  158. elif threshold == 32:
  159. input_length = 448
  160. elif threshold == 30:
  161. input_length = 450
  162. # 生成序列
  163. sequences, _, group_ids = create_fixed_length_sequences(df_test_inputs, features, target_vars, threshold, input_length, is_train=False)
  164. print(f"序列数量:{len(sequences)}")
  165. #----- 新增:智能模型加载 -----#
  166. assemble_idx = batch_idx // assemble_size # 计算当前集群索引
  167. print("assemble_idx:", assemble_idx)
  168. if assemble_idx != current_assembled:
  169. # 从文件加载并缓存
  170. model_path = os.path.join(output_dir, f'best_model_as_{assemble_idx}.pth')
  171. if os.path.exists(model_path):
  172. state_dict = torch.load(model_path)
  173. model.load_state_dict(state_dict)
  174. current_assembled = assemble_idx
  175. print(f"从文件加载并缓存 assemble {assemble_idx} 的模型参数")
  176. else:
  177. print(f"未找到 assemble {assemble_idx} 的模型文件,跳过")
  178. continue
  179. else:
  180. # 同一assemble直接使用已加载参数
  181. print(f"复用 assemble {assemble_idx} 的已加载模型参数")
  182. target_scaler = None
  183. # 预测未来数据
  184. predict_future_distribute(model, sequences, group_ids, target_scaler=target_scaler,
  185. interval_hours=interval_hours, predict_dir=predict_dir, pred_time_str=pred_time_str)
  186. print("所有批次的预测结束")
  187. print()
  188. # 所有批次的预测结束后, 统一过滤处理
  189. # csv_file = 'future_predictions.csv'
  190. # csv_path = os.path.join(output_dir, csv_file)
  191. # # 汇总预测结果
  192. # try:
  193. # df_predict = pd.read_csv(csv_path)
  194. # except Exception as e:
  195. # print(f"read {csv_path} error: {str(e)}")
  196. # df_predict = None
  197. # 后续的处理
  198. pass
  199. if __name__ == "__main__":
  200. parser = argparse.ArgumentParser(description='预测脚本')
  201. parser.add_argument('--interval', type=int, choices=[2, 4, 8],
  202. default=8, help='间隔小时数(2, 4, 8)')
  203. args = parser.parse_args()
  204. start_predict(args.interval)