Просмотр исходного кода

提交近期修改 特征处理 标准化 序列构建

node04 1 день назад
Родитель
Сommit
10cd052ace
5 измененных файлов с 411 добавлено и 66 удалено
  1. 41 1
      config.py
  2. 71 32
      data_loader.py
  3. 128 16
      data_preprocess.py
  4. 88 17
      main_tr.py
  5. 83 0
      utils.py

+ 41 - 1
config.py

@@ -55,11 +55,51 @@ city_to_country = {
     "VTE": "LA",  # 万象,老挝
     "VTE": "LA",  # 万象,老挝
     "HYD": "IN",  # 海得拉巴,印度
     "HYD": "IN",  # 海得拉巴,印度
     "AMD": "IN",  # 艾哈迈达巴德,印度
     "AMD": "IN",  # 艾哈迈达巴德,印度
+    "JKT": "ID",  # 雅加达,印度尼西亚
 }
 }
 
 
-# 城市码-数字映射
+# 城市码-数字映射
 vj_city_code_map = {k: i for i, k in enumerate(city_to_country.keys())}
 vj_city_code_map = {k: i for i, k in enumerate(city_to_country.keys())}
 
 
+# 航班号-数字映射
+vi_flight_number_map = {
+    'VJ': 0,
+    'VJ082': 1, 'VJ083': 2, 'VJ084': 3, 'VJ085': 4, 'VJ086': 5, 'VJ1159': 6, 'VJ120': 7, 'VJ121': 8, 'VJ122': 9, 'VJ123': 10, 
+    'VJ124': 11, 'VJ125': 12, 'VJ126': 13, 'VJ127': 14, 'VJ128': 15, 'VJ129': 16, 'VJ130': 17, 'VJ131': 18, 'VJ132': 19, 'VJ1321': 20, 
+    'VJ1322': 21, 'VJ1329': 22, 'VJ133': 23, 'VJ1330': 24, 'VJ134': 25, 'VJ135': 26, 'VJ136': 27, 'VJ137': 28, 'VJ138': 29, 'VJ139': 30, 
+    'VJ142': 31, 'VJ143': 32, 'VJ144': 33, 'VJ1441': 34, 'VJ1443': 35, 'VJ1445': 36, 'VJ1447': 37, 'VJ145': 38, 'VJ146': 39, 'VJ147': 40, 
+    'VJ148': 41, 'VJ149': 42, 'VJ1492': 43, 'VJ150': 44, 'VJ151': 45, 'VJ152': 46, 'VJ153': 47, 'VJ154': 48, 'VJ155': 49, 'VJ156': 50, 
+    'VJ157': 51, 'VJ158': 52, 'VJ159': 53, 'VJ160': 54, 'VJ1602': 55, 'VJ1604': 56, 'VJ1606': 57, 'VJ1608': 58, 'VJ161': 59, 'VJ1612': 60, 
+    'VJ1614': 61, 'VJ1616': 62, 'VJ1618': 63, 'VJ163': 64, 'VJ164': 65, 'VJ1643': 66, 'VJ165': 67, 'VJ166': 68, 'VJ167': 69, 'VJ169': 70, 
+    'VJ171': 71, 'VJ1718': 72, 'VJ172': 73, 'VJ173': 74, 'VJ174': 75, 'VJ175': 76, 'VJ176': 77, 'VJ177': 78, 'VJ178': 79, 'VJ179': 80, 
+    'VJ180': 81, 'VJ1801': 82, 'VJ1802': 83, 'VJ1803': 84, 'VJ1805': 85, 'VJ181': 86, 'VJ182': 87, 'VJ183': 88, 'VJ1831': 89, 'VJ184': 90, 
+    'VJ185': 91, 'VJ186': 92, 'VJ187': 93, 'VJ188': 94, 'VJ189': 95, 'VJ190': 96, 'VJ191': 97, 'VJ192': 98, 'VJ1925': 99, 'VJ193': 100, 
+    'VJ194': 101, 'VJ195': 102, 'VJ196': 103, 'VJ198': 104, 'VJ199': 105, 'VJ270': 106, 'VJ272': 107, 'VJ274': 108, 'VJ278': 109, 'VJ280': 110, 
+    'VJ282': 111, 'VJ284': 112, 'VJ2841': 113, 'VJ288': 114, 'VJ320': 115, 'VJ321': 116, 'VJ322': 117, 'VJ324': 118, 'VJ325': 119, 'VJ326': 120, 
+    'VJ327': 121, 'VJ328': 122, 'VJ329': 123, 'VJ330': 124, 'VJ331': 125, 'VJ343': 126, 'VJ344': 127, 'VJ345': 128, 'VJ347': 129, 'VJ3524': 130, 
+    'VJ3526': 131, 'VJ381': 132, 'VJ385': 133, 'VJ387': 134, 'VJ3900': 135, 'VJ3901': 136, 'VJ3908': 137, 'VJ3909': 138, 'VJ3930': 139, 'VJ3931': 140, 
+    'VJ401': 141, 'VJ402': 142, 'VJ403': 143, 'VJ404': 144, 'VJ407': 145, 'VJ408': 146, 'VJ409': 147, 'VJ410': 148, 'VJ431': 149, 'VJ433': 150,
+    'VJ441': 151, 'VJ443': 152, 'VJ445': 153, 'VJ492': 154, 'VJ494': 155, 'VJ501': 156, 'VJ502': 157, 'VJ503': 158, 'VJ504': 159, 'VJ505': 160, 
+    'VJ506': 161, 'VJ507': 162, 'VJ508': 163, 'VJ509': 164, 'VJ510': 165, 'VJ511': 166, 'VJ512': 167, 'VJ513': 168, 'VJ514': 169, 'VJ515': 170, 
+    'VJ516': 171, 'VJ517': 172, 'VJ518': 173, 'VJ519': 174, 'VJ520': 175, 'VJ521': 176, 'VJ522': 177, 'VJ523': 178, 'VJ524': 179, 'VJ527': 180, 
+    'VJ528': 181, 'VJ600': 182, 'VJ602': 183, 'VJ620': 184, 'VJ621': 185, 'VJ622': 186, 'VJ623': 187, 'VJ624': 188, 'VJ625': 189, 'VJ626': 190, 
+    'VJ627': 191, 'VJ628': 192, 'VJ629': 193, 'VJ630': 194, 'VJ631': 195, 'VJ632': 196, 'VJ633': 197, 'VJ634': 198, 'VJ635': 199, 'VJ636': 200, 
+    'VJ637': 201, 'VJ638': 202, 'VJ639': 203, 'VJ640': 204, 'VJ641': 205, 'VJ642': 206, 'VJ643': 207, 'VJ644': 208, 'VJ645': 209, 'VJ646': 210, 
+    'VJ647': 211, 'VJ648': 212, 'VJ649': 213, 'VJ6922': 214, 'VJ716': 215, 'VJ718': 216, 'VJ7238': 217, 'VJ7239': 218, 'VJ729': 219, 'VJ7306': 220, 
+    'VJ7307': 221, 'VJ731': 222, 'VJ7526': 223, 'VJ7527': 224, 'VJ7614': 225, 'VJ770': 226, 'VJ771': 227, 'VJ772': 228, 'VJ773': 229, 'VJ774': 230, 
+    'VJ775': 231, 'VJ778': 232, 'VJ779': 233, 'VJ780': 234, 'VJ781': 235, 'VJ783': 236, 'VJ784': 237, 'VJ802': 238, 'VJ8021': 239, 'VJ804': 240, 
+    'VJ806': 241, 'VJ808': 242, 'VJ809': 243, 'VJ812': 244, 'VJ814': 245, 'VJ816': 246, 'VJ820': 247, 'VJ821': 248, 'VJ822': 249, 'VJ823': 250, 
+    'VJ824': 251, 'VJ828': 252, 'VJ829': 253, 'VJ832': 254, 'VJ833': 255, 'VJ834': 256, 'VJ835': 257, 'VJ836': 258, 'VJ837': 259, 'VJ838': 260, 
+    'VJ839': 261, 'VJ840': 262, 'VJ841': 263, 'VJ842': 264, 'VJ843': 265, 'VJ845': 266, 'VJ848': 267, 'VJ849': 268, 'VJ854': 269, 'VJ856': 270, 
+    'VJ857': 271, 'VJ858': 272, 'VJ859': 273, 'VJ860': 274, 'VJ861': 275, 'VJ862': 276, 'VJ863': 277, 'VJ864': 278, 'VJ865': 279, 'VJ874': 280, 
+    'VJ875': 281, 'VJ876': 282, 'VJ877': 283, 'VJ878': 284, 'VJ879': 285, 'VJ880': 286, 'VJ881': 287, 'VJ8814': 288, 'VJ883': 289, 'VJ885': 290, 
+    'VJ890': 291, 'VJ892': 292, 'VJ893': 293, 'VJ894': 294, 'VJ895': 295, 'VJ896': 296, 'VJ897': 297, 'VJ898': 298, 'VJ899': 299, 'VJ900': 300, 
+    'VJ905': 301, 'VJ906': 302, 'VJ907': 303, 'VJ913': 304, 'VJ916': 305, 'VJ919': 306, 'VJ920': 307, 'VJ921': 308, 'VJ928': 309, 'VJ930': 310, 
+    'VJ931': 311, 'VJ932': 312, 'VJ933': 313, 'VJ934': 314, 'VJ935': 315, 'VJ938': 316, 'VJ939': 317, 'VJ940': 318, 'VJ941': 319, 'VJ942': 320, 
+    'VJ943': 321, 'VJ948': 322, 'VJ959': 323, 'VJ970': 324, 'VJ971': 325, 'VJ972': 326, 'VJ974': 327, 'VJ976': 328, 'VJ978': 329, 'VJ9812': 330, 
+    'VJ984': 331, 'VJ985': 332, 'VJ986': 333, 'VJ991': 334, 'VJ997': 335, 'VJ998': 336, 'VJ9986': 337, 'VZ3525': 338, 'VZ566': 339, 'VZ568': 340, 
+    'VZ570': 341
+}
 
 
 # 生成各个国家(地区)的节假日
 # 生成各个国家(地区)的节假日
 def build_country_holidays(city_to_country):
 def build_country_holidays(city_to_country):

