|
@@ -0,0 +1,204 @@
|
|
|
|
|
+import warnings
|
|
|
|
|
+import os
|
|
|
|
|
+import torch
|
|
|
|
|
+import torch.distributed as dist
|
|
|
|
|
+from torch.nn.parallel import DistributedDataParallel as DDP
|
|
|
|
|
+import joblib
|
|
|
|
|
+import gc
|
|
|
|
|
+import pandas as pd
|
|
|
|
|
+import numpy as np
|
|
|
|
|
+import redis
|
|
|
|
|
+import time
|
|
|
|
|
+import pickle
|
|
|
|
|
+import shutil
|
|
|
|
|
+from datetime import datetime, timedelta
|
|
|
|
|
+from data_loader import chunk_list, mongo_con_parse, load_train_data
|
|
|
|
|
+from data_preprocess import preprocess_data
|
|
|
|
|
+from config import mongodb_config, vj_flight_route_list, vj_flight_route_list_hot, vj_flight_route_list_nothot, \
|
|
|
|
|
+ CLEAN_VJ_HOT_NEAR_INFO_TAB, CLEAN_VJ_HOT_FAR_INFO_TAB, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB, CLEAN_VJ_NOTHOT_FAR_INFO_TAB
|
|
|
|
|
+
|
|
|
|
|
+warnings.filterwarnings('ignore')
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# 根据环境变量的存在设置分布式开关
|
|
|
|
|
+if 'LOCAL_RANK' in os.environ:
|
|
|
|
|
+ FLAG_Distributed = True
|
|
|
|
|
+else:
|
|
|
|
|
+ FLAG_Distributed = False
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# 定义特征和参数
|
|
|
|
|
+categorical_features = ['city_pair', 'flight_number_1', 'flight_number_2']
|
|
|
|
|
+other_features = []
|
|
|
|
|
+features = []
|
|
|
|
|
+
|
|
|
|
|
+target_vars = ['target_min_to_price'] # 最低会降到的价格
|
|
|
|
|
+
|
|
|
|
|
+# 分布式环境初始化
|
|
|
|
|
+def init_distributed_backend():
|
|
|
|
|
+ if FLAG_Distributed:
|
|
|
|
|
+ local_rank = int(os.environ['LOCAL_RANK'])
|
|
|
|
|
+ # 关键:绑定设备必须在初始化进程组之前
|
|
|
|
|
+ torch.cuda.set_device(local_rank) # 显式设置当前进程使用的 GPU
|
|
|
|
|
+ try:
|
|
|
|
|
+ dist.init_process_group(
|
|
|
|
|
+ backend='nccl',
|
|
|
|
|
+ init_method='env://',
|
|
|
|
|
+ world_size=int(os.environ['WORLD_SIZE']),
|
|
|
|
|
+ rank=int(os.environ['RANK']),
|
|
|
|
|
+ timeout=timedelta(minutes=30)
|
|
|
|
|
+ )
|
|
|
|
|
+ print(f"Process group initialized for rank {dist.get_rank()}") # 添加日志
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ print(f"Failed to initialize process group: {e}") # 捕获异常
|
|
|
|
|
+ raise
|
|
|
|
|
+ device = torch.device("cuda", local_rank)
|
|
|
|
|
+ else:
|
|
|
|
|
+ # 如果不在分布式环境中, 使用默认设备
|
|
|
|
|
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
+ print("use common environment")
|
|
|
|
|
+ return device
|
|
|
|
|
+
|
|
|
|
|
+# 初始化模型和相关参数
|
|
|
|
|
+def initialize_model(device):
|
|
|
|
|
+ return None
|
|
|
|
|
+
|
|
|
|
|
+def continue_before_process(redis_client, lock_key):
|
|
|
|
|
+ # rank0 跳出循环前的处理
|
|
|
|
|
+ redis_client.set(lock_key, 2) # 设置 Redis 锁 key 的值为 2
|
|
|
|
|
+ print("rank0 已将 Redis 锁 key 值设置为 2")
|
|
|
|
|
+ time.sleep(5)
|
|
|
|
|
+ print("rank0 5秒等待结束")
|
|
|
|
|
+
|
|
|
|
|
+def start_train():
|
|
|
|
|
+ device = init_distributed_backend()
|
|
|
|
|
+
|
|
|
|
|
+ model = initialize_model(device)
|
|
|
|
|
+
|
|
|
|
|
+ if FLAG_Distributed:
|
|
|
|
|
+ rank = dist.get_rank()
|
|
|
|
|
+ local_rank = int(os.environ.get('LOCAL_RANK'))
|
|
|
|
|
+ world_size = dist.get_world_size()
|
|
|
|
|
+ else:
|
|
|
|
|
+ rank = 0
|
|
|
|
|
+ local_rank = 0
|
|
|
|
|
+ world_size = 1
|
|
|
|
|
+
|
|
|
|
|
+ output_dir = "./data_shards"
|
|
|
|
|
+ photo_dir = "./photo"
|
|
|
|
|
+
|
|
|
|
|
+ date_end = datetime.today().strftime("%Y-%m-%d")
|
|
|
|
|
+ date_begin = (datetime.today() - timedelta(days=10)).strftime("%Y-%m-%d")
|
|
|
|
|
+
|
|
|
|
|
+ # 仅在 rank == 0 时要做的
|
|
|
|
|
+ if rank == 0:
|
|
|
|
|
+ # 如果处理中断, 注释掉以下代码
|
|
|
|
|
+ batch_dir = os.path.join(output_dir, "batches")
|
|
|
|
|
+ try:
|
|
|
|
|
+ shutil.rmtree(batch_dir)
|
|
|
|
|
+ except FileNotFoundError:
|
|
|
|
|
+ print(f"rank:{rank}, {batch_dir} not found")
|
|
|
|
|
+
|
|
|
|
|
+ # 如果处理中断, 注释掉以下代码
|
|
|
|
|
+ csv_file_list = ['evaluate_results.csv']
|
|
|
|
|
+ for csv_file in csv_file_list:
|
|
|
|
|
+ try:
|
|
|
|
|
+ csv_path = os.path.join(output_dir, csv_file)
|
|
|
|
|
+ os.remove(csv_path)
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ print(f"remove {csv_path}: {str(e)}")
|
|
|
|
|
+
|
|
|
|
|
+ # 确保目录存在
|
|
|
|
|
+ os.makedirs(output_dir, exist_ok=True)
|
|
|
|
|
+ os.makedirs(photo_dir, exist_ok=True)
|
|
|
|
|
+
|
|
|
|
|
+ print(f"最终特征列表:{features}")
|
|
|
|
|
+
|
|
|
|
|
+ # 定义优化器和损失函数(只回归)
|
|
|
|
|
+ # criterion = RegressionLoss(loss_func_flag="Quantile", quantile=0.5)
|
|
|
|
|
+ # optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=1e-5)
|
|
|
|
|
+
|
|
|
|
|
+ group_size = 1
|
|
|
|
|
+ num_epochs_per_batch = 200 # 每个批次训练的轮数,可以根据需要调整
|
|
|
|
|
+
|
|
|
|
|
+ # 初始化 Redis 客户端(请根据实际情况修改 host、port、db)
|
|
|
|
|
+ redis_client = redis.Redis(host='192.168.20.222', port=6379, db=0)
|
|
|
|
|
+ lock_key = "data_loading_lock_11"
|
|
|
|
|
+ barrier_key = 'distributed_barrier_11'
|
|
|
|
|
+
|
|
|
|
|
+ batch_idx = -1
|
|
|
|
|
+
|
|
|
|
|
+ # 主干代码
|
|
|
|
|
+ flight_route_list = vj_flight_route_list_hot + vj_flight_route_list_nothot
|
|
|
|
|
+ flight_route_list_len = len(flight_route_list)
|
|
|
|
|
+ route_len_hot = len(vj_flight_route_list_hot)
|
|
|
|
|
+ route_len_nothot = len(vj_flight_route_list_nothot)
|
|
|
|
|
+
|
|
|
|
|
+ # 调试代码
|
|
|
|
|
+ # s = 38 # 菲律宾2025-12-08是节假日 s=38 选到马尼拉
|
|
|
|
|
+ # flight_route_list = vj_flight_route_list_hot[:0] + vj_flight_route_list_nothot[s:]
|
|
|
|
|
+ # flight_route_list_len = len(flight_route_list)
|
|
|
|
|
+ # route_len_hot = len(vj_flight_route_list_hot[:0])
|
|
|
|
|
+ # route_len_nothot = len(vj_flight_route_list_nothot[s:])
|
|
|
|
|
+
|
|
|
|
|
+ if local_rank == 0:
|
|
|
|
|
+ print(f"flight_route_list_len:{flight_route_list_len}")
|
|
|
|
|
+ print(f"route_len_hot:{route_len_hot}")
|
|
|
|
|
+ print(f"route_len_nothot:{route_len_nothot}")
|
|
|
|
|
+
|
|
|
|
|
+ chunks = chunk_list(flight_route_list, group_size)
|
|
|
|
|
+
|
|
|
|
|
+ for idx, group_route_list in enumerate(chunks, start=0):
|
|
|
|
|
+ # 特殊处理,跳过不好的批次
|
|
|
|
|
+ pass
|
|
|
|
|
+ redis_client.set(lock_key, 0)
|
|
|
|
|
+ redis_client.set(barrier_key, 0)
|
|
|
|
|
+ # 所有 Rank 同步的标志变量
|
|
|
|
|
+ valid_batch = torch.tensor([1], dtype=torch.int, device=device) # 1表示有效批次
|
|
|
|
|
+
|
|
|
|
|
+ # 仅在 rank == 0 时要做的
|
|
|
|
|
+ if rank == 0:
|
|
|
|
|
+ # Rank0 设置 Redis 锁 key 的初始值为 0,表示数据加载尚未完成
|
|
|
|
|
+ redis_client.set(lock_key, 0)
|
|
|
|
|
+ print("rank0 开始数据加载...")
|
|
|
|
|
+ # 使用默认配置
|
|
|
|
|
+ client, db = mongo_con_parse()
|
|
|
|
|
+ print(f"第 {idx} 组 :", group_route_list)
|
|
|
|
|
+
|
|
|
|
|
+ # 根据索引位置决定是 热门 还是 冷门
|
|
|
|
|
+ if 0 <= idx < route_len_hot:
|
|
|
|
|
+ is_hot = 1
|
|
|
|
|
+ table_name = CLEAN_VJ_HOT_NEAR_INFO_TAB
|
|
|
|
|
+ elif route_len_hot <= idx < route_len_hot + route_len_nothot:
|
|
|
|
|
+ is_hot = 0
|
|
|
|
|
+ table_name = CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
|
|
|
|
|
+ else:
|
|
|
|
|
+ print(f"无法确定热门还是冷门, 跳过此批次。")
|
|
|
|
|
+ continue_before_process(redis_client, lock_key)
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ # 加载训练数据
|
|
|
|
|
+ start_time = time.time()
|
|
|
|
|
+ df_train = load_train_data(db, group_route_list, table_name, date_begin, date_end, output_dir, is_hot)
|
|
|
|
|
+ end_time = time.time()
|
|
|
|
|
+ run_time = round(end_time - start_time, 3)
|
|
|
|
|
+ print(f"用时: {run_time} 秒")
|
|
|
|
|
+
|
|
|
|
|
+ client.close()
|
|
|
|
|
+
|
|
|
|
|
+ if df_train.empty:
|
|
|
|
|
+ print(f"训练数据为空,跳过此批次。")
|
|
|
|
|
+ continue_before_process(redis_client, lock_key)
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ # 数据预处理
|
|
|
|
|
+ df_train_inputs = preprocess_data(df_train, features, categorical_features, is_training=True)
|
|
|
|
|
+ pass
|
|
|
|
|
+
|
|
|
|
|
+ else:
|
|
|
|
|
+ pass
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+if __name__ == "__main__":
|
|
|
|
|
+ start_train()
|