Browse Source

提交整个训练流程

node04 2 days ago
parent
commit
c88dd378ea
7 changed files with 1131 additions and 14 deletions
  1. 174 0
      evaluate.py
  2. 4 0
      font.py
  3. 327 10
      main_tr.py
  4. 87 0
      model.py
  5. BIN
      simhei.ttf
  6. 443 0
      train.py
  7. 96 4
      utils.py

+ 174 - 0
evaluate.py

@@ -0,0 +1,174 @@
+import os
+import numpy as np
+import pandas as pd
+import torch
+import torch.distributed as dist
+from torch.utils.data import DataLoader
+from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, mean_absolute_error
+from matplotlib import font_manager
+import matplotlib.pyplot as plt
+import seaborn as sns
+from utils import FlightDataset
+
+
+# 分布式模型评估
+def evaluate_model_distribute(model, device, sequences, targets, group_ids, batch_size=16, test_loader=None, 
+                              batch_flight_routes=None, target_scaler=None, 
+                              flag_distributed=False, rank=0, local_rank=0, world_size=1, output_dir='.', batch_idx=-1,
+                              csv_file='evaluate_results.csv', evalute_flag='evaluate', save_mode='a'):
+    
+    if test_loader is None:
+        if not sequences:
+            print("没有足够的数据进行评估。")
+            return
+        test_dataset = FlightDataset(sequences, targets, group_ids)
+        test_loader = DataLoader(test_dataset, batch_size=batch_size)   # ??
+    
+    batch_fn_str = ' '.join([route.replace('|', ' ') for route in batch_flight_routes]) if batch_flight_routes else ''
+
+    model.eval()
+
+    # 初始化存储容器(张量形式以便跨进程通信)
+    y_preds_list = []
+    y_trues_list = []
+    group_info_list = []
+
+    with torch.no_grad():
+        for X_batch, y_batch, group_ids_batch in test_loader:
+            X_batch = X_batch.to(device)
+            y_batch = y_batch.to(device)
+
+            # 分布式模式下需确保不同进程处理不同数据分片
+            outputs = model(X_batch)    
+
+            # 收集当前批次的结果(保留在GPU上)
+            y_preds_list.append(outputs.cpu().numpy())  # 移动到CPU以节省GPU内存
+            y_trues_list.append(y_batch.cpu().numpy())
+
+            # 处理 group_info(需转换为可序列化格式)
+            for i in range(len(group_ids_batch[0])):
+                group_id = tuple(g[i].item() if isinstance(g, torch.Tensor) else g[i] for g in group_ids_batch)
+                group_info_list.append(group_id)
+                pass
+    
+    # 合并当前进程的结果
+    y_preds = np.concatenate(y_preds_list, axis=0)
+    y_trues = np.concatenate(y_trues_list, axis=0)
+    group_info = group_info_list
+    
+    # --- 分布式结果聚合 ---
+    if flag_distributed:
+
+        # 收集所有进程的预测结果
+        y_preds_tensor = torch.tensor(y_preds, device=device)
+        y_trues_tensor = torch.tensor(y_trues, device=device)
+
+        # 收集所有进程的 y_preds 和 y_trues
+        gather_y_preds = [torch.zeros_like(y_preds_tensor) for _ in range(world_size)]
+        gather_y_trues = [torch.zeros_like(y_trues_tensor) for _ in range(world_size)]
+        dist.all_gather(gather_y_preds, y_preds_tensor)
+        dist.all_gather(gather_y_trues, y_trues_tensor)
+
+        # 合并结果到 rank 0
+        if rank == 0:
+            y_preds = np.concatenate([tensor.cpu().numpy() for tensor in gather_y_preds], axis=0)
+            y_trues = np.concatenate([tensor.cpu().numpy() for tensor in gather_y_trues], axis=0)
+
+        # 将 group_info 转换为字符串列表以便传输
+        group_info_str = ['|'.join(map(str, info)) for info in group_info]
+        gather_group_info = [None for _ in range(world_size)]
+        dist.all_gather_object(gather_group_info, group_info_str)
+
+        if rank == 0:
+            group_info = []
+            for info_list in gather_group_info:
+                for info_str in info_list:
+                    group_info.append(tuple(info_str.split('|')))   
+
+    # --- 仅在 rank 0 计算指标并保存结果 ---
+    if rank == 0:
+        
+        # 分类任务结果
+        y_preds_class = y_preds[:, 0]
+        y_trues_class = y_trues[:, 0]
+        y_preds_class_labels = (y_preds_class >= 0.5).astype(int)
+        y_trues_class_labels = y_trues_class.astype(int)
+
+        # 打印指标
+        printScore_cc(y_trues_class_labels, y_preds_class_labels, batch_fn_str=batch_fn_str, batch_idx=batch_idx, evalute_flag=evalute_flag)
+
+        # 构造 DataFrame
+        results_df = pd.DataFrame({
+            'city_pair': [info[0] for info in group_info],
+            'flight_day': [info[1] for info in group_info],
+            'flight_number_1': [info[2] for info in group_info],
+            'flight_number_2': [info[3] for info in group_info],
+            'from_date': [info[4] for info in group_info],
+            'baggage': [info[5] for info in group_info],
+            'price': [info[6] for info in group_info],
+            'Hours_until_Departure': [info[7] for info in group_info],
+            'update_hour': [info[8] for info in group_info],
+            'probability': y_preds_class,
+            'Actual_Will_Price_Drop': y_trues_class_labels,
+            'Predicted_Will_Price_Drop': y_preds_class_labels,
+        })
+
+        # 数值处理
+        threshold = 1e-3
+        numeric_columns = ['probability',
+                           # 'Actual_Amount_Of_Drop', 'Predicted_Amount_Of_Drop', 'Actual_Time_To_Drop', 'Predicted_Time_To_Drop'
+                           ]
+        for col in numeric_columns:
+            results_df[col] = results_df[col].where(results_df[col].abs() >= threshold, 0)
+        
+        # 保存结果
+        results_df_path = os.path.join(output_dir, csv_file)
+        if save_mode == 'a':
+            # 追加模式
+            results_df.to_csv(results_df_path, mode='a', index=False, header=not os.path.exists(results_df_path))
+        else:
+            # 重写模式
+            results_df.to_csv(results_df_path, mode='w', index=False, header=True) 
+        print(f"预测结果已保存到 '{results_df_path}'")
+        
+        return results_df
+    
+    else:
+        return None
+
+
+def printScore_cc(y_trues_class_labels, y_preds_class_labels, batch_fn_str='', batch_idx=-1, evalute_flag='evaluate'):
+    
+    accuracy = accuracy_score(y_trues_class_labels, y_preds_class_labels)
+    precision = precision_score(y_trues_class_labels, y_preds_class_labels, zero_division=0)
+    recall = recall_score(y_trues_class_labels, y_preds_class_labels, zero_division=0)
+    f1 = f1_score(y_trues_class_labels, y_preds_class_labels, zero_division=0)
+
+    print(f"分类准确率: {accuracy:.4f}")
+    print(f"分类精确率: {precision:.4f}")
+    print(f"分类召回率: {recall:.4f}")
+    print(f"分类F1值: {f1:.4f}")
+
+    # 获取二者都为1的正例索引
+    indices = np.where((y_trues_class_labels == 1) & (y_preds_class_labels == 1))[0]
+    if len(indices) > 0:
+        pass
+    else:
+        print("没有正例")
+
+    font_path = "./simhei.ttf"
+    font_prop = font_manager.FontProperties(fname=font_path)
+    # font_prop = font.font_prop
+    
+    # 混淆矩阵
+    cm = confusion_matrix(y_trues_class_labels, y_preds_class_labels)
+    plt.figure(figsize=(6, 5))
+    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
+                xticklabels=['预测:不会降价', '预测:会降价'],
+                yticklabels=['实际:不会降价', '实际:会降价'])
+    plt.xticks(fontproperties=font_prop)
+    plt.yticks(fontproperties=font_prop)
+    plt.xlabel('预测情况', fontproperties=font_prop)
+    plt.ylabel('实际结果', fontproperties=font_prop)
+    plt.title('分类结果的混淆矩阵', fontproperties=font_prop)
+    plt.savefig(f"./photo/{evalute_flag}_confusion_matrix_{batch_idx}_{batch_fn_str}.png")

