import os import time import gc import pandas as pd from datetime import datetime, timedelta from config import mongo_config, uo_city_pairs_new from data_loader import load_data from data_process import preprocess_data_simple from utils import merge_and_overwrite_csv def start_train(): print(f"开始训练") output_dir = "./data_shards" # 确保目录存在 os.makedirs(output_dir, exist_ok=True) cpu_cores = os.cpu_count() # 你的系统是72 max_workers = min(8, cpu_cores) # 最大不超过8个进程 from_date_end = (datetime.today() - timedelta(days=0)).strftime("%Y-%m-%d") # 截止日改为今天 from_date_begin = "2026-03-27" print(f"训练时间范围: {from_date_begin} 到 {from_date_end}") uo_city_pairs = uo_city_pairs_new.copy() uo_city_pair_list = [f"{pair[:3]}-{pair[3:]}" for pair in uo_city_pairs] # 如果临时处理中断,从日志里找到 中断的索引 修改它 resume_idx = 0 uo_city_pair_list = uo_city_pair_list[resume_idx:] # 打印训练阶段起始索引顺序 max_len = len(uo_city_pair_list) + resume_idx print(f"训练阶段起始索引顺序:{resume_idx} ~ {max_len - 1}") for idx, uo_city_pair in enumerate(uo_city_pair_list, start=resume_idx): print(f"第 {idx} 组 :", uo_city_pair) # 加载训练数据 start_time = time.time() df_train = load_data(mongo_config, uo_city_pair, from_date_begin, from_date_end, use_multiprocess=True, max_workers=max_workers) end_time = time.time() run_time = round(end_time - start_time, 3) print(f"用时: {run_time} 秒") if df_train.empty: print(f"训练数据为空,跳过此批次。") continue _, df_drop_nodes, df_rise_nodes, df_envelope = preprocess_data_simple(df_train, is_train=True) dedup_cols = ['citypair', 'flight_numbers', 'from_date', 'baggage_weight'] drop_info_csv_path = os.path.join(output_dir, f'{uo_city_pair}_drop_info.csv') if df_drop_nodes.empty: print(f"df_drop_nodes 为空,跳过保存: {drop_info_csv_path}") else: merge_and_overwrite_csv(df_drop_nodes, drop_info_csv_path, dedup_cols) print(f"本批次训练已保存csv文件: {drop_info_csv_path}") rise_info_csv_path = os.path.join(output_dir, f'{uo_city_pair}_rise_info.csv') if df_rise_nodes.empty: print(f"df_rise_nodes 为空,跳过保存: {rise_info_csv_path}") else: merge_and_overwrite_csv(df_rise_nodes, rise_info_csv_path, dedup_cols) print(f"本批次训练已保存csv文件: {rise_info_csv_path}") envelope_csv_path = os.path.join(output_dir, f'{uo_city_pair}_envelope_info.csv') if df_envelope.empty: print(f"df_envelope 为空,跳过保存: {envelope_csv_path}") else: merge_and_overwrite_csv(df_envelope, envelope_csv_path, dedup_cols) print(f"本批次训练已保存csv文件: {envelope_csv_path}") del df_drop_nodes del df_rise_nodes del df_envelope gc.collect() time.sleep(1) print(f"所有批次训练已完成") if __name__ == "__main__": start_train()