| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889 |
- import os
- import time
- import gc
- import pandas as pd
- from datetime import datetime, timedelta
- from config import mongo_config, uo_city_pairs_new
- from data_loader import load_data
- from data_process import preprocess_data_simple
- from utils import merge_and_overwrite_csv
- def start_train():
- print(f"开始训练")
- output_dir = "./data_shards"
- # 确保目录存在
- os.makedirs(output_dir, exist_ok=True)
- cpu_cores = os.cpu_count() # 你的系统是72
- max_workers = min(8, cpu_cores) # 最大不超过8个进程
- from_date_end = (datetime.today() - timedelta(days=1)).strftime("%Y-%m-%d")
- from_date_begin = "2026-03-17"
- print(f"训练时间范围: {from_date_begin} 到 {from_date_end}")
- uo_city_pairs = uo_city_pairs_new.copy()
- uo_city_pair_list = [f"{pair[:3]}-{pair[3:]}" for pair in uo_city_pairs]
- # 如果临时处理中断,从日志里找到 中断的索引 修改它
- resume_idx = 0
- uo_city_pair_list = uo_city_pair_list[resume_idx:]
- # 打印训练阶段起始索引顺序
- max_len = len(uo_city_pair_list) + resume_idx
- print(f"训练阶段起始索引顺序:{resume_idx} ~ {max_len - 1}")
- for idx, uo_city_pair in enumerate(uo_city_pair_list, start=resume_idx):
- print(f"第 {idx} 组 :", uo_city_pair)
- # 加载训练数据
- start_time = time.time()
- df_train = load_data(mongo_config, uo_city_pair, from_date_begin, from_date_end,
- use_multiprocess=True, max_workers=max_workers)
- end_time = time.time()
- run_time = round(end_time - start_time, 3)
- print(f"用时: {run_time} 秒")
- if df_train.empty:
- print(f"训练数据为空,跳过此批次。")
- continue
-
- _, df_drop_nodes, df_rise_nodes, df_envelope = preprocess_data_simple(df_train, is_train=True)
- dedup_cols = ['citypair', 'flight_numbers', 'from_date', 'baggage_weight']
- drop_info_csv_path = os.path.join(output_dir, f'{uo_city_pair}_drop_info.csv')
- if df_drop_nodes.empty:
- print(f"df_drop_nodes 为空,跳过保存: {drop_info_csv_path}")
- else:
- merge_and_overwrite_csv(df_drop_nodes, drop_info_csv_path, dedup_cols)
- print(f"本批次训练已保存csv文件: {drop_info_csv_path}")
-
- rise_info_csv_path = os.path.join(output_dir, f'{uo_city_pair}_rise_info.csv')
- if df_rise_nodes.empty:
- print(f"df_rise_nodes 为空,跳过保存: {rise_info_csv_path}")
- else:
- merge_and_overwrite_csv(df_rise_nodes, rise_info_csv_path, dedup_cols)
- print(f"本批次训练已保存csv文件: {rise_info_csv_path}")
-
- envelope_csv_path = os.path.join(output_dir, f'{uo_city_pair}_envelope_info.csv')
- if df_envelope.empty:
- print(f"df_envelope 为空,跳过保存: {envelope_csv_path}")
- else:
- merge_and_overwrite_csv(df_envelope, envelope_csv_path, dedup_cols)
- print(f"本批次训练已保存csv文件: {envelope_csv_path}")
-
- del df_drop_nodes
- del df_rise_nodes
- del df_envelope
- gc.collect()
- time.sleep(1)
- print(f"所有批次训练已完成")
- if __name__ == "__main__":
- start_train()
|