import os import time from datetime import datetime, timedelta from config import mongodb_config, vj_flight_route_list_hot, vj_flight_route_list_nothot, CLEAN_VJ_HOT_NEAR_INFO_TAB, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB from data_loader import load_train_data from data_preprocess import preprocess_data_simple, predict_data_simple from utils import chunk_list_with_index def start_predict(): print(f"开始预测") output_dir = "./data_shards_0" # photo_dir = "./photo_0" predict_dir = "./predictions_0" # 确保目录存在 os.makedirs(output_dir, exist_ok=True) # os.makedirs(photo_dir, exist_ok=True) os.makedirs(predict_dir, exist_ok=True) cpu_cores = os.cpu_count() # 你的系统是72 max_workers = min(4, cpu_cores) # 最大不超过4个进程 # 当前时间,取整时 current_time = datetime.now() current_time_str = current_time.strftime("%Y%m%d%H%M") hourly_time = current_time.replace(minute=0, second=0, microsecond=0) hourly_time_str = hourly_time.strftime("%Y%m%d%H%M") print(f"预测时间:{current_time_str}, (取整): {hourly_time_str}") # 清空上一次(同小时内)预测结果 csv_file_list = [f'future_predictions_{hourly_time_str}.csv'] for csv_file in csv_file_list: try: csv_path = os.path.join(predict_dir, csv_file) os.remove(csv_path) except Exception as e: print(f"remove {csv_path} info: {str(e)}") # 预测时间范围,满足起飞时间 在12小时后到60小时后 pred_hour_begin = hourly_time + timedelta(hours=12) pred_hour_end = hourly_time + timedelta(hours=60) pred_date_end = pred_hour_end.strftime("%Y-%m-%d") pred_date_begin = pred_hour_begin.strftime("%Y-%m-%d") print(f"预测起飞时间范围: {pred_date_begin} 到 {pred_date_end}") # 主干代码 (排除冷门航线) flight_route_list = vj_flight_route_list_hot + vj_flight_route_list_nothot[:0] flight_route_list_len = len(flight_route_list) route_len_hot = len(vj_flight_route_list_hot) route_len_nothot = len(vj_flight_route_list_nothot[:0]) group_size = 1 # 每几组作为一个批次 chunks = chunk_list_with_index(flight_route_list, group_size) # 如果从中途某个批次预测, 修改起始索引 resume_chunk_idx = 0 chunks = chunks[resume_chunk_idx:] batch_starts = [start_idx for start_idx, _ in chunks] print(f"预测阶段起始索引顺序:{batch_starts}") # 预测阶段 for i, (_, group_route_list) in enumerate(chunks, start=resume_chunk_idx): # 特殊处理,跳过不好的批次 # client, db = mongo_con_parse() print(f"第 {i} 组 :", group_route_list) # batch_flight_routes = group_route_list group_route_str = ','.join(group_route_list) # 根据索引位置决定是 热门 还是 冷门 if 0 <= i < route_len_hot: is_hot = 1 table_name = CLEAN_VJ_HOT_NEAR_INFO_TAB elif route_len_hot <= i < route_len_hot + route_len_nothot: is_hot = 0 table_name = CLEAN_VJ_NOTHOT_NEAR_INFO_TAB else: print(f"无法确定热门还是冷门, 跳过此批次。") continue # 加载测试数据 (仅仅是时间段取到后天) start_time = time.time() df_test = load_train_data(mongodb_config, group_route_list, table_name, pred_date_begin, pred_date_end, output_dir, is_hot, use_multiprocess=True, max_workers=max_workers) end_time = time.time() run_time = round(end_time - start_time, 3) print(f"用时: {run_time} 秒") if df_test.empty: print(f"测试数据为空,跳过此批次。") continue # 按起飞时间过滤 # 创建临时字段:seg1_dep_time 的整点时间 df_test['seg1_dep_hour'] = df_test['seg1_dep_time'].dt.floor('h') # 使用整点时间进行比较过滤 mask = (df_test['seg1_dep_hour'] >= pred_hour_begin) & (df_test['seg1_dep_hour'] < pred_hour_end) original_count = len(df_test) df_test = df_test[mask].reset_index(drop=True) filtered_count = len(df_test) # 删除临时字段 df_test = df_test.drop(columns=['seg1_dep_hour']) print(f"按起飞时间过滤:过滤前 {original_count} 条,过滤后 {filtered_count} 条") if filtered_count == 0: print(f"起飞时间在 {pred_hour_begin} 到 {pred_hour_end} 之间没有航班,跳过此批次。") continue df_test_inputs, _, _, = preprocess_data_simple(df_test) df_predict = predict_data_simple(df_test_inputs, group_route_str, output_dir, predict_dir, hourly_time_str) del df_test_inputs del df_predict print(f"第 {i} 组 预测完成") print() time.sleep(1) print("所有批次的预测结束") print() if __name__ == "__main__": start_predict()