+ 4 - 0
font.py

@@ -0,0 +1,4 @@
+from matplotlib import font_manager
+# 设置字体
+font_path = "./simhei.ttf"
+font_prop = font_manager.FontProperties(fname=font_path)

+ 327 - 10
main_tr.py

@@ -6,15 +6,18 @@ from torch.nn.parallel import DistributedDataParallel as DDP
 import joblib
 import joblib
 import gc
 import gc
 import pandas as pd
 import pandas as pd
-import numpy as np
+# import numpy as np
 import redis
 import redis
 import time
 import time
 import pickle
 import pickle
 import shutil
 import shutil
 from datetime import datetime, timedelta
 from datetime import datetime, timedelta
 from utils import chunk_list_with_index, create_fixed_length_sequences
 from utils import chunk_list_with_index, create_fixed_length_sequences
+from model import PriceDropClassifiTransModel
 from data_loader import mongo_con_parse, load_train_data
 from data_loader import mongo_con_parse, load_train_data
 from data_preprocess import preprocess_data, standardization
 from data_preprocess import preprocess_data, standardization
+from train import prepare_data_distribute, train_model_distribute
+from evaluate import printScore_cc
 from config import mongodb_config, vj_flight_route_list, vj_flight_route_list_hot, vj_flight_route_list_nothot, \
 from config import mongodb_config, vj_flight_route_list, vj_flight_route_list_hot, vj_flight_route_list_nothot, \
     CLEAN_VJ_HOT_NEAR_INFO_TAB, CLEAN_VJ_HOT_FAR_INFO_TAB, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB, CLEAN_VJ_NOTHOT_FAR_INFO_TAB
     CLEAN_VJ_HOT_NEAR_INFO_TAB, CLEAN_VJ_HOT_FAR_INFO_TAB, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB, CLEAN_VJ_NOTHOT_FAR_INFO_TAB
 
 
@@ -68,7 +71,19 @@ def init_distributed_backend():
 
 
 # 初始化模型和相关参数
 # 初始化模型和相关参数
 def initialize_model(device):
 def initialize_model(device):
-    return None
+    input_size = len(features)
+    model = PriceDropClassifiTransModel(input_size, num_periods=2, hidden_size=64, num_layers=3, output_size=1, dropout=0.2)
+    model.to(device)
+
+    if FLAG_Distributed:
+        model = DDP(model, device_ids=[device], find_unused_parameters=True)   # 使用DDP包装模型
+
+    if FLAG_Distributed:
+        print(f"Rank:{dist.get_rank()}, 模型已初始化,输入尺寸:{input_size}")
+    else:
+        print(f"模型已初始化,输入尺寸:{input_size}")
+
+    return model
 
 
 def continue_before_process(redis_client, lock_key):
 def continue_before_process(redis_client, lock_key):
     # rank0 跳出循环前的处理
     # rank0 跳出循环前的处理
@@ -95,7 +110,7 @@ def start_train():
     photo_dir = "./photo"
     photo_dir = "./photo"
 
 
     date_end = datetime.today().strftime("%Y-%m-%d")
     date_end = datetime.today().strftime("%Y-%m-%d")
-    date_begin = (datetime.today() - timedelta(days=18)).strftime("%Y-%m-%d")
+    date_begin = (datetime.today() - timedelta(days=41)).strftime("%Y-%m-%d")
 
 
     # 仅在 rank == 0 时要做的
     # 仅在 rank == 0 时要做的
     if rank == 0:
     if rank == 0:
@@ -121,9 +136,9 @@ def start_train():
 
 
         print(f"最终特征列表:{features}")
         print(f"最终特征列表:{features}")
 
 
-    # 定义优化器和损失函数(只回归)
-    # criterion = RegressionLoss(loss_func_flag="Quantile", quantile=0.5)
-    # optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=1e-5)
+    # 定义优化器和损失函数
+    criterion = None   #  后面在训练之前定义
+    optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=1e-5)
 
 
     group_size = 1              # 每几组作为一个批次
     group_size = 1              # 每几组作为一个批次
     num_epochs_per_batch = 200  # 每个批次训练的轮数,可以根据需要调整
     num_epochs_per_batch = 200  # 每个批次训练的轮数,可以根据需要调整
@@ -136,7 +151,9 @@ def start_train():
     lock_key = "data_loading_lock_11"
     lock_key = "data_loading_lock_11"
     barrier_key = 'distributed_barrier_11'
     barrier_key = 'distributed_barrier_11'
 
 
+    assemble_size = 1   # 几个batch作为一个集群assemble
     batch_idx = -1
     batch_idx = -1
+    batch_flight_routes = None   # 占位, 避免其它rank找不到定义
 
 
     # 主干代码
     # 主干代码
     # flight_route_list = vj_flight_route_list_hot + vj_flight_route_list_nothot
     # flight_route_list = vj_flight_route_list_hot + vj_flight_route_list_nothot
@@ -146,9 +163,9 @@ def start_train():
 
 
     # 调试代码
     # 调试代码
     s = 38   # 菲律宾2025-12-08是节假日 s=38 选到马尼拉 
     s = 38   # 菲律宾2025-12-08是节假日 s=38 选到马尼拉 
-    flight_route_list = vj_flight_route_list_hot[:0] + vj_flight_route_list_nothot[s:]
+    flight_route_list = vj_flight_route_list_hot[0:] + vj_flight_route_list_nothot[s:]
     flight_route_list_len = len(flight_route_list)
     flight_route_list_len = len(flight_route_list)
-    route_len_hot = len(vj_flight_route_list_hot[:0])
+    route_len_hot = len(vj_flight_route_list_hot[0:])
     route_len_nothot = len(vj_flight_route_list_nothot[s:])
     route_len_nothot = len(vj_flight_route_list_nothot[s:])
     
     
     if local_rank == 0:
     if local_rank == 0:
@@ -222,6 +239,7 @@ def start_train():
             # 使用默认配置
             # 使用默认配置
             client, db = mongo_con_parse()
             client, db = mongo_con_parse()
             print(f"第 {i} 组 :", group_route_list)
             print(f"第 {i} 组 :", group_route_list)
