Przeglądaj źródła

调整导入与预测相关

node04 1 miesiąc temu
rodzic
commit
6ae8a105c3
4 zmienionych plików z 25 dodań i 18 usunięć
  1. 3 3
      data_process.py
  2. 2 2
      main_pe.py
  3. 2 2
      main_tr.py
  4. 18 11
      uo_atlas_import.py

+ 3 - 3
data_process.py

@@ -175,7 +175,7 @@ def preprocess_data_simple(df_input, is_train=False, hourly_time=None):
     return df_input, None, None, None
  
 
-def predict_data_simple(df_input, city_pair, output_dir, predict_dir=".", pred_time_str=""):
+def predict_data_simple(df_input, city_pair, object_dir, predict_dir=".", pred_time_str=""):
     if df_input is None or df_input.empty:
         return pd.DataFrame()
 
@@ -195,14 +195,14 @@ def predict_data_simple(df_input, city_pair, output_dir, predict_dir=".", pred_t
     )
 
     # 读历史降价场景
-    drop_info_csv_path = os.path.join(output_dir, f'{city_pair}_drop_info.csv')
+    drop_info_csv_path = os.path.join(object_dir, f'{city_pair}_drop_info.csv')
     if os.path.exists(drop_info_csv_path):
         df_drop_nodes = pd.read_csv(drop_info_csv_path)
     else:
         df_drop_nodes = pd.DataFrame()
 
     # 读历史升价场景
-    rise_info_csv_path = os.path.join(output_dir, f'{city_pair}_rise_info.csv')
+    rise_info_csv_path = os.path.join(object_dir, f'{city_pair}_rise_info.csv')
     if os.path.exists(rise_info_csv_path):
         df_rise_nodes = pd.read_csv(rise_info_csv_path)
     else:

+ 2 - 2
main_pe.py

@@ -9,7 +9,7 @@ from data_process import preprocess_data_simple, predict_data_simple
 def start_predict():
     print(f"开始预测")
 
-    output_dir = "./data_shards"
+    object_dir = "./data_shards"
     predict_dir = "./predictions"
 
     os.makedirs(predict_dir, exist_ok=True)
@@ -85,7 +85,7 @@ def start_predict():
 
         df_test_inputs, _, _, _,  = preprocess_data_simple(df_test, is_train=False, hourly_time=hourly_time)
 
-        df_predict = predict_data_simple(df_test_inputs, uo_city_pair, output_dir, predict_dir, hourly_time_str)
+        df_predict = predict_data_simple(df_test_inputs, uo_city_pair, object_dir, predict_dir, hourly_time_str)
         
         del df_test_inputs
         del df_predict

+ 2 - 2
main_tr.py

@@ -20,8 +20,8 @@ def start_train():
     cpu_cores = os.cpu_count()  # 你的系统是72
     max_workers = min(8, cpu_cores)  # 最大不超过8个进程
 
-    from_date_end = (datetime.today() - timedelta(days=0)).strftime("%Y-%m-%d")  # 截止日改为今
-    from_date_begin = "2026-03-27"
+    from_date_end = (datetime.today() - timedelta(days=1)).strftime("%Y-%m-%d")  # 截止日改为昨
+    from_date_begin = "2026-03-17"  # 2026-03-17 2026-03-30
 
     print(f"训练时间范围: {from_date_begin} 到 {from_date_end}")
 

+ 18 - 11
uo_atlas_import.py

@@ -1,7 +1,7 @@
 import os
 import random
 import time
-from datetime import datetime
+from datetime import datetime, timedelta
 import pymongo
 from pymongo.errors import PyMongoError, ServerSelectionTimeoutError, BulkWriteError
 # import pandas as pd
@@ -217,27 +217,34 @@ def main_import_process(create_at_begin, create_at_end):
 
     uo_city_pairs = uo_city_pairs_new.copy()
 
+    # 调试分支
+    # uo_city_pairs = uo_city_pairs[47:48]
+
     for idx, city_pair in enumerate(uo_city_pairs):
         atlas_client, atlas_db = mongo_con_parse(atlas_config)
         mongo_client, mongo_db = mongo_con_parse(mongo_config)
+        
         print(f"开始处理航线 {idx+1}/{len(uo_city_pairs)}: {city_pair}")    
         import_flight_range_status(atlas_db, mongo_db, city_pair, create_at_begin_stamp, create_at_end_stamp)
         print(f"结束处理航线 {idx+1}/{len(uo_city_pairs)}: {city_pair}")
+
         atlas_client.close()
         mongo_client.close()
     pass
     print("整体结束")
-    print()
+    
 
 if __name__ == "__main__":
-    create_at_begin = "2026-03-27 10:00:00"
-    create_at_end = "2026-03-27 15:59:59"
-    main_import_process(create_at_begin, create_at_end)
+    print(f"本次导入开始时间: {datetime.now()}")
     
-    # try:
-    #     client, db = mongo_con_parse(mongo_atlas_config)
-    #     print(f"✅ 数据库连接创建成功")
-    # except Exception as e:
-    #     print(f"❌ 数据库连接创建失败: {e}")
-    #     db = None
+    current_time = datetime.now()
+    create_at_end = current_time.strftime("%Y-%m-%d %H:%M:%S")
+    create_at_begin = (current_time - timedelta(days=1)).strftime("%Y-%m-%d %H:%M:%S")
+
+    # create_at_begin = "2026-03-21 00:00:00"
+    # create_at_end = "2026-03-29 23:59:59"
+
+    main_import_process(create_at_begin, create_at_end)
     
+    print(f"本次导入结束时间: {datetime.now()}")
+    print()