Prechádzať zdrojové kódy

调整训练与预测架构, 增加价格相关特征, 增加自动验证函数

node04 3 týždňov pred
rodič
commit
8b6d05f544
5 zmenil súbory, kde vykonal 386 pridanie a 68 odobranie
  1. 151 29
      data_preprocess.py
  2. 16 14
      main_pe.py
  3. 25 16
      main_tr.py
  4. 187 2
      result_validate.py
  5. 7 7
      utils.py

+ 151 - 29
data_preprocess.py

@@ -10,9 +10,37 @@ from utils import insert_df_col
 COUNTRY_HOLIDAYS = build_country_holidays(city_to_country)
 
 
-def preprocess_data(df_input, features, categorical_features, is_training=True, current_n_hours=36):
-    print(">>> 开始数据预处理") 
+def preprocess_data_cycle(df_input, interval_hours=8, feature_length=240, target_length=24, is_training=True):
+
+    # df_input_part = df_input[(df_input['hours_until_departure'] >= current_n_hours) & (df_input['hours_until_departure'] < current_n_hours)].copy()
+
+    df_input = preprocess_data_first_half(df_input)
+
+    # 创建一个空列表来存储所有处理后的数据部分
+    list_df_parts = []
+
+    crop_lower_limit_list = [4]   # [4, 28, 52, 76, 100]
+    for crop_lower_limit in crop_lower_limit_list:
+        target_n_hours = crop_lower_limit + target_length
+        feature_n_hours = target_n_hours + interval_hours
+        crop_upper_limit = feature_n_hours + feature_length
+        df_input_part = preprocess_data(df_input, is_training=is_training, crop_upper_limit=crop_upper_limit, feature_n_hours=feature_n_hours,
+                                        target_n_hours=target_n_hours, crop_lower_limit=crop_lower_limit)
+        # 将处理后的部分添加到列表中
+        list_df_parts.append(df_input_part)
+        if not is_training:
+            break
+    
+    # 合并所有处理后的数据部分
+    if list_df_parts:
+        df_combined = pd.concat(list_df_parts, ignore_index=True)
+        return df_combined
+    else:
+        return pd.DataFrame()  # 如果没有数据,返回空DataFrame
 
