main_tr.py 3.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. import os
  2. import time
  3. import gc
  4. import pandas as pd
  5. from datetime import datetime, timedelta
  6. from config import mongo_config, uo_city_pairs_new
  7. from data_loader import load_data
  8. from data_process import preprocess_data_simple
  9. from utils import merge_and_overwrite_csv
  10. def start_train():
  11. print(f"开始训练")
  12. output_dir = "./data_shards"
  13. # 确保目录存在
  14. os.makedirs(output_dir, exist_ok=True)
  15. cpu_cores = os.cpu_count() # 你的系统是72
  16. max_workers = min(8, cpu_cores) # 最大不超过8个进程
  17. from_date_end = (datetime.today() + timedelta(days=1)).strftime("%Y-%m-%d") # 截止日改为明天
  18. from_date_begin = from_date_end
  19. # from_date_begin = "2026-03-17" # 2026-03-17 2026-04-30 2026-05-06
  20. # from_date_begin = "2026-05-06"
  21. print(f"训练时间范围: {from_date_begin} 到 {from_date_end}")
  22. uo_city_pairs = uo_city_pairs_new.copy()
  23. uo_city_pair_list = [f"{pair[:3]}-{pair[3:]}" for pair in uo_city_pairs]
  24. # 如果临时处理中断,从日志里找到 中断的索引 修改它
  25. resume_idx = 0
  26. uo_city_pair_list = uo_city_pair_list[resume_idx:]
  27. # 打印训练阶段起始索引顺序
  28. max_len = len(uo_city_pair_list) + resume_idx
  29. print(f"训练阶段起始索引顺序:{resume_idx} ~ {max_len - 1}")
  30. for idx, uo_city_pair in enumerate(uo_city_pair_list, start=resume_idx):
  31. print(f"第 {idx} 组 :", uo_city_pair)
  32. # 加载训练数据
  33. start_time = time.time()
  34. df_train = load_data(mongo_config, uo_city_pair, from_date_begin, from_date_end,
  35. use_multiprocess=True, max_workers=max_workers)
  36. end_time = time.time()
  37. run_time = round(end_time - start_time, 3)
  38. print(f"用时: {run_time} 秒")
  39. if df_train.empty:
  40. print(f"训练数据为空,跳过此批次。")
  41. continue
  42. _, df_drop_nodes, df_rise_nodes, df_envelope = preprocess_data_simple(df_train, is_train=True)
  43. dedup_cols = ['citypair', 'flight_numbers', 'from_date', 'baggage_weight']
  44. drop_info_csv_path = os.path.join(output_dir, f'{uo_city_pair}_drop_info.csv')
  45. if df_drop_nodes.empty:
  46. print(f"df_drop_nodes 为空,跳过保存: {drop_info_csv_path}")
  47. else:
  48. merge_and_overwrite_csv(df_drop_nodes, drop_info_csv_path, dedup_cols)
  49. print(f"本批次训练已保存csv文件: {drop_info_csv_path}")
  50. rise_info_csv_path = os.path.join(output_dir, f'{uo_city_pair}_rise_info.csv')
  51. if df_rise_nodes.empty:
  52. print(f"df_rise_nodes 为空,跳过保存: {rise_info_csv_path}")
  53. else:
  54. merge_and_overwrite_csv(df_rise_nodes, rise_info_csv_path, dedup_cols)
  55. print(f"本批次训练已保存csv文件: {rise_info_csv_path}")
  56. envelope_csv_path = os.path.join(output_dir, f'{uo_city_pair}_envelope_info.csv')
  57. if df_envelope.empty:
  58. print(f"df_envelope 为空,跳过保存: {envelope_csv_path}")
  59. else:
  60. merge_and_overwrite_csv(df_envelope, envelope_csv_path, dedup_cols)
  61. print(f"本批次训练已保存csv文件: {envelope_csv_path}")
  62. del df_drop_nodes
  63. del df_rise_nodes
  64. del df_envelope
  65. gc.collect()
  66. time.sleep(1)
  67. print(f"所有批次训练已完成")
  68. print()
  69. if __name__ == "__main__":
  70. start_train()