main_tr_0.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. import os
  2. import time
  3. import gc
  4. import pandas as pd
  5. from datetime import datetime, timedelta
  6. from config import mongodb_config, vj_flight_route_list_hot, vj_flight_route_list_nothot, CLEAN_VJ_HOT_NEAR_INFO_TAB, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
  7. from data_loader import load_train_data
  8. from data_preprocess import preprocess_data_simple
  9. from utils import chunk_list_with_index
  10. def merge_and_overwrite_csv(df_new, csv_path, dedup_cols):
  11. key_cols = [c for c in dedup_cols if c in df_new.columns]
  12. # 若干天后的训练:如果本次 df_new 里某些 flight_day(连同航班键)在历史 CSV df_old 里已经出现过,就认为这一天已经处理过了,
  13. # 本次不再追加这一天的任何节点;只追加“历史里不存在的 flight_day(同航班键)”的数据
  14. if os.path.exists(csv_path):
  15. df_old = pd.read_csv(csv_path, encoding='utf-8-sig')
  16. if key_cols and all(c in df_old.columns for c in key_cols):
  17. df_old_keys = df_old[key_cols].drop_duplicates()
  18. df_add = df_new.merge(df_old_keys, on=key_cols, how='left', indicator=True) # indicator=True 会在结果df中添加一个_merge列
  19. df_add = df_add[df_add['_merge'] == 'left_only'].drop(columns=['_merge']) # left_only 只在左表(df_new)中存在
  20. else:
  21. df_add = df_new.copy()
  22. df_merged = pd.concat([df_old, df_add], ignore_index=True)
  23. # 第一次训练:直接保留,不做去重
  24. else:
  25. df_merged = df_new.copy()
  26. sort_cols = [c for c in dedup_cols if c in df_merged.columns]
  27. if sort_cols:
  28. df_merged = df_merged.sort_values(by=sort_cols).reset_index(drop=True) # 重新分组排序
  29. df_merged.to_csv(csv_path, index=False, encoding='utf-8-sig')
  30. def start_train():
  31. print(f"开始训练")
  32. output_dir = "./data_shards_0"
  33. # photo_dir = "./photo_0"
  34. # 确保目录存在
  35. os.makedirs(output_dir, exist_ok=True)
  36. # os.makedirs(photo_dir, exist_ok=True)
  37. cpu_cores = os.cpu_count() # 你的系统是72
  38. max_workers = min(8, cpu_cores) # 最大不超过8个进程
  39. # date_end = datetime.today().strftime("%Y-%m-%d")
  40. date_end = (datetime.today() - timedelta(days=1)).strftime("%Y-%m-%d")
  41. # date_begin = (datetime.today() - timedelta(days=32)).strftime("%Y-%m-%d")
  42. date_begin = "2026-01-01" # 2026-01-01 2026-03-11 2026-03-16 2026-03-18
  43. print(f"训练时间范围: {date_begin} 到 {date_end}")
  44. # 主干代码 (热门航线 + 冷门航线)
  45. flight_route_list = vj_flight_route_list_hot + vj_flight_route_list_nothot
  46. flight_route_list_len = len(flight_route_list)
  47. route_len_hot = len(vj_flight_route_list_hot)
  48. route_len_nothot = len(vj_flight_route_list_nothot)
  49. print(f"flight_route_list_len:{flight_route_list_len}")
  50. print(f"route_len_hot:{route_len_hot}")
  51. print(f"route_len_nothot:{route_len_nothot}")
  52. group_size = 1 # 每几组作为一个批次
  53. chunks = chunk_list_with_index(flight_route_list, group_size)
  54. # 如果临时处理中断,从日志里找到 中断的索引 修改它
  55. resume_chunk_idx = 0
  56. chunks = chunks[resume_chunk_idx:]
  57. batch_starts = [start_idx for start_idx, _ in chunks]
  58. print(f"训练阶段起始索引顺序:{batch_starts}")
  59. # 训练阶段
  60. for i, (_, group_route_list) in enumerate(chunks, start=resume_chunk_idx):
  61. # 特殊处理,跳过不好的批次
  62. # client, db = mongo_con_parse()
  63. print(f"第 {i} 组 :", group_route_list)
  64. # batch_flight_routes = group_route_list
  65. group_route_str = ','.join(group_route_list)
  66. # 根据索引位置决定是 热门 还是 冷门
  67. if 0 <= i < route_len_hot:
  68. is_hot = 1
  69. table_name = CLEAN_VJ_HOT_NEAR_INFO_TAB
  70. elif route_len_hot <= i < route_len_hot + route_len_nothot:
  71. is_hot = 0
  72. table_name = CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
  73. else:
  74. print(f"无法确定热门还是冷门, 跳过此批次。")
  75. continue
  76. # 加载训练数据
  77. start_time = time.time()
  78. df_train = load_train_data(mongodb_config, group_route_list, table_name, date_begin, date_end, output_dir, is_hot,
  79. use_multiprocess=True, max_workers=max_workers)
  80. end_time = time.time()
  81. run_time = round(end_time - start_time, 3)
  82. print(f"用时: {run_time} 秒")
  83. if df_train.empty:
  84. print(f"训练数据为空,跳过此批次。")
  85. continue
  86. _, df_drop_nodes, df_rise_nodes, df_envelope = preprocess_data_simple(df_train, is_train=True)
  87. dedup_cols = ['city_pair', 'flight_number_1', 'flight_number_2', 'flight_day']
  88. drop_info_csv_path = os.path.join(output_dir, f'{group_route_str}_drop_info.csv')
  89. if df_drop_nodes.empty:
  90. print(f"df_drop_nodes 为空,跳过保存: {drop_info_csv_path}")
  91. else:
  92. merge_and_overwrite_csv(df_drop_nodes, drop_info_csv_path, dedup_cols)
  93. print(f"本批次训练已保存csv文件: {drop_info_csv_path}")
  94. rise_info_csv_path = os.path.join(output_dir, f'{group_route_str}_rise_info.csv')
  95. if df_rise_nodes.empty:
  96. print(f"df_rise_nodes 为空,跳过保存: {rise_info_csv_path}")
  97. else:
  98. merge_and_overwrite_csv(df_rise_nodes, rise_info_csv_path, dedup_cols)
  99. print(f"本批次训练已保存csv文件: {rise_info_csv_path}")
  100. envelope_csv_path = os.path.join(output_dir, f'{group_route_str}_envelope_info.csv')
  101. if not df_envelope.empty:
  102. merge_and_overwrite_csv(df_envelope, envelope_csv_path, dedup_cols)
  103. print(f"本批次训练已保存csv文件: {envelope_csv_path}")
  104. del df_drop_nodes
  105. del df_rise_nodes
  106. del df_envelope
  107. gc.collect()
  108. time.sleep(1)
  109. print(f"所有批次训练已完成")
  110. if __name__ == "__main__":
  111. start_train()