| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605 |
- import warnings
- import os
- import torch
- import torch.distributed as dist
- from torch.nn.parallel import DistributedDataParallel as DDP
- import joblib
- import gc
- import pandas as pd
- # import numpy as np
- import redis
- import time
- import pickle
- import shutil
- from datetime import datetime, timedelta
- from utils import chunk_list_with_index, create_fixed_length_sequences
- from model import PriceDropClassifiTransModel
- from data_loader import mongo_con_parse, load_train_data
- from data_preprocess import preprocess_data, standardization
- from train import prepare_data_distribute, train_model_distribute
- from evaluate import printScore_cc
- from config import mongodb_config, vj_flight_route_list, vj_flight_route_list_hot, vj_flight_route_list_nothot, \
- CLEAN_VJ_HOT_NEAR_INFO_TAB, CLEAN_VJ_HOT_FAR_INFO_TAB, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB, CLEAN_VJ_NOTHOT_FAR_INFO_TAB
- warnings.filterwarnings('ignore')
- # 根据环境变量的存在设置分布式开关
- if 'LOCAL_RANK' in os.environ:
- FLAG_Distributed = True
- else:
- FLAG_Distributed = False
- # 定义特征和参数
- categorical_features = ['city_pair', 'flight_day', 'flight_number_1', 'flight_number_2'] # 这个与gid的分组条件一致
- common_features = ['hours_until_departure', 'days_to_departure', 'seats_remaining', 'is_cross_country', 'is_transfer',
- 'fly_duration', 'stop_duration',
- 'flight_by_hour', 'flight_by_day', 'flight_day_of_month', 'flight_day_of_week', 'flight_day_of_quarter', 'flight_day_is_weekend',
- 'dep_country_is_holiday', 'arr_country_is_holiday', 'any_country_is_holiday', 'days_to_holiday',
- ]
- price_features = ['adult_total_price', 'price_change_times_total', 'price_last_change_hours']
- encoded_columns = ['from_city_num', 'to_city_num', 'flight_1_num', 'flight_2_num', 'baggage_level']
- features = encoded_columns + price_features + common_features
- target_vars = ['target_will_price_drop'] # 是否降价
- # 分布式环境初始化
- def init_distributed_backend():
- if FLAG_Distributed:
- local_rank = int(os.environ['LOCAL_RANK'])
- # 关键:绑定设备必须在初始化进程组之前
- torch.cuda.set_device(local_rank) # 显式设置当前进程使用的 GPU
- try:
- dist.init_process_group(
- backend='nccl',
- init_method='env://',
- world_size=int(os.environ['WORLD_SIZE']),
- rank=int(os.environ['RANK']),
- timeout=timedelta(minutes=30)
- )
- print(f"Process group initialized for rank {dist.get_rank()}") # 添加日志
- except Exception as e:
- print(f"Failed to initialize process group: {e}") # 捕获异常
- raise
- device = torch.device("cuda", local_rank)
- else:
- # 如果不在分布式环境中, 使用默认设备
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- print("use common environment")
- return device
- # 初始化模型和相关参数
- def initialize_model(device):
- input_size = len(features)
- model = PriceDropClassifiTransModel(input_size, num_periods=2, hidden_size=64, num_layers=3, output_size=1, dropout=0.2)
- model.to(device)
- if FLAG_Distributed:
- model = DDP(model, device_ids=[device], find_unused_parameters=True) # 使用DDP包装模型
- if FLAG_Distributed:
- print(f"Rank:{dist.get_rank()}, 模型已初始化,输入尺寸:{input_size}")
- else:
- print(f"模型已初始化,输入尺寸:{input_size}")
- return model
- def continue_before_process(redis_client, lock_key):
- # rank0 跳出循环前的处理
- redis_client.set(lock_key, 2) # 设置 Redis 锁 key 的值为 2
- print("rank0 已将 Redis 锁 key 值设置为 2")
- time.sleep(5)
- print("rank0 5秒等待结束")
- def start_train():
- device = init_distributed_backend()
- model = initialize_model(device)
- if FLAG_Distributed:
- rank = dist.get_rank()
- local_rank = int(os.environ.get('LOCAL_RANK'))
- world_size = dist.get_world_size()
- else:
- rank = 0
- local_rank = 0
- world_size = 1
- output_dir = "./data_shards"
- photo_dir = "./photo"
- date_end = datetime.today().strftime("%Y-%m-%d")
- date_begin = (datetime.today() - timedelta(days=41)).strftime("%Y-%m-%d")
- # 仅在 rank == 0 时要做的
- if rank == 0:
- # 如果处理中断, 注释掉以下代码
- batch_dir = os.path.join(output_dir, "batches")
- try:
- shutil.rmtree(batch_dir)
- except FileNotFoundError:
- print(f"rank:{rank}, {batch_dir} not found")
- # 如果处理中断, 注释掉以下代码
- csv_file_list = ['evaluate_results.csv']
- for csv_file in csv_file_list:
- try:
- csv_path = os.path.join(output_dir, csv_file)
- os.remove(csv_path)
- except Exception as e:
- print(f"remove {csv_path}: {str(e)}")
- # 确保目录存在
- os.makedirs(output_dir, exist_ok=True)
- os.makedirs(photo_dir, exist_ok=True)
- print(f"最终特征列表:{features}")
- # 定义优化器和损失函数
- criterion = None # 后面在训练之前定义
- optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=1e-5)
- group_size = 1 # 每几组作为一个批次
- num_epochs_per_batch = 200 # 每个批次训练的轮数,可以根据需要调整
- feature_scaler = None # 初始化特征缩放器
- target_scaler = None # 初始化目标缩放器
- # 初始化 Redis 客户端(请根据实际情况修改 host、port、db)
- redis_client = redis.Redis(host='192.168.20.222', port=6379, db=0)
- lock_key = "data_loading_lock_11"
- barrier_key = 'distributed_barrier_11'
- assemble_size = 1 # 几个batch作为一个集群assemble
- batch_idx = -1
- batch_flight_routes = None # 占位, 避免其它rank找不到定义
- # 主干代码
- # flight_route_list = vj_flight_route_list_hot + vj_flight_route_list_nothot
- # flight_route_list_len = len(flight_route_list)
- # route_len_hot = len(vj_flight_route_list_hot)
- # route_len_nothot = len(vj_flight_route_list_nothot)
- # 调试代码
- s = 38 # 菲律宾2025-12-08是节假日 s=38 选到马尼拉
- flight_route_list = vj_flight_route_list_hot[0:] + vj_flight_route_list_nothot[s:]
- flight_route_list_len = len(flight_route_list)
- route_len_hot = len(vj_flight_route_list_hot[0:])
- route_len_nothot = len(vj_flight_route_list_nothot[s:])
-
- if local_rank == 0:
- print(f"flight_route_list_len:{flight_route_list_len}")
- print(f"route_len_hot:{route_len_hot}")
- print(f"route_len_nothot:{route_len_nothot}")
- # 如果处理中断,打开注释加载批次顺序
- # with open(os.path.join(output_dir, f'order.pkl'), "rb") as f:
- # flight_route_list = pickle.load(f)
- if rank == 0:
- pass
- # 保存批次顺序, 如果处理临时中断, 将这段代码注释掉
- with open(os.path.join(output_dir, f'order.pkl'), "wb") as f:
- pickle.dump(flight_route_list, f)
- chunks = chunk_list_with_index(flight_route_list, group_size)
- # 新增部分:计算总批次数并初始化 scaler 列表
- if rank == 0:
- total_batches = len(chunks)
- feature_scaler_list = [None] * total_batches # 预分配列表空间
- # target_scaler_list = [None] * total_batches # 预分配列表空间
- # 中断时,打开下面注释, 临时加载一下 scaler 列表
- # if rank == 0:
- # feature_scaler_path = os.path.join(output_dir, 'feature_scalers.joblib')
- # target_scaler_path = os.path.join(output_dir, 'target_scalers.joblib')
- # if os.path.exists(feature_scaler_path):
- # # 加载旧的scaler列表
- # old_feature_scaler_list = joblib.load(feature_scaler_path)
- # # 计算旧的总批次数
- # old_total_batches = len(old_feature_scaler_list)
- # # 只替换重叠部分
- # min_batches = min(old_total_batches, total_batches)
- # feature_scaler_list[:min_batches] = old_feature_scaler_list[:min_batches]
- # if os.path.exists(target_scaler_path):
- # # 加载旧的scaler列表
- # old_target_scaler_list = joblib.load(target_scaler_path)
- # # 计算旧的总批次数
- # old_total_batches = len(old_target_scaler_list)
- # # 只替换重叠部分
- # min_batches = min(old_total_batches, total_batches)
- # target_scaler_list[:min_batches] = old_target_scaler_list[:min_batches]
-
- # 如果临时处理中断,从日志里找到 中断的索引 修改它
- resume_chunk_idx = 0
- chunks = chunks[resume_chunk_idx:]
- if local_rank == 0:
- batch_starts = [start_idx for start_idx, _ in chunks]
- print(f"rank:{rank}, local_rank:{local_rank}, 训练阶段起始索引顺序:{batch_starts}")
-
- # 训练阶段
- for i, (_, group_route_list) in enumerate(chunks, start=resume_chunk_idx):
- # 特殊处理,跳过不好的批次
- pass
- redis_client.set(lock_key, 0)
- redis_client.set(barrier_key, 0)
- # 所有 Rank 同步的标志变量
- valid_batch = torch.tensor([1], dtype=torch.int, device=device) # 1表示有效批次
- # 仅在 rank == 0 时要做的
- if rank == 0:
- # Rank0 设置 Redis 锁 key 的初始值为 0,表示数据加载尚未完成
- redis_client.set(lock_key, 0)
- print("rank0 开始数据加载...")
- # 使用默认配置
- client, db = mongo_con_parse()
- print(f"第 {i} 组 :", group_route_list)
- batch_flight_routes = group_route_list
- # 根据索引位置决定是 热门 还是 冷门
- if 0 <= i < route_len_hot:
- is_hot = 1
- table_name = CLEAN_VJ_HOT_NEAR_INFO_TAB
- elif route_len_hot <= i < route_len_hot + route_len_nothot:
- is_hot = 0
- table_name = CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
- else:
- print(f"无法确定热门还是冷门, 跳过此批次。")
- continue_before_process(redis_client, lock_key)
- continue
-
- # 加载训练数据
- start_time = time.time()
- df_train = load_train_data(db, group_route_list, table_name, date_begin, date_end, output_dir, is_hot)
- end_time = time.time()
- run_time = round(end_time - start_time, 3)
- print(f"用时: {run_time} 秒")
- client.close()
- if df_train.empty:
- print(f"训练数据为空,跳过此批次。")
- continue_before_process(redis_client, lock_key)
- continue
-
- # 数据预处理
- df_train_inputs = preprocess_data(df_train, features, categorical_features, is_training=True)
- print("预处理后数据样本:\n", df_train_inputs.head())
- total_rows = df_train_inputs.shape[0]
- print(f"行数: {total_rows}")
- if total_rows == 0:
- print(f"预处理后的训练数据为空,跳过此批次。")
- continue_before_process(redis_client, lock_key)
- continue
-
- # 标准化与归一化处理
- df_train_inputs, feature_scaler, target_scaler = standardization(df_train_inputs, feature_scaler=None, target_scaler=None)
- # 将 scaler 存入列表
- batch_idx = i
- print("batch_idx:", batch_idx)
- feature_scaler_list[batch_idx] = feature_scaler
- # target_scaler_list[batch_idx] = target_scaler
- # 每个批次保存一下scaler
- feature_scaler_path = os.path.join(output_dir, f'feature_scalers.joblib')
- # target_scaler_path = os.path.join(output_dir, f'target_scalers.joblib')
- joblib.dump(feature_scaler_list, feature_scaler_path)
- # joblib.dump(target_scaler_list, target_scaler_path)
-
- # 生成序列
- sequences, targets, group_ids = create_fixed_length_sequences(df_train_inputs, features, target_vars, input_length=452)
-
- # 新增有效性检查
- if len(sequences) == 0 or len(targets) == 0 or len(group_ids) == 0:
- valid_batch[0] = 0
- print("警告:当前批次数据为空,标记为无效批次")
-
- # 数据加载及预处理完成,设置 Redis 锁 key 的值为 1
- redis_client.set(lock_key, 1)
- print("rank0 数据加载完成,已将 Redis 锁 key 值设置为 1")
- else:
- val = None
- # 其它 rank 等待:只有当 lock key 存在且其值为 "1" 时才算数据加载完成
- print(f"rank{rank} 正在等待 rank0 完成数据加载...")
- while True:
- val = redis_client.get(lock_key)
- if val is not None and val.decode('utf-8') in ["1", "2"]:
- break
- time.sleep(1)
- if val is not None and val.decode('utf-8') == "2":
- print(f"rank{rank} 跳过空批次 {i}")
- time.sleep(3)
- continue
- print(f"rank{rank} 检测到数据加载已完成,继续后续处理...")
- # 同步点:所有 Rank 在此等待
- if FLAG_Distributed:
- # 确保所有 CUDA 操作完成并释放缓存
- print(f"rank{rank} ready synchronize ...")
- torch.cuda.synchronize()
-
- print(f"rank{rank} ready empty_cache ...")
- torch.cuda.empty_cache()
-
- print(f"rank{rank} ready barrier ...")
- dist.barrier() # 移除 device_ids 参数
- # dist.barrier(device_ids=[local_rank])
- print(f"rank{rank} done barrier ...")
- # 广播批次有效性标志
- if FLAG_Distributed:
- dist.broadcast(valid_batch, src=0)
- # 所有 Rank 检查批次有效性
- if valid_batch.item() == 0:
- print(f"Rank {rank} 跳过无效批次 {i}")
- continue # 所有 Rank 跳过当前循环
- # 所有 Rank 同时进入数据分发
- if rank == 0:
- # 分片并分发
- my_sequences, my_targets, my_group_ids = distribute_sharded_data(sequences, targets, group_ids, world_size, rank, device, flag_distributed=FLAG_Distributed)
- else:
- # 其它 Rank 接收数据
- my_sequences, my_targets, my_group_ids = distribute_sharded_data([], [], [], world_size, rank, device, flag_distributed=FLAG_Distributed)
- # 查看一下各rank是否分到数据
- debug_print_shard_info([], my_targets, my_group_ids, rank, local_rank, world_size)
- pre_flag, train_single, val_single, test_single = prepare_data_distribute(my_sequences, my_targets, my_group_ids,
- flag_distributed=FLAG_Distributed, rank=rank, local_rank=local_rank, world_size=world_size)
-
- del my_sequences
- del my_targets
- del my_group_ids
- gc.collect()
- if not pre_flag:
- print(f"Rank {rank} 跳过无效数据批次 {i}")
- continue
- train_sequences = train_single['sequences']
- train_targets = train_single['targets']
- train_group_ids = train_single['group_ids']
- val_sequences = val_single['sequences']
- val_targets = val_single['targets']
- val_group_ids = val_single['group_ids']
- # test_sequences = test_single['sequences']
- # test_targets = test_single['targets']
- # test_group_ids = test_single['group_ids']
- if FLAG_Distributed:
- dist.barrier()
- # 训练模型
- model = train_model_distribute(train_sequences, train_targets, train_group_ids, val_sequences, val_targets, val_group_ids,
- model, criterion, optimizer, device, num_epochs=num_epochs_per_batch, batch_size=16, target_scaler=target_scaler,
- flag_distributed=FLAG_Distributed, rank=rank, local_rank=local_rank, world_size=world_size,
- output_dir=output_dir, photo_dir=photo_dir, batch_idx=batch_idx,
- batch_flight_routes=batch_flight_routes, patience=40, delta=0.001)
- del train_single
- del val_single
- del test_single
- gc.collect()
- # 重置模型参数
- if (i + 1) % assemble_size == 0:
-
- if FLAG_Distributed:
- dist.barrier()
- del model, optimizer
- torch.cuda.empty_cache() # 清理GPU缓存
- model = initialize_model(device) # 重置模型
-
- optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=1e-5) # 重置优化器
- print(f"Rank {rank}, Reset Model at batch {i} due to performance drop")
- ###############################################################################################################
- # 在整体批次训练结束后
- if rank == 0:
- # pass
- # torch.save(model.state_dict(), os.path.join(output_dir, 'final_model.pth'))
- print("模型训练完成并已保存。")
- csv_file = 'evaluate_results.csv'
- csv_path = os.path.join(output_dir, csv_file)
- # 汇总评估结果
- try:
- df = pd.read_csv(csv_path)
- except Exception as e:
- print(f"read {csv_path} error: {str(e)}")
- df = None
- if df is not None:
- # 提取真实值和预测值
- y_trues_class_labels = df['Actual_Will_Price_Drop']
- y_preds_class_labels = df['Predicted_Will_Price_Drop']
- printScore_cc(y_trues_class_labels, y_preds_class_labels, batch_fn_str='validate', batch_idx='')
- if FLAG_Distributed:
- dist.destroy_process_group() # 显式调用 destroy_process_group 来清理 NCCL 的进程组资源
- def distribute_sharded_data(sequences, targets, group_ids, world_size, rank, device, flag_distributed):
- # --- 非分布式模式:直接返回全量数据
- if not flag_distributed:
- return sequences, targets, group_ids
- # ================== 第一阶段:元数据广播 ==================
- if rank == 0:
- # 将 group_ids 序列化为字节流
- group_bytes = pickle.dumps(group_ids)
- # 转换为张量用于分块传输
- group_tensor = torch.frombuffer(bytearray(group_bytes), dtype=torch.uint8).to(device)
- # 处理其他数据
- seq_tensor = torch.stack(sequences, dim=0).to(device) # shape [N, 2, 452, 25]
- tgt_tensor = torch.stack(targets, dim=0).to(device) # shape [N, 1]
- meta_data = {
- # sequences/targets 元数据
- 'seq_shape': seq_tensor.shape,
- 'tgt_shape': tgt_tensor.shape,
- 'seq_dtype': str(seq_tensor.dtype).replace('torch.', ''), # 关键修改点
- 'tgt_dtype': str(tgt_tensor.dtype).replace('torch.', ''),
-
- # group_ids 元数据
- 'group_shape': group_tensor.shape,
- 'group_bytes_len': len(group_bytes),
- 'pickle_protocol': pickle.HIGHEST_PROTOCOL
- }
- else:
- meta_data = None
- # 广播元数据(所有rank都需要)
- meta_data = broadcast(meta_data, src=0, rank=rank, device=device)
-
- # ================== 第二阶段:分块传输 ==================
- # 初始化接收缓冲区(所有Rank)
- if rank == 0:
- group_tensor = group_tensor
- seq_tensor = seq_tensor
- tgt_tensor = tgt_tensor
- else:
- seq_dtype = getattr(torch, meta_data['seq_dtype']) # 例如 meta_data['seq_dtype'] = "float32"
- tgt_dtype = getattr(torch, meta_data['tgt_dtype'])
-
- group_tensor = torch.zeros(meta_data['group_shape'], dtype=torch.uint8, device=device)
- seq_tensor = torch.zeros(meta_data['seq_shape'], dtype=seq_dtype, device=device)
- tgt_tensor = torch.zeros(meta_data['tgt_shape'], dtype=tgt_dtype, device=device)
-
- # 并行传输所有数据(按传输量排序:先大后小)
- _chunked_broadcast(seq_tensor, src=0, rank=rank) # 最大数据优先
- _chunked_broadcast(tgt_tensor, src=0, rank=rank)
- _chunked_broadcast(group_tensor, src=0, rank=rank) # 最后传输group_ids
- # ================== 第三阶段:数据重建 ==================
- # 重建 sequences 和 targets
- sequences_list = [seq.cpu().clone() for seq in seq_tensor] # 自动按第0维切分
- targets_list = [tgt.cpu().clone() for tgt in tgt_tensor]
- # 重建 group_ids(关键步骤)
- if rank == 0:
- # Rank0直接使用原始数据避免重复序列化
- group_ids_rebuilt = group_ids
- else:
- # 1. 提取有效字节(去除填充)
- group_bytes = bytes(group_tensor.cpu().numpy().tobytes()[:meta_data['group_bytes_len']])
- # 2. 反序列化
- try:
- group_ids_rebuilt = pickle.loads(group_bytes)
- except pickle.UnpicklingError as e:
- raise RuntimeError(f"反序列化 group_ids 失败: {str(e)}")
-
- # 3. 结构校验
- _validate_group_structure(group_ids_rebuilt)
- return sequences_list, targets_list, group_ids_rebuilt
- def broadcast(data, src, rank, device):
- """安全地广播任意数据,确保张量在正确的设备上"""
- if rank == src:
- # 序列化数据
- data_bytes = pickle.dumps(data)
- data_size = torch.tensor([len(data_bytes)], dtype=torch.long, device=device)
- # 创建数据张量并移动到设备
- data_tensor = torch.frombuffer(bytearray(data_bytes), dtype=torch.uint8).to(device)
- # 先广播数据大小
- dist.broadcast(data_size, src=src)
- # 然后广播数据
- dist.broadcast(data_tensor, src=src)
- return data
- else:
- # 接收数据大小
- data_size = torch.tensor([0], dtype=torch.long, device=device)
- dist.broadcast(data_size, src=src)
- # 分配数据张量
- data_tensor = torch.empty(data_size.item(), dtype=torch.uint8, device=device)
- dist.broadcast(data_tensor, src=src)
- # 反序列化
- data = pickle.loads(data_tensor.cpu().numpy().tobytes())
- return data
-
- def _chunked_broadcast(tensor, src, rank, chunk_size=1024*1024*128): # chunk_size 单位是字节
- """分块广播张量优化通信效率"""
- # Step 1. 准备连续内存缓冲
- buffer = tensor.detach().contiguous()
- # Step 2. 计算字节数
- element_size = buffer.element_size() # 每个元素的字节数(如 float32 是 4)
- total_elements = buffer.numel()
- # 计算每个块最多包含多少元素(根据字节数换算)
- elements_per_chunk = chunk_size // element_size
- # 分块数量
- num_chunks = (total_elements + elements_per_chunk - 1) // elements_per_chunk
-
- # Step 4. 逐块广播
- for chunk_idx in range(num_chunks):
- # 计算当前块的字节范围
- start_element = chunk_idx * elements_per_chunk
- end_element = min((chunk_idx+1)*elements_per_chunk, total_elements)
- # Step 5. 从大张量中切出当前块
- chunk = buffer.view(-1).narrow(0, start_element, end_element - start_element)
- # Step 6. 执行广播
- dist.broadcast(chunk, src=src)
- # 说明: 虽然单个chunk是一维的, 但通过其内部的 1.严格的传输顺序 2.接收端的内存预分配 3.最终reshape操作 原始张量的形状得以完美恢复
- def _validate_group_structure(group_ids):
- """校验 group_ids 数据结构完整性"""
- assert isinstance(group_ids, list), "Group IDs 必须是列表"
- if len(group_ids) == 0:
- print("还原的 group_ids 长度为0")
- return
-
- sample = group_ids[0]
- assert isinstance(sample, tuple), "元素必须是元组"
- assert len(sample) == 9, "元组长度必须为9"
- def debug_print_shard_info(sequences, targets, group_ids, rank, local_rank, world_size):
- """分布式环境下按Rank顺序打印分片前5条样本"""
- # 同步所有进程
- if FLAG_Distributed:
- dist.barrier(device_ids=[local_rank])
- # 按Rank顺序逐个打印(避免输出混杂)
- for r in range(world_size):
- if r == rank:
- print(f"\n=== Rank {rank}/{world_size} Data Shard Samples (showing first 5) ===")
-
- # 打印序列数据
- # print("[Sequences]")
- # for i, seq in enumerate(sequences[:5]):
- # print(f"Sample {i}: {seq[:3]}...") # 只显示前3元素示意
- # 打印目标数据
- print("\n[Targets]")
- print(targets[:5])
- # 打印Group ID分布
- print("\n[Group IDs]")
- # unique_gids = list(set(group_ids[:50])) # 检查前50条的group分布
- print(f"First 5 GIDs: {group_ids[:5]}")
-
- # sys.stdout.flush() # 确保立即输出
- if FLAG_Distributed:
- dist.barrier(device_ids=[local_rank]) # 等待当前Rank打印完成
- if __name__ == "__main__":
- start_train()
|