Преглед на файлове

提交基于统计的快速模型训练与预测代码

node04 преди 2 седмици
родител
ревизия
f5374dda0a
променени са 4 файла, в които са добавени 358 реда и са изтрити 45 реда
  1. 4 4
      data_loader.py
  2. 297 6
      data_preprocess.py
  3. 19 9
      main_pe_0.py
  4. 38 26
      main_tr_0.py

+ 4 - 4
data_loader.py

@@ -332,8 +332,8 @@ def fill_hourly_crawl_date(df, head_fill=0, rear_fill=0):
     # 假设你想保留最早的一条
     df = df.sort_values(['update_hour', 'crawl_date'])
 
-    # 3. 按小时去重,保留该小时内最早的一条
-    df = df.drop_duplicates(subset=['update_hour'], keep='first')
+    # 3. 按小时去重,保留该小时内最早(最晚)的一条
+    df = df.drop_duplicates(subset=['update_hour'], keep='last')   #  keep='first'  keep='last'
 
     # 删除原始时间戳列
     # df = df.drop(columns=['crawl_date'])
@@ -1005,7 +1005,7 @@ if __name__ == "__main__":
     os.makedirs(output_dir, exist_ok=True)
 
     # 加载热门航线数据
-    date_begin = "2026-01-08"
+    date_begin = "2026-01-15"
     date_end = datetime.today().strftime("%Y-%m-%d")
 
     flight_route_list = vj_flight_route_list_hot[:]  # 热门 vj_flight_route_list_hot  冷门 vj_flight_route_list_nothot
@@ -1019,7 +1019,7 @@ if __name__ == "__main__":
         # client, db = mongo_con_parse()
         print(f"第 {idx} 组 :", group_route_list)
         start_time = time.time()