+            batch_flight_routes = group_route_list
 
 
             # 根据索引位置决定是 热门 还是 冷门
             # 根据索引位置决定是 热门 还是 冷门
             if 0 <= i < route_len_hot:
             if 0 <= i < route_len_hot:
@@ -277,11 +295,310 @@ def start_train():
             
             
             # 生成序列
             # 生成序列
             sequences, targets, group_ids = create_fixed_length_sequences(df_train_inputs, features, target_vars, input_length=452)
             sequences, targets, group_ids = create_fixed_length_sequences(df_train_inputs, features, target_vars, input_length=452)
-            pass
+            
+            # 新增有效性检查
+            if len(sequences) == 0 or len(targets) == 0 or len(group_ids) == 0:
+                valid_batch[0] = 0
+                print("警告:当前批次数据为空,标记为无效批次")
+            
+            # 数据加载及预处理完成,设置 Redis 锁 key 的值为 1
+            redis_client.set(lock_key, 1)
+            print("rank0 数据加载完成,已将 Redis 锁 key 值设置为 1")
 
 
         else:
         else:
-            pass
+            val = None
+            # 其它 rank 等待:只有当 lock key 存在且其值为 "1" 时才算数据加载完成
+            print(f"rank{rank} 正在等待 rank0 完成数据加载...")
+            while True:
+                val = redis_client.get(lock_key)
+                if val is not None and val.decode('utf-8') in ["1", "2"]:
+                    break
+                time.sleep(1)
+            if val is not None and val.decode('utf-8') == "2":
+                print(f"rank{rank} 跳过空批次 {i}")
+                time.sleep(3)
+                continue
+            print(f"rank{rank} 检测到数据加载已完成,继续后续处理...")
+
+        # 同步点:所有 Rank 在此等待
+        if FLAG_Distributed:
+            # 确保所有 CUDA 操作完成并释放缓存
+            print(f"rank{rank} ready synchronize ...")
+            torch.cuda.synchronize()
+            
+            print(f"rank{rank} ready empty_cache ...")
+            torch.cuda.empty_cache()
+            
+            print(f"rank{rank} ready barrier ...")
+            dist.barrier()  # 移除 device_ids 参数
+            # dist.barrier(device_ids=[local_rank])
+
+        print(f"rank{rank} done barrier ...")
+
+        # 广播批次有效性标志
+        if FLAG_Distributed:
+            dist.broadcast(valid_batch, src=0)
+
+        # 所有 Rank 检查批次有效性
+        if valid_batch.item() == 0:
+            print(f"Rank {rank} 跳过无效批次 {i}")
+            continue  # 所有 Rank 跳过当前循环
+
+        # 所有 Rank 同时进入数据分发
+        if rank == 0:
+            # 分片并分发
+            my_sequences, my_targets, my_group_ids = distribute_sharded_data(sequences, targets, group_ids, world_size, rank, device, flag_distributed=FLAG_Distributed)
+        else:
+            # 其它 Rank 接收数据
+            my_sequences, my_targets, my_group_ids = distribute_sharded_data([], [], [], world_size, rank, device, flag_distributed=FLAG_Distributed)
+
+        # 查看一下各rank是否分到数据
+        debug_print_shard_info([], my_targets, my_group_ids, rank, local_rank, world_size)
+
+        pre_flag, train_single, val_single, test_single = prepare_data_distribute(my_sequences, my_targets, my_group_ids, 
+            flag_distributed=FLAG_Distributed, rank=rank, local_rank=local_rank, world_size=world_size)
+        
+        del my_sequences
+        del my_targets
+        del my_group_ids
+        gc.collect()
+
+        if not pre_flag:
+            print(f"Rank {rank} 跳过无效数据批次 {i}")
+            continue
+
+        train_sequences = train_single['sequences']
+        train_targets = train_single['targets']
+        train_group_ids = train_single['group_ids']
+
+        val_sequences = val_single['sequences']
+        val_targets = val_single['targets']
+        val_group_ids = val_single['group_ids']
+
+        # test_sequences = test_single['sequences']
+        # test_targets = test_single['targets']
+        # test_group_ids = test_single['group_ids']
+
+        if FLAG_Distributed:
+            dist.barrier()
+
+        # 训练模型
+        model = train_model_distribute(train_sequences, train_targets, train_group_ids, val_sequences, val_targets, val_group_ids,
+            model, criterion, optimizer, device, num_epochs=num_epochs_per_batch, batch_size=16, target_scaler=target_scaler, 
+            flag_distributed=FLAG_Distributed, rank=rank, local_rank=local_rank, world_size=world_size, 
+            output_dir=output_dir, photo_dir=photo_dir, batch_idx=batch_idx, 
+            batch_flight_routes=batch_flight_routes, patience=40, delta=0.001)
+
+        del train_single
+        del val_single
+        del test_single
+        gc.collect()
+
+        # 重置模型参数
+        if (i + 1) % assemble_size == 0:
+            
+            if FLAG_Distributed:
+                dist.barrier()
+
+            del model, optimizer 
+            torch.cuda.empty_cache()  # 清理GPU缓存
+
+            model = initialize_model(device)  # 重置模型
+            
+            optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=1e-5)  # 重置优化器
+            print(f"Rank {rank}, Reset Model at batch {i} due to performance drop")
+
+    ###############################################################################################################
+
+    # 在整体批次训练结束后
+    if rank == 0:
+        # pass
+        # torch.save(model.state_dict(), os.path.join(output_dir, 'final_model.pth'))
+        print("模型训练完成并已保存。")
+
+        csv_file = 'evaluate_results.csv'
+        csv_path = os.path.join(output_dir, csv_file)
+        # 汇总评估结果
+        try:
+            df = pd.read_csv(csv_path)
+        except Exception as e:
+            print(f"read {csv_path} error: {str(e)}")
+            df = None
+
+        if df is not None:
+            # 提取真实值和预测值
+            y_trues_class_labels = df['Actual_Will_Price_Drop']
+            y_preds_class_labels = df['Predicted_Will_Price_Drop']
+
+            printScore_cc(y_trues_class_labels, y_preds_class_labels, batch_fn_str='validate', batch_idx='')
+
+    if FLAG_Distributed:
+        dist.destroy_process_group()   # 显式调用 destroy_process_group 来清理 NCCL 的进程组资源
+
+
+def distribute_sharded_data(sequences, targets, group_ids, world_size, rank, device, flag_distributed):
+    # --- 非分布式模式:直接返回全量数据
+    if not flag_distributed:
+        return sequences, targets, group_ids
+
+    # ================== 第一阶段:元数据广播 ==================
+    if rank == 0:
+        # 将 group_ids 序列化为字节流
+        group_bytes = pickle.dumps(group_ids)
+        # 转换为张量用于分块传输
+        group_tensor = torch.frombuffer(bytearray(group_bytes), dtype=torch.uint8).to(device)
+        # 处理其他数据
+        seq_tensor = torch.stack(sequences, dim=0).to(device)    # shape [N, 2, 452, 25]
+        tgt_tensor = torch.stack(targets, dim=0).to(device)      # shape [N, 1]
+
+        meta_data = {
+            # sequences/targets 元数据
+            'seq_shape': seq_tensor.shape,
+            'tgt_shape': tgt_tensor.shape,
+            'seq_dtype': str(seq_tensor.dtype).replace('torch.', ''),   # 关键修改点
+            'tgt_dtype': str(tgt_tensor.dtype).replace('torch.', ''),
+            
+            # group_ids 元数据
+            'group_shape': group_tensor.shape,
+            'group_bytes_len': len(group_bytes),
+            'pickle_protocol': pickle.HIGHEST_PROTOCOL
+        }
+    else:
+        meta_data = None
+
+    # 广播元数据(所有rank都需要)
+    meta_data = broadcast(meta_data, src=0, rank=rank, device=device)
+    
+    # ================== 第二阶段:分块传输 ==================
+    # 初始化接收缓冲区(所有Rank)
+    if rank == 0:
+        group_tensor = group_tensor
+        seq_tensor = seq_tensor
+        tgt_tensor = tgt_tensor
+    else:
+        seq_dtype = getattr(torch, meta_data['seq_dtype'])   # 例如 meta_data['seq_dtype'] = "float32"
+        tgt_dtype = getattr(torch, meta_data['tgt_dtype'])
+        
+        group_tensor = torch.zeros(meta_data['group_shape'], dtype=torch.uint8, device=device)
+        seq_tensor = torch.zeros(meta_data['seq_shape'], dtype=seq_dtype, device=device)
+        tgt_tensor = torch.zeros(meta_data['tgt_shape'], dtype=tgt_dtype, device=device)
+    
+    # 并行传输所有数据(按传输量排序:先大后小)
+    _chunked_broadcast(seq_tensor, src=0, rank=rank)     # 最大数据优先
+    _chunked_broadcast(tgt_tensor, src=0, rank=rank)
+    _chunked_broadcast(group_tensor, src=0, rank=rank)    # 最后传输group_ids
+
+    # ================== 第三阶段:数据重建 ==================
+    # 重建 sequences 和 targets
+    sequences_list = [seq.cpu().clone() for seq in seq_tensor]   # 自动按第0维切分
+    targets_list = [tgt.cpu().clone() for tgt in tgt_tensor]
+
+    # 重建 group_ids(关键步骤)
+    if rank == 0:
+        # Rank0直接使用原始数据避免重复序列化
+        group_ids_rebuilt = group_ids
+    else:
+        # 1. 提取有效字节(去除填充)
+        group_bytes = bytes(group_tensor.cpu().numpy().tobytes()[:meta_data['group_bytes_len']])
+        # 2. 反序列化
+        try:
+            group_ids_rebuilt = pickle.loads(group_bytes)
+        except pickle.UnpicklingError as e:
+            raise RuntimeError(f"反序列化 group_ids 失败: {str(e)}")
+        
+        # 3. 结构校验
+        _validate_group_structure(group_ids_rebuilt)
+
+    return sequences_list, targets_list, group_ids_rebuilt
+
+def broadcast(data, src, rank, device):
+    """安全地广播任意数据,确保张量在正确的设备上"""
+    if rank == src:
+        # 序列化数据
+        data_bytes = pickle.dumps(data)
+        data_size = torch.tensor([len(data_bytes)], dtype=torch.long, device=device)
+        # 创建数据张量并移动到设备
+        data_tensor = torch.frombuffer(bytearray(data_bytes), dtype=torch.uint8).to(device)
+        # 先广播数据大小
+        dist.broadcast(data_size, src=src)
+        # 然后广播数据
+        dist.broadcast(data_tensor, src=src)
+        return data
+    else:
+        # 接收数据大小
+        data_size = torch.tensor([0], dtype=torch.long, device=device)
+        dist.broadcast(data_size, src=src)
+        # 分配数据张量
+        data_tensor = torch.empty(data_size.item(), dtype=torch.uint8, device=device)
+        dist.broadcast(data_tensor, src=src)
+        # 反序列化
+        data = pickle.loads(data_tensor.cpu().numpy().tobytes())
+        return data
+    
+def _chunked_broadcast(tensor, src, rank, chunk_size=1024*1024*128):    # chunk_size 单位是字节
+    """分块广播张量优化通信效率"""
+    # Step 1. 准备连续内存缓冲
+    buffer = tensor.detach().contiguous()
+    # Step 2. 计算字节数
+    element_size = buffer.element_size()  # 每个元素的字节数(如 float32 是 4)
+    total_elements = buffer.numel()
+
+    # 计算每个块最多包含多少元素(根据字节数换算)
+    elements_per_chunk = chunk_size // element_size
+    # 分块数量
+    num_chunks = (total_elements + elements_per_chunk - 1) // elements_per_chunk
+    
+    # Step 4. 逐块广播
+    for chunk_idx in range(num_chunks):
+        # 计算当前块的字节范围
+        start_element = chunk_idx * elements_per_chunk
+        end_element = min((chunk_idx+1)*elements_per_chunk, total_elements)
+        # Step 5. 从大张量中切出当前块
+        chunk = buffer.view(-1).narrow(0, start_element, end_element - start_element)
+        # Step 6. 执行广播
+        dist.broadcast(chunk, src=src)
+        # 说明: 虽然单个chunk是一维的, 但通过其内部的 1.严格的传输顺序 2.接收端的内存预分配 3.最终reshape操作 原始张量的形状得以完美恢复
+
+def _validate_group_structure(group_ids):
+    """校验 group_ids 数据结构完整性"""
+    assert isinstance(group_ids, list), "Group IDs 必须是列表"
+    if len(group_ids) == 0:
+        print("还原的 group_ids 长度为0")
+        return
+    
+    sample = group_ids[0]
+    assert isinstance(sample, tuple), "元素必须是元组"
+    assert len(sample) == 9, "元组长度必须为9"
+
+def debug_print_shard_info(sequences, targets, group_ids, rank, local_rank, world_size):
+    """分布式环境下按Rank顺序打印分片前5条样本"""
+    # 同步所有进程
+    if FLAG_Distributed:
+        dist.barrier(device_ids=[local_rank])
+
+    # 按Rank顺序逐个打印(避免输出混杂)
+    for r in range(world_size):
+        if r == rank:
+            print(f"\n=== Rank {rank}/{world_size} Data Shard Samples (showing first 5) ===")
+            
+            # 打印序列数据
+            # print("[Sequences]")
+            # for i, seq in enumerate(sequences[:5]):
+            #     print(f"Sample {i}: {seq[:3]}...")  # 只显示前3元素示意
+
+            # 打印目标数据
+            print("\n[Targets]")
+            print(targets[:5])
+
+            # 打印Group ID分布
+            print("\n[Group IDs]")
+            # unique_gids = list(set(group_ids[:50]))  # 检查前50条的group分布
+            print(f"First 5 GIDs: {group_ids[:5]}")
+            
+            # sys.stdout.flush()  # 确保立即输出
 
 
+        if FLAG_Distributed:
+            dist.barrier(device_ids=[local_rank])  # 等待当前Rank打印完成
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":

