Browse Source

提交训练过程相关

node04 3 days ago
parent
commit
c586c3cdb1
4 changed files with 274 additions and 0 deletions
  1. 1 0
      .gitignore
  2. 158 0
      data_process.py
  3. 89 0
      main_tr.py
  4. 26 0
      utils.py

+ 1 - 0
.gitignore

@@ -32,6 +32,7 @@ desktop.ini
 # ===== Project-specific =====
 # 生成的图表图片
 photo/
+data_shards/
 
 # 字体文件(体积大,不适合版本控制)
 *.ttf

+ 158 - 0
data_process.py

@@ -0,0 +1,158 @@
+import pandas as pd
+import numpy as np
+import gc
+import os
+
+
+def preprocess_data_simple(df_input, is_train=False):
+
+    print(">>> 开始数据预处理")
+    # 城市码映射成数字    
+    
+    # gid:基于指定字段的分组标记(整数)
+    df_input['gid'] = (
+        df_input
+        .groupby(
+            ['citypair', 'flight_numbers', 'from_date'],    # 'baggage_weight' 先不进分组
+            sort=False
+        )
+        .ngroup()
+    )
+
+    # 在 gid 与 baggage_weight 内按时间降序
+    df_input = df_input.sort_values(
+        by=['gid', 'baggage_weight', 'hours_until_departure'],
+        ascending=[True, True, False]
+    ).reset_index(drop=True)
+
+    df_input = df_input[df_input['hours_until_departure'] <= 480]
+    df_input = df_input[df_input['baggage_weight'] == 20]   # 先保留20公斤行李的
+
+    # 在hours_until_departure 的末尾 保留真实的而不是补齐的数据
+    if not is_train:
+        _tail_filled = df_input.groupby(['gid', 'baggage_weight'])['is_filled'].transform(
+            lambda s: s.iloc[::-1].cummin().iloc[::-1]
+        )
+        df_input = df_input[~((df_input['is_filled'] == 1) & (_tail_filled == 1))]
+
+    # 价格变化最小量阈值
+    price_change_amount_threshold = 5
+    df_input['_raw_price_diff'] = df_input.groupby(['gid', 'baggage_weight'], group_keys=False)['price_total'].diff()
+
+    # 计算价格变化量
+    df_input['price_change_amount'] = (
+        df_input['_raw_price_diff']
+        .mask(df_input['_raw_price_diff'].abs() < price_change_amount_threshold, 0)
+        .replace(0, np.nan)
+        .groupby([df_input['gid'], df_input['baggage_weight']], group_keys=False)
+        .ffill()
+        .fillna(0)
+        .round(2)
+    )
+
+    # 计算价格变化百分比(相对于上一时间点的变化率)
+    df_input['price_change_percent'] = (
+        df_input.groupby(['gid', 'baggage_weight'], group_keys=False)['price_total']
+        .pct_change()
+        .mask(df_input['_raw_price_diff'].abs() < price_change_amount_threshold, 0)
+        .replace(0, np.nan)
+        .groupby([df_input['gid'], df_input['baggage_weight']], group_keys=False)
+        .ffill()
+        .fillna(0)
+        .round(4)
+    )
+
+    # 第一步:标记价格变化段
+    df_input['price_change_segment'] = (
+        df_input.groupby(['gid', 'baggage_weight'], group_keys=False)['price_change_amount']
+        .apply(lambda s: (s != s.shift()).cumsum())
+    )
+
+    # 第二步:计算每个变化段内的持续时间
+    df_input['price_duration_hours'] = (
+        df_input.groupby(['gid', 'baggage_weight', 'price_change_segment'], group_keys=False)
+        .cumcount()
+        .add(1)
+    )
+
+    # 可选:删除临时列
+    df_input = df_input.drop(columns=['price_change_segment', '_raw_price_diff'])
+
+    # 训练过程
+    if is_train:
+        df_target = df_input[(df_input['hours_until_departure'] >= 24) & (df_input['hours_until_departure'] <= 360)].copy()
+        df_target = df_target.sort_values(
+            by=['gid', 'baggage_weight', 'hours_until_departure'],
+            ascending=[True, True, False]
+        ).reset_index(drop=True)
+
+        # 对于先升后降的分析
+        prev_pct = df_target.groupby(['gid', 'baggage_weight'], group_keys=False)['price_change_percent'].shift(1)
+        prev_amo = df_target.groupby(['gid', 'baggage_weight'], group_keys=False)['price_change_amount'].shift(1)
+        prev_dur = df_target.groupby(['gid', 'baggage_weight'], group_keys=False)['price_duration_hours'].shift(1)
+        prev_price = df_target.groupby(['gid', 'baggage_weight'], group_keys=False)['price_total'].shift(1)
+        drop_mask = (prev_pct > 0) & (df_target['price_change_percent'] < 0)
+
+        df_drop_nodes = df_target.loc[drop_mask, ['gid', 'baggage_weight', 'hours_until_departure']].copy()
+        df_drop_nodes.rename(columns={'hours_until_departure': 'drop_hours_until_departure'}, inplace=True)
+        df_drop_nodes['drop_price_change_percent'] = df_target.loc[drop_mask, 'price_change_percent'].astype(float).round(4).to_numpy()
+        df_drop_nodes['drop_price_change_amount'] = df_target.loc[drop_mask, 'price_change_amount'].astype(float).round(2).to_numpy()
+        df_drop_nodes['high_price_duration_hours'] = prev_dur.loc[drop_mask].astype(float).to_numpy()
+        df_drop_nodes['high_price_change_percent'] = prev_pct.loc[drop_mask].astype(float).round(4).to_numpy()
+        df_drop_nodes['high_price_change_amount'] = prev_amo.loc[drop_mask].astype(float).round(2).to_numpy()
+        df_drop_nodes['high_price_amount'] = prev_price.loc[drop_mask].astype(float).round(2).to_numpy()
+        df_drop_nodes = df_drop_nodes.reset_index(drop=True)
+
+        flight_info_cols = [
+            'citypair', 'flight_numbers', 'from_time', 'from_date', 'currency',
+        ]
+        flight_info_cols = [c for c in flight_info_cols if c in df_target.columns]
+        df_gid_info = df_target[['gid', 'baggage_weight'] + flight_info_cols].drop_duplicates(subset=['gid', 'baggage_weight']).reset_index(drop=True)
+        df_drop_nodes = df_drop_nodes.merge(df_gid_info, on=['gid', 'baggage_weight'], how='left')
+
+        drop_info_cols = [
+            'drop_hours_until_departure', 'drop_price_change_percent', 'drop_price_change_amount',
+            'high_price_duration_hours', 'high_price_change_percent', 'high_price_change_amount', 'high_price_amount', 
+        ]
+        # 按顺序排列 去掉gid
+        df_drop_nodes = df_drop_nodes[flight_info_cols + ['baggage_weight'] + drop_info_cols]
+        
+        # 对于“上涨后再次上涨”的分析(连续两个正向变价段)
+        seg_start_mask = df_target['price_duration_hours'].eq(1)
+        rise_mask = seg_start_mask & (prev_pct > 0) & (df_target['price_change_percent'] > 0)
+
+        df_rise_nodes = df_target.loc[rise_mask, ['gid', 'baggage_weight', 'hours_until_departure']].copy()
+        df_rise_nodes.rename(columns={'hours_until_departure': 'rise_hours_until_departure'}, inplace=True)
+        df_rise_nodes['rise_price_change_percent'] = df_target.loc[rise_mask, 'price_change_percent'].astype(float).round(4).to_numpy()
+        df_rise_nodes['rise_price_change_amount'] = df_target.loc[rise_mask, 'price_change_amount'].astype(float).round(2).to_numpy()
+        df_rise_nodes['prev_rise_duration_hours'] = prev_dur.loc[rise_mask].astype(float).to_numpy()
+        df_rise_nodes['prev_rise_change_percent'] = prev_pct.loc[rise_mask].astype(float).round(4).to_numpy()
+        df_rise_nodes['prev_rise_change_amount'] = prev_amo.loc[rise_mask].astype(float).round(2).to_numpy()
+        df_rise_nodes['prev_rise_amount'] = prev_price.loc[rise_mask].astype(float).round(2).to_numpy()
+        df_rise_nodes = df_rise_nodes.reset_index(drop=True)
+
+        df_rise_nodes = df_rise_nodes.merge(df_gid_info, on=['gid', 'baggage_weight'], how='left')
+        
+        rise_info_cols = [
+            'rise_hours_until_departure', 'rise_price_change_percent', 'rise_price_change_amount',
+            'prev_rise_duration_hours', 'prev_rise_change_percent', 'prev_rise_change_amount', 'prev_rise_amount',
+        ]
+        df_rise_nodes = df_rise_nodes[flight_info_cols + ['baggage_weight'] + rise_info_cols]
+        
+        # 制作历史包络线
+        envelope_group = ['citypair', 'flight_numbers', 'from_date', 'baggage_weight']
+        idx_peak = df_input.groupby(envelope_group)['price_total'].idxmax()
+        df_envelope = df_input.loc[idx_peak, envelope_group + [
+            'price_total', 'hours_until_departure'
+        ]].rename(columns={
+            'price_total': 'peak_price',
+            'hours_until_departure': 'peak_hours',
+        }).reset_index(drop=True)
+        
+        del df_gid_info
+        del df_target
+
+        return df_input, df_drop_nodes, df_rise_nodes, df_envelope
+    
+    return df_input, None, None, None
+ 

