main_tr_0.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import os
  2. import time
  3. import pandas as pd
  4. from datetime import datetime, timedelta
  5. from config import mongodb_config, vj_flight_route_list_hot, vj_flight_route_list_nothot, CLEAN_VJ_HOT_NEAR_INFO_TAB, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
  6. from data_loader import load_train_data
  7. from data_preprocess import preprocess_data_simple
  8. from utils import chunk_list_with_index
  9. def start_train():
  10. print(f"开始训练")
  11. output_dir = "./data_shards_0"
  12. # photo_dir = "./photo_0"
  13. # 确保目录存在
  14. os.makedirs(output_dir, exist_ok=True)
  15. # os.makedirs(photo_dir, exist_ok=True)
  16. cpu_cores = os.cpu_count() # 你的系统是72
  17. max_workers = min(8, cpu_cores) # 最大不超过8个进程
  18. # date_end = datetime.today().strftime("%Y-%m-%d")
  19. date_end = (datetime.today() - timedelta(days=1)).strftime("%Y-%m-%d")
  20. date_begin = (datetime.today() - timedelta(days=31)).strftime("%Y-%m-%d")
  21. # date_begin = "2025-12-01"
  22. print(f"训练时间范围: {date_begin} 到 {date_end}")
  23. # 主干代码 (排除冷门航线)
  24. flight_route_list = vj_flight_route_list_hot + vj_flight_route_list_nothot[:0]
  25. flight_route_list_len = len(flight_route_list)
  26. route_len_hot = len(vj_flight_route_list_hot)
  27. route_len_nothot = len(vj_flight_route_list_nothot[:0])
  28. print(f"flight_route_list_len:{flight_route_list_len}")
  29. print(f"route_len_hot:{route_len_hot}")
  30. print(f"route_len_nothot:{route_len_nothot}")
  31. group_size = 1 # 每几组作为一个批次
  32. chunks = chunk_list_with_index(flight_route_list, group_size)
  33. # 如果临时处理中断,从日志里找到 中断的索引 修改它
  34. resume_chunk_idx = 0
  35. chunks = chunks[resume_chunk_idx:]
  36. batch_starts = [start_idx for start_idx, _ in chunks]
  37. print(f"训练阶段起始索引顺序:{batch_starts}")
  38. # 训练阶段
  39. for i, (_, group_route_list) in enumerate(chunks, start=resume_chunk_idx):
  40. # 特殊处理,跳过不好的批次
  41. # client, db = mongo_con_parse()
  42. print(f"第 {i} 组 :", group_route_list)
  43. # batch_flight_routes = group_route_list
  44. group_route_str = ','.join(group_route_list)
  45. # 根据索引位置决定是 热门 还是 冷门
  46. if 0 <= i < route_len_hot:
  47. is_hot = 1
  48. table_name = CLEAN_VJ_HOT_NEAR_INFO_TAB
  49. elif route_len_hot <= i < route_len_hot + route_len_nothot:
  50. is_hot = 0
  51. table_name = CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
  52. else:
  53. print(f"无法确定热门还是冷门, 跳过此批次。")
  54. continue
  55. # 加载训练数据
  56. start_time = time.time()
  57. df_train = load_train_data(mongodb_config, group_route_list, table_name, date_begin, date_end, output_dir, is_hot,
  58. use_multiprocess=True, max_workers=max_workers)
  59. end_time = time.time()
  60. run_time = round(end_time - start_time, 3)
  61. print(f"用时: {run_time} 秒")
  62. if df_train.empty:
  63. print(f"训练数据为空,跳过此批次。")
  64. continue
  65. _, df_drop_nodes = preprocess_data_simple(df_train, is_train=True)
  66. drop_info_csv_path = os.path.join(output_dir, f'{group_route_str}_drop_info.csv')
  67. if df_drop_nodes.empty:
  68. print(f"df_drop_nodes 为空,跳过保存: {drop_info_csv_path}")
  69. continue
  70. dedup_cols = ['city_pair', 'flight_number_1', 'flight_number_2', 'flight_day']
  71. key_cols = [c for c in dedup_cols if c in df_drop_nodes.columns]
  72. # 若干天后的训练:如果本次 df_drop_nodes 里某些 flight_day(连同航班键)在历史 CSV 里已经出现过,就认为这一天已经处理过了,
  73. # 本次不再追加这一天的任何节点;只追加“历史里不存在的 flight_day(同航班键)”的数据
  74. if os.path.exists(drop_info_csv_path):
  75. df_old = pd.read_csv(drop_info_csv_path, encoding='utf-8-sig')
  76. if key_cols and all(c in df_old.columns for c in key_cols):
  77. df_old_keys = df_old[key_cols].drop_duplicates()
  78. df_new = df_drop_nodes.merge(df_old_keys, on=key_cols, how='left', indicator=True)
  79. df_new = df_new[df_new['_merge'] == 'left_only'].drop(columns=['_merge'])
  80. else:
  81. df_new = df_drop_nodes.copy()
  82. df_merged = pd.concat([df_old, df_new], ignore_index=True)
  83. # 第一次训练:直接保留,不做 dedup_cols 的去重
  84. else:
  85. df_merged = df_drop_nodes.copy()
  86. sort_cols = [c for c in dedup_cols if c in df_merged.columns]
  87. if sort_cols:
  88. df_merged = df_merged.sort_values(by=sort_cols).reset_index(drop=True)
  89. df_merged.to_csv(drop_info_csv_path, index=False, encoding='utf-8-sig')
  90. print(f"本批次训练已保存csv文件: {drop_info_csv_path}")
  91. time.sleep(1)
  92. print(f"所有批次训练已完成")
  93. if __name__ == "__main__":
  94. start_train()