+ 87 - 0
model.py

@@ -0,0 +1,87 @@
+import torch
+import torch.nn as nn
+
+
+# 分类模型 (Transformer) 
+class PriceDropClassifiTransModel(nn.Module):
+    def __init__(self, input_size, num_periods=2, hidden_size=128, num_layers=3, output_size=1, dropout=0.3, conv_out_channels=64, kernel_size=3, num_heads=8):
+        super(PriceDropClassifiTransModel, self).__init__()
+        self.hidden_size = hidden_size
+        self.num_layers = num_layers
+        self.num_periods = num_periods
+
+        # 卷积层
+        self.conv1 = nn.Conv1d(
+            in_channels=input_size * num_periods,
+            out_channels=conv_out_channels,
+            kernel_size=kernel_size,
+            padding=kernel_size // 2,
+            bias=False,
+        )
+        self.relu = nn.ReLU()
+
+        # Transformer Encoder
+        self.transformer_layer = nn.TransformerEncoderLayer(
+            d_model=conv_out_channels,
+            # d_model=input_size * num_periods,   # 这里的d_model应为输入的特征数量, d_model能被num_heads整除
+            nhead=num_heads,
+            dim_feedforward=hidden_size,
+            dropout=dropout
+        )
+        self.transformer_encoder = nn.TransformerEncoder(
+            self.transformer_layer,
+            num_layers=num_layers
+        )
+
+        # 注意力机制
+        self.attention_layer = nn.Sequential(
+            nn.Linear(conv_out_channels, hidden_size),
+            # nn.Linear(input_size * num_periods, hidden_size),
+            # nn.Conv1d(conv_out_channels, hidden_size),
+            # nn.Tanh(),
+            nn.ReLU(),
+            nn.Linear(hidden_size, 1)
+        )
+
+        # 分类和回归输出层
+        self.fc_classification = nn.Linear(conv_out_channels, 1)
+
+    def forward(self, x):
+        """
+        输入x的形状应为 [batch_size, num_periods, seq_length, input_size]
+        """
+        batch_size, num_periods, seq_length, input_size = x.size()
+        # x = x[:,0,:,:].view(batch_size, 1, input_size, seq_length)
+
+        # 将输入转换为 [batch_size, num_periods * input_size, seq_length]
+        x = x.permute(0, 1, 3, 2).contiguous()  # [batch_size, num_periods, input_size, seq_length]
+        x = x.view(batch_size, num_periods * input_size, seq_length)  # [batch_size, num_periods * input_size, seq_length]
+        # x = x.view(batch_size, 1 * input_size, seq_length)
+
+        # 经过卷积层和激活函数
+        x = self.conv1(x)    # [batch_size, conv_out_channels, seq_length]
+        x = self.relu(x)
+
+        # 转置以适应Transformer输入要求
+        x = x.permute(2, 0, 1)  # [seq_length, batch_size, conv_out_channels(num_periods * input_size)]
+
+        # 经过Transformer编码器
+        x = self.transformer_encoder(x)  # [seq_length, batch_size, conv_out_channels(num_periods * input_size)]
+
+        # 计算注意力
+        attention_scores = self.attention_layer(x)  # [seq_length, batch_size, 1]
+        attention_weights = torch.softmax(attention_scores, dim=0)  # [seq_length, batch_size, 1]
+        # 对所有时间步进行加权求和
+        context_vector = torch.sum(attention_weights * x, dim=0)  # [batch_size, conv_out_channels(num_periods * input_size)]
+        
+        # 取最后一个时间步的输出进行分类和回归
+        # context_vector = x[-1, :, :]  # [batch_size, conv_out_channels(num_periods * input_size)]
+
+        # 分类和回归输出
+        classification_output = torch.sigmoid(self.fc_classification(context_vector))  # [batch_size, 1]
+        # 打印检查:输出范围
+        # print(f"Before clamp: min: {classification_output.min().item()}, max: {classification_output.max().item()}")
+        # 将输出值限制在 [0.0001, 0.9999] 范围内,以避免数值极端
+        # classification_output = torch.clamp(classification_output, min=1e-4, max=1 - 1e-4)
+        
+        return classification_output

