import os import time import pandas as pd from datetime import datetime, timedelta from config import mongodb_config, vj_flight_route_list_hot, vj_flight_route_list_nothot, CLEAN_VJ_HOT_NEAR_INFO_TAB, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB from data_loader import load_train_data from data_preprocess import preprocess_data_simple from utils import chunk_list_with_index def merge_and_overwrite_csv(df_new, csv_path, dedup_cols): key_cols = [c for c in dedup_cols if c in df_new.columns] # 若干天后的训练:如果本次 df_new 里某些 flight_day(连同航班键)在历史 CSV df_old 里已经出现过,就认为这一天已经处理过了, # 本次不再追加这一天的任何节点;只追加“历史里不存在的 flight_day(同航班键)”的数据 if os.path.exists(csv_path): df_old = pd.read_csv(csv_path, encoding='utf-8-sig') if key_cols and all(c in df_old.columns for c in key_cols): df_old_keys = df_old[key_cols].drop_duplicates() df_add = df_new.merge(df_old_keys, on=key_cols, how='left', indicator=True) df_add = df_add[df_add['_merge'] == 'left_only'].drop(columns=['_merge']) else: df_add = df_new.copy() df_merged = pd.concat([df_old, df_add], ignore_index=True) # 第一次训练:直接保留,不做去重 else: df_merged = df_new.copy() sort_cols = [c for c in dedup_cols if c in df_merged.columns] if sort_cols: df_merged = df_merged.sort_values(by=sort_cols).reset_index(drop=True) df_merged.to_csv(csv_path, index=False, encoding='utf-8-sig') def start_train(): print(f"开始训练") output_dir = "./data_shards_0" # photo_dir = "./photo_0" # 确保目录存在 os.makedirs(output_dir, exist_ok=True) # os.makedirs(photo_dir, exist_ok=True) cpu_cores = os.cpu_count() # 你的系统是72 max_workers = min(8, cpu_cores) # 最大不超过8个进程 # date_end = datetime.today().strftime("%Y-%m-%d") date_end = (datetime.today() - timedelta(days=1)).strftime("%Y-%m-%d") # date_begin = (datetime.today() - timedelta(days=32)).strftime("%Y-%m-%d") date_begin = "2025-12-01" print(f"训练时间范围: {date_begin} 到 {date_end}") # 主干代码 (排除冷门航线) flight_route_list = vj_flight_route_list_hot + vj_flight_route_list_nothot[:0] flight_route_list_len = len(flight_route_list) route_len_hot = len(vj_flight_route_list_hot) route_len_nothot = len(vj_flight_route_list_nothot[:0]) print(f"flight_route_list_len:{flight_route_list_len}") print(f"route_len_hot:{route_len_hot}") print(f"route_len_nothot:{route_len_nothot}") group_size = 1 # 每几组作为一个批次 chunks = chunk_list_with_index(flight_route_list, group_size) # 如果临时处理中断,从日志里找到 中断的索引 修改它 resume_chunk_idx = 0 chunks = chunks[resume_chunk_idx:] batch_starts = [start_idx for start_idx, _ in chunks] print(f"训练阶段起始索引顺序:{batch_starts}") # 训练阶段 for i, (_, group_route_list) in enumerate(chunks, start=resume_chunk_idx): # 特殊处理,跳过不好的批次 # client, db = mongo_con_parse() print(f"第 {i} 组 :", group_route_list) # batch_flight_routes = group_route_list group_route_str = ','.join(group_route_list) # 根据索引位置决定是 热门 还是 冷门 if 0 <= i < route_len_hot: is_hot = 1 table_name = CLEAN_VJ_HOT_NEAR_INFO_TAB elif route_len_hot <= i < route_len_hot + route_len_nothot: is_hot = 0 table_name = CLEAN_VJ_NOTHOT_NEAR_INFO_TAB else: print(f"无法确定热门还是冷门, 跳过此批次。") continue # 加载训练数据 start_time = time.time() df_train = load_train_data(mongodb_config, group_route_list, table_name, date_begin, date_end, output_dir, is_hot, use_multiprocess=True, max_workers=max_workers) end_time = time.time() run_time = round(end_time - start_time, 3) print(f"用时: {run_time} 秒") if df_train.empty: print(f"训练数据为空,跳过此批次。") continue _, df_drop_nodes, df_keep_nodes = preprocess_data_simple(df_train, is_train=True) dedup_cols = ['city_pair', 'flight_number_1', 'flight_number_2', 'flight_day'] drop_info_csv_path = os.path.join(output_dir, f'{group_route_str}_drop_info.csv') if df_drop_nodes.empty: print(f"df_drop_nodes 为空,跳过保存: {drop_info_csv_path}") else: merge_and_overwrite_csv(df_drop_nodes, drop_info_csv_path, dedup_cols) print(f"本批次训练已保存csv文件: {drop_info_csv_path}") keep_info_csv_path = os.path.join(output_dir, f'{group_route_str}_keep_info.csv') if df_keep_nodes.empty: print(f"df_keep_nodes 为空,跳过保存: {keep_info_csv_path}") else: merge_and_overwrite_csv(df_keep_nodes, keep_info_csv_path, dedup_cols) print(f"本批次训练已保存csv文件: {keep_info_csv_path}") time.sleep(1) print(f"所有批次训练已完成") if __name__ == "__main__": start_train()