Переглянути джерело

调整结构,支持训练和预测多个不同时间间隔下的模型

node04 4 тижнів тому
батько
коміт
36c67cbdc2
9 змінених файлів з 92 додано та 36 видалено
  1. 1 0
      config.py
  2. 2 2
      data_loader.py
  3. 1 1
      data_preprocess.py
  4. 4 4
      evaluate.py
  5. 49 15
      main_pe.py
  6. 26 3
      main_tr.py
  7. 2 6
      predict.py
  8. 5 3
      result_validate.py
  9. 2 2
      train.py

+ 1 - 0
config.py

@@ -6,6 +6,7 @@ CLEAN_VJ_HOT_FAR_INFO_TAB = "clean_flights_vj_hot_7_30_info_tab"
 CLEAN_VJ_NOTHOT_NEAR_INFO_TAB = "clean_flights_vj_nothot_0_7_info_tab"
 CLEAN_VJ_NOTHOT_FAR_INFO_TAB = "clean_flights_vj_nothot_7_30_info_tab"
 
+INTERVAL_HOURS = 8
 
 mongodb_config = {
     "host": "192.168.20.218",

+ 2 - 2
data_loader.py

@@ -1005,10 +1005,10 @@ if __name__ == "__main__":
     os.makedirs(output_dir, exist_ok=True)
 
     # 加载热门航线数据
-    date_begin = "2025-12-07"
+    date_begin = "2026-01-08"
     date_end = datetime.today().strftime("%Y-%m-%d")
 
-    flight_route_list = vj_flight_route_list_hot[4:]  # 热门 vj_flight_route_list_hot  冷门 vj_flight_route_list_nothot
+    flight_route_list = vj_flight_route_list_hot[:]  # 热门 vj_flight_route_list_hot  冷门 vj_flight_route_list_nothot
     table_name = CLEAN_VJ_HOT_NEAR_INFO_TAB  # 热门 CLEAN_VJ_HOT_NEAR_INFO_TAB  冷门 CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
     is_hot = 1   # 1 热门 0 冷门
     group_size = 1

+ 1 - 1
data_preprocess.py

@@ -496,7 +496,7 @@ def preprocess_data(df_input, features, categorical_features, is_training=True,
         
         print(">>> 计算 price_at_n_hours")
         df_input_object = df_input[(df_input['hours_until_departure'] >= current_n_hours) & (df_input['baggage'] == 30)].copy()
-        df_last = df_input_object.groupby('gid', observed=True).last().reset_index()   # 一般落在起飞前48小时
+        df_last = df_input_object.groupby('gid', observed=True).last().reset_index()   # 一般落在起飞前36\32\30小时
         
         # 提取并重命名 price 列
         df_last_price_at_n_hours = df_last[['gid', 'adult_total_price']].rename(columns={'adult_total_price': 'price_at_n_hours'})

+ 4 - 4
evaluate.py

@@ -14,7 +14,7 @@ from utils import FlightDataset
 # 分布式模型评估
 def evaluate_model_distribute(model, device, sequences, targets, group_ids, batch_size=16, test_loader=None, 
                               batch_flight_routes=None, target_scaler=None, 
-                              flag_distributed=False, rank=0, local_rank=0, world_size=1, output_dir='.', batch_idx=-1,
+                              flag_distributed=False, rank=0, local_rank=0, world_size=1, output_dir='.', photo_dir='.', batch_idx=-1,
                               csv_file='evaluate_results.csv', evalute_flag='evaluate', save_mode='a'):
     
     if test_loader is None:
@@ -95,7 +95,7 @@ def evaluate_model_distribute(model, device, sequences, targets, group_ids, batc
         y_trues_class_labels = y_trues_class.astype(int)
 
         # 打印指标
-        printScore_cc(y_trues_class_labels, y_preds_class_labels, batch_fn_str=batch_fn_str, batch_idx=batch_idx, evalute_flag=evalute_flag)
+        printScore_cc(y_trues_class_labels, y_preds_class_labels, batch_fn_str=batch_fn_str, batch_idx=batch_idx, evalute_flag=evalute_flag, photo_dir=photo_dir)
 
         # 构造 DataFrame
         results_df = pd.DataFrame({
@@ -154,7 +154,7 @@ def evaluate_model_distribute(model, device, sequences, targets, group_ids, batc
         return None
 
 
-def printScore_cc(y_trues_class_labels, y_preds_class_labels, batch_fn_str='', batch_idx=-1, evalute_flag='evaluate'):
+def printScore_cc(y_trues_class_labels, y_preds_class_labels, batch_fn_str='', batch_idx=-1, evalute_flag='evaluate', photo_dir='.'):
     
     accuracy = accuracy_score(y_trues_class_labels, y_preds_class_labels)
     precision = precision_score(y_trues_class_labels, y_preds_class_labels, zero_division=0)
@@ -188,4 +188,4 @@ def printScore_cc(y_trues_class_labels, y_preds_class_labels, batch_fn_str='', b
     plt.xlabel('预测情况', fontproperties=font_prop)
     plt.ylabel('实际结果', fontproperties=font_prop)
     plt.title('分类结果的混淆矩阵', fontproperties=font_prop)
-    plt.savefig(f"./photo/{evalute_flag}_confusion_matrix_{batch_idx}_{batch_fn_str}.png")
+    plt.savefig(f"{photo_dir}/{evalute_flag}_confusion_matrix_{batch_idx}_{batch_fn_str}.png")

+ 49 - 15
main_pe.py

@@ -1,10 +1,11 @@
 import os
 import torch
 import joblib
-import pandas as pd
-import numpy as np
+# import pandas as pd
+# import numpy as np
 import pickle
 import time
+import argparse
 from datetime import datetime, timedelta
 from config import mongodb_config, vj_flight_route_list_hot, vj_flight_route_list_nothot, CLEAN_VJ_HOT_NEAR_INFO_TAB, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
 from data_loader import mongo_con_parse, load_train_data
@@ -15,10 +16,6 @@ from predict import predict_future_distribute
 from main_tr import features, categorical_features, target_vars
 
 
-output_dir = "./data_shards"
-photo_dir = "./photo"
-
-
 def initialize_model():
     input_size = len(features)
     model = PriceDropClassifiTransModel(input_size, num_periods=2, hidden_size=64, num_layers=3, output_size=1, dropout=0.2)
@@ -34,11 +31,28 @@ def convert_date_format(date_str):
     return dt
     # return dt.strftime('%Y%m%d%H%M00')
 
-def start_predict():
+def start_predict(interval_hours):
+
+    print(f"开始预测,间隔小时数: {interval_hours}")
+
+    output_dir = "./data_shards"
+    photo_dir = "./photo"
+    predict_dir = "./predictions"
+
+    if interval_hours == 4:
+        output_dir = "./data_shards_4"
+        photo_dir = "./photo_4"
+        predict_dir = "./predictions_4"
+
+    elif interval_hours == 2:
+        output_dir = "./data_shards_2"
+        photo_dir = "./photo_2"
+        predict_dir = "./predictions_2"
 
     # 确保目录存在
     os.makedirs(output_dir, exist_ok=True) 
     os.makedirs(photo_dir, exist_ok=True)
+    os.makedirs(predict_dir, exist_ok=True)
 
     # 清空上一次预测结果
     # csv_file_list = ['future_predictions.csv']
@@ -61,9 +75,15 @@ def start_predict():
     pred_time_str = hourly_time.strftime("%Y%m%d%H%M")
     print(f"预测时间(取整): {pred_time_str}")
 
-    # 预测时间范围,满足起飞时间 在28小时后到40小时后
+    current_n_hours = 36
+    if interval_hours == 4:
+        current_n_hours = 32
+    elif interval_hours == 2:
+        current_n_hours = 30
+
+    # 预测时间范围,满足起飞时间 在28小时后到36/32/30小时后
     pred_hour_begin = hourly_time + timedelta(hours=28)
-    pred_hour_end = hourly_time + timedelta(hours=40)
+    pred_hour_end = hourly_time + timedelta(hours=current_n_hours)
 
     pred_date_end = pred_hour_end.strftime("%Y-%m-%d")
     pred_date_begin = pred_hour_begin.strftime("%Y-%m-%d")
@@ -134,7 +154,7 @@ def start_predict():
         # 创建临时字段:seg1_dep_time 的整点时间
         df_test['seg1_dep_hour'] = df_test['seg1_dep_time'].dt.floor('h')
         # 使用整点时间进行比较过滤
-        mask = (df_test['seg1_dep_hour'] >= pred_hour_begin) & (df_test['seg1_dep_hour'] <= pred_hour_end)
+        mask = (df_test['seg1_dep_hour'] >= pred_hour_begin) & (df_test['seg1_dep_hour'] < pred_hour_end)
         original_count = len(df_test)
         df_test = df_test[mask].reset_index(drop=True)
         filtered_count = len(df_test)
@@ -147,7 +167,7 @@ def start_predict():
             continue
 
         # 数据预处理
-        df_test_inputs = preprocess_data(df_test, features, categorical_features, is_training=False)
+        df_test_inputs = preprocess_data(df_test, features, categorical_features, is_training=False, current_n_hours=current_n_hours)
         
         total_rows = df_test_inputs.shape[0]
         print(f"行数: {total_rows}")
@@ -165,11 +185,21 @@ def start_predict():
 
         # 标准化与归一化处理
         df_test_inputs, feature_scaler, _ = standardization(df_test_inputs, feature_scaler, is_training=False)
-
         print("标准化后数据样本:\n", df_test_inputs.head())
 
+        threshold = current_n_hours
+        input_length = 444
+
+        # 确保 threshold 与 input_length 之合为 480
+        if threshold == 36:
+            input_length = 444
+        elif threshold == 32:
+            input_length = 448
+        elif threshold == 30:
+            input_length = 450
+
         # 生成序列
-        sequences, _, group_ids = create_fixed_length_sequences(df_test_inputs, features, target_vars, is_train=False)
+        sequences, _, group_ids = create_fixed_length_sequences(df_test_inputs, features, target_vars, threshold, input_length, is_train=False)
         print(f"序列数量:{len(sequences)}")
 
         #----- 新增:智能模型加载 -----#
@@ -192,7 +222,7 @@ def start_predict():
 
         target_scaler = None
         # 预测未来数据
-        predict_future_distribute(model, sequences, group_ids, target_scaler=target_scaler, output_dir=output_dir, pred_time_str=pred_time_str)
+        predict_future_distribute(model, sequences, group_ids, target_scaler=target_scaler, predict_dir=predict_dir, pred_time_str=pred_time_str)
 
     print("所有批次的预测结束")
     print()
@@ -213,4 +243,8 @@ def start_predict():
 
 
 if __name__ == "__main__":
-    start_predict()
+    parser = argparse.ArgumentParser(description='预测脚本')
+    parser.add_argument('--interval', type=int, choices=[2, 4, 8], 
+                        default=8, help='间隔小时数(2, 4, 8)')
+    args = parser.parse_args()
+    start_predict(args.interval)

+ 26 - 3
main_tr.py

@@ -19,7 +19,7 @@ from data_preprocess import preprocess_data, standardization
 from train import prepare_data_distribute, train_model_distribute
 from evaluate import printScore_cc
 from config import mongodb_config, vj_flight_route_list, vj_flight_route_list_hot, vj_flight_route_list_nothot, \
-    CLEAN_VJ_HOT_NEAR_INFO_TAB, CLEAN_VJ_HOT_FAR_INFO_TAB, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB, CLEAN_VJ_NOTHOT_FAR_INFO_TAB
+    CLEAN_VJ_HOT_NEAR_INFO_TAB, CLEAN_VJ_HOT_FAR_INFO_TAB, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB, CLEAN_VJ_NOTHOT_FAR_INFO_TAB, INTERVAL_HOURS
 
 warnings.filterwarnings('ignore')
 
@@ -110,6 +110,12 @@ def start_train():
 
     output_dir = "./data_shards"
     photo_dir = "./photo"
+    if INTERVAL_HOURS == 4:
+        output_dir = "./data_shards_4"
+        photo_dir = "./photo_4"
+    elif INTERVAL_HOURS == 2:
+        output_dir = "./data_shards_2"
+        photo_dir = "./photo_2"
 
     date_end = datetime.today().strftime("%Y-%m-%d")
     # date_begin = (datetime.today() - timedelta(days=41)).strftime("%Y-%m-%d")
@@ -274,9 +280,15 @@ def start_train():
                 print(f"训练数据为空,跳过此批次。")
                 continue_before_process(redis_client, lock_key)
                 continue
+
+            current_n_hours = 36
+            if INTERVAL_HOURS == 4:
+                current_n_hours = 32
+            elif INTERVAL_HOURS == 2:
+                current_n_hours = 30
             
             # 数据预处理
-            df_train_inputs = preprocess_data(df_train, features, categorical_features, is_training=True)
+            df_train_inputs = preprocess_data(df_train, features, categorical_features, is_training=True, current_n_hours=current_n_hours)
             print("预处理后数据样本:\n", df_train_inputs.head())
 
             total_rows = df_train_inputs.shape[0]
@@ -304,8 +316,19 @@ def start_train():
             assemble_idx = batch_idx // assemble_size  # 计算当前集群索引
             print("assemble_idx:", assemble_idx)
             
+            threshold = current_n_hours
+            input_length = 444
+
+            # 确保 threshold 与 input_length 之合为 480
+            if threshold == 36:
+                input_length = 444
+            elif threshold == 32:
+                input_length = 448
+            elif threshold == 30:
+                input_length = 450
+
             # 生成序列
-            sequences, targets, group_ids = create_fixed_length_sequences(df_train_inputs, features, target_vars)
+            sequences, targets, group_ids = create_fixed_length_sequences(df_train_inputs, features, target_vars, threshold, input_length)
             
             # 新增有效性检查
             if len(sequences) == 0 or len(targets) == 0 or len(group_ids) == 0:

+ 2 - 6
predict.py

@@ -7,7 +7,7 @@ from torch.utils.data import DataLoader
 from utils import FlightDataset
 
 
-def predict_future_distribute(model, sequences, group_ids, batch_size=16, target_scaler=None, output_dir=".", pred_time_str=""):
+def predict_future_distribute(model, sequences, group_ids, batch_size=16, target_scaler=None, predict_dir=".", pred_time_str=""):
     if not sequences:
         print("没有足够的数据进行预测。")
         return
@@ -65,11 +65,7 @@ def predict_future_distribute(model, sequences, group_ids, batch_size=16, target
     for col in numeric_columns:
         results_df[col] = results_df[col].where(results_df[col].abs() >= threshold, 0)
     
-    # 修改预测保存路径
-    output_dir = './predictions'
-    os.makedirs(output_dir, exist_ok=True)
-
-    csv_path1 = os.path.join(output_dir, f'future_predictions_{pred_time_str}.csv')
+    csv_path1 = os.path.join(predict_dir, f'future_predictions_{pred_time_str}.csv')
     results_df.to_csv(csv_path1, mode='a', index=False, header=not os.path.exists(csv_path1), encoding='utf-8-sig')
 
     print("预测结果已追加")

+ 5 - 3
result_validate.py

@@ -4,7 +4,9 @@ import pandas as pd
 from data_loader import mongo_con_parse, validate_one_line, fill_hourly_crawl_date
 
 
-def validate_process(node, date, pred_time_str):
+def validate_process(node, pred_time_str):
+
+    date = pred_time_str[4:8]
 
     output_dir = f"./validate/{node}_{date}"
     os.makedirs(output_dir, exist_ok=True)
@@ -113,5 +115,5 @@ def validate_process(node, date, pred_time_str):
 
 
 if __name__ == "__main__":
-    node, date, pred_time_str = "node0108", "0110", "202601100800"
-    validate_process(node, date, pred_time_str)
+    node, pred_time_str = "node0108", "202601121000"
+    validate_process(node, pred_time_str)

+ 2 - 2
train.py

@@ -302,7 +302,7 @@ def train_model_distribute(train_sequences, train_targets, train_group_ids, val_
             batch_flight_routes=batch_flight_routes, target_scaler=target_scaler,
             flag_distributed=flag_distributed,
             rank=rank, local_rank=local_rank, world_size=world_size, 
-            output_dir=output_dir, batch_idx=batch_idx, save_mode='a'
+            output_dir=output_dir, photo_dir=photo_dir, batch_idx=batch_idx, save_mode='a'
         )
     else:
         evaluate_model_distribute(
@@ -312,7 +312,7 @@ def train_model_distribute(train_sequences, train_targets, train_group_ids, val_
             test_loader=val_loader,  # 使用累积验证集
             batch_flight_routes=batch_flight_routes, target_scaler=target_scaler,
             flag_distributed=False,
-            output_dir=output_dir, batch_idx=batch_idx, save_mode='a'
+            output_dir=output_dir, photo_dir=photo_dir, batch_idx=batch_idx, save_mode='a'
         )
 
     return model