Quellcode durchsuchen

提交预测代码, 并做相关修改

node04 vor 1 Monat
Ursprung
Commit
ad9fcd622e
8 geänderte Dateien mit 280 neuen und 14 gelöschten Zeilen
  1. 10 4
      data_loader.py
  2. 3 3
      data_preprocess.py
  3. 3 2
      evaluate.py
  4. 181 0
      main_pe.py
  5. 6 2
      main_tr.py
  6. 73 0
      predict.py
  7. 3 3
      train.py
  8. 1 0
      utils.py

+ 10 - 4
data_loader.py

@@ -1,6 +1,6 @@
 import gc
 import time
-from datetime import datetime
+from datetime import datetime, timedelta
 import pymongo
 from pymongo.errors import PyMongoError, ServerSelectionTimeoutError
 import pandas as pd
@@ -349,19 +349,25 @@ def fill_hourly_crawl_date(df, head_fill=0, rear_fill=0):
     return df
 
 
-def query_groups_of_city_code(db, from_city, to_city, table_name, min_days=15, max_retries=3, base_sleep=1.0):
+def query_groups_of_city_code(db, from_city, to_city, table_name, min_days=20, max_retries=3, base_sleep=1.0):
     """
     从一组城市对中查找所有分组(航班号与起飞时间)的组合
     按:第一段航班号 → 第二段航班号 → 起飞时间 排序
-    (失败自动重试)
+    (失败自动重试) 保证2个月内至少有20天起飞的航线
     """
     print(f"{from_city}-{to_city} 查找所有分组")
