node04 2 өдөр өмнө
parent
commit
c88dd378ea
7 өөрчлөгдсөн 1131 нэмэгдсэн , 14 устгасан
  1. 174 0
      evaluate.py
  2. 4 0
      font.py
  3. 327 10
      main_tr.py
  4. 87 0
      model.py
  5. BIN
      simhei.ttf
  6. 443 0
      train.py
  7. 96 4
      utils.py

+ 174 - 0
evaluate.py

@@ -0,0 +1,174 @@
+import os
+import numpy as np
+import pandas as pd
+import torch
+import torch.distributed as dist
+from torch.utils.data import DataLoader
+from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, mean_absolute_error
+from matplotlib import font_manager
+import matplotlib.pyplot as plt
+import seaborn as sns
+from utils import FlightDataset
+
+
+# 分布式模型评估
+def evaluate_model_distribute(model, device, sequences, targets, group_ids, batch_size=16, test_loader=None, 
+                              batch_flight_routes=None, target_scaler=None, 
+                              flag_distributed=False, rank=0, local_rank=0, world_size=1, output_dir='.', batch_idx=-1,
+                              csv_file='evaluate_results.csv', evalute_flag='evaluate', save_mode='a'):
+    
+    if test_loader is None:
+        if not sequences:
+            print("没有足够的数据进行评估。")
+            return
+        test_dataset = FlightDataset(sequences, targets, group_ids)
+        test_loader = DataLoader(test_dataset, batch_size=batch_size)   # ??
+    
+    batch_fn_str = ' '.join([route.replace('|', ' ') for route in batch_flight_routes]) if batch_flight_routes else ''
+
+    model.eval()
+
+    # 初始化存储容器(张量形式以便跨进程通信)
+    y_preds_list = []
+    y_trues_list = []
+    group_info_list = []
+
+    with torch.no_grad():
+        for X_batch, y_batch, group_ids_batch in test_loader:
+            X_batch = X_batch.to(device)
+            y_batch = y_batch.to(device)
+
+            # 分布式模式下需确保不同进程处理不同数据分片
+            outputs = model(X_batch)    
+
+            # 收集当前批次的结果(保留在GPU上)
+            y_preds_list.append(outputs.cpu().numpy())  # 移动到CPU以节省GPU内存
+            y_trues_list.append(y_batch.cpu().numpy())
+
+            # 处理 group_info(需转换为可序列化格式)
+            for i in range(len(group_ids_batch[0])):
+                group_id = tuple(g[i].item() if isinstance(g, torch.Tensor) else g[i] for g in group_ids_batch)
+                group_info_list.append(group_id)
+                pass
+    
+    # 合并当前进程的结果
+    y_preds = np.concatenate(y_preds_list, axis=0)
+    y_trues = np.concatenate(y_trues_list, axis=0)
+    group_info = group_info_list
+    
+    # --- 分布式结果聚合 ---
+    if flag_distributed:
+
+        # 收集所有进程的预测结果
+        y_preds_tensor = torch.tensor(y_preds, device=device)
+        y_trues_tensor = torch.tensor(y_trues, device=device)
+
+        # 收集所有进程的 y_preds 和 y_trues
+        gather_y_preds = [torch.zeros_like(y_preds_tensor) for _ in range(world_size)]
+        gather_y_trues = [torch.zeros_like(y_trues_tensor) for _ in range(world_size)]
+        dist.all_gather(gather_y_preds, y_preds_tensor)
+        dist.all_gather(gather_y_trues, y_trues_tensor)
+
+        # 合并结果到 rank 0
+        if rank == 0:
+            y_preds = np.concatenate([tensor.cpu().numpy() for tensor in gather_y_preds], axis=0)
+            y_trues = np.concatenate([tensor.cpu().numpy() for tensor in gather_y_trues], axis=0)
+
+        # 将 group_info 转换为字符串列表以便传输
+        group_info_str = ['|'.join(map(str, info)) for info in group_info]
+        gather_group_info = [None for _ in range(world_size)]
+        dist.all_gather_object(gather_group_info, group_info_str)
+
+        if rank == 0:
+            group_info = []
+            for info_list in gather_group_info:
+                for info_str in info_list:
+                    group_info.append(tuple(info_str.split('|')))   
+
+    # --- 仅在 rank 0 计算指标并保存结果 ---
+    if rank == 0:
+        
+        # 分类任务结果
+        y_preds_class = y_preds[:, 0]
+        y_trues_class = y_trues[:, 0]
+        y_preds_class_labels = (y_preds_class >= 0.5).astype(int)
+        y_trues_class_labels = y_trues_class.astype(int)
+
+        # 打印指标
+        printScore_cc(y_trues_class_labels, y_preds_class_labels, batch_fn_str=batch_fn_str, batch_idx=batch_idx, evalute_flag=evalute_flag)
+
+        # 构造 DataFrame
+        results_df = pd.DataFrame({
+            'city_pair': [info[0] for info in group_info],
+            'flight_day': [info[1] for info in group_info],
+            'flight_number_1': [info[2] for info in group_info],
+            'flight_number_2': [info[3] for info in group_info],
+            'from_date': [info[4] for info in group_info],
+            'baggage': [info[5] for info in group_info],
+            'price': [info[6] for info in group_info],
+            'Hours_until_Departure': [info[7] for info in group_info],
+            'update_hour': [info[8] for info in group_info],
+            'probability': y_preds_class,
+            'Actual_Will_Price_Drop': y_trues_class_labels,
+            'Predicted_Will_Price_Drop': y_preds_class_labels,
+        })
+
+        # 数值处理
+        threshold = 1e-3
+        numeric_columns = ['probability',
+                           # 'Actual_Amount_Of_Drop', 'Predicted_Amount_Of_Drop', 'Actual_Time_To_Drop', 'Predicted_Time_To_Drop'
+                           ]
+        for col in numeric_columns:
+            results_df[col] = results_df[col].where(results_df[col].abs() >= threshold, 0)
+        
+        # 保存结果
+        results_df_path = os.path.join(output_dir, csv_file)
+        if save_mode == 'a':
+            # 追加模式
+            results_df.to_csv(results_df_path, mode='a', index=False, header=not os.path.exists(results_df_path))
+        else:
+            # 重写模式
+            results_df.to_csv(results_df_path, mode='w', index=False, header=True) 
+        print(f"预测结果已保存到 '{results_df_path}'")
+        
+        return results_df
+    
+    else:
+        return None
+
+
+def printScore_cc(y_trues_class_labels, y_preds_class_labels, batch_fn_str='', batch_idx=-1, evalute_flag='evaluate'):
+    
+    accuracy = accuracy_score(y_trues_class_labels, y_preds_class_labels)
+    precision = precision_score(y_trues_class_labels, y_preds_class_labels, zero_division=0)
+    recall = recall_score(y_trues_class_labels, y_preds_class_labels, zero_division=0)
+    f1 = f1_score(y_trues_class_labels, y_preds_class_labels, zero_division=0)
+
+    print(f"分类准确率: {accuracy:.4f}")
+    print(f"分类精确率: {precision:.4f}")
+    print(f"分类召回率: {recall:.4f}")
+    print(f"分类F1值: {f1:.4f}")
+
+    # 获取二者都为1的正例索引
+    indices = np.where((y_trues_class_labels == 1) & (y_preds_class_labels == 1))[0]
+    if len(indices) > 0:
+        pass
+    else:
+        print("没有正例")
+
+    font_path = "./simhei.ttf"
+    font_prop = font_manager.FontProperties(fname=font_path)
+    # font_prop = font.font_prop
+    
+    # 混淆矩阵
+    cm = confusion_matrix(y_trues_class_labels, y_preds_class_labels)
+    plt.figure(figsize=(6, 5))
+    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
+                xticklabels=['预测:不会降价', '预测:会降价'],
+                yticklabels=['实际:不会降价', '实际:会降价'])
+    plt.xticks(fontproperties=font_prop)
+    plt.yticks(fontproperties=font_prop)
+    plt.xlabel('预测情况', fontproperties=font_prop)
+    plt.ylabel('实际结果', fontproperties=font_prop)
+    plt.title('分类结果的混淆矩阵', fontproperties=font_prop)
+    plt.savefig(f"./photo/{evalute_flag}_confusion_matrix_{batch_idx}_{batch_fn_str}.png")

