|
@@ -1,10 +1,11 @@
|
|
|
import os
|
|
import os
|
|
|
import torch
|
|
import torch
|
|
|
import joblib
|
|
import joblib
|
|
|
-import pandas as pd
|
|
|
|
|
-import numpy as np
|
|
|
|
|
|
|
+# import pandas as pd
|
|
|
|
|
+# import numpy as np
|
|
|
import pickle
|
|
import pickle
|
|
|
import time
|
|
import time
|
|
|
|
|
+import argparse
|
|
|
from datetime import datetime, timedelta
|
|
from datetime import datetime, timedelta
|
|
|
from config import mongodb_config, vj_flight_route_list_hot, vj_flight_route_list_nothot, CLEAN_VJ_HOT_NEAR_INFO_TAB, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
|
|
from config import mongodb_config, vj_flight_route_list_hot, vj_flight_route_list_nothot, CLEAN_VJ_HOT_NEAR_INFO_TAB, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
|
|
|
from data_loader import mongo_con_parse, load_train_data
|
|
from data_loader import mongo_con_parse, load_train_data
|
|
@@ -15,10 +16,6 @@ from predict import predict_future_distribute
|
|
|
from main_tr import features, categorical_features, target_vars
|
|
from main_tr import features, categorical_features, target_vars
|
|
|
|
|
|
|
|
|
|
|
|
|
-output_dir = "./data_shards"
|
|
|
|
|
-photo_dir = "./photo"
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
def initialize_model():
|
|
def initialize_model():
|
|
|
input_size = len(features)
|
|
input_size = len(features)
|
|
|
model = PriceDropClassifiTransModel(input_size, num_periods=2, hidden_size=64, num_layers=3, output_size=1, dropout=0.2)
|
|
model = PriceDropClassifiTransModel(input_size, num_periods=2, hidden_size=64, num_layers=3, output_size=1, dropout=0.2)
|
|
@@ -34,11 +31,28 @@ def convert_date_format(date_str):
|
|
|
return dt
|
|
return dt
|
|
|
# return dt.strftime('%Y%m%d%H%M00')
|
|
# return dt.strftime('%Y%m%d%H%M00')
|
|
|
|
|
|
|
|
-def start_predict():
|
|
|
|
|
|
|
+def start_predict(interval_hours):
|
|
|
|
|
+
|
|
|
|
|
+ print(f"开始预测,间隔小时数: {interval_hours}")
|
|
|
|
|
+
|
|
|
|
|
+ output_dir = "./data_shards"
|
|
|
|
|
+ photo_dir = "./photo"
|
|
|
|
|
+ predict_dir = "./predictions"
|
|
|
|
|
+
|
|
|
|
|
+ if interval_hours == 4:
|
|
|
|
|
+ output_dir = "./data_shards_4"
|
|
|
|
|
+ photo_dir = "./photo_4"
|
|
|
|
|
+ predict_dir = "./predictions_4"
|
|
|
|
|
+
|
|
|
|
|
+ elif interval_hours == 2:
|
|
|
|
|
+ output_dir = "./data_shards_2"
|
|
|
|
|
+ photo_dir = "./photo_2"
|
|
|
|
|
+ predict_dir = "./predictions_2"
|
|
|
|
|
|
|
|
# 确保目录存在
|
|
# 确保目录存在
|
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
os.makedirs(photo_dir, exist_ok=True)
|
|
os.makedirs(photo_dir, exist_ok=True)
|
|
|
|
|
+ os.makedirs(predict_dir, exist_ok=True)
|
|
|
|
|
|
|
|
# 清空上一次预测结果
|
|
# 清空上一次预测结果
|
|
|
# csv_file_list = ['future_predictions.csv']
|
|
# csv_file_list = ['future_predictions.csv']
|
|
@@ -61,9 +75,15 @@ def start_predict():
|
|
|
pred_time_str = hourly_time.strftime("%Y%m%d%H%M")
|
|
pred_time_str = hourly_time.strftime("%Y%m%d%H%M")
|
|
|
print(f"预测时间(取整): {pred_time_str}")
|
|
print(f"预测时间(取整): {pred_time_str}")
|
|
|
|
|
|
|
|
- # 预测时间范围,满足起飞时间 在28小时后到40小时后
|
|
|
|
|
|
|
+ current_n_hours = 36
|
|
|
|
|
+ if interval_hours == 4:
|
|
|
|
|
+ current_n_hours = 32
|
|
|
|
|
+ elif interval_hours == 2:
|
|
|
|
|
+ current_n_hours = 30
|
|
|
|
|
+
|
|
|
|
|
+ # 预测时间范围,满足起飞时间 在28小时后到36/32/30小时后
|
|
|
pred_hour_begin = hourly_time + timedelta(hours=28)
|
|
pred_hour_begin = hourly_time + timedelta(hours=28)
|
|
|
- pred_hour_end = hourly_time + timedelta(hours=40)
|
|
|
|
|
|
|
+ pred_hour_end = hourly_time + timedelta(hours=current_n_hours)
|
|
|
|
|
|
|
|
pred_date_end = pred_hour_end.strftime("%Y-%m-%d")
|
|
pred_date_end = pred_hour_end.strftime("%Y-%m-%d")
|
|
|
pred_date_begin = pred_hour_begin.strftime("%Y-%m-%d")
|
|
pred_date_begin = pred_hour_begin.strftime("%Y-%m-%d")
|
|
@@ -134,7 +154,7 @@ def start_predict():
|
|
|
# 创建临时字段:seg1_dep_time 的整点时间
|
|
# 创建临时字段:seg1_dep_time 的整点时间
|
|
|
df_test['seg1_dep_hour'] = df_test['seg1_dep_time'].dt.floor('h')
|
|
df_test['seg1_dep_hour'] = df_test['seg1_dep_time'].dt.floor('h')
|
|
|
# 使用整点时间进行比较过滤
|
|
# 使用整点时间进行比较过滤
|
|
|
- mask = (df_test['seg1_dep_hour'] >= pred_hour_begin) & (df_test['seg1_dep_hour'] <= pred_hour_end)
|
|
|
|
|
|
|
+ mask = (df_test['seg1_dep_hour'] >= pred_hour_begin) & (df_test['seg1_dep_hour'] < pred_hour_end)
|
|
|
original_count = len(df_test)
|
|
original_count = len(df_test)
|
|
|
df_test = df_test[mask].reset_index(drop=True)
|
|
df_test = df_test[mask].reset_index(drop=True)
|
|
|
filtered_count = len(df_test)
|
|
filtered_count = len(df_test)
|
|
@@ -147,7 +167,7 @@ def start_predict():
|
|
|
continue
|
|
continue
|
|
|
|
|
|
|
|
# 数据预处理
|
|
# 数据预处理
|
|
|
- df_test_inputs = preprocess_data(df_test, features, categorical_features, is_training=False)
|
|
|
|
|
|
|
+ df_test_inputs = preprocess_data(df_test, features, categorical_features, is_training=False, current_n_hours=current_n_hours)
|
|
|
|
|
|
|
|
total_rows = df_test_inputs.shape[0]
|
|
total_rows = df_test_inputs.shape[0]
|
|
|
print(f"行数: {total_rows}")
|
|
print(f"行数: {total_rows}")
|
|
@@ -165,11 +185,21 @@ def start_predict():
|
|
|
|
|
|
|
|
# 标准化与归一化处理
|
|
# 标准化与归一化处理
|
|
|
df_test_inputs, feature_scaler, _ = standardization(df_test_inputs, feature_scaler, is_training=False)
|
|
df_test_inputs, feature_scaler, _ = standardization(df_test_inputs, feature_scaler, is_training=False)
|
|
|
-
|
|
|
|
|
print("标准化后数据样本:\n", df_test_inputs.head())
|
|
print("标准化后数据样本:\n", df_test_inputs.head())
|
|
|
|
|
|
|
|
|
|
+ threshold = current_n_hours
|
|
|
|
|
+ input_length = 444
|
|
|
|
|
+
|
|
|
|
|
+ # 确保 threshold 与 input_length 之合为 480
|
|
|
|
|
+ if threshold == 36:
|
|
|
|
|
+ input_length = 444
|
|
|
|
|
+ elif threshold == 32:
|
|
|
|
|
+ input_length = 448
|
|
|
|
|
+ elif threshold == 30:
|
|
|
|
|
+ input_length = 450
|
|
|
|
|
+
|
|
|
# 生成序列
|
|
# 生成序列
|
|
|
- sequences, _, group_ids = create_fixed_length_sequences(df_test_inputs, features, target_vars, is_train=False)
|
|
|
|
|
|
|
+ sequences, _, group_ids = create_fixed_length_sequences(df_test_inputs, features, target_vars, threshold, input_length, is_train=False)
|
|
|
print(f"序列数量:{len(sequences)}")
|
|
print(f"序列数量:{len(sequences)}")
|
|
|
|
|
|
|
|
#----- 新增:智能模型加载 -----#
|
|
#----- 新增:智能模型加载 -----#
|
|
@@ -192,7 +222,7 @@ def start_predict():
|
|
|
|
|
|
|
|
target_scaler = None
|
|
target_scaler = None
|
|
|
# 预测未来数据
|
|
# 预测未来数据
|
|
|
- predict_future_distribute(model, sequences, group_ids, target_scaler=target_scaler, output_dir=output_dir, pred_time_str=pred_time_str)
|
|
|
|
|
|
|
+ predict_future_distribute(model, sequences, group_ids, target_scaler=target_scaler, predict_dir=predict_dir, pred_time_str=pred_time_str)
|
|
|
|
|
|
|
|
print("所有批次的预测结束")
|
|
print("所有批次的预测结束")
|
|
|
print()
|
|
print()
|
|
@@ -213,4 +243,8 @@ def start_predict():
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|
|
|
- start_predict()
|
|
|
|
|
|
|
+ parser = argparse.ArgumentParser(description='预测脚本')
|
|
|
|
|
+ parser.add_argument('--interval', type=int, choices=[2, 4, 8],
|
|
|
|
|
+ default=8, help='间隔小时数(2, 4, 8)')
|
|
|
|
|
+ args = parser.parse_args()
|
|
|
|
|
+ start_predict(args.interval)
|