import gc import os import torch import torch.nn as nn import torch.distributed as dist from torch.utils.data import DataLoader, DistributedSampler from sklearn.model_selection import train_test_split from imblearn.over_sampling import SMOTE, RandomOverSampler from collections import Counter from evaluate import evaluate_model_distribute from utils import FlightDataset, EarlyStoppingDist # EarlyStopping, train_process, train_process_distribute, CombinedLoss import numpy as np import matplotlib.pyplot as plt import font import config import redis import time # 智能分层划分函数 def safe_train_test_split(*arrays, test_size=0.2, random_state=None, stratify=None, rank=0, local_rank=0): if stratify is not None: counts = Counter(stratify) min_count = min(counts.values()) if counts else 0 if min_count < 2: if local_rank == 0: print(f"Rank:{rank}, Local Rank:{local_rank}, 安全分层:检测到最小类别样本数={min_count},禁用分层") stratify = None return train_test_split( *arrays, test_size=test_size, random_state=random_state, stratify=stratify ) # 分布式数据集准备 def prepare_data_distribute(sequences, targets, group_ids, flag_distributed=False, rank=0, local_rank=0, world_size=1): if len(sequences) == 0 or len(targets) == 0: print(f"Rank:{rank}, 没有足够的数据参与训练。") return False, None, None, None targets_array = np.array([t[0].item() if isinstance(t[0], torch.Tensor) else t[0] for t in targets]) unique_classes, class_counts = np.unique(targets_array, return_counts=True) if len(unique_classes) == 1: print(f"Rank:{rank}, 警告:目标变量只有一个类别,无法参与训练。") return False, None, None, None # --- 高效过滤样本数 ≤ 1 的类别(浮点兼容版)--- unique_classes, class_counts = np.unique(targets_array, return_counts=True) class_to_count = dict(zip(unique_classes, class_counts)) valid_mask = np.array([class_to_count[cls] >= 2 for cls in targets_array]) if not np.any(valid_mask): print(f"Rank:{rank}, 警告:所有类别的样本数均 ≤ 1,无法分层拆分。") return False, None, None, None # 一次性筛选数据(兼容列表/Tensor/Array) sequences_filtered = [seq for i, seq in enumerate(sequences) if valid_mask[i]] targets_filtered = [t for i, t in enumerate(targets) if valid_mask[i]] group_ids_filtered = [g for i, g in enumerate(group_ids) if valid_mask[i]] targets_array_filtered = targets_array[valid_mask] # 第一步:将28样本拆分为训练集(80%)和临时集(20%) train_28, temp_28, train_28_targets, temp_28_targets, train_28_gids, temp_28_gids = safe_train_test_split( sequences_filtered, targets_filtered, group_ids_filtered, stratify=targets_array_filtered, test_size=0.2, random_state=42, rank=rank, local_rank=local_rank ) # 验证集与测试集全部引用临时集 val_28 = temp_28 test_28 = temp_28 val_28_targets = temp_28_targets test_28_targets = temp_28_targets val_28_gids = temp_28_gids test_28_gids = temp_28_gids # 合并训练集 train_sequences = train_28 train_targets = train_28_targets train_group_ids = train_28_gids # 合并验证集 val_sequences = val_28 val_targets = val_28_targets val_group_ids = val_28_gids # 测试集 test_sequences = test_28 test_targets = test_28_targets test_group_ids = test_28_gids if local_rank == 0: print(f"Rank:{rank}, Local Rank:{local_rank}, 批次训练集数量:{len(train_sequences)}") print(f"Rank:{rank}, Local Rank:{local_rank}, 批次验证集数量:{len(val_sequences)}") print(f"Rank:{rank}, Local Rank:{local_rank}, 批次测试集数量:{len(test_sequences)}") train_sequences_tensors = [torch.tensor(seq, dtype=torch.float32) for seq in train_sequences] train_targets_tensors = [torch.tensor(target, dtype=torch.float32) for target in train_targets] if local_rank == 0: # 打印检查 print(f"Rank:{rank}, Local Rank:{local_rank}, train_targets_tensors[0].shape:{train_targets_tensors[0].shape}") # 应该是 torch.Size([1]) print(f"Rank:{rank}, Local Rank:{local_rank}, train_sequences_tensors[0].dtype:{train_sequences_tensors[0].dtype}") # 应该是 torch.float32 print(f"Rank:{rank}, Local Rank:{local_rank}, train_targets_tensors[0].dtype:{train_targets_tensors[0].dtype}") # 应该是 torch.float32 train_single = {'sequences': train_sequences_tensors, 'targets': train_targets_tensors, 'group_ids': train_group_ids} val_single = {'sequences': val_sequences, 'targets': val_targets, 'group_ids': val_group_ids} test_single = {'sequences': test_sequences, 'targets': test_targets, 'group_ids': test_group_ids} def _redis_barrier(redis_client, barrier_key, world_size, timeout=3600, poll_interval=1): # 每个 rank 到达 barrier 时,将计数加 1 redis_client.incr(barrier_key) start_time = time.time() while True: count = redis_client.get(barrier_key) count = int(count) if count else 0 if count >= world_size: break if time.time() - start_time > timeout: raise TimeoutError("等待 barrier 超时") time.sleep(poll_interval) # 等待其他进程生成数据,并同步 if flag_distributed: redis_client = redis.Redis(host='192.168.20.222', port=6379, db=0) barrier_key = 'distributed_barrier_11' # 等待所有进程都到达 barrier _redis_barrier(redis_client, barrier_key, world_size) return True, train_single, val_single, test_single # 分布式训练 def train_model_distribute(train_sequences, train_targets, train_group_ids, val_sequences, val_targets, val_group_ids, model, criterion, optimizer, device, num_epochs=200, batch_size=16, target_scaler=None, flag_distributed=False, rank=0, local_rank=0, world_size=1, output_dir='.', photo_dir='.', batch_idx=-1, batch_flight_routes=None, patience=20, delta=0.01 ): # 统计正负样本数量 all_targets = torch.cat(train_targets) # 将所有目标值拼接成一个张量 positive_count = torch.sum(all_targets == 1).item() negative_count = torch.sum(all_targets == 0).item() total_samples = len(all_targets) # 计算比例 positive_ratio = positive_count / total_samples negative_ratio = negative_count / total_samples if local_rank == 0: # 打印检查 print(f"Rank:{rank}, Local Rank:{local_rank}, 汇总训练集数量:{len(train_sequences)}") print(f"Rank:{rank}, Local Rank:{local_rank}, 汇总训练集目标数量:{len(train_targets)}") print(f"Rank:{rank}, Local Rank:{local_rank}, 汇总验证集数量:{len(val_sequences)}") print(f"Rank:{rank}, Local Rank:{local_rank}, 汇总验证集目标数量:{len(val_targets)}") # 打印正负样本统计 print(f"Rank:{rank}, Local Rank:{local_rank}, 训练集总样本数: {total_samples}") print(f"Rank:{rank}, Local Rank:{local_rank}, 训练集正样本数: {positive_count} ({positive_ratio*100:.2f}%)") print(f"Rank:{rank}, Local Rank:{local_rank}, 训练集负样本数: {negative_count} ({negative_ratio*100:.2f}%)") # 计算并打印推荐的 pos_weight if positive_count > 0: recommended_pos_weight = negative_count / positive_count if local_rank == 0: print(f"Rank:{rank}, Local Rank:{local_rank}, 推荐的 pos_weight: {recommended_pos_weight:.2f}") else: recommended_pos_weight = 1.0 if local_rank == 0: print(f"Rank:{rank}, Local Rank:{local_rank}, 警告: 没有正样本!") train_dataset = FlightDataset(train_sequences, train_targets) val_dataset = FlightDataset(val_sequences, val_targets, val_group_ids) # test_dataset = FlightDataset(test_sequences, test_targets, test_group_ids) del train_sequences del train_targets del train_group_ids del val_sequences del val_targets del val_group_ids gc.collect() if flag_distributed: sampler_train = DistributedSampler(train_dataset, shuffle=True) # 分布式采样器 sampler_val = DistributedSampler(val_dataset, shuffle=False) train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler_train) val_loader = DataLoader(val_dataset, batch_size=batch_size, sampler=sampler_val) else: train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=batch_size) if local_rank == 0: print(f"Rank:{rank}, Local Rank:{local_rank}, train_dataset 0 0 {train_dataset[0][0].shape}") # 特征尺寸 print(f"Rank:{rank}, Local Rank:{local_rank}, train_dataset 0 1 {train_dataset[0][1].shape}") # 目标尺寸 pos_weight_value = recommended_pos_weight # 从上面的计算中获取 # 创建带权重的损失函数 criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight_value])).to(device) early_stopping = EarlyStoppingDist(patience=patience, verbose=True, delta=delta, path=os.path.join(output_dir, f'best_model_as_{batch_idx}.pth'), rank=rank, local_rank=local_rank) # 分布式训练模型 train_losses, val_losses = train_process_distribute( model, optimizer, early_stopping, train_loader, val_loader, device, num_epochs=num_epochs, criterion=criterion, flag_distributed=flag_distributed, rank=rank, local_rank=local_rank, loss_call_label="val") if rank == 0: font_prop = font.font_prop # 绘制损失曲线(可选) plt.figure(figsize=(10, 6)) epochs = range(1, len(train_losses) + 1) plt.plot(epochs, train_losses, 'b-', label='训练集损失') plt.plot(epochs, val_losses, 'r-', label='验证集损失') plt.title('训练和验证集损失曲线', fontproperties=font_prop) plt.xlabel('Epochs', fontproperties=font_prop) plt.ylabel('Loss', fontproperties=font_prop) plt.legend(prop=font_prop) plt.savefig(os.path.join(photo_dir, f"train_loss_batch_{batch_idx}.png")) # 训练结束后加载最佳模型参数 best_model_path = os.path.join(output_dir, f'best_model_as_{batch_idx}.pth') # 确保所有进程都看到相同的文件系统状态 if flag_distributed: dist.barrier() # 创建用于广播的列表(只有一个元素) checkpoint_list = [None] if rank == 0: if os.path.exists(best_model_path): print(f"Rank 0: batch_idx:{batch_idx} Loading best model from {best_model_path}") # 直接加载到 CPU,避免设备不一致问题 checkpoint_list[0] = torch.load(best_model_path, map_location='cpu') else: print(f"Rank 0: batch_idx:{batch_idx} Warning - Best model not found at {best_model_path}") # 使用当前模型状态(确保在 CPU 上) if flag_distributed: checkpoint_list[0] = model.module.cpu().state_dict() else: checkpoint_list[0] = model.cpu().state_dict() # 广播模型状态字典 if flag_distributed: dist.broadcast_object_list(checkpoint_list, src=0) # 所有进程获取广播后的状态字典 checkpoint = checkpoint_list[0] # 加载模型状态 if flag_distributed: model.module.load_state_dict(checkpoint) else: model.load_state_dict(checkpoint) # 确保所有进程完成加载 if flag_distributed: dist.barrier() if flag_distributed: # 调用评估函数 evaluate_model_distribute( model.module, # 使用 DDP 包裹前的原始模型 device, None, None, None, test_loader=val_loader, # 使用累积验证集 batch_flight_routes=batch_flight_routes, target_scaler=target_scaler, flag_distributed=flag_distributed, rank=rank, local_rank=local_rank, world_size=world_size, output_dir=output_dir, batch_idx=batch_idx, save_mode='a' ) else: evaluate_model_distribute( model, device, None, None, None, test_loader=val_loader, # 使用累积验证集 batch_flight_routes=batch_flight_routes, target_scaler=target_scaler, flag_distributed=False, output_dir=output_dir, batch_idx=batch_idx, save_mode='a' ) return model def train_process_distribute(model, optimizer, early_stopping, train_loader, val_loader, device, num_epochs=200, criterion=None, save_file='best_model.pth', flag_distributed=False, rank=0, local_rank=0, loss_call_label="train"): # 具体训练过程 train_losses = [] val_losses = [] # 初始化损失为张量(兼容非分布式和分布式) # total_train_loss = torch.tensor(0.0, device=device) # total_val_loss = torch.tensor(0.0, device=device) # 初始化 TensorBoard(只在主进程) # if rank == 0: # writer = SummaryWriter(log_dir='runs/experiment_name') # train_global_step = 0 # val_global_step = 0 for epoch in range(num_epochs): # --- 训练阶段 --- model.train() if flag_distributed: train_loader.sampler.set_epoch(epoch) # 确保每个进程一致地打乱顺序 # total_train_loss.zero_() # 重置损失累计 total_train_loss = torch.tensor(0.0, device=device) num_train_samples = torch.tensor(0, device=device) # 当前进程的样本数 for batch_idx, batch in enumerate(train_loader): X_batch, y_batch = batch[:2] # 假设 group_ids 不需要参与训练 X_batch = X_batch.to(device) y_batch = y_batch.to(device) optimizer.zero_grad() outputs = model(X_batch) loss = criterion(outputs, y_batch) loss.backward() # 打印 # if rank == 0: # # print_gradient_range(model) # # 记录损失值 # writer.add_scalar('Loss/train_batch', loss.item(), train_global_step) # # 记录元数据 # writer.add_scalar('Metadata/train_epoch', epoch, train_global_step) # writer.add_scalar('Metadata/train_batch_in_epoch', batch_idx, train_global_step) # log_gradient_stats(model, writer, train_global_step, "train") # # 更新全局步数 # train_global_step += 1 # 梯度裁剪(已兼容 DDP) # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() # 累计损失 total_train_loss += loss.detach() * X_batch.size(0) # detach() 保留张量形式以支持跨进程通信 num_train_samples += X_batch.size(0) # --- 同步训练损失 --- if flag_distributed: # 会将所有进程的 total_train_loss 求和后, 同步到每个进程 dist.all_reduce(total_train_loss, op=dist.ReduceOp.SUM) dist.all_reduce(num_train_samples, op=dist.ReduceOp.SUM) # avg_train_loss = total_train_loss.item() / len(train_loader.dataset) avg_train_loss = total_train_loss.item() / num_train_samples.item() train_losses.append(avg_train_loss) # --- 验证阶段 --- model.eval() # total_val_loss.zero_() # 重置验证损失 total_val_loss = torch.tensor(0.0, device=device) num_val_samples = torch.tensor(0, device=device) with torch.no_grad(): for batch_idx, batch in enumerate(val_loader): X_val, y_val = batch[:2] X_val = X_val.to(device) y_val = y_val.to(device) outputs = model(X_val) val_loss = criterion(outputs, y_val) total_val_loss += val_loss.detach() * X_val.size(0) num_val_samples += X_val.size(0) # if rank == 0: # # 记录验证集batch loss # writer.add_scalar('Loss/val_batch', val_loss.item(), val_global_step) # # 记录验证集元数据 # writer.add_scalar('Metadata/val_epoch', epoch, val_global_step) # writer.add_scalar('Metadata/val_batch_in_epoch', batch_idx, val_global_step) # # 更新验证集全局步数 # val_global_step += 1 # if local_rank == 0: # print(f"rank:{rank}, outputs:{outputs}") # print(f"rank:{rank}, y_val:{y_val}") # print(f"rank:{rank}, val_loss:{val_loss.detach()}") # print(f"rank:{rank}, size:{X_val.size(0)}") # --- 同步验证损失 --- if flag_distributed: dist.all_reduce(total_val_loss, op=dist.ReduceOp.SUM) dist.all_reduce(num_val_samples, op=dist.ReduceOp.SUM) # avg_val_loss = total_val_loss.item() / len(val_loader.dataset) avg_val_loss = total_val_loss.item() / num_val_samples.item() val_losses.append(avg_val_loss) # if rank == 0: # # 记录epoch平均损失 # writer.add_scalar('Loss/train_epoch_avg', avg_train_loss, epoch) # writer.add_scalar('Loss/val_epoch_avg', avg_val_loss, epoch) if local_rank == 0: print(f"Rank:{rank}, Epoch {epoch+1}/{num_epochs}, 训练集损失: {avg_train_loss:.4f}, 验证集损失: {avg_val_loss:.4f}") # --- 早停与保存逻辑(仅在 rank 0 执行)--- if rank == 0: # 模型保存兼容分布式和非分布式 model_to_save = model.module if flag_distributed else model # 当使用 model = DDP(model) 封装后,原始模型会被包裹在 model.module 属性 if loss_call_label == "train": early_stopping(avg_train_loss, model_to_save) # 平均训练集损失 else: early_stopping(avg_val_loss, model_to_save) # 平均验证集损失 if early_stopping.early_stop: print(f"Rank:{rank}, 早停触发,停止训练 at epoch {epoch}") # 非分布式模式下直接退出循环 if not flag_distributed: break # --- 同步早停状态(仅分布式需要)--- if flag_distributed: # 将早停标志转换为张量广播 early_stop_flag = torch.tensor([early_stopping.early_stop], device=device) dist.broadcast(early_stop_flag, src=0) if early_stop_flag.item(): # item()取张量的布尔值 break # else: # # 非分布式模式下,直接检查早停标志 # if early_stopping.early_stop: # break return train_losses, val_losses