Procházet zdrojové kódy

调整预测策略与相关验证

node04 před 1 měsícem
rodič
revize
e89201c7bb
3 změnil soubory, kde provedl 49 přidání a 21 odebrání
  1. 39 11
      main_pe.py
  2. 4 4
      predict.py
  3. 6 6
      result_validate.py

+ 39 - 11
main_pe.py

@@ -41,19 +41,31 @@ def start_predict():
     os.makedirs(photo_dir, exist_ok=True)
 
     # 清空上一次预测结果
-    csv_file_list = ['future_predictions.csv']
+    # csv_file_list = ['future_predictions.csv']
 
-    for csv_file in csv_file_list:
-        try:
-            csv_path = os.path.join(output_dir, csv_file)
-            os.remove(csv_path)
-        except Exception as e:
-            print(f"remove {csv_path} error: {str(e)}")
+    # for csv_file in csv_file_list:
+    #     try:
+    #         csv_path = os.path.join(output_dir, csv_file)
+    #         os.remove(csv_path)
+    #     except Exception as e:
+    #         print(f"remove {csv_path} error: {str(e)}")
 
     model, _ = initialize_model()
 
-    date_end = (datetime.today() + timedelta(days=2)).strftime("%Y-%m-%d")
-    date_begin = (datetime.today() + timedelta(days=2)).strftime("%Y-%m-%d")
+    # 当前时间,取整时
+    current_time = datetime.now() 
+    hourly_time = current_time.replace(minute=0, second=0, microsecond=0)
+    pred_time_str = hourly_time.strftime("%Y%m%d%H%M")
+
+    # 预测时间范围,满足起飞时间 在28小时后到40小时后
+    pred_hour_begin = hourly_time + timedelta(hours=28)
+    pred_hour_end = hourly_time + timedelta(hours=40)
+
+    pred_date_end = pred_hour_end.strftime("%Y-%m-%d")
+    pred_date_begin = pred_hour_begin.strftime("%Y-%m-%d")
+    
+    # date_end = (datetime.today() + timedelta(days=2)).strftime("%Y-%m-%d")
+    # date_begin = (datetime.today() + timedelta(days=2)).strftime("%Y-%m-%d")
 
     # 加载 scaler 列表
     feature_scaler_path = os.path.join(output_dir, 'feature_scalers.joblib')
@@ -102,7 +114,7 @@ def start_predict():
         
         # 加载测试数据 (仅仅是时间段取到后天)
         start_time = time.time()
-        df_test = load_train_data(mongodb_config, group_route_list, table_name, date_begin, date_end, output_dir, is_hot)
+        df_test = load_train_data(mongodb_config, group_route_list, table_name, pred_date_begin, pred_date_end, output_dir, is_hot)
         end_time = time.time()
         run_time = round(end_time - start_time, 3)
         print(f"用时: {run_time} 秒")
@@ -113,6 +125,22 @@ def start_predict():
             print(f"测试数据为空,跳过此批次。")
             continue
 
+        # 按起飞时间过滤
+        # 创建临时字段:seg1_dep_time 的整点时间
+        df_test['seg1_dep_hour'] = df_test['seg1_dep_time'].dt.floor('h')
+        # 使用整点时间进行比较过滤
+        mask = (df_test['seg1_dep_hour'] >= pred_hour_begin) & (df_test['seg1_dep_hour'] <= pred_hour_end)
+        original_count = len(df_test)
+        df_test = df_test[mask].reset_index(drop=True)
+        filtered_count = len(df_test)
+        # 删除临时字段
+        df_test = df_test.drop(columns=['seg1_dep_hour'])
+        print(f"按起飞时间过滤:过滤前 {original_count} 条,过滤后 {filtered_count} 条")
+        
+        if filtered_count == 0:
+            print(f"起飞时间在 {pred_hour_begin} 到 {pred_hour_end} 之间没有航班,跳过此批次。")
+            continue
+
         # 数据预处理
         df_test_inputs = preprocess_data(df_test, features, categorical_features, is_training=False)
         
@@ -159,7 +187,7 @@ def start_predict():
 
         target_scaler = None
         # 预测未来数据
-        predict_future_distribute(model, sequences, group_ids, target_scaler=target_scaler, output_dir=output_dir)
+        predict_future_distribute(model, sequences, group_ids, target_scaler=target_scaler, output_dir=output_dir, pred_time_str=pred_time_str)
 
     print("所有批次的预测结束")
     # 所有批次的预测结束后, 统一过滤处理

+ 4 - 4
predict.py

@@ -7,7 +7,7 @@ from torch.utils.data import DataLoader
 from utils import FlightDataset
 
 
-def predict_future_distribute(model, sequences, group_ids, batch_size=16, target_scaler=None, output_dir="."):
+def predict_future_distribute(model, sequences, group_ids, batch_size=16, target_scaler=None, output_dir=".", pred_time_str=""):
     if not sequences:
         print("没有足够的数据进行预测。")
         return
@@ -49,8 +49,8 @@ def predict_future_distribute(model, sequences, group_ids, batch_size=16, target
     })
 
     # 先转成 datetime
-    update_hour_dt = pd.to_datetime(results_df['update_hour'])   # 起飞前48小时对应时间(整点)
-    valid_begin_dt = update_hour_dt + pd.Timedelta(hours=20)     # 起飞前28小时(48-20=28)对应时间(整点) 
+    update_hour_dt = pd.to_datetime(results_df['update_hour'])   # 起飞前36小时对应时间(整点)
+    valid_begin_dt = update_hour_dt + pd.Timedelta(hours=8)      # 起飞前28小时(36-8=28)对应时间(整点) 
 
     # 在 probability 前新增一列
     results_df.insert(
@@ -65,7 +65,7 @@ def predict_future_distribute(model, sequences, group_ids, batch_size=16, target
     for col in numeric_columns:
         results_df[col] = results_df[col].where(results_df[col].abs() >= threshold, 0)
     
-    csv_path1 = os.path.join(output_dir, 'future_predictions.csv')
+    csv_path1 = os.path.join(output_dir, f'future_predictions_{pred_time_str}.csv')
     results_df.to_csv(csv_path1, mode='a', index=False, header=not os.path.exists(csv_path1), encoding='utf-8-sig')
 
     print("预测结果已追加")

+ 6 - 6
result_validate.py

@@ -4,13 +4,13 @@ import pandas as pd
 from data_loader import mongo_con_parse, validate_one_line, fill_hourly_crawl_date
 
 
-def validate_process(node, date):
+def validate_process(node, date, pred_time_str):
 
     output_dir = f"./validate/{node}_{date}"
     os.makedirs(output_dir, exist_ok=True)
 
     object_dir = "./data_shards"
-    csv_file = 'future_predictions.csv'  
+    csv_file = f'future_predictions_{pred_time_str}.csv'  
     csv_path = os.path.join(object_dir, csv_file)
 
     try:
@@ -105,13 +105,13 @@ def validate_process(node, date):
     client.close()
 
     timestamp_str = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
-    save_scv = f"result_validate_{node}_{date}_{timestamp_str}.csv"
-
+    save_scv = f"result_validate_{node}_{pred_time_str}_{timestamp_str}.csv"
+    
     output_path = os.path.join(output_dir, save_scv)
     df_predict.to_csv(output_path, index=False, encoding="utf-8-sig")
     print(f"保存完成: {output_path}")
 
 
 if __name__ == "__main__":
-    node, date = "node0105", "0107"
-    validate_process(node, date)
+    node, date, pred_time_str = "node0108", "0109", "202601091100"
+    validate_process(node, date, pred_time_str)