main_tr.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. import warnings
  2. import os
  3. import torch
  4. import torch.distributed as dist
  5. from torch.nn.parallel import DistributedDataParallel as DDP
  6. import joblib
  7. import gc
  8. import pandas as pd
  9. import numpy as np
  10. import redis
  11. import time
  12. import pickle
  13. import shutil
  14. from datetime import datetime, timedelta
  15. from utils import chunk_list_with_index, create_fixed_length_sequences
  16. from data_loader import mongo_con_parse, load_train_data
  17. from data_preprocess import preprocess_data, standardization
  18. from config import mongodb_config, vj_flight_route_list, vj_flight_route_list_hot, vj_flight_route_list_nothot, \
  19. CLEAN_VJ_HOT_NEAR_INFO_TAB, CLEAN_VJ_HOT_FAR_INFO_TAB, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB, CLEAN_VJ_NOTHOT_FAR_INFO_TAB
  20. warnings.filterwarnings('ignore')
  21. # 根据环境变量的存在设置分布式开关
  22. if 'LOCAL_RANK' in os.environ:
  23. FLAG_Distributed = True
  24. else:
  25. FLAG_Distributed = False
  26. # 定义特征和参数
  27. categorical_features = ['city_pair', 'flight_day', 'flight_number_1', 'flight_number_2'] # 这个与gid的分组条件一致
  28. common_features = ['hours_until_departure', 'days_to_departure', 'seats_remaining', 'is_cross_country', 'is_transfer',
  29. 'fly_duration', 'stop_duration',
  30. 'flight_by_hour', 'flight_by_day', 'flight_day_of_month', 'flight_day_of_week', 'flight_day_of_quarter', 'flight_day_is_weekend',
  31. 'dep_country_is_holiday', 'arr_country_is_holiday', 'any_country_is_holiday', 'days_to_holiday',
  32. ]
  33. price_features = ['adult_total_price', 'price_change_times_total', 'price_last_change_hours']
  34. encoded_columns = ['from_city_num', 'to_city_num', 'flight_1_num', 'flight_2_num', 'baggage_level']
  35. features = encoded_columns + price_features + common_features
  36. target_vars = ['target_will_price_drop'] # 是否降价
  37. # 分布式环境初始化
  38. def init_distributed_backend():
  39. if FLAG_Distributed:
  40. local_rank = int(os.environ['LOCAL_RANK'])
  41. # 关键:绑定设备必须在初始化进程组之前
  42. torch.cuda.set_device(local_rank) # 显式设置当前进程使用的 GPU
  43. try:
  44. dist.init_process_group(
  45. backend='nccl',
  46. init_method='env://',
  47. world_size=int(os.environ['WORLD_SIZE']),
  48. rank=int(os.environ['RANK']),
  49. timeout=timedelta(minutes=30)
  50. )
  51. print(f"Process group initialized for rank {dist.get_rank()}") # 添加日志
  52. except Exception as e:
  53. print(f"Failed to initialize process group: {e}") # 捕获异常
  54. raise
  55. device = torch.device("cuda", local_rank)
  56. else:
  57. # 如果不在分布式环境中, 使用默认设备
  58. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  59. print("use common environment")
  60. return device
  61. # 初始化模型和相关参数
  62. def initialize_model(device):
  63. return None
  64. def continue_before_process(redis_client, lock_key):
  65. # rank0 跳出循环前的处理
  66. redis_client.set(lock_key, 2) # 设置 Redis 锁 key 的值为 2
  67. print("rank0 已将 Redis 锁 key 值设置为 2")
  68. time.sleep(5)
  69. print("rank0 5秒等待结束")
  70. def start_train():
  71. device = init_distributed_backend()
  72. model = initialize_model(device)
  73. if FLAG_Distributed:
  74. rank = dist.get_rank()
  75. local_rank = int(os.environ.get('LOCAL_RANK'))
  76. world_size = dist.get_world_size()
  77. else:
  78. rank = 0
  79. local_rank = 0
  80. world_size = 1
  81. output_dir = "./data_shards"
  82. photo_dir = "./photo"
  83. date_end = datetime.today().strftime("%Y-%m-%d")
  84. date_begin = (datetime.today() - timedelta(days=18)).strftime("%Y-%m-%d")
  85. # 仅在 rank == 0 时要做的
  86. if rank == 0:
  87. # 如果处理中断, 注释掉以下代码
  88. batch_dir = os.path.join(output_dir, "batches")
  89. try:
  90. shutil.rmtree(batch_dir)
  91. except FileNotFoundError:
  92. print(f"rank:{rank}, {batch_dir} not found")
  93. # 如果处理中断, 注释掉以下代码
  94. csv_file_list = ['evaluate_results.csv']
  95. for csv_file in csv_file_list:
  96. try:
  97. csv_path = os.path.join(output_dir, csv_file)
  98. os.remove(csv_path)
  99. except Exception as e:
  100. print(f"remove {csv_path}: {str(e)}")
  101. # 确保目录存在
  102. os.makedirs(output_dir, exist_ok=True)
  103. os.makedirs(photo_dir, exist_ok=True)
  104. print(f"最终特征列表:{features}")
  105. # 定义优化器和损失函数(只回归)
  106. # criterion = RegressionLoss(loss_func_flag="Quantile", quantile=0.5)
  107. # optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=1e-5)
  108. group_size = 1 # 每几组作为一个批次
  109. num_epochs_per_batch = 200 # 每个批次训练的轮数,可以根据需要调整
  110. feature_scaler = None # 初始化特征缩放器
  111. target_scaler = None # 初始化目标缩放器
  112. # 初始化 Redis 客户端(请根据实际情况修改 host、port、db)
  113. redis_client = redis.Redis(host='192.168.20.222', port=6379, db=0)
  114. lock_key = "data_loading_lock_11"
  115. barrier_key = 'distributed_barrier_11'
  116. batch_idx = -1
  117. # 主干代码
  118. # flight_route_list = vj_flight_route_list_hot + vj_flight_route_list_nothot
  119. # flight_route_list_len = len(flight_route_list)
  120. # route_len_hot = len(vj_flight_route_list_hot)
  121. # route_len_nothot = len(vj_flight_route_list_nothot)
  122. # 调试代码
  123. s = 38 # 菲律宾2025-12-08是节假日 s=38 选到马尼拉
  124. flight_route_list = vj_flight_route_list_hot[:0] + vj_flight_route_list_nothot[s:]
  125. flight_route_list_len = len(flight_route_list)
  126. route_len_hot = len(vj_flight_route_list_hot[:0])
  127. route_len_nothot = len(vj_flight_route_list_nothot[s:])
  128. if local_rank == 0:
  129. print(f"flight_route_list_len:{flight_route_list_len}")
  130. print(f"route_len_hot:{route_len_hot}")
  131. print(f"route_len_nothot:{route_len_nothot}")
  132. # 如果处理中断,打开注释加载批次顺序
  133. # with open(os.path.join(output_dir, f'order.pkl'), "rb") as f:
  134. # flight_route_list = pickle.load(f)
  135. if rank == 0:
  136. pass
  137. # 保存批次顺序, 如果处理临时中断, 将这段代码注释掉
  138. with open(os.path.join(output_dir, f'order.pkl'), "wb") as f:
  139. pickle.dump(flight_route_list, f)
  140. chunks = chunk_list_with_index(flight_route_list, group_size)
  141. # 新增部分:计算总批次数并初始化 scaler 列表
  142. if rank == 0:
  143. total_batches = len(chunks)
  144. feature_scaler_list = [None] * total_batches # 预分配列表空间
  145. # target_scaler_list = [None] * total_batches # 预分配列表空间
  146. # 中断时,打开下面注释, 临时加载一下 scaler 列表
  147. # if rank == 0:
  148. # feature_scaler_path = os.path.join(output_dir, 'feature_scalers.joblib')
  149. # target_scaler_path = os.path.join(output_dir, 'target_scalers.joblib')
  150. # if os.path.exists(feature_scaler_path):
  151. # # 加载旧的scaler列表
  152. # old_feature_scaler_list = joblib.load(feature_scaler_path)
  153. # # 计算旧的总批次数
  154. # old_total_batches = len(old_feature_scaler_list)
  155. # # 只替换重叠部分
  156. # min_batches = min(old_total_batches, total_batches)
  157. # feature_scaler_list[:min_batches] = old_feature_scaler_list[:min_batches]
  158. # if os.path.exists(target_scaler_path):
  159. # # 加载旧的scaler列表
  160. # old_target_scaler_list = joblib.load(target_scaler_path)
  161. # # 计算旧的总批次数
  162. # old_total_batches = len(old_target_scaler_list)
  163. # # 只替换重叠部分
  164. # min_batches = min(old_total_batches, total_batches)
  165. # target_scaler_list[:min_batches] = old_target_scaler_list[:min_batches]
  166. # 如果临时处理中断,从日志里找到 中断的索引 修改它
  167. resume_chunk_idx = 0
  168. chunks = chunks[resume_chunk_idx:]
  169. if local_rank == 0:
  170. batch_starts = [start_idx for start_idx, _ in chunks]
  171. print(f"rank:{rank}, local_rank:{local_rank}, 训练阶段起始索引顺序:{batch_starts}")
  172. # 训练阶段
  173. for i, (_, group_route_list) in enumerate(chunks, start=resume_chunk_idx):
  174. # 特殊处理,跳过不好的批次
  175. pass
  176. redis_client.set(lock_key, 0)
  177. redis_client.set(barrier_key, 0)
  178. # 所有 Rank 同步的标志变量
  179. valid_batch = torch.tensor([1], dtype=torch.int, device=device) # 1表示有效批次
  180. # 仅在 rank == 0 时要做的
  181. if rank == 0:
  182. # Rank0 设置 Redis 锁 key 的初始值为 0,表示数据加载尚未完成
  183. redis_client.set(lock_key, 0)
  184. print("rank0 开始数据加载...")
  185. # 使用默认配置
  186. client, db = mongo_con_parse()
  187. print(f"第 {i} 组 :", group_route_list)
  188. # 根据索引位置决定是 热门 还是 冷门
  189. if 0 <= i < route_len_hot:
  190. is_hot = 1
  191. table_name = CLEAN_VJ_HOT_NEAR_INFO_TAB
  192. elif route_len_hot <= i < route_len_hot + route_len_nothot:
  193. is_hot = 0
  194. table_name = CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
  195. else:
  196. print(f"无法确定热门还是冷门, 跳过此批次。")
  197. continue_before_process(redis_client, lock_key)
  198. continue
  199. # 加载训练数据
  200. start_time = time.time()
  201. df_train = load_train_data(db, group_route_list, table_name, date_begin, date_end, output_dir, is_hot)
  202. end_time = time.time()
  203. run_time = round(end_time - start_time, 3)
  204. print(f"用时: {run_time} 秒")
  205. client.close()
  206. if df_train.empty:
  207. print(f"训练数据为空,跳过此批次。")
  208. continue_before_process(redis_client, lock_key)
  209. continue
  210. # 数据预处理
  211. df_train_inputs = preprocess_data(df_train, features, categorical_features, is_training=True)
  212. print("预处理后数据样本:\n", df_train_inputs.head())
  213. total_rows = df_train_inputs.shape[0]
  214. print(f"行数: {total_rows}")
  215. if total_rows == 0:
  216. print(f"预处理后的训练数据为空,跳过此批次。")
  217. continue_before_process(redis_client, lock_key)
  218. continue
  219. # 标准化与归一化处理
  220. df_train_inputs, feature_scaler, target_scaler = standardization(df_train_inputs, feature_scaler=None, target_scaler=None)
  221. # 将 scaler 存入列表
  222. batch_idx = i
  223. print("batch_idx:", batch_idx)
  224. feature_scaler_list[batch_idx] = feature_scaler
  225. # target_scaler_list[batch_idx] = target_scaler
  226. # 每个批次保存一下scaler
  227. feature_scaler_path = os.path.join(output_dir, f'feature_scalers.joblib')
  228. # target_scaler_path = os.path.join(output_dir, f'target_scalers.joblib')
  229. joblib.dump(feature_scaler_list, feature_scaler_path)
  230. # joblib.dump(target_scaler_list, target_scaler_path)
  231. # 生成序列
  232. sequences, targets, group_ids = create_fixed_length_sequences(df_train_inputs, features, target_vars, input_length=452)
  233. pass
  234. else:
  235. pass
  236. if __name__ == "__main__":
  237. start_train()