BIN
simhei.ttf


+ 443 - 0
train.py

@@ -0,0 +1,443 @@
+import gc
+import os
+import torch
+import torch.nn as nn
+import torch.distributed as dist
+from torch.utils.data import DataLoader, DistributedSampler
+from sklearn.model_selection import train_test_split
+from imblearn.over_sampling import SMOTE, RandomOverSampler
+from collections import Counter
+from evaluate import evaluate_model_distribute
+from utils import FlightDataset, EarlyStoppingDist  # EarlyStopping, train_process, train_process_distribute, CombinedLoss
+import numpy as np
+import matplotlib.pyplot as plt
+import font
+import config
+import redis
+import time
+
+
+# 智能分层划分函数
+def safe_train_test_split(*arrays, test_size=0.2, random_state=None, stratify=None, rank=0, local_rank=0):
+    if stratify is not None:
+        counts = Counter(stratify)
+        min_count = min(counts.values()) if counts else 0
+        if min_count < 2:
+            if local_rank == 0:
+                print(f"Rank:{rank}, Local Rank:{local_rank}, 安全分层:检测到最小类别样本数={min_count},禁用分层")
+            stratify = None
+    
+    return train_test_split(
+        *arrays,
+        test_size=test_size,
+        random_state=random_state,
+        stratify=stratify
+    )
+
+
+# 分布式数据集准备
+def prepare_data_distribute(sequences, targets, group_ids, flag_distributed=False, rank=0, local_rank=0, world_size=1):
+    if len(sequences) == 0 or len(targets) == 0:
+        print(f"Rank:{rank}, 没有足够的数据参与训练。")
+        return False, None, None, None
+    
+    targets_array = np.array([t[0].item() if isinstance(t[0], torch.Tensor) else t[0] for t in targets])
+    unique_classes, class_counts = np.unique(targets_array, return_counts=True)
+    if len(unique_classes) == 1:
+        print(f"Rank:{rank}, 警告:目标变量只有一个类别,无法参与训练。")
+        return False, None, None, None
+    
+    # --- 高效过滤样本数 ≤ 1 的类别(浮点兼容版)---
+    unique_classes, class_counts = np.unique(targets_array, return_counts=True)
+    class_to_count = dict(zip(unique_classes, class_counts))
+    valid_mask = np.array([class_to_count[cls] >= 2 for cls in targets_array])
+
+    if not np.any(valid_mask):
+        print(f"Rank:{rank}, 警告:所有类别的样本数均 ≤ 1,无法分层拆分。")
+        return False, None, None, None
+
+    # 一次性筛选数据(兼容列表/Tensor/Array)
+    sequences_filtered = [seq for i, seq in enumerate(sequences) if valid_mask[i]]
+    targets_filtered = [t for i, t in enumerate(targets) if valid_mask[i]]
+    group_ids_filtered = [g for i, g in enumerate(group_ids) if valid_mask[i]]
+    targets_array_filtered = targets_array[valid_mask]
+
+    # 第一步:将28样本拆分为训练集(80%)和临时集(20%)
+    train_28, temp_28, train_28_targets, temp_28_targets, train_28_gids, temp_28_gids = safe_train_test_split(
+        sequences_filtered, targets_filtered, group_ids_filtered,
+        stratify=targets_array_filtered,
+        test_size=0.2,
+        random_state=42,
+        rank=rank,
+        local_rank=local_rank
+    )
+
+    # 验证集与测试集全部引用临时集
+    val_28 = temp_28
+    test_28 = temp_28
+    val_28_targets = temp_28_targets
+    test_28_targets = temp_28_targets
+    val_28_gids = temp_28_gids
+    test_28_gids = temp_28_gids
+
+    # 合并训练集
+    train_sequences = train_28
+    train_targets = train_28_targets
+    train_group_ids = train_28_gids
+    
+    # 合并验证集
+    val_sequences = val_28 
+    val_targets = val_28_targets 
+    val_group_ids = val_28_gids
+
+    # 测试集
+    test_sequences = test_28
+    test_targets = test_28_targets
+    test_group_ids = test_28_gids
+
+    if local_rank == 0:
+        print(f"Rank:{rank}, Local Rank:{local_rank}, 批次训练集数量:{len(train_sequences)}")
+        print(f"Rank:{rank}, Local Rank:{local_rank}, 批次验证集数量:{len(val_sequences)}")
+        print(f"Rank:{rank}, Local Rank:{local_rank}, 批次测试集数量:{len(test_sequences)}")
+
+    train_sequences_tensors = [torch.tensor(seq, dtype=torch.float32) for seq in train_sequences]
+    train_targets_tensors = [torch.tensor(target, dtype=torch.float32) for target in train_targets]
+
+    if local_rank == 0:
+        # 打印检查
+        print(f"Rank:{rank}, Local Rank:{local_rank}, train_targets_tensors[0].shape:{train_targets_tensors[0].shape}")  # 应该是 torch.Size([1])
+        print(f"Rank:{rank}, Local Rank:{local_rank}, train_sequences_tensors[0].dtype:{train_sequences_tensors[0].dtype}")  # 应该是 torch.float32
+        print(f"Rank:{rank}, Local Rank:{local_rank}, train_targets_tensors[0].dtype:{train_targets_tensors[0].dtype}")  # 应该是 torch.float32
+
+    train_single = {'sequences': train_sequences_tensors, 'targets': train_targets_tensors, 'group_ids': train_group_ids}
+    val_single = {'sequences': val_sequences, 'targets': val_targets, 'group_ids': val_group_ids}
+    test_single = {'sequences': test_sequences, 'targets': test_targets, 'group_ids': test_group_ids}    
+
+    def _redis_barrier(redis_client, barrier_key, world_size, timeout=3600, poll_interval=1):
+        # 每个 rank 到达 barrier 时,将计数加 1
+        redis_client.incr(barrier_key)
+        
+        start_time = time.time()
+        while True:
+            count = redis_client.get(barrier_key)
+            count = int(count) if count else 0
+            if count >= world_size:
+                break
+            if time.time() - start_time > timeout:
+                raise TimeoutError("等待 barrier 超时")
+            time.sleep(poll_interval)
+
+    # 等待其他进程生成数据,并同步
+    if flag_distributed:
+        redis_client = redis.Redis(host='192.168.20.222', port=6379, db=0)
+        barrier_key = 'distributed_barrier_11'
+
+        # 等待所有进程都到达 barrier
+        _redis_barrier(redis_client, barrier_key, world_size)
+
+    return True, train_single, val_single, test_single
+
+# 分布式训练
+def train_model_distribute(train_sequences, train_targets, train_group_ids, val_sequences, val_targets, val_group_ids,
+                           model, criterion, optimizer, device, num_epochs=200, batch_size=16, target_scaler=None,
+                           flag_distributed=False, rank=0, local_rank=0, world_size=1, output_dir='.',
+                           photo_dir='.', batch_idx=-1, batch_flight_routes=None, patience=20, delta=0.01
+                           ):
+    
+    # 统计正负样本数量
+    all_targets = torch.cat(train_targets)   # 将所有目标值拼接成一个张量
+    positive_count = torch.sum(all_targets == 1).item()
+    negative_count = torch.sum(all_targets == 0).item()
+    total_samples = len(all_targets)
+
+    # 计算比例
+    positive_ratio = positive_count / total_samples
+    negative_ratio = negative_count / total_samples
+
+    if local_rank == 0:
+        # 打印检查
+        print(f"Rank:{rank}, Local Rank:{local_rank}, 汇总训练集数量:{len(train_sequences)}")
+        print(f"Rank:{rank}, Local Rank:{local_rank}, 汇总训练集目标数量:{len(train_targets)}")
+        print(f"Rank:{rank}, Local Rank:{local_rank}, 汇总验证集数量:{len(val_sequences)}")
+        print(f"Rank:{rank}, Local Rank:{local_rank}, 汇总验证集目标数量:{len(val_targets)}")
+
+        # 打印正负样本统计
+        print(f"Rank:{rank}, Local Rank:{local_rank}, 训练集总样本数: {total_samples}")
+        print(f"Rank:{rank}, Local Rank:{local_rank}, 训练集正样本数: {positive_count} ({positive_ratio*100:.2f}%)")
+        print(f"Rank:{rank}, Local Rank:{local_rank}, 训练集负样本数: {negative_count} ({negative_ratio*100:.2f}%)")
+
+    # 计算并打印推荐的 pos_weight
+    if positive_count > 0:
+        recommended_pos_weight = negative_count / positive_count
+        if local_rank == 0:
+            print(f"Rank:{rank}, Local Rank:{local_rank}, 推荐的 pos_weight: {recommended_pos_weight:.2f}")
+    else:
+        recommended_pos_weight = 1.0
+        if local_rank == 0:
+            print(f"Rank:{rank}, Local Rank:{local_rank}, 警告: 没有正样本!")
+
+    train_dataset = FlightDataset(train_sequences, train_targets)
+    val_dataset = FlightDataset(val_sequences, val_targets, val_group_ids)
+    # test_dataset = FlightDataset(test_sequences, test_targets, test_group_ids)
+
+    del train_sequences
+    del train_targets
+    del train_group_ids
+    del val_sequences
+    del val_targets
+    del val_group_ids
+    gc.collect()
+
+    if flag_distributed:
+        sampler_train = DistributedSampler(train_dataset, shuffle=True)  # 分布式采样器
+        sampler_val = DistributedSampler(val_dataset, shuffle=False)
+        
+        train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler_train)
+        val_loader = DataLoader(val_dataset, batch_size=batch_size,  sampler=sampler_val)
+    else:
+        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
+        val_loader = DataLoader(val_dataset, batch_size=batch_size)
+
+    if local_rank == 0:
+        print(f"Rank:{rank}, Local Rank:{local_rank}, train_dataset 0 0 {train_dataset[0][0].shape}")  # 特征尺寸
+        print(f"Rank:{rank}, Local Rank:{local_rank}, train_dataset 0 1 {train_dataset[0][1].shape}")  # 目标尺寸
+
+    pos_weight_value = recommended_pos_weight  # 从上面的计算中获取
+    # 创建带权重的损失函数
+    criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight_value])).to(device)
+
+    early_stopping = EarlyStoppingDist(patience=patience, verbose=True, delta=delta, path=os.path.join(output_dir, f'best_model_as_{batch_idx}.pth'),
+                                       rank=rank, local_rank=local_rank)
+    
+    # 分布式训练模型
+    train_losses, val_losses = train_process_distribute(
+        model, optimizer, early_stopping, train_loader, val_loader, device, num_epochs=num_epochs, criterion=criterion,
+        flag_distributed=flag_distributed, rank=rank, local_rank=local_rank, loss_call_label="val")
+    
+    if rank == 0:
+        font_prop = font.font_prop
+
+        # 绘制损失曲线(可选)
+        plt.figure(figsize=(10, 6))
+        epochs = range(1, len(train_losses) + 1)
+        plt.plot(epochs, train_losses, 'b-', label='训练集损失')
+        plt.plot(epochs, val_losses, 'r-', label='验证集损失')
+        plt.title('训练和验证集损失曲线', fontproperties=font_prop)
+        plt.xlabel('Epochs', fontproperties=font_prop)
+        plt.ylabel('Loss', fontproperties=font_prop)
+        plt.legend(prop=font_prop)
+        plt.savefig(os.path.join(photo_dir, f"train_loss_batch_{batch_idx}.png"))
+
+    # 训练结束后加载最佳模型参数
+    best_model_path = os.path.join(output_dir, f'best_model_as_{batch_idx}.pth')
+
+    # 确保所有进程都看到相同的文件系统状态
+    if flag_distributed:
+        dist.barrier()
+    
+    # 创建用于广播的列表(只有一个元素)
+    checkpoint_list = [None]
+
+    if rank == 0:
+        if os.path.exists(best_model_path):
+            print(f"Rank 0: batch_idx:{batch_idx} Loading best model from {best_model_path}")
+            # 直接加载到 CPU,避免设备不一致问题
+            checkpoint_list[0] = torch.load(best_model_path, map_location='cpu')
+        else:
+            print(f"Rank 0: batch_idx:{batch_idx} Warning - Best model not found at {best_model_path}")
+            # 使用当前模型状态(确保在 CPU 上)
+            if flag_distributed:
+                checkpoint_list[0] = model.module.cpu().state_dict()
+            else:
+                checkpoint_list[0] = model.cpu().state_dict()
+
+    # 广播模型状态字典
+    if flag_distributed:
+        dist.broadcast_object_list(checkpoint_list, src=0)
+
+    # 所有进程获取广播后的状态字典
+    checkpoint = checkpoint_list[0]
+
+    # 加载模型状态
+    if flag_distributed:
+        model.module.load_state_dict(checkpoint)
+    else:
+        model.load_state_dict(checkpoint)
+    
+    # 确保所有进程完成加载
+    if flag_distributed:
+        dist.barrier()
+
+    if flag_distributed:
+        # 调用评估函数
+        evaluate_model_distribute(
+            model.module,       # 使用 DDP 包裹前的原始模型
+            device,
+            None, None, None,
+            test_loader=val_loader,  # 使用累积验证集 
+            batch_flight_routes=batch_flight_routes, target_scaler=target_scaler,
+            flag_distributed=flag_distributed,
+            rank=rank, local_rank=local_rank, world_size=world_size, 
+            output_dir=output_dir, batch_idx=batch_idx, save_mode='a'
+        )
+    else:
+        evaluate_model_distribute(
+            model,
+            device,
+            None, None, None,
+            test_loader=val_loader,  # 使用累积验证集
+            batch_flight_routes=batch_flight_routes, target_scaler=target_scaler,
+            flag_distributed=False,
+            output_dir=output_dir, batch_idx=batch_idx, save_mode='a'
+        )
+
+    return model
+
+
+def train_process_distribute(model, optimizer, early_stopping, train_loader, val_loader, device, num_epochs=200, criterion=None, save_file='best_model.pth',
+                             flag_distributed=False, rank=0, local_rank=0, loss_call_label="train"):
+    # 具体训练过程
+    train_losses = []
+    val_losses = []
+    
+    # 初始化损失为张量(兼容非分布式和分布式)
+    # total_train_loss = torch.tensor(0.0, device=device)
+    # total_val_loss = torch.tensor(0.0, device=device)
+
+    # 初始化 TensorBoard(只在主进程)
+    # if rank == 0:
+    #     writer = SummaryWriter(log_dir='runs/experiment_name')
+    #     train_global_step = 0
+    #     val_global_step = 0
+    
+    for epoch in range(num_epochs):
+        # --- 训练阶段 ---
+        model.train()
+        if flag_distributed:
+            train_loader.sampler.set_epoch(epoch)  # 确保每个进程一致地打乱顺序
+
+        # total_train_loss.zero_()   # 重置损失累计
+        total_train_loss = torch.tensor(0.0, device=device)
+        num_train_samples = torch.tensor(0, device=device)     # 当前进程的样本数
+
+        for batch_idx, batch in enumerate(train_loader):
+            X_batch, y_batch = batch[:2]  # 假设 group_ids 不需要参与训练
+            X_batch = X_batch.to(device)
+            y_batch = y_batch.to(device)
+            
+            optimizer.zero_grad()
+            outputs = model(X_batch)
+            loss = criterion(outputs, y_batch)
+            loss.backward()
+
+            # 打印
+            # if rank == 0:
+            #     # print_gradient_range(model)
+            #     # 记录损失值
+            #     writer.add_scalar('Loss/train_batch', loss.item(), train_global_step)
+            #     # 记录元数据
+            #     writer.add_scalar('Metadata/train_epoch', epoch, train_global_step)
+            #     writer.add_scalar('Metadata/train_batch_in_epoch', batch_idx, train_global_step)
+
+            #     log_gradient_stats(model, writer, train_global_step, "train")
+
+            #     # 更新全局步数
+            #     train_global_step += 1
+
+            # 梯度裁剪(已兼容 DDP)
+            # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
+            optimizer.step()
+            
+            # 累计损失
+            total_train_loss += loss.detach() * X_batch.size(0)   # detach() 保留张量形式以支持跨进程通信
+            num_train_samples += X_batch.size(0)
+        
+        # --- 同步训练损失 ---
+        if flag_distributed:
+            # 会将所有进程的 total_train_loss 求和后, 同步到每个进程
+            dist.all_reduce(total_train_loss, op=dist.ReduceOp.SUM)
+            dist.all_reduce(num_train_samples, op=dist.ReduceOp.SUM)
+
+        # avg_train_loss = total_train_loss.item() / len(train_loader.dataset)
+        avg_train_loss = total_train_loss.item() / num_train_samples.item()
+        train_losses.append(avg_train_loss)
+
+        # --- 验证阶段 ---
+        model.eval()
+        # total_val_loss.zero_()  # 重置验证损失
+        total_val_loss = torch.tensor(0.0, device=device)
+        num_val_samples = torch.tensor(0, device=device)
+
+        with torch.no_grad():
+            for batch_idx, batch in enumerate(val_loader):
+                X_val, y_val = batch[:2]
+                X_val = X_val.to(device)
+                y_val = y_val.to(device)
+                
+                outputs = model(X_val)
+                val_loss = criterion(outputs, y_val)
+                total_val_loss += val_loss.detach() * X_val.size(0)
+                num_val_samples += X_val.size(0)
+
+                # if rank == 0:
+                #     # 记录验证集batch loss
+                #     writer.add_scalar('Loss/val_batch', val_loss.item(), val_global_step)
+                #     # 记录验证集元数据
+                #     writer.add_scalar('Metadata/val_epoch', epoch, val_global_step)
+                #     writer.add_scalar('Metadata/val_batch_in_epoch', batch_idx, val_global_step)
+
+                #     # 更新验证集全局步数
+                #     val_global_step += 1
+
+                # if local_rank == 0:
+                #     print(f"rank:{rank}, outputs:{outputs}")
+                #     print(f"rank:{rank}, y_val:{y_val}")
+                #     print(f"rank:{rank}, val_loss:{val_loss.detach()}")
+                #     print(f"rank:{rank}, size:{X_val.size(0)}")
+
+        # --- 同步验证损失 ---
+        if flag_distributed:
+            dist.all_reduce(total_val_loss, op=dist.ReduceOp.SUM)
+            dist.all_reduce(num_val_samples, op=dist.ReduceOp.SUM)
+
+        # avg_val_loss = total_val_loss.item() / len(val_loader.dataset)
+        avg_val_loss = total_val_loss.item() / num_val_samples.item()
+        val_losses.append(avg_val_loss)
+
+        # if rank == 0:
+        #     # 记录epoch平均损失
+        #     writer.add_scalar('Loss/train_epoch_avg', avg_train_loss, epoch)
+        #     writer.add_scalar('Loss/val_epoch_avg', avg_val_loss, epoch)
+
+        if local_rank == 0:
+            print(f"Rank:{rank}, Epoch {epoch+1}/{num_epochs}, 训练集损失: {avg_train_loss:.4f}, 验证集损失: {avg_val_loss:.4f}")
+
+        # --- 早停与保存逻辑(仅在 rank 0 执行)---
+        if rank == 0:
+            
+            # 模型保存兼容分布式和非分布式
+            model_to_save = model.module if flag_distributed else model   # 当使用 model = DDP(model) 封装后,原始模型会被包裹在 model.module 属性
+            if loss_call_label == "train":
+                early_stopping(avg_train_loss, model_to_save)  # 平均训练集损失 
+            else:
+                early_stopping(avg_val_loss, model_to_save)   # 平均验证集损失
+            
+            if early_stopping.early_stop:
+                print(f"Rank:{rank}, 早停触发,停止训练 at epoch {epoch}")
+                # 非分布式模式下直接退出循环
+                if not flag_distributed:
+                    break
+
+        # --- 同步早停状态(仅分布式需要)---
+        if flag_distributed:
+            # 将早停标志转换为张量广播
+            early_stop_flag = torch.tensor([early_stopping.early_stop], device=device)
+            dist.broadcast(early_stop_flag, src=0)
+            if early_stop_flag.item():   # item()取张量的布尔值
+                break
+        # else:
+        #     # 非分布式模式下,直接检查早停标志
+        #     if early_stopping.early_stop:
+        #         break
+
+    return train_losses, val_losses