+ 71 - 32
data_loader.py

@@ -31,9 +31,9 @@ def mongo_con_parse(config=None):
             client = pymongo.MongoClient(
             client = pymongo.MongoClient(
                 config['host'],
                 config['host'],
                 config['port'],
                 config['port'],
-                serverSelectionTimeoutMS=6000,  # 6秒
-                connectTimeoutMS=6000,  # 6秒
-                socketTimeoutMS=6000,  # 6秒,
+                serverSelectionTimeoutMS=15000,  # 6秒
+                connectTimeoutMS=15000,  # 6秒
+                socketTimeoutMS=15000,  # 6秒,
                 retryReads=True,    # 开启重试
                 retryReads=True,    # 开启重试
                 maxPoolSize=50
                 maxPoolSize=50
             )
             )
@@ -718,38 +718,77 @@ def load_train_data(db, flight_route_list, table_name, date_begin, date_end, out
     return df_all
     return df_all
 
 
 
 
-def chunk_list(lst, group_size):
-    return [lst[i:i + group_size] for i in range(0, len(lst), group_size)]
+def query_all_flight_number(db, table_name):
+    print(f"{table_name} 查找所有航班号")
+    pipeline = [
+        {
+            "$project": {
+                "flight_numbers": "$segments.flight_number"
+            }
+        },
+        {
+            "$group": {
+                "_id": "$flight_numbers",
+                "count": { "$sum": 1 }
+            }
+        },
+    ]
+    # 执行聚合查询
+    collection = db[table_name]
+    results = list(collection.aggregate(pipeline))
 
 
+    list_flight_number = []
+    for item in results:
+        item_li = item.get("_id", [])
+        list_flight_number.extend(item_li)
+
+    list_flight_number = list(set(list_flight_number))
+    
+    return list_flight_number
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
 
 
     # test_mongo_connection(db)
     # test_mongo_connection(db)
 
 
-    output_dir = f"./output"
-    os.makedirs(output_dir, exist_ok=True)
-
-    # 加载热门航线数据
-    date_begin = "2025-11-20"
-    date_end = datetime.today().strftime("%Y-%m-%d")
-
-    flight_route_list = vj_flight_route_list_hot[0:]  # 热门 vj_flight_route_list_hot  冷门 vj_flight_route_list_nothot
-    table_name = CLEAN_VJ_HOT_NEAR_INFO_TAB  # 热门 CLEAN_VJ_HOT_NEAR_INFO_TAB  冷门 CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
-    is_hot = 1   # 1 热门 0 冷门
-    group_size = 1
-    chunks = chunk_list(flight_route_list, group_size)
-
-    for idx, group_route_list in enumerate(chunks, 1):
-        # 使用默认配置
-        client, db = mongo_con_parse()
-        print(f"第 {idx} 组 :", group_route_list)
-        start_time = time.time()
-        load_train_data(db, group_route_list, table_name, date_begin, date_end, output_dir, is_hot)
-        end_time = time.time()
-        run_time = round(end_time - start_time, 3)
-        print(f"用时: {run_time} 秒")
-
-        client.close()
-        time.sleep(3)
-
-    print("整体结束")
+    # output_dir = f"./output"
+    # os.makedirs(output_dir, exist_ok=True)
+
+    # # 加载热门航线数据
+    # date_begin = "2025-11-20"
+    # date_end = datetime.today().strftime("%Y-%m-%d")
+
+    # flight_route_list = vj_flight_route_list_hot[0:]  # 热门 vj_flight_route_list_hot  冷门 vj_flight_route_list_nothot
+    # table_name = CLEAN_VJ_HOT_NEAR_INFO_TAB  # 热门 CLEAN_VJ_HOT_NEAR_INFO_TAB  冷门 CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
+    # is_hot = 1   # 1 热门 0 冷门
+    # group_size = 1
+    # chunks = chunk_list(flight_route_list, group_size)
+
+    # for idx, group_route_list in enumerate(chunks, 1):
+    #     # 使用默认配置
+    #     client, db = mongo_con_parse()
+    #     print(f"第 {idx} 组 :", group_route_list)
+    #     start_time = time.time()
+    #     load_train_data(db, group_route_list, table_name, date_begin, date_end, output_dir, is_hot)
+    #     end_time = time.time()
+    #     run_time = round(end_time - start_time, 3)
+    #     print(f"用时: {run_time} 秒")
+
+    #     client.close()
+    #     time.sleep(3)
+
+    # print("整体结束")
+
+    client, db = mongo_con_parse()
+    list_flight_number_1 = query_all_flight_number(db, CLEAN_VJ_HOT_NEAR_INFO_TAB)
+    list_flight_number_2 = query_all_flight_number(db, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB)
+
+    list_flight_number_all = list_flight_number_1 + list_flight_number_2
+    list_flight_number_all = list(set(list_flight_number_all))
+    list_flight_number_all.sort()
+    
+    print(list_flight_number_all)
+    print(len(list_flight_number_all))
+
+    flight_map = {v: i for i, v in enumerate(list_flight_number_all, start=1)}
+    print(flight_map)
+    

+ 128 - 16
data_preprocess.py

@@ -4,7 +4,8 @@ import bisect
 import gc
 import gc
 from datetime import datetime, timedelta
 from datetime import datetime, timedelta
 from sklearn.preprocessing import StandardScaler
 from sklearn.preprocessing import StandardScaler
-from config import city_to_country, vj_city_code_map, build_country_holidays
+from config import city_to_country, vj_city_code_map, vi_flight_number_map, build_country_holidays
+from utils import insert_df_col
 
 
 COUNTRY_HOLIDAYS = build_country_holidays(city_to_country)
 COUNTRY_HOLIDAYS = build_country_holidays(city_to_country)
 
 
@@ -16,8 +17,22 @@ def preprocess_data(df_input, features, categorical_features, is_training=True,
     df_input['city_pair'] = (
     df_input['city_pair'] = (
         df_input['from_city_code'].astype(str) + "-" + df_input['to_city_code'].astype(str)
         df_input['from_city_code'].astype(str) + "-" + df_input['to_city_code'].astype(str)
     )
     )
+    # 城市码映射成数字
     df_input['from_city_num'] = df_input['from_city_code'].map(vj_city_code_map)
     df_input['from_city_num'] = df_input['from_city_code'].map(vj_city_code_map)
     df_input['to_city_num'] = df_input['to_city_code'].map(vj_city_code_map)
     df_input['to_city_num'] = df_input['to_city_code'].map(vj_city_code_map)
+    
+    missing_from = (
+        df_input.loc[df_input['from_city_num'].isna(), 'from_city_code']
+        .unique()
+    )
+    missing_to = (
+        df_input.loc[df_input['to_city_num'].isna(), 'to_city_code']
+        .unique()
+    )
+    if missing_from:
+        print("未映射的 from_city:", missing_from)
+    if missing_to:
+        print("未映射的 to_city:", missing_to)
 
 
     # 把 city_pair、from_city_code、from_city_num, to_city_code, to_city_num 放到前几列
     # 把 city_pair、from_city_code、from_city_num, to_city_code, to_city_num 放到前几列
     cols = df_input.columns.tolist()
     cols = df_input.columns.tolist()
@@ -26,6 +41,7 @@ def preprocess_data(df_input, features, categorical_features, is_training=True,
         cols.remove(c)
         cols.remove(c)
     # 这几列插入到最前面
     # 这几列插入到最前面
     df_input = df_input[['city_pair', 'from_city_code', 'from_city_num', 'to_city_code', 'to_city_num'] + cols]
     df_input = df_input[['city_pair', 'from_city_code', 'from_city_num', 'to_city_code', 'to_city_num'] + cols]
+    pass
 
 
     # 转格式
     # 转格式
     df_input['search_dep_time'] = pd.to_datetime(
     df_input['search_dep_time'] = pd.to_datetime(
@@ -48,6 +64,42 @@ def preprocess_data(df_input, features, categorical_features, is_training=True,
     df_input['flight_number_1'] = df_input['flight_number_1'].fillna('VJ')
     df_input['flight_number_1'] = df_input['flight_number_1'].fillna('VJ')
     df_input['flight_number_2'] = df_input['flight_number_2'].fillna('VJ')
     df_input['flight_number_2'] = df_input['flight_number_2'].fillna('VJ')
 
 
+    # 航班号转数字
+    df_input['flight_1_num'] = df_input['flight_number_1'].map(vi_flight_number_map)
+    df_input['flight_2_num'] = df_input['flight_number_2'].map(vi_flight_number_map)
+
+    missing_flight_1 = (
+        df_input.loc[df_input['flight_1_num'].isna(), 'flight_number_1']
+        .unique()
+    )
+    missing_flight_2 = (
+        df_input.loc[df_input['flight_2_num'].isna(), 'flight_number_2']
+        .unique()
+    )
+    if missing_flight_1:
+        print("未映射的 flight_1:", missing_flight_1)
+    if missing_flight_2:
+        print("未映射的 flight_2:", missing_flight_2)
+    
+    # flight_1_num 放在 seg1_dep_air_port 之前
+    insert_df_col(df_input, 'flight_1_num', 'seg1_dep_air_port')
+    
+    # flight_2_num 放在 seg2_dep_air_port 之前
+    insert_df_col(df_input, 'flight_2_num', 'seg2_dep_air_port')
+
+    df_input['baggage_level'] = (df_input['baggage'] == 30).astype(int)   # 30--> 1  20--> 0 
+    # baggage_level 放在 flight_number_2 之前
+    insert_df_col(df_input, 'baggage_level', 'flight_number_2')
+
+    df_input['Adult_Total_Price'] = df_input['adult_total_price']
+    # Adult_Total_Price 放在 seats_remaining 之前  保存缩放前的原始值
+    insert_df_col(df_input, 'Adult_Total_Price', 'seats_remaining')
+
+    df_input['Hours_Until_Departure'] = df_input['hours_until_departure']
+    # Hours_Until_Departure 放在 days_to_departure 之前  保存缩放前的原始值
+    insert_df_col(df_input, 'Hours_Until_Departure', 'days_to_departure')
+    pass
+
     # gid:基于指定字段的分组标记(整数)
     # gid:基于指定字段的分组标记(整数)
     df_input['gid'] = (
     df_input['gid'] = (
         df_input
         df_input
@@ -116,9 +168,7 @@ def preprocess_data(df_input, features, categorical_features, is_training=True,
     # 删除原始第一机场码
     # 删除原始第一机场码
     df_input.drop(columns=['seg1_dep_air_port', 'seg1_arr_air_port'], inplace=True)
     df_input.drop(columns=['seg1_dep_air_port', 'seg1_arr_air_port'], inplace=True)
     # 第一机场对 放到 seg1_dep_time 列的前面
     # 第一机场对 放到 seg1_dep_time 列的前面
-    insert_idx = df_input.columns.get_loc('seg1_dep_time')
-    airport_pair_1 = df_input.pop('airport_pair_1')
-    df_input.insert(insert_idx, 'airport_pair_1', airport_pair_1)
+    insert_df_col(df_input, 'airport_pair_1', 'seg1_dep_time')
 
 
     # 生成第二机场对(带缺失兜底)
     # 生成第二机场对(带缺失兜底)
     df_input['airport_pair_2'] = np.where(
     df_input['airport_pair_2'] = np.where(
@@ -130,15 +180,12 @@ def preprocess_data(df_input, features, categorical_features, is_training=True,
     # 删除原始第二机场码
     # 删除原始第二机场码
     df_input.drop(columns=['seg2_dep_air_port', 'seg2_arr_air_port'], inplace=True)
     df_input.drop(columns=['seg2_dep_air_port', 'seg2_arr_air_port'], inplace=True)
     # 第二机场对 放到 seg2_dep_time 列的前面
     # 第二机场对 放到 seg2_dep_time 列的前面
-    insert_idx = df_input.columns.get_loc('seg2_dep_time')
-    airport_pair_2 = df_input.pop('airport_pair_2')
-    df_input.insert(insert_idx, 'airport_pair_2', airport_pair_2)
-    
+    insert_df_col(df_input, 'airport_pair_2', 'seg2_dep_time')
+
     # 是否转乘
     # 是否转乘
     df_input['is_transfer'] = np.where(df_input['flight_number_2'] == 'VJ', 0, 1)
     df_input['is_transfer'] = np.where(df_input['flight_number_2'] == 'VJ', 0, 1)
-    insert_idx = df_input.columns.get_loc('flight_number_2')
-    is_transfer = df_input.pop('is_transfer')
-    df_input.insert(insert_idx, 'is_transfer', is_transfer)
+    # 是否转乘 放到 flight_number_2 列的前面
+    insert_df_col(df_input, 'is_transfer', 'flight_number_2')
 
 
     # 重命名起飞时刻与到达时刻
     # 重命名起飞时刻与到达时刻
     df_input.rename(
     df_input.rename(
@@ -236,7 +283,7 @@ def preprocess_data(df_input, features, categorical_features, is_training=True,
     ).astype(int)
     ).astype(int)
 
 
     # 在任一侧是否节假日
     # 在任一侧是否节假日
-    df_input['flight_day_is_holiday'] = (
+    df_input['any_country_is_holiday'] = (
         df_input[['dep_country_is_holiday', 'arr_country_is_holiday']]
         df_input[['dep_country_is_holiday', 'arr_country_is_holiday']]
         .max(axis=1)
         .max(axis=1)
     )
     )
@@ -275,9 +322,7 @@ def preprocess_data(df_input, features, categorical_features, is_training=True,
     # df_input['days_to_holiday'] = df_input['days_to_holiday'].fillna(999)
     # df_input['days_to_holiday'] = df_input['days_to_holiday'].fillna(999)
 
 
     # days_to_holiday 插在 update_hour 前面
     # days_to_holiday 插在 update_hour 前面
-    insert_idx = df_input.columns.get_loc('update_hour')
-    days_to_holiday = df_input.pop('days_to_holiday')
-    df_input.insert(insert_idx, 'days_to_holiday', days_to_holiday)
+    insert_df_col(df_input, 'days_to_holiday', 'update_hour')
 
 
     # 制作targets
     # 制作targets
     print(f"\n>>> 开始处理 对应区间: n_hours = {current_n_hours}")
     print(f"\n>>> 开始处理 对应区间: n_hours = {current_n_hours}")
@@ -372,5 +417,72 @@ def preprocess_data(df_input, features, categorical_features, is_training=True,
     print(">>> 合并后 df_input 样例:")
     print(">>> 合并后 df_input 样例:")
     print(df_input[['gid', 'hours_until_departure', 'adult_total_price', 'target_will_price_drop', 'target_amount_of_drop', 'target_time_to_drop']].head(5))
     print(df_input[['gid', 'hours_until_departure', 'adult_total_price', 'target_will_price_drop', 'target_amount_of_drop', 'target_time_to_drop']].head(5))
 
 
-    
+    # 按顺序排列
+    order_columns = [
+        "city_pair", "from_city_code", "from_city_num", "to_city_code", "to_city_num", "flight_day", 
+        "seats_remaining", "baggage", "baggage_level", 
+        "price_change_times_total", "price_last_change_hours", "adult_total_price", "Adult_Total_Price", "target_will_price_drop", "target_time_to_drop",
+        "days_to_departure", "days_to_holiday", "hours_until_departure", "Hours_Until_Departure", "update_hour", "gid",
+        "flight_number_1", "flight_1_num", "airport_pair_1", "dep_time_1", "arr_time_1", "fly_duration_1", 
+        "flight_by_hour", "flight_by_day", "flight_day_of_month", "flight_day_of_week", "flight_day_of_quarter", "flight_day_is_weekend", "is_transfer", 
+        "flight_number_2", "flight_2_num", "airport_pair_2", "dep_time_2", "arr_time_2", "fly_duration_2", "fly_duration", "stop_duration", 
+        "global_dep_time", "dep_country", "dep_country_is_holiday", "is_cross_country",
+        "global_arr_time", "arr_country", "arr_country_is_holiday", "any_country_is_holiday",
+    ]
+    df_input = df_input[order_columns]
+
     return df_input
     return df_input
+
+
+def standardization(df, feature_scaler, target_scaler, is_training=True, is_test=False):
+    print(">>> 开始标准化处理")
+
+    # 准备走标准化的特征
+    scaler_features = ['adult_total_price', 'fly_duration', 'stop_duration']
+    
+    if is_training:
+        print(">>> 特征数据标准化开始")
+        if feature_scaler is None:
+            feature_scaler = StandardScaler()
+        if not is_test:
+            feature_scaler.fit(df[scaler_features])
+        df[scaler_features] = feature_scaler.transform(df[scaler_features])
+        print(">>> 特征数据标准化完成")
+    
+    else:
+        df[scaler_features] = feature_scaler.transform(df[scaler_features])
+        print(">>> 预测模式下特征标准化处理完成")
+
+    # 准备走归一化的特征
+    # 事先定义好每个特征的合理范围
+    fixed_ranges = {
+        'hours_until_departure': (0, 480),       # 0-20天
+        'from_city_num': (0, 38),
+        'to_city_num': (0, 38),
+        'flight_1_num': (0, 341),
+        'flight_2_num': (0, 341),
+        'seats_remaining': (1, 5),
+        'price_change_times_total': (0, 30),     # 假设价格变更次数不会超过30次
+        'price_last_change_hours': (0, 480),     
+        'days_to_departure': (0, 30),
+        'days_to_holiday': (0, 120),             # 最长的越南节假日间隔120天
+        'flight_by_hour': (0, 23),
+        'flight_by_day': (1, 31),
+        'flight_day_of_month': (1, 12),
+        'flight_day_of_week': (0, 6),
+        'flight_day_of_quarter': (1, 4),
+    }
+    normal_features = list(fixed_ranges.keys())
+
+    print(">>> 归一化特征列: ", normal_features)
+    print(">>> 基于固定范围的特征数据归一化开始")
+    for col in normal_features:
+        if col in df.columns:
+            # 核心归一化公式: (x - min) / (max - min)
+            col_min, col_max = fixed_ranges[col]
+            df[col] = (df[col] - col_min) / (col_max - col_min)
+            # 添加裁剪,将超出范围的值强制限制在[0,1]区间
+            df[col] = df[col].clip(0, 1)
+    print(">>> 基于固定范围的特征数据归一化完成")
+
+    return df, feature_scaler, target_scaler

+ 88 - 17
main_tr.py

@@ -12,8 +12,9 @@ import time
 import pickle
 import pickle
 import shutil
 import shutil
 from datetime import datetime, timedelta
 from datetime import datetime, timedelta
-from data_loader import chunk_list, mongo_con_parse, load_train_data
-from data_preprocess import preprocess_data
+from utils import chunk_list_with_index, create_fixed_length_sequences
+from data_loader import mongo_con_parse, load_train_data
+from data_preprocess import preprocess_data, standardization
 from config import mongodb_config, vj_flight_route_list, vj_flight_route_list_hot, vj_flight_route_list_nothot, \
 from config import mongodb_config, vj_flight_route_list, vj_flight_route_list_hot, vj_flight_route_list_nothot, \
     CLEAN_VJ_HOT_NEAR_INFO_TAB, CLEAN_VJ_HOT_FAR_INFO_TAB, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB, CLEAN_VJ_NOTHOT_FAR_INFO_TAB
     CLEAN_VJ_HOT_NEAR_INFO_TAB, CLEAN_VJ_HOT_FAR_INFO_TAB, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB, CLEAN_VJ_NOTHOT_FAR_INFO_TAB
 
 
@@ -32,10 +33,10 @@ categorical_features = ['city_pair', 'flight_day', 'flight_number_1', 'flight_nu
 common_features = ['hours_until_departure', 'days_to_departure', 'seats_remaining', 'is_cross_country', 'is_transfer', 
 common_features = ['hours_until_departure', 'days_to_departure', 'seats_remaining', 'is_cross_country', 'is_transfer', 
                    'fly_duration', 'stop_duration', 
                    'fly_duration', 'stop_duration', 
                    'flight_by_hour', 'flight_by_day', 'flight_day_of_month', 'flight_day_of_week', 'flight_day_of_quarter', 'flight_day_is_weekend',
                    'flight_by_hour', 'flight_by_day', 'flight_day_of_month', 'flight_day_of_week', 'flight_day_of_quarter', 'flight_day_is_weekend',
-                   'dep_country_is_holiday', 'arr_country_is_holiday', 'flight_day_is_holiday', 'days_to_holiday',
+                   'dep_country_is_holiday', 'arr_country_is_holiday', 'any_country_is_holiday', 'days_to_holiday',
                   ]
                   ]
 price_features = ['adult_total_price', 'price_change_times_total', 'price_last_change_hours']
 price_features = ['adult_total_price', 'price_change_times_total', 'price_last_change_hours']
-encoded_columns = ['from_city_num', 'to_city_num', 'flight_1_num', 'flight_2_num', 'cabin_level', 'baggage_level']
+encoded_columns = ['from_city_num', 'to_city_num', 'flight_1_num', 'flight_2_num', 'baggage_level']
 features = encoded_columns + price_features + common_features
 features = encoded_columns + price_features + common_features
 target_vars = ['target_will_price_drop']   # 是否降价
 target_vars = ['target_will_price_drop']   # 是否降价
 
 
@@ -94,7 +95,7 @@ def start_train():
     photo_dir = "./photo"
     photo_dir = "./photo"
 
 
     date_end = datetime.today().strftime("%Y-%m-%d")
     date_end = datetime.today().strftime("%Y-%m-%d")
-    date_begin = (datetime.today() - timedelta(days=15)).strftime("%Y-%m-%d")
+    date_begin = (datetime.today() - timedelta(days=18)).strftime("%Y-%m-%d")
 
 
     # 仅在 rank == 0 时要做的
     # 仅在 rank == 0 时要做的
     if rank == 0:
     if rank == 0:
@@ -124,9 +125,12 @@ def start_train():
     # criterion = RegressionLoss(loss_func_flag="Quantile", quantile=0.5)
     # criterion = RegressionLoss(loss_func_flag="Quantile", quantile=0.5)
     # optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=1e-5)
     # optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=1e-5)
 
 
-    group_size = 1
+    group_size = 1              # 每几组作为一个批次
     num_epochs_per_batch = 200  # 每个批次训练的轮数,可以根据需要调整
     num_epochs_per_batch = 200  # 每个批次训练的轮数,可以根据需要调整
 
 
+    feature_scaler = None     # 初始化特征缩放器
+    target_scaler = None      # 初始化目标缩放器
+
     # 初始化 Redis 客户端(请根据实际情况修改 host、port、db)
     # 初始化 Redis 客户端(请根据实际情况修改 host、port、db)
     redis_client = redis.Redis(host='192.168.20.222', port=6379, db=0)
     redis_client = redis.Redis(host='192.168.20.222', port=6379, db=0)
     lock_key = "data_loading_lock_11"
     lock_key = "data_loading_lock_11"
@@ -142,19 +146,67 @@ def start_train():
 
 
     # 调试代码
     # 调试代码
     s = 38   # 菲律宾2025-12-08是节假日 s=38 选到马尼拉 
     s = 38   # 菲律宾2025-12-08是节假日 s=38 选到马尼拉 
-    flight_route_list = vj_flight_route_list_hot[0:] + vj_flight_route_list_nothot[s:]
+    flight_route_list = vj_flight_route_list_hot[:0] + vj_flight_route_list_nothot[s:]
     flight_route_list_len = len(flight_route_list)
     flight_route_list_len = len(flight_route_list)
-    route_len_hot = len(vj_flight_route_list_hot[0:])
+    route_len_hot = len(vj_flight_route_list_hot[:0])
     route_len_nothot = len(vj_flight_route_list_nothot[s:])
     route_len_nothot = len(vj_flight_route_list_nothot[s:])
     
     
     if local_rank == 0:
     if local_rank == 0:
         print(f"flight_route_list_len:{flight_route_list_len}")
         print(f"flight_route_list_len:{flight_route_list_len}")
         print(f"route_len_hot:{route_len_hot}")
         print(f"route_len_hot:{route_len_hot}")
         print(f"route_len_nothot:{route_len_nothot}")
         print(f"route_len_nothot:{route_len_nothot}")
+
+    # 如果处理中断,打开注释加载批次顺序
+    # with open(os.path.join(output_dir, f'order.pkl'), "rb") as f:
+    #     flight_route_list = pickle.load(f)
+
+    if rank == 0:
+        pass
+        # 保存批次顺序, 如果处理临时中断, 将这段代码注释掉
+        with open(os.path.join(output_dir, f'order.pkl'), "wb") as f:
+            pickle.dump(flight_route_list, f)
+
+    chunks = chunk_list_with_index(flight_route_list, group_size)
+
+    # 新增部分:计算总批次数并初始化 scaler 列表
+    if rank == 0:
+        total_batches = len(chunks)
+        feature_scaler_list = [None] * total_batches  # 预分配列表空间
+        # target_scaler_list = [None] * total_batches   # 预分配列表空间
+
+    # 中断时,打开下面注释, 临时加载一下 scaler 列表
+    # if rank == 0:
+    #     feature_scaler_path = os.path.join(output_dir, 'feature_scalers.joblib')
+    #     target_scaler_path = os.path.join(output_dir, 'target_scalers.joblib')
+
+    #     if os.path.exists(feature_scaler_path):
+    #         # 加载旧的scaler列表
+    #         old_feature_scaler_list = joblib.load(feature_scaler_path)
+    #         # 计算旧的总批次数
+    #         old_total_batches = len(old_feature_scaler_list)
+    #         # 只替换重叠部分
+    #         min_batches = min(old_total_batches, total_batches)
+    #         feature_scaler_list[:min_batches] = old_feature_scaler_list[:min_batches]
+
+    #     if os.path.exists(target_scaler_path):
+    #         # 加载旧的scaler列表
+    #         old_target_scaler_list = joblib.load(target_scaler_path)
+    #         # 计算旧的总批次数
+    #         old_total_batches = len(old_target_scaler_list)
+    #         # 只替换重叠部分
+    #         min_batches = min(old_total_batches, total_batches)
+    #         target_scaler_list[:min_batches] = old_target_scaler_list[:min_batches]
     
     
-    chunks = chunk_list(flight_route_list, group_size)
+    # 如果临时处理中断,从日志里找到 中断的索引 修改它
+    resume_chunk_idx = 0
+    chunks = chunks[resume_chunk_idx:]
 
 
-    for idx, group_route_list in enumerate(chunks, start=0):
+    if local_rank == 0:
+        batch_starts = [start_idx for start_idx, _ in chunks]
+        print(f"rank:{rank}, local_rank:{local_rank}, 训练阶段起始索引顺序:{batch_starts}")
+        
+    # 训练阶段
+    for i, (_, group_route_list) in enumerate(chunks, start=resume_chunk_idx):
         # 特殊处理,跳过不好的批次
         # 特殊处理,跳过不好的批次
         pass
         pass
         redis_client.set(lock_key, 0)
         redis_client.set(lock_key, 0)
@@ -169,13 +221,13 @@ def start_train():
             print("rank0 开始数据加载...")
             print("rank0 开始数据加载...")
             # 使用默认配置
             # 使用默认配置
             client, db = mongo_con_parse()
             client, db = mongo_con_parse()
-            print(f"第 {idx} 组 :", group_route_list)
+            print(f"第 {i} 组 :", group_route_list)
 
 
             # 根据索引位置决定是 热门 还是 冷门
             # 根据索引位置决定是 热门 还是 冷门
-            if 0 <= idx < route_len_hot:
+            if 0 <= i < route_len_hot:
                 is_hot = 1
                 is_hot = 1
                 table_name = CLEAN_VJ_HOT_NEAR_INFO_TAB
                 table_name = CLEAN_VJ_HOT_NEAR_INFO_TAB
-            elif route_len_hot <= idx < route_len_hot + route_len_nothot:
+            elif route_len_hot <= i < route_len_hot + route_len_nothot:
                 is_hot = 0
                 is_hot = 0
                 table_name = CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
                 table_name = CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
             else:
             else:
@@ -202,10 +254,29 @@ def start_train():
             print("预处理后数据样本:\n", df_train_inputs.head())
             print("预处理后数据样本:\n", df_train_inputs.head())
 
 
             total_rows = df_train_inputs.shape[0]
             total_rows = df_train_inputs.shape[0]
-
-
-
-
+            print(f"行数: {total_rows}")
+            if total_rows == 0:
+                print(f"预处理后的训练数据为空,跳过此批次。")
+                continue_before_process(redis_client, lock_key)
+                continue
+            
+            # 标准化与归一化处理
+            df_train_inputs, feature_scaler, target_scaler = standardization(df_train_inputs, feature_scaler=None, target_scaler=None)
+
+            # 将 scaler 存入列表
+            batch_idx = i
+            print("batch_idx:", batch_idx)
+            feature_scaler_list[batch_idx] = feature_scaler
+            # target_scaler_list[batch_idx] = target_scaler
+
+            # 每个批次保存一下scaler
+            feature_scaler_path = os.path.join(output_dir, f'feature_scalers.joblib')
+            # target_scaler_path = os.path.join(output_dir, f'target_scalers.joblib')
+            joblib.dump(feature_scaler_list, feature_scaler_path)
+            # joblib.dump(target_scaler_list, target_scaler_path)
+            
+            # 生成序列
+            sequences, targets, group_ids = create_fixed_length_sequences(df_train_inputs, features, target_vars, input_length=452)
             pass
             pass
 
 
         else:
         else:

+ 83 - 0
utils.py

@@ -0,0 +1,83 @@
+import torch
+
+
+# 航线列表分组切片并带上索引
+def chunk_list_with_index(lst, group_size):
+    return [(i, lst[i:i + group_size]) for i in range(0, len(lst), group_size)]
+
+# pandas 在指定列之前插入新列
+def insert_df_col(df, insert_col_name, base_col_name, inplace=True):
+    if not inplace:
+        df = df.copy()
+
+    if base_col_name not in df.columns:
+        raise ValueError(f"base_col_name '{base_col_name}' 不存在")
+
+    if insert_col_name not in df.columns:
+        raise ValueError(f"insert_col_name '{insert_col_name}' 不存在")
+
+    if base_col_name == insert_col_name:
+        return df
+
+    insert_idx = df.columns.get_loc(base_col_name)
+    col_data = df.pop(insert_col_name)
+    df.insert(insert_idx, insert_col_name, col_data)
+
+    return df
+
+# 真正创建序列过程
+def create_fixed_length_sequences(df, features, target_vars, input_length=452, is_train=True):
+    sequences = []
+    targets = []
+    group_ids = []
+
+    threshold = 28   # 截止起飞小时数
+
+    # gid 基于 city_pair, flight_day, flight_number_1, flight_number_2 分组 不包括 baggage
+    grouped = df.groupby(['gid'])
+    for _, df_group in grouped:
+        city_pair = df_group['city_pair'].iloc[0]
+        flight_day = df_group['flight_day'].iloc[0]
+        flight_number_1 = df_group['flight_number_1'].iloc[0]
+        flight_number_2 = df_group['flight_number_2'].iloc[0]
+        dep_time_str = df_group['dep_time_1'].iloc[0].strftime('%Y-%m-%d %H:%M:%S')
+        
+        # 按行李配额分开
+        df_group_bag_30 = df_group[df_group['baggage']==30]
+        df_group_bag_20 = df_group[df_group['baggage']==20]
+
+        # 过滤训练时间段 (28 ~ 480)
+        df_group_bag_30_filtered = df_group_bag_30[(df_group_bag_30['Hours_Until_Departure'] >= threshold) & (df_group_bag_30['Hours_Until_Departure'] < threshold + input_length)]
+        df_group_bag_20_filtered = df_group_bag_20[(df_group_bag_20['Hours_Until_Departure'] >= threshold) & (df_group_bag_20['Hours_Until_Departure'] < threshold + input_length)]
+
+        # 条件: 长度要一致
+        condition_list = [
+            len(df_group_bag_30_filtered) == input_length,
+            len(df_group_bag_20_filtered) == input_length,
+        ]
+        if all(condition_list):
+            seq_features_1 = df_group_bag_30_filtered[features].to_numpy()
+            seq_features_2 = df_group_bag_20_filtered[features].to_numpy()
+            
+            # 将几个特征序列沿着第 0 维拼接,得到形状为 (2, 452, 25)
+            combined_features = torch.stack([torch.tensor(seq_features_1, dtype=torch.float32),    
+                                             torch.tensor(seq_features_2, dtype=torch.float32)])
+
+            # 将拼接后的结果添加到 sequences 列表中
+            sequences.append(combined_features)
+            if is_train and target_vars:
+                seq_targets = df_group_bag_30_filtered[target_vars].iloc[0].to_numpy()
+                targets.append(torch.tensor(seq_targets, dtype=torch.float32))
+            
+            name_c = [city_pair, flight_day, flight_number_1, flight_number_2, dep_time_str]
+            # 直接获取最后一行的相关信息
+            last_row = df_group_bag_30_filtered.iloc[-1]
+            name_c.extend([str(last_row['baggage']),
+                           str(last_row['Adult_Total_Price']), 
+                           str(last_row['Hours_Until_Departure']), 
+                           str(last_row['update_hour'])])
+            group_ids.append(tuple(name_c))
+            pass
+
+        pass
+    pass