+ 4 - 0
font.py

@@ -0,0 +1,4 @@
+from matplotlib import font_manager
+# 设置字体
+font_path = "./simhei.ttf"
+font_prop = font_manager.FontProperties(fname=font_path)

+ 327 - 10
main_tr.py

@@ -6,15 +6,18 @@ from torch.nn.parallel import DistributedDataParallel as DDP
 import joblib
 import gc
 import pandas as pd
-import numpy as np
+# import numpy as np
 import redis
 import time
 import pickle
 import shutil
 from datetime import datetime, timedelta
 from utils import chunk_list_with_index, create_fixed_length_sequences
+from model import PriceDropClassifiTransModel
 from data_loader import mongo_con_parse, load_train_data
 from data_preprocess import preprocess_data, standardization
+from train import prepare_data_distribute, train_model_distribute
+from evaluate import printScore_cc
 from config import mongodb_config, vj_flight_route_list, vj_flight_route_list_hot, vj_flight_route_list_nothot, \
     CLEAN_VJ_HOT_NEAR_INFO_TAB, CLEAN_VJ_HOT_FAR_INFO_TAB, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB, CLEAN_VJ_NOTHOT_FAR_INFO_TAB
 
@@ -68,7 +71,19 @@ def init_distributed_backend():
 
 # 初始化模型和相关参数
 def initialize_model(device):
-    return None
+    input_size = len(features)
+    model = PriceDropClassifiTransModel(input_size, num_periods=2, hidden_size=64, num_layers=3, output_size=1, dropout=0.2)
+    model.to(device)
+
+    if FLAG_Distributed:
+        model = DDP(model, device_ids=[device], find_unused_parameters=True)   # 使用DDP包装模型
+
+    if FLAG_Distributed:
+        print(f"Rank:{dist.get_rank()}, 模型已初始化,输入尺寸:{input_size}")
+    else:
+        print(f"模型已初始化,输入尺寸:{input_size}")
+
+    return model
 
 def continue_before_process(redis_client, lock_key):
     # rank0 跳出循环前的处理
@@ -95,7 +110,7 @@ def start_train():
     photo_dir = "./photo"
 
     date_end = datetime.today().strftime("%Y-%m-%d")
-    date_begin = (datetime.today() - timedelta(days=18)).strftime("%Y-%m-%d")
+    date_begin = (datetime.today() - timedelta(days=41)).strftime("%Y-%m-%d")
 
     # 仅在 rank == 0 时要做的
     if rank == 0:
@@ -121,9 +136,9 @@ def start_train():
 
         print(f"最终特征列表:{features}")
 
-    # 定义优化器和损失函数(只回归)
-    # criterion = RegressionLoss(loss_func_flag="Quantile", quantile=0.5)
-    # optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=1e-5)
+    # 定义优化器和损失函数
+    criterion = None   #  后面在训练之前定义
+    optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=1e-5)
 
     group_size = 1              # 每几组作为一个批次
     num_epochs_per_batch = 200  # 每个批次训练的轮数,可以根据需要调整
@@ -136,7 +151,9 @@ def start_train():
     lock_key = "data_loading_lock_11"
     barrier_key = 'distributed_barrier_11'
 
+    assemble_size = 1   # 几个batch作为一个集群assemble
     batch_idx = -1
+    batch_flight_routes = None   # 占位, 避免其它rank找不到定义
 
     # 主干代码
     # flight_route_list = vj_flight_route_list_hot + vj_flight_route_list_nothot
@@ -146,9 +163,9 @@ def start_train():
 
     # 调试代码
     s = 38   # 菲律宾2025-12-08是节假日 s=38 选到马尼拉 
-    flight_route_list = vj_flight_route_list_hot[:0] + vj_flight_route_list_nothot[s:]
+    flight_route_list = vj_flight_route_list_hot[0:] + vj_flight_route_list_nothot[s:]
     flight_route_list_len = len(flight_route_list)
-    route_len_hot = len(vj_flight_route_list_hot[:0])
+    route_len_hot = len(vj_flight_route_list_hot[0:])
     route_len_nothot = len(vj_flight_route_list_nothot[s:])
     
     if local_rank == 0:
@@ -222,6 +239,7 @@ def start_train():
             # 使用默认配置
             client, db = mongo_con_parse()
             print(f"第 {i} 组 :", group_route_list)
+            batch_flight_routes = group_route_list
 
             # 根据索引位置决定是 热门 还是 冷门
             if 0 <= i < route_len_hot:
@@ -277,11 +295,310 @@ def start_train():
             
             # 生成序列
             sequences, targets, group_ids = create_fixed_length_sequences(df_train_inputs, features, target_vars, input_length=452)
-            pass
+            
+            # 新增有效性检查
+            if len(sequences) == 0 or len(targets) == 0 or len(group_ids) == 0:
+                valid_batch[0] = 0
+                print("警告:当前批次数据为空,标记为无效批次")
+            
+            # 数据加载及预处理完成,设置 Redis 锁 key 的值为 1
+            redis_client.set(lock_key, 1)
+            print("rank0 数据加载完成,已将 Redis 锁 key 值设置为 1")
 
         else:
