Jelajahi Sumber

训练与预测细节调整

node04 3 minggu lalu
induk
melakukan
615e053f61
5 mengubah file dengan 43 tambahan dan 14 penghapusan
  1. 1 1
      data_loader.py
  2. 4 2
      main_pe.py
  3. 1 1
      main_tr.py
  4. 28 6
      predict.py
  5. 9 4
      result_validate.py

+ 1 - 1
data_loader.py

@@ -808,7 +808,7 @@ def load_train_data(db_config, flight_route_list, table_name, date_begin, date_e
         print(f"该航线共有{all_groups_len}个航班号")
         
         if use_multiprocess and all_groups_len > 1:
-            print(f"启用多线程处理,最大线程数: {max_workers}")
+            print(f"启用多进程处理,最大进程数: {max_workers}")
             # 多进程处理
             process_args = []
             process_id = 0

+ 4 - 2
main_pe.py

@@ -71,9 +71,10 @@ def start_predict(interval_hours):
 
     # 当前时间,取整时
     current_time = datetime.now() 
+    current_time_str = current_time.strftime("%Y%m%d%H%M")
     hourly_time = current_time.replace(minute=0, second=0, microsecond=0)
     pred_time_str = hourly_time.strftime("%Y%m%d%H%M")
-    print(f"预测时间(取整): {pred_time_str}")
+    print(f"预测时间:{current_time_str}, (取整): {pred_time_str}")
 
     current_n_hours = 36
     if interval_hours == 4:
@@ -222,7 +223,8 @@ def start_predict(interval_hours):
 
         target_scaler = None
         # 预测未来数据
-        predict_future_distribute(model, sequences, group_ids, target_scaler=target_scaler, predict_dir=predict_dir, pred_time_str=pred_time_str)
+        predict_future_distribute(model, sequences, group_ids, target_scaler=target_scaler, 
+                                  interval_hours=interval_hours, predict_dir=predict_dir, pred_time_str=pred_time_str)
 
     print("所有批次的预测结束")
     print()

+ 1 - 1
main_tr.py

@@ -464,7 +464,7 @@ def start_train():
             y_trues_class_labels = df['Actual_Will_Price_Drop']
             y_preds_class_labels = df['Predicted_Will_Price_Drop']
 
-            printScore_cc(y_trues_class_labels, y_preds_class_labels, batch_fn_str='validate', batch_idx='')
+            printScore_cc(y_trues_class_labels, y_preds_class_labels, batch_fn_str='validate', batch_idx='', photo_dir=photo_dir)
 
     if FLAG_Distributed:
         dist.destroy_process_group()   # 显式调用 destroy_process_group 来清理 NCCL 的进程组资源

+ 28 - 6
predict.py

@@ -7,7 +7,7 @@ from torch.utils.data import DataLoader
 from utils import FlightDataset
 
 
-def predict_future_distribute(model, sequences, group_ids, batch_size=16, target_scaler=None, predict_dir=".", pred_time_str=""):
+def predict_future_distribute(model, sequences, group_ids, batch_size=16, target_scaler=None, interval_hours=8, predict_dir=".", pred_time_str=""):
     if not sequences:
         print("没有足够的数据进行预测。")
         return
@@ -49,14 +49,36 @@ def predict_future_distribute(model, sequences, group_ids, batch_size=16, target
     })
 
     # 先转成 datetime
-    update_hour_dt = pd.to_datetime(results_df['update_hour'])   # 起飞前36小时对应时间(整点)
-    valid_begin_dt = update_hour_dt + pd.Timedelta(hours=8)      # 起飞前28小时(36-8=28)对应时间(整点) 
+    update_hour_dt = pd.to_datetime(results_df['update_hour'])                # 起飞前36小时对应时间(整点)
+    valid_begin_dt = update_hour_dt + pd.Timedelta(hours=interval_hours)      # 起飞前28小时(36-8=28)(32-4=28)(30-2=28)对应时间(整点) 
+    valid_end_dt = valid_begin_dt + pd.Timedelta(hours=24)                    # 起飞前4小时(28-24=4) 
 
-    # 在 probability 前新增一列
+    # 统一格式化
+    valid_begin_str = valid_begin_dt.dt.strftime('%Y-%m-%d %H:%M:%S')
+    valid_end_str = valid_end_dt.dt.strftime('%Y-%m-%d %H:%M:%S')
+
+    # probability 列的位置
+    prob_col_idx = results_df.columns.get_loc('probability')
+    
+    # interval_hours(统一数值)
+    results_df.insert(
+        loc=prob_col_idx,
+        column='interval_hours',
+        value=interval_hours
+    )
+
+    # valid_begin_hour
     results_df.insert(
-        loc=results_df.columns.get_loc('probability'),
+        loc=prob_col_idx + 1,   # 原 probability 列的位置 加1
         column='valid_begin_hour',
-        value=valid_begin_dt.dt.strftime('%Y-%m-%d %H:%M:%S')
+        value=valid_begin_str
+    )
+
+    # valid_end_hour
+    results_df.insert(
+        loc=prob_col_idx + 2,   # 原 probability 列的位置 加2
+        column='valid_end_hour',
+        value=valid_end_str
     )
     
     # 数值处理

+ 9 - 4
result_validate.py

@@ -4,7 +4,7 @@ import pandas as pd
 from data_loader import mongo_con_parse, validate_one_line, fill_hourly_crawl_date
 
 
-def validate_process(node, pred_time_str):
+def validate_process(node, interval_hours, pred_time_str):
 
     date = pred_time_str[4:8]
 
@@ -12,6 +12,11 @@ def validate_process(node, pred_time_str):
     os.makedirs(output_dir, exist_ok=True)
 
     object_dir = "./predictions"
+    if interval_hours == 4:
+        object_dir = "./predictions_4"
+    elif interval_hours == 2:
+        object_dir = "./predictions_2"
+
     csv_file = f'future_predictions_{pred_time_str}.csv'  
     csv_path = os.path.join(object_dir, csv_file)
 
@@ -107,7 +112,7 @@ def validate_process(node, pred_time_str):
     client.close()
 
     timestamp_str = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
-    save_scv = f"result_validate_{node}_{pred_time_str}_{timestamp_str}.csv"
+    save_scv = f"result_validate_{node}_{interval_hours}_{pred_time_str}_{timestamp_str}.csv"
     
     output_path = os.path.join(output_dir, save_scv)
     df_predict.to_csv(output_path, index=False, encoding="utf-8-sig")
@@ -115,5 +120,5 @@ def validate_process(node, pred_time_str):
 
 
 if __name__ == "__main__":
-    node, pred_time_str = "node0108", "202601121000"
-    validate_process(node, pred_time_str)
+    node, interval_hours, pred_time_str = "node0112", 8, "202601141600"
+    validate_process(node, interval_hours, pred_time_str)