|
|
@@ -0,0 +1,443 @@
|
|
|
+import gc
|
|
|
+import os
|
|
|
+import torch
|
|
|
+import torch.nn as nn
|
|
|
+import torch.distributed as dist
|
|
|
+from torch.utils.data import DataLoader, DistributedSampler
|
|
|
+from sklearn.model_selection import train_test_split
|
|
|
+from imblearn.over_sampling import SMOTE, RandomOverSampler
|
|
|
+from collections import Counter
|
|
|
+from evaluate import evaluate_model_distribute
|
|
|
+from utils import FlightDataset, EarlyStoppingDist # EarlyStopping, train_process, train_process_distribute, CombinedLoss
|
|
|
+import numpy as np
|
|
|
+import matplotlib.pyplot as plt
|
|
|
+import font
|
|
|
+import config
|
|
|
+import redis
|
|
|
+import time
|
|
|
+
|
|
|
+
|
|
|
+# 智能分层划分函数
|
|
|
+def safe_train_test_split(*arrays, test_size=0.2, random_state=None, stratify=None, rank=0, local_rank=0):
|
|
|
+ if stratify is not None:
|
|
|
+ counts = Counter(stratify)
|
|
|
+ min_count = min(counts.values()) if counts else 0
|
|
|
+ if min_count < 2:
|
|
|
+ if local_rank == 0:
|
|
|
+ print(f"Rank:{rank}, Local Rank:{local_rank}, 安全分层:检测到最小类别样本数={min_count},禁用分层")
|
|
|
+ stratify = None
|
|
|
+
|
|
|
+ return train_test_split(
|
|
|
+ *arrays,
|
|
|
+ test_size=test_size,
|
|
|
+ random_state=random_state,
|
|
|
+ stratify=stratify
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+# 分布式数据集准备
|
|
|
+def prepare_data_distribute(sequences, targets, group_ids, flag_distributed=False, rank=0, local_rank=0, world_size=1):
|
|
|
+ if len(sequences) == 0 or len(targets) == 0:
|
|
|
+ print(f"Rank:{rank}, 没有足够的数据参与训练。")
|
|
|
+ return False, None, None, None
|
|
|
+
|
|
|
+ targets_array = np.array([t[0].item() if isinstance(t[0], torch.Tensor) else t[0] for t in targets])
|
|
|
+ unique_classes, class_counts = np.unique(targets_array, return_counts=True)
|
|
|
+ if len(unique_classes) == 1:
|
|
|
+ print(f"Rank:{rank}, 警告:目标变量只有一个类别,无法参与训练。")
|
|
|
+ return False, None, None, None
|
|
|
+
|
|
|
+ # --- 高效过滤样本数 ≤ 1 的类别(浮点兼容版)---
|
|
|
+ unique_classes, class_counts = np.unique(targets_array, return_counts=True)
|
|
|
+ class_to_count = dict(zip(unique_classes, class_counts))
|
|
|
+ valid_mask = np.array([class_to_count[cls] >= 2 for cls in targets_array])
|
|
|
+
|
|
|
+ if not np.any(valid_mask):
|
|
|
+ print(f"Rank:{rank}, 警告:所有类别的样本数均 ≤ 1,无法分层拆分。")
|
|
|
+ return False, None, None, None
|
|
|
+
|
|
|
+ # 一次性筛选数据(兼容列表/Tensor/Array)
|
|
|
+ sequences_filtered = [seq for i, seq in enumerate(sequences) if valid_mask[i]]
|
|
|
+ targets_filtered = [t for i, t in enumerate(targets) if valid_mask[i]]
|
|
|
+ group_ids_filtered = [g for i, g in enumerate(group_ids) if valid_mask[i]]
|
|
|
+ targets_array_filtered = targets_array[valid_mask]
|
|
|
+
|
|
|
+ # 第一步:将28样本拆分为训练集(80%)和临时集(20%)
|
|
|
+ train_28, temp_28, train_28_targets, temp_28_targets, train_28_gids, temp_28_gids = safe_train_test_split(
|
|
|
+ sequences_filtered, targets_filtered, group_ids_filtered,
|
|
|
+ stratify=targets_array_filtered,
|
|
|
+ test_size=0.2,
|
|
|
+ random_state=42,
|
|
|
+ rank=rank,
|
|
|
+ local_rank=local_rank
|
|
|
+ )
|
|
|
+
|
|
|
+ # 验证集与测试集全部引用临时集
|
|
|
+ val_28 = temp_28
|
|
|
+ test_28 = temp_28
|
|
|
+ val_28_targets = temp_28_targets
|
|
|
+ test_28_targets = temp_28_targets
|
|
|
+ val_28_gids = temp_28_gids
|
|
|
+ test_28_gids = temp_28_gids
|
|
|
+
|
|
|
+ # 合并训练集
|
|
|
+ train_sequences = train_28
|
|
|
+ train_targets = train_28_targets
|
|
|
+ train_group_ids = train_28_gids
|
|
|
+
|
|
|
+ # 合并验证集
|
|
|
+ val_sequences = val_28
|
|
|
+ val_targets = val_28_targets
|
|
|
+ val_group_ids = val_28_gids
|
|
|
+
|
|
|
+ # 测试集
|
|
|
+ test_sequences = test_28
|
|
|
+ test_targets = test_28_targets
|
|
|
+ test_group_ids = test_28_gids
|
|
|
+
|
|
|
+ if local_rank == 0:
|
|
|
+ print(f"Rank:{rank}, Local Rank:{local_rank}, 批次训练集数量:{len(train_sequences)}")
|
|
|
+ print(f"Rank:{rank}, Local Rank:{local_rank}, 批次验证集数量:{len(val_sequences)}")
|
|
|
+ print(f"Rank:{rank}, Local Rank:{local_rank}, 批次测试集数量:{len(test_sequences)}")
|
|
|
+
|
|
|
+ train_sequences_tensors = [torch.tensor(seq, dtype=torch.float32) for seq in train_sequences]
|
|
|
+ train_targets_tensors = [torch.tensor(target, dtype=torch.float32) for target in train_targets]
|
|
|
+
|
|
|
+ if local_rank == 0:
|
|
|
+ # 打印检查
|
|
|
+ print(f"Rank:{rank}, Local Rank:{local_rank}, train_targets_tensors[0].shape:{train_targets_tensors[0].shape}") # 应该是 torch.Size([1])
|
|
|
+ print(f"Rank:{rank}, Local Rank:{local_rank}, train_sequences_tensors[0].dtype:{train_sequences_tensors[0].dtype}") # 应该是 torch.float32
|
|
|
+ print(f"Rank:{rank}, Local Rank:{local_rank}, train_targets_tensors[0].dtype:{train_targets_tensors[0].dtype}") # 应该是 torch.float32
|
|
|
+
|
|
|
+ train_single = {'sequences': train_sequences_tensors, 'targets': train_targets_tensors, 'group_ids': train_group_ids}
|
|
|
+ val_single = {'sequences': val_sequences, 'targets': val_targets, 'group_ids': val_group_ids}
|
|
|
+ test_single = {'sequences': test_sequences, 'targets': test_targets, 'group_ids': test_group_ids}
|
|
|
+
|
|
|
+ def _redis_barrier(redis_client, barrier_key, world_size, timeout=3600, poll_interval=1):
|
|
|
+ # 每个 rank 到达 barrier 时,将计数加 1
|
|
|
+ redis_client.incr(barrier_key)
|
|
|
+
|
|
|
+ start_time = time.time()
|
|
|
+ while True:
|
|
|
+ count = redis_client.get(barrier_key)
|
|
|
+ count = int(count) if count else 0
|
|
|
+ if count >= world_size:
|
|
|
+ break
|
|
|
+ if time.time() - start_time > timeout:
|
|
|
+ raise TimeoutError("等待 barrier 超时")
|
|
|
+ time.sleep(poll_interval)
|
|
|
+
|
|
|
+ # 等待其他进程生成数据,并同步
|
|
|
+ if flag_distributed:
|
|
|
+ redis_client = redis.Redis(host='192.168.20.222', port=6379, db=0)
|
|
|
+ barrier_key = 'distributed_barrier_11'
|
|
|
+
|
|
|
+ # 等待所有进程都到达 barrier
|
|
|
+ _redis_barrier(redis_client, barrier_key, world_size)
|
|
|
+
|
|
|
+ return True, train_single, val_single, test_single
|
|
|
+
|
|
|
+# 分布式训练
|
|
|
+def train_model_distribute(train_sequences, train_targets, train_group_ids, val_sequences, val_targets, val_group_ids,
|
|
|
+ model, criterion, optimizer, device, num_epochs=200, batch_size=16, target_scaler=None,
|
|
|
+ flag_distributed=False, rank=0, local_rank=0, world_size=1, output_dir='.',
|
|
|
+ photo_dir='.', batch_idx=-1, batch_flight_routes=None, patience=20, delta=0.01
|
|
|
+ ):
|
|
|
+
|
|
|
+ # 统计正负样本数量
|
|
|
+ all_targets = torch.cat(train_targets) # 将所有目标值拼接成一个张量
|
|
|
+ positive_count = torch.sum(all_targets == 1).item()
|
|
|
+ negative_count = torch.sum(all_targets == 0).item()
|
|
|
+ total_samples = len(all_targets)
|
|
|
+
|
|
|
+ # 计算比例
|
|
|
+ positive_ratio = positive_count / total_samples
|
|
|
+ negative_ratio = negative_count / total_samples
|
|
|
+
|
|
|
+ if local_rank == 0:
|
|
|
+ # 打印检查
|
|
|
+ print(f"Rank:{rank}, Local Rank:{local_rank}, 汇总训练集数量:{len(train_sequences)}")
|
|
|
+ print(f"Rank:{rank}, Local Rank:{local_rank}, 汇总训练集目标数量:{len(train_targets)}")
|
|
|
+ print(f"Rank:{rank}, Local Rank:{local_rank}, 汇总验证集数量:{len(val_sequences)}")
|
|
|
+ print(f"Rank:{rank}, Local Rank:{local_rank}, 汇总验证集目标数量:{len(val_targets)}")
|
|
|
+
|
|
|
+ # 打印正负样本统计
|
|
|
+ print(f"Rank:{rank}, Local Rank:{local_rank}, 训练集总样本数: {total_samples}")
|
|
|
+ print(f"Rank:{rank}, Local Rank:{local_rank}, 训练集正样本数: {positive_count} ({positive_ratio*100:.2f}%)")
|
|
|
+ print(f"Rank:{rank}, Local Rank:{local_rank}, 训练集负样本数: {negative_count} ({negative_ratio*100:.2f}%)")
|
|
|
+
|
|
|
+ # 计算并打印推荐的 pos_weight
|
|
|
+ if positive_count > 0:
|
|
|
+ recommended_pos_weight = negative_count / positive_count
|
|
|
+ if local_rank == 0:
|
|
|
+ print(f"Rank:{rank}, Local Rank:{local_rank}, 推荐的 pos_weight: {recommended_pos_weight:.2f}")
|
|
|
+ else:
|
|
|
+ recommended_pos_weight = 1.0
|
|
|
+ if local_rank == 0:
|
|
|
+ print(f"Rank:{rank}, Local Rank:{local_rank}, 警告: 没有正样本!")
|
|
|
+
|
|
|
+ train_dataset = FlightDataset(train_sequences, train_targets)
|
|
|
+ val_dataset = FlightDataset(val_sequences, val_targets, val_group_ids)
|
|
|
+ # test_dataset = FlightDataset(test_sequences, test_targets, test_group_ids)
|
|
|
+
|
|
|
+ del train_sequences
|
|
|
+ del train_targets
|
|
|
+ del train_group_ids
|
|
|
+ del val_sequences
|
|
|
+ del val_targets
|
|
|
+ del val_group_ids
|
|
|
+ gc.collect()
|
|
|
+
|
|
|
+ if flag_distributed:
|
|
|
+ sampler_train = DistributedSampler(train_dataset, shuffle=True) # 分布式采样器
|
|
|
+ sampler_val = DistributedSampler(val_dataset, shuffle=False)
|
|
|
+
|
|
|
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler_train)
|
|
|
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, sampler=sampler_val)
|
|
|
+ else:
|
|
|
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
|
|
+ val_loader = DataLoader(val_dataset, batch_size=batch_size)
|
|
|
+
|
|
|
+ if local_rank == 0:
|
|
|
+ print(f"Rank:{rank}, Local Rank:{local_rank}, train_dataset 0 0 {train_dataset[0][0].shape}") # 特征尺寸
|
|
|
+ print(f"Rank:{rank}, Local Rank:{local_rank}, train_dataset 0 1 {train_dataset[0][1].shape}") # 目标尺寸
|
|
|
+
|
|
|
+ pos_weight_value = recommended_pos_weight # 从上面的计算中获取
|
|
|
+ # 创建带权重的损失函数
|
|
|
+ criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight_value])).to(device)
|
|
|
+
|
|
|
+ early_stopping = EarlyStoppingDist(patience=patience, verbose=True, delta=delta, path=os.path.join(output_dir, f'best_model_as_{batch_idx}.pth'),
|
|
|
+ rank=rank, local_rank=local_rank)
|
|
|
+
|
|
|
+ # 分布式训练模型
|
|
|
+ train_losses, val_losses = train_process_distribute(
|
|
|
+ model, optimizer, early_stopping, train_loader, val_loader, device, num_epochs=num_epochs, criterion=criterion,
|
|
|
+ flag_distributed=flag_distributed, rank=rank, local_rank=local_rank, loss_call_label="val")
|
|
|
+
|
|
|
+ if rank == 0:
|
|
|
+ font_prop = font.font_prop
|
|
|
+
|
|
|
+ # 绘制损失曲线(可选)
|
|
|
+ plt.figure(figsize=(10, 6))
|
|
|
+ epochs = range(1, len(train_losses) + 1)
|
|
|
+ plt.plot(epochs, train_losses, 'b-', label='训练集损失')
|
|
|
+ plt.plot(epochs, val_losses, 'r-', label='验证集损失')
|
|
|
+ plt.title('训练和验证集损失曲线', fontproperties=font_prop)
|
|
|
+ plt.xlabel('Epochs', fontproperties=font_prop)
|
|
|
+ plt.ylabel('Loss', fontproperties=font_prop)
|
|
|
+ plt.legend(prop=font_prop)
|
|
|
+ plt.savefig(os.path.join(photo_dir, f"train_loss_batch_{batch_idx}.png"))
|
|
|
+
|
|
|
+ # 训练结束后加载最佳模型参数
|
|
|
+ best_model_path = os.path.join(output_dir, f'best_model_as_{batch_idx}.pth')
|
|
|
+
|
|
|
+ # 确保所有进程都看到相同的文件系统状态
|
|
|
+ if flag_distributed:
|
|
|
+ dist.barrier()
|
|
|
+
|
|
|
+ # 创建用于广播的列表(只有一个元素)
|
|
|
+ checkpoint_list = [None]
|
|
|
+
|
|
|
+ if rank == 0:
|
|
|
+ if os.path.exists(best_model_path):
|
|
|
+ print(f"Rank 0: batch_idx:{batch_idx} Loading best model from {best_model_path}")
|
|
|
+ # 直接加载到 CPU,避免设备不一致问题
|
|
|
+ checkpoint_list[0] = torch.load(best_model_path, map_location='cpu')
|
|
|
+ else:
|
|
|
+ print(f"Rank 0: batch_idx:{batch_idx} Warning - Best model not found at {best_model_path}")
|
|
|
+ # 使用当前模型状态(确保在 CPU 上)
|
|
|
+ if flag_distributed:
|
|
|
+ checkpoint_list[0] = model.module.cpu().state_dict()
|
|
|
+ else:
|
|
|
+ checkpoint_list[0] = model.cpu().state_dict()
|
|
|
+
|
|
|
+ # 广播模型状态字典
|
|
|
+ if flag_distributed:
|
|
|
+ dist.broadcast_object_list(checkpoint_list, src=0)
|
|
|
+
|
|
|
+ # 所有进程获取广播后的状态字典
|
|
|
+ checkpoint = checkpoint_list[0]
|
|
|
+
|
|
|
+ # 加载模型状态
|
|
|
+ if flag_distributed:
|
|
|
+ model.module.load_state_dict(checkpoint)
|
|
|
+ else:
|
|
|
+ model.load_state_dict(checkpoint)
|
|
|
+
|
|
|
+ # 确保所有进程完成加载
|
|
|
+ if flag_distributed:
|
|
|
+ dist.barrier()
|
|
|
+
|
|
|
+ if flag_distributed:
|
|
|
+ # 调用评估函数
|
|
|
+ evaluate_model_distribute(
|
|
|
+ model.module, # 使用 DDP 包裹前的原始模型
|
|
|
+ device,
|
|
|
+ None, None, None,
|
|
|
+ test_loader=val_loader, # 使用累积验证集
|
|
|
+ batch_flight_routes=batch_flight_routes, target_scaler=target_scaler,
|
|
|
+ flag_distributed=flag_distributed,
|
|
|
+ rank=rank, local_rank=local_rank, world_size=world_size,
|
|
|
+ output_dir=output_dir, batch_idx=batch_idx, save_mode='a'
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ evaluate_model_distribute(
|
|
|
+ model,
|
|
|
+ device,
|
|
|
+ None, None, None,
|
|
|
+ test_loader=val_loader, # 使用累积验证集
|
|
|
+ batch_flight_routes=batch_flight_routes, target_scaler=target_scaler,
|
|
|
+ flag_distributed=False,
|
|
|
+ output_dir=output_dir, batch_idx=batch_idx, save_mode='a'
|
|
|
+ )
|
|
|
+
|
|
|
+ return model
|
|
|
+
|
|
|
+
|
|
|
+def train_process_distribute(model, optimizer, early_stopping, train_loader, val_loader, device, num_epochs=200, criterion=None, save_file='best_model.pth',
|
|
|
+ flag_distributed=False, rank=0, local_rank=0, loss_call_label="train"):
|
|
|
+ # 具体训练过程
|
|
|
+ train_losses = []
|
|
|
+ val_losses = []
|
|
|
+
|
|
|
+ # 初始化损失为张量(兼容非分布式和分布式)
|
|
|
+ # total_train_loss = torch.tensor(0.0, device=device)
|
|
|
+ # total_val_loss = torch.tensor(0.0, device=device)
|
|
|
+
|
|
|
+ # 初始化 TensorBoard(只在主进程)
|
|
|
+ # if rank == 0:
|
|
|
+ # writer = SummaryWriter(log_dir='runs/experiment_name')
|
|
|
+ # train_global_step = 0
|
|
|
+ # val_global_step = 0
|
|
|
+
|
|
|
+ for epoch in range(num_epochs):
|
|
|
+ # --- 训练阶段 ---
|
|
|
+ model.train()
|
|
|
+ if flag_distributed:
|
|
|
+ train_loader.sampler.set_epoch(epoch) # 确保每个进程一致地打乱顺序
|
|
|
+
|
|
|
+ # total_train_loss.zero_() # 重置损失累计
|
|
|
+ total_train_loss = torch.tensor(0.0, device=device)
|
|
|
+ num_train_samples = torch.tensor(0, device=device) # 当前进程的样本数
|
|
|
+
|
|
|
+ for batch_idx, batch in enumerate(train_loader):
|
|
|
+ X_batch, y_batch = batch[:2] # 假设 group_ids 不需要参与训练
|
|
|
+ X_batch = X_batch.to(device)
|
|
|
+ y_batch = y_batch.to(device)
|
|
|
+
|
|
|
+ optimizer.zero_grad()
|
|
|
+ outputs = model(X_batch)
|
|
|
+ loss = criterion(outputs, y_batch)
|
|
|
+ loss.backward()
|
|
|
+
|
|
|
+ # 打印
|
|
|
+ # if rank == 0:
|
|
|
+ # # print_gradient_range(model)
|
|
|
+ # # 记录损失值
|
|
|
+ # writer.add_scalar('Loss/train_batch', loss.item(), train_global_step)
|
|
|
+ # # 记录元数据
|
|
|
+ # writer.add_scalar('Metadata/train_epoch', epoch, train_global_step)
|
|
|
+ # writer.add_scalar('Metadata/train_batch_in_epoch', batch_idx, train_global_step)
|
|
|
+
|
|
|
+ # log_gradient_stats(model, writer, train_global_step, "train")
|
|
|
+
|
|
|
+ # # 更新全局步数
|
|
|
+ # train_global_step += 1
|
|
|
+
|
|
|
+ # 梯度裁剪(已兼容 DDP)
|
|
|
+ # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
|
|
+ optimizer.step()
|
|
|
+
|
|
|
+ # 累计损失
|
|
|
+ total_train_loss += loss.detach() * X_batch.size(0) # detach() 保留张量形式以支持跨进程通信
|
|
|
+ num_train_samples += X_batch.size(0)
|
|
|
+
|
|
|
+ # --- 同步训练损失 ---
|
|
|
+ if flag_distributed:
|
|
|
+ # 会将所有进程的 total_train_loss 求和后, 同步到每个进程
|
|
|
+ dist.all_reduce(total_train_loss, op=dist.ReduceOp.SUM)
|
|
|
+ dist.all_reduce(num_train_samples, op=dist.ReduceOp.SUM)
|
|
|
+
|
|
|
+ # avg_train_loss = total_train_loss.item() / len(train_loader.dataset)
|
|
|
+ avg_train_loss = total_train_loss.item() / num_train_samples.item()
|
|
|
+ train_losses.append(avg_train_loss)
|
|
|
+
|
|
|
+ # --- 验证阶段 ---
|
|
|
+ model.eval()
|
|
|
+ # total_val_loss.zero_() # 重置验证损失
|
|
|
+ total_val_loss = torch.tensor(0.0, device=device)
|
|
|
+ num_val_samples = torch.tensor(0, device=device)
|
|
|
+
|
|
|
+ with torch.no_grad():
|
|
|
+ for batch_idx, batch in enumerate(val_loader):
|
|
|
+ X_val, y_val = batch[:2]
|
|
|
+ X_val = X_val.to(device)
|
|
|
+ y_val = y_val.to(device)
|
|
|
+
|
|
|
+ outputs = model(X_val)
|
|
|
+ val_loss = criterion(outputs, y_val)
|
|
|
+ total_val_loss += val_loss.detach() * X_val.size(0)
|
|
|
+ num_val_samples += X_val.size(0)
|
|
|
+
|
|
|
+ # if rank == 0:
|
|
|
+ # # 记录验证集batch loss
|
|
|
+ # writer.add_scalar('Loss/val_batch', val_loss.item(), val_global_step)
|
|
|
+ # # 记录验证集元数据
|
|
|
+ # writer.add_scalar('Metadata/val_epoch', epoch, val_global_step)
|
|
|
+ # writer.add_scalar('Metadata/val_batch_in_epoch', batch_idx, val_global_step)
|
|
|
+
|
|
|
+ # # 更新验证集全局步数
|
|
|
+ # val_global_step += 1
|
|
|
+
|
|
|
+ # if local_rank == 0:
|
|
|
+ # print(f"rank:{rank}, outputs:{outputs}")
|
|
|
+ # print(f"rank:{rank}, y_val:{y_val}")
|
|
|
+ # print(f"rank:{rank}, val_loss:{val_loss.detach()}")
|
|
|
+ # print(f"rank:{rank}, size:{X_val.size(0)}")
|
|
|
+
|
|
|
+ # --- 同步验证损失 ---
|
|
|
+ if flag_distributed:
|
|
|
+ dist.all_reduce(total_val_loss, op=dist.ReduceOp.SUM)
|
|
|
+ dist.all_reduce(num_val_samples, op=dist.ReduceOp.SUM)
|
|
|
+
|
|
|
+ # avg_val_loss = total_val_loss.item() / len(val_loader.dataset)
|
|
|
+ avg_val_loss = total_val_loss.item() / num_val_samples.item()
|
|
|
+ val_losses.append(avg_val_loss)
|
|
|
+
|
|
|
+ # if rank == 0:
|
|
|
+ # # 记录epoch平均损失
|
|
|
+ # writer.add_scalar('Loss/train_epoch_avg', avg_train_loss, epoch)
|
|
|
+ # writer.add_scalar('Loss/val_epoch_avg', avg_val_loss, epoch)
|
|
|
+
|
|
|
+ if local_rank == 0:
|
|
|
+ print(f"Rank:{rank}, Epoch {epoch+1}/{num_epochs}, 训练集损失: {avg_train_loss:.4f}, 验证集损失: {avg_val_loss:.4f}")
|
|
|
+
|
|
|
+ # --- 早停与保存逻辑(仅在 rank 0 执行)---
|
|
|
+ if rank == 0:
|
|
|
+
|
|
|
+ # 模型保存兼容分布式和非分布式
|
|
|
+ model_to_save = model.module if flag_distributed else model # 当使用 model = DDP(model) 封装后,原始模型会被包裹在 model.module 属性
|
|
|
+ if loss_call_label == "train":
|
|
|
+ early_stopping(avg_train_loss, model_to_save) # 平均训练集损失
|
|
|
+ else:
|
|
|
+ early_stopping(avg_val_loss, model_to_save) # 平均验证集损失
|
|
|
+
|
|
|
+ if early_stopping.early_stop:
|
|
|
+ print(f"Rank:{rank}, 早停触发,停止训练 at epoch {epoch}")
|
|
|
+ # 非分布式模式下直接退出循环
|
|
|
+ if not flag_distributed:
|
|
|
+ break
|
|
|
+
|
|
|
+ # --- 同步早停状态(仅分布式需要)---
|
|
|
+ if flag_distributed:
|
|
|
+ # 将早停标志转换为张量广播
|
|
|
+ early_stop_flag = torch.tensor([early_stopping.early_stop], device=device)
|
|
|
+ dist.broadcast(early_stop_flag, src=0)
|
|
|
+ if early_stop_flag.item(): # item()取张量的布尔值
|
|
|
+ break
|
|
|
+ # else:
|
|
|
+ # # 非分布式模式下,直接检查早停标志
|
|
|
+ # if early_stopping.early_stop:
|
|
|
+ # break
|
|
|
+
|
|
|
+ return train_losses, val_losses
|