main_tr.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605
  1. import warnings
  2. import os
  3. import torch
  4. import torch.distributed as dist
  5. from torch.nn.parallel import DistributedDataParallel as DDP
  6. import joblib
  7. import gc
  8. import pandas as pd
  9. # import numpy as np
  10. import redis
  11. import time
  12. import pickle
  13. import shutil
  14. from datetime import datetime, timedelta
  15. from utils import chunk_list_with_index, create_fixed_length_sequences
  16. from model import PriceDropClassifiTransModel
  17. from data_loader import mongo_con_parse, load_train_data
  18. from data_preprocess import preprocess_data, standardization
  19. from train import prepare_data_distribute, train_model_distribute
  20. from evaluate import printScore_cc
  21. from config import mongodb_config, vj_flight_route_list, vj_flight_route_list_hot, vj_flight_route_list_nothot, \
  22. CLEAN_VJ_HOT_NEAR_INFO_TAB, CLEAN_VJ_HOT_FAR_INFO_TAB, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB, CLEAN_VJ_NOTHOT_FAR_INFO_TAB
  23. warnings.filterwarnings('ignore')
  24. # 根据环境变量的存在设置分布式开关
  25. if 'LOCAL_RANK' in os.environ:
  26. FLAG_Distributed = True
  27. else:
  28. FLAG_Distributed = False
  29. # 定义特征和参数
  30. categorical_features = ['city_pair', 'flight_day', 'flight_number_1', 'flight_number_2'] # 这个与gid的分组条件一致
  31. common_features = ['hours_until_departure', 'days_to_departure', 'seats_remaining', 'is_cross_country', 'is_transfer',
  32. 'fly_duration', 'stop_duration',
  33. 'flight_by_hour', 'flight_by_day', 'flight_day_of_month', 'flight_day_of_week', 'flight_day_of_quarter', 'flight_day_is_weekend',
  34. 'dep_country_is_holiday', 'arr_country_is_holiday', 'any_country_is_holiday', 'days_to_holiday',
  35. ]
  36. price_features = ['adult_total_price', 'price_change_times_total', 'price_last_change_hours']
  37. encoded_columns = ['from_city_num', 'to_city_num', 'flight_1_num', 'flight_2_num', 'baggage_level']
  38. features = encoded_columns + price_features + common_features
  39. target_vars = ['target_will_price_drop'] # 是否降价
  40. # 分布式环境初始化
  41. def init_distributed_backend():
  42. if FLAG_Distributed:
  43. local_rank = int(os.environ['LOCAL_RANK'])
  44. # 关键:绑定设备必须在初始化进程组之前
  45. torch.cuda.set_device(local_rank) # 显式设置当前进程使用的 GPU
  46. try:
  47. dist.init_process_group(
  48. backend='nccl',
  49. init_method='env://',
  50. world_size=int(os.environ['WORLD_SIZE']),
  51. rank=int(os.environ['RANK']),
  52. timeout=timedelta(minutes=30)
  53. )
  54. print(f"Process group initialized for rank {dist.get_rank()}") # 添加日志
  55. except Exception as e:
  56. print(f"Failed to initialize process group: {e}") # 捕获异常
  57. raise
  58. device = torch.device("cuda", local_rank)
  59. else:
  60. # 如果不在分布式环境中, 使用默认设备
  61. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  62. print("use common environment")
  63. return device
  64. # 初始化模型和相关参数
  65. def initialize_model(device):
  66. input_size = len(features)
  67. model = PriceDropClassifiTransModel(input_size, num_periods=2, hidden_size=64, num_layers=3, output_size=1, dropout=0.2)
  68. model.to(device)
  69. if FLAG_Distributed:
  70. model = DDP(model, device_ids=[device], find_unused_parameters=True) # 使用DDP包装模型
  71. if FLAG_Distributed:
  72. print(f"Rank:{dist.get_rank()}, 模型已初始化,输入尺寸:{input_size}")
  73. else:
  74. print(f"模型已初始化,输入尺寸:{input_size}")
  75. return model
  76. def continue_before_process(redis_client, lock_key):
  77. # rank0 跳出循环前的处理
  78. redis_client.set(lock_key, 2) # 设置 Redis 锁 key 的值为 2
  79. print("rank0 已将 Redis 锁 key 值设置为 2")
  80. time.sleep(5)
  81. print("rank0 5秒等待结束")
  82. def start_train():
  83. device = init_distributed_backend()
  84. model = initialize_model(device)
  85. if FLAG_Distributed:
  86. rank = dist.get_rank()
  87. local_rank = int(os.environ.get('LOCAL_RANK'))
  88. world_size = dist.get_world_size()
  89. else:
  90. rank = 0
  91. local_rank = 0
  92. world_size = 1
  93. output_dir = "./data_shards"
  94. photo_dir = "./photo"
  95. date_end = datetime.today().strftime("%Y-%m-%d")
  96. date_begin = (datetime.today() - timedelta(days=41)).strftime("%Y-%m-%d")
  97. # 仅在 rank == 0 时要做的
  98. if rank == 0:
  99. # 如果处理中断, 注释掉以下代码
  100. batch_dir = os.path.join(output_dir, "batches")
  101. try:
  102. shutil.rmtree(batch_dir)
  103. except FileNotFoundError:
  104. print(f"rank:{rank}, {batch_dir} not found")
  105. # 如果处理中断, 注释掉以下代码
  106. csv_file_list = ['evaluate_results.csv']
  107. for csv_file in csv_file_list:
  108. try:
  109. csv_path = os.path.join(output_dir, csv_file)
  110. os.remove(csv_path)
  111. except Exception as e:
  112. print(f"remove {csv_path}: {str(e)}")
  113. # 确保目录存在
  114. os.makedirs(output_dir, exist_ok=True)
  115. os.makedirs(photo_dir, exist_ok=True)
  116. print(f"最终特征列表:{features}")
  117. # 定义优化器和损失函数
  118. criterion = None # 后面在训练之前定义
  119. optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=1e-5)
  120. group_size = 1 # 每几组作为一个批次
  121. num_epochs_per_batch = 200 # 每个批次训练的轮数,可以根据需要调整
  122. feature_scaler = None # 初始化特征缩放器
  123. target_scaler = None # 初始化目标缩放器
  124. # 初始化 Redis 客户端(请根据实际情况修改 host、port、db)
  125. redis_client = redis.Redis(host='192.168.20.222', port=6379, db=0)
  126. lock_key = "data_loading_lock_11"
  127. barrier_key = 'distributed_barrier_11'
  128. assemble_size = 1 # 几个batch作为一个集群assemble
  129. batch_idx = -1
  130. batch_flight_routes = None # 占位, 避免其它rank找不到定义
  131. # 主干代码
  132. # flight_route_list = vj_flight_route_list_hot + vj_flight_route_list_nothot
  133. # flight_route_list_len = len(flight_route_list)
  134. # route_len_hot = len(vj_flight_route_list_hot)
  135. # route_len_nothot = len(vj_flight_route_list_nothot)
  136. # 调试代码
  137. s = 38 # 菲律宾2025-12-08是节假日 s=38 选到马尼拉
  138. flight_route_list = vj_flight_route_list_hot[0:] + vj_flight_route_list_nothot[s:]
  139. flight_route_list_len = len(flight_route_list)
  140. route_len_hot = len(vj_flight_route_list_hot[0:])
  141. route_len_nothot = len(vj_flight_route_list_nothot[s:])
  142. if local_rank == 0:
  143. print(f"flight_route_list_len:{flight_route_list_len}")
  144. print(f"route_len_hot:{route_len_hot}")
  145. print(f"route_len_nothot:{route_len_nothot}")
  146. # 如果处理中断,打开注释加载批次顺序
  147. # with open(os.path.join(output_dir, f'order.pkl'), "rb") as f:
  148. # flight_route_list = pickle.load(f)
  149. if rank == 0:
  150. pass
  151. # 保存批次顺序, 如果处理临时中断, 将这段代码注释掉
  152. with open(os.path.join(output_dir, f'order.pkl'), "wb") as f:
  153. pickle.dump(flight_route_list, f)
  154. chunks = chunk_list_with_index(flight_route_list, group_size)
  155. # 新增部分:计算总批次数并初始化 scaler 列表
  156. if rank == 0:
  157. total_batches = len(chunks)
  158. feature_scaler_list = [None] * total_batches # 预分配列表空间
  159. # target_scaler_list = [None] * total_batches # 预分配列表空间
  160. # 中断时,打开下面注释, 临时加载一下 scaler 列表
  161. # if rank == 0:
  162. # feature_scaler_path = os.path.join(output_dir, 'feature_scalers.joblib')
  163. # target_scaler_path = os.path.join(output_dir, 'target_scalers.joblib')
  164. # if os.path.exists(feature_scaler_path):
  165. # # 加载旧的scaler列表
  166. # old_feature_scaler_list = joblib.load(feature_scaler_path)
  167. # # 计算旧的总批次数
  168. # old_total_batches = len(old_feature_scaler_list)
  169. # # 只替换重叠部分
  170. # min_batches = min(old_total_batches, total_batches)
  171. # feature_scaler_list[:min_batches] = old_feature_scaler_list[:min_batches]
  172. # if os.path.exists(target_scaler_path):
  173. # # 加载旧的scaler列表
  174. # old_target_scaler_list = joblib.load(target_scaler_path)
  175. # # 计算旧的总批次数
  176. # old_total_batches = len(old_target_scaler_list)
  177. # # 只替换重叠部分
  178. # min_batches = min(old_total_batches, total_batches)
  179. # target_scaler_list[:min_batches] = old_target_scaler_list[:min_batches]
  180. # 如果临时处理中断,从日志里找到 中断的索引 修改它
  181. resume_chunk_idx = 0
  182. chunks = chunks[resume_chunk_idx:]
  183. if local_rank == 0:
  184. batch_starts = [start_idx for start_idx, _ in chunks]
  185. print(f"rank:{rank}, local_rank:{local_rank}, 训练阶段起始索引顺序:{batch_starts}")
  186. # 训练阶段
  187. for i, (_, group_route_list) in enumerate(chunks, start=resume_chunk_idx):
  188. # 特殊处理,跳过不好的批次
  189. pass
  190. redis_client.set(lock_key, 0)
  191. redis_client.set(barrier_key, 0)
  192. # 所有 Rank 同步的标志变量
  193. valid_batch = torch.tensor([1], dtype=torch.int, device=device) # 1表示有效批次
  194. # 仅在 rank == 0 时要做的
  195. if rank == 0:
  196. # Rank0 设置 Redis 锁 key 的初始值为 0,表示数据加载尚未完成
  197. redis_client.set(lock_key, 0)
  198. print("rank0 开始数据加载...")
  199. # 使用默认配置
  200. client, db = mongo_con_parse()
  201. print(f"第 {i} 组 :", group_route_list)
  202. batch_flight_routes = group_route_list
  203. # 根据索引位置决定是 热门 还是 冷门
  204. if 0 <= i < route_len_hot:
  205. is_hot = 1
  206. table_name = CLEAN_VJ_HOT_NEAR_INFO_TAB
  207. elif route_len_hot <= i < route_len_hot + route_len_nothot:
  208. is_hot = 0
  209. table_name = CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
  210. else:
  211. print(f"无法确定热门还是冷门, 跳过此批次。")
  212. continue_before_process(redis_client, lock_key)
  213. continue
  214. # 加载训练数据
  215. start_time = time.time()
  216. df_train = load_train_data(db, group_route_list, table_name, date_begin, date_end, output_dir, is_hot)
  217. end_time = time.time()
  218. run_time = round(end_time - start_time, 3)
  219. print(f"用时: {run_time} 秒")
  220. client.close()
  221. if df_train.empty:
  222. print(f"训练数据为空,跳过此批次。")
  223. continue_before_process(redis_client, lock_key)
  224. continue
  225. # 数据预处理
  226. df_train_inputs = preprocess_data(df_train, features, categorical_features, is_training=True)
  227. print("预处理后数据样本:\n", df_train_inputs.head())
  228. total_rows = df_train_inputs.shape[0]
  229. print(f"行数: {total_rows}")
  230. if total_rows == 0:
  231. print(f"预处理后的训练数据为空,跳过此批次。")
  232. continue_before_process(redis_client, lock_key)
  233. continue
  234. # 标准化与归一化处理
  235. df_train_inputs, feature_scaler, target_scaler = standardization(df_train_inputs, feature_scaler=None, target_scaler=None)
  236. # 将 scaler 存入列表
  237. batch_idx = i
  238. print("batch_idx:", batch_idx)
  239. feature_scaler_list[batch_idx] = feature_scaler
  240. # target_scaler_list[batch_idx] = target_scaler
  241. # 每个批次保存一下scaler
  242. feature_scaler_path = os.path.join(output_dir, f'feature_scalers.joblib')
  243. # target_scaler_path = os.path.join(output_dir, f'target_scalers.joblib')
  244. joblib.dump(feature_scaler_list, feature_scaler_path)
  245. # joblib.dump(target_scaler_list, target_scaler_path)
  246. # 生成序列
  247. sequences, targets, group_ids = create_fixed_length_sequences(df_train_inputs, features, target_vars, input_length=452)
  248. # 新增有效性检查
  249. if len(sequences) == 0 or len(targets) == 0 or len(group_ids) == 0:
  250. valid_batch[0] = 0
  251. print("警告:当前批次数据为空,标记为无效批次")
  252. # 数据加载及预处理完成,设置 Redis 锁 key 的值为 1
  253. redis_client.set(lock_key, 1)
  254. print("rank0 数据加载完成,已将 Redis 锁 key 值设置为 1")
  255. else:
  256. val = None
  257. # 其它 rank 等待:只有当 lock key 存在且其值为 "1" 时才算数据加载完成
  258. print(f"rank{rank} 正在等待 rank0 完成数据加载...")
  259. while True:
  260. val = redis_client.get(lock_key)
  261. if val is not None and val.decode('utf-8') in ["1", "2"]:
  262. break
  263. time.sleep(1)
  264. if val is not None and val.decode('utf-8') == "2":
  265. print(f"rank{rank} 跳过空批次 {i}")
  266. time.sleep(3)
  267. continue
  268. print(f"rank{rank} 检测到数据加载已完成,继续后续处理...")
  269. # 同步点:所有 Rank 在此等待
  270. if FLAG_Distributed:
  271. # 确保所有 CUDA 操作完成并释放缓存
  272. print(f"rank{rank} ready synchronize ...")
  273. torch.cuda.synchronize()
  274. print(f"rank{rank} ready empty_cache ...")
  275. torch.cuda.empty_cache()
  276. print(f"rank{rank} ready barrier ...")
  277. dist.barrier() # 移除 device_ids 参数
  278. # dist.barrier(device_ids=[local_rank])
  279. print(f"rank{rank} done barrier ...")
  280. # 广播批次有效性标志
  281. if FLAG_Distributed:
  282. dist.broadcast(valid_batch, src=0)
  283. # 所有 Rank 检查批次有效性
  284. if valid_batch.item() == 0:
  285. print(f"Rank {rank} 跳过无效批次 {i}")
  286. continue # 所有 Rank 跳过当前循环
  287. # 所有 Rank 同时进入数据分发
  288. if rank == 0:
  289. # 分片并分发
  290. my_sequences, my_targets, my_group_ids = distribute_sharded_data(sequences, targets, group_ids, world_size, rank, device, flag_distributed=FLAG_Distributed)
  291. else:
  292. # 其它 Rank 接收数据
  293. my_sequences, my_targets, my_group_ids = distribute_sharded_data([], [], [], world_size, rank, device, flag_distributed=FLAG_Distributed)
  294. # 查看一下各rank是否分到数据
  295. debug_print_shard_info([], my_targets, my_group_ids, rank, local_rank, world_size)
  296. pre_flag, train_single, val_single, test_single = prepare_data_distribute(my_sequences, my_targets, my_group_ids,
  297. flag_distributed=FLAG_Distributed, rank=rank, local_rank=local_rank, world_size=world_size)
  298. del my_sequences
  299. del my_targets
  300. del my_group_ids
  301. gc.collect()
  302. if not pre_flag:
  303. print(f"Rank {rank} 跳过无效数据批次 {i}")
  304. continue
  305. train_sequences = train_single['sequences']
  306. train_targets = train_single['targets']
  307. train_group_ids = train_single['group_ids']
  308. val_sequences = val_single['sequences']
  309. val_targets = val_single['targets']
  310. val_group_ids = val_single['group_ids']
  311. # test_sequences = test_single['sequences']
  312. # test_targets = test_single['targets']
  313. # test_group_ids = test_single['group_ids']
  314. if FLAG_Distributed:
  315. dist.barrier()
  316. # 训练模型
  317. model = train_model_distribute(train_sequences, train_targets, train_group_ids, val_sequences, val_targets, val_group_ids,
  318. model, criterion, optimizer, device, num_epochs=num_epochs_per_batch, batch_size=16, target_scaler=target_scaler,
  319. flag_distributed=FLAG_Distributed, rank=rank, local_rank=local_rank, world_size=world_size,
  320. output_dir=output_dir, photo_dir=photo_dir, batch_idx=batch_idx,
  321. batch_flight_routes=batch_flight_routes, patience=40, delta=0.001)
  322. del train_single
  323. del val_single
  324. del test_single
  325. gc.collect()
  326. # 重置模型参数
  327. if (i + 1) % assemble_size == 0:
  328. if FLAG_Distributed:
  329. dist.barrier()
  330. del model, optimizer
  331. torch.cuda.empty_cache() # 清理GPU缓存
  332. model = initialize_model(device) # 重置模型
  333. optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=1e-5) # 重置优化器
  334. print(f"Rank {rank}, Reset Model at batch {i} due to performance drop")
  335. ###############################################################################################################
  336. # 在整体批次训练结束后
  337. if rank == 0:
  338. # pass
  339. # torch.save(model.state_dict(), os.path.join(output_dir, 'final_model.pth'))
  340. print("模型训练完成并已保存。")
  341. csv_file = 'evaluate_results.csv'
  342. csv_path = os.path.join(output_dir, csv_file)
  343. # 汇总评估结果
  344. try:
  345. df = pd.read_csv(csv_path)
  346. except Exception as e:
  347. print(f"read {csv_path} error: {str(e)}")
  348. df = None
  349. if df is not None:
  350. # 提取真实值和预测值
  351. y_trues_class_labels = df['Actual_Will_Price_Drop']
  352. y_preds_class_labels = df['Predicted_Will_Price_Drop']
  353. printScore_cc(y_trues_class_labels, y_preds_class_labels, batch_fn_str='validate', batch_idx='')
  354. if FLAG_Distributed:
  355. dist.destroy_process_group() # 显式调用 destroy_process_group 来清理 NCCL 的进程组资源
  356. def distribute_sharded_data(sequences, targets, group_ids, world_size, rank, device, flag_distributed):
  357. # --- 非分布式模式:直接返回全量数据
  358. if not flag_distributed:
  359. return sequences, targets, group_ids
  360. # ================== 第一阶段:元数据广播 ==================
  361. if rank == 0:
  362. # 将 group_ids 序列化为字节流
  363. group_bytes = pickle.dumps(group_ids)
  364. # 转换为张量用于分块传输
  365. group_tensor = torch.frombuffer(bytearray(group_bytes), dtype=torch.uint8).to(device)
  366. # 处理其他数据
  367. seq_tensor = torch.stack(sequences, dim=0).to(device) # shape [N, 2, 452, 25]
  368. tgt_tensor = torch.stack(targets, dim=0).to(device) # shape [N, 1]
  369. meta_data = {
  370. # sequences/targets 元数据
  371. 'seq_shape': seq_tensor.shape,
  372. 'tgt_shape': tgt_tensor.shape,
  373. 'seq_dtype': str(seq_tensor.dtype).replace('torch.', ''), # 关键修改点
  374. 'tgt_dtype': str(tgt_tensor.dtype).replace('torch.', ''),
  375. # group_ids 元数据
  376. 'group_shape': group_tensor.shape,
  377. 'group_bytes_len': len(group_bytes),
  378. 'pickle_protocol': pickle.HIGHEST_PROTOCOL
  379. }
  380. else:
  381. meta_data = None
  382. # 广播元数据(所有rank都需要)
  383. meta_data = broadcast(meta_data, src=0, rank=rank, device=device)
  384. # ================== 第二阶段:分块传输 ==================
  385. # 初始化接收缓冲区(所有Rank)
  386. if rank == 0:
  387. group_tensor = group_tensor
  388. seq_tensor = seq_tensor
  389. tgt_tensor = tgt_tensor
  390. else:
  391. seq_dtype = getattr(torch, meta_data['seq_dtype']) # 例如 meta_data['seq_dtype'] = "float32"
  392. tgt_dtype = getattr(torch, meta_data['tgt_dtype'])
  393. group_tensor = torch.zeros(meta_data['group_shape'], dtype=torch.uint8, device=device)
  394. seq_tensor = torch.zeros(meta_data['seq_shape'], dtype=seq_dtype, device=device)
  395. tgt_tensor = torch.zeros(meta_data['tgt_shape'], dtype=tgt_dtype, device=device)
  396. # 并行传输所有数据(按传输量排序:先大后小)
  397. _chunked_broadcast(seq_tensor, src=0, rank=rank) # 最大数据优先
  398. _chunked_broadcast(tgt_tensor, src=0, rank=rank)
  399. _chunked_broadcast(group_tensor, src=0, rank=rank) # 最后传输group_ids
  400. # ================== 第三阶段:数据重建 ==================
  401. # 重建 sequences 和 targets
  402. sequences_list = [seq.cpu().clone() for seq in seq_tensor] # 自动按第0维切分
  403. targets_list = [tgt.cpu().clone() for tgt in tgt_tensor]
  404. # 重建 group_ids(关键步骤)
  405. if rank == 0:
  406. # Rank0直接使用原始数据避免重复序列化
  407. group_ids_rebuilt = group_ids
  408. else:
  409. # 1. 提取有效字节(去除填充)
  410. group_bytes = bytes(group_tensor.cpu().numpy().tobytes()[:meta_data['group_bytes_len']])
  411. # 2. 反序列化
  412. try:
  413. group_ids_rebuilt = pickle.loads(group_bytes)
  414. except pickle.UnpicklingError as e:
  415. raise RuntimeError(f"反序列化 group_ids 失败: {str(e)}")
  416. # 3. 结构校验
  417. _validate_group_structure(group_ids_rebuilt)
  418. return sequences_list, targets_list, group_ids_rebuilt
  419. def broadcast(data, src, rank, device):
  420. """安全地广播任意数据,确保张量在正确的设备上"""
  421. if rank == src:
  422. # 序列化数据
  423. data_bytes = pickle.dumps(data)
  424. data_size = torch.tensor([len(data_bytes)], dtype=torch.long, device=device)
  425. # 创建数据张量并移动到设备
  426. data_tensor = torch.frombuffer(bytearray(data_bytes), dtype=torch.uint8).to(device)
  427. # 先广播数据大小
  428. dist.broadcast(data_size, src=src)
  429. # 然后广播数据
  430. dist.broadcast(data_tensor, src=src)
  431. return data
  432. else:
  433. # 接收数据大小
  434. data_size = torch.tensor([0], dtype=torch.long, device=device)
  435. dist.broadcast(data_size, src=src)
  436. # 分配数据张量
  437. data_tensor = torch.empty(data_size.item(), dtype=torch.uint8, device=device)
  438. dist.broadcast(data_tensor, src=src)
  439. # 反序列化
  440. data = pickle.loads(data_tensor.cpu().numpy().tobytes())
  441. return data
  442. def _chunked_broadcast(tensor, src, rank, chunk_size=1024*1024*128): # chunk_size 单位是字节
  443. """分块广播张量优化通信效率"""
  444. # Step 1. 准备连续内存缓冲
  445. buffer = tensor.detach().contiguous()
  446. # Step 2. 计算字节数
  447. element_size = buffer.element_size() # 每个元素的字节数(如 float32 是 4)
  448. total_elements = buffer.numel()
  449. # 计算每个块最多包含多少元素(根据字节数换算)
  450. elements_per_chunk = chunk_size // element_size
  451. # 分块数量
  452. num_chunks = (total_elements + elements_per_chunk - 1) // elements_per_chunk
  453. # Step 4. 逐块广播
  454. for chunk_idx in range(num_chunks):
  455. # 计算当前块的字节范围
  456. start_element = chunk_idx * elements_per_chunk
  457. end_element = min((chunk_idx+1)*elements_per_chunk, total_elements)
  458. # Step 5. 从大张量中切出当前块
  459. chunk = buffer.view(-1).narrow(0, start_element, end_element - start_element)
  460. # Step 6. 执行广播
  461. dist.broadcast(chunk, src=src)
  462. # 说明: 虽然单个chunk是一维的, 但通过其内部的 1.严格的传输顺序 2.接收端的内存预分配 3.最终reshape操作 原始张量的形状得以完美恢复
  463. def _validate_group_structure(group_ids):
  464. """校验 group_ids 数据结构完整性"""
  465. assert isinstance(group_ids, list), "Group IDs 必须是列表"
  466. if len(group_ids) == 0:
  467. print("还原的 group_ids 长度为0")
  468. return
  469. sample = group_ids[0]
  470. assert isinstance(sample, tuple), "元素必须是元组"
  471. assert len(sample) == 11, "元组长度必须为11"
  472. def debug_print_shard_info(sequences, targets, group_ids, rank, local_rank, world_size):
  473. """分布式环境下按Rank顺序打印分片前5条样本"""
  474. # 同步所有进程
  475. if FLAG_Distributed:
  476. dist.barrier(device_ids=[local_rank])
  477. # 按Rank顺序逐个打印(避免输出混杂)
  478. for r in range(world_size):
  479. if r == rank:
  480. print(f"\n=== Rank {rank}/{world_size} Data Shard Samples (showing first 5) ===")
  481. # 打印序列数据
  482. # print("[Sequences]")
  483. # for i, seq in enumerate(sequences[:5]):
  484. # print(f"Sample {i}: {seq[:3]}...") # 只显示前3元素示意
  485. # 打印目标数据
  486. print("\n[Targets]")
  487. print(targets[:5])
  488. # 打印Group ID分布
  489. print("\n[Group IDs]")
  490. # unique_gids = list(set(group_ids[:50])) # 检查前50条的group分布
  491. print(f"First 5 GIDs: {group_ids[:5]}")
  492. # sys.stdout.flush() # 确保立即输出
  493. if FLAG_Distributed:
  494. dist.barrier(device_ids=[local_rank]) # 等待当前Rank打印完成
  495. if __name__ == "__main__":
  496. start_train()