-            pass
+            val = None
+            # 其它 rank 等待:只有当 lock key 存在且其值为 "1" 时才算数据加载完成
+            print(f"rank{rank} 正在等待 rank0 完成数据加载...")
+            while True:
+                val = redis_client.get(lock_key)
+                if val is not None and val.decode('utf-8') in ["1", "2"]:
+                    break
+                time.sleep(1)
+            if val is not None and val.decode('utf-8') == "2":
+                print(f"rank{rank} 跳过空批次 {i}")
+                time.sleep(3)
+                continue
+            print(f"rank{rank} 检测到数据加载已完成,继续后续处理...")
+
+        # 同步点:所有 Rank 在此等待
+        if FLAG_Distributed:
+            # 确保所有 CUDA 操作完成并释放缓存
+            print(f"rank{rank} ready synchronize ...")
+            torch.cuda.synchronize()
+            
+            print(f"rank{rank} ready empty_cache ...")
+            torch.cuda.empty_cache()
+            
+            print(f"rank{rank} ready barrier ...")
+            dist.barrier()  # 移除 device_ids 参数
+            # dist.barrier(device_ids=[local_rank])
+
+        print(f"rank{rank} done barrier ...")
+
+        # 广播批次有效性标志
+        if FLAG_Distributed:
+            dist.broadcast(valid_batch, src=0)
+
+        # 所有 Rank 检查批次有效性
+        if valid_batch.item() == 0:
+            print(f"Rank {rank} 跳过无效批次 {i}")
+            continue  # 所有 Rank 跳过当前循环
+
+        # 所有 Rank 同时进入数据分发
+        if rank == 0:
+            # 分片并分发
+            my_sequences, my_targets, my_group_ids = distribute_sharded_data(sequences, targets, group_ids, world_size, rank, device, flag_distributed=FLAG_Distributed)
+        else:
+            # 其它 Rank 接收数据
+            my_sequences, my_targets, my_group_ids = distribute_sharded_data([], [], [], world_size, rank, device, flag_distributed=FLAG_Distributed)
+
+        # 查看一下各rank是否分到数据
+        debug_print_shard_info([], my_targets, my_group_ids, rank, local_rank, world_size)
+
+        pre_flag, train_single, val_single, test_single = prepare_data_distribute(my_sequences, my_targets, my_group_ids, 
+            flag_distributed=FLAG_Distributed, rank=rank, local_rank=local_rank, world_size=world_size)
+        
+        del my_sequences
+        del my_targets
+        del my_group_ids
+        gc.collect()
+
+        if not pre_flag:
+            print(f"Rank {rank} 跳过无效数据批次 {i}")
+            continue
+
+        train_sequences = train_single['sequences']
+        train_targets = train_single['targets']
+        train_group_ids = train_single['group_ids']
+
+        val_sequences = val_single['sequences']
+        val_targets = val_single['targets']
+        val_group_ids = val_single['group_ids']
+
+        # test_sequences = test_single['sequences']
+        # test_targets = test_single['targets']
+        # test_group_ids = test_single['group_ids']
+
+        if FLAG_Distributed:
+            dist.barrier()
+
+        # 训练模型
+        model = train_model_distribute(train_sequences, train_targets, train_group_ids, val_sequences, val_targets, val_group_ids,
+            model, criterion, optimizer, device, num_epochs=num_epochs_per_batch, batch_size=16, target_scaler=target_scaler, 
+            flag_distributed=FLAG_Distributed, rank=rank, local_rank=local_rank, world_size=world_size, 
+            output_dir=output_dir, photo_dir=photo_dir, batch_idx=batch_idx, 
+            batch_flight_routes=batch_flight_routes, patience=40, delta=0.001)
+
+        del train_single
+        del val_single
+        del test_single
+        gc.collect()
+
+        # 重置模型参数
+        if (i + 1) % assemble_size == 0:
+            
+            if FLAG_Distributed:
+                dist.barrier()
+
+            del model, optimizer 
+            torch.cuda.empty_cache()  # 清理GPU缓存
+
+            model = initialize_model(device)  # 重置模型
+            
+            optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=1e-5)  # 重置优化器
+            print(f"Rank {rank}, Reset Model at batch {i} due to performance drop")
+
+    ###############################################################################################################
+
+    # 在整体批次训练结束后
+    if rank == 0:
+        # pass
+        # torch.save(model.state_dict(), os.path.join(output_dir, 'final_model.pth'))
+        print("模型训练完成并已保存。")
+
+        csv_file = 'evaluate_results.csv'
+        csv_path = os.path.join(output_dir, csv_file)
+        # 汇总评估结果
+        try:
+            df = pd.read_csv(csv_path)
+        except Exception as e:
+            print(f"read {csv_path} error: {str(e)}")
+            df = None
+
+        if df is not None:
+            # 提取真实值和预测值
+            y_trues_class_labels = df['Actual_Will_Price_Drop']
+            y_preds_class_labels = df['Predicted_Will_Price_Drop']
+
+            printScore_cc(y_trues_class_labels, y_preds_class_labels, batch_fn_str='validate', batch_idx='')
+
+    if FLAG_Distributed:
+        dist.destroy_process_group()   # 显式调用 destroy_process_group 来清理 NCCL 的进程组资源
+
+
+def distribute_sharded_data(sequences, targets, group_ids, world_size, rank, device, flag_distributed):
+    # --- 非分布式模式:直接返回全量数据
+    if not flag_distributed:
+        return sequences, targets, group_ids
+
+    # ================== 第一阶段:元数据广播 ==================
+    if rank == 0:
+        # 将 group_ids 序列化为字节流
+        group_bytes = pickle.dumps(group_ids)
+        # 转换为张量用于分块传输
+        group_tensor = torch.frombuffer(bytearray(group_bytes), dtype=torch.uint8).to(device)
+        # 处理其他数据
+        seq_tensor = torch.stack(sequences, dim=0).to(device)    # shape [N, 2, 452, 25]
+        tgt_tensor = torch.stack(targets, dim=0).to(device)      # shape [N, 1]
+
+        meta_data = {
+            # sequences/targets 元数据
+            'seq_shape': seq_tensor.shape,
+            'tgt_shape': tgt_tensor.shape,
+            'seq_dtype': str(seq_tensor.dtype).replace('torch.', ''),   # 关键修改点
+            'tgt_dtype': str(tgt_tensor.dtype).replace('torch.', ''),
+            
+            # group_ids 元数据
+            'group_shape': group_tensor.shape,
+            'group_bytes_len': len(group_bytes),
+            'pickle_protocol': pickle.HIGHEST_PROTOCOL
+        }
+    else:
+        meta_data = None
+
+    # 广播元数据(所有rank都需要)
+    meta_data = broadcast(meta_data, src=0, rank=rank, device=device)
+    
+    # ================== 第二阶段:分块传输 ==================
+    # 初始化接收缓冲区(所有Rank)
+    if rank == 0:
+        group_tensor = group_tensor
+        seq_tensor = seq_tensor
+        tgt_tensor = tgt_tensor
+    else:
+        seq_dtype = getattr(torch, meta_data['seq_dtype'])   # 例如 meta_data['seq_dtype'] = "float32"
+        tgt_dtype = getattr(torch, meta_data['tgt_dtype'])
+        
+        group_tensor = torch.zeros(meta_data['group_shape'], dtype=torch.uint8, device=device)
+        seq_tensor = torch.zeros(meta_data['seq_shape'], dtype=seq_dtype, device=device)
+        tgt_tensor = torch.zeros(meta_data['tgt_shape'], dtype=tgt_dtype, device=device)
+    
+    # 并行传输所有数据(按传输量排序:先大后小)
+    _chunked_broadcast(seq_tensor, src=0, rank=rank)     # 最大数据优先
+    _chunked_broadcast(tgt_tensor, src=0, rank=rank)
+    _chunked_broadcast(group_tensor, src=0, rank=rank)    # 最后传输group_ids
+
+    # ================== 第三阶段:数据重建 ==================
+    # 重建 sequences 和 targets
+    sequences_list = [seq.cpu().clone() for seq in seq_tensor]   # 自动按第0维切分
+    targets_list = [tgt.cpu().clone() for tgt in tgt_tensor]
+
+    # 重建 group_ids(关键步骤)
+    if rank == 0:
+        # Rank0直接使用原始数据避免重复序列化
+        group_ids_rebuilt = group_ids
+    else:
+        # 1. 提取有效字节(去除填充)
+        group_bytes = bytes(group_tensor.cpu().numpy().tobytes()[:meta_data['group_bytes_len']])
+        # 2. 反序列化
+        try:
+            group_ids_rebuilt = pickle.loads(group_bytes)
+        except pickle.UnpicklingError as e:
+            raise RuntimeError(f"反序列化 group_ids 失败: {str(e)}")
+        
+        # 3. 结构校验
+        _validate_group_structure(group_ids_rebuilt)
+
+    return sequences_list, targets_list, group_ids_rebuilt
+
+def broadcast(data, src, rank, device):
+    """安全地广播任意数据,确保张量在正确的设备上"""
+    if rank == src:
+        # 序列化数据
+        data_bytes = pickle.dumps(data)
+        data_size = torch.tensor([len(data_bytes)], dtype=torch.long, device=device)
+        # 创建数据张量并移动到设备
+        data_tensor = torch.frombuffer(bytearray(data_bytes), dtype=torch.uint8).to(device)
+        # 先广播数据大小
+        dist.broadcast(data_size, src=src)
+        # 然后广播数据
+        dist.broadcast(data_tensor, src=src)
+        return data
+    else:
+        # 接收数据大小
+        data_size = torch.tensor([0], dtype=torch.long, device=device)
+        dist.broadcast(data_size, src=src)
+        # 分配数据张量
+        data_tensor = torch.empty(data_size.item(), dtype=torch.uint8, device=device)
+        dist.broadcast(data_tensor, src=src)
+        # 反序列化
+        data = pickle.loads(data_tensor.cpu().numpy().tobytes())
+        return data
+    
+def _chunked_broadcast(tensor, src, rank, chunk_size=1024*1024*128):    # chunk_size 单位是字节
+    """分块广播张量优化通信效率"""
+    # Step 1. 准备连续内存缓冲
+    buffer = tensor.detach().contiguous()
+    # Step 2. 计算字节数
+    element_size = buffer.element_size()  # 每个元素的字节数(如 float32 是 4)
+    total_elements = buffer.numel()
+
+    # 计算每个块最多包含多少元素(根据字节数换算)
+    elements_per_chunk = chunk_size // element_size
+    # 分块数量
+    num_chunks = (total_elements + elements_per_chunk - 1) // elements_per_chunk
+    
+    # Step 4. 逐块广播
+    for chunk_idx in range(num_chunks):
+        # 计算当前块的字节范围
+        start_element = chunk_idx * elements_per_chunk
+        end_element = min((chunk_idx+1)*elements_per_chunk, total_elements)
+        # Step 5. 从大张量中切出当前块
+        chunk = buffer.view(-1).narrow(0, start_element, end_element - start_element)
+        # Step 6. 执行广播
+        dist.broadcast(chunk, src=src)
+        # 说明: 虽然单个chunk是一维的, 但通过其内部的 1.严格的传输顺序 2.接收端的内存预分配 3.最终reshape操作 原始张量的形状得以完美恢复
+
+def _validate_group_structure(group_ids):
+    """校验 group_ids 数据结构完整性"""
+    assert isinstance(group_ids, list), "Group IDs 必须是列表"
+    if len(group_ids) == 0:
+        print("还原的 group_ids 长度为0")
+        return
+    
+    sample = group_ids[0]
+    assert isinstance(sample, tuple), "元素必须是元组"
+    assert len(sample) == 9, "元组长度必须为9"
+
+def debug_print_shard_info(sequences, targets, group_ids, rank, local_rank, world_size):
+    """分布式环境下按Rank顺序打印分片前5条样本"""
+    # 同步所有进程
+    if FLAG_Distributed:
+        dist.barrier(device_ids=[local_rank])
+
+    # 按Rank顺序逐个打印(避免输出混杂)
+    for r in range(world_size):
+        if r == rank:
+            print(f"\n=== Rank {rank}/{world_size} Data Shard Samples (showing first 5) ===")
+            
+            # 打印序列数据
+            # print("[Sequences]")
+            # for i, seq in enumerate(sequences[:5]):
+            #     print(f"Sample {i}: {seq[:3]}...")  # 只显示前3元素示意
+
+            # 打印目标数据
+            print("\n[Targets]")
+            print(targets[:5])
+
+            # 打印Group ID分布
+            print("\n[Group IDs]")
+            # unique_gids = list(set(group_ids[:50]))  # 检查前50条的group分布
+            print(f"First 5 GIDs: {group_ids[:5]}")
+            
+            # sys.stdout.flush()  # 确保立即输出
 