+ 96 - 4
utils.py

@@ -1,5 +1,7 @@
+import gc
+import time
 import torch
 import torch
-
+from torch.utils.data import Dataset
 
 
 # 航线列表分组切片并带上索引
 # 航线列表分组切片并带上索引
 def chunk_list_with_index(lst, group_size):
 def chunk_list_with_index(lst, group_size):
@@ -27,6 +29,9 @@ def insert_df_col(df, insert_col_name, base_col_name, inplace=True):
 
 
 # 真正创建序列过程
 # 真正创建序列过程
 def create_fixed_length_sequences(df, features, target_vars, input_length=452, is_train=True):
 def create_fixed_length_sequences(df, features, target_vars, input_length=452, is_train=True):
+    print(">>开始创建序列")
+    start_time = time.time()
+
     sequences = []
     sequences = []
     targets = []
     targets = []
     group_ids = []
     group_ids = []
@@ -77,7 +82,94 @@ def create_fixed_length_sequences(df, features, target_vars, input_length=452, i
                            str(last_row['Hours_Until_Departure']), 
                            str(last_row['Hours_Until_Departure']), 
                            str(last_row['update_hour'])])
                            str(last_row['update_hour'])])
             group_ids.append(tuple(name_c))
             group_ids.append(tuple(name_c))
