Jelajahi Sumber

调整mongo查询机制, 减小负载

node04 1 bulan lalu
induk
melakukan
b8b226753b
1 mengubah file dengan 75 tambahan dan 63 penghapusan
  1. 75 63
      data_loader.py

+ 75 - 63
data_loader.py

@@ -1,6 +1,7 @@
 import os
 import time
 import random
+import atexit
 from datetime import datetime, timedelta
 import gc
 from concurrent.futures import ProcessPoolExecutor, as_completed
@@ -16,6 +17,32 @@ from config import mongo_config, mongo_table_uo, uo_city_pairs_old, uo_city_pair
 font_path = "./simhei.ttf"
 font_prop = font_manager.FontProperties(fname=font_path)
 
+_worker_client = None
+_worker_db = None
+
+
+def _close_worker_mongo():
+    global _worker_client, _worker_db
+    if _worker_client is not None:
+        try:
+            _worker_client.close()
+        except Exception:
+            pass
+    _worker_client = None
+    _worker_db = None
+
+
+def init_worker_mongo(db_config):
+    global _worker_client, _worker_db
+    try:
+        _worker_client, _worker_db = mongo_con_parse(db_config)
+        atexit.register(_close_worker_mongo)
+        print("[worker] ✅ 数据库连接创建成功")
+    except Exception as e:
+        _worker_client = None
+        _worker_db = None
+        print(f"[worker] ❌ 数据库连接创建失败: {e}")
+
 
 def query_groups_of_city_pair(db, city_pair, table_name, min_days=10, max_retries=3, base_sleep=1.0):
     """根据city_pair查询航线, 筛选1个月内至少有10天起飞的航线"""
@@ -23,64 +50,48 @@ def query_groups_of_city_pair(db, city_pair, table_name, min_days=10, max_retrie
     date_begin = (datetime.today() - timedelta(days=30)).strftime("%Y-%m-%d")
     date_end = datetime.today().strftime("%Y-%m-%d")
     
-    # 聚合查询管道
-    pipeline = [
-        {
-            "$match": {
+    for attempt in range(1, max_retries + 1):
+        try:
+            print(f"  第 {attempt}/{max_retries} 次尝试查询")
+
+            collection = db[table_name]
+            query = {
                 "citypair": city_pair,
                 "from_date": {
                     "$gte": date_begin,
                     "$lte": date_end
                 }
             }
-        },
-        {
-            "$group": {
-                "_id": {
-                    "flight_numbers": "$flight_numbers",
-                    "from_date": "$from_date"
-                }
-            }
-        },
-        {
-            "$group": {
-                "_id": "$_id.flight_numbers",
-                "days": {"$sum": 1},
-                "details": {"$push": "$_id.from_date"}
-            }
-        },
-        {
-            "$match": {
-                "days": {"$gte": min_days}
-            }
-        },
-        {
-            "$addFields": {
-                "details": {"$sortArray": {"input": "$details", "sortBy": 1}}
+            projection = {
+                "_id": 0,
+                "flight_numbers": 1,
+                "from_date": 1
             }
-        },
-        {
-            "$sort": {"_id": 1}
-        }
-    ]
-    for attempt in range(1, max_retries + 1):
-        try:
-            print(f"  第 {attempt}/{max_retries} 次尝试查询")
 
-            # 执行聚合查询
-            collection = db[table_name]
-            results = list(collection.aggregate(pipeline))
+            raw_rows = list(collection.find(query, projection))
+            if not raw_rows:
+                return []
 
-            # 格式化结果,使字段名更清晰
-            formatted_results = [
-                {
-                    "flight_numbers": r["_id"],
-                    "days": r["days"],
-                    "flight_dates": r["details"]
-                }
-                for r in results
-            ]
+            df = pd.DataFrame(raw_rows)
+            if df.empty or 'flight_numbers' not in df.columns or 'from_date' not in df.columns:
+                return []
+            
+            df = df.dropna(subset=['flight_numbers', 'from_date'])
+            if df.empty:
+                return []
             
+            df = df.drop_duplicates(subset=['flight_numbers', 'from_date'])
+
+            df_grouped = (
+                df.groupby('flight_numbers', as_index=False)
+                .agg(days=('from_date', 'size'), flight_dates=('from_date', lambda s: sorted(s.tolist())))
+            )
+            df_grouped = df_grouped[df_grouped['days'] >= min_days].sort_values('flight_numbers').reset_index(drop=True)
+
+            if df_grouped.empty:
+                return []
+
+            formatted_results = df_grouped[['flight_numbers', 'days', 'flight_dates']].to_dict(orient='records')
             return formatted_results
 
         except (ServerSelectionTimeoutError, PyMongoError) as e:
@@ -420,16 +431,17 @@ def process_flight_numbers(args):
     process_id, db_config, city_pair, flight_numbers, from_date_begin, from_date_end, is_train, plot_flag, output_dir = args
     print(f"[进程{process_id}] 开始处理航班号: {flight_numbers}")
     
-    # 为每个进程创建独立的数据库连接
-    try:
-        client, db = mongo_con_parse(db_config)
-        print(f"[进程{process_id}] ✅ 数据库连接创建成功")
-    except Exception as e:
-        print(f"[进程{process_id}] ❌ 数据库连接创建失败: {e}")
-        return pd.DataFrame()
+    local_client = None
+    db = _worker_db
+    if db is None:
+        try:
+            local_client, db = mongo_con_parse(db_config)
+            print(f"[进程{process_id}] ✅ 数据库连接创建成功")
+        except Exception as e:
+            print(f"[进程{process_id}] ❌ 数据库连接创建失败: {e}")
+            return pd.DataFrame()
     
     try:
-        # 查询
         df1 = query_flight_range_status(db, mongo_table_uo, city_pair, flight_numbers, from_date_begin, from_date_end)
 
         if df1.empty:
@@ -488,12 +500,12 @@ def process_flight_numbers(args):
         print(f"[进程{process_id}] ❌ 处理航班号:{flight_numbers} 时发生异常: {e}")
         return pd.DataFrame()
     finally:
-        # 确保关闭数据库连接
-        try:
-            client.close()
-            print(f"[进程{process_id}] ✅ 数据库连接已关闭")
-        except:
-            pass
+        if local_client is not None:
+            try:
+                local_client.close()
+                print(f"[进程{process_id}] ✅ 数据库连接已关闭")
+            except Exception:
+                pass
 
 
 def load_data(db_config, city_pair, from_date_begin, from_date_end, is_train=True, plot_flag=False, output_dir='.', 
@@ -519,7 +531,7 @@ def load_data(db_config, city_pair, from_date_begin, from_date_end, is_train=Tru
             args = (process_id, db_config, city_pair, flight_numbers, from_date_begin, from_date_end, is_train, plot_flag, output_dir)
             process_args.append(args)
 
-        with ProcessPoolExecutor(max_workers=max_workers) as executor:
+        with ProcessPoolExecutor(max_workers=max_workers, initializer=init_worker_mongo, initargs=(db_config,)) as executor:
             future_to_group = {executor.submit(process_flight_numbers, args): each_group for args, each_group in zip(process_args, all_groups)}
             for future in as_completed(future_to_group):
                 each_group = future_to_group[future]