Ver código fonte

增加6个价格相关特征 优化数据加载

node04 1 mês atrás
pai
commit
977b8a0867
6 arquivos alterados com 556 adições e 247 exclusões
  1. 374 228
      data_loader.py
  2. 166 7
      data_preprocess.py
  3. 4 4
      main_pe.py
  4. 7 5
      main_tr.py
  5. 2 0
      result_validate.py
  6. 3 3
      utils.py

+ 374 - 228
data_loader.py

@@ -6,6 +6,8 @@ from pymongo.errors import PyMongoError, ServerSelectionTimeoutError
 import pandas as pd
 import os
 import random
+import threading
+from concurrent.futures import ThreadPoolExecutor, as_completed
 import numpy as np
 import matplotlib.pyplot as plt
 from matplotlib import font_manager
@@ -67,7 +69,7 @@ def test_mongo_connection(db):
 
 
 def query_flight_range_status(db, table_name, from_city, to_city, dep_date_begin, dep_date_end, flight_nums, 
-                              limit=0, max_retries=3, base_sleep=1.0):
+                              limit=0, max_retries=3, base_sleep=1.0, thread_id=0):
     """
     从指定表(4类)查询数据(指定起飞天的范围) (失败自动重试)
     """
@@ -132,7 +134,7 @@ def query_flight_range_status(db, table_name, from_city, to_city, dep_date_begin
                 # 1️⃣ 展开 segments
                 print(f"📊 开始扩展segments 稍等...")
                 t1 = time.time()
-                df = expand_segments_columns(df)
+                df = expand_segments_columns_optimized(df)  # 改为调用优化版
                 t2 = time.time()
                 rt = round(t2 - t1, 3)
                 print(f"用时: {rt} 秒")
@@ -157,73 +159,148 @@ def query_flight_range_status(db, table_name, from_city, to_city, dep_date_begin
             time.sleep(sleep_time)
 
 
-def expand_segments_columns(df):
-    """展开 segments"""
+# def expand_segments_columns(df):
+#     """展开 segments"""
+#     df = df.copy()
+
+#     # 定义要展开的列
+#     seg1_cols = ['flight_number', 'dep_air_port', 'dep_time', 'arr_air_port', 'arr_time', 'cabin', 'baggage']
+#     seg2_cols = ['flight_number', 'dep_air_port', 'dep_time', 'arr_air_port', 'arr_time']
+
+#     # 定义 apply 函数一次返回字典
+#     def extract_segments(row):
+#         segments = row.get('segments')
+#         result = {}
+#         # 默认缺失使用 pd.NA(对字符串友好)
+#         missing = pd.NA
+#         if isinstance(segments, list):
+#             # 第一段
+#             if len(segments) >= 1 and isinstance(segments[0], dict):
+#                 for col in seg1_cols:
+#                     result[f'seg1_{col}'] = segments[0].get(col)
+#             else:
+#                 for col in seg1_cols:
+#                     result[f'seg1_{col}'] = missing
+#             # 第二段
+#             if len(segments) >= 2 and isinstance(segments[1], dict):
+#                 for col in seg2_cols:
+#                     result[f'seg2_{col}'] = segments[1].get(col)
+#             else:
+#                 for col in seg2_cols:
+#                     result[f'seg2_{col}'] = missing
+#         else:
+#             # segments 不是 list,全都置空
+#             for col in seg1_cols:
+#                 result[f'seg1_{col}'] = missing
+#             for col in seg2_cols:
+#                 result[f'seg2_{col}'] = missing
+
+#         return pd.Series(result)
+
+#     # 一次 apply
+#     df_segments = df.apply(extract_segments, axis=1)
+
+#     # 拼回原 df
+#     df = pd.concat([df.drop(columns=['segments'], errors='ignore'), df_segments], axis=1)
+
+#     # 统一转换时间字段为 datetime
+#     time_cols = [
+#         'seg1_dep_time', 'seg1_arr_time',
+#         'seg2_dep_time', 'seg2_arr_time'
+#     ]
+#     for col in time_cols:
+#         if col in df.columns:
+#             df[col] = pd.to_datetime(
+#                 df[col],
+#                 format='%Y%m%d%H%M%S',
+#                 errors='coerce'
+#             )
+
+#     # 站点来源 -> 是否近期
+#     df['source_website'] = np.where(
+#         df['source_website'].str.contains('7_30'),
+#         0,  # 远期 -> 0
+#         np.where(df['source_website'].str.contains('0_7'),
+#                  1,  # 近期 -> 1
+#                  df['source_website'])  # 其他情况保持原值
+#     )
+
+#     # 行李配额字符 -> 数字
+#     conditions = [
+#         df['seg1_baggage'] == '-;-;-;-',
+#         df['seg1_baggage'] == '1-20',
+#         df['seg1_baggage'] == '1-30',
+#         df['seg1_baggage'] == '1-40',
+#     ]
+#     choices = [0, 20, 30, 40]
+#     df['seg1_baggage'] = np.select(conditions, choices, default=df['seg1_baggage'])
+
+#     # 重命名字段
+#     df = df.rename(columns={
+#         'seg1_cabin': 'cabin',
+#         'seg1_baggage': 'baggage',
+#         'source_website': 'is_near',
+#     })
+
+#     return df
+
+def expand_segments_columns_optimized(df):
+    """优化版的展开segments函数(避免逐行apply)"""
+    if df.empty:
+        return df
+    
     df = df.copy()
 
-    # 定义要展开的列
-    seg1_cols = ['flight_number', 'dep_air_port', 'dep_time', 'arr_air_port', 'arr_time', 'cabin', 'baggage']
-    seg2_cols = ['flight_number', 'dep_air_port', 'dep_time', 'arr_air_port', 'arr_time']
-
-    # 定义 apply 函数一次返回字典
-    def extract_segments(row):
-        segments = row.get('segments')
-        result = {}
-        # 默认缺失使用 pd.NA(对字符串友好)
-        missing = pd.NA
-        if isinstance(segments, list):
-            # 第一段
-            if len(segments) >= 1 and isinstance(segments[0], dict):
+    # 直接操作segments列表,避免逐行apply
+    if 'segments' in df.columns:
+        # 提取第一段信息
+        seg1_cols = ['flight_number', 'dep_air_port', 'dep_time', 'arr_air_port', 'arr_time', 'cabin', 'baggage']
+        # 提取第二段信息
+        seg2_cols = ['flight_number', 'dep_air_port', 'dep_time', 'arr_air_port', 'arr_time']
+    
+        # 使用列表推导式替代apply,大幅提升性能
+        seg1_data = []
+        seg2_data = []
+
+        for segments in df['segments']:
+            seg1_dict = {}
+            seg2_dict = {}
+
+            if isinstance(segments, list) and len(segments) >= 1 and isinstance(segments[0], dict):
                 for col in seg1_cols:
-                    result[f'seg1_{col}'] = segments[0].get(col)
+                    seg1_dict[f'seg1_{col}'] = segments[0].get(col)
             else:
                 for col in seg1_cols:
-                    result[f'seg1_{col}'] = missing
-            # 第二段
-            if len(segments) >= 2 and isinstance(segments[1], dict):
+                    seg1_dict[f'seg1_{col}'] = pd.NA
+            
+            if isinstance(segments, list) and len(segments) >= 2 and isinstance(segments[1], dict):
                 for col in seg2_cols:
-                    result[f'seg2_{col}'] = segments[1].get(col)
+                    seg2_dict[f'seg2_{col}'] = segments[1].get(col)
             else:
                 for col in seg2_cols:
-                    result[f'seg2_{col}'] = missing
-        else:
-            # segments 不是 list,全都置空
-            for col in seg1_cols:
-                result[f'seg1_{col}'] = missing
-            for col in seg2_cols:
-                result[f'seg2_{col}'] = missing
-
-        return pd.Series(result)
+                    seg2_dict[f'seg2_{col}'] = pd.NA
+                    
+            seg1_data.append(seg1_dict)
+            seg2_data.append(seg2_dict)
 
-    # 一次 apply
-    df_segments = df.apply(extract_segments, axis=1)
+        # 创建DataFrame
+        df_seg1 = pd.DataFrame(seg1_data, index=df.index)
+        df_seg2 = pd.DataFrame(seg2_data, index=df.index)
 
-    # 拼回原 df
-    df = pd.concat([df.drop(columns=['segments'], errors='ignore'), df_segments], axis=1)
+        # 合并到原DataFrame
+        df = pd.concat([df.drop(columns=['segments'], errors='ignore'), df_seg1, df_seg2], axis=1)
 
-    # 统一转换时间字段为 datetime
-    time_cols = [
-        'seg1_dep_time', 'seg1_arr_time',
-        'seg2_dep_time', 'seg2_arr_time'
-    ]
+    # 后续处理保持不变
+    time_cols = ['seg1_dep_time', 'seg1_arr_time', 'seg2_dep_time', 'seg2_arr_time']
     for col in time_cols:
         if col in df.columns:
-            df[col] = pd.to_datetime(
-                df[col],
-                format='%Y%m%d%H%M%S',
-                errors='coerce'
-            )
-
-    # 站点来源 -> 是否近期
+            df[col] = pd.to_datetime(df[col], format='%Y%m%d%H%M%S', errors='coerce')
+    
     df['source_website'] = np.where(
-        df['source_website'].str.contains('7_30'),
-        0,  # 远期 -> 0
-        np.where(df['source_website'].str.contains('0_7'),
-                 1,  # 近期 -> 1
-                 df['source_website'])  # 其他情况保持原值
+        df['source_website'].str.contains('7_30'), 0,
+        np.where(df['source_website'].str.contains('0_7'), 1, df['source_website'])
     )
 
-    # 行李配额字符 -> 数字
     conditions = [
         df['seg1_baggage'] == '-;-;-;-',
         df['seg1_baggage'] == '1-20',
@@ -233,13 +310,12 @@ def expand_segments_columns(df):
     choices = [0, 20, 30, 40]
     df['seg1_baggage'] = np.select(conditions, choices, default=df['seg1_baggage'])
 
-    # 重命名字段
     df = df.rename(columns={
         'seg1_cabin': 'cabin',
         'seg1_baggage': 'baggage',
         'source_website': 'is_near',
     })
-
+    
     return df
 
 
@@ -564,152 +640,217 @@ def plot_c12_trend(df, output_dir="."):
     plt.close(fig)
 
 
-def load_train_data(db, flight_route_list, table_name, date_begin, date_end, output_dir='.', is_hot=1):
-    """加载训练数据"""
+def process_flight_group(args):
+    """处理单个航班号的线程函数(独立数据库连接)"""
+    thread_id, db_config, each_group, from_city, to_city, date_begin_s, date_end_s, is_hot, plot_flag, output_dir = args
+    flight_nums = each_group.get("flight_numbers")
+    details = each_group.get("details")
+
+    print(f"[线程{thread_id}] 开始处理航班号: {flight_nums}")
+
+    # 为每个线程创建独立的数据库连接
+    try:
+        client, db = mongo_con_parse(db_config)
+        print(f"[线程{thread_id}] ✅ 数据库连接创建成功")
+    except Exception as e:
+        print(f"[线程{thread_id}] ❌ 数据库连接创建失败: {e}")
+        return pd.DataFrame()
+
+    try:
+        # 查询远期表
+        if is_hot == 1:
+            df1 = query_flight_range_status(db, CLEAN_VJ_HOT_FAR_INFO_TAB, from_city, to_city,
+                                            date_begin_s, date_end_s, flight_nums)
+        else:
+            df1 = query_flight_range_status(db, CLEAN_VJ_NOTHOT_FAR_INFO_TAB, from_city, to_city,
+                                            date_begin_s, date_end_s, flight_nums)
+        
+        # 保证远期表里有数据
+        if df1.empty:
+            print(f"[线程{thread_id}] 航班号:{flight_nums} 远期表无数据, 跳过")
+            return pd.DataFrame()
+        
+        # 查询近期表
+        if is_hot == 1:
+            df2 = query_flight_range_status(db, CLEAN_VJ_HOT_NEAR_INFO_TAB, from_city, to_city,
+                                            date_begin_s, date_end_s, flight_nums)
+        else:
+            df2 = query_flight_range_status(db, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB, from_city, to_city,
+                                            date_begin_s, date_end_s, flight_nums)
+            
+        # 保证近期表里有数据
+        if df2.empty:
+            print(f"[线程{thread_id}] 航班号:{flight_nums} 近期表无数据, 跳过")
+            return pd.DataFrame()
+        
+        # 起飞天数、行李配额以近期表的为主
+        if df2.empty:
+            common_dep_dates = []
+            common_baggages = []
+        else:
+            common_dep_dates = df2['search_dep_time'].unique()
+            common_baggages = df2['baggage'].unique()
+
+        list_mid = []
+        for dep_date in common_dep_dates:
+            # 起飞日期筛选
+            df_d1 = df1[df1["search_dep_time"] == dep_date].copy()
+            if not df_d1.empty:
+                for col in ["seg1_dep_time", "seg1_arr_time", "seg2_dep_time", "seg2_arr_time"]:
+                    mode_series_1 = df_d1[col].mode()
+                    if mode_series_1.empty:
+                        zong_1 = pd.NaT
+                    else:
+                        zong_1 = mode_series_1.iloc[0]
+                    df_d1[col] = zong_1
+
+            df_d2 = df2[df2["search_dep_time"] == dep_date].copy()
+            if not df_d2.empty:
+                for col in ["seg1_dep_time", "seg1_arr_time", "seg2_dep_time", "seg2_arr_time"]:
+                    mode_series_2 = df_d2[col].mode()
+                    if mode_series_2.empty:
+                        zong_2 = pd.NaT
+                    else:
+                        zong_2 = mode_series_2.iloc[0]
+                    df_d2[col] = zong_2
+
+            list_12 = []
+            for baggage in common_baggages:
+                # 行李配额筛选
+                df_b1 = df_d1[df_d1["baggage"] == baggage].copy()
+                df_b2 = df_d2[df_d2["baggage"] == baggage].copy()
+
+                # 合并前检查是否都有数据
+                if df_b1.empty and df_b2.empty:
+                    print(f"[线程{thread_id}] ⚠️ dep_date:{dep_date}, baggage:{baggage} 远期表和近期表都为空,跳过")
+                    continue
+
+                cols = ["seg1_flight_number", "seg1_dep_air_port", "seg1_arr_air_port",
+                        "seg2_flight_number", "seg2_dep_air_port", "seg2_arr_air_port"]
+                df_b1[cols] = df_b1[cols].astype("string")
+                df_b2[cols] = df_b2[cols].astype("string")
+
+                df_b12 = pd.concat([df_b1, df_b2]).reset_index(drop=True)
+                # print(f"📊 dep_date:{dep_date}, baggage:{baggage} 已将远期表和近期表合并,形状: {df_b12.shape}")
+                df_b12 = fill_hourly_crawl_date(df_b12, rear_fill=2)
+                # print(f"📊 dep_date:{dep_date}, baggage:{baggage} 已合并且补齐为完整小时序列,形状: {df_b12.shape}")
+                list_12.append(df_b12)
+
+                del df_b12
+                del df_b2
+                del df_b1
+
+            if list_12:
+                df_c12 = pd.concat(list_12, ignore_index=True)
+                if plot_flag:
+                    print(f"[线程{thread_id}] ✅ dep_date:{dep_date}, 所有 baggage 数据合并完成,总形状: {df_c12.shape}")
+                    plot_c12_trend(df_c12, output_dir)
+                    print(f"[线程{thread_id}] ✅ dep_date:{dep_date}, 所有 baggage 数据绘图完成")
+            else:
+                df_c12 = pd.DataFrame()
+                if plot_flag:
+                    print(f"[线程{thread_id}] ⚠️ dep_date:{dep_date}, 所有 baggage 数据合并为空")
+
+            del list_12
+            list_mid.append(df_c12)
+
+            del df_c12
+            del df_d1
+            del df_d2
+            # print(f"结束处理起飞日期: {dep_date}")
+
+        if list_mid:
+            df_mid = pd.concat(list_mid, ignore_index=True)
+            print(f"[线程{thread_id}] ✅ 航班号:{flight_nums} 所有 起飞日期 数据合并完成,总形状: {df_mid.shape}")
+        else:
+            df_mid = pd.DataFrame()
+            print(f"[线程{thread_id}] ⚠️ 航班号:{flight_nums} 所有 起飞日期 数据合并为空")
+        
+        del list_mid
+        del df1
+        del df2
+        gc.collect()
+        print(f"[线程{thread_id}] 结束处理航班号: {flight_nums}")
+        return df_mid
+    
+    except Exception as e:
+        print(f"[线程{thread_id}] ❌ 处理航班号:{flight_nums} 时发生异常: {e}")
+        return pd.DataFrame()
+    finally:
+        # 确保关闭数据库连接
+        try:
+            client.close()
+            print(f"[线程{thread_id}] ✅ 数据库连接已关闭")
+        except:
+            pass
+
+
+def load_train_data(db_config, flight_route_list, table_name, date_begin, date_end, output_dir='.', is_hot=1, plot_flag=False,
+                    use_multithread=False, max_workers=None):
+    """加载训练数据(支持多线程)"""
     timestamp_str = datetime.now().strftime("%Y%m%d%H%M%S")
     date_begin_s = datetime.strptime(date_begin, "%Y-%m-%d").strftime("%Y%m%d")  # 查询时的格式
     date_end_s = datetime.strptime(date_end, "%Y-%m-%d").strftime("%Y%m%d")
     list_all = []
+
     # 每一航线对
     for flight_route in flight_route_list:
         from_city = flight_route.split('-')[0]
         to_city = flight_route.split('-')[1]
         route = f"{from_city}-{to_city}"
         print(f"开始处理航线: {route}")
-        all_groups = query_groups_of_city_code(db, from_city, to_city, table_name)
-        all_groups_len = len(all_groups)
-        print(f"该航线共有{all_groups_len}个航班号")
-        # 每一组航班号
-        for each_group in all_groups:
-            flight_nums = each_group.get("flight_numbers")
-            print(f"开始处理航班号: {flight_nums}")
-            details = each_group.get("details")
-
-            print(f"查远期表")
-            if is_hot == 1:
-                df1 = query_flight_range_status(db, CLEAN_VJ_HOT_FAR_INFO_TAB, from_city, to_city,
-                                                date_begin_s, date_end_s, flight_nums)
-            else:
-                df1 = query_flight_range_status(db, CLEAN_VJ_NOTHOT_FAR_INFO_TAB, from_city, to_city,
-                                                date_begin_s, date_end_s, flight_nums)
-
-            # 保证远期表里有数据
-            if df1.empty:
-                print(f"航班号:{flight_nums} 远期表无数据, 跳过")
-                continue
-
-            print(f"查近期表")
-            if is_hot == 1:
-                df2 = query_flight_range_status(db, CLEAN_VJ_HOT_NEAR_INFO_TAB, from_city, to_city,
-                                                date_begin_s, date_end_s, flight_nums)
-            else:
-                df2 = query_flight_range_status(db, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB, from_city, to_city,
-                                                date_begin_s, date_end_s, flight_nums)
-
-            # 保证近期表里有数据
-            if df2.empty:
-                print(f"航班号:{flight_nums} 近期表无数据, 跳过")
-                continue
-
-            # 起飞天数、行李配额以近期表的为主
-            if df2.empty:
-                common_dep_dates = []
-                common_baggages = []
-            else:
-                common_dep_dates = df2['search_dep_time'].unique()
-                common_baggages = df2['baggage'].unique()
-
-            list_mid = []
-            for dep_date in common_dep_dates:
-                # 起飞日期筛选
-                df_d1 = df1[df1["search_dep_time"] == dep_date].copy()
-                if not df_d1.empty:
-                    for col in ["seg1_dep_time", "seg1_arr_time", "seg2_dep_time", "seg2_arr_time"]:
-                        mode_series_1 = df_d1[col].mode()
-                        if mode_series_1.empty:
-                            # 如果整个列都是 NaT,则众数为空,直接赋 NaT
-                            zong_1 = pd.NaT
-                        else:
-                            zong_1 = mode_series_1.iloc[0]
-                        df_d1[col] = zong_1
-
-                df_d2 = df2[df2["search_dep_time"] == dep_date].copy()
-                if not df_d2.empty:
-                    for col in ["seg1_dep_time", "seg1_arr_time", "seg2_dep_time", "seg2_arr_time"]:
-                        mode_series_2 = df_d2[col].mode()
-                        if mode_series_2.empty:
-                            # 如果整个列都是 NaT,则众数为空,直接赋 NaT
-                            zong_2 = pd.NaT
-                        else:
-                            zong_2 = mode_series_2.iloc[0]
-                        df_d2[col] = zong_2
-
-                list_12 = []
-                for baggage in common_baggages:
-                    # 行李配额筛选
-                    df_b1 = df_d1[df_d1["baggage"] == baggage].copy()
-                    df_b2 = df_d2[df_d2["baggage"] == baggage].copy()
-
-                    # 合并前检查是否都有数据
-                    if df_b1.empty and df_b2.empty:
-                        print(f"⚠️ dep_date:{dep_date}, baggage:{baggage} 远期表和近期表都为空,跳过")
-                        continue
-
-                    cols = ["seg1_flight_number", "seg1_dep_air_port", "seg1_arr_air_port",
-                            "seg2_flight_number", "seg2_dep_air_port", "seg2_arr_air_port"]
-                    # df_b1 = df_b1.copy()
-                    # df_b2 = df_b2.copy()
-                    df_b1[cols] = df_b1[cols].astype("string")
-                    df_b2[cols] = df_b2[cols].astype("string")
-
-                    df_b12 = pd.concat([df_b1, df_b2]).reset_index(drop=True)
-                    # print(f"📊 dep_date:{dep_date}, baggage:{baggage} 已将远期表和近期表合并,形状: {df_b12.shape}")
-                    df_b12 = fill_hourly_crawl_date(df_b12, rear_fill=2)
-                    # print(f"📊 dep_date:{dep_date}, baggage:{baggage} 已合并且补齐为完整小时序列,形状: {df_b12.shape}")
-                    # print(df_b12.dtypes)
-                    list_12.append(df_b12)
-                    del df_b12
-                    del df_b2
-                    del df_b1
-
-                if list_12:
-                    df_c12 = pd.concat(list_12, ignore_index=True)
-                    # print(f"✅ dep_date:{dep_date}, 所有 baggage 数据合并完成,总形状: {df_c12.shape}")
-                    # plot_c12_trend(df_c12, output_dir)
-                    # print(f"✅ dep_date:{dep_date}, 所有 baggage 数据绘图完成")
-                else:
-                    df_c12 = pd.DataFrame()
-                    # print(f"⚠️ dep_date:{dep_date}, 所有 baggage 数据合并为空")
-
-                del list_12
-                list_mid.append(df_c12)
-
-                del df_c12
-                del df_d1
-                del df_d2
-
-                # print(f"结束处理起飞日期: {dep_date}")
-
-            if list_mid:
-                df_mid = pd.concat(list_mid, ignore_index=True)
-                print(f"✅ 航班号:{flight_nums} 所有 起飞日期 数据合并完成,总形状: {df_mid.shape}")
-            else:
-                df_mid = pd.DataFrame()
-                print(f"⚠️ 航班号:{flight_nums} 所有 起飞日期 数据合并为空")
 
-            del list_mid
-            list_all.append(df_mid)
+        # 在主线程中查询航班号分组(避免多线程重复查询)
+        main_client, main_db = mongo_con_parse(db_config)
+        all_groups = query_groups_of_city_code(main_db, from_city, to_city, table_name)
+        main_client.close()
 
-            del df1
-            del df2
-            
-            # output_path = os.path.join(output_dir, f"./{route}_{timestamp_str}.csv")
-            # df_mid.to_csv(output_path, index=False, encoding="utf-8-sig", mode="a", header=not os.path.exists(output_path))
+        all_groups_len = len(all_groups)
+        print(f"该航线共有{all_groups_len}个航班号")
+        
+        if use_multithread and all_groups_len > 1:
+            print(f"启用多线程处理,最大线程数: {max_workers}")
+            # 多线程处理
+            thread_args = []
+            thread_id = 0
+            for each_group in all_groups:
+                thread_id += 1
+                args = (thread_id, db_config, each_group, from_city, to_city, date_begin_s, date_end_s, is_hot, plot_flag, output_dir)
+                thread_args.append(args)
             
-            del df_mid
-            gc.collect()
-            print(f"结束处理航班号: {flight_nums}")
+            with ThreadPoolExecutor(max_workers=max_workers) as executor:
+                future_to_group = {executor.submit(process_flight_group, args): each_group for args, each_group in zip(thread_args, all_groups)}
+                
+                for future in as_completed(future_to_group):
+                    each_group = future_to_group[future]
+                    flight_nums = each_group.get("flight_numbers", "未知")
+                    try:
+                        df_mid = future.result()
+                        if not df_mid.empty:
+                            list_all.append(df_mid)
+                            print(f"✅ 航班号:{flight_nums} 处理完成")
+                        else:
+                            print(f"⚠️ 航班号:{flight_nums} 处理结果为空")
+                    except Exception as e:
+                        print(f"❌ 航班号:{flight_nums} 处理异常: {e}")
 
+        else:
+            # 单线程处理(线程编号为0)
+            print("使用单线程处理")
+            thread_id = 0
+            for each_group in all_groups:
+                args = (thread_id, db_config, each_group, from_city, to_city, date_begin_s, date_end_s, is_hot, plot_flag, output_dir)
+                flight_nums = each_group.get("flight_numbers", "未知")
+                try:
+                    df_mid = process_flight_group(args)
+                    if not df_mid.empty:
+                        list_all.append(df_mid)
+                        print(f"✅ 航班号:{flight_nums} 处理完成")
+                    else:
+                        print(f"⚠️ 航班号:{flight_nums} 处理结果为空")
+                except Exception as e:
+                    print(f"❌ 航班号:{flight_nums} 处理异常: {e}")
+                
         print(f"结束处理航线: {from_city}-{to_city}")
 
     if list_all:
@@ -828,7 +969,7 @@ def validate_one_line(db, city_pair, flight_day, flight_number_1, flight_number_
                 # 1️⃣ 展开 segments
                 print(f"📊 开始扩展segments 稍等...")
                 t1 = time.time()
-                df = expand_segments_columns(df)
+                df = expand_segments_columns_optimized(df)
                 t2 = time.time()
                 rt = round(t2 - t1, 3)
                 print(f"用时: {rt} 秒")
@@ -856,46 +997,51 @@ def validate_one_line(db, city_pair, flight_day, flight_number_1, flight_number_
 if __name__ == "__main__":
 
     # test_mongo_connection(db)
-
-    # output_dir = f"./output"
-    # os.makedirs(output_dir, exist_ok=True)
-
-    # # 加载热门航线数据
-    # date_begin = "2025-11-20"
-    # date_end = datetime.today().strftime("%Y-%m-%d")
-
-    # flight_route_list = vj_flight_route_list_hot[0:]  # 热门 vj_flight_route_list_hot  冷门 vj_flight_route_list_nothot
-    # table_name = CLEAN_VJ_HOT_NEAR_INFO_TAB  # 热门 CLEAN_VJ_HOT_NEAR_INFO_TAB  冷门 CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
-    # is_hot = 1   # 1 热门 0 冷门
-    # group_size = 1
-    # chunks = chunk_list(flight_route_list, group_size)
-
-    # for idx, group_route_list in enumerate(chunks, 1):
-    #     # 使用默认配置
-    #     client, db = mongo_con_parse()
-    #     print(f"第 {idx} 组 :", group_route_list)
-    #     start_time = time.time()
-    #     load_train_data(db, group_route_list, table_name, date_begin, date_end, output_dir, is_hot)
-    #     end_time = time.time()
-    #     run_time = round(end_time - start_time, 3)
-    #     print(f"用时: {run_time} 秒")
-
-    #     client.close()
-    #     time.sleep(3)
-
-    # print("整体结束")
-
-    client, db = mongo_con_parse()
-    list_flight_number_1 = query_all_flight_number(db, CLEAN_VJ_HOT_NEAR_INFO_TAB)
-    list_flight_number_2 = query_all_flight_number(db, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB)
-
-    list_flight_number_all = list_flight_number_1 + list_flight_number_2
-    list_flight_number_all = list(set(list_flight_number_all))
-    list_flight_number_all.sort()
+    from utils import chunk_list_with_index
+
+    cpu_cores = os.cpu_count()  # 你的系统是72
+    max_workers = min(16, cpu_cores)  # 最大不超过16个线程
+
+    output_dir = f"./output"
+    os.makedirs(output_dir, exist_ok=True)
+
+    # 加载热门航线数据
+    date_begin = "2025-12-07"
+    date_end = datetime.today().strftime("%Y-%m-%d")
+
+    flight_route_list = vj_flight_route_list_hot[0:]  # 热门 vj_flight_route_list_hot  冷门 vj_flight_route_list_nothot
+    table_name = CLEAN_VJ_HOT_NEAR_INFO_TAB  # 热门 CLEAN_VJ_HOT_NEAR_INFO_TAB  冷门 CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
+    is_hot = 1   # 1 热门 0 冷门
+    group_size = 1
+    chunks = chunk_list_with_index(flight_route_list, group_size)
+
+    for idx, (_, group_route_list) in enumerate(chunks, 1):
+        # 使用默认配置
+        # client, db = mongo_con_parse()
+        print(f"第 {idx} 组 :", group_route_list)
+        start_time = time.time()
+        load_train_data(mongodb_config, group_route_list, table_name, date_begin, date_end, output_dir, is_hot, plot_flag=False,
+                        use_multithread=False, max_workers=max_workers)
+        end_time = time.time()
+        run_time = round(end_time - start_time, 3)
+        print(f"用时: {run_time} 秒")
+
+        # client.close()
+        time.sleep(3)
+
+    print("整体结束")
+
+    # client, db = mongo_con_parse()
+    # list_flight_number_1 = query_all_flight_number(db, CLEAN_VJ_HOT_NEAR_INFO_TAB)
+    # list_flight_number_2 = query_all_flight_number(db, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB)
+
+    # list_flight_number_all = list_flight_number_1 + list_flight_number_2
+    # list_flight_number_all = list(set(list_flight_number_all))
+    # list_flight_number_all.sort()
     
-    print(list_flight_number_all)
-    print(len(list_flight_number_all))
+    # print(list_flight_number_all)
+    # print(len(list_flight_number_all))
 
-    flight_map = {v: i for i, v in enumerate(list_flight_number_all, start=1)}
-    print(flight_map)
+    # flight_map = {v: i for i, v in enumerate(list_flight_number_all, start=1)}
+    # print(flight_map)
     

+ 166 - 7
data_preprocess.py

@@ -10,7 +10,7 @@ from utils import insert_df_col
 COUNTRY_HOLIDAYS = build_country_holidays(city_to_country)
 
 
-def preprocess_data(df_input, features, categorical_features, is_training=True, current_n_hours=48):
+def preprocess_data(df_input, features, categorical_features, is_training=True, current_n_hours=36):
     print(">>> 开始数据预处理") 
 
     # 生成 城市对
@@ -110,10 +110,10 @@ def preprocess_data(df_input, features, categorical_features, is_training=True,
         .ngroup()
     )
 
-    # 做一下时间段裁剪, 保留起飞前480小时之内的
-    df_input = df_input[df_input['hours_until_departure'] < 480].reset_index(drop=True)
-    pass
-
+    # 做一下时间段裁剪, 保留起飞前480小时之内且大于等于4小时
+    df_input = df_input[(df_input['hours_until_departure'] < 480) & 
+                        (df_input['hours_until_departure'] >= 4)].reset_index(drop=True)
+    
     # 在 gid 与 baggage 内按时间降序
     df_input = df_input.sort_values(
         by=['gid', 'baggage', 'hours_until_departure'],
@@ -161,6 +161,160 @@ def preprocess_data(df_input, features, categorical_features, is_training=True,
     df_input = df_input[new_order]
     pass
 
+    print(">>> 计算价格区间特征")
+    #  1. 基于绝对价格水平的价格区间划分
+    # 先计算每个(gid, baggage)的价格统计特征
+    # g = df_input.groupby(['gid', 'baggage'])
+    price_stats = df_input.groupby(['gid', 'baggage'])['adult_total_price'].agg(
+        min_price='min',
+        max_price='max',
+        mean_price='mean',
+        std_price='std'
+    ).reset_index()
+
+    # 合并统计特征到原数据
+    df_input = df_input.merge(price_stats, on=['gid', 'baggage'], how='left')
+
+    # 2. 基于绝对价格的价格区间划分 (可以删除,因为后面有更精细的基于频率加权的分类)
+    # # 高价区间:超过均值+1倍标准差
+    # df_input['price_absolute_high'] = (df_input['adult_total_price'] > 
+    #                                   (df_input['mean_price'] + df_input['std_price'])).astype(int)
+
+    # # 中高价区间:均值到均值+1倍标准差
+    # df_input['price_absolute_mid_high'] = ((df_input['adult_total_price'] > df_input['mean_price']) & 
+    #                                        (df_input['adult_total_price'] <= (df_input['mean_price'] + df_input['std_price']))).astype(int)
+
+    # # 中低价区间:均值-1倍标准差到均值
+    # df_input['price_absolute_mid_low'] = ((df_input['adult_total_price'] > (df_input['mean_price'] - df_input['std_price'])) & 
+    #                                       (df_input['adult_total_price'] <= df_input['mean_price'])).astype(int)
+
+    # # 低价区间:低于均值-1倍标准差
+    # df_input['price_absolute_low'] = (df_input['adult_total_price'] <= (df_input['mean_price'] - df_input['std_price'])).astype(int)
+
+    # 3. 基于频率加权的价格百分位数(改进版)
+    # 计算每个价格出现的频率
+    price_freq = df_input.groupby(['gid', 'baggage', 'adult_total_price']).size().reset_index(name='price_frequency')
+    df_input = df_input.merge(price_freq, on=['gid', 'baggage', 'adult_total_price'], how='left')
+
+    # 计算频率加权的百分位数
+    def weighted_percentile(group):
+        if len(group) == 0:
+            return pd.Series([np.nan] * 4, index=['price_weighted_percentile_25', 
+                                                'price_weighted_percentile_50', 
+                                                'price_weighted_percentile_75', 
+                                                'price_weighted_percentile_90'])
+        
+        # 按价格排序,计算累积频率
+        group = group.sort_values('adult_total_price')
+        group['cum_freq'] = group['price_frequency'].cumsum()
+        total_freq = group['price_frequency'].sum()
+        
+        # 计算加权百分位数
+        percentiles = []
+        for p in [0.25, 0.5, 0.75, 0.9]:
+            threshold = total_freq * p
+            # 找到第一个累积频率超过阈值的价格
+            mask = group['cum_freq'] >= threshold
+            if mask.any():
+                percentile_value = group.loc[mask.idxmax(), 'adult_total_price']
+            else:
+                percentile_value = group['adult_total_price'].max()
+            percentiles.append(percentile_value)
+        
+        return pd.Series(percentiles, index=['price_weighted_percentile_25', 
+                                             'price_weighted_percentile_50', 
+                                             'price_weighted_percentile_75', 
+                                             'price_weighted_percentile_90'])
+        
+    # 按gid和baggage分组计算加权百分位数
+    weighted_percentiles = df_input.groupby(['gid', 'baggage']).apply(weighted_percentile).reset_index()
+    df_input = df_input.merge(weighted_percentiles, on=['gid', 'baggage'], how='left')
+
+    # 4. 结合绝对价格和频率的综合判断(改进版)
+    freq_median = df_input.groupby(['gid', 'baggage'])['price_frequency'].transform('median')
+
+    # 计算价格相对于90%百分位数的倍数,用于区分不同级别的高价
+    df_input['price_relative_to_90p'] = df_input['adult_total_price'] / df_input['price_weighted_percentile_90']
+
+    # 添加价格容忍度:避免相近价格被分到不同区间
+    # 计算价格差异容忍度(使用各百分位数的1%作为容忍度阈值)
+    # tolerance_90p = df_input['price_weighted_percentile_90'] * 0.01
+    tolerance_75p = df_input['price_weighted_percentile_75'] * 0.01
+    tolerance_50p = df_input['price_weighted_percentile_50'] * 0.01
+    tolerance_25p = df_input['price_weighted_percentile_25'] * 0.01
+
+    # 重新设计价格区间分类(确保无重叠):
+    # 首先定义各个区间的mask
+
+    # 4.1 异常高价:价格远高于90%百分位数(超过1.5倍)且频率极低(低于中位数的1/3)
+    price_abnormal_high_mask = ((df_input['price_relative_to_90p'] > 1.5) & 
+                                (df_input['price_frequency'] < freq_median * 0.33))
+
+    # 4.2 真正高位:严格满足条件(价格 > 90%分位数 且 频率 < 中位数)
+    price_real_high_mask = ((df_input['adult_total_price'] > df_input['price_weighted_percentile_90']) & 
+                            (df_input['price_frequency'] < freq_median) &
+                             ~price_abnormal_high_mask)
+
+    # 4.3 正常高位:使用容忍度(价格接近75%分位数)
+    price_normal_high_mask = ((df_input['adult_total_price'] > df_input['price_weighted_percentile_75'] - tolerance_75p) & 
+                               ~price_real_high_mask & ~price_abnormal_high_mask)
+
+    # 4.4 中高价:使用容忍度(价格在50%-75%分位数之间)
+    price_mid_high_mask = ((df_input['adult_total_price'] > df_input['price_weighted_percentile_50'] - tolerance_50p) & 
+                           (df_input['adult_total_price'] <= df_input['price_weighted_percentile_75'] + tolerance_75p) &
+                            ~price_normal_high_mask & ~price_real_high_mask & ~price_abnormal_high_mask)
+
+    # 4.5 中低价:使用容忍度(价格在25%-50%分位数之间)
+    price_mid_low_mask = ((df_input['adult_total_price'] > df_input['price_weighted_percentile_25'] - tolerance_25p) & 
+                          (df_input['adult_total_price'] <= df_input['price_weighted_percentile_50'] + tolerance_50p) &
+                           ~price_mid_high_mask & ~price_normal_high_mask & ~price_real_high_mask & ~price_abnormal_high_mask)
+    
+    # 4.6 低价:严格满足条件(价格 ≤ 25%分位数) 
+    price_low_mask = ((df_input['adult_total_price'] <= df_input['price_weighted_percentile_25']) &
+                 ~price_mid_low_mask & ~price_mid_high_mask & ~price_normal_high_mask & ~price_real_high_mask & ~price_abnormal_high_mask)
+
+    # 使用np.select确保互斥性
+    price_zone_masks = [
+        price_abnormal_high_mask,  # 异常高价区(5级)
+        price_real_high_mask,      # 真正高价区(4级)
+        price_normal_high_mask,    # 正常高价区(3级)
+        price_mid_high_mask,       # 中高价区(2级)
+        price_mid_low_mask,        # 中低价区(1级)
+        price_low_mask,            # 低价区(0级)
+    ]  
+    price_zone_values = [5, 4, 3, 2, 1, 0]  # 5:异常高价, 4:真正高价, 3:正常高价, 2:中高价, 1:中低价, 0:低价 
+
+    # 使用np.select确保每个价格只被分到一个区间
+    price_zone_result = np.select(price_zone_masks, price_zone_values, default=2)  # 默认中高价
+    # 4.8 价格区间综合标记
+    df_input['price_zone_comprehensive'] = price_zone_result
+
+    # 5. 价格异常度检测
+    # 价格相对于均值的标准化偏差
+    df_input['price_z_score'] = (df_input['adult_total_price'] - df_input['mean_price']) / df_input['std_price']
+    
+    # 价格异常度:基于Z-score的绝对值
+    df_input['price_anomaly_score'] = np.abs(df_input['price_z_score'])
+    
+    # 6. 价格稳定性特征
+    # 计算价格波动系数(标准差/均值)
+    df_input['price_coefficient_variation'] = df_input['std_price'] / df_input['mean_price']
+
+    # 7. 价格趋势特征
+    # 计算当前价格相对于历史价格的位置
+    df_input['price_relative_position'] = (df_input['adult_total_price'] - df_input['min_price']) / (df_input['max_price'] - df_input['min_price'])
+    df_input['price_relative_position'] = df_input['price_relative_position'].fillna(0.5)  # 兜底
+
+    # 删除中间计算列
+    df_input.drop(columns=['price_frequency', 'price_z_score', 'price_relative_to_90p'], inplace=True, errors='ignore')
+
+    del price_freq
+    del price_stats
+    del weighted_percentiles
+    del freq_median
+
+    print(">>> 改进版价格区间特征计算完成")
+
     # 生成第一机场对
     df_input['airport_pair_1'] = (
         df_input['seg1_dep_air_port'].astype(str) + "-" + df_input['seg1_arr_air_port'].astype(str)
@@ -492,6 +646,8 @@ def preprocess_data(df_input, features, categorical_features, is_training=True,
         "flight_number_2", "flight_2_num", "airport_pair_2", "dep_time_2", "arr_time_2", "fly_duration_2", "fly_duration", "stop_duration", 
         "global_dep_time", "dep_country", "dep_country_is_holiday", "is_cross_country",
         "global_arr_time", "arr_country", "arr_country_is_holiday", "any_country_is_holiday",
+        "price_weighted_percentile_25", "price_weighted_percentile_50", "price_weighted_percentile_75", "price_weighted_percentile_90",
+        "price_zone_comprehensive", "price_relative_position",
     ]
     df_input = df_input[order_columns]
     
@@ -502,7 +658,9 @@ def standardization(df, feature_scaler, target_scaler=None, is_training=True, is
     print(">>> 开始标准化处理")
 
     # 准备走标准化的特征
-    scaler_features = ['adult_total_price', 'fly_duration', 'stop_duration']
+    scaler_features = ['adult_total_price', 'fly_duration', 'stop_duration', 
+                       'price_weighted_percentile_25', 'price_weighted_percentile_50', 
+                       'price_weighted_percentile_75', 'price_weighted_percentile_90']
     
     if is_training:
         print(">>> 特征数据标准化开始")
@@ -527,7 +685,8 @@ def standardization(df, feature_scaler, target_scaler=None, is_training=True, is
         'flight_2_num': (0, 341),
         'seats_remaining': (1, 5),
         'price_change_times_total': (0, 30),     # 假设价格变更次数不会超过30次
-        'price_last_change_hours': (0, 480),     
+        'price_last_change_hours': (0, 480), 
+        'price_zone_comprehensive': (0, 5),    
         'days_to_departure': (0, 30),
         'days_to_holiday': (0, 120),             # 最长的越南节假日间隔120天
         'flight_by_hour': (0, 23),

+ 4 - 4
main_pe.py

@@ -6,7 +6,7 @@ import numpy as np
 import pickle
 import time
 from datetime import datetime, timedelta
-from config import vj_flight_route_list_hot, vj_flight_route_list_nothot, CLEAN_VJ_HOT_NEAR_INFO_TAB, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
+from config import mongodb_config, vj_flight_route_list_hot, vj_flight_route_list_nothot, CLEAN_VJ_HOT_NEAR_INFO_TAB, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
 from data_loader import mongo_con_parse, load_train_data
 from data_preprocess import preprocess_data, standardization
 from utils import chunk_list_with_index, create_fixed_length_sequences
@@ -85,7 +85,7 @@ def start_predict():
     # 测试阶段
     for i, (_, group_route_list) in enumerate(chunks, start=resume_chunk_idx):
         # 特殊处理,跳过不好的批次
-        client, db = mongo_con_parse()
+        # client, db = mongo_con_parse()
         print(f"第 {i} 组 :", group_route_list)
         # batch_flight_routes = group_route_list
 
@@ -102,12 +102,12 @@ def start_predict():
         
         # 加载测试数据 (仅仅是时间段取到后天)
         start_time = time.time()
-        df_test = load_train_data(db, group_route_list, table_name, date_begin, date_end, output_dir, is_hot)
+        df_test = load_train_data(mongodb_config, group_route_list, table_name, date_begin, date_end, output_dir, is_hot)
         end_time = time.time()
         run_time = round(end_time - start_time, 3)
         print(f"用时: {run_time} 秒")
 
-        client.close()
+        # client.close()
 
         if df_test.empty:
             print(f"测试数据为空,跳过此批次。")

+ 7 - 5
main_tr.py

@@ -38,9 +38,11 @@ common_features = ['hours_until_departure', 'days_to_departure', 'seats_remainin
                    'flight_by_hour', 'flight_by_day', 'flight_day_of_month', 'flight_day_of_week', 'flight_day_of_quarter', 'flight_day_is_weekend',
                    'dep_country_is_holiday', 'arr_country_is_holiday', 'any_country_is_holiday', 'days_to_holiday',
                   ]
+price_info_features = ['price_weighted_percentile_25', 'price_weighted_percentile_50', 'price_weighted_percentile_75', 'price_weighted_percentile_90',
+                       'price_zone_comprehensive', 'price_relative_position']
 price_features = ['adult_total_price', 'price_change_times_total', 'price_last_change_hours']
 encoded_columns = ['from_city_num', 'to_city_num', 'flight_1_num', 'flight_2_num', 'baggage_level']
-features = encoded_columns + price_features + common_features
+features = encoded_columns + price_info_features + price_features + common_features
 target_vars = ['target_will_price_drop']   # 是否降价
 
 
@@ -111,7 +113,7 @@ def start_train():
 
     date_end = datetime.today().strftime("%Y-%m-%d")
     # date_begin = (datetime.today() - timedelta(days=41)).strftime("%Y-%m-%d")
-    date_begin = "2025-11-20"
+    date_begin = "2025-12-01"
 
     # 仅在 rank == 0 时要做的
     if rank == 0:
@@ -239,7 +241,7 @@ def start_train():
             redis_client.set(lock_key, 0)
             print("rank0 开始数据加载...")
             # 使用默认配置
-            client, db = mongo_con_parse()
+            # client, db = mongo_con_parse()
             print(f"第 {i} 组 :", group_route_list)
             batch_flight_routes = group_route_list
 
@@ -257,12 +259,12 @@ def start_train():
             
             # 加载训练数据
             start_time = time.time()
-            df_train = load_train_data(db, group_route_list, table_name, date_begin, date_end, output_dir, is_hot)
+            df_train = load_train_data(mongodb_config, group_route_list, table_name, date_begin, date_end, output_dir, is_hot)
             end_time = time.time()
             run_time = round(end_time - start_time, 3)
             print(f"用时: {run_time} 秒")
 
-            client.close()
+            # client.close()
 
             if df_train.empty:
                 print(f"训练数据为空,跳过此批次。")

+ 2 - 0
result_validate.py

@@ -36,6 +36,7 @@ def validate_process(node, date):
         baggage = row['baggage']
         valid_begin_hour = row['valid_begin_hour'] 
         df_val= validate_one_line(db, city_pair, flight_day, flight_number_1, flight_number_2, baggage, valid_begin_hour)
+        # 有可能在当前验证时刻,数据库里没有在valid_begin_hour之后的数据
         if not df_val.empty:
             df_val_f = fill_hourly_crawl_date(df_val, rear_fill=2)
             df_val_f = df_val_f[df_val_f['is_filled']==0]    # 只要原始数据,不要补齐的
@@ -76,6 +77,7 @@ def validate_process(node, date):
                     
                 list_change_price = df_price_changes['adult_total_price'].tolist()
                 list_change_hours = df_price_changes['hours_until_departure'].tolist()
+        
         else:
             drop_flag = 0
             first_drop_amount = pd.NA

+ 3 - 3
utils.py

@@ -28,7 +28,7 @@ def insert_df_col(df, insert_col_name, base_col_name, inplace=True):
     return df
 
 # 真正创建序列过程
-def create_fixed_length_sequences(df, features, target_vars, threshold=48, input_length=432, is_train=True):
+def create_fixed_length_sequences(df, features, target_vars, threshold=36, input_length=444, is_train=True):
     print(">>开始创建序列")
     start_time = time.time()
 
@@ -49,7 +49,7 @@ def create_fixed_length_sequences(df, features, target_vars, threshold=48, input
         df_group_bag_30 = df_group[df_group['baggage']==30]
         df_group_bag_20 = df_group[df_group['baggage']==20]
 
-        # 过滤训练时间段 (48 ~ 480)
+        # 过滤训练时间段 (36 ~ 480)
         df_group_bag_30_filtered = df_group_bag_30[(df_group_bag_30['Hours_Until_Departure'] >= threshold) & (df_group_bag_30['Hours_Until_Departure'] < threshold + input_length)]
         df_group_bag_20_filtered = df_group_bag_20[(df_group_bag_20['Hours_Until_Departure'] >= threshold) & (df_group_bag_20['Hours_Until_Departure'] < threshold + input_length)]
 
@@ -62,7 +62,7 @@ def create_fixed_length_sequences(df, features, target_vars, threshold=48, input
             seq_features_1 = df_group_bag_30_filtered[features].to_numpy()
             seq_features_2 = df_group_bag_20_filtered[features].to_numpy()
             
-            # 将几个特征序列沿着第 0 维拼接,得到形状为 (2, 432, 25)
+            # 将几个特征序列沿着第 0 维拼接,得到形状为 (2, 444, 31)
             combined_features = torch.stack([torch.tensor(seq_features_1, dtype=torch.float32),    
                                              torch.tensor(seq_features_2, dtype=torch.float32)])