+ 89 - 0
main_tr.py

@@ -0,0 +1,89 @@
+import os
+import time
+import gc
+import pandas as pd
+from datetime import datetime, timedelta
+from config import mongo_config, uo_city_pairs_new
+from data_loader import load_data
+from data_process import preprocess_data_simple
+from utils import merge_and_overwrite_csv
+
+
+def start_train():
+    print(f"开始训练")
+
+    output_dir = "./data_shards"
+
+    # 确保目录存在
+    os.makedirs(output_dir, exist_ok=True) 
+
+    cpu_cores = os.cpu_count()  # 你的系统是72
+    max_workers = min(8, cpu_cores)  # 最大不超过8个进程
+
+    from_date_end = (datetime.today() - timedelta(days=1)).strftime("%Y-%m-%d")
+    from_date_begin = "2026-03-17"
+
+    print(f"训练时间范围: {from_date_begin} 到 {from_date_end}")
+
+    uo_city_pairs = uo_city_pairs_new.copy()
+    uo_city_pair_list = [f"{pair[:3]}-{pair[3:]}" for pair in uo_city_pairs]            
+
+    # 如果临时处理中断,从日志里找到 中断的索引 修改它
+    resume_idx = 0
+    uo_city_pair_list = uo_city_pair_list[resume_idx:]
+
+    # 打印训练阶段起始索引顺序
+    max_len = len(uo_city_pair_list) + resume_idx
+    print(f"训练阶段起始索引顺序:{resume_idx} ~ {max_len - 1}")
+
+    for idx, uo_city_pair in enumerate(uo_city_pair_list, start=resume_idx):
+        print(f"第 {idx} 组 :", uo_city_pair)
+
+        # 加载训练数据
+        start_time = time.time()
+        df_train = load_data(mongo_config, uo_city_pair, from_date_begin, from_date_end,
+                             use_multiprocess=True, max_workers=max_workers)
+        end_time = time.time()
+        run_time = round(end_time - start_time, 3)
+        print(f"用时: {run_time} 秒")
+
+        if df_train.empty:
+            print(f"训练数据为空,跳过此批次。")
+            continue
+        
+        _, df_drop_nodes, df_rise_nodes, df_envelope = preprocess_data_simple(df_train, is_train=True)
+
+        dedup_cols = ['citypair', 'flight_numbers', 'from_date', 'baggage_weight']
+
+        drop_info_csv_path = os.path.join(output_dir, f'{uo_city_pair}_drop_info.csv')
+        if df_drop_nodes.empty:
+            print(f"df_drop_nodes 为空,跳过保存: {drop_info_csv_path}")
+        else:
+            merge_and_overwrite_csv(df_drop_nodes, drop_info_csv_path, dedup_cols)
+            print(f"本批次训练已保存csv文件: {drop_info_csv_path}")
+        
+        rise_info_csv_path = os.path.join(output_dir, f'{uo_city_pair}_rise_info.csv')
+        if df_rise_nodes.empty:
+            print(f"df_rise_nodes 为空,跳过保存: {rise_info_csv_path}")
+        else:
+            merge_and_overwrite_csv(df_rise_nodes, rise_info_csv_path, dedup_cols)
+            print(f"本批次训练已保存csv文件: {rise_info_csv_path}")
+        
+        envelope_csv_path = os.path.join(output_dir, f'{uo_city_pair}_envelope_info.csv')
+        if df_envelope.empty:
+            print(f"df_envelope 为空,跳过保存: {envelope_csv_path}")
+        else:
+            merge_and_overwrite_csv(df_envelope, envelope_csv_path, dedup_cols)
+            print(f"本批次训练已保存csv文件: {envelope_csv_path}")
+        
+        del df_drop_nodes
+        del df_rise_nodes
+        del df_envelope
+
+        gc.collect()
+        time.sleep(1)
+
+    print(f"所有批次训练已完成")
+
+if __name__ == "__main__":
+    start_train()

