| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443 |
- import gc
- import os
- import torch
- import torch.nn as nn
- import torch.distributed as dist
- from torch.utils.data import DataLoader, DistributedSampler
- from sklearn.model_selection import train_test_split
- from imblearn.over_sampling import SMOTE, RandomOverSampler
- from collections import Counter
- from evaluate import evaluate_model_distribute
- from utils import FlightDataset, EarlyStoppingDist # EarlyStopping, train_process, train_process_distribute, CombinedLoss
- import numpy as np
- import matplotlib.pyplot as plt
- import font
- import config
- import redis
- import time
- # 智能分层划分函数
- def safe_train_test_split(*arrays, test_size=0.2, random_state=None, stratify=None, rank=0, local_rank=0):
- if stratify is not None:
- counts = Counter(stratify)
- min_count = min(counts.values()) if counts else 0
- if min_count < 2:
- if local_rank == 0:
- print(f"Rank:{rank}, Local Rank:{local_rank}, 安全分层:检测到最小类别样本数={min_count},禁用分层")
- stratify = None
-
- return train_test_split(
- *arrays,
- test_size=test_size,
- random_state=random_state,
- stratify=stratify
- )
- # 分布式数据集准备
- def prepare_data_distribute(sequences, targets, group_ids, flag_distributed=False, rank=0, local_rank=0, world_size=1):
- if len(sequences) == 0 or len(targets) == 0:
- print(f"Rank:{rank}, 没有足够的数据参与训练。")
- return False, None, None, None
-
- targets_array = np.array([t[0].item() if isinstance(t[0], torch.Tensor) else t[0] for t in targets])
- unique_classes, class_counts = np.unique(targets_array, return_counts=True)
- if len(unique_classes) == 1:
- print(f"Rank:{rank}, 警告:目标变量只有一个类别,无法参与训练。")
- return False, None, None, None
-
- # --- 高效过滤样本数 ≤ 1 的类别(浮点兼容版)---
- unique_classes, class_counts = np.unique(targets_array, return_counts=True)
- class_to_count = dict(zip(unique_classes, class_counts))
- valid_mask = np.array([class_to_count[cls] >= 2 for cls in targets_array])
- if not np.any(valid_mask):
- print(f"Rank:{rank}, 警告:所有类别的样本数均 ≤ 1,无法分层拆分。")
- return False, None, None, None
- # 一次性筛选数据(兼容列表/Tensor/Array)
- sequences_filtered = [seq for i, seq in enumerate(sequences) if valid_mask[i]]
- targets_filtered = [t for i, t in enumerate(targets) if valid_mask[i]]
- group_ids_filtered = [g for i, g in enumerate(group_ids) if valid_mask[i]]
- targets_array_filtered = targets_array[valid_mask]
- # 第一步:将28样本拆分为训练集(80%)和临时集(20%)
- train_28, temp_28, train_28_targets, temp_28_targets, train_28_gids, temp_28_gids = safe_train_test_split(
- sequences_filtered, targets_filtered, group_ids_filtered,
- stratify=targets_array_filtered,
- test_size=0.2,
- random_state=42,
- rank=rank,
- local_rank=local_rank
- )
- # 验证集与测试集全部引用临时集
- val_28 = temp_28
- test_28 = temp_28
- val_28_targets = temp_28_targets
- test_28_targets = temp_28_targets
- val_28_gids = temp_28_gids
- test_28_gids = temp_28_gids
- # 合并训练集
- train_sequences = train_28
- train_targets = train_28_targets
- train_group_ids = train_28_gids
-
- # 合并验证集
- val_sequences = val_28
- val_targets = val_28_targets
- val_group_ids = val_28_gids
- # 测试集
- test_sequences = test_28
- test_targets = test_28_targets
- test_group_ids = test_28_gids
- if local_rank == 0:
- print(f"Rank:{rank}, Local Rank:{local_rank}, 批次训练集数量:{len(train_sequences)}")
- print(f"Rank:{rank}, Local Rank:{local_rank}, 批次验证集数量:{len(val_sequences)}")
- print(f"Rank:{rank}, Local Rank:{local_rank}, 批次测试集数量:{len(test_sequences)}")
- train_sequences_tensors = [torch.tensor(seq, dtype=torch.float32) for seq in train_sequences]
- train_targets_tensors = [torch.tensor(target, dtype=torch.float32) for target in train_targets]
- if local_rank == 0:
- # 打印检查
- print(f"Rank:{rank}, Local Rank:{local_rank}, train_targets_tensors[0].shape:{train_targets_tensors[0].shape}") # 应该是 torch.Size([1])
- print(f"Rank:{rank}, Local Rank:{local_rank}, train_sequences_tensors[0].dtype:{train_sequences_tensors[0].dtype}") # 应该是 torch.float32
- print(f"Rank:{rank}, Local Rank:{local_rank}, train_targets_tensors[0].dtype:{train_targets_tensors[0].dtype}") # 应该是 torch.float32
- train_single = {'sequences': train_sequences_tensors, 'targets': train_targets_tensors, 'group_ids': train_group_ids}
- val_single = {'sequences': val_sequences, 'targets': val_targets, 'group_ids': val_group_ids}
- test_single = {'sequences': test_sequences, 'targets': test_targets, 'group_ids': test_group_ids}
- def _redis_barrier(redis_client, barrier_key, world_size, timeout=3600, poll_interval=1):
- # 每个 rank 到达 barrier 时,将计数加 1
- redis_client.incr(barrier_key)
-
- start_time = time.time()
- while True:
- count = redis_client.get(barrier_key)
- count = int(count) if count else 0
- if count >= world_size:
- break
- if time.time() - start_time > timeout:
- raise TimeoutError("等待 barrier 超时")
- time.sleep(poll_interval)
- # 等待其他进程生成数据,并同步
- if flag_distributed:
- redis_client = redis.Redis(host='192.168.20.222', port=6379, db=0)
- barrier_key = 'distributed_barrier_11'
- # 等待所有进程都到达 barrier
- _redis_barrier(redis_client, barrier_key, world_size)
- return True, train_single, val_single, test_single
- # 分布式训练
- def train_model_distribute(train_sequences, train_targets, train_group_ids, val_sequences, val_targets, val_group_ids,
- model, criterion, optimizer, device, num_epochs=200, batch_size=16, target_scaler=None,
- flag_distributed=False, rank=0, local_rank=0, world_size=1, output_dir='.',
- photo_dir='.', batch_idx=-1, batch_flight_routes=None, patience=20, delta=0.01
- ):
-
- # 统计正负样本数量
- all_targets = torch.cat(train_targets) # 将所有目标值拼接成一个张量
- positive_count = torch.sum(all_targets == 1).item()
- negative_count = torch.sum(all_targets == 0).item()
- total_samples = len(all_targets)
- # 计算比例
- positive_ratio = positive_count / total_samples
- negative_ratio = negative_count / total_samples
- if local_rank == 0:
- # 打印检查
- print(f"Rank:{rank}, Local Rank:{local_rank}, 汇总训练集数量:{len(train_sequences)}")
- print(f"Rank:{rank}, Local Rank:{local_rank}, 汇总训练集目标数量:{len(train_targets)}")
- print(f"Rank:{rank}, Local Rank:{local_rank}, 汇总验证集数量:{len(val_sequences)}")
- print(f"Rank:{rank}, Local Rank:{local_rank}, 汇总验证集目标数量:{len(val_targets)}")
- # 打印正负样本统计
- print(f"Rank:{rank}, Local Rank:{local_rank}, 训练集总样本数: {total_samples}")
- print(f"Rank:{rank}, Local Rank:{local_rank}, 训练集正样本数: {positive_count} ({positive_ratio*100:.2f}%)")
- print(f"Rank:{rank}, Local Rank:{local_rank}, 训练集负样本数: {negative_count} ({negative_ratio*100:.2f}%)")
- # 计算并打印推荐的 pos_weight
- if positive_count > 0:
- recommended_pos_weight = negative_count / positive_count
- if local_rank == 0:
- print(f"Rank:{rank}, Local Rank:{local_rank}, 推荐的 pos_weight: {recommended_pos_weight:.2f}")
- else:
- recommended_pos_weight = 1.0
- if local_rank == 0:
- print(f"Rank:{rank}, Local Rank:{local_rank}, 警告: 没有正样本!")
- train_dataset = FlightDataset(train_sequences, train_targets)
- val_dataset = FlightDataset(val_sequences, val_targets, val_group_ids)
- # test_dataset = FlightDataset(test_sequences, test_targets, test_group_ids)
- del train_sequences
- del train_targets
- del train_group_ids
- del val_sequences
- del val_targets
- del val_group_ids
- gc.collect()
- if flag_distributed:
- sampler_train = DistributedSampler(train_dataset, shuffle=True) # 分布式采样器
- sampler_val = DistributedSampler(val_dataset, shuffle=False)
-
- train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler_train)
- val_loader = DataLoader(val_dataset, batch_size=batch_size, sampler=sampler_val)
- else:
- train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
- val_loader = DataLoader(val_dataset, batch_size=batch_size)
- if local_rank == 0:
- print(f"Rank:{rank}, Local Rank:{local_rank}, train_dataset 0 0 {train_dataset[0][0].shape}") # 特征尺寸
- print(f"Rank:{rank}, Local Rank:{local_rank}, train_dataset 0 1 {train_dataset[0][1].shape}") # 目标尺寸
- pos_weight_value = recommended_pos_weight # 从上面的计算中获取
- # 创建带权重的损失函数
- criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight_value])).to(device)
- early_stopping = EarlyStoppingDist(patience=patience, verbose=True, delta=delta, path=os.path.join(output_dir, f'best_model_as_{batch_idx}.pth'),
- rank=rank, local_rank=local_rank)
-
- # 分布式训练模型
- train_losses, val_losses = train_process_distribute(
- model, optimizer, early_stopping, train_loader, val_loader, device, num_epochs=num_epochs, criterion=criterion,
- flag_distributed=flag_distributed, rank=rank, local_rank=local_rank, loss_call_label="val")
-
- if rank == 0:
- font_prop = font.font_prop
- # 绘制损失曲线(可选)
- plt.figure(figsize=(10, 6))
- epochs = range(1, len(train_losses) + 1)
- plt.plot(epochs, train_losses, 'b-', label='训练集损失')
- plt.plot(epochs, val_losses, 'r-', label='验证集损失')
- plt.title('训练和验证集损失曲线', fontproperties=font_prop)
- plt.xlabel('Epochs', fontproperties=font_prop)
- plt.ylabel('Loss', fontproperties=font_prop)
- plt.legend(prop=font_prop)
- plt.savefig(os.path.join(photo_dir, f"train_loss_batch_{batch_idx}.png"))
- # 训练结束后加载最佳模型参数
- best_model_path = os.path.join(output_dir, f'best_model_as_{batch_idx}.pth')
- # 确保所有进程都看到相同的文件系统状态
- if flag_distributed:
- dist.barrier()
-
- # 创建用于广播的列表(只有一个元素)
- checkpoint_list = [None]
- if rank == 0:
- if os.path.exists(best_model_path):
- print(f"Rank 0: batch_idx:{batch_idx} Loading best model from {best_model_path}")
- # 直接加载到 CPU,避免设备不一致问题
- checkpoint_list[0] = torch.load(best_model_path, map_location='cpu')
- else:
- print(f"Rank 0: batch_idx:{batch_idx} Warning - Best model not found at {best_model_path}")
- # 使用当前模型状态(确保在 CPU 上)
- if flag_distributed:
- checkpoint_list[0] = model.module.cpu().state_dict()
- else:
- checkpoint_list[0] = model.cpu().state_dict()
- # 广播模型状态字典
- if flag_distributed:
- dist.broadcast_object_list(checkpoint_list, src=0)
- # 所有进程获取广播后的状态字典
- checkpoint = checkpoint_list[0]
- # 加载模型状态
- if flag_distributed:
- model.module.load_state_dict(checkpoint)
- else:
- model.load_state_dict(checkpoint)
-
- # 确保所有进程完成加载
- if flag_distributed:
- dist.barrier()
- if flag_distributed:
- # 调用评估函数
- evaluate_model_distribute(
- model.module, # 使用 DDP 包裹前的原始模型
- device,
- None, None, None,
- test_loader=val_loader, # 使用累积验证集
- batch_flight_routes=batch_flight_routes, target_scaler=target_scaler,
- flag_distributed=flag_distributed,
- rank=rank, local_rank=local_rank, world_size=world_size,
- output_dir=output_dir, batch_idx=batch_idx, save_mode='a'
- )
- else:
- evaluate_model_distribute(
- model,
- device,
- None, None, None,
- test_loader=val_loader, # 使用累积验证集
- batch_flight_routes=batch_flight_routes, target_scaler=target_scaler,
- flag_distributed=False,
- output_dir=output_dir, batch_idx=batch_idx, save_mode='a'
- )
- return model
- def train_process_distribute(model, optimizer, early_stopping, train_loader, val_loader, device, num_epochs=200, criterion=None, save_file='best_model.pth',
- flag_distributed=False, rank=0, local_rank=0, loss_call_label="train"):
- # 具体训练过程
- train_losses = []
- val_losses = []
-
- # 初始化损失为张量(兼容非分布式和分布式)
- # total_train_loss = torch.tensor(0.0, device=device)
- # total_val_loss = torch.tensor(0.0, device=device)
- # 初始化 TensorBoard(只在主进程)
- # if rank == 0:
- # writer = SummaryWriter(log_dir='runs/experiment_name')
- # train_global_step = 0
- # val_global_step = 0
-
- for epoch in range(num_epochs):
- # --- 训练阶段 ---
- model.train()
- if flag_distributed:
- train_loader.sampler.set_epoch(epoch) # 确保每个进程一致地打乱顺序
- # total_train_loss.zero_() # 重置损失累计
- total_train_loss = torch.tensor(0.0, device=device)
- num_train_samples = torch.tensor(0, device=device) # 当前进程的样本数
- for batch_idx, batch in enumerate(train_loader):
- X_batch, y_batch = batch[:2] # 假设 group_ids 不需要参与训练
- X_batch = X_batch.to(device)
- y_batch = y_batch.to(device)
-
- optimizer.zero_grad()
- outputs = model(X_batch)
- loss = criterion(outputs, y_batch)
- loss.backward()
- # 打印
- # if rank == 0:
- # # print_gradient_range(model)
- # # 记录损失值
- # writer.add_scalar('Loss/train_batch', loss.item(), train_global_step)
- # # 记录元数据
- # writer.add_scalar('Metadata/train_epoch', epoch, train_global_step)
- # writer.add_scalar('Metadata/train_batch_in_epoch', batch_idx, train_global_step)
- # log_gradient_stats(model, writer, train_global_step, "train")
- # # 更新全局步数
- # train_global_step += 1
- # 梯度裁剪(已兼容 DDP)
- # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
- optimizer.step()
-
- # 累计损失
- total_train_loss += loss.detach() * X_batch.size(0) # detach() 保留张量形式以支持跨进程通信
- num_train_samples += X_batch.size(0)
-
- # --- 同步训练损失 ---
- if flag_distributed:
- # 会将所有进程的 total_train_loss 求和后, 同步到每个进程
- dist.all_reduce(total_train_loss, op=dist.ReduceOp.SUM)
- dist.all_reduce(num_train_samples, op=dist.ReduceOp.SUM)
- # avg_train_loss = total_train_loss.item() / len(train_loader.dataset)
- avg_train_loss = total_train_loss.item() / num_train_samples.item()
- train_losses.append(avg_train_loss)
- # --- 验证阶段 ---
- model.eval()
- # total_val_loss.zero_() # 重置验证损失
- total_val_loss = torch.tensor(0.0, device=device)
- num_val_samples = torch.tensor(0, device=device)
- with torch.no_grad():
- for batch_idx, batch in enumerate(val_loader):
- X_val, y_val = batch[:2]
- X_val = X_val.to(device)
- y_val = y_val.to(device)
-
- outputs = model(X_val)
- val_loss = criterion(outputs, y_val)
- total_val_loss += val_loss.detach() * X_val.size(0)
- num_val_samples += X_val.size(0)
- # if rank == 0:
- # # 记录验证集batch loss
- # writer.add_scalar('Loss/val_batch', val_loss.item(), val_global_step)
- # # 记录验证集元数据
- # writer.add_scalar('Metadata/val_epoch', epoch, val_global_step)
- # writer.add_scalar('Metadata/val_batch_in_epoch', batch_idx, val_global_step)
- # # 更新验证集全局步数
- # val_global_step += 1
- # if local_rank == 0:
- # print(f"rank:{rank}, outputs:{outputs}")
- # print(f"rank:{rank}, y_val:{y_val}")
- # print(f"rank:{rank}, val_loss:{val_loss.detach()}")
- # print(f"rank:{rank}, size:{X_val.size(0)}")
- # --- 同步验证损失 ---
- if flag_distributed:
- dist.all_reduce(total_val_loss, op=dist.ReduceOp.SUM)
- dist.all_reduce(num_val_samples, op=dist.ReduceOp.SUM)
- # avg_val_loss = total_val_loss.item() / len(val_loader.dataset)
- avg_val_loss = total_val_loss.item() / num_val_samples.item()
- val_losses.append(avg_val_loss)
- # if rank == 0:
- # # 记录epoch平均损失
- # writer.add_scalar('Loss/train_epoch_avg', avg_train_loss, epoch)
- # writer.add_scalar('Loss/val_epoch_avg', avg_val_loss, epoch)
- if local_rank == 0:
- print(f"Rank:{rank}, Epoch {epoch+1}/{num_epochs}, 训练集损失: {avg_train_loss:.4f}, 验证集损失: {avg_val_loss:.4f}")
- # --- 早停与保存逻辑(仅在 rank 0 执行)---
- if rank == 0:
-
- # 模型保存兼容分布式和非分布式
- model_to_save = model.module if flag_distributed else model # 当使用 model = DDP(model) 封装后,原始模型会被包裹在 model.module 属性
- if loss_call_label == "train":
- early_stopping(avg_train_loss, model_to_save) # 平均训练集损失
- else:
- early_stopping(avg_val_loss, model_to_save) # 平均验证集损失
-
- if early_stopping.early_stop:
- print(f"Rank:{rank}, 早停触发,停止训练 at epoch {epoch}")
- # 非分布式模式下直接退出循环
- if not flag_distributed:
- break
- # --- 同步早停状态(仅分布式需要)---
- if flag_distributed:
- # 将早停标志转换为张量广播
- early_stop_flag = torch.tensor([early_stopping.early_stop], device=device)
- dist.broadcast(early_stop_flag, src=0)
- if early_stop_flag.item(): # item()取张量的布尔值
- break
- # else:
- # # 非分布式模式下,直接检查早停标志
- # if early_stopping.early_stop:
- # break
- return train_losses, val_losses
|