import warnings import os import torch import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP import joblib import gc import pandas as pd import numpy as np import redis import time import pickle import shutil from datetime import datetime, timedelta from utils import chunk_list_with_index, create_fixed_length_sequences from data_loader import mongo_con_parse, load_train_data from data_preprocess import preprocess_data, standardization from config import mongodb_config, vj_flight_route_list, vj_flight_route_list_hot, vj_flight_route_list_nothot, \ CLEAN_VJ_HOT_NEAR_INFO_TAB, CLEAN_VJ_HOT_FAR_INFO_TAB, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB, CLEAN_VJ_NOTHOT_FAR_INFO_TAB warnings.filterwarnings('ignore') # 根据环境变量的存在设置分布式开关 if 'LOCAL_RANK' in os.environ: FLAG_Distributed = True else: FLAG_Distributed = False # 定义特征和参数 categorical_features = ['city_pair', 'flight_day', 'flight_number_1', 'flight_number_2'] # 这个与gid的分组条件一致 common_features = ['hours_until_departure', 'days_to_departure', 'seats_remaining', 'is_cross_country', 'is_transfer', 'fly_duration', 'stop_duration', 'flight_by_hour', 'flight_by_day', 'flight_day_of_month', 'flight_day_of_week', 'flight_day_of_quarter', 'flight_day_is_weekend', 'dep_country_is_holiday', 'arr_country_is_holiday', 'any_country_is_holiday', 'days_to_holiday', ] price_features = ['adult_total_price', 'price_change_times_total', 'price_last_change_hours'] encoded_columns = ['from_city_num', 'to_city_num', 'flight_1_num', 'flight_2_num', 'baggage_level'] features = encoded_columns + price_features + common_features target_vars = ['target_will_price_drop'] # 是否降价 # 分布式环境初始化 def init_distributed_backend(): if FLAG_Distributed: local_rank = int(os.environ['LOCAL_RANK']) # 关键:绑定设备必须在初始化进程组之前 torch.cuda.set_device(local_rank) # 显式设置当前进程使用的 GPU try: dist.init_process_group( backend='nccl', init_method='env://', world_size=int(os.environ['WORLD_SIZE']), rank=int(os.environ['RANK']), timeout=timedelta(minutes=30) ) print(f"Process group initialized for rank {dist.get_rank()}") # 添加日志 except Exception as e: print(f"Failed to initialize process group: {e}") # 捕获异常 raise device = torch.device("cuda", local_rank) else: # 如果不在分布式环境中, 使用默认设备 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("use common environment") return device # 初始化模型和相关参数 def initialize_model(device): return None def continue_before_process(redis_client, lock_key): # rank0 跳出循环前的处理 redis_client.set(lock_key, 2) # 设置 Redis 锁 key 的值为 2 print("rank0 已将 Redis 锁 key 值设置为 2") time.sleep(5) print("rank0 5秒等待结束") def start_train(): device = init_distributed_backend() model = initialize_model(device) if FLAG_Distributed: rank = dist.get_rank() local_rank = int(os.environ.get('LOCAL_RANK')) world_size = dist.get_world_size() else: rank = 0 local_rank = 0 world_size = 1 output_dir = "./data_shards" photo_dir = "./photo" date_end = datetime.today().strftime("%Y-%m-%d") date_begin = (datetime.today() - timedelta(days=18)).strftime("%Y-%m-%d") # 仅在 rank == 0 时要做的 if rank == 0: # 如果处理中断, 注释掉以下代码 batch_dir = os.path.join(output_dir, "batches") try: shutil.rmtree(batch_dir) except FileNotFoundError: print(f"rank:{rank}, {batch_dir} not found") # 如果处理中断, 注释掉以下代码 csv_file_list = ['evaluate_results.csv'] for csv_file in csv_file_list: try: csv_path = os.path.join(output_dir, csv_file) os.remove(csv_path) except Exception as e: print(f"remove {csv_path}: {str(e)}") # 确保目录存在 os.makedirs(output_dir, exist_ok=True) os.makedirs(photo_dir, exist_ok=True) print(f"最终特征列表:{features}") # 定义优化器和损失函数(只回归) # criterion = RegressionLoss(loss_func_flag="Quantile", quantile=0.5) # optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=1e-5) group_size = 1 # 每几组作为一个批次 num_epochs_per_batch = 200 # 每个批次训练的轮数,可以根据需要调整 feature_scaler = None # 初始化特征缩放器 target_scaler = None # 初始化目标缩放器 # 初始化 Redis 客户端(请根据实际情况修改 host、port、db) redis_client = redis.Redis(host='192.168.20.222', port=6379, db=0) lock_key = "data_loading_lock_11" barrier_key = 'distributed_barrier_11' batch_idx = -1 # 主干代码 # flight_route_list = vj_flight_route_list_hot + vj_flight_route_list_nothot # flight_route_list_len = len(flight_route_list) # route_len_hot = len(vj_flight_route_list_hot) # route_len_nothot = len(vj_flight_route_list_nothot) # 调试代码 s = 38 # 菲律宾2025-12-08是节假日 s=38 选到马尼拉 flight_route_list = vj_flight_route_list_hot[:0] + vj_flight_route_list_nothot[s:] flight_route_list_len = len(flight_route_list) route_len_hot = len(vj_flight_route_list_hot[:0]) route_len_nothot = len(vj_flight_route_list_nothot[s:]) if local_rank == 0: print(f"flight_route_list_len:{flight_route_list_len}") print(f"route_len_hot:{route_len_hot}") print(f"route_len_nothot:{route_len_nothot}") # 如果处理中断,打开注释加载批次顺序 # with open(os.path.join(output_dir, f'order.pkl'), "rb") as f: # flight_route_list = pickle.load(f) if rank == 0: pass # 保存批次顺序, 如果处理临时中断, 将这段代码注释掉 with open(os.path.join(output_dir, f'order.pkl'), "wb") as f: pickle.dump(flight_route_list, f) chunks = chunk_list_with_index(flight_route_list, group_size) # 新增部分:计算总批次数并初始化 scaler 列表 if rank == 0: total_batches = len(chunks) feature_scaler_list = [None] * total_batches # 预分配列表空间 # target_scaler_list = [None] * total_batches # 预分配列表空间 # 中断时,打开下面注释, 临时加载一下 scaler 列表 # if rank == 0: # feature_scaler_path = os.path.join(output_dir, 'feature_scalers.joblib') # target_scaler_path = os.path.join(output_dir, 'target_scalers.joblib') # if os.path.exists(feature_scaler_path): # # 加载旧的scaler列表 # old_feature_scaler_list = joblib.load(feature_scaler_path) # # 计算旧的总批次数 # old_total_batches = len(old_feature_scaler_list) # # 只替换重叠部分 # min_batches = min(old_total_batches, total_batches) # feature_scaler_list[:min_batches] = old_feature_scaler_list[:min_batches] # if os.path.exists(target_scaler_path): # # 加载旧的scaler列表 # old_target_scaler_list = joblib.load(target_scaler_path) # # 计算旧的总批次数 # old_total_batches = len(old_target_scaler_list) # # 只替换重叠部分 # min_batches = min(old_total_batches, total_batches) # target_scaler_list[:min_batches] = old_target_scaler_list[:min_batches] # 如果临时处理中断,从日志里找到 中断的索引 修改它 resume_chunk_idx = 0 chunks = chunks[resume_chunk_idx:] if local_rank == 0: batch_starts = [start_idx for start_idx, _ in chunks] print(f"rank:{rank}, local_rank:{local_rank}, 训练阶段起始索引顺序:{batch_starts}") # 训练阶段 for i, (_, group_route_list) in enumerate(chunks, start=resume_chunk_idx): # 特殊处理,跳过不好的批次 pass redis_client.set(lock_key, 0) redis_client.set(barrier_key, 0) # 所有 Rank 同步的标志变量 valid_batch = torch.tensor([1], dtype=torch.int, device=device) # 1表示有效批次 # 仅在 rank == 0 时要做的 if rank == 0: # Rank0 设置 Redis 锁 key 的初始值为 0,表示数据加载尚未完成 redis_client.set(lock_key, 0) print("rank0 开始数据加载...") # 使用默认配置 client, db = mongo_con_parse() print(f"第 {i} 组 :", group_route_list) # 根据索引位置决定是 热门 还是 冷门 if 0 <= i < route_len_hot: is_hot = 1 table_name = CLEAN_VJ_HOT_NEAR_INFO_TAB elif route_len_hot <= i < route_len_hot + route_len_nothot: is_hot = 0 table_name = CLEAN_VJ_NOTHOT_NEAR_INFO_TAB else: print(f"无法确定热门还是冷门, 跳过此批次。") continue_before_process(redis_client, lock_key) continue # 加载训练数据 start_time = time.time() df_train = load_train_data(db, group_route_list, table_name, date_begin, date_end, output_dir, is_hot) end_time = time.time() run_time = round(end_time - start_time, 3) print(f"用时: {run_time} 秒") client.close() if df_train.empty: print(f"训练数据为空,跳过此批次。") continue_before_process(redis_client, lock_key) continue # 数据预处理 df_train_inputs = preprocess_data(df_train, features, categorical_features, is_training=True) print("预处理后数据样本:\n", df_train_inputs.head()) total_rows = df_train_inputs.shape[0] print(f"行数: {total_rows}") if total_rows == 0: print(f"预处理后的训练数据为空,跳过此批次。") continue_before_process(redis_client, lock_key) continue # 标准化与归一化处理 df_train_inputs, feature_scaler, target_scaler = standardization(df_train_inputs, feature_scaler=None, target_scaler=None) # 将 scaler 存入列表 batch_idx = i print("batch_idx:", batch_idx) feature_scaler_list[batch_idx] = feature_scaler # target_scaler_list[batch_idx] = target_scaler # 每个批次保存一下scaler feature_scaler_path = os.path.join(output_dir, f'feature_scalers.joblib') # target_scaler_path = os.path.join(output_dir, f'target_scalers.joblib') joblib.dump(feature_scaler_list, feature_scaler_path) # joblib.dump(target_scaler_list, target_scaler_path) # 生成序列 sequences, targets, group_ids = create_fixed_length_sequences(df_train_inputs, features, target_vars, input_length=452) pass else: pass if __name__ == "__main__": start_train()