train.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443
  1. import gc
  2. import os
  3. import torch
  4. import torch.nn as nn
  5. import torch.distributed as dist
  6. from torch.utils.data import DataLoader, DistributedSampler
  7. from sklearn.model_selection import train_test_split
  8. from imblearn.over_sampling import SMOTE, RandomOverSampler
  9. from collections import Counter
  10. from evaluate import evaluate_model_distribute
  11. from utils import FlightDataset, EarlyStoppingDist # EarlyStopping, train_process, train_process_distribute, CombinedLoss
  12. import numpy as np
  13. import matplotlib.pyplot as plt
  14. import font
  15. import config
  16. import redis
  17. import time
  18. # 智能分层划分函数
  19. def safe_train_test_split(*arrays, test_size=0.2, random_state=None, stratify=None, rank=0, local_rank=0):
  20. if stratify is not None:
  21. counts = Counter(stratify)
  22. min_count = min(counts.values()) if counts else 0
  23. if min_count < 2:
  24. if local_rank == 0:
  25. print(f"Rank:{rank}, Local Rank:{local_rank}, 安全分层:检测到最小类别样本数={min_count},禁用分层")
  26. stratify = None
  27. return train_test_split(
  28. *arrays,
  29. test_size=test_size,
  30. random_state=random_state,
  31. stratify=stratify
  32. )
  33. # 分布式数据集准备
  34. def prepare_data_distribute(sequences, targets, group_ids, flag_distributed=False, rank=0, local_rank=0, world_size=1):
  35. if len(sequences) == 0 or len(targets) == 0:
  36. print(f"Rank:{rank}, 没有足够的数据参与训练。")
  37. return False, None, None, None
  38. targets_array = np.array([t[0].item() if isinstance(t[0], torch.Tensor) else t[0] for t in targets])
  39. unique_classes, class_counts = np.unique(targets_array, return_counts=True)
  40. if len(unique_classes) == 1:
  41. print(f"Rank:{rank}, 警告:目标变量只有一个类别,无法参与训练。")
  42. return False, None, None, None
  43. # --- 高效过滤样本数 ≤ 1 的类别(浮点兼容版)---
  44. unique_classes, class_counts = np.unique(targets_array, return_counts=True)
  45. class_to_count = dict(zip(unique_classes, class_counts))
  46. valid_mask = np.array([class_to_count[cls] >= 2 for cls in targets_array])
  47. if not np.any(valid_mask):
  48. print(f"Rank:{rank}, 警告:所有类别的样本数均 ≤ 1,无法分层拆分。")
  49. return False, None, None, None
  50. # 一次性筛选数据(兼容列表/Tensor/Array)
  51. sequences_filtered = [seq for i, seq in enumerate(sequences) if valid_mask[i]]
  52. targets_filtered = [t for i, t in enumerate(targets) if valid_mask[i]]
  53. group_ids_filtered = [g for i, g in enumerate(group_ids) if valid_mask[i]]
  54. targets_array_filtered = targets_array[valid_mask]
  55. # 第一步:将28样本拆分为训练集(80%)和临时集(20%)
  56. train_28, temp_28, train_28_targets, temp_28_targets, train_28_gids, temp_28_gids = safe_train_test_split(
  57. sequences_filtered, targets_filtered, group_ids_filtered,
  58. stratify=targets_array_filtered,
  59. test_size=0.2,
  60. random_state=42,
  61. rank=rank,
  62. local_rank=local_rank
  63. )
  64. # 验证集与测试集全部引用临时集
  65. val_28 = temp_28
  66. test_28 = temp_28
  67. val_28_targets = temp_28_targets
  68. test_28_targets = temp_28_targets
  69. val_28_gids = temp_28_gids
  70. test_28_gids = temp_28_gids
  71. # 合并训练集
  72. train_sequences = train_28
  73. train_targets = train_28_targets
  74. train_group_ids = train_28_gids
  75. # 合并验证集
  76. val_sequences = val_28
  77. val_targets = val_28_targets
  78. val_group_ids = val_28_gids
  79. # 测试集
  80. test_sequences = test_28
  81. test_targets = test_28_targets
  82. test_group_ids = test_28_gids
  83. if local_rank == 0:
  84. print(f"Rank:{rank}, Local Rank:{local_rank}, 批次训练集数量:{len(train_sequences)}")
  85. print(f"Rank:{rank}, Local Rank:{local_rank}, 批次验证集数量:{len(val_sequences)}")
  86. print(f"Rank:{rank}, Local Rank:{local_rank}, 批次测试集数量:{len(test_sequences)}")
  87. train_sequences_tensors = [torch.tensor(seq, dtype=torch.float32) for seq in train_sequences]
  88. train_targets_tensors = [torch.tensor(target, dtype=torch.float32) for target in train_targets]
  89. if local_rank == 0:
  90. # 打印检查
  91. print(f"Rank:{rank}, Local Rank:{local_rank}, train_targets_tensors[0].shape:{train_targets_tensors[0].shape}") # 应该是 torch.Size([1])
  92. print(f"Rank:{rank}, Local Rank:{local_rank}, train_sequences_tensors[0].dtype:{train_sequences_tensors[0].dtype}") # 应该是 torch.float32
  93. print(f"Rank:{rank}, Local Rank:{local_rank}, train_targets_tensors[0].dtype:{train_targets_tensors[0].dtype}") # 应该是 torch.float32
  94. train_single = {'sequences': train_sequences_tensors, 'targets': train_targets_tensors, 'group_ids': train_group_ids}
  95. val_single = {'sequences': val_sequences, 'targets': val_targets, 'group_ids': val_group_ids}
  96. test_single = {'sequences': test_sequences, 'targets': test_targets, 'group_ids': test_group_ids}
  97. def _redis_barrier(redis_client, barrier_key, world_size, timeout=3600, poll_interval=1):
  98. # 每个 rank 到达 barrier 时,将计数加 1
  99. redis_client.incr(barrier_key)
  100. start_time = time.time()
  101. while True:
  102. count = redis_client.get(barrier_key)
  103. count = int(count) if count else 0
  104. if count >= world_size:
  105. break
  106. if time.time() - start_time > timeout:
  107. raise TimeoutError("等待 barrier 超时")
  108. time.sleep(poll_interval)
  109. # 等待其他进程生成数据,并同步
  110. if flag_distributed:
  111. redis_client = redis.Redis(host='192.168.20.222', port=6379, db=0)
  112. barrier_key = 'distributed_barrier_11'
  113. # 等待所有进程都到达 barrier
  114. _redis_barrier(redis_client, barrier_key, world_size)
  115. return True, train_single, val_single, test_single
  116. # 分布式训练
  117. def train_model_distribute(train_sequences, train_targets, train_group_ids, val_sequences, val_targets, val_group_ids,
  118. model, criterion, optimizer, device, num_epochs=200, batch_size=16, target_scaler=None,
  119. flag_distributed=False, rank=0, local_rank=0, world_size=1, output_dir='.',
  120. photo_dir='.', batch_idx=-1, batch_flight_routes=None, patience=20, delta=0.01
  121. ):
  122. # 统计正负样本数量
  123. all_targets = torch.cat(train_targets) # 将所有目标值拼接成一个张量
  124. positive_count = torch.sum(all_targets == 1).item()
  125. negative_count = torch.sum(all_targets == 0).item()
  126. total_samples = len(all_targets)
  127. # 计算比例
  128. positive_ratio = positive_count / total_samples
  129. negative_ratio = negative_count / total_samples
  130. if local_rank == 0:
  131. # 打印检查
  132. print(f"Rank:{rank}, Local Rank:{local_rank}, 汇总训练集数量:{len(train_sequences)}")
  133. print(f"Rank:{rank}, Local Rank:{local_rank}, 汇总训练集目标数量:{len(train_targets)}")
  134. print(f"Rank:{rank}, Local Rank:{local_rank}, 汇总验证集数量:{len(val_sequences)}")
  135. print(f"Rank:{rank}, Local Rank:{local_rank}, 汇总验证集目标数量:{len(val_targets)}")
  136. # 打印正负样本统计
  137. print(f"Rank:{rank}, Local Rank:{local_rank}, 训练集总样本数: {total_samples}")
  138. print(f"Rank:{rank}, Local Rank:{local_rank}, 训练集正样本数: {positive_count} ({positive_ratio*100:.2f}%)")
  139. print(f"Rank:{rank}, Local Rank:{local_rank}, 训练集负样本数: {negative_count} ({negative_ratio*100:.2f}%)")
  140. # 计算并打印推荐的 pos_weight
  141. if positive_count > 0:
  142. recommended_pos_weight = negative_count / positive_count
  143. if local_rank == 0:
  144. print(f"Rank:{rank}, Local Rank:{local_rank}, 推荐的 pos_weight: {recommended_pos_weight:.2f}")
  145. else:
  146. recommended_pos_weight = 1.0
  147. if local_rank == 0:
  148. print(f"Rank:{rank}, Local Rank:{local_rank}, 警告: 没有正样本!")
  149. train_dataset = FlightDataset(train_sequences, train_targets)
  150. val_dataset = FlightDataset(val_sequences, val_targets, val_group_ids)
  151. # test_dataset = FlightDataset(test_sequences, test_targets, test_group_ids)
  152. del train_sequences
  153. del train_targets
  154. del train_group_ids
  155. del val_sequences
  156. del val_targets
  157. del val_group_ids
  158. gc.collect()
  159. if flag_distributed:
  160. sampler_train = DistributedSampler(train_dataset, shuffle=True) # 分布式采样器
  161. sampler_val = DistributedSampler(val_dataset, shuffle=False)
  162. train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler_train)
  163. val_loader = DataLoader(val_dataset, batch_size=batch_size, sampler=sampler_val)
  164. else:
  165. train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
  166. val_loader = DataLoader(val_dataset, batch_size=batch_size)
  167. if local_rank == 0:
  168. print(f"Rank:{rank}, Local Rank:{local_rank}, train_dataset 0 0 {train_dataset[0][0].shape}") # 特征尺寸
  169. print(f"Rank:{rank}, Local Rank:{local_rank}, train_dataset 0 1 {train_dataset[0][1].shape}") # 目标尺寸
  170. pos_weight_value = recommended_pos_weight # 从上面的计算中获取
  171. # 创建带权重的损失函数
  172. criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight_value])).to(device)
  173. early_stopping = EarlyStoppingDist(patience=patience, verbose=True, delta=delta, path=os.path.join(output_dir, f'best_model_as_{batch_idx}.pth'),
  174. rank=rank, local_rank=local_rank)
  175. # 分布式训练模型
  176. train_losses, val_losses = train_process_distribute(
  177. model, optimizer, early_stopping, train_loader, val_loader, device, num_epochs=num_epochs, criterion=criterion,
  178. flag_distributed=flag_distributed, rank=rank, local_rank=local_rank, loss_call_label="val")
  179. if rank == 0:
  180. font_prop = font.font_prop
  181. # 绘制损失曲线(可选)
  182. plt.figure(figsize=(10, 6))
  183. epochs = range(1, len(train_losses) + 1)
  184. plt.plot(epochs, train_losses, 'b-', label='训练集损失')
  185. plt.plot(epochs, val_losses, 'r-', label='验证集损失')
  186. plt.title('训练和验证集损失曲线', fontproperties=font_prop)
  187. plt.xlabel('Epochs', fontproperties=font_prop)
  188. plt.ylabel('Loss', fontproperties=font_prop)
  189. plt.legend(prop=font_prop)
  190. plt.savefig(os.path.join(photo_dir, f"train_loss_batch_{batch_idx}.png"))
  191. # 训练结束后加载最佳模型参数
  192. best_model_path = os.path.join(output_dir, f'best_model_as_{batch_idx}.pth')
  193. # 确保所有进程都看到相同的文件系统状态
  194. if flag_distributed:
  195. dist.barrier()
  196. # 创建用于广播的列表(只有一个元素)
  197. checkpoint_list = [None]
  198. if rank == 0:
  199. if os.path.exists(best_model_path):
  200. print(f"Rank 0: batch_idx:{batch_idx} Loading best model from {best_model_path}")
  201. # 直接加载到 CPU,避免设备不一致问题
  202. checkpoint_list[0] = torch.load(best_model_path, map_location='cpu')
  203. else:
  204. print(f"Rank 0: batch_idx:{batch_idx} Warning - Best model not found at {best_model_path}")
  205. # 使用当前模型状态(确保在 CPU 上)
  206. if flag_distributed:
  207. checkpoint_list[0] = model.module.cpu().state_dict()
  208. else:
  209. checkpoint_list[0] = model.cpu().state_dict()
  210. # 广播模型状态字典
  211. if flag_distributed:
  212. dist.broadcast_object_list(checkpoint_list, src=0)
  213. # 所有进程获取广播后的状态字典
  214. checkpoint = checkpoint_list[0]
  215. # 加载模型状态
  216. if flag_distributed:
  217. model.module.load_state_dict(checkpoint)
  218. else:
  219. model.load_state_dict(checkpoint)
  220. # 确保所有进程完成加载
  221. if flag_distributed:
  222. dist.barrier()
  223. if flag_distributed:
  224. # 调用评估函数
  225. evaluate_model_distribute(
  226. model.module, # 使用 DDP 包裹前的原始模型
  227. device,
  228. None, None, None,
  229. test_loader=val_loader, # 使用累积验证集
  230. batch_flight_routes=batch_flight_routes, target_scaler=target_scaler,
  231. flag_distributed=flag_distributed,
  232. rank=rank, local_rank=local_rank, world_size=world_size,
  233. output_dir=output_dir, batch_idx=batch_idx, save_mode='a'
  234. )
  235. else:
  236. evaluate_model_distribute(
  237. model,
  238. device,
  239. None, None, None,
  240. test_loader=val_loader, # 使用累积验证集
  241. batch_flight_routes=batch_flight_routes, target_scaler=target_scaler,
  242. flag_distributed=False,
  243. output_dir=output_dir, batch_idx=batch_idx, save_mode='a'
  244. )
  245. return model
  246. def train_process_distribute(model, optimizer, early_stopping, train_loader, val_loader, device, num_epochs=200, criterion=None, save_file='best_model.pth',
  247. flag_distributed=False, rank=0, local_rank=0, loss_call_label="train"):
  248. # 具体训练过程
  249. train_losses = []
  250. val_losses = []
  251. # 初始化损失为张量(兼容非分布式和分布式)
  252. # total_train_loss = torch.tensor(0.0, device=device)
  253. # total_val_loss = torch.tensor(0.0, device=device)
  254. # 初始化 TensorBoard(只在主进程)
  255. # if rank == 0:
  256. # writer = SummaryWriter(log_dir='runs/experiment_name')
  257. # train_global_step = 0
  258. # val_global_step = 0
  259. for epoch in range(num_epochs):
  260. # --- 训练阶段 ---
  261. model.train()
  262. if flag_distributed:
  263. train_loader.sampler.set_epoch(epoch) # 确保每个进程一致地打乱顺序
  264. # total_train_loss.zero_() # 重置损失累计
  265. total_train_loss = torch.tensor(0.0, device=device)
  266. num_train_samples = torch.tensor(0, device=device) # 当前进程的样本数
  267. for batch_idx, batch in enumerate(train_loader):
  268. X_batch, y_batch = batch[:2] # 假设 group_ids 不需要参与训练
  269. X_batch = X_batch.to(device)
  270. y_batch = y_batch.to(device)
  271. optimizer.zero_grad()
  272. outputs = model(X_batch)
  273. loss = criterion(outputs, y_batch)
  274. loss.backward()
  275. # 打印
  276. # if rank == 0:
  277. # # print_gradient_range(model)
  278. # # 记录损失值
  279. # writer.add_scalar('Loss/train_batch', loss.item(), train_global_step)
  280. # # 记录元数据
  281. # writer.add_scalar('Metadata/train_epoch', epoch, train_global_step)
  282. # writer.add_scalar('Metadata/train_batch_in_epoch', batch_idx, train_global_step)
  283. # log_gradient_stats(model, writer, train_global_step, "train")
  284. # # 更新全局步数
  285. # train_global_step += 1
  286. # 梯度裁剪(已兼容 DDP)
  287. # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
  288. optimizer.step()
  289. # 累计损失
  290. total_train_loss += loss.detach() * X_batch.size(0) # detach() 保留张量形式以支持跨进程通信
  291. num_train_samples += X_batch.size(0)
  292. # --- 同步训练损失 ---
  293. if flag_distributed:
  294. # 会将所有进程的 total_train_loss 求和后, 同步到每个进程
  295. dist.all_reduce(total_train_loss, op=dist.ReduceOp.SUM)
  296. dist.all_reduce(num_train_samples, op=dist.ReduceOp.SUM)
  297. # avg_train_loss = total_train_loss.item() / len(train_loader.dataset)
  298. avg_train_loss = total_train_loss.item() / num_train_samples.item()
  299. train_losses.append(avg_train_loss)
  300. # --- 验证阶段 ---
  301. model.eval()
  302. # total_val_loss.zero_() # 重置验证损失
  303. total_val_loss = torch.tensor(0.0, device=device)
  304. num_val_samples = torch.tensor(0, device=device)
  305. with torch.no_grad():
  306. for batch_idx, batch in enumerate(val_loader):
  307. X_val, y_val = batch[:2]
  308. X_val = X_val.to(device)
  309. y_val = y_val.to(device)
  310. outputs = model(X_val)
  311. val_loss = criterion(outputs, y_val)
  312. total_val_loss += val_loss.detach() * X_val.size(0)
  313. num_val_samples += X_val.size(0)
  314. # if rank == 0:
  315. # # 记录验证集batch loss
  316. # writer.add_scalar('Loss/val_batch', val_loss.item(), val_global_step)
  317. # # 记录验证集元数据
  318. # writer.add_scalar('Metadata/val_epoch', epoch, val_global_step)
  319. # writer.add_scalar('Metadata/val_batch_in_epoch', batch_idx, val_global_step)
  320. # # 更新验证集全局步数
  321. # val_global_step += 1
  322. # if local_rank == 0:
  323. # print(f"rank:{rank}, outputs:{outputs}")
  324. # print(f"rank:{rank}, y_val:{y_val}")
  325. # print(f"rank:{rank}, val_loss:{val_loss.detach()}")
  326. # print(f"rank:{rank}, size:{X_val.size(0)}")
  327. # --- 同步验证损失 ---
  328. if flag_distributed:
  329. dist.all_reduce(total_val_loss, op=dist.ReduceOp.SUM)
  330. dist.all_reduce(num_val_samples, op=dist.ReduceOp.SUM)
  331. # avg_val_loss = total_val_loss.item() / len(val_loader.dataset)
  332. avg_val_loss = total_val_loss.item() / num_val_samples.item()
  333. val_losses.append(avg_val_loss)
  334. # if rank == 0:
  335. # # 记录epoch平均损失
  336. # writer.add_scalar('Loss/train_epoch_avg', avg_train_loss, epoch)
  337. # writer.add_scalar('Loss/val_epoch_avg', avg_val_loss, epoch)
  338. if local_rank == 0:
  339. print(f"Rank:{rank}, Epoch {epoch+1}/{num_epochs}, 训练集损失: {avg_train_loss:.4f}, 验证集损失: {avg_val_loss:.4f}")
  340. # --- 早停与保存逻辑(仅在 rank 0 执行)---
  341. if rank == 0:
  342. # 模型保存兼容分布式和非分布式
  343. model_to_save = model.module if flag_distributed else model # 当使用 model = DDP(model) 封装后,原始模型会被包裹在 model.module 属性
  344. if loss_call_label == "train":
  345. early_stopping(avg_train_loss, model_to_save) # 平均训练集损失
  346. else:
  347. early_stopping(avg_val_loss, model_to_save) # 平均验证集损失
  348. if early_stopping.early_stop:
  349. print(f"Rank:{rank}, 早停触发,停止训练 at epoch {epoch}")
  350. # 非分布式模式下直接退出循环
  351. if not flag_distributed:
  352. break
  353. # --- 同步早停状态(仅分布式需要)---
  354. if flag_distributed:
  355. # 将早停标志转换为张量广播
  356. early_stop_flag = torch.tensor([early_stopping.early_stop], device=device)
  357. dist.broadcast(early_stop_flag, src=0)
  358. if early_stop_flag.item(): # item()取张量的布尔值
  359. break
  360. # else:
  361. # # 非分布式模式下,直接检查早停标志
  362. # if early_stopping.early_stop:
  363. # break
  364. return train_losses, val_losses