Explorar o código

提交近期修改

node04 hai 1 semana
pai
achega
575bc3d789
Modificáronse 4 ficheiros con 492 adicións e 26 borrados
  1. 98 24
      config.py
  2. 5 2
      data_loader.py
  3. 185 0
      data_preprocess.py
  4. 204 0
      main_tr.py

+ 98 - 24
config.py

@@ -1,3 +1,6 @@
+import holidays
+import pandas as pd
+
 CLEAN_VJ_HOT_NEAR_INFO_TAB = "clean_flights_vj_hot_0_7_info_tab"
 CLEAN_VJ_HOT_FAR_INFO_TAB = "clean_flights_vj_hot_7_30_info_tab"
 CLEAN_VJ_NOTHOT_NEAR_INFO_TAB = "clean_flights_vj_nothot_0_7_info_tab"
@@ -12,6 +15,74 @@ mongodb_config = {
     "pwd": ""
 }
 
+# 城市码-国家码的映射
+city_to_country = {
+    "CAN": "CN",  # 广州,中国
+    "DPS": "ID",  # 巴厘岛,印度尼西亚
+    "HAN": "VN",  # 河内,越南
+    "SGN": "VN",  # 胡志明(西贡),越南
+    "CTU": "CN",  # 成都,中国
+    "DAD": "VN",  # 岘港,越南
+    "SEL": "KR",  # 首尔,韩国
+    "DEL": "IN",  # 德里,印度
+    "UIH": "VN",  # 归仁,越南
+    "HKG": "HK",  # 香港,中国
+    "PQC": "VN",  # 富国岛,越南
+    "KUL": "MY",  # 吉隆坡,马来西亚
+    "NGO": "JP",  # 名古屋,日本
+    "NHA": "VN",  # 芽庄,越南
+    "PUS": "KR",  # 釜山,韩国
+    "SHA": "CN",  # 上海,中国
+    "SIN": "SG",  # 新加坡,新加坡
+    "TPE": "TW",  # 台北,中国台湾
+    "TYO": "JP",  # 东京,日本
+    "BKK": "TH",  # 曼谷,泰国
+    "BLR": "IN",  # 班加罗尔,印度
+    "FUK": "JP",  # 福冈,日本
+    "BMV": "VN",  # 邦美蜀,越南
+    "BNE": "AU",  # 布里斯班,澳大利亚
+    "BOM": "IN",  # 孟买,印度
+    "DLI": "VN",  # 大叻,越南
+    "OSA": "JP",  # 大阪,日本
+    "RMQ": "TW",  # 台中,中国台湾
+    "HKT": "TH",  # 普吉岛,泰国
+    "HPH": "VN",  # 海防,越南
+    "KHH": "TW",  # 高雄,中国台湾
+    "MEL": "AU",  # 墨尔本,澳大利亚
+    "MNL": "PH",  # 马尼拉,菲律宾
+    "SYD": "AU",  # 悉尼,澳大利亚
+    "REP": "KH",  # 暹粒,柬埔寨
+    "VTE": "LA",  # 万象,老挝
+    "HYD": "IN",  # 海得拉巴,印度
+    "AMD": "IN",  # 艾哈迈达巴德,印度
+}
+
+# 生成各个国家(地区)的节假日
+def build_country_holidays(city_to_country):
+    countries = sorted(set(city_to_country.values()))
+    start_date = pd.Timestamp('2025-11-01')
+    end_date = pd.Timestamp('2026-12-31')
+
+    country_holidays = {}
+
+    for country in countries:
+        try:
+            hdays = holidays.country_holidays(
+                country,
+                years=[2025, 2026]
+            )
+            # 转成 set[date],方便高速查询
+            country_holidays[country] = {
+                d for d in hdays
+                if start_date.date() <= d <= end_date.date()
+            }
+        except Exception:
+            # 个别国家 holidays 库可能不支持
+            country_holidays[country] = set()
+
+    return country_holidays
+
+
 # 热门的航线
 vj_flight_route_list_hot = [
     "CAN-DPS", "CAN-HAN", "CAN-SGN", "CTU-HAN", "CTU-SGN",
@@ -46,27 +117,30 @@ vj_flight_route_list = vj_flight_route_list_hot + vj_flight_route_list_nothot
 
 
 if __name__ == '__main__':
-    from collections import Counter
-    # 检查重复项
-    # 统计每个航线出现的次数
-    route_counter = Counter(vj_flight_route_list)
-
-    # 找出重复的航线
-    duplicates = {route: count for route, count in route_counter.items() if count > 1}
-
-    # 输出结果
-    if duplicates:
-        print("发现重复的航线:")
-        for route, count in duplicates.items():
-            print(f"  {route}: 出现 {count} 次")
-
-        print(f"\n总共发现 {len(duplicates)} 条重复航线")
-
-        # 查找这些航线分别在哪个列表中
-        print("\n重复航线分布:")
-        for route in duplicates:
-            hot_count = vj_flight_route_list_hot.count(route)
-            nothot_count = vj_flight_route_list_nothot.count(route)
-            print(f"  {route}: hot列表中出现 {hot_count} 次, nothot列表中出现 {nothot_count} 次")
-    else:
-        print("没有发现重复航线")
+    # from collections import Counter
+    # # 检查重复项
+    # # 统计每个航线出现的次数
+    # route_counter = Counter(vj_flight_route_list)
+
+    # # 找出重复的航线
+    # duplicates = {route: count for route, count in route_counter.items() if count > 1}
+
+    # # 输出结果
+    # if duplicates:
+    #     print("发现重复的航线:")
+    #     for route, count in duplicates.items():
+    #         print(f"  {route}: 出现 {count} 次")
+
+    #     print(f"\n总共发现 {len(duplicates)} 条重复航线")
+
+    #     # 查找这些航线分别在哪个列表中
+    #     print("\n重复航线分布:")
+    #     for route in duplicates:
+    #         hot_count = vj_flight_route_list_hot.count(route)
+    #         nothot_count = vj_flight_route_list_nothot.count(route)
+    #         print(f"  {route}: hot列表中出现 {hot_count} 次, nothot列表中出现 {nothot_count} 次")
+    # else:
+    #     print("没有发现重复航线")
+
+    COUNTRY_HOLIDAYS = build_country_holidays(city_to_country)
+    print(COUNTRY_HOLIDAYS)

+ 5 - 2
data_loader.py

@@ -568,12 +568,15 @@ def load_train_data(db, flight_route_list, table_name, date_begin, date_end, out
         route = f"{from_city}-{to_city}"
         print(f"开始处理航线: {route}")
         all_groups = query_groups_of_city_code(db, from_city, to_city, table_name)
+        all_groups_len = len(all_groups)
+        print(f"该航线共有{all_groups_len}个航班号")
         # 每一组航班号
         for each_group in all_groups:
             flight_nums = each_group.get("flight_numbers")
             print(f"开始处理航班号: {flight_nums}")
             details = each_group.get("details")
-            # 查远期表
+
+            print(f"查远期表")
             if is_hot == 1:
                 df1 = query_flight_range_status(db, CLEAN_VJ_HOT_FAR_INFO_TAB, from_city, to_city,
                                                 date_begin_s, date_end_s, flight_nums)
@@ -586,7 +589,7 @@ def load_train_data(db, flight_route_list, table_name, date_begin, date_end, out
                 print(f"航班号:{flight_nums} 远期表无数据, 跳过")
                 continue
 
-            # 查近期表
+            print(f"查近期表")
             if is_hot == 1:
                 df2 = query_flight_range_status(db, CLEAN_VJ_HOT_NEAR_INFO_TAB, from_city, to_city,
                                                 date_begin_s, date_end_s, flight_nums)

+ 185 - 0
data_preprocess.py

@@ -0,0 +1,185 @@
+import pandas as pd
+import numpy as np
+import bisect
+from datetime import datetime, timedelta
+from sklearn.preprocessing import StandardScaler
+from config import city_to_country, build_country_holidays
+
+COUNTRY_HOLIDAYS = build_country_holidays(city_to_country)
+
+
+def preprocess_data(df_train, features, categorical_features, is_training=True):
+    print(">>> 开始数据预处理") 
+
+    # 生成 城市对
+    df_train['city_pair'] = (
+        df_train['from_city_code'].astype(str) + "-" + df_train['to_city_code'].astype(str)
+    )
+    # 把 city_pair、from_city_code、to_city_code 放到前三列
+    cols = df_train.columns.tolist()
+    # 删除已存在的三列(保证顺序正确)
+    for c in ['city_pair', 'from_city_code', 'to_city_code']:
+        cols.remove(c)
+    # 这三列插入到最前面
+    df_train = df_train[['city_pair', 'from_city_code', 'to_city_code'] + cols]
+
+    # 转格式
+    df_train['search_dep_time'] = pd.to_datetime(
+        df_train['search_dep_time'],
+        format='%Y%m%d',
+        errors='coerce'
+    ).dt.strftime('%Y-%m-%d')
+    # 重命名起飞日期
+    df_train.rename(columns={'search_dep_time': 'flight_day'}, inplace=True)
+    
+    # 重命名航班号
+    df_train.rename(
+        columns={
+            'seg1_flight_number': 'flight_number_1',
+            'seg2_flight_number': 'flight_number_2'
+        },
+        inplace=True
+    )
+    # 分开填充
+    df_train['flight_number_1'] = df_train['flight_number_1'].fillna('VJ')
+    df_train['flight_number_2'] = df_train['flight_number_2'].fillna('VJ')
+
+    # 生成第一机场对
+    df_train['airport_pair_1'] = (
+        df_train['seg1_dep_air_port'].astype(str) + "-" + df_train['seg1_arr_air_port'].astype(str)
+    )
+    # 删除原始第一机场码
+    df_train.drop(columns=['seg1_dep_air_port', 'seg1_arr_air_port'], inplace=True)
+    # 第一机场对 放到 seg1_dep_time 列的前面
+    insert_idx = df_train.columns.get_loc('seg1_dep_time')
+    airport_pair_1 = df_train.pop('airport_pair_1')
+    df_train.insert(insert_idx, 'airport_pair_1', airport_pair_1)
+
+    # 生成第二机场对(带缺失兜底)
+    df_train['airport_pair_2'] = np.where(
+        df_train['seg2_dep_air_port'].isna() | df_train['seg2_arr_air_port'].isna(),
+        'NA',
+        df_train['seg2_dep_air_port'].astype(str) + "-" +
+        df_train['seg2_arr_air_port'].astype(str)
+    )
+    # 删除原始第二机场码
+    df_train.drop(columns=['seg2_dep_air_port', 'seg2_arr_air_port'], inplace=True)
+    # 第二机场对 放到 seg2_dep_time 列的前面
+    insert_idx = df_train.columns.get_loc('seg2_dep_time')
+    airport_pair_2 = df_train.pop('airport_pair_2')
+    df_train.insert(insert_idx, 'airport_pair_2', airport_pair_2)
+    
+    # 是否转乘
+    df_train['is_transfer'] = np.where(df_train['flight_number_2'] == 'VJ', 0, 1)
+    insert_idx = df_train.columns.get_loc('flight_number_2')
+    is_transfer = df_train.pop('is_transfer')
+    df_train.insert(insert_idx, 'is_transfer', is_transfer)
+
+    # 重命名起飞时刻与到达时刻
+    df_train.rename(
+        columns={
+            'seg1_dep_time': 'dep_time_1',
+            'seg1_arr_time': 'arr_time_1',
+            'seg2_dep_time': 'dep_time_2',
+            'seg2_arr_time': 'arr_time_2',
+        },
+        inplace=True
+    )
+    
+    # 第一段飞行时长
+    df_train['fly_duration_1'] = (
+        (df_train['arr_time_1'] - df_train['dep_time_1'])
+        .dt.total_seconds() / 3600
+    ).round(2)
+
+    # 第二段飞行时长(无转乘为 0)
+    df_train['fly_duration_2'] = (
+        (df_train['arr_time_2'] - df_train['dep_time_2'])
+        .dt.total_seconds() / 3600
+    ).fillna(0).round(2)
+
+    # 总飞行时长
+    df_train['fly_duration'] = (
+        df_train['fly_duration_1'] + df_train['fly_duration_2']
+    ).round(2)
+
+    # 中转停留时长(无转乘为 0)
+    df_train['stop_duration'] = (
+        (df_train['dep_time_2'] - df_train['arr_time_1'])
+        .dt.total_seconds() / 3600
+    ).fillna(0).round(2)
+
+    # 裁剪,防止负数
+    # for c in ['fly_duration_1', 'fly_duration_2', 'fly_duration', 'stop_duration']:
+    #     df_train[c] = df_train[c].clip(lower=0)
+
+    # 和 is_transfer 逻辑保持一致
+    # df_train.loc[df_train['is_transfer'] == 0, ['fly_duration_2', 'stop_duration']] = 0
+    
+    # 一次性插到 is_filled 前面
+    insert_before = 'is_filled'
+    new_cols = [
+        'fly_duration_1',
+        'fly_duration_2',
+        'fly_duration',
+        'stop_duration'
+    ]
+    cols = df_train.columns.tolist()
+    idx = cols.index(insert_before)
+    # 删除旧位置
+    cols = [c for c in cols if c not in new_cols]
+    # 插入新位置(顺序保持)
+    cols[idx:idx] = new_cols    # python独有空切片插入法
+    df_train = df_train[cols]
+
+    # 一次生成多个字段
+    dep_t1 = df_train['dep_time_1']
+    # 几点起飞(0–23)
+    df_train['flight_by_hour'] = dep_t1.dt.hour
+    # 起飞日期几号(1–31)
+    df_train['flight_by_day'] = dep_t1.dt.day
+    # 起飞日期几月(1–12)
+    df_train['flight_day_of_month'] = dep_t1.dt.month
+    # 起飞日期周几(0=周一, 6=周日)
+    df_train['flight_day_of_week'] = dep_t1.dt.weekday
+    # 起飞日期季度(1–4)
+    df_train['flight_day_of_quarter'] = dep_t1.dt.quarter
+    # 是否周末(周六 / 周日)
+    df_train['flight_day_is_weekend'] = dep_t1.dt.weekday.isin([5, 6]).astype(int)
+
+    # 找到对应的国家码
+    df_train['dep_country'] = df_train['from_city_code'].map(city_to_country)
+    df_train['arr_country'] = df_train['to_city_code'].map(city_to_country) 
+
+    # 整体出发时间 就是 dep_time_1
+    df_train['global_dep_time'] = df_train['dep_time_1']
+    # 整体到达时间:有转乘用 arr_time_2,否则用 arr_time_1
+    df_train['global_arr_time'] = df_train['arr_time_2'].fillna(df_train['arr_time_1'])
+
+    # 出发日期在出发国家是否节假日
+    df_train['dep_country_is_holiday'] = df_train.apply(
+        lambda r: r['global_dep_time'].date()
+        in COUNTRY_HOLIDAYS.get(r['dep_country'], set()),
+        axis=1
+    ).astype(int)
+
+    # 到达日期在到达国家是否节假日
+    df_train['arr_country_is_holiday'] = df_train.apply(
+        lambda r: r['global_arr_time'].date()
+        in COUNTRY_HOLIDAYS.get(r['arr_country'], set()),
+        axis=1
+    ).astype(int)
+
+    # 在任一侧是否节假日
+    df_train['flight_day_is_holiday'] = (
+        df_train[['dep_country_is_holiday', 'arr_country_is_holiday']]
+        .max(axis=1)
+    )
+
+    # 是否跨国航线
+    df_train['is_cross_country'] = (
+        df_train['dep_country'] != df_train['arr_country']
+    ).astype(int)
+
+    pass
+

+ 204 - 0
main_tr.py

@@ -0,0 +1,204 @@
+import warnings
+import os
+import torch
+import torch.distributed as dist
+from torch.nn.parallel import DistributedDataParallel as DDP
+import joblib
+import gc
+import pandas as pd
+import numpy as np
+import redis
+import time
+import pickle
+import shutil
+from datetime import datetime, timedelta
+from data_loader import chunk_list, mongo_con_parse, load_train_data
+from data_preprocess import preprocess_data
+from config import mongodb_config, vj_flight_route_list, vj_flight_route_list_hot, vj_flight_route_list_nothot, \
+    CLEAN_VJ_HOT_NEAR_INFO_TAB, CLEAN_VJ_HOT_FAR_INFO_TAB, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB, CLEAN_VJ_NOTHOT_FAR_INFO_TAB
+
+warnings.filterwarnings('ignore')
+
+
+# 根据环境变量的存在设置分布式开关
+if 'LOCAL_RANK' in os.environ:
+    FLAG_Distributed = True
+else:
+    FLAG_Distributed = False
+
+
+# 定义特征和参数
+categorical_features = ['city_pair', 'flight_number_1', 'flight_number_2']
+other_features = []
+features = []
+
+target_vars = ['target_min_to_price']   # 最低会降到的价格
+
+# 分布式环境初始化
+def init_distributed_backend():
+    if FLAG_Distributed:
+        local_rank = int(os.environ['LOCAL_RANK'])
+        # 关键:绑定设备必须在初始化进程组之前
+        torch.cuda.set_device(local_rank)            # 显式设置当前进程使用的 GPU
+        try:
+            dist.init_process_group(
+                backend='nccl',
+                init_method='env://',
+                world_size=int(os.environ['WORLD_SIZE']),
+                rank=int(os.environ['RANK']),
+                timeout=timedelta(minutes=30)   
+            )
+            print(f"Process group initialized for rank {dist.get_rank()}")  # 添加日志
+        except Exception as e:
+            print(f"Failed to initialize process group: {e}")  # 捕获异常
+            raise
+        device = torch.device("cuda", local_rank)
+    else:
+        # 如果不在分布式环境中, 使用默认设备
+        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+        print("use common environment")
+    return device
+
+# 初始化模型和相关参数
+def initialize_model(device):
+    return None
+
+def continue_before_process(redis_client, lock_key):
+    # rank0 跳出循环前的处理
+    redis_client.set(lock_key, 2)               # 设置 Redis 锁 key 的值为 2
+    print("rank0 已将 Redis 锁 key 值设置为 2")
+    time.sleep(5)
+    print("rank0 5秒等待结束")
+
+def start_train():
+    device = init_distributed_backend()
+
+    model = initialize_model(device)
+
+    if FLAG_Distributed:
+        rank = dist.get_rank()
+        local_rank = int(os.environ.get('LOCAL_RANK'))
+        world_size = dist.get_world_size()
+    else:
+        rank = 0
+        local_rank = 0
+        world_size = 1
+
+    output_dir = "./data_shards" 
+    photo_dir = "./photo"
+
+    date_end = datetime.today().strftime("%Y-%m-%d")
+    date_begin = (datetime.today() - timedelta(days=10)).strftime("%Y-%m-%d")
+
+    # 仅在 rank == 0 时要做的
+    if rank == 0:
+        # 如果处理中断, 注释掉以下代码
+        batch_dir = os.path.join(output_dir, "batches")
+        try:
+            shutil.rmtree(batch_dir)
+        except FileNotFoundError:
+            print(f"rank:{rank}, {batch_dir} not found")
+
+        # 如果处理中断, 注释掉以下代码
+        csv_file_list = ['evaluate_results.csv']
+        for csv_file in csv_file_list:
+            try:
+                csv_path = os.path.join(output_dir, csv_file)
+                os.remove(csv_path)
+            except Exception as e:
+                print(f"remove {csv_path}: {str(e)}")
+
+        # 确保目录存在
+        os.makedirs(output_dir, exist_ok=True) 
+        os.makedirs(photo_dir, exist_ok=True)
+
+        print(f"最终特征列表:{features}")
+
+    # 定义优化器和损失函数(只回归)
+    # criterion = RegressionLoss(loss_func_flag="Quantile", quantile=0.5)
+    # optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=1e-5)
+
+    group_size = 1
+    num_epochs_per_batch = 200  # 每个批次训练的轮数,可以根据需要调整
+
+    # 初始化 Redis 客户端(请根据实际情况修改 host、port、db)
+    redis_client = redis.Redis(host='192.168.20.222', port=6379, db=0)
+    lock_key = "data_loading_lock_11"
+    barrier_key = 'distributed_barrier_11'
+
+    batch_idx = -1
+
+    # 主干代码
+    flight_route_list = vj_flight_route_list_hot + vj_flight_route_list_nothot
+    flight_route_list_len = len(flight_route_list)
+    route_len_hot = len(vj_flight_route_list_hot)
+    route_len_nothot = len(vj_flight_route_list_nothot)
+
+    # 调试代码
+    # s = 38   # 菲律宾2025-12-08是节假日 s=38 选到马尼拉 
+    # flight_route_list = vj_flight_route_list_hot[:0] + vj_flight_route_list_nothot[s:]
+    # flight_route_list_len = len(flight_route_list)
+    # route_len_hot = len(vj_flight_route_list_hot[:0])
+    # route_len_nothot = len(vj_flight_route_list_nothot[s:])
+    
+    if local_rank == 0:
+        print(f"flight_route_list_len:{flight_route_list_len}")
+        print(f"route_len_hot:{route_len_hot}")
+        print(f"route_len_nothot:{route_len_nothot}")
+    
+    chunks = chunk_list(flight_route_list, group_size)
+
+    for idx, group_route_list in enumerate(chunks, start=0):
+        # 特殊处理,跳过不好的批次
+        pass
+        redis_client.set(lock_key, 0)
+        redis_client.set(barrier_key, 0)
+        # 所有 Rank 同步的标志变量
+        valid_batch = torch.tensor([1], dtype=torch.int, device=device)  # 1表示有效批次
+
+        # 仅在 rank == 0 时要做的
+        if rank == 0:
+            # Rank0 设置 Redis 锁 key 的初始值为 0,表示数据加载尚未完成
+            redis_client.set(lock_key, 0)
+            print("rank0 开始数据加载...")
+            # 使用默认配置
+            client, db = mongo_con_parse()
+            print(f"第 {idx} 组 :", group_route_list)
+
+            # 根据索引位置决定是 热门 还是 冷门
+            if 0 <= idx < route_len_hot:
+                is_hot = 1
+                table_name = CLEAN_VJ_HOT_NEAR_INFO_TAB
+            elif route_len_hot <= idx < route_len_hot + route_len_nothot:
+                is_hot = 0
+                table_name = CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
+            else:
+                print(f"无法确定热门还是冷门, 跳过此批次。")
+                continue_before_process(redis_client, lock_key)
+                continue
+            
+            # 加载训练数据
+            start_time = time.time()
+            df_train = load_train_data(db, group_route_list, table_name, date_begin, date_end, output_dir, is_hot)
+            end_time = time.time()
+            run_time = round(end_time - start_time, 3)
+            print(f"用时: {run_time} 秒")
+
+            client.close()
+
+            if df_train.empty:
+                print(f"训练数据为空,跳过此批次。")
+                continue_before_process(redis_client, lock_key)
+                continue
+            
+            # 数据预处理
+            df_train_inputs = preprocess_data(df_train, features, categorical_features, is_training=True)
+            pass
+
+        else:
+            pass
+
+
+
+if __name__ == "__main__":
+    start_train()