|
|
@@ -0,0 +1,181 @@
|
|
|
+import os
|
|
|
+import torch
|
|
|
+import joblib
|
|
|
+import pandas as pd
|
|
|
+import numpy as np
|
|
|
+import pickle
|
|
|
+import time
|
|
|
+from datetime import datetime, timedelta
|
|
|
+from config import vj_flight_route_list_hot, vj_flight_route_list_nothot, CLEAN_VJ_HOT_NEAR_INFO_TAB, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
|
|
|
+from data_loader import mongo_con_parse, load_train_data
|
|
|
+from data_preprocess import preprocess_data, standardization
|
|
|
+from utils import chunk_list_with_index, create_fixed_length_sequences
|
|
|
+from model import PriceDropClassifiTransModel
|
|
|
+from predict import predict_future_distribute
|
|
|
+from main_tr import features, categorical_features, target_vars
|
|
|
+
|
|
|
+
|
|
|
+output_dir = "./data_shards"
|
|
|
+photo_dir = "./photo"
|
|
|
+
|
|
|
+
|
|
|
+def initialize_model():
|
|
|
+ input_size = len(features)
|
|
|
+ model = PriceDropClassifiTransModel(input_size, num_periods=2, hidden_size=64, num_layers=3, output_size=1, dropout=0.2)
|
|
|
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
+ model.to(device)
|
|
|
+
|
|
|
+ print(f"模型已初始化,输入尺寸:{input_size}")
|
|
|
+ return model, device
|
|
|
+
|
|
|
+def convert_date_format(date_str):
|
|
|
+ """将 '2025-09-19 19:35:00' 转换为 '20250919193500' 格式"""
|
|
|
+ dt = datetime.strptime(date_str, '%Y-%m-%d %H:%M:%S')
|
|
|
+ return dt
|
|
|
+ # return dt.strftime('%Y%m%d%H%M00')
|
|
|
+
|
|
|
+def start_predict():
|
|
|
+
|
|
|
+ # 确保目录存在
|
|
|
+ os.makedirs(output_dir, exist_ok=True)
|
|
|
+ os.makedirs(photo_dir, exist_ok=True)
|
|
|
+
|
|
|
+ # 清空上一次预测结果
|
|
|
+ csv_file_list = ['future_predictions.csv']
|
|
|
+
|
|
|
+ for csv_file in csv_file_list:
|
|
|
+ try:
|
|
|
+ csv_path = os.path.join(output_dir, csv_file)
|
|
|
+ os.remove(csv_path)
|
|
|
+ except Exception as e:
|
|
|
+ print(f"remove {csv_path} error: {str(e)}")
|
|
|
+
|
|
|
+ model, _ = initialize_model()
|
|
|
+
|
|
|
+ date_end = (datetime.today() + timedelta(days=2)).strftime("%Y-%m-%d")
|
|
|
+ date_begin = (datetime.today() + timedelta(days=2)).strftime("%Y-%m-%d")
|
|
|
+
|
|
|
+ # 加载 scaler 列表
|
|
|
+ feature_scaler_path = os.path.join(output_dir, 'feature_scalers.joblib')
|
|
|
+ # target_scaler_path = os.path.join(output_dir, 'target_scalers.joblib')
|
|
|
+ feature_scaler_list = joblib.load(feature_scaler_path)
|
|
|
+ # target_scaler_list = joblib.load(target_scaler_path)
|
|
|
+
|
|
|
+ # 加载训练时保存的航班列表顺序
|
|
|
+ with open(os.path.join(output_dir, f'order.pkl'), "rb") as f:
|
|
|
+ flight_route_list = pickle.load(f)
|
|
|
+
|
|
|
+ flight_route_list_len = len(flight_route_list)
|
|
|
+ route_len_hot = len(vj_flight_route_list_hot)
|
|
|
+ route_len_nothot = len(vj_flight_route_list_nothot)
|
|
|
+
|
|
|
+ assemble_size = 1 # 几个batch作为一个集群assemble
|
|
|
+ current_assembled = -1 # 当前已加载的assemble索引
|
|
|
+ group_size = 1 # 每几组作为一个批次
|
|
|
+
|
|
|
+ chunks = chunk_list_with_index(flight_route_list, group_size)
|
|
|
+
|
|
|
+ # 如果从中途某个批次预测, 修改起始索引
|
|
|
+ resume_chunk_idx = 0
|
|
|
+ chunks = chunks[resume_chunk_idx:]
|
|
|
+
|
|
|
+ batch_starts = [start_idx for start_idx, _ in chunks]
|
|
|
+ print(f"预测阶段起始索引顺序:{batch_starts}")
|
|
|
+
|
|
|
+ # 测试阶段
|
|
|
+ for i, (_, group_route_list) in enumerate(chunks, start=resume_chunk_idx):
|
|
|
+ # 特殊处理,跳过不好的批次
|
|
|
+ client, db = mongo_con_parse()
|
|
|
+ print(f"第 {i} 组 :", group_route_list)
|
|
|
+ # batch_flight_routes = group_route_list
|
|
|
+
|
|
|
+ # 根据索引位置决定是 热门 还是 冷门
|
|
|
+ if 0 <= i < route_len_hot:
|
|
|
+ is_hot = 1
|
|
|
+ table_name = CLEAN_VJ_HOT_NEAR_INFO_TAB
|
|
|
+ elif route_len_hot <= i < route_len_hot + route_len_nothot:
|
|
|
+ is_hot = 0
|
|
|
+ table_name = CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
|
|
|
+ else:
|
|
|
+ print(f"无法确定热门还是冷门, 跳过此批次。")
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 加载测试数据 (仅仅是时间段取到后天)
|
|
|
+ start_time = time.time()
|
|
|
+ df_test = load_train_data(db, group_route_list, table_name, date_begin, date_end, output_dir, is_hot)
|
|
|
+ end_time = time.time()
|
|
|
+ run_time = round(end_time - start_time, 3)
|
|
|
+ print(f"用时: {run_time} 秒")
|
|
|
+
|
|
|
+ client.close()
|
|
|
+
|
|
|
+ if df_test.empty:
|
|
|
+ print(f"测试数据为空,跳过此批次。")
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 数据预处理
|
|
|
+ df_test_inputs = preprocess_data(df_test, features, categorical_features, is_training=False)
|
|
|
+
|
|
|
+ total_rows = df_test_inputs.shape[0]
|
|
|
+ print(f"行数: {total_rows}")
|
|
|
+ if total_rows == 0:
|
|
|
+ print(f"预处理后的测试数据为空,跳过此批次。")
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 找对应的特征缩放文件
|
|
|
+ batch_idx = i
|
|
|
+ print("batch_idx:", batch_idx)
|
|
|
+ feature_scaler = feature_scaler_list[batch_idx]
|
|
|
+ if feature_scaler is None:
|
|
|
+ print(f"批次{batch_idx}没有找到特征标准化缩放文件")
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 标准化与归一化处理
|
|
|
+ df_test_inputs, feature_scaler, _ = standardization(df_test_inputs, feature_scaler, is_training=False)
|
|
|
+
|
|
|
+ print("标准化后数据样本:\n", df_test_inputs.head())
|
|
|
+
|
|
|
+ # 生成序列
|
|
|
+ sequences, _, group_ids = create_fixed_length_sequences(df_test_inputs, features, target_vars, is_train=False)
|
|
|
+ print(f"序列数量:{len(sequences)}")
|
|
|
+
|
|
|
+ #----- 新增:智能模型加载 -----#
|
|
|
+ assemble_idx = batch_idx // assemble_size # 计算当前集群索引
|
|
|
+ print("assemble_idx:", assemble_idx)
|
|
|
+ if assemble_idx != current_assembled:
|
|
|
+ # 从文件加载并缓存
|
|
|
+ model_path = os.path.join(output_dir, f'best_model_as_{assemble_idx}.pth')
|
|
|
+ if os.path.exists(model_path):
|
|
|
+ state_dict = torch.load(model_path)
|
|
|
+ model.load_state_dict(state_dict)
|
|
|
+ current_assembled = assemble_idx
|
|
|
+ print(f"从文件加载并缓存 assemble {assemble_idx} 的模型参数")
|
|
|
+ else:
|
|
|
+ print(f"未找到 assemble {assemble_idx} 的模型文件,跳过")
|
|
|
+ continue
|
|
|
+ else:
|
|
|
+ # 同一assemble直接使用已加载参数
|
|
|
+ print(f"复用 assemble {assemble_idx} 的已加载模型参数")
|
|
|
+
|
|
|
+ target_scaler = None
|
|
|
+ # 预测未来数据
|
|
|
+ predict_future_distribute(model, sequences, group_ids, target_scaler=target_scaler, output_dir=output_dir)
|
|
|
+
|
|
|
+ print("所有批次的预测结束")
|
|
|
+ # 所有批次的预测结束后, 统一过滤处理
|
|
|
+ # csv_file = 'future_predictions.csv'
|
|
|
+ # csv_path = os.path.join(output_dir, csv_file)
|
|
|
+
|
|
|
+ # # 汇总预测结果
|
|
|
+ # try:
|
|
|
+ # df_predict = pd.read_csv(csv_path)
|
|
|
+ # except Exception as e:
|
|
|
+ # print(f"read {csv_path} error: {str(e)}")
|
|
|
+ # df_predict = None
|
|
|
+
|
|
|
+ # 后续的处理
|
|
|
+ pass
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ start_predict()
|