+    date_begin = (datetime.today() - timedelta(days=60)).strftime("%Y%m%d")
+    date_end = datetime.today().strftime("%Y%m%d")
     pipeline = [
         # 1️⃣ 先筛选城市对
         {
             "$match": {
                 "from_city_code": from_city,
-                "to_city_code": to_city
+                "to_city_code": to_city,
+                "search_dep_time": {
+                    "$gte": date_begin,
+                    "$lte": date_end
+                }
             }
         },
         # 2️⃣ 投影字段 + 拆第一、第二段航班号用于排序

+ 3 - 3
data_preprocess.py

@@ -486,7 +486,7 @@ def preprocess_data(df_input, features, categorical_features, is_training=True,
         "city_pair", "from_city_code", "from_city_num", "to_city_code", "to_city_num", "flight_day", 
         "seats_remaining", "baggage", "baggage_level", 
         "price_change_times_total", "price_last_change_hours", "adult_total_price", "Adult_Total_Price", "target_will_price_drop", "target_amount_of_drop", "target_time_to_drop",
-        "days_to_departure", "days_to_holiday", "hours_until_departure", "Hours_Until_Departure", "update_hour", "gid",
+        "days_to_departure", "days_to_holiday", "hours_until_departure", "Hours_Until_Departure", "update_hour", "crawl_date", "gid",
         "flight_number_1", "flight_1_num", "airport_pair_1", "dep_time_1", "arr_time_1", "fly_duration_1", 
         "flight_by_hour", "flight_by_day", "flight_day_of_month", "flight_day_of_week", "flight_day_of_quarter", "flight_day_is_weekend", "is_transfer", 
         "flight_number_2", "flight_2_num", "airport_pair_2", "dep_time_2", "arr_time_2", "fly_duration_2", "fly_duration", "stop_duration", 
@@ -498,7 +498,7 @@ def preprocess_data(df_input, features, categorical_features, is_training=True,
     return df_input
 
 
-def standardization(df, feature_scaler, target_scaler, is_training=True, is_test=False):
+def standardization(df, feature_scaler, target_scaler=None, is_training=True, is_val=False):
     print(">>> 开始标准化处理")
 
     # 准备走标准化的特征
@@ -508,7 +508,7 @@ def standardization(df, feature_scaler, target_scaler, is_training=True, is_test
         print(">>> 特征数据标准化开始")
         if feature_scaler is None:
             feature_scaler = StandardScaler()
-        if not is_test:
+        if not is_val:
             feature_scaler.fit(df[scaler_features])
         df[scaler_features] = feature_scaler.transform(df[scaler_features])
         print(">>> 特征数据标准化完成")

+ 3 - 2
evaluate.py

@@ -108,8 +108,9 @@ def evaluate_model_distribute(model, device, sequences, targets, group_ids, batc
             'price': [info[6] for info in group_info],
             'Hours_until_Departure': [info[7] for info in group_info],
             'update_hour': [info[8] for info in group_info],
-            'target_amount_of_drop': [info[9] for info in group_info],  # 训练时的验证才有这两个target列
-            'target_time_to_drop': [info[10] for info in group_info],
+            'crawl_date': [info[9] for info in group_info],
+            'target_amount_of_drop': [info[10] for info in group_info],  # 训练时的验证才有这两个target列
+            'target_time_to_drop': [info[11] for info in group_info],
             'probability': y_preds_class,
             'Actual_Will_Price_Drop': y_trues_class_labels,
             'Predicted_Will_Price_Drop': y_preds_class_labels,

+ 181 - 0
main_pe.py

@@ -0,0 +1,181 @@
+import os
+import torch
+import joblib
+import pandas as pd
+import numpy as np
+import pickle
+import time
+from datetime import datetime, timedelta
+from config import vj_flight_route_list_hot, vj_flight_route_list_nothot, CLEAN_VJ_HOT_NEAR_INFO_TAB, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
+from data_loader import mongo_con_parse, load_train_data
+from data_preprocess import preprocess_data, standardization
+from utils import chunk_list_with_index, create_fixed_length_sequences
+from model import PriceDropClassifiTransModel
+from predict import predict_future_distribute
+from main_tr import features, categorical_features, target_vars
+
+
+output_dir = "./data_shards"
+photo_dir = "./photo"
+
+
+def initialize_model():
+    input_size = len(features)
+    model = PriceDropClassifiTransModel(input_size, num_periods=2, hidden_size=64, num_layers=3, output_size=1, dropout=0.2)
+    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+    model.to(device)
+
+    print(f"模型已初始化,输入尺寸:{input_size}")
+    return model, device
+
+def convert_date_format(date_str):
+    """将 '2025-09-19 19:35:00' 转换为 '20250919193500' 格式"""
+    dt = datetime.strptime(date_str, '%Y-%m-%d %H:%M:%S')
+    return dt
+    # return dt.strftime('%Y%m%d%H%M00')
+
+def start_predict():
+
+    # 确保目录存在
+    os.makedirs(output_dir, exist_ok=True) 
+    os.makedirs(photo_dir, exist_ok=True)
+
+    # 清空上一次预测结果
+    csv_file_list = ['future_predictions.csv']
+
+    for csv_file in csv_file_list:
+        try:
+            csv_path = os.path.join(output_dir, csv_file)
+            os.remove(csv_path)
+        except Exception as e:
+            print(f"remove {csv_path} error: {str(e)}")
+
+    model, _ = initialize_model()
+
+    date_end = (datetime.today() + timedelta(days=2)).strftime("%Y-%m-%d")
+    date_begin = (datetime.today() + timedelta(days=2)).strftime("%Y-%m-%d")
+
+    # 加载 scaler 列表
+    feature_scaler_path = os.path.join(output_dir, 'feature_scalers.joblib')
+    # target_scaler_path = os.path.join(output_dir, 'target_scalers.joblib')
+    feature_scaler_list = joblib.load(feature_scaler_path)
+    # target_scaler_list = joblib.load(target_scaler_path) 
+
+    # 加载训练时保存的航班列表顺序
+    with open(os.path.join(output_dir, f'order.pkl'), "rb") as f:
+        flight_route_list = pickle.load(f)   
+
+    flight_route_list_len = len(flight_route_list)
+    route_len_hot = len(vj_flight_route_list_hot)
+    route_len_nothot = len(vj_flight_route_list_nothot)
+    
+    assemble_size = 1           # 几个batch作为一个集群assemble
+    current_assembled = -1      # 当前已加载的assemble索引
+    group_size = 1              # 每几组作为一个批次
+
+    chunks = chunk_list_with_index(flight_route_list, group_size)
+
+    # 如果从中途某个批次预测, 修改起始索引
+    resume_chunk_idx = 0
+    chunks = chunks[resume_chunk_idx:]
+
+    batch_starts = [start_idx for start_idx, _ in chunks]
+    print(f"预测阶段起始索引顺序:{batch_starts}")
+
+    # 测试阶段
+    for i, (_, group_route_list) in enumerate(chunks, start=resume_chunk_idx):
+        # 特殊处理,跳过不好的批次
+        client, db = mongo_con_parse()
+        print(f"第 {i} 组 :", group_route_list)
+        # batch_flight_routes = group_route_list
+
+        # 根据索引位置决定是 热门 还是 冷门
+        if 0 <= i < route_len_hot:
+            is_hot = 1
+            table_name = CLEAN_VJ_HOT_NEAR_INFO_TAB
+        elif route_len_hot <= i < route_len_hot + route_len_nothot:
+            is_hot = 0
+            table_name = CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
+        else:
+            print(f"无法确定热门还是冷门, 跳过此批次。")
+            continue
+        
+        # 加载测试数据 (仅仅是时间段取到后天)
+        start_time = time.time()
+        df_test = load_train_data(db, group_route_list, table_name, date_begin, date_end, output_dir, is_hot)
+        end_time = time.time()
+        run_time = round(end_time - start_time, 3)
+        print(f"用时: {run_time} 秒")
+
+        client.close()
+
+        if df_test.empty:
+            print(f"测试数据为空,跳过此批次。")
+            continue
+
+        # 数据预处理
+        df_test_inputs = preprocess_data(df_test, features, categorical_features, is_training=False)
+        
+        total_rows = df_test_inputs.shape[0]
+        print(f"行数: {total_rows}")
+        if total_rows == 0:
+            print(f"预处理后的测试数据为空,跳过此批次。")
+            continue
+
+        # 找对应的特征缩放文件 
+        batch_idx = i
+        print("batch_idx:", batch_idx)
+        feature_scaler = feature_scaler_list[batch_idx]
+        if feature_scaler is None:
+            print(f"批次{batch_idx}没有找到特征标准化缩放文件")
+            continue
+
+        # 标准化与归一化处理
+        df_test_inputs, feature_scaler, _ = standardization(df_test_inputs, feature_scaler, is_training=False)
+
+        print("标准化后数据样本:\n", df_test_inputs.head())
+
+        # 生成序列
+        sequences, _, group_ids = create_fixed_length_sequences(df_test_inputs, features, target_vars, is_train=False)
+        print(f"序列数量:{len(sequences)}")
+
+        #----- 新增:智能模型加载 -----#
+        assemble_idx = batch_idx // assemble_size  # 计算当前集群索引
+        print("assemble_idx:", assemble_idx)
+        if assemble_idx != current_assembled:
+            # 从文件加载并缓存
+            model_path = os.path.join(output_dir, f'best_model_as_{assemble_idx}.pth')
+            if os.path.exists(model_path):
+                state_dict = torch.load(model_path)
+                model.load_state_dict(state_dict)
+                current_assembled = assemble_idx
+                print(f"从文件加载并缓存 assemble {assemble_idx} 的模型参数")
+            else:
+                print(f"未找到 assemble {assemble_idx} 的模型文件,跳过")
+                continue
+        else:
+            # 同一assemble直接使用已加载参数
+            print(f"复用 assemble {assemble_idx} 的已加载模型参数")
+
+        target_scaler = None
+        # 预测未来数据
+        predict_future_distribute(model, sequences, group_ids, target_scaler=target_scaler, output_dir=output_dir)
+
+    print("所有批次的预测结束")
+    # 所有批次的预测结束后, 统一过滤处理
+    # csv_file = 'future_predictions.csv'
+    # csv_path = os.path.join(output_dir, csv_file)
+
+    # # 汇总预测结果
+    # try:
+    #     df_predict = pd.read_csv(csv_path)
+    # except Exception as e:
+    #     print(f"read {csv_path} error: {str(e)}")
+    #     df_predict = None
+
+    # 后续的处理
+    pass
+
+
+if __name__ == "__main__":
+    start_predict()

+ 6 - 2
main_tr.py

@@ -153,6 +153,7 @@ def start_train():
     barrier_key = 'distributed_barrier_11'
 
     assemble_size = 1   # 几个batch作为一个集群assemble
+    assemble_idx = -1   
     batch_idx = -1
     batch_flight_routes = None   # 占位, 避免其它rank找不到定义
 
@@ -293,6 +294,9 @@ def start_train():
             # target_scaler_path = os.path.join(output_dir, f'target_scalers.joblib')
             joblib.dump(feature_scaler_list, feature_scaler_path)
             # joblib.dump(target_scaler_list, target_scaler_path)
+
+            assemble_idx = batch_idx // assemble_size  # 计算当前集群索引
+            print("assemble_idx:", assemble_idx)
             
             # 生成序列
             sequences, targets, group_ids = create_fixed_length_sequences(df_train_inputs, features, target_vars)
@@ -387,7 +391,7 @@ def start_train():
         model = train_model_distribute(train_sequences, train_targets, train_group_ids, val_sequences, val_targets, val_group_ids,
             model, criterion, optimizer, device, num_epochs=num_epochs_per_batch, batch_size=16, target_scaler=target_scaler, 
             flag_distributed=FLAG_Distributed, rank=rank, local_rank=local_rank, world_size=world_size, 
-            output_dir=output_dir, photo_dir=photo_dir, batch_idx=batch_idx, 
+            output_dir=output_dir, assemble_idx=assemble_idx, photo_dir=photo_dir, batch_idx=batch_idx, 
             batch_flight_routes=batch_flight_routes, patience=40, delta=0.001)
 
         del train_single
@@ -569,7 +573,7 @@ def _validate_group_structure(group_ids):
     
     sample = group_ids[0]
     assert isinstance(sample, tuple), "元素必须是元组"
-    assert len(sample) == 11, "元组长度必须为11"
+    assert len(sample) == 12, "元组长度必须为12"
 
 def debug_print_shard_info(sequences, targets, group_ids, rank, local_rank, world_size):
     """分布式环境下按Rank顺序打印分片前5条样本"""

+ 73 - 0
predict.py

@@ -0,0 +1,73 @@
+import datetime
+import torch
+import pandas as pd
+import numpy as np
+import os
+from torch.utils.data import DataLoader
+from utils import FlightDataset
+
+
+def predict_future_distribute(model, sequences, group_ids, batch_size=16, target_scaler=None, output_dir="."):
+    if not sequences:
+        print("没有足够的数据进行预测。")
+        return
+    
+    test_dataset = FlightDataset(sequences, None, group_ids)
+    test_loader = DataLoader(test_dataset, batch_size=batch_size)
+    model.eval()
+
+    y_preds = []
+    group_info = []
+    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+    with torch.no_grad():
+        for X_batch, group_ids in test_loader:
+            X_batch = X_batch.to(device)
+            outputs = model(X_batch)
+            y_preds.extend(outputs.cpu().numpy())
+            for i in range(len(group_ids[0])):
+                group_id = tuple(group_ids_elem[i] for group_ids_elem in group_ids)
+                group_info.append(group_id)                
+    
+    y_preds = np.array(y_preds)
+    y_preds_class = y_preds[:, 0]
+    y_preds_class_labels = (y_preds_class >= 0.5).astype(int)
+    
+    results_df = pd.DataFrame({
+        'city_pair': [info[0] for info in group_info],
+        'flight_day': [info[1] for info in group_info],
+        'flight_number_1': [info[2] for info in group_info],
+        'flight_number_2': [info[3] for info in group_info],
+        'from_date': [info[4] for info in group_info],
+        'baggage': [info[5] for info in group_info],
+        'price': [info[6] for info in group_info],
+        'Hours_until_Departure': [info[7] for info in group_info],
+        'update_hour': [info[8] for info in group_info],
+        'crawl_date': [info[9] for info in group_info],
+        'probability': y_preds_class,
+        'Predicted_Will_Price_Drop': y_preds_class_labels,
+    })
+
+    # 先转成 datetime
+    update_hour_dt = pd.to_datetime(results_df['update_hour'])   # 起飞前48小时对应时间(整点)
+    valid_begin_dt = update_hour_dt + pd.Timedelta(hours=20)     # 起飞前28小时(48-20=28)对应时间(整点) 
+
+    # 在 probability 前新增一列
+    results_df.insert(
+        loc=results_df.columns.get_loc('probability'),
+        column='valid_begin_hour',
+        value=valid_begin_dt.dt.strftime('%Y-%m-%d %H:%M:%S')
+    )
+    
+    # 数值处理
+    threshold = 1e-3
+    numeric_columns = ['probability']
+    for col in numeric_columns:
+        results_df[col] = results_df[col].where(results_df[col].abs() >= threshold, 0)
+    
+    csv_path1 = os.path.join(output_dir, 'future_predictions.csv')
+    results_df.to_csv(csv_path1, mode='a', index=False, header=not os.path.exists(csv_path1), encoding='utf-8-sig')
+
+    print("预测结果已追加")
+
+    return results_df

+ 3 - 3
train.py

@@ -164,7 +164,7 @@ def prepare_data_distribute(sequences, targets, group_ids, flag_distributed=Fals
 # 分布式训练
 def train_model_distribute(train_sequences, train_targets, train_group_ids, val_sequences, val_targets, val_group_ids,
                            model, criterion, optimizer, device, num_epochs=200, batch_size=16, target_scaler=None,
-                           flag_distributed=False, rank=0, local_rank=0, world_size=1, output_dir='.',
+                           flag_distributed=False, rank=0, local_rank=0, world_size=1, output_dir='.', assemble_idx=-1,
                            photo_dir='.', batch_idx=-1, batch_flight_routes=None, patience=20, delta=0.01
                            ):
     
@@ -230,7 +230,7 @@ def train_model_distribute(train_sequences, train_targets, train_group_ids, val_
     # 创建带权重的损失函数
     criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight_value])).to(device)
 
-    early_stopping = EarlyStoppingDist(patience=patience, verbose=True, delta=delta, path=os.path.join(output_dir, f'best_model_as_{batch_idx}.pth'),
+    early_stopping = EarlyStoppingDist(patience=patience, verbose=True, delta=delta, path=os.path.join(output_dir, f'best_model_as_{assemble_idx}.pth'),
                                        rank=rank, local_rank=local_rank)
     
     # 分布式训练模型
@@ -253,7 +253,7 @@ def train_model_distribute(train_sequences, train_targets, train_group_ids, val_
         plt.savefig(os.path.join(photo_dir, f"train_loss_batch_{batch_idx}.png"))
 
     # 训练结束后加载最佳模型参数
-    best_model_path = os.path.join(output_dir, f'best_model_as_{batch_idx}.pth')
+    best_model_path = os.path.join(output_dir, f'best_model_as_{assemble_idx}.pth')
 
     # 确保所有进程都看到相同的文件系统状态
     if flag_distributed:

+ 1 - 0
utils.py

@@ -79,6 +79,7 @@ def create_fixed_length_sequences(df, features, target_vars, threshold=48, input
                             str(last_row['Adult_Total_Price']), 
                             str(last_row['Hours_Until_Departure']),
                             str(last_row['update_hour']), 
+                            str(last_row['crawl_date']), 
                             ]
             if is_train:
                 next_name_li.append(last_row['target_amount_of_drop'])