-            pass
+        
+        del df_group_bag_30_filtered, df_group_bag_20_filtered
+        del df_group_bag_30, df_group_bag_20
+        del df_group
+
+    gc.collect()
+    print(">>结束创建序列")
+    end_time = time.time()
+    run_time = round(end_time - start_time, 3)
+    print(f"用时: {run_time} 秒")
+    print(f"生成的序列数量:{len(sequences)}")
+    
+    return sequences, targets, group_ids
+
+
+class FlightDataset(Dataset):
+    def __init__(self, X_sequences, y_sequences=None, group_ids=None):
+        self.X_sequences = X_sequences
+        self.y_sequences = y_sequences
+        self.group_ids = group_ids
+        self.return_group_ids = group_ids is not None
+
+    def __len__(self):
+        return len(self.X_sequences)
+
+    def __getitem__(self, idx):
+        if self.return_group_ids:
+            if self.y_sequences:
+                return self.X_sequences[idx], self.y_sequences[idx], self.group_ids[idx]
+            else:
+                return self.X_sequences[idx], self.group_ids[idx]
+        else:
+            if self.y_sequences:
+                return self.X_sequences[idx], self.y_sequences[idx]
+            else:
+                return self.X_sequences[idx]
+
+
+class EarlyStoppingDist:
+    """早停机制(分布式)"""
+    def __init__(self, patience=10, verbose=False, delta=0, path='best_model.pth', rank=0, local_rank=0):
+        """
+        Args:
+            patience (int): 在训练集(或验证集)损失不再改善时,等待多少个epoch后停止训练
+            verbose (bool): 是否打印相关信息
+            delta (float): 训练集损失需要减少的最小变化量
+            path (str): 保存最佳模型的路径
+        """
+        self.patience = patience
+        self.verbose = verbose
+        self.delta = delta
+        self.path = path
+        self.counter = 0
+        self.best_loss = None
+        self.early_stop = False
+        self.rank = rank
+        self.local_rank = local_rank
+
+    def __call__(self, loss, model):
+        if self.best_loss is None:
+            self.best_loss = loss
+            self.save_checkpoint(loss, model)
+        elif loss > self.best_loss - self.delta:
+            self.counter += 1
+            if self.verbose and self.rank == 0:
+                print(f'Rank:{self.rank}, Local Rank:{self.local_rank}, EarlyStopping counter: {self.counter} out of {self.patience}')
+            if self.counter >= self.patience:
+                self.early_stop = True
+        else:
+            self.save_checkpoint(loss, model)
+            self.best_loss = loss
+            self.counter = 0
+            if self.is_nan(loss):
+                self.counter += self.patience  # 立即触发早停
+                self.early_stop = True
+
+    def is_nan(self, loss):
+        """检查损失值是否为NaN(通用方法)"""
+        try:
+            # 所有NaN类型都不等于自身
+            return loss != loss
+        except Exception:
+            # 处理不支持比较的类型
+            return False
+
+    def save_checkpoint(self, loss, model):
+        """保存模型"""
+        if self.verbose and self.rank == 0:
+            print(f'Rank:{self.rank}, Local Rank:{self.local_rank}, Loss decreased ({self.best_loss:.6f} --> {loss:.6f}).  Saving model ...')
+            torch.save(model.state_dict(), self.path)
 
 
-        pass
-    pass