+        if FLAG_Distributed:
+            dist.barrier(device_ids=[local_rank])  # 等待当前Rank打印完成
 
 
 if __name__ == "__main__":

+ 87 - 0
model.py

@@ -0,0 +1,87 @@
+import torch
+import torch.nn as nn
+
+
+# 分类模型 (Transformer) 
+class PriceDropClassifiTransModel(nn.Module):
+    def __init__(self, input_size, num_periods=2, hidden_size=128, num_layers=3, output_size=1, dropout=0.3, conv_out_channels=64, kernel_size=3, num_heads=8):
+        super(PriceDropClassifiTransModel, self).__init__()
+        self.hidden_size = hidden_size
+        self.num_layers = num_layers
+        self.num_periods = num_periods
+
+        # 卷积层
+        self.conv1 = nn.Conv1d(
+            in_channels=input_size * num_periods,
+            out_channels=conv_out_channels,
+            kernel_size=kernel_size,
+            padding=kernel_size // 2,
+            bias=False,
+        )
+        self.relu = nn.ReLU()
+
+        # Transformer Encoder
+        self.transformer_layer = nn.TransformerEncoderLayer(
+            d_model=conv_out_channels,
+            # d_model=input_size * num_periods,   # 这里的d_model应为输入的特征数量, d_model能被num_heads整除
+            nhead=num_heads,
+            dim_feedforward=hidden_size,
+            dropout=dropout
+        )
+        self.transformer_encoder = nn.TransformerEncoder(
+            self.transformer_layer,
+            num_layers=num_layers
+        )
+
+        # 注意力机制
+        self.attention_layer = nn.Sequential(
+            nn.Linear(conv_out_channels, hidden_size),
+            # nn.Linear(input_size * num_periods, hidden_size),
+            # nn.Conv1d(conv_out_channels, hidden_size),
+            # nn.Tanh(),
+            nn.ReLU(),
+            nn.Linear(hidden_size, 1)
+        )
+
+        # 分类和回归输出层
+        self.fc_classification = nn.Linear(conv_out_channels, 1)
+
+    def forward(self, x):
+        """
+        输入x的形状应为 [batch_size, num_periods, seq_length, input_size]
+        """
+        batch_size, num_periods, seq_length, input_size = x.size()
+        # x = x[:,0,:,:].view(batch_size, 1, input_size, seq_length)
+
+        # 将输入转换为 [batch_size, num_periods * input_size, seq_length]
+        x = x.permute(0, 1, 3, 2).contiguous()  # [batch_size, num_periods, input_size, seq_length]
+        x = x.view(batch_size, num_periods * input_size, seq_length)  # [batch_size, num_periods * input_size, seq_length]
+        # x = x.view(batch_size, 1 * input_size, seq_length)
+
+        # 经过卷积层和激活函数
+        x = self.conv1(x)    # [batch_size, conv_out_channels, seq_length]
+        x = self.relu(x)
+
+        # 转置以适应Transformer输入要求
+        x = x.permute(2, 0, 1)  # [seq_length, batch_size, conv_out_channels(num_periods * input_size)]
+
+        # 经过Transformer编码器
+        x = self.transformer_encoder(x)  # [seq_length, batch_size, conv_out_channels(num_periods * input_size)]
+
+        # 计算注意力
+        attention_scores = self.attention_layer(x)  # [seq_length, batch_size, 1]
+        attention_weights = torch.softmax(attention_scores, dim=0)  # [seq_length, batch_size, 1]
+        # 对所有时间步进行加权求和
+        context_vector = torch.sum(attention_weights * x, dim=0)  # [batch_size, conv_out_channels(num_periods * input_size)]
+        
+        # 取最后一个时间步的输出进行分类和回归
+        # context_vector = x[-1, :, :]  # [batch_size, conv_out_channels(num_periods * input_size)]
+
+        # 分类和回归输出
+        classification_output = torch.sigmoid(self.fc_classification(context_vector))  # [batch_size, 1]
+        # 打印检查:输出范围
+        # print(f"Before clamp: min: {classification_output.min().item()}, max: {classification_output.max().item()}")
+        # 将输出值限制在 [0.0001, 0.9999] 范围内,以避免数值极端
+        # classification_output = torch.clamp(classification_output, min=1e-4, max=1 - 1e-4)
+        
+        return classification_output

