Преглед изворни кода

提交快速模型相关代码

node04 пре 3 недеља
родитељ
комит
212712f15b
3 измењених фајлова са 333 додато и 0 уклоњено
  1. 98 0
      data_preprocess.py
  2. 116 0
      main_pe_0.py
  3. 119 0
      main_tr_0.py

+ 98 - 0
data_preprocess.py

@@ -831,3 +831,101 @@ def standardization(df, feature_scaler, target_scaler=None, is_training=True, is
     print(">>> 基于固定范围的特征数据归一化完成")
 
     return df, feature_scaler, target_scaler
+
+
+def preprocess_data_simple(df_input, is_train=False, output_dir='.'):
+
+    df_input = preprocess_data_first_half(df_input)
+    
+    # 在 gid 与 baggage 内按时间降序
+    df_input = df_input.sort_values(
+        by=['gid', 'baggage', 'hours_until_departure'],
+        ascending=[True, True, False]
+    ).reset_index(drop=True)
+
+    df_input = df_input[df_input['hours_until_departure'] <= 480]
+    df_input = df_input[df_input['baggage'] == 30]
+
+    # 保留真实的而不是补齐的数据
+    if not is_train:
+        df_input = df_input[df_input['is_filled'] == 0]
+    
+    # 计算价格变化量
+    df_input['price_change_amount'] = (
+        df_input.groupby(['gid', 'baggage'], group_keys=False)['adult_total_price']
+        .apply(lambda s: s.diff().replace(0, np.nan).ffill().fillna(0)).round(2)
+    )
+
+    # 计算价格变化百分比(相对于上一时间点的变化率)
+    df_input['price_change_percent'] = (
+        df_input.groupby(['gid', 'baggage'], group_keys=False)['adult_total_price']
+        .apply(lambda s: s.pct_change().replace(0, np.nan).ffill().fillna(0)).round(4)
+    )
+
+    # 第一步:标记价格变化段
+    df_input['price_change_segment'] = (
+        df_input.groupby(['gid', 'baggage'], group_keys=False)['price_change_amount']
+        .apply(lambda s: (s != s.shift()).cumsum())
+    )
+
+    # 第二步:计算每个变化段内的持续时间
+    df_input['price_duration_hours'] = (
+        df_input.groupby(['gid', 'baggage', 'price_change_segment'], group_keys=False)
+        .cumcount()
+        .add(1)
+    )
+    
+    # 可选:删除临时列
+    df_input = df_input.drop(columns=['price_change_segment'])
+    
+    adult_price = df_input.pop('Adult_Total_Price')
+    hours_until = df_input.pop('Hours_Until_Departure')
+    df_input['Adult_Total_Price'] = adult_price
+    df_input['Hours_Until_Departure'] = hours_until
+    df_input['Baggage'] = df_input['baggage']
+
+    if is_train:
+        df_target = df_input[(df_input['hours_until_departure'] >= 18) & (df_input['hours_until_departure'] <= 54)].copy()
+        df_target = df_target.sort_values(
+            by=['gid', 'hours_until_departure'],
+            ascending=[True, False]
+        ).reset_index(drop=True)
+
+        prev_pct = df_target.groupby('gid', group_keys=False)['price_change_percent'].shift(1)
+        prev_amo = df_target.groupby('gid', group_keys=False)['price_change_amount'].shift(1)
+        prev_dur = df_target.groupby('gid', group_keys=False)['price_duration_hours'].shift(1)
+        drop_mask = (prev_pct > 0) & (df_target['price_change_percent'] < 0)
+        
+        df_drop_nodes = df_target.loc[drop_mask, ['gid', 'hours_until_departure']].copy()
+        df_drop_nodes.rename(columns={'hours_until_departure': 'drop_hours_until_departure'}, inplace=True)
+        df_drop_nodes['drop_price_change_percent'] = df_target.loc[drop_mask, 'price_change_percent'].astype(float).round(4).to_numpy()
+        df_drop_nodes['drop_price_change_amount'] = df_target.loc[drop_mask, 'price_change_amount'].astype(float).round(2).to_numpy()
+        df_drop_nodes['high_price_duration_hours'] = prev_dur.loc[drop_mask].astype(float).to_numpy()
+        df_drop_nodes['high_price_change_percent'] = prev_pct.loc[drop_mask].astype(float).round(4).to_numpy()
+        df_drop_nodes['high_price_change_amount'] = prev_amo.loc[drop_mask].astype(float).round(2).to_numpy()
+
+        df_drop_nodes = df_drop_nodes.reset_index(drop=True)
+
+        flight_info_cols = [
+            'city_pair', 
+            'flight_number_1', 'seg1_dep_air_port', 'seg1_dep_time', 'seg1_arr_air_port', 'seg1_arr_time',
+            'flight_number_2', 'seg2_dep_air_port', 'seg2_dep_time', 'seg2_arr_air_port', 'seg2_arr_time',
+            'currency', 'baggage', 'flight_day',
+        ]
+        
+        df_gid_info = df_target[['gid'] + flight_info_cols].drop_duplicates(subset=['gid']).reset_index(drop=True)
+        df_drop_nodes = df_drop_nodes.merge(df_gid_info, on='gid', how='left')
+
+        drop_info_cols = ['drop_hours_until_departure', 'drop_price_change_percent', 'drop_price_change_amount',
+                          'high_price_duration_hours', 'high_price_change_percent', 'high_price_change_amount'
+        ]
+        # 按顺序排列 去掉gid
+        order_columns = flight_info_cols + drop_info_cols
+        df_drop_nodes = df_drop_nodes[order_columns]
+        
+        del df_gid_info
+        del df_target
+    else:
+        df_drop_nodes = None
+
+    return df_input, df_drop_nodes

+ 116 - 0
main_pe_0.py

@@ -0,0 +1,116 @@
+import os
+import time
+from datetime import datetime, timedelta
+from config import mongodb_config, vj_flight_route_list_hot, vj_flight_route_list_nothot, CLEAN_VJ_HOT_NEAR_INFO_TAB, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
+from data_loader import load_train_data
+from data_preprocess import preprocess_data_simple
+from utils import chunk_list_with_index
+
+
+def start_predict():
+    print(f"开始预测")
+
+    output_dir = "./data_shards_0"
+    photo_dir = "./photo_0"
+    predict_dir = "./predictions_0"
+
+    # 确保目录存在
+    os.makedirs(output_dir, exist_ok=True) 
+    os.makedirs(photo_dir, exist_ok=True)
+    os.makedirs(predict_dir, exist_ok=True)        
+
+    cpu_cores = os.cpu_count()  # 你的系统是72
+    max_workers = min(4, cpu_cores)  # 最大不超过4个进程
+
+    # 当前时间,取整时
+    current_time = datetime.now() 
+    current_time_str = current_time.strftime("%Y%m%d%H%M")
+    hourly_time = current_time.replace(minute=0, second=0, microsecond=0)
+    hourly_time_str = hourly_time.strftime("%Y%m%d%H%M")
+    print(f"预测时间:{current_time_str}, (取整): {hourly_time_str}")
+
+    # 预测时间范围,满足起飞时间 在18小时后到54小时后
+    pred_hour_begin = hourly_time + timedelta(hours=18)
+    pred_hour_end = hourly_time + timedelta(hours=54)
+
+    pred_date_end = pred_hour_end.strftime("%Y-%m-%d")
+    pred_date_begin = pred_hour_begin.strftime("%Y-%m-%d")
+
+    print(f"预测起飞时间范围: {pred_date_begin} 到 {pred_date_end}")
+
+    # 主干代码 (排除冷门航线)
+    flight_route_list = vj_flight_route_list_hot + vj_flight_route_list_nothot[:0]
+    flight_route_list_len = len(flight_route_list)
+    route_len_hot = len(vj_flight_route_list_hot)
+    route_len_nothot = len(vj_flight_route_list_nothot[:0])
+
+    group_size = 1              # 每几组作为一个批次
+
+    chunks = chunk_list_with_index(flight_route_list, group_size)
+    
+    # 如果从中途某个批次预测, 修改起始索引
+    resume_chunk_idx = 0
+    chunks = chunks[resume_chunk_idx:]
+
+    batch_starts = [start_idx for start_idx, _ in chunks]
+    print(f"预测阶段起始索引顺序:{batch_starts}")
+
+    # 预测阶段
+    for i, (_, group_route_list) in enumerate(chunks, start=resume_chunk_idx):
+        # 特殊处理,跳过不好的批次
+        # client, db = mongo_con_parse()
+        print(f"第 {i} 组 :", group_route_list)
+        # batch_flight_routes = group_route_list
+
+        # 根据索引位置决定是 热门 还是 冷门
+        if 0 <= i < route_len_hot:
+            is_hot = 1
+            table_name = CLEAN_VJ_HOT_NEAR_INFO_TAB
+        elif route_len_hot <= i < route_len_hot + route_len_nothot:
+            is_hot = 0
+            table_name = CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
+        else:
+            print(f"无法确定热门还是冷门, 跳过此批次。")
+            continue
+
+        # 加载测试数据 (仅仅是时间段取到后天)
+        start_time = time.time()
+        df_test = load_train_data(mongodb_config, group_route_list, table_name, pred_date_begin, pred_date_end, output_dir, is_hot,
+                                  use_multiprocess=True, max_workers=max_workers)
+        end_time = time.time()
+        run_time = round(end_time - start_time, 3)
+        print(f"用时: {run_time} 秒")  
+
+        if df_test.empty:
+            print(f"测试数据为空,跳过此批次。")
+            continue
+
+        # 按起飞时间过滤
+        # 创建临时字段:seg1_dep_time 的整点时间
+        df_test['seg1_dep_hour'] = df_test['seg1_dep_time'].dt.floor('h')
+        # 使用整点时间进行比较过滤
+        mask = (df_test['seg1_dep_hour'] >= pred_hour_begin) & (df_test['seg1_dep_hour'] < pred_hour_end)
+        original_count = len(df_test)
+        df_test = df_test[mask].reset_index(drop=True)
+        filtered_count = len(df_test)
+        # 删除临时字段
+        df_test = df_test.drop(columns=['seg1_dep_hour'])
+        print(f"按起飞时间过滤:过滤前 {original_count} 条,过滤后 {filtered_count} 条")
+        
+        if filtered_count == 0:
+            print(f"起飞时间在 {pred_hour_begin} 到 {pred_hour_end} 之间没有航班,跳过此批次。")
+            continue
+        
+        df_test_inputs = preprocess_data_simple(df_test)
+
+        # 保存临时文件
+        csv_path = os.path.join(output_dir, f'temp.csv')
+        df_test_inputs.to_csv(csv_path, mode='a', index=False, header=not os.path.exists(csv_path), encoding='utf-8-sig')
+        
+        del df_test_inputs
+        pass
+
+    pass
+
+if __name__ == "__main__":
+    start_predict()

+ 119 - 0
main_tr_0.py

@@ -0,0 +1,119 @@
+import os
+import time
+import pandas as pd
+from datetime import datetime, timedelta
+from config import mongodb_config, vj_flight_route_list_hot, vj_flight_route_list_nothot, CLEAN_VJ_HOT_NEAR_INFO_TAB, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
+from data_loader import load_train_data
+from data_preprocess import preprocess_data_simple
+from utils import chunk_list_with_index
+
+
+def start_train():
+    print(f"开始训练")
+
+    output_dir = "./data_shards_0"
+    # photo_dir = "./photo_0"
+
+    # 确保目录存在
+    os.makedirs(output_dir, exist_ok=True) 
+    # os.makedirs(photo_dir, exist_ok=True) 
+
+    cpu_cores = os.cpu_count()  # 你的系统是72
+    max_workers = min(8, cpu_cores)  # 最大不超过8个进程
+
+    # date_end = datetime.today().strftime("%Y-%m-%d")
+    date_end = (datetime.today() - timedelta(days=1)).strftime("%Y-%m-%d")
+    date_begin = (datetime.today() - timedelta(days=31)).strftime("%Y-%m-%d")
+    # date_begin = "2025-12-01"
+
+    print(f"训练时间范围: {date_begin} 到 {date_end}")
+
+    # 主干代码 (排除冷门航线)
+    flight_route_list = vj_flight_route_list_hot + vj_flight_route_list_nothot[:0]
+    flight_route_list_len = len(flight_route_list)
+    route_len_hot = len(vj_flight_route_list_hot)
+    route_len_nothot = len(vj_flight_route_list_nothot[:0])
+
+    print(f"flight_route_list_len:{flight_route_list_len}")
+    print(f"route_len_hot:{route_len_hot}")
+    print(f"route_len_nothot:{route_len_nothot}")
+
+    group_size = 1    # 每几组作为一个批次
+
+    chunks = chunk_list_with_index(flight_route_list, group_size)
+    
+    # 如果临时处理中断,从日志里找到 中断的索引 修改它
+    resume_chunk_idx = 0
+    chunks = chunks[resume_chunk_idx:]
+
+    batch_starts = [start_idx for start_idx, _ in chunks]
+    print(f"训练阶段起始索引顺序:{batch_starts}")
+
+    # 训练阶段
+    for i, (_, group_route_list) in enumerate(chunks, start=resume_chunk_idx):
+        # 特殊处理,跳过不好的批次
+        # client, db = mongo_con_parse()
+        print(f"第 {i} 组 :", group_route_list)
+        # batch_flight_routes = group_route_list
+        group_route_str = ','.join(group_route_list)
+
+        # 根据索引位置决定是 热门 还是 冷门
+        if 0 <= i < route_len_hot:
+            is_hot = 1
+            table_name = CLEAN_VJ_HOT_NEAR_INFO_TAB
+        elif route_len_hot <= i < route_len_hot + route_len_nothot:
+            is_hot = 0
+            table_name = CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
+        else:
+            print(f"无法确定热门还是冷门, 跳过此批次。")
+            continue
+
+        # 加载训练数据
+        start_time = time.time()
+        df_train = load_train_data(mongodb_config, group_route_list, table_name, date_begin, date_end, output_dir, is_hot,
+                                   use_multiprocess=True, max_workers=max_workers)
+        end_time = time.time()
+        run_time = round(end_time - start_time, 3)
+        print(f"用时: {run_time} 秒")
+
+        if df_train.empty:
+            print(f"训练数据为空,跳过此批次。")
+            continue
+
+        _, df_drop_nodes = preprocess_data_simple(df_train, is_train=True)
+
+        drop_info_csv_path = os.path.join(output_dir, f'{group_route_str}_drop_info.csv')
+        if df_drop_nodes.empty:
+            print(f"df_drop_nodes 为空,跳过保存: {drop_info_csv_path}")
+            continue
+
+        dedup_cols = ['city_pair', 'flight_number_1', 'flight_number_2', 'flight_day']
+        key_cols = [c for c in dedup_cols if c in df_drop_nodes.columns]
+
+        # 若干天后的训练:如果本次 df_drop_nodes 里某些 flight_day(连同航班键)在历史 CSV 里已经出现过,就认为这一天已经处理过了,
+        # 本次不再追加这一天的任何节点;只追加“历史里不存在的 flight_day(同航班键)”的数据
+        if os.path.exists(drop_info_csv_path):
+            df_old = pd.read_csv(drop_info_csv_path, encoding='utf-8-sig')
+            if key_cols and all(c in df_old.columns for c in key_cols):
+                df_old_keys = df_old[key_cols].drop_duplicates()
+                df_new = df_drop_nodes.merge(df_old_keys, on=key_cols, how='left', indicator=True)
+                df_new = df_new[df_new['_merge'] == 'left_only'].drop(columns=['_merge'])
+            else:
+                df_new = df_drop_nodes.copy()
+            df_merged = pd.concat([df_old, df_new], ignore_index=True)
+        # 第一次训练:直接保留,不做 dedup_cols 的去重
+        else:
+            df_merged = df_drop_nodes.copy()
+
+        sort_cols = [c for c in dedup_cols if c in df_merged.columns]
+        if sort_cols:
+            df_merged = df_merged.sort_values(by=sort_cols).reset_index(drop=True)
+        df_merged.to_csv(drop_info_csv_path, index=False, encoding='utf-8-sig')
+
+        print(f"本批次训练已保存csv文件: {drop_info_csv_path}")
+        time.sleep(1)
+
+    print(f"所有批次训练已完成")
+
+if __name__ == "__main__":
+    start_train()