| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217 |
- import warnings
- import os
- import torch
- import torch.distributed as dist
- from torch.nn.parallel import DistributedDataParallel as DDP
- import joblib
- import gc
- import pandas as pd
- import numpy as np
- import redis
- import time
- import pickle
- import shutil
- from datetime import datetime, timedelta
- from data_loader import chunk_list, mongo_con_parse, load_train_data
- from data_preprocess import preprocess_data
- from config import mongodb_config, vj_flight_route_list, vj_flight_route_list_hot, vj_flight_route_list_nothot, \
- CLEAN_VJ_HOT_NEAR_INFO_TAB, CLEAN_VJ_HOT_FAR_INFO_TAB, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB, CLEAN_VJ_NOTHOT_FAR_INFO_TAB
- warnings.filterwarnings('ignore')
- # 根据环境变量的存在设置分布式开关
- if 'LOCAL_RANK' in os.environ:
- FLAG_Distributed = True
- else:
- FLAG_Distributed = False
- # 定义特征和参数
- categorical_features = ['city_pair', 'flight_day', 'flight_number_1', 'flight_number_2'] # 这个与gid的分组条件一致
- common_features = ['hours_until_departure', 'days_to_departure', 'seats_remaining', 'is_cross_country', 'is_transfer',
- 'fly_duration', 'stop_duration',
- 'flight_by_hour', 'flight_by_day', 'flight_day_of_month', 'flight_day_of_week', 'flight_day_of_quarter', 'flight_day_is_weekend',
- 'dep_country_is_holiday', 'arr_country_is_holiday', 'flight_day_is_holiday', 'days_to_holiday',
- ]
- price_features = ['adult_total_price', 'price_change_times_total', 'price_last_change_hours']
- encoded_columns = ['from_city_num', 'to_city_num', 'flight_1_num', 'flight_2_num', 'cabin_level', 'baggage_level']
- features = encoded_columns + price_features + common_features
- target_vars = ['target_will_price_drop'] # 是否降价
- # 分布式环境初始化
- def init_distributed_backend():
- if FLAG_Distributed:
- local_rank = int(os.environ['LOCAL_RANK'])
- # 关键:绑定设备必须在初始化进程组之前
- torch.cuda.set_device(local_rank) # 显式设置当前进程使用的 GPU
- try:
- dist.init_process_group(
- backend='nccl',
- init_method='env://',
- world_size=int(os.environ['WORLD_SIZE']),
- rank=int(os.environ['RANK']),
- timeout=timedelta(minutes=30)
- )
- print(f"Process group initialized for rank {dist.get_rank()}") # 添加日志
- except Exception as e:
- print(f"Failed to initialize process group: {e}") # 捕获异常
- raise
- device = torch.device("cuda", local_rank)
- else:
- # 如果不在分布式环境中, 使用默认设备
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- print("use common environment")
- return device
- # 初始化模型和相关参数
- def initialize_model(device):
- return None
- def continue_before_process(redis_client, lock_key):
- # rank0 跳出循环前的处理
- redis_client.set(lock_key, 2) # 设置 Redis 锁 key 的值为 2
- print("rank0 已将 Redis 锁 key 值设置为 2")
- time.sleep(5)
- print("rank0 5秒等待结束")
- def start_train():
- device = init_distributed_backend()
- model = initialize_model(device)
- if FLAG_Distributed:
- rank = dist.get_rank()
- local_rank = int(os.environ.get('LOCAL_RANK'))
- world_size = dist.get_world_size()
- else:
- rank = 0
- local_rank = 0
- world_size = 1
- output_dir = "./data_shards"
- photo_dir = "./photo"
- date_end = datetime.today().strftime("%Y-%m-%d")
- date_begin = (datetime.today() - timedelta(days=15)).strftime("%Y-%m-%d")
- # 仅在 rank == 0 时要做的
- if rank == 0:
- # 如果处理中断, 注释掉以下代码
- batch_dir = os.path.join(output_dir, "batches")
- try:
- shutil.rmtree(batch_dir)
- except FileNotFoundError:
- print(f"rank:{rank}, {batch_dir} not found")
- # 如果处理中断, 注释掉以下代码
- csv_file_list = ['evaluate_results.csv']
- for csv_file in csv_file_list:
- try:
- csv_path = os.path.join(output_dir, csv_file)
- os.remove(csv_path)
- except Exception as e:
- print(f"remove {csv_path}: {str(e)}")
- # 确保目录存在
- os.makedirs(output_dir, exist_ok=True)
- os.makedirs(photo_dir, exist_ok=True)
- print(f"最终特征列表:{features}")
- # 定义优化器和损失函数(只回归)
- # criterion = RegressionLoss(loss_func_flag="Quantile", quantile=0.5)
- # optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=1e-5)
- group_size = 1
- num_epochs_per_batch = 200 # 每个批次训练的轮数,可以根据需要调整
- # 初始化 Redis 客户端(请根据实际情况修改 host、port、db)
- redis_client = redis.Redis(host='192.168.20.222', port=6379, db=0)
- lock_key = "data_loading_lock_11"
- barrier_key = 'distributed_barrier_11'
- batch_idx = -1
- # 主干代码
- # flight_route_list = vj_flight_route_list_hot + vj_flight_route_list_nothot
- # flight_route_list_len = len(flight_route_list)
- # route_len_hot = len(vj_flight_route_list_hot)
- # route_len_nothot = len(vj_flight_route_list_nothot)
- # 调试代码
- s = 38 # 菲律宾2025-12-08是节假日 s=38 选到马尼拉
- flight_route_list = vj_flight_route_list_hot[0:] + vj_flight_route_list_nothot[s:]
- flight_route_list_len = len(flight_route_list)
- route_len_hot = len(vj_flight_route_list_hot[0:])
- route_len_nothot = len(vj_flight_route_list_nothot[s:])
-
- if local_rank == 0:
- print(f"flight_route_list_len:{flight_route_list_len}")
- print(f"route_len_hot:{route_len_hot}")
- print(f"route_len_nothot:{route_len_nothot}")
-
- chunks = chunk_list(flight_route_list, group_size)
- for idx, group_route_list in enumerate(chunks, start=0):
- # 特殊处理,跳过不好的批次
- pass
- redis_client.set(lock_key, 0)
- redis_client.set(barrier_key, 0)
- # 所有 Rank 同步的标志变量
- valid_batch = torch.tensor([1], dtype=torch.int, device=device) # 1表示有效批次
- # 仅在 rank == 0 时要做的
- if rank == 0:
- # Rank0 设置 Redis 锁 key 的初始值为 0,表示数据加载尚未完成
- redis_client.set(lock_key, 0)
- print("rank0 开始数据加载...")
- # 使用默认配置
- client, db = mongo_con_parse()
- print(f"第 {idx} 组 :", group_route_list)
- # 根据索引位置决定是 热门 还是 冷门
- if 0 <= idx < route_len_hot:
- is_hot = 1
- table_name = CLEAN_VJ_HOT_NEAR_INFO_TAB
- elif route_len_hot <= idx < route_len_hot + route_len_nothot:
- is_hot = 0
- table_name = CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
- else:
- print(f"无法确定热门还是冷门, 跳过此批次。")
- continue_before_process(redis_client, lock_key)
- continue
-
- # 加载训练数据
- start_time = time.time()
- df_train = load_train_data(db, group_route_list, table_name, date_begin, date_end, output_dir, is_hot)
- end_time = time.time()
- run_time = round(end_time - start_time, 3)
- print(f"用时: {run_time} 秒")
- client.close()
- if df_train.empty:
- print(f"训练数据为空,跳过此批次。")
- continue_before_process(redis_client, lock_key)
- continue
-
- # 数据预处理
- df_train_inputs = preprocess_data(df_train, features, categorical_features, is_training=True)
- print("预处理后数据样本:\n", df_train_inputs.head())
- total_rows = df_train_inputs.shape[0]
- pass
- else:
- pass
- if __name__ == "__main__":
- start_train()
|