main_pe.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. import os
  2. import torch
  3. import joblib
  4. import pandas as pd
  5. import numpy as np
  6. import pickle
  7. import time
  8. from datetime import datetime, timedelta
  9. from config import mongodb_config, vj_flight_route_list_hot, vj_flight_route_list_nothot, CLEAN_VJ_HOT_NEAR_INFO_TAB, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
  10. from data_loader import mongo_con_parse, load_train_data
  11. from data_preprocess import preprocess_data, standardization
  12. from utils import chunk_list_with_index, create_fixed_length_sequences
  13. from model import PriceDropClassifiTransModel
  14. from predict import predict_future_distribute
  15. from main_tr import features, categorical_features, target_vars
  16. output_dir = "./data_shards"
  17. photo_dir = "./photo"
  18. def initialize_model():
  19. input_size = len(features)
  20. model = PriceDropClassifiTransModel(input_size, num_periods=2, hidden_size=64, num_layers=3, output_size=1, dropout=0.2)
  21. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  22. model.to(device)
  23. print(f"模型已初始化,输入尺寸:{input_size}")
  24. return model, device
  25. def convert_date_format(date_str):
  26. """将 '2025-09-19 19:35:00' 转换为 '20250919193500' 格式"""
  27. dt = datetime.strptime(date_str, '%Y-%m-%d %H:%M:%S')
  28. return dt
  29. # return dt.strftime('%Y%m%d%H%M00')
  30. def start_predict():
  31. # 确保目录存在
  32. os.makedirs(output_dir, exist_ok=True)
  33. os.makedirs(photo_dir, exist_ok=True)
  34. # 清空上一次预测结果
  35. # csv_file_list = ['future_predictions.csv']
  36. # for csv_file in csv_file_list:
  37. # try:
  38. # csv_path = os.path.join(output_dir, csv_file)
  39. # os.remove(csv_path)
  40. # except Exception as e:
  41. # print(f"remove {csv_path} error: {str(e)}")
  42. cpu_cores = os.cpu_count() # 你的系统是72
  43. max_workers = min(4, cpu_cores) # 最大不超过4个进程
  44. model, _ = initialize_model()
  45. # 当前时间,取整时
  46. current_time = datetime.now()
  47. hourly_time = current_time.replace(minute=0, second=0, microsecond=0)
  48. pred_time_str = hourly_time.strftime("%Y%m%d%H%M")
  49. print(f"预测时间(取整): {pred_time_str}")
  50. # 预测时间范围,满足起飞时间 在28小时后到40小时后
  51. pred_hour_begin = hourly_time + timedelta(hours=28)
  52. pred_hour_end = hourly_time + timedelta(hours=40)
  53. pred_date_end = pred_hour_end.strftime("%Y-%m-%d")
  54. pred_date_begin = pred_hour_begin.strftime("%Y-%m-%d")
  55. # date_end = (datetime.today() + timedelta(days=2)).strftime("%Y-%m-%d")
  56. # date_begin = (datetime.today() + timedelta(days=2)).strftime("%Y-%m-%d")
  57. # 加载 scaler 列表
  58. feature_scaler_path = os.path.join(output_dir, 'feature_scalers.joblib')
  59. # target_scaler_path = os.path.join(output_dir, 'target_scalers.joblib')
  60. feature_scaler_list = joblib.load(feature_scaler_path)
  61. # target_scaler_list = joblib.load(target_scaler_path)
  62. # 加载训练时保存的航班列表顺序
  63. with open(os.path.join(output_dir, f'order.pkl'), "rb") as f:
  64. flight_route_list = pickle.load(f)
  65. flight_route_list_len = len(flight_route_list)
  66. route_len_hot = len(vj_flight_route_list_hot)
  67. route_len_nothot = len(vj_flight_route_list_nothot)
  68. assemble_size = 1 # 几个batch作为一个集群assemble
  69. current_assembled = -1 # 当前已加载的assemble索引
  70. group_size = 1 # 每几组作为一个批次
  71. chunks = chunk_list_with_index(flight_route_list, group_size)
  72. # 如果从中途某个批次预测, 修改起始索引
  73. resume_chunk_idx = 0
  74. chunks = chunks[resume_chunk_idx:]
  75. batch_starts = [start_idx for start_idx, _ in chunks]
  76. print(f"预测阶段起始索引顺序:{batch_starts}")
  77. # 测试阶段
  78. for i, (_, group_route_list) in enumerate(chunks, start=resume_chunk_idx):
  79. # 特殊处理,跳过不好的批次
  80. # client, db = mongo_con_parse()
  81. print(f"第 {i} 组 :", group_route_list)
  82. # batch_flight_routes = group_route_list
  83. # 根据索引位置决定是 热门 还是 冷门
  84. if 0 <= i < route_len_hot:
  85. is_hot = 1
  86. table_name = CLEAN_VJ_HOT_NEAR_INFO_TAB
  87. elif route_len_hot <= i < route_len_hot + route_len_nothot:
  88. is_hot = 0
  89. table_name = CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
  90. else:
  91. print(f"无法确定热门还是冷门, 跳过此批次。")
  92. continue
  93. # 加载测试数据 (仅仅是时间段取到后天)
  94. start_time = time.time()
  95. df_test = load_train_data(mongodb_config, group_route_list, table_name, pred_date_begin, pred_date_end, output_dir, is_hot,
  96. use_multiprocess=True, max_workers=max_workers)
  97. end_time = time.time()
  98. run_time = round(end_time - start_time, 3)
  99. print(f"用时: {run_time} 秒")
  100. # client.close()
  101. if df_test.empty:
  102. print(f"测试数据为空,跳过此批次。")
  103. continue
  104. # 按起飞时间过滤
  105. # 创建临时字段:seg1_dep_time 的整点时间
  106. df_test['seg1_dep_hour'] = df_test['seg1_dep_time'].dt.floor('h')
  107. # 使用整点时间进行比较过滤
  108. mask = (df_test['seg1_dep_hour'] >= pred_hour_begin) & (df_test['seg1_dep_hour'] <= pred_hour_end)
  109. original_count = len(df_test)
  110. df_test = df_test[mask].reset_index(drop=True)
  111. filtered_count = len(df_test)
  112. # 删除临时字段
  113. df_test = df_test.drop(columns=['seg1_dep_hour'])
  114. print(f"按起飞时间过滤:过滤前 {original_count} 条,过滤后 {filtered_count} 条")
  115. if filtered_count == 0:
  116. print(f"起飞时间在 {pred_hour_begin} 到 {pred_hour_end} 之间没有航班,跳过此批次。")
  117. continue
  118. # 数据预处理
  119. df_test_inputs = preprocess_data(df_test, features, categorical_features, is_training=False)
  120. total_rows = df_test_inputs.shape[0]
  121. print(f"行数: {total_rows}")
  122. if total_rows == 0:
  123. print(f"预处理后的测试数据为空,跳过此批次。")
  124. continue
  125. # 找对应的特征缩放文件
  126. batch_idx = i
  127. print("batch_idx:", batch_idx)
  128. feature_scaler = feature_scaler_list[batch_idx]
  129. if feature_scaler is None:
  130. print(f"批次{batch_idx}没有找到特征标准化缩放文件")
  131. continue
  132. # 标准化与归一化处理
  133. df_test_inputs, feature_scaler, _ = standardization(df_test_inputs, feature_scaler, is_training=False)
  134. print("标准化后数据样本:\n", df_test_inputs.head())
  135. # 生成序列
  136. sequences, _, group_ids = create_fixed_length_sequences(df_test_inputs, features, target_vars, is_train=False)
  137. print(f"序列数量:{len(sequences)}")
  138. #----- 新增:智能模型加载 -----#
  139. assemble_idx = batch_idx // assemble_size # 计算当前集群索引
  140. print("assemble_idx:", assemble_idx)
  141. if assemble_idx != current_assembled:
  142. # 从文件加载并缓存
  143. model_path = os.path.join(output_dir, f'best_model_as_{assemble_idx}.pth')
  144. if os.path.exists(model_path):
  145. state_dict = torch.load(model_path)
  146. model.load_state_dict(state_dict)
  147. current_assembled = assemble_idx
  148. print(f"从文件加载并缓存 assemble {assemble_idx} 的模型参数")
  149. else:
  150. print(f"未找到 assemble {assemble_idx} 的模型文件,跳过")
  151. continue
  152. else:
  153. # 同一assemble直接使用已加载参数
  154. print(f"复用 assemble {assemble_idx} 的已加载模型参数")
  155. target_scaler = None
  156. # 预测未来数据
  157. predict_future_distribute(model, sequences, group_ids, target_scaler=target_scaler, output_dir=output_dir, pred_time_str=pred_time_str)
  158. print("所有批次的预测结束")
  159. print()
  160. # 所有批次的预测结束后, 统一过滤处理
  161. # csv_file = 'future_predictions.csv'
  162. # csv_path = os.path.join(output_dir, csv_file)
  163. # # 汇总预测结果
  164. # try:
  165. # df_predict = pd.read_csv(csv_path)
  166. # except Exception as e:
  167. # print(f"read {csv_path} error: {str(e)}")
  168. # df_predict = None
  169. # 后续的处理
  170. pass
  171. if __name__ == "__main__":
  172. start_predict()