main_tr.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. import warnings
  2. import os
  3. import torch
  4. import torch.distributed as dist
  5. from torch.nn.parallel import DistributedDataParallel as DDP
  6. import joblib
  7. import gc
  8. import pandas as pd
  9. import numpy as np
  10. import redis
  11. import time
  12. import pickle
  13. import shutil
  14. from datetime import datetime, timedelta
  15. from data_loader import chunk_list, mongo_con_parse, load_train_data
  16. from data_preprocess import preprocess_data
  17. from config import mongodb_config, vj_flight_route_list, vj_flight_route_list_hot, vj_flight_route_list_nothot, \
  18. CLEAN_VJ_HOT_NEAR_INFO_TAB, CLEAN_VJ_HOT_FAR_INFO_TAB, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB, CLEAN_VJ_NOTHOT_FAR_INFO_TAB
  19. warnings.filterwarnings('ignore')
  20. # 根据环境变量的存在设置分布式开关
  21. if 'LOCAL_RANK' in os.environ:
  22. FLAG_Distributed = True
  23. else:
  24. FLAG_Distributed = False
  25. # 定义特征和参数
  26. categorical_features = ['city_pair', 'flight_number_1', 'flight_number_2']
  27. other_features = []
  28. features = []
  29. target_vars = ['target_min_to_price'] # 最低会降到的价格
  30. # 分布式环境初始化
  31. def init_distributed_backend():
  32. if FLAG_Distributed:
  33. local_rank = int(os.environ['LOCAL_RANK'])
  34. # 关键:绑定设备必须在初始化进程组之前
  35. torch.cuda.set_device(local_rank) # 显式设置当前进程使用的 GPU
  36. try:
  37. dist.init_process_group(
  38. backend='nccl',
  39. init_method='env://',
  40. world_size=int(os.environ['WORLD_SIZE']),
  41. rank=int(os.environ['RANK']),
  42. timeout=timedelta(minutes=30)
  43. )
  44. print(f"Process group initialized for rank {dist.get_rank()}") # 添加日志
  45. except Exception as e:
  46. print(f"Failed to initialize process group: {e}") # 捕获异常
  47. raise
  48. device = torch.device("cuda", local_rank)
  49. else:
  50. # 如果不在分布式环境中, 使用默认设备
  51. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  52. print("use common environment")
  53. return device
  54. # 初始化模型和相关参数
  55. def initialize_model(device):
  56. return None
  57. def continue_before_process(redis_client, lock_key):
  58. # rank0 跳出循环前的处理
  59. redis_client.set(lock_key, 2) # 设置 Redis 锁 key 的值为 2
  60. print("rank0 已将 Redis 锁 key 值设置为 2")
  61. time.sleep(5)
  62. print("rank0 5秒等待结束")
  63. def start_train():
  64. device = init_distributed_backend()
  65. model = initialize_model(device)
  66. if FLAG_Distributed:
  67. rank = dist.get_rank()
  68. local_rank = int(os.environ.get('LOCAL_RANK'))
  69. world_size = dist.get_world_size()
  70. else:
  71. rank = 0
  72. local_rank = 0
  73. world_size = 1
  74. output_dir = "./data_shards"
  75. photo_dir = "./photo"
  76. date_end = datetime.today().strftime("%Y-%m-%d")
  77. date_begin = (datetime.today() - timedelta(days=10)).strftime("%Y-%m-%d")
  78. # 仅在 rank == 0 时要做的
  79. if rank == 0:
  80. # 如果处理中断, 注释掉以下代码
  81. batch_dir = os.path.join(output_dir, "batches")
  82. try:
  83. shutil.rmtree(batch_dir)
  84. except FileNotFoundError:
  85. print(f"rank:{rank}, {batch_dir} not found")
  86. # 如果处理中断, 注释掉以下代码
  87. csv_file_list = ['evaluate_results.csv']
  88. for csv_file in csv_file_list:
  89. try:
  90. csv_path = os.path.join(output_dir, csv_file)
  91. os.remove(csv_path)
  92. except Exception as e:
  93. print(f"remove {csv_path}: {str(e)}")
  94. # 确保目录存在
  95. os.makedirs(output_dir, exist_ok=True)
  96. os.makedirs(photo_dir, exist_ok=True)
  97. print(f"最终特征列表:{features}")
  98. # 定义优化器和损失函数(只回归)
  99. # criterion = RegressionLoss(loss_func_flag="Quantile", quantile=0.5)
  100. # optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=1e-5)
  101. group_size = 1
  102. num_epochs_per_batch = 200 # 每个批次训练的轮数,可以根据需要调整
  103. # 初始化 Redis 客户端(请根据实际情况修改 host、port、db)
  104. redis_client = redis.Redis(host='192.168.20.222', port=6379, db=0)
  105. lock_key = "data_loading_lock_11"
  106. barrier_key = 'distributed_barrier_11'
  107. batch_idx = -1
  108. # 主干代码
  109. flight_route_list = vj_flight_route_list_hot + vj_flight_route_list_nothot
  110. flight_route_list_len = len(flight_route_list)
  111. route_len_hot = len(vj_flight_route_list_hot)
  112. route_len_nothot = len(vj_flight_route_list_nothot)
  113. # 调试代码
  114. # s = 38 # 菲律宾2025-12-08是节假日 s=38 选到马尼拉
  115. # flight_route_list = vj_flight_route_list_hot[:0] + vj_flight_route_list_nothot[s:]
  116. # flight_route_list_len = len(flight_route_list)
  117. # route_len_hot = len(vj_flight_route_list_hot[:0])
  118. # route_len_nothot = len(vj_flight_route_list_nothot[s:])
  119. if local_rank == 0:
  120. print(f"flight_route_list_len:{flight_route_list_len}")
  121. print(f"route_len_hot:{route_len_hot}")
  122. print(f"route_len_nothot:{route_len_nothot}")
  123. chunks = chunk_list(flight_route_list, group_size)
  124. for idx, group_route_list in enumerate(chunks, start=0):
  125. # 特殊处理,跳过不好的批次
  126. pass
  127. redis_client.set(lock_key, 0)
  128. redis_client.set(barrier_key, 0)
  129. # 所有 Rank 同步的标志变量
  130. valid_batch = torch.tensor([1], dtype=torch.int, device=device) # 1表示有效批次
  131. # 仅在 rank == 0 时要做的
  132. if rank == 0:
  133. # Rank0 设置 Redis 锁 key 的初始值为 0,表示数据加载尚未完成
  134. redis_client.set(lock_key, 0)
  135. print("rank0 开始数据加载...")
  136. # 使用默认配置
  137. client, db = mongo_con_parse()
  138. print(f"第 {idx} 组 :", group_route_list)
  139. # 根据索引位置决定是 热门 还是 冷门
  140. if 0 <= idx < route_len_hot:
  141. is_hot = 1
  142. table_name = CLEAN_VJ_HOT_NEAR_INFO_TAB
  143. elif route_len_hot <= idx < route_len_hot + route_len_nothot:
  144. is_hot = 0
  145. table_name = CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
  146. else:
  147. print(f"无法确定热门还是冷门, 跳过此批次。")
  148. continue_before_process(redis_client, lock_key)
  149. continue
  150. # 加载训练数据
  151. start_time = time.time()
  152. df_train = load_train_data(db, group_route_list, table_name, date_begin, date_end, output_dir, is_hot)
  153. end_time = time.time()
  154. run_time = round(end_time - start_time, 3)
  155. print(f"用时: {run_time} 秒")
  156. client.close()
  157. if df_train.empty:
  158. print(f"训练数据为空,跳过此批次。")
  159. continue_before_process(redis_client, lock_key)
  160. continue
  161. # 数据预处理
  162. df_train_inputs = preprocess_data(df_train, features, categorical_features, is_training=True)
  163. pass
  164. else:
  165. pass
  166. if __name__ == "__main__":
  167. start_train()