main_tr.py 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. import os
  2. import time
  3. import gc
  4. import pandas as pd
  5. from datetime import datetime, timedelta
  6. from config import mongo_config, uo_city_pairs_new
  7. from data_loader import load_data
  8. from data_process import preprocess_data_simple
  9. from utils import merge_and_overwrite_csv
  10. def start_train():
  11. print(f"开始训练")
  12. output_dir = "./data_shards"
  13. # 确保目录存在
  14. os.makedirs(output_dir, exist_ok=True)
  15. cpu_cores = os.cpu_count() # 你的系统是72
  16. max_workers = min(8, cpu_cores) # 最大不超过8个进程
  17. from_date_end = (datetime.today() - timedelta(days=1)).strftime("%Y-%m-%d")
  18. from_date_begin = "2026-03-17"
  19. print(f"训练时间范围: {from_date_begin} 到 {from_date_end}")
  20. uo_city_pairs = uo_city_pairs_new.copy()
  21. uo_city_pair_list = [f"{pair[:3]}-{pair[3:]}" for pair in uo_city_pairs]
  22. # 如果临时处理中断,从日志里找到 中断的索引 修改它
  23. resume_idx = 0
  24. uo_city_pair_list = uo_city_pair_list[resume_idx:]
  25. # 打印训练阶段起始索引顺序
  26. max_len = len(uo_city_pair_list) + resume_idx
  27. print(f"训练阶段起始索引顺序:{resume_idx} ~ {max_len - 1}")
  28. for idx, uo_city_pair in enumerate(uo_city_pair_list, start=resume_idx):
  29. print(f"第 {idx} 组 :", uo_city_pair)
  30. # 加载训练数据
  31. start_time = time.time()
  32. df_train = load_data(mongo_config, uo_city_pair, from_date_begin, from_date_end,
  33. use_multiprocess=True, max_workers=max_workers)
  34. end_time = time.time()
  35. run_time = round(end_time - start_time, 3)
  36. print(f"用时: {run_time} 秒")
  37. if df_train.empty:
  38. print(f"训练数据为空,跳过此批次。")
  39. continue
  40. _, df_drop_nodes, df_rise_nodes, df_envelope = preprocess_data_simple(df_train, is_train=True)
  41. dedup_cols = ['citypair', 'flight_numbers', 'from_date', 'baggage_weight']
  42. drop_info_csv_path = os.path.join(output_dir, f'{uo_city_pair}_drop_info.csv')
  43. if df_drop_nodes.empty:
  44. print(f"df_drop_nodes 为空,跳过保存: {drop_info_csv_path}")
  45. else:
  46. merge_and_overwrite_csv(df_drop_nodes, drop_info_csv_path, dedup_cols)
  47. print(f"本批次训练已保存csv文件: {drop_info_csv_path}")
  48. rise_info_csv_path = os.path.join(output_dir, f'{uo_city_pair}_rise_info.csv')
  49. if df_rise_nodes.empty:
  50. print(f"df_rise_nodes 为空,跳过保存: {rise_info_csv_path}")
  51. else:
  52. merge_and_overwrite_csv(df_rise_nodes, rise_info_csv_path, dedup_cols)
  53. print(f"本批次训练已保存csv文件: {rise_info_csv_path}")
  54. envelope_csv_path = os.path.join(output_dir, f'{uo_city_pair}_envelope_info.csv')
  55. if df_envelope.empty:
  56. print(f"df_envelope 为空,跳过保存: {envelope_csv_path}")
  57. else:
  58. merge_and_overwrite_csv(df_envelope, envelope_csv_path, dedup_cols)
  59. print(f"本批次训练已保存csv文件: {envelope_csv_path}")
  60. del df_drop_nodes
  61. del df_rise_nodes
  62. del df_envelope
  63. gc.collect()
  64. time.sleep(1)
  65. print(f"所有批次训练已完成")
  66. if __name__ == "__main__":
  67. start_train()