Răsfoiți Sursa

再次小幅修改训练, 调整测试集占比

node04 1 lună în urmă
părinte
comite
943818de39
3 a modificat fișierele cu 40 adăugiri și 15 ștergeri
  1. 3 3
      data_loader.py
  2. 12 11
      main_tr.py
  3. 25 1
      train.py

+ 3 - 3
data_loader.py

@@ -668,12 +668,12 @@ def load_train_data(db, flight_route_list, table_name, date_begin, date_end, out
 
                 if list_12:
                     df_c12 = pd.concat(list_12, ignore_index=True)
-                    print(f"✅ dep_date:{dep_date}, 所有 baggage 数据合并完成,总形状: {df_c12.shape}")
+                    # print(f"✅ dep_date:{dep_date}, 所有 baggage 数据合并完成,总形状: {df_c12.shape}")
                     # plot_c12_trend(df_c12, output_dir)
                     # print(f"✅ dep_date:{dep_date}, 所有 baggage 数据绘图完成")
                 else:
                     df_c12 = pd.DataFrame()
-                    print(f"⚠️ dep_date:{dep_date}, 所有 baggage 数据合并为空")
+                    # print(f"⚠️ dep_date:{dep_date}, 所有 baggage 数据合并为空")
 
                 del list_12
                 list_mid.append(df_c12)
@@ -682,7 +682,7 @@ def load_train_data(db, flight_route_list, table_name, date_begin, date_end, out
                 del df_d1
                 del df_d2
 
-                print(f"结束处理起飞日期: {dep_date}")
+                # print(f"结束处理起飞日期: {dep_date}")
 
             if list_mid:
                 df_mid = pd.concat(list_mid, ignore_index=True)

+ 12 - 11
main_tr.py

@@ -106,11 +106,12 @@ def start_train():
         local_rank = 0
         world_size = 1
 
-    output_dir = "./data_shards" 
+    output_dir = "./data_shards"
     photo_dir = "./photo"
 
     date_end = datetime.today().strftime("%Y-%m-%d")
-    date_begin = (datetime.today() - timedelta(days=41)).strftime("%Y-%m-%d")
+    # date_begin = (datetime.today() - timedelta(days=41)).strftime("%Y-%m-%d")
+    date_begin = "2025-11-20"
 
     # 仅在 rank == 0 时要做的
     if rank == 0:
@@ -156,17 +157,17 @@ def start_train():
     batch_flight_routes = None   # 占位, 避免其它rank找不到定义
 
     # 主干代码
-    # flight_route_list = vj_flight_route_list_hot + vj_flight_route_list_nothot
-    # flight_route_list_len = len(flight_route_list)
-    # route_len_hot = len(vj_flight_route_list_hot)
-    # route_len_nothot = len(vj_flight_route_list_nothot)
+    flight_route_list = vj_flight_route_list_hot + vj_flight_route_list_nothot
+    flight_route_list_len = len(flight_route_list)
+    route_len_hot = len(vj_flight_route_list_hot)
+    route_len_nothot = len(vj_flight_route_list_nothot)
 
     # 调试代码
-    s = 38   # 菲律宾2025-12-08是节假日 s=38 选到马尼拉 
-    flight_route_list = vj_flight_route_list_hot[0:] + vj_flight_route_list_nothot[s:]
-    flight_route_list_len = len(flight_route_list)
-    route_len_hot = len(vj_flight_route_list_hot[0:])
-    route_len_nothot = len(vj_flight_route_list_nothot[s:])
+    # s = 38   # 菲律宾2025-12-08是节假日 s=38 选到马尼拉 
+    # flight_route_list = vj_flight_route_list_hot[:0] + vj_flight_route_list_nothot[s:]
+    # flight_route_list_len = len(flight_route_list)
+    # route_len_hot = len(vj_flight_route_list_hot[:0])
+    # route_len_nothot = len(vj_flight_route_list_nothot[s:])
     
     if local_rank == 0:
         print(f"flight_route_list_len:{flight_route_list_len}")

+ 25 - 1
train.py

@@ -34,6 +34,18 @@ def safe_train_test_split(*arrays, test_size=0.2, random_state=None, stratify=No
         stratify=stratify
     )
 
+def adaptive_test_size(n_samples):
+    """
+    根据总样本数,自适应 test_size
+    """
+    if n_samples < 50:
+        return 0.4      # 极小样本,强行保验证
+    elif n_samples < 100:
+        return 0.3
+    # elif n_samples < 300:
+    #     return 0.25
+    else:
+        return 0.2      # 常规情况
 
 # 分布式数据集准备
 def prepare_data_distribute(sequences, targets, group_ids, flag_distributed=False, rank=0, local_rank=0, world_size=1):
@@ -62,11 +74,23 @@ def prepare_data_distribute(sequences, targets, group_ids, flag_distributed=Fals
     group_ids_filtered = [g for i, g in enumerate(group_ids) if valid_mask[i]]
     targets_array_filtered = targets_array[valid_mask]
 
+    n_samples = len(sequences_filtered)
+    test_size = adaptive_test_size(n_samples)
+
+    # 只有在样本数本身不太小的情况下,才做最小验证集兜底
+    min_val_samples = 10
+    if n_samples >= 2 * min_val_samples:
+        expected_val = int(n_samples * test_size)
+        if expected_val < min_val_samples:
+            test_size = min_val_samples / n_samples
+    
+    print(f"test_size: {test_size}")
+
     # 第一步:将28样本拆分为训练集(80%)和临时集(20%)
     train_28, temp_28, train_28_targets, temp_28_targets, train_28_gids, temp_28_gids = safe_train_test_split(
         sequences_filtered, targets_filtered, group_ids_filtered,
         stratify=targets_array_filtered,
-        test_size=0.2,
+        test_size=test_size,
         random_state=42,
         rank=rank,
         local_rank=local_rank