-        load_train_data(mongodb_config, group_route_list, table_name, date_begin, date_end, output_dir, is_hot, plot_flag=False,
+        load_train_data(mongodb_config, group_route_list, table_name, date_begin, date_end, output_dir, is_hot, plot_flag=True,
                         use_multiprocess=True, max_workers=max_workers)
         end_time = time.time()
         run_time = round(end_time - start_time, 3)

+ 297 - 6
data_preprocess.py

@@ -2,6 +2,7 @@ import pandas as pd
 import numpy as np
 import bisect
 import gc
+import os
 from datetime import datetime, timedelta
 from sklearn.preprocessing import StandardScaler
 from config import city_to_country, vj_city_code_map, vi_flight_number_map, build_country_holidays
@@ -833,7 +834,7 @@ def standardization(df, feature_scaler, target_scaler=None, is_training=True, is
     return df, feature_scaler, target_scaler
 
 
-def preprocess_data_simple(df_input, is_train=False, output_dir='.'):
+def preprocess_data_simple(df_input, is_train=False):
 
     df_input = preprocess_data_first_half(df_input)
     
@@ -884,6 +885,7 @@ def preprocess_data_simple(df_input, is_train=False, output_dir='.'):
     df_input['Hours_Until_Departure'] = hours_until
     df_input['Baggage'] = df_input['baggage']
 
+    # 训练过程
     if is_train:
         df_target = df_input[(df_input['hours_until_departure'] >= 18) & (df_input['hours_until_departure'] <= 54)].copy()
         df_target = df_target.sort_values(
@@ -891,6 +893,7 @@ def preprocess_data_simple(df_input, is_train=False, output_dir='.'):
             ascending=[True, False]
         ).reset_index(drop=True)
 
+        # 对于先升后降的分析
         prev_pct = df_target.groupby('gid', group_keys=False)['price_change_percent'].shift(1)
         prev_amo = df_target.groupby('gid', group_keys=False)['price_change_amount'].shift(1)
         prev_dur = df_target.groupby('gid', group_keys=False)['price_duration_hours'].shift(1)
@@ -907,11 +910,13 @@ def preprocess_data_simple(df_input, is_train=False, output_dir='.'):
         df_drop_nodes = df_drop_nodes.reset_index(drop=True)
 
         flight_info_cols = [
-            'city_pair', 
+            'city_pair',
             'flight_number_1', 'seg1_dep_air_port', 'seg1_dep_time', 'seg1_arr_air_port', 'seg1_arr_time',
             'flight_number_2', 'seg2_dep_air_port', 'seg2_dep_time', 'seg2_arr_air_port', 'seg2_arr_time',
             'currency', 'baggage', 'flight_day',
         ]
+
+        flight_info_cols = [c for c in flight_info_cols if c in df_target.columns]
         
         df_gid_info = df_target[['gid'] + flight_info_cols].drop_duplicates(subset=['gid']).reset_index(drop=True)
         df_drop_nodes = df_drop_nodes.merge(df_gid_info, on='gid', how='left')
@@ -920,12 +925,298 @@ def preprocess_data_simple(df_input, is_train=False, output_dir='.'):
                           'high_price_duration_hours', 'high_price_change_percent', 'high_price_change_amount'
         ]
         # 按顺序排列 去掉gid
-        order_columns = flight_info_cols + drop_info_cols
-        df_drop_nodes = df_drop_nodes[order_columns]
+        df_drop_nodes = df_drop_nodes[flight_info_cols + drop_info_cols]
+
+        # 对于没有先升后降的gid进行分析
+        gids_with_drop = df_target.loc[drop_mask, 'gid'].unique()
+        df_no_drop = df_target[~df_target['gid'].isin(gids_with_drop)].copy()
+
+        keep_info_cols = [
+            'keep_hours_until_departure', 'keep_price_change_percent', 'keep_price_change_amount', 'keep_price_duration_hours'
+        ]
+        
+        if df_no_drop.empty:
+            df_keep_nodes = pd.DataFrame(columns=flight_info_cols + keep_info_cols)
+        else:
+            df_no_drop = df_no_drop.sort_values(
+                by=['gid', 'hours_until_departure'],
+                ascending=[True, False]
+            ).reset_index(drop=True)
+
+            df_no_drop['keep_segment'] = df_no_drop.groupby('gid')['price_change_percent'].transform(
+                lambda s: (s != s.shift()).cumsum()
+            )
+
+            df_keep_row = (
+                df_no_drop.groupby(['gid', 'keep_segment'], as_index=False)
+                .tail(1)
+                .reset_index(drop=True)
+            )
+
+            df_keep_nodes = df_keep_row[
+                ['gid', 'hours_until_departure', 'price_change_percent', 'price_change_amount', 'price_duration_hours']
+            ].copy()
+            df_keep_nodes.rename(
+                columns={
+                    'hours_until_departure': 'keep_hours_until_departure',
+                    'price_change_percent': 'keep_price_change_percent',
+                    'price_change_amount': 'keep_price_change_amount',
+                    'price_duration_hours': 'keep_price_duration_hours',
+                },
+                inplace=True,
+            )
+
+            df_keep_nodes = df_keep_nodes.merge(df_gid_info, on='gid', how='left')
+            df_keep_nodes = df_keep_nodes[flight_info_cols + keep_info_cols]
+
+            del df_keep_row
         
         del df_gid_info
         del df_target
+        del df_no_drop
+
+        return df_input, df_drop_nodes, df_keep_nodes
+
+    return df_input, None, None
+
+
+def predict_data_simple(df_input, group_route_str, output_dir, predict_dir=".", pred_time_str=""):
+    if df_input is None or df_input.empty:
+        return pd.DataFrame()
+    
+    df_sorted = df_input.sort_values(
+        by=['gid', 'hours_until_departure'],
+        ascending=[True, False],
+    ).reset_index(drop=True)
+
+    df_sorted = df_sorted[
+        df_sorted['hours_until_departure'].between(18, 54)
+    ].reset_index(drop=True)
+
+    # 每个 gid 取 hours_until_departure 最小的一条
+    df_min_hours = (
+        df_sorted.drop_duplicates(subset=['gid'], keep='last')
+        .reset_index(drop=True)
+    )
+
+    # 确保 hours_until_departure 在 [18, 54] 的 范围内
+    # df_min_hours = df_min_hours[
+    #     df_min_hours['hours_until_departure'].between(18, 54)
+    # ].reset_index(drop=True)
+
+    drop_info_csv_path = os.path.join(output_dir, f'{group_route_str}_drop_info.csv')
+    if os.path.exists(drop_info_csv_path):
+        df_drop_nodes = pd.read_csv(drop_info_csv_path)
     else:
-        df_drop_nodes = None
+        df_drop_nodes = pd.DataFrame()
+
+    keep_info_csv_path = os.path.join(output_dir, f'{group_route_str}_keep_info.csv')
+    if os.path.exists(keep_info_csv_path):
+        df_keep_nodes = pd.read_csv(keep_info_csv_path)
+    else:
+        df_keep_nodes = pd.DataFrame()
+
+    df_min_hours['simple_will_price_drop'] = -1   # -1 表示未知
+    df_min_hours['simple_drop_in_hours'] = 0
+    df_min_hours['simple_drop_in_hours_prob'] = 0.0
+    df_min_hours['simple_drop_in_hours_dist'] = ''
+    
+    # 这个阈值取多少?
+    pct_threshold = 0.01
+    # pct_threshold = 2
+    pct_threshold_1 = 0.001
+    pct_threshold_c = 0.001
+
+    for idx, row in df_min_hours.iterrows(): 
+        city_pair = row['city_pair']
+        flight_number_1 = row['flight_number_1']
+        flight_number_2 = row['flight_number_2']
+        price_change_percent = row['price_change_percent']
+        price_duration_hours = row['price_duration_hours']
+        hours_until_departure = row['hours_until_departure']
+        # 针对历史上发生的 高价->低价
+        if not df_drop_nodes.empty:
+            # 对准航班号, 不同起飞日期
+            if flight_number_2 and flight_number_2 != 'VJ':
+                df_drop_nodes_part = df_drop_nodes[
+                    (df_drop_nodes['city_pair'] == city_pair) &
+                    (df_drop_nodes['flight_number_1'] == flight_number_1) &
+                    (df_drop_nodes['flight_number_2'] == flight_number_2)
+                ]
+            else:
+                df_drop_nodes_part = df_drop_nodes[
+                    (df_drop_nodes['city_pair'] == city_pair) &
+                    (df_drop_nodes['flight_number_1'] == flight_number_1)
+                ]
+            
+            # 降价前 增幅阈值的匹配 与 高价历史持续时间 得出降价时间的概率
+            if not df_drop_nodes_part.empty and pd.notna(price_change_percent):   
+                # 增幅太小的去掉
+                df_drop_nodes_part = df_drop_nodes_part[df_drop_nodes_part['high_price_change_percent'] >= 0.1]
+                # pct_vals = df_drop_nodes_part['high_price_change_percent'].replace([np.inf, -np.inf], np.nan).dropna()
+                # # 保留百分位 10% ~ 90% 之间的 数据
+                # if not pct_vals.empty:
+                #     q10 = float(pct_vals.quantile(0.10))
+                #     q90 = float(pct_vals.quantile(0.90))
+                #     df_drop_nodes_part = df_drop_nodes_part[
+                #         df_drop_nodes_part['high_price_change_percent'].between(q10, q90)
+                #     ]
+                # if df_drop_nodes_part.empty:
+                #     continue
+                pct_diff = (df_drop_nodes_part['high_price_change_percent'] - float(price_change_percent)).abs()
+                df_match = df_drop_nodes_part.loc[pct_diff <= pct_threshold, ['high_price_duration_hours', 'high_price_change_percent']].copy()
+
+                if not df_match.empty and pd.notna(price_duration_hours):
+                    remaining_hours = (df_match['high_price_duration_hours'] - float(price_duration_hours)).clip(lower=0)
+                    remaining_hours = remaining_hours.round().astype(int)
+
+                    counts = remaining_hours.value_counts().sort_index()
+                    probs = (counts / counts.sum()).round(4)
+
+                    top_hours = int(probs.idxmax())
+                    top_prob = float(probs.max())
+
+                    dist_items = list(zip(probs.index.tolist(), probs.tolist()))
+                    dist_items = dist_items[:10]
+                    dist_str = ' | '.join([f"{int(h)}:{float(p)}" for h, p in dist_items])
+
+                    df_min_hours.loc[idx, 'simple_will_price_drop'] = 1
+                    df_min_hours.loc[idx, 'simple_drop_in_hours'] = top_hours
+                    df_min_hours.loc[idx, 'simple_drop_in_hours_prob'] = top_prob
+                    df_min_hours.loc[idx, 'simple_drop_in_hours_dist'] = dist_str
+
+                    continue   # 已经判定降价 后面不再做  
+
+        # 针对历史上发生 一直低价、一直高价、低价->高价、连续低价 等 
+        if not df_keep_nodes.empty:
+            # 对准航班号, 不同起飞日期
+            if flight_number_2 and flight_number_2 != 'VJ':
+                df_keep_nodes_part = df_keep_nodes[
+                    (df_keep_nodes['city_pair'] == city_pair) &
+                    (df_keep_nodes['flight_number_1'] == flight_number_1) &
+                    (df_keep_nodes['flight_number_2'] == flight_number_2)
+                ]
+            else:
+                df_keep_nodes_part = df_keep_nodes[
+                    (df_keep_nodes['city_pair'] == city_pair) &
+                    (df_keep_nodes['flight_number_1'] == flight_number_1)
+                ]
+
+            if not df_keep_nodes_part.empty and pd.notna(price_change_percent):
+                # pct_vals_1 = df_keep_nodes_part['keep_price_change_percent'].replace([np.inf, -np.inf], np.nan).dropna()
+                # # 保留百分位 10% ~ 90% 之间的 数据
+                # if not pct_vals_1.empty:
+                #     q10_1 = float(pct_vals_1.quantile(0.10))
+                #     q90_1 = float(pct_vals_1.quantile(0.90))
+                #     df_keep_nodes_part = df_keep_nodes_part[
+                #         df_keep_nodes_part['keep_price_change_percent'].between(q10_1, q90_1)
+                #     ]
+                # if df_keep_nodes_part.empty:
+                #     continue
+                
+                # 特殊判定场景
+                if price_change_percent < 0:
+
+                    df_tmp = df_keep_nodes_part.copy()
+                    # 确保组内顺序正确(如果前面已经排过,这行可省略)
+                    df_tmp = df_tmp.sort_values(
+                        by=["flight_day", "keep_hours_until_departure"],
+                        ascending=[True, False]
+                    )
+                    # 是否为负值
+                    df_tmp["is_negative"] = df_tmp["keep_price_change_percent"] < 0
+                    
+                    if df_tmp["is_negative"].any():
+                        # 标记“负值段”的开始
+                        # 当 is_negative 为 True 且 前一行不是负值时,认为是一个新段
+                        df_tmp["neg_block_id"] = (
+                            df_tmp["is_negative"]
+                            & ~df_tmp.groupby("flight_day")["is_negative"].shift(fill_value=False)
+                        ).groupby(df_tmp["flight_day"]).cumsum()
+                        # 在每个负值段内计数(第几个负值)
+                        df_tmp["neg_rank_in_block"] = (
+                            df_tmp.groupby(["flight_day", "neg_block_id"])
+                            .cumcount() + 1
+                        )
+                        # 每个连续负值段的长度
+                        df_tmp["neg_block_size"] = (
+                            df_tmp.groupby(["flight_day", "neg_block_id"])["is_negative"]
+                            .transform("sum")
+                        )
+                        # 只保留:
+                        # 1) 是负值
+                        # 2) 且不是该连续负值段的最后一个
+                        df_continuous_price_drop = df_tmp[
+                            (df_tmp["is_negative"]) &
+                            (df_tmp["neg_rank_in_block"] < df_tmp["neg_block_size"])
+                        ].drop(
+                            columns=[
+                                "is_negative",
+                                "neg_block_id",
+                                "neg_rank_in_block",
+                                "neg_block_size",
+                            ]
+                        )
+                        pct_diff_c = (df_continuous_price_drop['keep_price_change_percent'] - float(price_change_percent)).abs()
+                        df_match_c = df_continuous_price_drop.loc[pct_diff_c <= pct_threshold_c, ['flight_day', 'keep_hours_until_departure', 'keep_price_duration_hours', 'keep_price_change_percent']].copy()
+
+                        # 符合连续降价条件
+                        if not df_match_c.empty and pd.notna(price_duration_hours):
+                            vals_c = df_match_c['keep_price_duration_hours'].replace([np.inf, -np.inf], np.nan).dropna()
+                            if not vals_c.empty:
+                                min_val = vals_c.min()
+                                if min_val <= float(price_duration_hours):
+                                    df_min_hours.loc[idx, 'simple_will_price_drop'] = 1
+                                    df_min_hours.loc[idx, 'simple_drop_in_hours'] = 0
+                                    df_min_hours.loc[idx, 'simple_drop_in_hours_prob'] = 0.5
+                                    df_min_hours.loc[idx, 'simple_drop_in_hours_dist'] = ''
+                                    continue
+
+                # 一般判定场景
+                pct_diff_1 = (df_keep_nodes_part['keep_price_change_percent'] - float(price_change_percent)).abs()
+                df_match_1 = df_keep_nodes_part.loc[pct_diff_1 <= pct_threshold_1, ['flight_day', 'keep_hours_until_departure', 'keep_price_duration_hours', 'keep_price_change_percent']].copy()
+
+                if not df_match_1.empty and pd.notna(price_duration_hours):
+                    
+                    df_match_1['hours_delta'] = hours_until_departure - df_match_1['keep_hours_until_departure']
+                    df_match_1['modify_keep_price_duration_hours'] = df_match_1['keep_price_duration_hours'] - df_match_1['hours_delta']
+                    df_match_1 = df_match_1[df_match_1['modify_keep_price_duration_hours'] > 0]
+
+                    # 比较 price_duration_hours 在 modify_keep_price_duration_hours 的百分位                    
+                    vals = df_match_1['modify_keep_price_duration_hours'].replace([np.inf, -np.inf], np.nan).dropna()
+                    if not vals.empty:
+                        q10_11 = float(vals.quantile(0.10))
+                        # q90_11 = float(vals.quantile(0.90))
+                        if q10_11 <= float(price_duration_hours):
+                            df_min_hours.loc[idx, 'simple_will_price_drop'] = 0
+                            df_min_hours.loc[idx, 'simple_drop_in_hours'] = 0
+                            df_min_hours.loc[idx, 'simple_drop_in_hours_prob'] = 0.0
+                            df_min_hours.loc[idx, 'simple_drop_in_hours_dist'] = ''
+    
+    df_min_hours = df_min_hours.rename(columns={'seg1_dep_time': 'from_time'})
+    _pred_dt = pd.to_datetime(str(pred_time_str), format="%Y%m%d%H%M", errors="coerce")
+    df_min_hours["update_hour"] = _pred_dt
+    _dep_hour = pd.to_datetime(df_min_hours["from_time"], errors="coerce").dt.floor("h")
+    df_min_hours["valid_begin_hour"] = _dep_hour - pd.to_timedelta(54, unit="h")
+    df_min_hours["valid_end_hour"] = _dep_hour - pd.to_timedelta(18, unit="h")
+
+    order_cols = ['city_pair', 'flight_day', 'flight_number_1', 'flight_number_2', 'from_time', 'baggage', 'currency', 
+                  'adult_total_price', 'hours_until_departure', 'price_change_percent', 'price_duration_hours',
+                  'update_hour', 'crawl_date',
+                  'valid_begin_hour', 'valid_end_hour',
+                  'simple_will_price_drop', 'simple_drop_in_hours', 'simple_drop_in_hours_prob', 'simple_drop_in_hours_dist'
+                 ]
+    df_predict = df_min_hours[order_cols]
+    df_predict = df_predict.rename(columns={
+            'simple_will_price_drop': 'will_price_drop',
+            'simple_drop_in_hours': 'drop_in_hours',
+            'simple_drop_in_hours_prob': 'drop_in_hours_prob',
+            'simple_drop_in_hours_dist': 'drop_in_hours_dist',
+        }
+    )
+
+    csv_path1 = os.path.join(predict_dir, f'future_predictions_{pred_time_str}.csv')
+    df_predict.to_csv(csv_path1, mode='a', index=False, header=not os.path.exists(csv_path1), encoding='utf-8-sig')
 
-    return df_input, df_drop_nodes
+    print("预测结果已追加")
+    return df_predict

+ 19 - 9
main_pe_0.py

@@ -3,7 +3,7 @@ import time
 from datetime import datetime, timedelta
 from config import mongodb_config, vj_flight_route_list_hot, vj_flight_route_list_nothot, CLEAN_VJ_HOT_NEAR_INFO_TAB, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
 from data_loader import load_train_data
-from data_preprocess import preprocess_data_simple
+from data_preprocess import preprocess_data_simple, predict_data_simple
 from utils import chunk_list_with_index
 
 
@@ -11,12 +11,12 @@ def start_predict():
     print(f"开始预测")
 
     output_dir = "./data_shards_0"
-    photo_dir = "./photo_0"
+    # photo_dir = "./photo_0"
     predict_dir = "./predictions_0"
 
     # 确保目录存在
     os.makedirs(output_dir, exist_ok=True) 
-    os.makedirs(photo_dir, exist_ok=True)
+    # os.makedirs(photo_dir, exist_ok=True)
     os.makedirs(predict_dir, exist_ok=True)        
 
     cpu_cores = os.cpu_count()  # 你的系统是72
@@ -29,6 +29,15 @@ def start_predict():
     hourly_time_str = hourly_time.strftime("%Y%m%d%H%M")
     print(f"预测时间:{current_time_str}, (取整): {hourly_time_str}")
 
+    # 清空上一次(同小时内)预测结果
+    csv_file_list = [f'future_predictions_{hourly_time_str}.csv']
+    for csv_file in csv_file_list:
+        try:
+            csv_path = os.path.join(predict_dir, csv_file)
+            os.remove(csv_path)
+        except Exception as e:
+            print(f"remove {csv_path} info: {str(e)}")
+
     # 预测时间范围,满足起飞时间 在18小时后到54小时后
     pred_hour_begin = hourly_time + timedelta(hours=18)
     pred_hour_end = hourly_time + timedelta(hours=54)
@@ -61,6 +70,7 @@ def start_predict():
         # client, db = mongo_con_parse()
         print(f"第 {i} 组 :", group_route_list)
         # batch_flight_routes = group_route_list
+        group_route_str = ','.join(group_route_list)
 
         # 根据索引位置决定是 热门 还是 冷门
         if 0 <= i < route_len_hot:
@@ -101,16 +111,16 @@ def start_predict():
             print(f"起飞时间在 {pred_hour_begin} 到 {pred_hour_end} 之间没有航班,跳过此批次。")
             continue
         
-        df_test_inputs = preprocess_data_simple(df_test)
+        df_test_inputs, _, _, = preprocess_data_simple(df_test)
 
-        # 保存临时文件
-        csv_path = os.path.join(output_dir, f'temp.csv')
-        df_test_inputs.to_csv(csv_path, mode='a', index=False, header=not os.path.exists(csv_path), encoding='utf-8-sig')
+        df_predict = predict_data_simple(df_test_inputs, group_route_str, output_dir, predict_dir, hourly_time_str)
         
         del df_test_inputs
-        pass
+        del df_predict
+        time.sleep(1)
 
-    pass
+    print("所有批次的预测结束")
+    print()
 
 if __name__ == "__main__":
     start_predict()

+ 38 - 26
main_tr_0.py

@@ -8,6 +8,31 @@ from data_preprocess import preprocess_data_simple
 from utils import chunk_list_with_index
 
 
+def merge_and_overwrite_csv(df_new, csv_path, dedup_cols):
+    key_cols = [c for c in dedup_cols if c in df_new.columns]
+
+    # 若干天后的训练:如果本次 df_new 里某些 flight_day(连同航班键)在历史 CSV df_old 里已经出现过,就认为这一天已经处理过了,
+    # 本次不再追加这一天的任何节点;只追加“历史里不存在的 flight_day(同航班键)”的数据
+    if os.path.exists(csv_path):
+        df_old = pd.read_csv(csv_path, encoding='utf-8-sig')
+        if key_cols and all(c in df_old.columns for c in key_cols):
+            df_old_keys = df_old[key_cols].drop_duplicates()
+            df_add = df_new.merge(df_old_keys, on=key_cols, how='left', indicator=True)
+            df_add = df_add[df_add['_merge'] == 'left_only'].drop(columns=['_merge'])
+        else:
+            df_add = df_new.copy()
+        df_merged = pd.concat([df_old, df_add], ignore_index=True)
+    # 第一次训练:直接保留,不做去重
+    else:
+        df_merged = df_new.copy()
+
+    sort_cols = [c for c in dedup_cols if c in df_merged.columns]
+    if sort_cols:
+        df_merged = df_merged.sort_values(by=sort_cols).reset_index(drop=True)
+
+    df_merged.to_csv(csv_path, index=False, encoding='utf-8-sig')
+
+
 def start_train():
     print(f"开始训练")
 
@@ -23,8 +48,8 @@ def start_train():
 
     # date_end = datetime.today().strftime("%Y-%m-%d")
     date_end = (datetime.today() - timedelta(days=1)).strftime("%Y-%m-%d")
-    date_begin = (datetime.today() - timedelta(days=31)).strftime("%Y-%m-%d")
-    # date_begin = "2025-12-01"
+    # date_begin = (datetime.today() - timedelta(days=32)).strftime("%Y-%m-%d")
+    date_begin = "2025-12-01"
 
     print(f"训练时间范围: {date_begin} 到 {date_end}")
 
@@ -80,37 +105,24 @@ def start_train():
             print(f"训练数据为空,跳过此批次。")
             continue
 
-        _, df_drop_nodes = preprocess_data_simple(df_train, is_train=True)
+        _, df_drop_nodes, df_keep_nodes = preprocess_data_simple(df_train, is_train=True)
+
+        dedup_cols = ['city_pair', 'flight_number_1', 'flight_number_2', 'flight_day']
 
         drop_info_csv_path = os.path.join(output_dir, f'{group_route_str}_drop_info.csv')
         if df_drop_nodes.empty:
             print(f"df_drop_nodes 为空,跳过保存: {drop_info_csv_path}")
-            continue
-
-        dedup_cols = ['city_pair', 'flight_number_1', 'flight_number_2', 'flight_day']
-        key_cols = [c for c in dedup_cols if c in df_drop_nodes.columns]
-
-        # 若干天后的训练:如果本次 df_drop_nodes 里某些 flight_day(连同航班键)在历史 CSV 里已经出现过,就认为这一天已经处理过了,
-        # 本次不再追加这一天的任何节点;只追加“历史里不存在的 flight_day(同航班键)”的数据
-        if os.path.exists(drop_info_csv_path):
-            df_old = pd.read_csv(drop_info_csv_path, encoding='utf-8-sig')
-            if key_cols and all(c in df_old.columns for c in key_cols):
-                df_old_keys = df_old[key_cols].drop_duplicates()
-                df_new = df_drop_nodes.merge(df_old_keys, on=key_cols, how='left', indicator=True)
-                df_new = df_new[df_new['_merge'] == 'left_only'].drop(columns=['_merge'])
-            else:
-                df_new = df_drop_nodes.copy()
-            df_merged = pd.concat([df_old, df_new], ignore_index=True)
-        # 第一次训练:直接保留,不做 dedup_cols 的去重
         else:
-            df_merged = df_drop_nodes.copy()
+            merge_and_overwrite_csv(df_drop_nodes, drop_info_csv_path, dedup_cols)
+            print(f"本批次训练已保存csv文件: {drop_info_csv_path}")
 
-        sort_cols = [c for c in dedup_cols if c in df_merged.columns]
-        if sort_cols:
-            df_merged = df_merged.sort_values(by=sort_cols).reset_index(drop=True)
-        df_merged.to_csv(drop_info_csv_path, index=False, encoding='utf-8-sig')
+        keep_info_csv_path = os.path.join(output_dir, f'{group_route_str}_keep_info.csv')
+        if df_keep_nodes.empty:
+            print(f"df_keep_nodes 为空,跳过保存: {keep_info_csv_path}")
+        else:
+            merge_and_overwrite_csv(df_keep_nodes, keep_info_csv_path, dedup_cols)
+            print(f"本批次训练已保存csv文件: {keep_info_csv_path}")
 
-        print(f"本批次训练已保存csv文件: {drop_info_csv_path}")
         time.sleep(1)
 
     print(f"所有批次训练已完成")