import warnings import os import torch import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP import joblib import gc import pandas as pd import numpy as np import redis import time import pickle import shutil from datetime import datetime, timedelta from data_loader import chunk_list, mongo_con_parse, load_train_data from data_preprocess import preprocess_data from config import mongodb_config, vj_flight_route_list, vj_flight_route_list_hot, vj_flight_route_list_nothot, \ CLEAN_VJ_HOT_NEAR_INFO_TAB, CLEAN_VJ_HOT_FAR_INFO_TAB, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB, CLEAN_VJ_NOTHOT_FAR_INFO_TAB warnings.filterwarnings('ignore') # 根据环境变量的存在设置分布式开关 if 'LOCAL_RANK' in os.environ: FLAG_Distributed = True else: FLAG_Distributed = False # 定义特征和参数 categorical_features = ['city_pair', 'flight_day', 'flight_number_1', 'flight_number_2'] # 这个与gid的分组条件一致 common_features = ['hours_until_departure', 'days_to_departure', 'seats_remaining', 'is_cross_country', 'is_transfer', 'fly_duration', 'stop_duration', 'flight_by_hour', 'flight_by_day', 'flight_day_of_month', 'flight_day_of_week', 'flight_day_of_quarter', 'flight_day_is_weekend', 'dep_country_is_holiday', 'arr_country_is_holiday', 'flight_day_is_holiday', 'days_to_holiday', ] price_features = ['adult_total_price', 'price_change_times_total', 'price_last_change_hours'] encoded_columns = ['from_city_num', 'to_city_num', 'flight_1_num', 'flight_2_num', 'cabin_level', 'baggage_level'] features = encoded_columns + price_features + common_features target_vars = ['target_will_price_drop'] # 是否降价 # 分布式环境初始化 def init_distributed_backend(): if FLAG_Distributed: local_rank = int(os.environ['LOCAL_RANK']) # 关键:绑定设备必须在初始化进程组之前 torch.cuda.set_device(local_rank) # 显式设置当前进程使用的 GPU try: dist.init_process_group( backend='nccl', init_method='env://', world_size=int(os.environ['WORLD_SIZE']), rank=int(os.environ['RANK']), timeout=timedelta(minutes=30) ) print(f"Process group initialized for rank {dist.get_rank()}") # 添加日志 except Exception as e: print(f"Failed to initialize process group: {e}") # 捕获异常 raise device = torch.device("cuda", local_rank) else: # 如果不在分布式环境中, 使用默认设备 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("use common environment") return device # 初始化模型和相关参数 def initialize_model(device): return None def continue_before_process(redis_client, lock_key): # rank0 跳出循环前的处理 redis_client.set(lock_key, 2) # 设置 Redis 锁 key 的值为 2 print("rank0 已将 Redis 锁 key 值设置为 2") time.sleep(5) print("rank0 5秒等待结束") def start_train(): device = init_distributed_backend() model = initialize_model(device) if FLAG_Distributed: rank = dist.get_rank() local_rank = int(os.environ.get('LOCAL_RANK')) world_size = dist.get_world_size() else: rank = 0 local_rank = 0 world_size = 1 output_dir = "./data_shards" photo_dir = "./photo" date_end = datetime.today().strftime("%Y-%m-%d") date_begin = (datetime.today() - timedelta(days=15)).strftime("%Y-%m-%d") # 仅在 rank == 0 时要做的 if rank == 0: # 如果处理中断, 注释掉以下代码 batch_dir = os.path.join(output_dir, "batches") try: shutil.rmtree(batch_dir) except FileNotFoundError: print(f"rank:{rank}, {batch_dir} not found") # 如果处理中断, 注释掉以下代码 csv_file_list = ['evaluate_results.csv'] for csv_file in csv_file_list: try: csv_path = os.path.join(output_dir, csv_file) os.remove(csv_path) except Exception as e: print(f"remove {csv_path}: {str(e)}") # 确保目录存在 os.makedirs(output_dir, exist_ok=True) os.makedirs(photo_dir, exist_ok=True) print(f"最终特征列表:{features}") # 定义优化器和损失函数(只回归) # criterion = RegressionLoss(loss_func_flag="Quantile", quantile=0.5) # optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=1e-5) group_size = 1 num_epochs_per_batch = 200 # 每个批次训练的轮数,可以根据需要调整 # 初始化 Redis 客户端(请根据实际情况修改 host、port、db) redis_client = redis.Redis(host='192.168.20.222', port=6379, db=0) lock_key = "data_loading_lock_11" barrier_key = 'distributed_barrier_11' batch_idx = -1 # 主干代码 # flight_route_list = vj_flight_route_list_hot + vj_flight_route_list_nothot # flight_route_list_len = len(flight_route_list) # route_len_hot = len(vj_flight_route_list_hot) # route_len_nothot = len(vj_flight_route_list_nothot) # 调试代码 s = 38 # 菲律宾2025-12-08是节假日 s=38 选到马尼拉 flight_route_list = vj_flight_route_list_hot[0:] + vj_flight_route_list_nothot[s:] flight_route_list_len = len(flight_route_list) route_len_hot = len(vj_flight_route_list_hot[0:]) route_len_nothot = len(vj_flight_route_list_nothot[s:]) if local_rank == 0: print(f"flight_route_list_len:{flight_route_list_len}") print(f"route_len_hot:{route_len_hot}") print(f"route_len_nothot:{route_len_nothot}") chunks = chunk_list(flight_route_list, group_size) for idx, group_route_list in enumerate(chunks, start=0): # 特殊处理,跳过不好的批次 pass redis_client.set(lock_key, 0) redis_client.set(barrier_key, 0) # 所有 Rank 同步的标志变量 valid_batch = torch.tensor([1], dtype=torch.int, device=device) # 1表示有效批次 # 仅在 rank == 0 时要做的 if rank == 0: # Rank0 设置 Redis 锁 key 的初始值为 0,表示数据加载尚未完成 redis_client.set(lock_key, 0) print("rank0 开始数据加载...") # 使用默认配置 client, db = mongo_con_parse() print(f"第 {idx} 组 :", group_route_list) # 根据索引位置决定是 热门 还是 冷门 if 0 <= idx < route_len_hot: is_hot = 1 table_name = CLEAN_VJ_HOT_NEAR_INFO_TAB elif route_len_hot <= idx < route_len_hot + route_len_nothot: is_hot = 0 table_name = CLEAN_VJ_NOTHOT_NEAR_INFO_TAB else: print(f"无法确定热门还是冷门, 跳过此批次。") continue_before_process(redis_client, lock_key) continue # 加载训练数据 start_time = time.time() df_train = load_train_data(db, group_route_list, table_name, date_begin, date_end, output_dir, is_hot) end_time = time.time() run_time = round(end_time - start_time, 3) print(f"用时: {run_time} 秒") client.close() if df_train.empty: print(f"训练数据为空,跳过此批次。") continue_before_process(redis_client, lock_key) continue # 数据预处理 df_train_inputs = preprocess_data(df_train, features, categorical_features, is_training=True) print("预处理后数据样本:\n", df_train_inputs.head()) total_rows = df_train_inputs.shape[0] pass else: pass if __name__ == "__main__": start_train()