|
|
@@ -41,19 +41,31 @@ def start_predict():
|
|
|
os.makedirs(photo_dir, exist_ok=True)
|
|
|
|
|
|
# 清空上一次预测结果
|
|
|
- csv_file_list = ['future_predictions.csv']
|
|
|
+ # csv_file_list = ['future_predictions.csv']
|
|
|
|
|
|
- for csv_file in csv_file_list:
|
|
|
- try:
|
|
|
- csv_path = os.path.join(output_dir, csv_file)
|
|
|
- os.remove(csv_path)
|
|
|
- except Exception as e:
|
|
|
- print(f"remove {csv_path} error: {str(e)}")
|
|
|
+ # for csv_file in csv_file_list:
|
|
|
+ # try:
|
|
|
+ # csv_path = os.path.join(output_dir, csv_file)
|
|
|
+ # os.remove(csv_path)
|
|
|
+ # except Exception as e:
|
|
|
+ # print(f"remove {csv_path} error: {str(e)}")
|
|
|
|
|
|
model, _ = initialize_model()
|
|
|
|
|
|
- date_end = (datetime.today() + timedelta(days=2)).strftime("%Y-%m-%d")
|
|
|
- date_begin = (datetime.today() + timedelta(days=2)).strftime("%Y-%m-%d")
|
|
|
+ # 当前时间,取整时
|
|
|
+ current_time = datetime.now()
|
|
|
+ hourly_time = current_time.replace(minute=0, second=0, microsecond=0)
|
|
|
+ pred_time_str = hourly_time.strftime("%Y%m%d%H%M")
|
|
|
+
|
|
|
+ # 预测时间范围,满足起飞时间 在28小时后到40小时后
|
|
|
+ pred_hour_begin = hourly_time + timedelta(hours=28)
|
|
|
+ pred_hour_end = hourly_time + timedelta(hours=40)
|
|
|
+
|
|
|
+ pred_date_end = pred_hour_end.strftime("%Y-%m-%d")
|
|
|
+ pred_date_begin = pred_hour_begin.strftime("%Y-%m-%d")
|
|
|
+
|
|
|
+ # date_end = (datetime.today() + timedelta(days=2)).strftime("%Y-%m-%d")
|
|
|
+ # date_begin = (datetime.today() + timedelta(days=2)).strftime("%Y-%m-%d")
|
|
|
|
|
|
# 加载 scaler 列表
|
|
|
feature_scaler_path = os.path.join(output_dir, 'feature_scalers.joblib')
|
|
|
@@ -102,7 +114,7 @@ def start_predict():
|
|
|
|
|
|
# 加载测试数据 (仅仅是时间段取到后天)
|
|
|
start_time = time.time()
|
|
|
- df_test = load_train_data(mongodb_config, group_route_list, table_name, date_begin, date_end, output_dir, is_hot)
|
|
|
+ df_test = load_train_data(mongodb_config, group_route_list, table_name, pred_date_begin, pred_date_end, output_dir, is_hot)
|
|
|
end_time = time.time()
|
|
|
run_time = round(end_time - start_time, 3)
|
|
|
print(f"用时: {run_time} 秒")
|
|
|
@@ -113,6 +125,22 @@ def start_predict():
|
|
|
print(f"测试数据为空,跳过此批次。")
|
|
|
continue
|
|
|
|
|
|
+ # 按起飞时间过滤
|
|
|
+ # 创建临时字段:seg1_dep_time 的整点时间
|
|
|
+ df_test['seg1_dep_hour'] = df_test['seg1_dep_time'].dt.floor('h')
|
|
|
+ # 使用整点时间进行比较过滤
|
|
|
+ mask = (df_test['seg1_dep_hour'] >= pred_hour_begin) & (df_test['seg1_dep_hour'] <= pred_hour_end)
|
|
|
+ original_count = len(df_test)
|
|
|
+ df_test = df_test[mask].reset_index(drop=True)
|
|
|
+ filtered_count = len(df_test)
|
|
|
+ # 删除临时字段
|
|
|
+ df_test = df_test.drop(columns=['seg1_dep_hour'])
|
|
|
+ print(f"按起飞时间过滤:过滤前 {original_count} 条,过滤后 {filtered_count} 条")
|
|
|
+
|
|
|
+ if filtered_count == 0:
|
|
|
+ print(f"起飞时间在 {pred_hour_begin} 到 {pred_hour_end} 之间没有航班,跳过此批次。")
|
|
|
+ continue
|
|
|
+
|
|
|
# 数据预处理
|
|
|
df_test_inputs = preprocess_data(df_test, features, categorical_features, is_training=False)
|
|
|
|
|
|
@@ -159,7 +187,7 @@ def start_predict():
|
|
|
|
|
|
target_scaler = None
|
|
|
# 预测未来数据
|
|
|
- predict_future_distribute(model, sequences, group_ids, target_scaler=target_scaler, output_dir=output_dir)
|
|
|
+ predict_future_distribute(model, sequences, group_ids, target_scaler=target_scaler, output_dir=output_dir, pred_time_str=pred_time_str)
|
|
|
|
|
|
print("所有批次的预测结束")
|
|
|
# 所有批次的预测结束后, 统一过滤处理
|