BIN
simhei.ttf


+ 443 - 0
train.py

@@ -0,0 +1,443 @@
+import gc
+import os
+import torch
+import torch.nn as nn
+import torch.distributed as dist
+from torch.utils.data import DataLoader, DistributedSampler
+from sklearn.model_selection import train_test_split
+from imblearn.over_sampling import SMOTE, RandomOverSampler
+from collections import Counter
+from evaluate import evaluate_model_distribute
+from utils import FlightDataset, EarlyStoppingDist  # EarlyStopping, train_process, train_process_distribute, CombinedLoss
+import numpy as np
+import matplotlib.pyplot as plt
+import font
+import config
+import redis
+import time
+
+
+# 智能分层划分函数
+def safe_train_test_split(*arrays, test_size=0.2, random_state=None, stratify=None, rank=0, local_rank=0):
+    if stratify is not None:
+        counts = Counter(stratify)
+        min_count = min(counts.values()) if counts else 0
+        if min_count < 2:
+            if local_rank == 0:
+                print(f"Rank:{rank}, Local Rank:{local_rank}, 安全分层:检测到最小类别样本数={min_count},禁用分层")
+            stratify = None
+    
+    return train_test_split(
+        *arrays,
+        test_size=test_size,
+        random_state=random_state,
+        stratify=stratify
+    )
+
+
+# 分布式数据集准备
+def prepare_data_distribute(sequences, targets, group_ids, flag_distributed=False, rank=0, local_rank=0, world_size=1):
+    if len(sequences) == 0 or len(targets) == 0:
+        print(f"Rank:{rank}, 没有足够的数据参与训练。")
+        return False, None, None, None
+    
+    targets_array = np.array([t[0].item() if isinstance(t[0], torch.Tensor) else t[0] for t in targets])
+    unique_classes, class_counts = np.unique(targets_array, return_counts=True)
+    if len(unique_classes) == 1:
+        print(f"Rank:{rank}, 警告:目标变量只有一个类别,无法参与训练。")
+        return False, None, None, None
+    
+    # --- 高效过滤样本数 ≤ 1 的类别(浮点兼容版)---
+    unique_classes, class_counts = np.unique(targets_array, return_counts=True)
+    class_to_count = dict(zip(unique_classes, class_counts))
+    valid_mask = np.array([class_to_count[cls] >= 2 for cls in targets_array])
+
+    if not np.any(valid_mask):
+        print(f"Rank:{rank}, 警告:所有类别的样本数均 ≤ 1,无法分层拆分。")
+        return False, None, None, None
+
+    # 一次性筛选数据(兼容列表/Tensor/Array)
+    sequences_filtered = [seq for i, seq in enumerate(sequences) if valid_mask[i]]
+    targets_filtered = [t for i, t in enumerate(targets) if valid_mask[i]]
+    group_ids_filtered = [g for i, g in enumerate(group_ids) if valid_mask[i]]
+    targets_array_filtered = targets_array[valid_mask]
+
+    # 第一步:将28样本拆分为训练集(80%)和临时集(20%)
+    train_28, temp_28, train_28_targets, temp_28_targets, train_28_gids, temp_28_gids = safe_train_test_split(
+        sequences_filtered, targets_filtered, group_ids_filtered,
+        stratify=targets_array_filtered,
+        test_size=0.2,
+        random_state=42,
+        rank=rank,
+        local_rank=local_rank
+    )
+
+    # 验证集与测试集全部引用临时集
+    val_28 = temp_28
+    test_28 = temp_28
+    val_28_targets = temp_28_targets
+    test_28_targets = temp_28_targets
+    val_28_gids = temp_28_gids
+    test_28_gids = temp_28_gids
+
+    # 合并训练集
+    train_sequences = train_28
+    train_targets = train_28_targets
+    train_group_ids = train_28_gids
+    
+    # 合并验证集
+    val_sequences = val_28 
+    val_targets = val_28_targets 
+    val_group_ids = val_28_gids
+
+    # 测试集
+    test_sequences = test_28
+    test_targets = test_28_targets
+    test_group_ids = test_28_gids
+
+    if local_rank == 0:
+        print(f"Rank:{rank}, Local Rank:{local_rank}, 批次训练集数量:{len(train_sequences)}")
+        print(f"Rank:{rank}, Local Rank:{local_rank}, 批次验证集数量:{len(val_sequences)}")
+        print(f"Rank:{rank}, Local Rank:{local_rank}, 批次测试集数量:{len(test_sequences)}")
+
+    train_sequences_tensors = [torch.tensor(seq, dtype=torch.float32) for seq in train_sequences]
+    train_targets_tensors = [torch.tensor(target, dtype=torch.float32) for target in train_targets]
+
+    if local_rank == 0:
+        # 打印检查
+        print(f"Rank:{rank}, Local Rank:{local_rank}, train_targets_tensors[0].shape:{train_targets_tensors[0].shape}")  # 应该是 torch.Size([1])
+        print(f"Rank:{rank}, Local Rank:{local_rank}, train_sequences_tensors[0].dtype:{train_sequences_tensors[0].dtype}")  # 应该是 torch.float32
+        print(f"Rank:{rank}, Local Rank:{local_rank}, train_targets_tensors[0].dtype:{train_targets_tensors[0].dtype}")  # 应该是 torch.float32
+
+    train_single = {'sequences': train_sequences_tensors, 'targets': train_targets_tensors, 'group_ids': train_group_ids}
+    val_single = {'sequences': val_sequences, 'targets': val_targets, 'group_ids': val_group_ids}
+    test_single = {'sequences': test_sequences, 'targets': test_targets, 'group_ids': test_group_ids}    
+
+    def _redis_barrier(redis_client, barrier_key, world_size, timeout=3600, poll_interval=1):
+        # 每个 rank 到达 barrier 时,将计数加 1
+        redis_client.incr(barrier_key)
+        
+        start_time = time.time()
+        while True:
+            count = redis_client.get(barrier_key)
+            count = int(count) if count else 0
+            if count >= world_size:
+                break
+            if time.time() - start_time > timeout:
+                raise TimeoutError("等待 barrier 超时")
+            time.sleep(poll_interval)
+
+    # 等待其他进程生成数据,并同步
+    if flag_distributed:
+        redis_client = redis.Redis(host='192.168.20.222', port=6379, db=0)
+        barrier_key = 'distributed_barrier_11'
+
+        # 等待所有进程都到达 barrier
+        _redis_barrier(redis_client, barrier_key, world_size)
+
+    return True, train_single, val_single, test_single
+
+# 分布式训练
+def train_model_distribute(train_sequences, train_targets, train_group_ids, val_sequences, val_targets, val_group_ids,
+                           model, criterion, optimizer, device, num_epochs=200, batch_size=16, target_scaler=None,
+                           flag_distributed=False, rank=0, local_rank=0, world_size=1, output_dir='.',
+                           photo_dir='.', batch_idx=-1, batch_flight_routes=None, patience=20, delta=0.01
+                           ):
+    
+    # 统计正负样本数量
+    all_targets = torch.cat(train_targets)   # 将所有目标值拼接成一个张量
+    positive_count = torch.sum(all_targets == 1).item()
+    negative_count = torch.sum(all_targets == 0).item()
+    total_samples = len(all_targets)
+
+    # 计算比例
+    positive_ratio = positive_count / total_samples
+    negative_ratio = negative_count / total_samples
+
+    if local_rank == 0:
+        # 打印检查
+        print(f"Rank:{rank}, Local Rank:{local_rank}, 汇总训练集数量:{len(train_sequences)}")
+        print(f"Rank:{rank}, Local Rank:{local_rank}, 汇总训练集目标数量:{len(train_targets)}")
+        print(f"Rank:{rank}, Local Rank:{local_rank}, 汇总验证集数量:{len(val_sequences)}")
+        print(f"Rank:{rank}, Local Rank:{local_rank}, 汇总验证集目标数量:{len(val_targets)}")
+
+        # 打印正负样本统计
+        print(f"Rank:{rank}, Local Rank:{local_rank}, 训练集总样本数: {total_samples}")
+        print(f"Rank:{rank}, Local Rank:{local_rank}, 训练集正样本数: {positive_count} ({positive_ratio*100:.2f}%)")
+        print(f"Rank:{rank}, Local Rank:{local_rank}, 训练集负样本数: {negative_count} ({negative_ratio*100:.2f}%)")
+
+    # 计算并打印推荐的 pos_weight
+    if positive_count > 0:
+        recommended_pos_weight = negative_count / positive_count
+        if local_rank == 0:
+            print(f"Rank:{rank}, Local Rank:{local_rank}, 推荐的 pos_weight: {recommended_pos_weight:.2f}")
+    else:
+        recommended_pos_weight = 1.0
+        if local_rank == 0:
+            print(f"Rank:{rank}, Local Rank:{local_rank}, 警告: 没有正样本!")
+
+    train_dataset = FlightDataset(train_sequences, train_targets)
+    val_dataset = FlightDataset(val_sequences, val_targets, val_group_ids)
+    # test_dataset = FlightDataset(test_sequences, test_targets, test_group_ids)
+
+    del train_sequences
+    del train_targets
+    del train_group_ids
+    del val_sequences
+    del val_targets
+    del val_group_ids
+    gc.collect()
+
+    if flag_distributed:
+        sampler_train = DistributedSampler(train_dataset, shuffle=True)  # 分布式采样器
+        sampler_val = DistributedSampler(val_dataset, shuffle=False)
+        
+        train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler_train)
+        val_loader = DataLoader(val_dataset, batch_size=batch_size,  sampler=sampler_val)
+    else:
+        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
+        val_loader = DataLoader(val_dataset, batch_size=batch_size)
+
+    if local_rank == 0:
+        print(f"Rank:{rank}, Local Rank:{local_rank}, train_dataset 0 0 {train_dataset[0][0].shape}")  # 特征尺寸
+        print(f"Rank:{rank}, Local Rank:{local_rank}, train_dataset 0 1 {train_dataset[0][1].shape}")  # 目标尺寸
+
+    pos_weight_value = recommended_pos_weight  # 从上面的计算中获取
+    # 创建带权重的损失函数
+    criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight_value])).to(device)
+
+    early_stopping = EarlyStoppingDist(patience=patience, verbose=True, delta=delta, path=os.path.join(output_dir, f'best_model_as_{batch_idx}.pth'),
+                                       rank=rank, local_rank=local_rank)
+    
+    # 分布式训练模型
+    train_losses, val_losses = train_process_distribute(
+        model, optimizer, early_stopping, train_loader, val_loader, device, num_epochs=num_epochs, criterion=criterion,
+        flag_distributed=flag_distributed, rank=rank, local_rank=local_rank, loss_call_label="val")
+    
+    if rank == 0:
+        font_prop = font.font_prop
+
+        # 绘制损失曲线(可选)
+        plt.figure(figsize=(10, 6))
+        epochs = range(1, len(train_losses) + 1)
+        plt.plot(epochs, train_losses, 'b-', label='训练集损失')
+        plt.plot(epochs, val_losses, 'r-', label='验证集损失')
+        plt.title('训练和验证集损失曲线', fontproperties=font_prop)
+        plt.xlabel('Epochs', fontproperties=font_prop)
+        plt.ylabel('Loss', fontproperties=font_prop)
+        plt.legend(prop=font_prop)
+        plt.savefig(os.path.join(photo_dir, f"train_loss_batch_{batch_idx}.png"))
+
+    # 训练结束后加载最佳模型参数
+    best_model_path = os.path.join(output_dir, f'best_model_as_{batch_idx}.pth')
+
+    # 确保所有进程都看到相同的文件系统状态
+    if flag_distributed:
+        dist.barrier()
+    
+    # 创建用于广播的列表(只有一个元素)
+    checkpoint_list = [None]
+
+    if rank == 0:
+        if os.path.exists(best_model_path):
+            print(f"Rank 0: batch_idx:{batch_idx} Loading best model from {best_model_path}")
+            # 直接加载到 CPU,避免设备不一致问题
+            checkpoint_list[0] = torch.load(best_model_path, map_location='cpu')
+        else:
+            print(f"Rank 0: batch_idx:{batch_idx} Warning - Best model not found at {best_model_path}")
+            # 使用当前模型状态(确保在 CPU 上)
+            if flag_distributed:
+                checkpoint_list[0] = model.module.cpu().state_dict()
+            else:
+                checkpoint_list[0] = model.cpu().state_dict()
+
+    # 广播模型状态字典
+    if flag_distributed:
+        dist.broadcast_object_list(checkpoint_list, src=0)
+
+    # 所有进程获取广播后的状态字典
+    checkpoint = checkpoint_list[0]
+
+    # 加载模型状态
+    if flag_distributed:
+        model.module.load_state_dict(checkpoint)
+    else:
+        model.load_state_dict(checkpoint)
+    
+    # 确保所有进程完成加载
+    if flag_distributed:
+        dist.barrier()
+
+    if flag_distributed:
+        # 调用评估函数
+        evaluate_model_distribute(
+            model.module,       # 使用 DDP 包裹前的原始模型
+            device,
+            None, None, None,
+            test_loader=val_loader,  # 使用累积验证集 
+            batch_flight_routes=batch_flight_routes, target_scaler=target_scaler,
+            flag_distributed=flag_distributed,
+            rank=rank, local_rank=local_rank, world_size=world_size, 
+            output_dir=output_dir, batch_idx=batch_idx, save_mode='a'
+        )
+    else:
+        evaluate_model_distribute(
+            model,
+            device,
+            None, None, None,
+            test_loader=val_loader,  # 使用累积验证集
+            batch_flight_routes=batch_flight_routes, target_scaler=target_scaler,
+            flag_distributed=False,
+            output_dir=output_dir, batch_idx=batch_idx, save_mode='a'
+        )
+
+    return model
+
+
+def train_process_distribute(model, optimizer, early_stopping, train_loader, val_loader, device, num_epochs=200, criterion=None, save_file='best_model.pth',
+                             flag_distributed=False, rank=0, local_rank=0, loss_call_label="train"):
+    # 具体训练过程
+    train_losses = []
+    val_losses = []
+    
+    # 初始化损失为张量(兼容非分布式和分布式)
+    # total_train_loss = torch.tensor(0.0, device=device)
+    # total_val_loss = torch.tensor(0.0, device=device)
+
+    # 初始化 TensorBoard(只在主进程)
+    # if rank == 0:
+    #     writer = SummaryWriter(log_dir='runs/experiment_name')
+    #     train_global_step = 0
+    #     val_global_step = 0
+    
+    for epoch in range(num_epochs):
+        # --- 训练阶段 ---
+        model.train()
+        if flag_distributed:
+            train_loader.sampler.set_epoch(epoch)  # 确保每个进程一致地打乱顺序
+
+        # total_train_loss.zero_()   # 重置损失累计
+        total_train_loss = torch.tensor(0.0, device=device)
+        num_train_samples = torch.tensor(0, device=device)     # 当前进程的样本数
+
+        for batch_idx, batch in enumerate(train_loader):
+            X_batch, y_batch = batch[:2]  # 假设 group_ids 不需要参与训练
+            X_batch = X_batch.to(device)
+            y_batch = y_batch.to(device)
+            
+            optimizer.zero_grad()
+            outputs = model(X_batch)
+            loss = criterion(outputs, y_batch)
+            loss.backward()
+
+            # 打印
+            # if rank == 0:
+            #     # print_gradient_range(model)
+            #     # 记录损失值
+            #     writer.add_scalar('Loss/train_batch', loss.item(), train_global_step)
+            #     # 记录元数据
+            #     writer.add_scalar('Metadata/train_epoch', epoch, train_global_step)
+            #     writer.add_scalar('Metadata/train_batch_in_epoch', batch_idx, train_global_step)
+
+            #     log_gradient_stats(model, writer, train_global_step, "train")
+
+            #     # 更新全局步数
+            #     train_global_step += 1
+
+            # 梯度裁剪(已兼容 DDP)
+            # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
+            optimizer.step()
+            
+            # 累计损失
+            total_train_loss += loss.detach() * X_batch.size(0)   # detach() 保留张量形式以支持跨进程通信
+            num_train_samples += X_batch.size(0)
+        
+        # --- 同步训练损失 ---
+        if flag_distributed:
+            # 会将所有进程的 total_train_loss 求和后, 同步到每个进程
+            dist.all_reduce(total_train_loss, op=dist.ReduceOp.SUM)
+            dist.all_reduce(num_train_samples, op=dist.ReduceOp.SUM)
+
+        # avg_train_loss = total_train_loss.item() / len(train_loader.dataset)
+        avg_train_loss = total_train_loss.item() / num_train_samples.item()
+        train_losses.append(avg_train_loss)
+
+        # --- 验证阶段 ---
+        model.eval()
+        # total_val_loss.zero_()  # 重置验证损失
+        total_val_loss = torch.tensor(0.0, device=device)
+        num_val_samples = torch.tensor(0, device=device)
+
+        with torch.no_grad():
+            for batch_idx, batch in enumerate(val_loader):
+                X_val, y_val = batch[:2]
+                X_val = X_val.to(device)
+                y_val = y_val.to(device)
+                
+                outputs = model(X_val)
+                val_loss = criterion(outputs, y_val)
+                total_val_loss += val_loss.detach() * X_val.size(0)
+                num_val_samples += X_val.size(0)
+
+                # if rank == 0:
+                #     # 记录验证集batch loss
+                #     writer.add_scalar('Loss/val_batch', val_loss.item(), val_global_step)
+                #     # 记录验证集元数据
+                #     writer.add_scalar('Metadata/val_epoch', epoch, val_global_step)
+                #     writer.add_scalar('Metadata/val_batch_in_epoch', batch_idx, val_global_step)
+
+                #     # 更新验证集全局步数
+                #     val_global_step += 1
+
+                # if local_rank == 0:
+                #     print(f"rank:{rank}, outputs:{outputs}")
+                #     print(f"rank:{rank}, y_val:{y_val}")
+                #     print(f"rank:{rank}, val_loss:{val_loss.detach()}")
+                #     print(f"rank:{rank}, size:{X_val.size(0)}")
+
+        # --- 同步验证损失 ---
+        if flag_distributed:
+            dist.all_reduce(total_val_loss, op=dist.ReduceOp.SUM)
+            dist.all_reduce(num_val_samples, op=dist.ReduceOp.SUM)
+
+        # avg_val_loss = total_val_loss.item() / len(val_loader.dataset)
+        avg_val_loss = total_val_loss.item() / num_val_samples.item()
+        val_losses.append(avg_val_loss)
+
+        # if rank == 0:
+        #     # 记录epoch平均损失
+        #     writer.add_scalar('Loss/train_epoch_avg', avg_train_loss, epoch)
+        #     writer.add_scalar('Loss/val_epoch_avg', avg_val_loss, epoch)
+
+        if local_rank == 0:
+            print(f"Rank:{rank}, Epoch {epoch+1}/{num_epochs}, 训练集损失: {avg_train_loss:.4f}, 验证集损失: {avg_val_loss:.4f}")
+
+        # --- 早停与保存逻辑(仅在 rank 0 执行)---
+        if rank == 0:
+            
+            # 模型保存兼容分布式和非分布式
+            model_to_save = model.module if flag_distributed else model   # 当使用 model = DDP(model) 封装后,原始模型会被包裹在 model.module 属性
+            if loss_call_label == "train":
+                early_stopping(avg_train_loss, model_to_save)  # 平均训练集损失 
+            else:
+                early_stopping(avg_val_loss, model_to_save)   # 平均验证集损失
+            
+            if early_stopping.early_stop:
+                print(f"Rank:{rank}, 早停触发,停止训练 at epoch {epoch}")
+                # 非分布式模式下直接退出循环
+                if not flag_distributed:
+                    break
+
+        # --- 同步早停状态(仅分布式需要)---
+        if flag_distributed:
+            # 将早停标志转换为张量广播
+            early_stop_flag = torch.tensor([early_stopping.early_stop], device=device)
+            dist.broadcast(early_stop_flag, src=0)
+            if early_stop_flag.item():   # item()取张量的布尔值
+                break
+        # else:
+        #     # 非分布式模式下,直接检查早停标志
+        #     if early_stopping.early_stop:
+        #         break
+
+    return train_losses, val_losses

