import os import torch import joblib # import pandas as pd # import numpy as np import pickle import time import argparse from datetime import datetime, timedelta from config import mongodb_config, vj_flight_route_list_hot, vj_flight_route_list_nothot, CLEAN_VJ_HOT_NEAR_INFO_TAB, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB from data_loader import load_train_data from data_preprocess import preprocess_data_cycle, standardization from utils import chunk_list_with_index, create_fixed_length_sequences from model import PriceDropClassifiTransModel from predict import predict_future_distribute from main_tr import features, target_vars def initialize_model(): input_size = len(features) model = PriceDropClassifiTransModel(input_size, num_periods=2, hidden_size=64, num_layers=3, output_size=1, dropout=0.2) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.to(device) print(f"模型已初始化,输入尺寸:{input_size}") return model, device def convert_date_format(date_str): """将 '2025-09-19 19:35:00' 转换为 '20250919193500' 格式""" dt = datetime.strptime(date_str, '%Y-%m-%d %H:%M:%S') return dt # return dt.strftime('%Y%m%d%H%M00') def start_predict(interval_hours): print(f"开始预测,间隔小时数: {interval_hours}") output_dir = "./data_shards" photo_dir = "./photo" predict_dir = "./predictions" if interval_hours == 4: output_dir = "./data_shards_4" photo_dir = "./photo_4" predict_dir = "./predictions_4" elif interval_hours == 2: output_dir = "./data_shards_2" photo_dir = "./photo_2" predict_dir = "./predictions_2" # 确保目录存在 os.makedirs(output_dir, exist_ok=True) os.makedirs(photo_dir, exist_ok=True) os.makedirs(predict_dir, exist_ok=True) # 清空上一次预测结果 # csv_file_list = ['future_predictions.csv'] # for csv_file in csv_file_list: # try: # csv_path = os.path.join(output_dir, csv_file) # os.remove(csv_path) # except Exception as e: # print(f"remove {csv_path} error: {str(e)}") cpu_cores = os.cpu_count() # 你的系统是72 max_workers = min(4, cpu_cores) # 最大不超过4个进程 model, _ = initialize_model() # 当前时间,取整时 current_time = datetime.now() current_time_str = current_time.strftime("%Y%m%d%H%M") hourly_time = current_time.replace(minute=0, second=0, microsecond=0) pred_time_str = hourly_time.strftime("%Y%m%d%H%M") print(f"预测时间:{current_time_str}, (取整): {pred_time_str}") current_n_hours = 36 if interval_hours == 4: current_n_hours = 32 elif interval_hours == 2: current_n_hours = 30 # 预测时间范围,满足起飞时间 在28小时后到36/32/30小时后 pred_hour_begin = hourly_time + timedelta(hours=28) pred_hour_end = hourly_time + timedelta(hours=current_n_hours) pred_date_end = pred_hour_end.strftime("%Y-%m-%d") pred_date_begin = pred_hour_begin.strftime("%Y-%m-%d") # date_end = (datetime.today() + timedelta(days=2)).strftime("%Y-%m-%d") # date_begin = (datetime.today() + timedelta(days=2)).strftime("%Y-%m-%d") # 加载 scaler 列表 feature_scaler_path = os.path.join(output_dir, 'feature_scalers.joblib') # target_scaler_path = os.path.join(output_dir, 'target_scalers.joblib') feature_scaler_list = joblib.load(feature_scaler_path) # target_scaler_list = joblib.load(target_scaler_path) # 加载训练时保存的航班列表顺序 with open(os.path.join(output_dir, f'order.pkl'), "rb") as f: flight_route_list = pickle.load(f) flight_route_list_len = len(flight_route_list) route_len_hot = len(vj_flight_route_list_hot) route_len_nothot = len(vj_flight_route_list_nothot[:0]) # 排除冷门航线 assemble_size = 1 # 几个batch作为一个集群assemble current_assembled = -1 # 当前已加载的assemble索引 group_size = 1 # 每几组作为一个批次 chunks = chunk_list_with_index(flight_route_list, group_size) # 如果从中途某个批次预测, 修改起始索引 resume_chunk_idx = 0 chunks = chunks[resume_chunk_idx:] batch_starts = [start_idx for start_idx, _ in chunks] print(f"预测阶段起始索引顺序:{batch_starts}") # 测试阶段 for i, (_, group_route_list) in enumerate(chunks, start=resume_chunk_idx): # 特殊处理,跳过不好的批次 # client, db = mongo_con_parse() print(f"第 {i} 组 :", group_route_list) # batch_flight_routes = group_route_list # 根据索引位置决定是 热门 还是 冷门 if 0 <= i < route_len_hot: is_hot = 1 table_name = CLEAN_VJ_HOT_NEAR_INFO_TAB elif route_len_hot <= i < route_len_hot + route_len_nothot: is_hot = 0 table_name = CLEAN_VJ_NOTHOT_NEAR_INFO_TAB else: print(f"无法确定热门还是冷门, 跳过此批次。") continue # 加载测试数据 (仅仅是时间段取到后天) start_time = time.time() df_test = load_train_data(mongodb_config, group_route_list, table_name, pred_date_begin, pred_date_end, output_dir, is_hot, use_multiprocess=True, max_workers=max_workers) end_time = time.time() run_time = round(end_time - start_time, 3) print(f"用时: {run_time} 秒") # client.close() if df_test.empty: print(f"测试数据为空,跳过此批次。") continue # 按起飞时间过滤 # 创建临时字段:seg1_dep_time 的整点时间 df_test['seg1_dep_hour'] = df_test['seg1_dep_time'].dt.floor('h') # 使用整点时间进行比较过滤 mask = (df_test['seg1_dep_hour'] >= pred_hour_begin) & (df_test['seg1_dep_hour'] < pred_hour_end) original_count = len(df_test) df_test = df_test[mask].reset_index(drop=True) filtered_count = len(df_test) # 删除临时字段 df_test = df_test.drop(columns=['seg1_dep_hour']) print(f"按起飞时间过滤:过滤前 {original_count} 条,过滤后 {filtered_count} 条") if filtered_count == 0: print(f"起飞时间在 {pred_hour_begin} 到 {pred_hour_end} 之间没有航班,跳过此批次。") continue feature_length = 240 # 数据预处理 df_test_inputs = preprocess_data_cycle(df_test, is_training=False, interval_hours=interval_hours, feature_length=feature_length) total_rows = df_test_inputs.shape[0] print(f"行数: {total_rows}") if total_rows == 0: print(f"预处理后的测试数据为空,跳过此批次。") continue # 找对应的特征缩放文件 batch_idx = i print("batch_idx:", batch_idx) feature_scaler = feature_scaler_list[batch_idx] if feature_scaler is None: print(f"批次{batch_idx}没有找到特征标准化缩放文件") continue # 标准化与归一化处理 df_test_inputs, feature_scaler, _ = standardization(df_test_inputs, feature_scaler, is_training=False, feature_length=feature_length) print("标准化后数据样本:\n", df_test_inputs.head()) threshold = current_n_hours # input_length = 444 # 确保 threshold 与 input_length 之合为 480 # if threshold == 36: # input_length = 444 # elif threshold == 32: # input_length = 448 # elif threshold == 30: # input_length = 450 # 生成序列 sequences, _, group_ids = create_fixed_length_sequences(df_test_inputs, features, target_vars, threshold, feature_length, is_train=False) print(f"序列数量:{len(sequences)}") #----- 新增:智能模型加载 -----# assemble_idx = batch_idx // assemble_size # 计算当前集群索引 print("assemble_idx:", assemble_idx) if assemble_idx != current_assembled: # 从文件加载并缓存 model_path = os.path.join(output_dir, f'best_model_as_{assemble_idx}.pth') if os.path.exists(model_path): state_dict = torch.load(model_path) model.load_state_dict(state_dict) current_assembled = assemble_idx print(f"从文件加载并缓存 assemble {assemble_idx} 的模型参数") else: print(f"未找到 assemble {assemble_idx} 的模型文件,跳过") continue else: # 同一assemble直接使用已加载参数 print(f"复用 assemble {assemble_idx} 的已加载模型参数") target_scaler = None # 预测未来数据 predict_future_distribute(model, sequences, group_ids, target_scaler=target_scaler, interval_hours=interval_hours, predict_dir=predict_dir, pred_time_str=pred_time_str) print("所有批次的预测结束") print() # 所有批次的预测结束后, 统一过滤处理 # csv_file = 'future_predictions.csv' # csv_path = os.path.join(output_dir, csv_file) # # 汇总预测结果 # try: # df_predict = pd.read_csv(csv_path) # except Exception as e: # print(f"read {csv_path} error: {str(e)}") # df_predict = None # 后续的处理 pass if __name__ == "__main__": parser = argparse.ArgumentParser(description='预测脚本') parser.add_argument('--interval', type=int, choices=[2, 4, 8], default=8, help='间隔小时数(2, 4, 8)') args = parser.parse_args() start_predict(args.interval)