+def preprocess_data_first_half(df_input):
+    '''前半部分'''
+    print(">>> 开始数据预处理")
     # 生成 城市对
     df_input['city_pair'] = (
         df_input['from_city_code'].astype(str) + "-" + df_input['to_city_code'].astype(str)
@@ -110,9 +138,14 @@ def preprocess_data(df_input, features, categorical_features, is_training=True,
         .ngroup()
     )
 
+    return df_input
+
+def preprocess_data(df_input, is_training=True, crop_upper_limit=480, feature_n_hours=36, target_n_hours=28, crop_lower_limit=4):
+    print(f"裁剪范围: [{crop_lower_limit}, {crop_upper_limit}], 间隔窗口: [{target_n_hours}, {feature_n_hours}]") 
+
     # 做一下时间段裁剪, 保留起飞前480小时之内且大于等于4小时的
-    df_input = df_input[(df_input['hours_until_departure'] < 480) & 
-                        (df_input['hours_until_departure'] >= 4)].reset_index(drop=True)
+    df_input = df_input[(df_input['hours_until_departure'] < crop_upper_limit) & 
+                        (df_input['hours_until_departure'] >= crop_lower_limit)].reset_index(drop=True)
     
     # 在 gid 与 baggage 内按时间降序
     df_input = df_input.sort_values(
@@ -120,34 +153,115 @@ def preprocess_data(df_input, features, categorical_features, is_training=True,
         ascending=[True, True, False]
     ).reset_index(drop=True)
 
+    # 价格幅度阈值
+    VALID_DROP_MIN = 5
+
     # 价格变化掩码
     g = df_input.groupby(['gid', 'baggage'])
     diff = g['adult_total_price'].transform('diff')
-    change_mask = diff.abs() >= 5   # 变化太小的不计入
-
+    # change_mask = diff.abs() >= VALID_DROP_MIN   # 变化太小的不计入
+    decrease_mask = diff <= -VALID_DROP_MIN        # 降价(变化太小的不计入)
+    increase_mask = diff >= VALID_DROP_MIN         # 升价(变化太小的不计入)
+
+    df_input['_price_event_dir'] = np.where(increase_mask, 1, np.where(decrease_mask, -1, 0))
+
+    # 计算连续升价/降价次数
+    def _calc_price_streaks(df_group):
+        dirs = df_group['_price_event_dir'].to_numpy()
+        n = len(dirs)
+        inc = np.full(n, np.nan)
+        dec = np.full(n, np.nan)
+
+        last_dir = 0
+        inc_cnt = 0
+        dec_cnt = 0
+        for i, d in enumerate(dirs):
+            if d == 1:
+                inc_cnt = inc_cnt + 1 if last_dir == 1 else 1
+                dec_cnt = 0
+                last_dir = 1
+                inc[i] = inc_cnt
+                dec[i] = dec_cnt
+            elif d == -1:
+                dec_cnt = dec_cnt + 1 if last_dir == -1 else 1
+                inc_cnt = 0
+                last_dir = -1
+                inc[i] = inc_cnt
+                dec[i] = dec_cnt
+        
+        inc_s = pd.Series(inc, index=df_group.index).ffill().fillna(0).astype(int)
+        dec_s = pd.Series(dec, index=df_group.index).ffill().fillna(0).astype(int)
+        return pd.DataFrame(
+            {
+                'price_increase_times_consecutive': inc_s,
+                'price_decrease_times_consecutive': dec_s,
+            },
+            index=df_group.index,
+        )
+    
+    streak_df = df_input.groupby(['gid', 'baggage'], sort=False, group_keys=False).apply(_calc_price_streaks)
+    df_input = df_input.join(streak_df)
+    df_input.drop(columns=['_price_event_dir'], inplace=True)
+    
     # 价格变化次数
-    df_input['price_change_times_total'] = (
-        change_mask.groupby([df_input['gid'], df_input['baggage']]).cumsum()
+    # df_input['price_change_times_total'] = (
+    #     change_mask.groupby([df_input['gid'], df_input['baggage']]).cumsum()
+    # )
+    # 价格下降次数
+    df_input['price_decrease_times_total'] = (
+        decrease_mask.groupby([df_input['gid'], df_input['baggage']]).cumsum()
+    )
+    # 价格上升次数
+    df_input['price_increase_times_total'] = (
+        increase_mask.groupby([df_input['gid'], df_input['baggage']]).cumsum()
     )
 
     # 上次发生变价的小时数
-    last_change_hour = (
+    # last_change_hour = (
+    #     df_input['hours_until_departure']
+    #     .where(change_mask)
+    #     .groupby([df_input['gid'], df_input['baggage']])
+    #     .ffill()  # 前向填充
+    # )
+    # 上次发生降价的小时数
+    last_decrease_hour = (
+        df_input['hours_until_departure']
+        .where(decrease_mask)
+        .groupby([df_input['gid'], df_input['baggage']])
+        .ffill()  # 前向填充
+    )
+    # 上次发生升价的小时数
+    last_increase_hour = (
         df_input['hours_until_departure']
-        .where(change_mask)
+        .where(increase_mask)
         .groupby([df_input['gid'], df_input['baggage']])
         .ffill()  # 前向填充
     )
 
     # 当前距离上一次变价过去多少小时
-    df_input['price_last_change_hours'] = (
-        last_change_hour - df_input['hours_until_departure']
+    # df_input['price_last_change_hours'] = (
+    #     last_change_hour - df_input['hours_until_departure']
+    # ).fillna(0)
+    # 当前距离上一次降价过去多少小时
+    df_input['price_last_decrease_hours'] = (
+        last_decrease_hour - df_input['hours_until_departure']
+    ).fillna(0)
+    # 当前距离上一次升价过去多少小时
+    df_input['price_last_increase_hours'] = (
+        last_increase_hour - df_input['hours_until_departure']
     ).fillna(0)
     pass
 
     # 想插入到 seats_remaining 前面的新列
     new_cols = [
-        'price_change_times_total',
-        'price_last_change_hours'
+        # 'price_change_times_total',
+        # 'price_last_change_hours',
+        'price_decrease_times_total',
+        'price_decrease_times_consecutive',
+        'price_last_decrease_hours',
+        'price_increase_times_total',
+        'price_increase_times_consecutive',
+        'price_last_increase_hours',
     ]
     # 当前所有列
     cols = df_input.columns.tolist()
@@ -481,9 +595,9 @@ def preprocess_data(df_input, features, categorical_features, is_training=True,
     # 训练模式
     if is_training:
         print(">>> 训练模式:计算 target 相关列")
-        print(f"\n>>> 开始处理 对应区间: n_hours = {current_n_hours}")
-        target_lower_limit = 4
-        target_upper_limit = current_n_hours
+        print(f"\n>>> 开始处理 对应区间: n_hours = {target_n_hours}")
+        target_lower_limit = crop_lower_limit
+        target_upper_limit = target_n_hours
         mask_targets = (df_input['hours_until_departure'] >= target_lower_limit) & (df_input['hours_until_departure'] < target_upper_limit) & (df_input['baggage'] == 30)
         df_targets = df_input.loc[mask_targets].copy()
 
@@ -491,11 +605,11 @@ def preprocess_data(df_input, features, categorical_features, is_training=True,
         print(f"当前 目标区间数据量: {targets_amout}, 区间: [{target_lower_limit}, {target_upper_limit})")
 
         if targets_amout == 0:
-            print(f">>> n_hours = {current_n_hours} 无有效数据,跳过")
+            print(f">>> n_hours = {target_n_hours} 无有效数据,跳过")
             return pd.DataFrame()
         
         print(">>> 计算 price_at_n_hours")
-        df_input_object = df_input[(df_input['hours_until_departure'] >= current_n_hours) & (df_input['baggage'] == 30)].copy()
+        df_input_object = df_input[(df_input['hours_until_departure'] >= feature_n_hours) & (df_input['baggage'] == 30)].copy()
         df_last = df_input_object.groupby('gid', observed=True).last().reset_index()   # 一般落在起飞前36\32\30小时
         
         # 提取并重命名 price 列
@@ -514,14 +628,14 @@ def preprocess_data(df_input, features, categorical_features, is_training=True,
         g = df_targets.groupby('gid', group_keys=False)
         df_targets['price_diff'] = g['adult_total_price'].diff()
         
-        VALID_DROP_MIN = 10
-        LOWER_HOUR = 4
-        UPPER_HOUR = 28
+        # VALID_DROP_MIN = 5
+        # LOWER_HOUR = 4
+        # UPPER_HOUR = 28
 
         valid_drop_mask = (
-            (df_targets['price_diff'] <= -VALID_DROP_MIN) &
-            (df_targets['hours_until_departure'] >= LOWER_HOUR) &
-            (df_targets['hours_until_departure'] <= UPPER_HOUR)
+            (df_targets['price_diff'] <= -VALID_DROP_MIN)
+            # (df_targets['hours_until_departure'] >= LOWER_HOUR) &
+            # (df_targets['hours_until_departure'] <= UPPER_HOUR)
         )
         # 有效的降价
         df_valid_drops = df_targets.loc[valid_drop_mask]
@@ -639,7 +753,9 @@ def preprocess_data(df_input, features, categorical_features, is_training=True,
     order_columns = [
         "city_pair", "from_city_code", "from_city_num", "to_city_code", "to_city_num", "flight_day", 
         "seats_remaining", "baggage", "baggage_level", 
-        "price_change_times_total", "price_last_change_hours", "adult_total_price", "Adult_Total_Price", "target_will_price_drop", "target_amount_of_drop", "target_time_to_drop",
+        "price_decrease_times_total", "price_decrease_times_consecutive", "price_last_decrease_hours", 
+        "price_increase_times_total", "price_increase_times_consecutive", "price_last_increase_hours",
+        "adult_total_price", "Adult_Total_Price", "target_will_price_drop", "target_amount_of_drop", "target_time_to_drop",
         "days_to_departure", "days_to_holiday", "hours_until_departure", "Hours_Until_Departure", "update_hour", "crawl_date", "gid",
         "flight_number_1", "flight_1_num", "airport_pair_1", "dep_time_1", "arr_time_1", "fly_duration_1", 
         "flight_by_hour", "flight_by_day", "flight_day_of_month", "flight_day_of_week", "flight_day_of_quarter", "flight_day_is_weekend", "is_transfer", 
@@ -654,7 +770,7 @@ def preprocess_data(df_input, features, categorical_features, is_training=True,
     return df_input
 
 
-def standardization(df, feature_scaler, target_scaler=None, is_training=True, is_val=False):
+def standardization(df, feature_scaler, target_scaler=None, is_training=True, is_val=False, feature_length=240):
     print(">>> 开始标准化处理")
 
     # 准备走标准化的特征
@@ -684,8 +800,14 @@ def standardization(df, feature_scaler, target_scaler=None, is_training=True, is
         'flight_1_num': (0, 341),
         'flight_2_num': (0, 341),
         'seats_remaining': (1, 5),
-        'price_change_times_total': (0, 30),     # 假设价格变更次数不会超过30次
-        'price_last_change_hours': (0, 480), 
+        # 'price_change_times_total': (0, 30),     # 假设价格变更次数不会超过30次
+        # 'price_last_change_hours': (0, 480),
+        'price_decrease_times_total': (0, 20),             # 假设价格下降次数不会超过20次
+        'price_decrease_times_consecutive': (0, 10),       # 假设价格连续下降次数不会超过10次
+        'price_last_decrease_hours': (0, feature_length),  #(0-240小时)
+        'price_increase_times_total': (0, 20),             # 假设价格上升次数不会超过20次
+        'price_increase_times_consecutive': (0, 10),       # 假设价格连续上升次数不会超过10次
+        'price_last_increase_hours': (0, feature_length),  #(0-240小时)
         'price_zone_comprehensive': (0, 5),    
         'days_to_departure': (0, 30),
         'days_to_holiday': (0, 120),             # 最长的越南节假日间隔120天

+ 16 - 14
main_pe.py

@@ -8,12 +8,12 @@ import time
 import argparse
 from datetime import datetime, timedelta
 from config import mongodb_config, vj_flight_route_list_hot, vj_flight_route_list_nothot, CLEAN_VJ_HOT_NEAR_INFO_TAB, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
-from data_loader import mongo_con_parse, load_train_data
-from data_preprocess import preprocess_data, standardization
+from data_loader import load_train_data
+from data_preprocess import preprocess_data_cycle, standardization
 from utils import chunk_list_with_index, create_fixed_length_sequences
 from model import PriceDropClassifiTransModel
 from predict import predict_future_distribute
-from main_tr import features, categorical_features, target_vars
+from main_tr import features, target_vars
 
 
 def initialize_model():
@@ -104,7 +104,7 @@ def start_predict(interval_hours):
 
     flight_route_list_len = len(flight_route_list)
     route_len_hot = len(vj_flight_route_list_hot)
-    route_len_nothot = len(vj_flight_route_list_nothot)
+    route_len_nothot = len(vj_flight_route_list_nothot[:0])  # 排除冷门航线
     
     assemble_size = 1           # 几个batch作为一个集群assemble
     current_assembled = -1      # 当前已加载的assemble索引
@@ -166,9 +166,11 @@ def start_predict(interval_hours):
         if filtered_count == 0:
             print(f"起飞时间在 {pred_hour_begin} 到 {pred_hour_end} 之间没有航班,跳过此批次。")
             continue
+        
+        feature_length = 240
 
         # 数据预处理
-        df_test_inputs = preprocess_data(df_test, features, categorical_features, is_training=False, current_n_hours=current_n_hours)
+        df_test_inputs = preprocess_data_cycle(df_test, is_training=False, interval_hours=interval_hours, feature_length=feature_length)
         
         total_rows = df_test_inputs.shape[0]
         print(f"行数: {total_rows}")
@@ -185,22 +187,22 @@ def start_predict(interval_hours):
             continue
 
         # 标准化与归一化处理
-        df_test_inputs, feature_scaler, _ = standardization(df_test_inputs, feature_scaler, is_training=False)
+        df_test_inputs, feature_scaler, _ = standardization(df_test_inputs, feature_scaler, is_training=False, feature_length=feature_length)
         print("标准化后数据样本:\n", df_test_inputs.head())
 
         threshold = current_n_hours
-        input_length = 444
+        # input_length = 444
 
         # 确保 threshold 与 input_length 之合为 480
-        if threshold == 36:
-            input_length = 444
-        elif threshold == 32:
-            input_length = 448
-        elif threshold == 30:
-            input_length = 450
+        # if threshold == 36:
+        #     input_length = 444
+        # elif threshold == 32:
+        #     input_length = 448
+        # elif threshold == 30:
+        #     input_length = 450
 
         # 生成序列
-        sequences, _, group_ids = create_fixed_length_sequences(df_test_inputs, features, target_vars, threshold, input_length, is_train=False)
+        sequences, _, group_ids = create_fixed_length_sequences(df_test_inputs, features, target_vars, threshold, feature_length, is_train=False)
         print(f"序列数量:{len(sequences)}")
 
         #----- 新增:智能模型加载 -----#

+ 25 - 16
main_tr.py

@@ -14,8 +14,8 @@ import shutil
 from datetime import datetime, timedelta
 from utils import chunk_list_with_index, create_fixed_length_sequences
 from model import PriceDropClassifiTransModel
-from data_loader import mongo_con_parse, load_train_data
-from data_preprocess import preprocess_data, standardization
+from data_loader import load_train_data
+from data_preprocess import standardization, preprocess_data_cycle
 from train import prepare_data_distribute, train_model_distribute
 from evaluate import printScore_cc
 from config import mongodb_config, vj_flight_route_list, vj_flight_route_list_hot, vj_flight_route_list_nothot, \
@@ -40,7 +40,11 @@ common_features = ['hours_until_departure', 'days_to_departure', 'seats_remainin
                   ]
 price_info_features = ['price_weighted_percentile_25', 'price_weighted_percentile_50', 'price_weighted_percentile_75', 'price_weighted_percentile_90',
                        'price_zone_comprehensive', 'price_relative_position']
-price_features = ['adult_total_price', 'price_change_times_total', 'price_last_change_hours']
+price_features = ['adult_total_price', 
+                  # 'price_change_times_total', 'price_last_change_hours'
+                  'price_decrease_times_total', 'price_decrease_times_consecutive', 'price_last_decrease_hours',
+                  'price_increase_times_total', 'price_increase_times_consecutive', 'price_last_increase_hours',
+                  ]
 encoded_columns = ['from_city_num', 'to_city_num', 'flight_1_num', 'flight_2_num', 'baggage_level']
 features = encoded_columns + price_info_features + price_features + common_features
 target_vars = ['target_will_price_drop']   # 是否降价
@@ -168,11 +172,11 @@ def start_train():
     batch_idx = -1
     batch_flight_routes = None   # 占位, 避免其它rank找不到定义
 
-    # 主干代码
-    flight_route_list = vj_flight_route_list_hot + vj_flight_route_list_nothot
+    # 主干代码 (排除冷门航线)
+    flight_route_list = vj_flight_route_list_hot + vj_flight_route_list_nothot[:0]
     flight_route_list_len = len(flight_route_list)
     route_len_hot = len(vj_flight_route_list_hot)
-    route_len_nothot = len(vj_flight_route_list_nothot)
+    route_len_nothot = len(vj_flight_route_list_nothot[:0])
 
     # 调试代码
     # s = 38   # 菲律宾2025-12-08是节假日 s=38 选到马尼拉 
@@ -287,8 +291,11 @@ def start_train():
             elif INTERVAL_HOURS == 2:
                 current_n_hours = 30
             
+            feature_length = 240
+
             # 数据预处理
-            df_train_inputs = preprocess_data(df_train, features, categorical_features, is_training=True, current_n_hours=current_n_hours)
+            df_train_inputs = preprocess_data_cycle(
+                df_train, is_training=True, interval_hours=INTERVAL_HOURS, feature_length=feature_length)
             print("预处理后数据样本:\n", df_train_inputs.head())
 
             total_rows = df_train_inputs.shape[0]
@@ -299,7 +306,8 @@ def start_train():
                 continue
             
             # 标准化与归一化处理
-            df_train_inputs, feature_scaler, target_scaler = standardization(df_train_inputs, feature_scaler=None, target_scaler=None)
+            df_train_inputs, feature_scaler, target_scaler = standardization(
+                df_train_inputs, feature_scaler=None, target_scaler=None, feature_length=feature_length)
 
             # 将 scaler 存入列表
             batch_idx = i
@@ -317,18 +325,19 @@ def start_train():
             print("assemble_idx:", assemble_idx)
             
             threshold = current_n_hours
-            input_length = 444
+            # input_length = 444
 
             # 确保 threshold 与 input_length 之合为 480
-            if threshold == 36:
-                input_length = 444
-            elif threshold == 32:
-                input_length = 448
-            elif threshold == 30:
-                input_length = 450
+            # if threshold == 36:
+            #     input_length = 444
+            # elif threshold == 32:
+            #     input_length = 448
+            # elif threshold == 30:
+            #     input_length = 450
 
             # 生成序列
-            sequences, targets, group_ids = create_fixed_length_sequences(df_train_inputs, features, target_vars, threshold, input_length)
+            sequences, targets, group_ids = create_fixed_length_sequences(
+                df_train_inputs, features, target_vars, threshold, feature_length)
             
             # 新增有效性检查
             if len(sequences) == 0 or len(targets) == 0 or len(group_ids) == 0:

+ 187 - 2
result_validate.py

@@ -1,6 +1,7 @@
 import os
 import datetime
 import pandas as pd
+import argparse
 from data_loader import mongo_con_parse, validate_one_line, fill_hourly_crawl_date
 
 
@@ -118,7 +119,191 @@ def validate_process(node, interval_hours, pred_time_str):
     df_predict.to_csv(output_path, index=False, encoding="utf-8-sig")
     print(f"保存完成: {output_path}")
 
+def validate_process_auto(node, interval_hours):
+    '''自动验证脚本'''
+    # 当前时间,取整时
+    current_time = datetime.datetime.now() 
+    current_time_str = current_time.strftime("%Y%m%d%H%M")
+    hourly_time = current_time.replace(minute=0, second=0, microsecond=0)
+    vali_time_str = hourly_time.strftime("%Y%m%d%H%M")
+    print(f"验证时间:{current_time_str}, (取整): {vali_time_str}")
+
+    output_dir = f"./validate/{node}"
+    os.makedirs(output_dir, exist_ok=True)
+
+    object_dir = "./predictions"
+    if interval_hours == 4:
+        object_dir = "./predictions_4"
+    elif interval_hours == 2:
+        object_dir = "./predictions_2"
+    
+    # 检查目录是否存在
+    if not os.path.exists(object_dir):
+        print(f"目录不存在: {object_dir}")
+        return
+    
+    # 获取所有以 future_predictions_ 开头的 CSV 文件
+    csv_files = []
+    for file in os.listdir(object_dir):
+        if file.startswith("future_predictions_") and file.endswith(".csv"):
+            csv_files.append(file)
+    
+    if not csv_files:
+        print(f"在 {object_dir} 中没有找到 future_predictions_ 开头的 CSV 文件")
+        return
+    
+    # 提取时间戳并转换为 datetime 对象
+    file_times = []
+    for file in csv_files:
+        # 提取时间戳部分:future_predictions_202601151600.csv -> 202601151600
+        timestamp_str = file.replace("future_predictions_", "").replace(".csv", "")
+        try:
+            # 将时间戳转换为 datetime 对象
+            file_time = datetime.datetime.strptime(timestamp_str, "%Y%m%d%H%M")
+            file_times.append((file, file_time))
+        except ValueError as e:
+            print(f"文件 {file} 的时间戳格式错误: {e}")
+            continue
+    
+    if not file_times:
+        print("没有找到有效的时间戳文件")
+        return
+    
+    # 计算昨天的对应时间
+    yesterday_time = hourly_time - datetime.timedelta(hours=24)
+    print(f"昨天对应时间: {yesterday_time.strftime('%Y%m%d%H%M')}")
+
+    # 过滤出小于昨天对应时间的文件,并按时间排序
+    valid_files = [(f, t) for f, t in file_times if t < yesterday_time]
+    valid_files.sort(key=lambda x: x[1])  # 按时间升序排序
+
+    if not valid_files:
+        print(f"没有找到小于昨天对应时间 {yesterday_time.strftime('%Y%m%d%H%M')} 的文件")
+        return
+    
+    # 获取最后一个小于昨天对应时间的文件
+    last_valid_file, last_valid_time = valid_files[-1]
+    last_valid_time_str = last_valid_time.strftime("%Y%m%d%H%M")
+    print(f"找到符合条件的文件: {last_valid_file} (时间: {last_valid_time_str})")
+
+    csv_path = os.path.join(object_dir, last_valid_file)
+
+    # 开始验证
+    try:
+        df_predict = pd.read_csv(csv_path)
+    except Exception as e:
+        print(f"read {csv_path} error: {str(e)}")
+        df_predict = pd.DataFrame()
+    
+    if df_predict.empty:
+        print(f"预测数据为空")
+        return
+    
+    client, db = mongo_con_parse()
+
+    count = 0
+    for idx, row in df_predict.iterrows(): 
+        city_pair = row['city_pair']
+        flight_day = row['flight_day']
+        flight_number_1 = row['flight_number_1']
+        flight_number_2 = row['flight_number_2']
+        baggage = row['baggage']
+        valid_begin_hour = row['valid_begin_hour'] 
+        df_val= validate_one_line(db, city_pair, flight_day, flight_number_1, flight_number_2, baggage, valid_begin_hour)
+        # 有可能在当前验证时刻,数据库里没有在valid_begin_hour之后的数据
+        if not df_val.empty:
+            df_val_f = fill_hourly_crawl_date(df_val, rear_fill=2)
+            df_val_f = df_val_f[df_val_f['is_filled']==0]    # 只要原始数据,不要补齐的
+            if df_val_f.empty:
+                drop_flag = 0
+                first_drop_amount = pd.NA
+                first_drop_hours = pd.NA
+                last_hours_util = pd.NA
+                last_update_hour = pd.NA
+                list_change_price = []
+                list_change_hours = []
+            else:
+                # 有效数据的最后一行
+                last_row = df_val_f.iloc[-1]
+                last_hours_util = last_row['hours_until_departure']
+                last_update_hour = last_row['update_hour']
+                
+                # 价格变化过滤
+                df_price_changes = df_val_f.loc[
+                    df_val_f["adult_total_price"].shift() != df_val_f["adult_total_price"]
+                ].copy()
+            
+                # 价格变化幅度
+                df_price_changes['change_amount'] = df_price_changes['adult_total_price'].diff().fillna(0)
+
+                # 找到第一个 change_amount 小于 -10 的行
+                first_negative_change = df_price_changes[df_price_changes['change_amount'] < -10].head(1)
+
+                # 提取所需的值
+                if not first_negative_change.empty:
+                    drop_flag = 1
+                    first_drop_amount = first_negative_change['change_amount'].iloc[0].round(2)
+                    first_drop_hours = first_negative_change['hours_until_departure'].iloc[0]
+                else:
+                    drop_flag = 0
+                    first_drop_amount = pd.NA
+                    first_drop_hours = pd.NA
+                    
+                list_change_price = df_price_changes['adult_total_price'].tolist()
+                list_change_hours = df_price_changes['hours_until_departure'].tolist()
+        
+        else:
+            drop_flag = 0
+            first_drop_amount = pd.NA
+            first_drop_hours = pd.NA
+            last_hours_util = pd.NA
+            last_update_hour = pd.NA
+            list_change_price = []
+            list_change_hours = []
+
+        safe_sep = "; "
+        
+        df_predict.at[idx, 'change_prices'] = safe_sep.join(map(str, list_change_price))
+        df_predict.at[idx, 'change_hours'] = safe_sep.join(map(str, list_change_hours))
+        df_predict.at[idx, 'last_hours_util'] = last_hours_util
+        df_predict.at[idx, 'last_update_hour'] = last_update_hour
+        df_predict.at[idx, 'first_drop_amount'] = first_drop_amount * -1  # 负数转正数
+        df_predict.at[idx, 'first_drop_hours'] = first_drop_hours
+        df_predict.at[idx, 'drop_flag'] = drop_flag
+
+        count += 1
+        if count % 5 == 0:
+            print(f"cal count: {count}")
+    
+    print(f"计算结束")
+    client.close()
+
+    timestamp_str = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
+    save_scv = f"result_validate_{node}_{interval_hours}_{last_valid_time_str}_{timestamp_str}.csv"
+    
+    output_path = os.path.join(output_dir, save_scv)
+    df_predict.to_csv(output_path, index=False, encoding="utf-8-sig")
+    print(f"保存完成: {output_path}")
+    print(f"验证完成: {node} {interval_hours} {last_valid_time_str}")    
+    print()
 
 if __name__ == "__main__":
-    node, interval_hours, pred_time_str = "node0112", 8, "202601141600"
-    validate_process(node, interval_hours, pred_time_str)
+    parser = argparse.ArgumentParser(description='验证脚本')
+    parser.add_argument('--interval', type=int, choices=[2, 4, 8], 
+                        default=0, help='间隔小时数(2, 4, 8)')
+    args = parser.parse_args() 
+    interval_hours = args.interval
+
+    # 0 手动验证
+    if interval_hours == 0:
+        node, interval_hours, pred_time_str = "node0112", 8, "202601151600"
+        validate_process(node, interval_hours, pred_time_str)
+    # 自动验证
+    else:
+        # 这个node可以手动去改
+        node = "node0112"
+        if interval_hours == 4:
+            node = "node0114"
+        if interval_hours == 2:
+            node = "node0115"    
+        validate_process_auto(node, interval_hours)

+ 7 - 7
utils.py

@@ -28,7 +28,7 @@ def insert_df_col(df, insert_col_name, base_col_name, inplace=True):
     return df
 
 # 真正创建序列过程
-def create_fixed_length_sequences(df, features, target_vars, threshold=36, input_length=444, is_train=True):
+def create_fixed_length_sequences(df, features, target_vars, threshold=36, feature_length=240, is_train=True):
     print(">>开始创建序列")
     start_time = time.time()
 
@@ -49,20 +49,20 @@ def create_fixed_length_sequences(df, features, target_vars, threshold=36, input
         df_group_bag_30 = df_group[df_group['baggage']==30]
         df_group_bag_20 = df_group[df_group['baggage']==20]
 
-        # 过滤训练时间段 (36 ~ 480)
-        df_group_bag_30_filtered = df_group_bag_30[(df_group_bag_30['Hours_Until_Departure'] >= threshold) & (df_group_bag_30['Hours_Until_Departure'] < threshold + input_length)]
-        df_group_bag_20_filtered = df_group_bag_20[(df_group_bag_20['Hours_Until_Departure'] >= threshold) & (df_group_bag_20['Hours_Until_Departure'] < threshold + input_length)]
+        # 过滤训练时间段 (36 ~ 36 + 240)
+        df_group_bag_30_filtered = df_group_bag_30[(df_group_bag_30['Hours_Until_Departure'] >= threshold) & (df_group_bag_30['Hours_Until_Departure'] < threshold + feature_length)]
+        df_group_bag_20_filtered = df_group_bag_20[(df_group_bag_20['Hours_Until_Departure'] >= threshold) & (df_group_bag_20['Hours_Until_Departure'] < threshold + feature_length)]
 
         # 条件: 长度要一致
         condition_list = [
-            len(df_group_bag_30_filtered) == input_length,
-            len(df_group_bag_20_filtered) == input_length,
+            len(df_group_bag_30_filtered) == feature_length,
+            len(df_group_bag_20_filtered) == feature_length,
         ]
         if all(condition_list):
             seq_features_1 = df_group_bag_30_filtered[features].to_numpy()
             seq_features_2 = df_group_bag_20_filtered[features].to_numpy()
             
-            # 将几个特征序列沿着第 0 维拼接,得到形状为 (2, 444, 31)
+            # 将几个特征序列沿着第 0 维拼接,得到形状为 (2, 240, 35)
             combined_features = torch.stack([torch.tensor(seq_features_1, dtype=torch.float32),    
                                              torch.tensor(seq_features_2, dtype=torch.float32)])