+ 96 - 4
utils.py

@@ -1,5 +1,7 @@
+import gc
+import time
 import torch
-
+from torch.utils.data import Dataset
 
 # 航线列表分组切片并带上索引
 def chunk_list_with_index(lst, group_size):
@@ -27,6 +29,9 @@ def insert_df_col(df, insert_col_name, base_col_name, inplace=True):
 
 # 真正创建序列过程
 def create_fixed_length_sequences(df, features, target_vars, input_length=452, is_train=True):
+    print(">>开始创建序列")
+    start_time = time.time()
+
     sequences = []
     targets = []
     group_ids = []
@@ -77,7 +82,94 @@ def create_fixed_length_sequences(df, features, target_vars, input_length=452, i
                            str(last_row['Hours_Until_Departure']), 
                            str(last_row['update_hour'])])
             group_ids.append(tuple(name_c))
-            pass
+        
+        del df_group_bag_30_filtered, df_group_bag_20_filtered
+        del df_group_bag_30, df_group_bag_20
+        del df_group
+
+    gc.collect()
+    print(">>结束创建序列")
+    end_time = time.time()
+    run_time = round(end_time - start_time, 3)
+    print(f"用时: {run_time} 秒")
+    print(f"生成的序列数量:{len(sequences)}")
+    
+    return sequences, targets, group_ids
+
+
+class FlightDataset(Dataset):
+    def __init__(self, X_sequences, y_sequences=None, group_ids=None):
+        self.X_sequences = X_sequences
+        self.y_sequences = y_sequences
+        self.group_ids = group_ids
+        self.return_group_ids = group_ids is not None
+
+    def __len__(self):
+        return len(self.X_sequences)
+
+    def __getitem__(self, idx):
+        if self.return_group_ids:
+            if self.y_sequences:
+                return self.X_sequences[idx], self.y_sequences[idx], self.group_ids[idx]
+            else:
+                return self.X_sequences[idx], self.group_ids[idx]
+        else:
+            if self.y_sequences:
+                return self.X_sequences[idx], self.y_sequences[idx]
+            else:
+                return self.X_sequences[idx]
+
+
+class EarlyStoppingDist:
+    """早停机制(分布式)"""
+    def __init__(self, patience=10, verbose=False, delta=0, path='best_model.pth', rank=0, local_rank=0):
+        """
+        Args:
+            patience (int): 在训练集(或验证集)损失不再改善时,等待多少个epoch后停止训练
+            verbose (bool): 是否打印相关信息
+            delta (float): 训练集损失需要减少的最小变化量
+            path (str): 保存最佳模型的路径
+        """
+        self.patience = patience
+        self.verbose = verbose
+        self.delta = delta
+        self.path = path
+        self.counter = 0
+        self.best_loss = None
+        self.early_stop = False
+        self.rank = rank
+        self.local_rank = local_rank
+
+    def __call__(self, loss, model):
+        if self.best_loss is None:
+            self.best_loss = loss
+            self.save_checkpoint(loss, model)
+        elif loss > self.best_loss - self.delta:
+            self.counter += 1
+            if self.verbose and self.rank == 0:
+                print(f'Rank:{self.rank}, Local Rank:{self.local_rank}, EarlyStopping counter: {self.counter} out of {self.patience}')
+            if self.counter >= self.patience:
+                self.early_stop = True
+        else:
+            self.save_checkpoint(loss, model)
+            self.best_loss = loss
+            self.counter = 0
+            if self.is_nan(loss):
+                self.counter += self.patience  # 立即触发早停
+                self.early_stop = True
+
+    def is_nan(self, loss):
+        """检查损失值是否为NaN(通用方法)"""
+        try:
+            # 所有NaN类型都不等于自身
+            return loss != loss
+        except Exception:
+            # 处理不支持比较的类型
+            return False
+
+    def save_checkpoint(self, loss, model):
+        """保存模型"""
+        if self.verbose and self.rank == 0:
+            print(f'Rank:{self.rank}, Local Rank:{self.local_rank}, Loss decreased ({self.best_loss:.6f} --> {loss:.6f}).  Saving model ...')
+            torch.save(model.state_dict(), self.path)
 
-        pass
-    pass