main_tr_0.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. import os
  2. import time
  3. import pandas as pd
  4. from datetime import datetime, timedelta
  5. from config import mongodb_config, vj_flight_route_list_hot, vj_flight_route_list_nothot, CLEAN_VJ_HOT_NEAR_INFO_TAB, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
  6. from data_loader import load_train_data
  7. from data_preprocess import preprocess_data_simple
  8. from utils import chunk_list_with_index
  9. def merge_and_overwrite_csv(df_new, csv_path, dedup_cols):
  10. key_cols = [c for c in dedup_cols if c in df_new.columns]
  11. # 若干天后的训练:如果本次 df_new 里某些 flight_day(连同航班键)在历史 CSV df_old 里已经出现过,就认为这一天已经处理过了,
  12. # 本次不再追加这一天的任何节点;只追加“历史里不存在的 flight_day(同航班键)”的数据
  13. if os.path.exists(csv_path):
  14. df_old = pd.read_csv(csv_path, encoding='utf-8-sig')
  15. if key_cols and all(c in df_old.columns for c in key_cols):
  16. df_old_keys = df_old[key_cols].drop_duplicates()
  17. df_add = df_new.merge(df_old_keys, on=key_cols, how='left', indicator=True) # indicator=True 会在结果df中添加一个_merge列
  18. df_add = df_add[df_add['_merge'] == 'left_only'].drop(columns=['_merge']) # left_only 只在左表(df_new)中存在
  19. else:
  20. df_add = df_new.copy()
  21. df_merged = pd.concat([df_old, df_add], ignore_index=True)
  22. # 第一次训练:直接保留,不做去重
  23. else:
  24. df_merged = df_new.copy()
  25. sort_cols = [c for c in dedup_cols if c in df_merged.columns]
  26. if sort_cols:
  27. df_merged = df_merged.sort_values(by=sort_cols).reset_index(drop=True) # 重新分组排序
  28. df_merged.to_csv(csv_path, index=False, encoding='utf-8-sig')
  29. def start_train():
  30. print(f"开始训练")
  31. output_dir = "./data_shards_0"
  32. # photo_dir = "./photo_0"
  33. # 确保目录存在
  34. os.makedirs(output_dir, exist_ok=True)
  35. # os.makedirs(photo_dir, exist_ok=True)
  36. cpu_cores = os.cpu_count() # 你的系统是72
  37. max_workers = min(8, cpu_cores) # 最大不超过8个进程
  38. # date_end = datetime.today().strftime("%Y-%m-%d")
  39. date_end = (datetime.today() - timedelta(days=1)).strftime("%Y-%m-%d")
  40. # date_begin = (datetime.today() - timedelta(days=32)).strftime("%Y-%m-%d")
  41. date_begin = "2026-01-27" # 2025-12-01 2026-01-27
  42. print(f"训练时间范围: {date_begin} 到 {date_end}")
  43. # 主干代码 (排除冷门航线)
  44. flight_route_list = vj_flight_route_list_hot + vj_flight_route_list_nothot[:0]
  45. flight_route_list_len = len(flight_route_list)
  46. route_len_hot = len(vj_flight_route_list_hot)
  47. route_len_nothot = len(vj_flight_route_list_nothot[:0])
  48. print(f"flight_route_list_len:{flight_route_list_len}")
  49. print(f"route_len_hot:{route_len_hot}")
  50. print(f"route_len_nothot:{route_len_nothot}")
  51. group_size = 1 # 每几组作为一个批次
  52. chunks = chunk_list_with_index(flight_route_list, group_size)
  53. # 如果临时处理中断,从日志里找到 中断的索引 修改它
  54. resume_chunk_idx = 0
  55. chunks = chunks[resume_chunk_idx:]
  56. batch_starts = [start_idx for start_idx, _ in chunks]
  57. print(f"训练阶段起始索引顺序:{batch_starts}")
  58. # 训练阶段
  59. for i, (_, group_route_list) in enumerate(chunks, start=resume_chunk_idx):
  60. # 特殊处理,跳过不好的批次
  61. # client, db = mongo_con_parse()
  62. print(f"第 {i} 组 :", group_route_list)
  63. # batch_flight_routes = group_route_list
  64. group_route_str = ','.join(group_route_list)
  65. # 根据索引位置决定是 热门 还是 冷门
  66. if 0 <= i < route_len_hot:
  67. is_hot = 1
  68. table_name = CLEAN_VJ_HOT_NEAR_INFO_TAB
  69. elif route_len_hot <= i < route_len_hot + route_len_nothot:
  70. is_hot = 0
  71. table_name = CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
  72. else:
  73. print(f"无法确定热门还是冷门, 跳过此批次。")
  74. continue
  75. # 加载训练数据
  76. start_time = time.time()
  77. df_train = load_train_data(mongodb_config, group_route_list, table_name, date_begin, date_end, output_dir, is_hot,
  78. use_multiprocess=True, max_workers=max_workers)
  79. end_time = time.time()
  80. run_time = round(end_time - start_time, 3)
  81. print(f"用时: {run_time} 秒")
  82. if df_train.empty:
  83. print(f"训练数据为空,跳过此批次。")
  84. continue
  85. _, df_drop_nodes, df_keep_nodes = preprocess_data_simple(df_train, is_train=True)
  86. dedup_cols = ['city_pair', 'flight_number_1', 'flight_number_2', 'flight_day']
  87. drop_info_csv_path = os.path.join(output_dir, f'{group_route_str}_drop_info.csv')
  88. if df_drop_nodes.empty:
  89. print(f"df_drop_nodes 为空,跳过保存: {drop_info_csv_path}")
  90. else:
  91. merge_and_overwrite_csv(df_drop_nodes, drop_info_csv_path, dedup_cols)
  92. print(f"本批次训练已保存csv文件: {drop_info_csv_path}")
  93. keep_info_csv_path = os.path.join(output_dir, f'{group_route_str}_keep_info.csv')
  94. if df_keep_nodes.empty:
  95. print(f"df_keep_nodes 为空,跳过保存: {keep_info_csv_path}")
  96. else:
  97. merge_and_overwrite_csv(df_keep_nodes, keep_info_csv_path, dedup_cols)
  98. print(f"本批次训练已保存csv文件: {keep_info_csv_path}")
  99. time.sleep(1)
  100. print(f"所有批次训练已完成")
  101. if __name__ == "__main__":
  102. start_train()