import warnings import os import torch import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP import joblib import gc import pandas as pd # import numpy as np import redis import time import pickle import shutil from datetime import datetime, timedelta from utils import chunk_list_with_index, create_fixed_length_sequences from model import PriceDropClassifiTransModel from data_loader import mongo_con_parse, load_train_data from data_preprocess import preprocess_data, standardization from train import prepare_data_distribute, train_model_distribute from evaluate import printScore_cc from config import mongodb_config, vj_flight_route_list, vj_flight_route_list_hot, vj_flight_route_list_nothot, \ CLEAN_VJ_HOT_NEAR_INFO_TAB, CLEAN_VJ_HOT_FAR_INFO_TAB, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB, CLEAN_VJ_NOTHOT_FAR_INFO_TAB warnings.filterwarnings('ignore') # 根据环境变量的存在设置分布式开关 if 'LOCAL_RANK' in os.environ: FLAG_Distributed = True else: FLAG_Distributed = False # 定义特征和参数 categorical_features = ['city_pair', 'flight_day', 'flight_number_1', 'flight_number_2'] # 这个与gid的分组条件一致 common_features = ['hours_until_departure', 'days_to_departure', 'seats_remaining', 'is_cross_country', 'is_transfer', 'fly_duration', 'stop_duration', 'flight_by_hour', 'flight_by_day', 'flight_day_of_month', 'flight_day_of_week', 'flight_day_of_quarter', 'flight_day_is_weekend', 'dep_country_is_holiday', 'arr_country_is_holiday', 'any_country_is_holiday', 'days_to_holiday', ] price_features = ['adult_total_price', 'price_change_times_total', 'price_last_change_hours'] encoded_columns = ['from_city_num', 'to_city_num', 'flight_1_num', 'flight_2_num', 'baggage_level'] features = encoded_columns + price_features + common_features target_vars = ['target_will_price_drop'] # 是否降价 # 分布式环境初始化 def init_distributed_backend(): if FLAG_Distributed: local_rank = int(os.environ['LOCAL_RANK']) # 关键:绑定设备必须在初始化进程组之前 torch.cuda.set_device(local_rank) # 显式设置当前进程使用的 GPU try: dist.init_process_group( backend='nccl', init_method='env://', world_size=int(os.environ['WORLD_SIZE']), rank=int(os.environ['RANK']), timeout=timedelta(minutes=30) ) print(f"Process group initialized for rank {dist.get_rank()}") # 添加日志 except Exception as e: print(f"Failed to initialize process group: {e}") # 捕获异常 raise device = torch.device("cuda", local_rank) else: # 如果不在分布式环境中, 使用默认设备 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("use common environment") return device # 初始化模型和相关参数 def initialize_model(device): input_size = len(features) model = PriceDropClassifiTransModel(input_size, num_periods=2, hidden_size=64, num_layers=3, output_size=1, dropout=0.2) model.to(device) if FLAG_Distributed: model = DDP(model, device_ids=[device], find_unused_parameters=True) # 使用DDP包装模型 if FLAG_Distributed: print(f"Rank:{dist.get_rank()}, 模型已初始化,输入尺寸:{input_size}") else: print(f"模型已初始化,输入尺寸:{input_size}") return model def continue_before_process(redis_client, lock_key): # rank0 跳出循环前的处理 redis_client.set(lock_key, 2) # 设置 Redis 锁 key 的值为 2 print("rank0 已将 Redis 锁 key 值设置为 2") time.sleep(5) print("rank0 5秒等待结束") def start_train(): device = init_distributed_backend() model = initialize_model(device) if FLAG_Distributed: rank = dist.get_rank() local_rank = int(os.environ.get('LOCAL_RANK')) world_size = dist.get_world_size() else: rank = 0 local_rank = 0 world_size = 1 output_dir = "./data_shards" photo_dir = "./photo" date_end = datetime.today().strftime("%Y-%m-%d") date_begin = (datetime.today() - timedelta(days=41)).strftime("%Y-%m-%d") # 仅在 rank == 0 时要做的 if rank == 0: # 如果处理中断, 注释掉以下代码 batch_dir = os.path.join(output_dir, "batches") try: shutil.rmtree(batch_dir) except FileNotFoundError: print(f"rank:{rank}, {batch_dir} not found") # 如果处理中断, 注释掉以下代码 csv_file_list = ['evaluate_results.csv'] for csv_file in csv_file_list: try: csv_path = os.path.join(output_dir, csv_file) os.remove(csv_path) except Exception as e: print(f"remove {csv_path}: {str(e)}") # 确保目录存在 os.makedirs(output_dir, exist_ok=True) os.makedirs(photo_dir, exist_ok=True) print(f"最终特征列表:{features}") # 定义优化器和损失函数 criterion = None # 后面在训练之前定义 optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=1e-5) group_size = 1 # 每几组作为一个批次 num_epochs_per_batch = 200 # 每个批次训练的轮数,可以根据需要调整 feature_scaler = None # 初始化特征缩放器 target_scaler = None # 初始化目标缩放器 # 初始化 Redis 客户端(请根据实际情况修改 host、port、db) redis_client = redis.Redis(host='192.168.20.222', port=6379, db=0) lock_key = "data_loading_lock_11" barrier_key = 'distributed_barrier_11' assemble_size = 1 # 几个batch作为一个集群assemble batch_idx = -1 batch_flight_routes = None # 占位, 避免其它rank找不到定义 # 主干代码 # flight_route_list = vj_flight_route_list_hot + vj_flight_route_list_nothot # flight_route_list_len = len(flight_route_list) # route_len_hot = len(vj_flight_route_list_hot) # route_len_nothot = len(vj_flight_route_list_nothot) # 调试代码 s = 38 # 菲律宾2025-12-08是节假日 s=38 选到马尼拉 flight_route_list = vj_flight_route_list_hot[0:] + vj_flight_route_list_nothot[s:] flight_route_list_len = len(flight_route_list) route_len_hot = len(vj_flight_route_list_hot[0:]) route_len_nothot = len(vj_flight_route_list_nothot[s:]) if local_rank == 0: print(f"flight_route_list_len:{flight_route_list_len}") print(f"route_len_hot:{route_len_hot}") print(f"route_len_nothot:{route_len_nothot}") # 如果处理中断,打开注释加载批次顺序 # with open(os.path.join(output_dir, f'order.pkl'), "rb") as f: # flight_route_list = pickle.load(f) if rank == 0: pass # 保存批次顺序, 如果处理临时中断, 将这段代码注释掉 with open(os.path.join(output_dir, f'order.pkl'), "wb") as f: pickle.dump(flight_route_list, f) chunks = chunk_list_with_index(flight_route_list, group_size) # 新增部分:计算总批次数并初始化 scaler 列表 if rank == 0: total_batches = len(chunks) feature_scaler_list = [None] * total_batches # 预分配列表空间 # target_scaler_list = [None] * total_batches # 预分配列表空间 # 中断时,打开下面注释, 临时加载一下 scaler 列表 # if rank == 0: # feature_scaler_path = os.path.join(output_dir, 'feature_scalers.joblib') # target_scaler_path = os.path.join(output_dir, 'target_scalers.joblib') # if os.path.exists(feature_scaler_path): # # 加载旧的scaler列表 # old_feature_scaler_list = joblib.load(feature_scaler_path) # # 计算旧的总批次数 # old_total_batches = len(old_feature_scaler_list) # # 只替换重叠部分 # min_batches = min(old_total_batches, total_batches) # feature_scaler_list[:min_batches] = old_feature_scaler_list[:min_batches] # if os.path.exists(target_scaler_path): # # 加载旧的scaler列表 # old_target_scaler_list = joblib.load(target_scaler_path) # # 计算旧的总批次数 # old_total_batches = len(old_target_scaler_list) # # 只替换重叠部分 # min_batches = min(old_total_batches, total_batches) # target_scaler_list[:min_batches] = old_target_scaler_list[:min_batches] # 如果临时处理中断,从日志里找到 中断的索引 修改它 resume_chunk_idx = 0 chunks = chunks[resume_chunk_idx:] if local_rank == 0: batch_starts = [start_idx for start_idx, _ in chunks] print(f"rank:{rank}, local_rank:{local_rank}, 训练阶段起始索引顺序:{batch_starts}") # 训练阶段 for i, (_, group_route_list) in enumerate(chunks, start=resume_chunk_idx): # 特殊处理,跳过不好的批次 pass redis_client.set(lock_key, 0) redis_client.set(barrier_key, 0) # 所有 Rank 同步的标志变量 valid_batch = torch.tensor([1], dtype=torch.int, device=device) # 1表示有效批次 # 仅在 rank == 0 时要做的 if rank == 0: # Rank0 设置 Redis 锁 key 的初始值为 0,表示数据加载尚未完成 redis_client.set(lock_key, 0) print("rank0 开始数据加载...") # 使用默认配置 client, db = mongo_con_parse() print(f"第 {i} 组 :", group_route_list) batch_flight_routes = group_route_list # 根据索引位置决定是 热门 还是 冷门 if 0 <= i < route_len_hot: is_hot = 1 table_name = CLEAN_VJ_HOT_NEAR_INFO_TAB elif route_len_hot <= i < route_len_hot + route_len_nothot: is_hot = 0 table_name = CLEAN_VJ_NOTHOT_NEAR_INFO_TAB else: print(f"无法确定热门还是冷门, 跳过此批次。") continue_before_process(redis_client, lock_key) continue # 加载训练数据 start_time = time.time() df_train = load_train_data(db, group_route_list, table_name, date_begin, date_end, output_dir, is_hot) end_time = time.time() run_time = round(end_time - start_time, 3) print(f"用时: {run_time} 秒") client.close() if df_train.empty: print(f"训练数据为空,跳过此批次。") continue_before_process(redis_client, lock_key) continue # 数据预处理 df_train_inputs = preprocess_data(df_train, features, categorical_features, is_training=True) print("预处理后数据样本:\n", df_train_inputs.head()) total_rows = df_train_inputs.shape[0] print(f"行数: {total_rows}") if total_rows == 0: print(f"预处理后的训练数据为空,跳过此批次。") continue_before_process(redis_client, lock_key) continue # 标准化与归一化处理 df_train_inputs, feature_scaler, target_scaler = standardization(df_train_inputs, feature_scaler=None, target_scaler=None) # 将 scaler 存入列表 batch_idx = i print("batch_idx:", batch_idx) feature_scaler_list[batch_idx] = feature_scaler # target_scaler_list[batch_idx] = target_scaler # 每个批次保存一下scaler feature_scaler_path = os.path.join(output_dir, f'feature_scalers.joblib') # target_scaler_path = os.path.join(output_dir, f'target_scalers.joblib') joblib.dump(feature_scaler_list, feature_scaler_path) # joblib.dump(target_scaler_list, target_scaler_path) # 生成序列 sequences, targets, group_ids = create_fixed_length_sequences(df_train_inputs, features, target_vars, input_length=452) # 新增有效性检查 if len(sequences) == 0 or len(targets) == 0 or len(group_ids) == 0: valid_batch[0] = 0 print("警告:当前批次数据为空,标记为无效批次") # 数据加载及预处理完成,设置 Redis 锁 key 的值为 1 redis_client.set(lock_key, 1) print("rank0 数据加载完成,已将 Redis 锁 key 值设置为 1") else: val = None # 其它 rank 等待:只有当 lock key 存在且其值为 "1" 时才算数据加载完成 print(f"rank{rank} 正在等待 rank0 完成数据加载...") while True: val = redis_client.get(lock_key) if val is not None and val.decode('utf-8') in ["1", "2"]: break time.sleep(1) if val is not None and val.decode('utf-8') == "2": print(f"rank{rank} 跳过空批次 {i}") time.sleep(3) continue print(f"rank{rank} 检测到数据加载已完成,继续后续处理...") # 同步点:所有 Rank 在此等待 if FLAG_Distributed: # 确保所有 CUDA 操作完成并释放缓存 print(f"rank{rank} ready synchronize ...") torch.cuda.synchronize() print(f"rank{rank} ready empty_cache ...") torch.cuda.empty_cache() print(f"rank{rank} ready barrier ...") dist.barrier() # 移除 device_ids 参数 # dist.barrier(device_ids=[local_rank]) print(f"rank{rank} done barrier ...") # 广播批次有效性标志 if FLAG_Distributed: dist.broadcast(valid_batch, src=0) # 所有 Rank 检查批次有效性 if valid_batch.item() == 0: print(f"Rank {rank} 跳过无效批次 {i}") continue # 所有 Rank 跳过当前循环 # 所有 Rank 同时进入数据分发 if rank == 0: # 分片并分发 my_sequences, my_targets, my_group_ids = distribute_sharded_data(sequences, targets, group_ids, world_size, rank, device, flag_distributed=FLAG_Distributed) else: # 其它 Rank 接收数据 my_sequences, my_targets, my_group_ids = distribute_sharded_data([], [], [], world_size, rank, device, flag_distributed=FLAG_Distributed) # 查看一下各rank是否分到数据 debug_print_shard_info([], my_targets, my_group_ids, rank, local_rank, world_size) pre_flag, train_single, val_single, test_single = prepare_data_distribute(my_sequences, my_targets, my_group_ids, flag_distributed=FLAG_Distributed, rank=rank, local_rank=local_rank, world_size=world_size) del my_sequences del my_targets del my_group_ids gc.collect() if not pre_flag: print(f"Rank {rank} 跳过无效数据批次 {i}") continue train_sequences = train_single['sequences'] train_targets = train_single['targets'] train_group_ids = train_single['group_ids'] val_sequences = val_single['sequences'] val_targets = val_single['targets'] val_group_ids = val_single['group_ids'] # test_sequences = test_single['sequences'] # test_targets = test_single['targets'] # test_group_ids = test_single['group_ids'] if FLAG_Distributed: dist.barrier() # 训练模型 model = train_model_distribute(train_sequences, train_targets, train_group_ids, val_sequences, val_targets, val_group_ids, model, criterion, optimizer, device, num_epochs=num_epochs_per_batch, batch_size=16, target_scaler=target_scaler, flag_distributed=FLAG_Distributed, rank=rank, local_rank=local_rank, world_size=world_size, output_dir=output_dir, photo_dir=photo_dir, batch_idx=batch_idx, batch_flight_routes=batch_flight_routes, patience=40, delta=0.001) del train_single del val_single del test_single gc.collect() # 重置模型参数 if (i + 1) % assemble_size == 0: if FLAG_Distributed: dist.barrier() del model, optimizer torch.cuda.empty_cache() # 清理GPU缓存 model = initialize_model(device) # 重置模型 optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=1e-5) # 重置优化器 print(f"Rank {rank}, Reset Model at batch {i} due to performance drop") ############################################################################################################### # 在整体批次训练结束后 if rank == 0: # pass # torch.save(model.state_dict(), os.path.join(output_dir, 'final_model.pth')) print("模型训练完成并已保存。") csv_file = 'evaluate_results.csv' csv_path = os.path.join(output_dir, csv_file) # 汇总评估结果 try: df = pd.read_csv(csv_path) except Exception as e: print(f"read {csv_path} error: {str(e)}") df = None if df is not None: # 提取真实值和预测值 y_trues_class_labels = df['Actual_Will_Price_Drop'] y_preds_class_labels = df['Predicted_Will_Price_Drop'] printScore_cc(y_trues_class_labels, y_preds_class_labels, batch_fn_str='validate', batch_idx='') if FLAG_Distributed: dist.destroy_process_group() # 显式调用 destroy_process_group 来清理 NCCL 的进程组资源 def distribute_sharded_data(sequences, targets, group_ids, world_size, rank, device, flag_distributed): # --- 非分布式模式:直接返回全量数据 if not flag_distributed: return sequences, targets, group_ids # ================== 第一阶段:元数据广播 ================== if rank == 0: # 将 group_ids 序列化为字节流 group_bytes = pickle.dumps(group_ids) # 转换为张量用于分块传输 group_tensor = torch.frombuffer(bytearray(group_bytes), dtype=torch.uint8).to(device) # 处理其他数据 seq_tensor = torch.stack(sequences, dim=0).to(device) # shape [N, 2, 452, 25] tgt_tensor = torch.stack(targets, dim=0).to(device) # shape [N, 1] meta_data = { # sequences/targets 元数据 'seq_shape': seq_tensor.shape, 'tgt_shape': tgt_tensor.shape, 'seq_dtype': str(seq_tensor.dtype).replace('torch.', ''), # 关键修改点 'tgt_dtype': str(tgt_tensor.dtype).replace('torch.', ''), # group_ids 元数据 'group_shape': group_tensor.shape, 'group_bytes_len': len(group_bytes), 'pickle_protocol': pickle.HIGHEST_PROTOCOL } else: meta_data = None # 广播元数据(所有rank都需要) meta_data = broadcast(meta_data, src=0, rank=rank, device=device) # ================== 第二阶段:分块传输 ================== # 初始化接收缓冲区(所有Rank) if rank == 0: group_tensor = group_tensor seq_tensor = seq_tensor tgt_tensor = tgt_tensor else: seq_dtype = getattr(torch, meta_data['seq_dtype']) # 例如 meta_data['seq_dtype'] = "float32" tgt_dtype = getattr(torch, meta_data['tgt_dtype']) group_tensor = torch.zeros(meta_data['group_shape'], dtype=torch.uint8, device=device) seq_tensor = torch.zeros(meta_data['seq_shape'], dtype=seq_dtype, device=device) tgt_tensor = torch.zeros(meta_data['tgt_shape'], dtype=tgt_dtype, device=device) # 并行传输所有数据(按传输量排序:先大后小) _chunked_broadcast(seq_tensor, src=0, rank=rank) # 最大数据优先 _chunked_broadcast(tgt_tensor, src=0, rank=rank) _chunked_broadcast(group_tensor, src=0, rank=rank) # 最后传输group_ids # ================== 第三阶段:数据重建 ================== # 重建 sequences 和 targets sequences_list = [seq.cpu().clone() for seq in seq_tensor] # 自动按第0维切分 targets_list = [tgt.cpu().clone() for tgt in tgt_tensor] # 重建 group_ids(关键步骤) if rank == 0: # Rank0直接使用原始数据避免重复序列化 group_ids_rebuilt = group_ids else: # 1. 提取有效字节(去除填充) group_bytes = bytes(group_tensor.cpu().numpy().tobytes()[:meta_data['group_bytes_len']]) # 2. 反序列化 try: group_ids_rebuilt = pickle.loads(group_bytes) except pickle.UnpicklingError as e: raise RuntimeError(f"反序列化 group_ids 失败: {str(e)}") # 3. 结构校验 _validate_group_structure(group_ids_rebuilt) return sequences_list, targets_list, group_ids_rebuilt def broadcast(data, src, rank, device): """安全地广播任意数据,确保张量在正确的设备上""" if rank == src: # 序列化数据 data_bytes = pickle.dumps(data) data_size = torch.tensor([len(data_bytes)], dtype=torch.long, device=device) # 创建数据张量并移动到设备 data_tensor = torch.frombuffer(bytearray(data_bytes), dtype=torch.uint8).to(device) # 先广播数据大小 dist.broadcast(data_size, src=src) # 然后广播数据 dist.broadcast(data_tensor, src=src) return data else: # 接收数据大小 data_size = torch.tensor([0], dtype=torch.long, device=device) dist.broadcast(data_size, src=src) # 分配数据张量 data_tensor = torch.empty(data_size.item(), dtype=torch.uint8, device=device) dist.broadcast(data_tensor, src=src) # 反序列化 data = pickle.loads(data_tensor.cpu().numpy().tobytes()) return data def _chunked_broadcast(tensor, src, rank, chunk_size=1024*1024*128): # chunk_size 单位是字节 """分块广播张量优化通信效率""" # Step 1. 准备连续内存缓冲 buffer = tensor.detach().contiguous() # Step 2. 计算字节数 element_size = buffer.element_size() # 每个元素的字节数(如 float32 是 4) total_elements = buffer.numel() # 计算每个块最多包含多少元素(根据字节数换算) elements_per_chunk = chunk_size // element_size # 分块数量 num_chunks = (total_elements + elements_per_chunk - 1) // elements_per_chunk # Step 4. 逐块广播 for chunk_idx in range(num_chunks): # 计算当前块的字节范围 start_element = chunk_idx * elements_per_chunk end_element = min((chunk_idx+1)*elements_per_chunk, total_elements) # Step 5. 从大张量中切出当前块 chunk = buffer.view(-1).narrow(0, start_element, end_element - start_element) # Step 6. 执行广播 dist.broadcast(chunk, src=src) # 说明: 虽然单个chunk是一维的, 但通过其内部的 1.严格的传输顺序 2.接收端的内存预分配 3.最终reshape操作 原始张量的形状得以完美恢复 def _validate_group_structure(group_ids): """校验 group_ids 数据结构完整性""" assert isinstance(group_ids, list), "Group IDs 必须是列表" if len(group_ids) == 0: print("还原的 group_ids 长度为0") return sample = group_ids[0] assert isinstance(sample, tuple), "元素必须是元组" assert len(sample) == 9, "元组长度必须为9" def debug_print_shard_info(sequences, targets, group_ids, rank, local_rank, world_size): """分布式环境下按Rank顺序打印分片前5条样本""" # 同步所有进程 if FLAG_Distributed: dist.barrier(device_ids=[local_rank]) # 按Rank顺序逐个打印(避免输出混杂) for r in range(world_size): if r == rank: print(f"\n=== Rank {rank}/{world_size} Data Shard Samples (showing first 5) ===") # 打印序列数据 # print("[Sequences]") # for i, seq in enumerate(sequences[:5]): # print(f"Sample {i}: {seq[:3]}...") # 只显示前3元素示意 # 打印目标数据 print("\n[Targets]") print(targets[:5]) # 打印Group ID分布 print("\n[Group IDs]") # unique_gids = list(set(group_ids[:50])) # 检查前50条的group分布 print(f"First 5 GIDs: {group_ids[:5]}") # sys.stdout.flush() # 确保立即输出 if FLAG_Distributed: dist.barrier(device_ids=[local_rank]) # 等待当前Rank打印完成 if __name__ == "__main__": start_train()