| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131 |
- import os
- import time
- import pandas as pd
- from datetime import datetime, timedelta
- from config import mongodb_config, vj_flight_route_list_hot, vj_flight_route_list_nothot, CLEAN_VJ_HOT_NEAR_INFO_TAB, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
- from data_loader import load_train_data
- from data_preprocess import preprocess_data_simple
- from utils import chunk_list_with_index
- def merge_and_overwrite_csv(df_new, csv_path, dedup_cols):
- key_cols = [c for c in dedup_cols if c in df_new.columns]
- # 若干天后的训练:如果本次 df_new 里某些 flight_day(连同航班键)在历史 CSV df_old 里已经出现过,就认为这一天已经处理过了,
- # 本次不再追加这一天的任何节点;只追加“历史里不存在的 flight_day(同航班键)”的数据
- if os.path.exists(csv_path):
- df_old = pd.read_csv(csv_path, encoding='utf-8-sig')
- if key_cols and all(c in df_old.columns for c in key_cols):
- df_old_keys = df_old[key_cols].drop_duplicates()
- df_add = df_new.merge(df_old_keys, on=key_cols, how='left', indicator=True)
- df_add = df_add[df_add['_merge'] == 'left_only'].drop(columns=['_merge'])
- else:
- df_add = df_new.copy()
- df_merged = pd.concat([df_old, df_add], ignore_index=True)
- # 第一次训练:直接保留,不做去重
- else:
- df_merged = df_new.copy()
- sort_cols = [c for c in dedup_cols if c in df_merged.columns]
- if sort_cols:
- df_merged = df_merged.sort_values(by=sort_cols).reset_index(drop=True)
- df_merged.to_csv(csv_path, index=False, encoding='utf-8-sig')
- def start_train():
- print(f"开始训练")
- output_dir = "./data_shards_0"
- # photo_dir = "./photo_0"
- # 确保目录存在
- os.makedirs(output_dir, exist_ok=True)
- # os.makedirs(photo_dir, exist_ok=True)
- cpu_cores = os.cpu_count() # 你的系统是72
- max_workers = min(8, cpu_cores) # 最大不超过8个进程
- # date_end = datetime.today().strftime("%Y-%m-%d")
- date_end = (datetime.today() - timedelta(days=1)).strftime("%Y-%m-%d")
- # date_begin = (datetime.today() - timedelta(days=32)).strftime("%Y-%m-%d")
- date_begin = "2025-12-01"
- print(f"训练时间范围: {date_begin} 到 {date_end}")
- # 主干代码 (排除冷门航线)
- flight_route_list = vj_flight_route_list_hot + vj_flight_route_list_nothot[:0]
- flight_route_list_len = len(flight_route_list)
- route_len_hot = len(vj_flight_route_list_hot)
- route_len_nothot = len(vj_flight_route_list_nothot[:0])
- print(f"flight_route_list_len:{flight_route_list_len}")
- print(f"route_len_hot:{route_len_hot}")
- print(f"route_len_nothot:{route_len_nothot}")
- group_size = 1 # 每几组作为一个批次
- chunks = chunk_list_with_index(flight_route_list, group_size)
-
- # 如果临时处理中断,从日志里找到 中断的索引 修改它
- resume_chunk_idx = 0
- chunks = chunks[resume_chunk_idx:]
- batch_starts = [start_idx for start_idx, _ in chunks]
- print(f"训练阶段起始索引顺序:{batch_starts}")
- # 训练阶段
- for i, (_, group_route_list) in enumerate(chunks, start=resume_chunk_idx):
- # 特殊处理,跳过不好的批次
- # client, db = mongo_con_parse()
- print(f"第 {i} 组 :", group_route_list)
- # batch_flight_routes = group_route_list
- group_route_str = ','.join(group_route_list)
- # 根据索引位置决定是 热门 还是 冷门
- if 0 <= i < route_len_hot:
- is_hot = 1
- table_name = CLEAN_VJ_HOT_NEAR_INFO_TAB
- elif route_len_hot <= i < route_len_hot + route_len_nothot:
- is_hot = 0
- table_name = CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
- else:
- print(f"无法确定热门还是冷门, 跳过此批次。")
- continue
- # 加载训练数据
- start_time = time.time()
- df_train = load_train_data(mongodb_config, group_route_list, table_name, date_begin, date_end, output_dir, is_hot,
- use_multiprocess=True, max_workers=max_workers)
- end_time = time.time()
- run_time = round(end_time - start_time, 3)
- print(f"用时: {run_time} 秒")
- if df_train.empty:
- print(f"训练数据为空,跳过此批次。")
- continue
- _, df_drop_nodes, df_keep_nodes = preprocess_data_simple(df_train, is_train=True)
- dedup_cols = ['city_pair', 'flight_number_1', 'flight_number_2', 'flight_day']
- drop_info_csv_path = os.path.join(output_dir, f'{group_route_str}_drop_info.csv')
- if df_drop_nodes.empty:
- print(f"df_drop_nodes 为空,跳过保存: {drop_info_csv_path}")
- else:
- merge_and_overwrite_csv(df_drop_nodes, drop_info_csv_path, dedup_cols)
- print(f"本批次训练已保存csv文件: {drop_info_csv_path}")
- keep_info_csv_path = os.path.join(output_dir, f'{group_route_str}_keep_info.csv')
- if df_keep_nodes.empty:
- print(f"df_keep_nodes 为空,跳过保存: {keep_info_csv_path}")
- else:
- merge_and_overwrite_csv(df_keep_nodes, keep_info_csv_path, dedup_cols)
- print(f"本批次训练已保存csv文件: {keep_info_csv_path}")
- time.sleep(1)
- print(f"所有批次训练已完成")
- if __name__ == "__main__":
- start_train()
|