+ 26 - 0
utils.py

@@ -0,0 +1,26 @@
+import os
+import pandas as pd
+
+def merge_and_overwrite_csv(df_new, csv_path, dedup_cols):
+    key_cols = [c for c in dedup_cols if c in df_new.columns]
+
+    # 若干天后的训练:如果本次 df_new 里某些 flight_day(连同航班键)在历史 CSV df_old 里已经出现过,就认为这一天已经处理过了,
+    # 本次不再追加这一天的任何节点;只追加“历史里不存在的 flight_day(同航班键)”的数据
+    if os.path.exists(csv_path):
+        df_old = pd.read_csv(csv_path, encoding='utf-8-sig')
+        if key_cols and all(c in df_old.columns for c in key_cols):
+            df_old_keys = df_old[key_cols].drop_duplicates()
+            df_add = df_new.merge(df_old_keys, on=key_cols, how='left', indicator=True)  # indicator=True 会在结果df中添加一个_merge列
+            df_add = df_add[df_add['_merge'] == 'left_only'].drop(columns=['_merge'])    # left_only 只在左表(df_new)中存在
+        else:
+            df_add = df_new.copy()
+        df_merged = pd.concat([df_old, df_add], ignore_index=True)
+    # 第一次训练:直接保留,不做去重
+    else:
+        df_merged = df_new.copy()
+
+    sort_cols = [c for c in dedup_cols if c in df_merged.columns]
+    if sort_cols:
+        df_merged = df_merged.sort_values(by=sort_cols).reset_index(drop=True)  # 重新分组排序
+
+    df_merged.to_csv(csv_path, index=False, encoding='utf-8-sig')