main_tr.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. import os
  2. import time
  3. import gc
  4. import pandas as pd
  5. from datetime import datetime, timedelta
  6. from config import mongo_config, uo_city_pairs_new
  7. from data_loader import load_data
  8. from data_process import preprocess_data_simple
  9. from utils import merge_and_overwrite_csv
  10. def start_train():
  11. print(f"开始训练")
  12. output_dir = "./data_shards"
  13. # 确保目录存在
  14. os.makedirs(output_dir, exist_ok=True)
  15. cpu_cores = os.cpu_count() # 你的系统是72
  16. max_workers = min(8, cpu_cores) # 最大不超过8个进程
  17. from_date_end = (datetime.today() - timedelta(days=1)).strftime("%Y-%m-%d") # 截止日改为昨天
  18. # from_date_begin = "2026-03-17" # 2026-03-17 2026-04-07 2026-04-09
  19. from_date_begin = "2026-04-07"
  20. print(f"训练时间范围: {from_date_begin} 到 {from_date_end}")
  21. uo_city_pairs = uo_city_pairs_new.copy()
  22. uo_city_pair_list = [f"{pair[:3]}-{pair[3:]}" for pair in uo_city_pairs]
  23. # 如果临时处理中断,从日志里找到 中断的索引 修改它
  24. resume_idx = 0
  25. uo_city_pair_list = uo_city_pair_list[resume_idx:]
  26. # 打印训练阶段起始索引顺序
  27. max_len = len(uo_city_pair_list) + resume_idx
  28. print(f"训练阶段起始索引顺序:{resume_idx} ~ {max_len - 1}")
  29. for idx, uo_city_pair in enumerate(uo_city_pair_list, start=resume_idx):
  30. print(f"第 {idx} 组 :", uo_city_pair)
  31. # 加载训练数据
  32. start_time = time.time()
  33. df_train = load_data(mongo_config, uo_city_pair, from_date_begin, from_date_end,
  34. use_multiprocess=True, max_workers=max_workers)
  35. end_time = time.time()
  36. run_time = round(end_time - start_time, 3)
  37. print(f"用时: {run_time} 秒")
  38. if df_train.empty:
  39. print(f"训练数据为空,跳过此批次。")
  40. continue
  41. _, df_drop_nodes, df_rise_nodes, df_envelope = preprocess_data_simple(df_train, is_train=True)
  42. dedup_cols = ['citypair', 'flight_numbers', 'from_date', 'baggage_weight']
  43. drop_info_csv_path = os.path.join(output_dir, f'{uo_city_pair}_drop_info.csv')
  44. if df_drop_nodes.empty:
  45. print(f"df_drop_nodes 为空,跳过保存: {drop_info_csv_path}")
  46. else:
  47. merge_and_overwrite_csv(df_drop_nodes, drop_info_csv_path, dedup_cols)
  48. print(f"本批次训练已保存csv文件: {drop_info_csv_path}")
  49. rise_info_csv_path = os.path.join(output_dir, f'{uo_city_pair}_rise_info.csv')
  50. if df_rise_nodes.empty:
  51. print(f"df_rise_nodes 为空,跳过保存: {rise_info_csv_path}")
  52. else:
  53. merge_and_overwrite_csv(df_rise_nodes, rise_info_csv_path, dedup_cols)
  54. print(f"本批次训练已保存csv文件: {rise_info_csv_path}")
  55. envelope_csv_path = os.path.join(output_dir, f'{uo_city_pair}_envelope_info.csv')
  56. if df_envelope.empty:
  57. print(f"df_envelope 为空,跳过保存: {envelope_csv_path}")
  58. else:
  59. merge_and_overwrite_csv(df_envelope, envelope_csv_path, dedup_cols)
  60. print(f"本批次训练已保存csv文件: {envelope_csv_path}")
  61. del df_drop_nodes
  62. del df_rise_nodes
  63. del df_envelope
  64. gc.collect()
  65. time.sleep(1)
  66. print(f"所有批次训练已完成")
  67. if __name__ == "__main__":
  68. start_train()