Parcourir la source

调整VJ的mongo查询机制, 减小查询负载

node04 il y a 1 mois
Parent
commit
1a248901aa
1 fichiers modifiés avec 264 ajouts et 179 suppressions
  1. 264 179
      data_loader.py

+ 264 - 179
data_loader.py

@@ -6,6 +6,7 @@ from pymongo.errors import PyMongoError, ServerSelectionTimeoutError
 import pandas as pd
 import os
 import random
+import tempfile
 from concurrent.futures import ProcessPoolExecutor, as_completed
 import numpy as np
 import matplotlib.pyplot as plt
@@ -18,24 +19,41 @@ from config import mongodb_config, vj_flight_route_list, vj_flight_route_list_ho
 font_path = "./simhei.ttf"
 font_prop = font_manager.FontProperties(fname=font_path)
 
-def mongo_con_parse(config=None):
+_MONGO_SHARED_CLIENT = None
+_MONGO_SHARED_DB = None
+_MONGO_SHARED_CFG_KEY = None
+
+
+def mongo_con_parse(config=None, reuse_client=False):
     if config is None:
         config = mongodb_config.copy()
 
+    global _MONGO_SHARED_CLIENT, _MONGO_SHARED_DB, _MONGO_SHARED_CFG_KEY
+
+    cfg_key = (
+        config.get("URI", ""),
+        config.get("host", ""),
+        config.get("port", ""),
+        config.get("db", ""),
+        config.get("user", ""),
+    )
+
+    if reuse_client and _MONGO_SHARED_CLIENT is not None and _MONGO_SHARED_DB is not None and _MONGO_SHARED_CFG_KEY == cfg_key:
+        return _MONGO_SHARED_CLIENT, _MONGO_SHARED_DB
+    
     try:
         if config.get("URI", ""):
             motor_uri = config["URI"]
             client = pymongo.MongoClient(motor_uri, maxPoolSize=100)
             db = client[config['db']]
-            print("motor_uri: ", motor_uri)
         else:
             client = pymongo.MongoClient(
                 config['host'],
                 config['port'],
-                serverSelectionTimeoutMS=15000,  # 6秒
-                connectTimeoutMS=15000,  # 6秒
-                socketTimeoutMS=15000,  # 6秒,
-                retryReads=True,    # 开启重试
+                serverSelectionTimeoutMS=30000,
+                connectTimeoutMS=30000,
+                socketTimeoutMS=30000,
+                retryReads=True,
                 maxPoolSize=50
             )
             db = client[config['db']]
@@ -47,6 +65,12 @@ def mongo_con_parse(config=None):
     except Exception as e:
         print(f"❌ 创建 MongoDB 连接对象时发生错误: {e}")
         raise
+
+    if reuse_client:
+        _MONGO_SHARED_CLIENT = client
+        _MONGO_SHARED_DB = db
+        _MONGO_SHARED_CFG_KEY = cfg_key
+
     return client, db
 
 
@@ -75,7 +99,6 @@ def query_flight_range_status(db, table_name, from_city, to_city, dep_date_begin
     for attempt in range(1, max_retries + 1):
         try:
             print(f"🔁 第 {attempt}/{max_retries} 次尝试查询")
-            # 构建查询条件
             query_condition = {
                 "from_city_code": from_city,
                 "to_city_code": to_city,
@@ -83,17 +106,13 @@ def query_flight_range_status(db, table_name, from_city, to_city, dep_date_begin
                     "$gte": dep_date_begin,
                     "$lte": dep_date_end,
                 },
-                # "segments.baggage": {"$in": ["-;-;-;-", "1-30"]}  # 无行李,30公斤行李
-                "segments.baggage": "-;-;-;-"
             }
-            # 动态添加航班号条件
-            for i, flight_num in enumerate(flight_nums):
-                query_condition[f"segments.{i}.flight_number"] = flight_num
-    
-            print(f"   查询条件: {query_condition}")
-            # 定义要查询的字段
+                
+            baggage_filter = 0
+            # flight_nums_filter = list(flight_nums) if flight_nums else []
+
+            print(f"   查询条件(走索引): {query_condition}")
             projection = {
-                # "_id": 1,
                 "from_city_code": 1,
                 "search_dep_time": 1,
                 "to_city_code": 1,
@@ -106,19 +125,13 @@ def query_flight_range_status(db, table_name, from_city, to_city, dep_date_begin
                 "source_website": 1,
                 "crawl_date": 1
             }
-            # 执行查询
-            cursor = db.get_collection(table_name).find(
-                query_condition,
-                projection=projection  # 添加投影参数
-            ).sort(
-                [
-                    ("search_dep_time", 1),  # 多级排序要用列表+元组的格式
-                    ("segments.0.baggage", 1),
-                    ("crawl_date", 1)
-                ]
+            
+            cursor = (
+                db.get_collection(table_name)
+                .find(query_condition, projection=projection)
+                .batch_size(5000)
+                .hint('from_city_code_1_to_city_code_1_search_dep_time_1')
             )
-            if limit > 0:
-                cursor = cursor.limit(limit)
 
             # 将结果转换为列表
             results = list(cursor)
@@ -140,7 +153,24 @@ def query_flight_range_status(db, table_name, from_city, to_city, dep_date_begin
                 print(f"用时: {rt} 秒")
                 print(f"📊 已将segments扩展成字段,形状: {df.shape}")
 
-                # 不用排序,因为mongo语句已经排好
+                if "baggage" in df.columns:
+                    df = df[df["baggage"] == baggage_filter]
+                
+                # for i, flight_num in enumerate(flight_nums_filter):
+                #     if flight_num is None or flight_num == "":
+                #         continue
+                #     col = f"seg{i + 1}_flight_number"
+                #     if col not in df.columns:
+                #         return pd.DataFrame()
+                #     df = df[df[col].astype("string") == str(flight_num)]
+                
+                # sort_cols = [c for c in ["search_dep_time", "baggage", "crawl_date"] if c in df.columns]
+                # if sort_cols:
+                #     df = df.sort_values(sort_cols, kind="mergesort").reset_index(drop=True)
+                
+                if limit > 0:
+                    df = df.head(limit).reset_index(drop=True)
+                    
                 return df
 
             else:
@@ -425,104 +455,95 @@ def fill_hourly_crawl_date(df, head_fill=0, rear_fill=0):
     return df
 
 
-def query_groups_of_city_code(db, from_city, to_city, table_name, min_days=20, max_retries=3, base_sleep=1.0):
+def query_groups_of_city_code(db, from_city, to_city, table_name, min_days=10, max_retries=3, base_sleep=1.0):
     """
     从一组城市对中查找所有分组(航班号与起飞时间)的组合
     按:第一段航班号 → 第二段航班号 → 起飞时间 排序
-    (失败自动重试) 保证2个月内至少有20天起飞的航线
+    (失败自动重试) 保证1个月内至少有10天起飞的航线
+
+    说明:为降低 Mongo 的聚合负担,这里只做轻量 find + 投影,把“按天统计/按航班组合汇总”的逻辑放到 pandas 侧处理。
     """
     print(f"{from_city}-{to_city} 查找所有分组")
-    date_begin = (datetime.today() - timedelta(days=60)).strftime("%Y%m%d")
+    date_begin = (datetime.today() - timedelta(days=31)).strftime("%Y%m%d")
     date_end = datetime.today().strftime("%Y%m%d")
-    pipeline = [
-        # 1️⃣ 先筛选城市对
-        {
-            "$match": {
-                "from_city_code": from_city,
-                "to_city_code": to_city,
-                "search_dep_time": {
-                    "$gte": date_begin,
-                    "$lte": date_end
-                }
-            }
-        },
-        # 2️⃣ 投影字段 + 拆第一、第二段航班号用于排序
-        {
-            "$project": {
-                "flight_numbers": "$segments.flight_number",
-                "search_dep_time": 1,
-                "fn1": {"$arrayElemAt": ["$segments.flight_number", 0]},
-                "fn2": {"$arrayElemAt": ["$segments.flight_number", 1]}
-            }
-        },
-        # 3️⃣ 第一级分组:组合 + 每一天
-        {
-            "$group": {
-                "_id": {
-                    "flight_numbers": "$flight_numbers",
-                    "search_dep_time": "$search_dep_time",
-                    "fn1": "$fn1",
-                    "fn2": "$fn2"
-                },
-                "count": {"$sum": 1}
-            }
-        },
-        # 关键修复点:这里先按【时间】排好序!
-        {
-            "$sort": {
-                "_id.fn1": 1,
-                "_id.fn2": 1,
-                "_id.search_dep_time": 1  # 确保 push 进去时是按天递增
-            }
-        },
-        # 4️⃣ 第二级分组:只按【航班组合】聚合 → 统计“有多少天”
-        {
-            "$group": {
-                "_id": {
-                    "flight_numbers": "$_id.flight_numbers",
-                    "fn1": "$_id.fn1",
-                    "fn2": "$_id.fn2"
-                },
-                "days": {"$sum": 1},  # 不同起飞天数
-                "details": {
-                    "$push": {
-                        "search_dep_time": "$_id.search_dep_time",
-                        "count": "$count"
-                    }
-                }
-            }
-        },
-        # 5️⃣ 关键:按“天数阈值”过滤
-        {
-            "$match": {
-                "days": {"$gte": min_days}
-            }
-        },
-        # 6️⃣ ✅ 按“第一段 → 第二段”排序
-        {
-            "$sort": {
-                "_id.fn1": 1,
-                "_id.fn2": 1,
-            }
-        }
-    ]
+    
+    query = {
+        "from_city_code": from_city,
+        "to_city_code": to_city,
+        "search_dep_time": {"$gte": date_begin, "$lte": date_end},
+    }
+    projection = {
+        "_id": 0,
+        "search_dep_time": 1,
+        "segments.flight_number": 1,
+    }
+
+    def _extract_flight_numbers(segments):
+        if not isinstance(segments, list):
+            return []
+        out = []
+        for seg in segments:
+            if not isinstance(seg, dict):
+                continue
+            fn = seg.get("flight_number")
+            if fn:
+                out.append(fn)
+        return out
+    
     for attempt in range(1, max_retries + 1):
         try:
             print(f"  第 {attempt}/{max_retries} 次尝试查询")
 
-            # 执行聚合查询
             collection = db[table_name]
-            results = list(collection.aggregate(pipeline))
+            cursor = collection.find(query, projection=projection).batch_size(5000).hint('from_city_code_1_to_city_code_1_search_dep_time_1')
+            docs = list(cursor)
+            if not docs:
+                return []
+            
+            df = pd.DataFrame.from_records(docs)
+            if df.empty or "segments" not in df.columns or "search_dep_time" not in df.columns:
+                return []
+            
+            df["flight_numbers"] = df["segments"].apply(_extract_flight_numbers)
+            df["fn1"] = df["flight_numbers"].str[0].fillna("")
+            df["fn2"] = df["flight_numbers"].str[1].fillna("")
+            df["flight_numbers_key"] = df["flight_numbers"].apply(lambda xs: ",".join(xs) if xs else "")
+
+            day_counts = (
+                df.groupby(["flight_numbers_key", "fn1", "fn2", "search_dep_time"], dropna=False)
+                .size()
+                .reset_index(name="count")
+                .sort_values(["fn1", "fn2", "search_dep_time"], kind="mergesort")
+                .reset_index(drop=True)
+            )
+
+            keys = ["flight_numbers_key", "fn1", "fn2"]
+            df_days = day_counts.groupby(keys, sort=False).size().reset_index(name="days")
+            df_details = (
+                day_counts.groupby(keys, sort=False)
+                .apply(lambda g: g[["search_dep_time", "count"]].to_dict("records"))
+                .reset_index(name="details")
+            )
+
+            df_result = df_days.merge(df_details, on=keys, how="inner")
+            df_result = df_result[df_result["days"] >= min_days].sort_values(["fn1", "fn2"], kind="mergesort")
 
-            # 格式化结果,将 _id 中的字段提取到外层
             formatted_results = []
-            for item in results:
-                formatted_item = {
-                    "flight_numbers": item["_id"]["flight_numbers"],
-                    "days": item["days"],       # 这个组合一共有多少天
-                    "details": item["details"]  # 每一天的 count 明细
-                }
-                formatted_results.append(formatted_item)
+            for _, row in df_result.iterrows():
+                flight_numbers = row["flight_numbers_key"].split(",") if row["flight_numbers_key"] else []
+                formatted_results.append(
+                    {
+                        "flight_numbers": flight_numbers,
+                        "days": int(row["days"]),
+                        "details": row["details"],
+                    }
+                )
+
+            del df_result
+            del df_details
+            del df_days
+            del df
+            # gc.collect()
 
             return formatted_results
 
@@ -640,45 +661,55 @@ def plot_c12_trend(df, output_dir="."):
     plt.close(fig)
 
 
+_ROUTE_CACHE_DF1 = None
+_ROUTE_CACHE_DF2 = None
+
+
+def _init_route_cache_worker(df1_pickle_path, df2_pickle_path):
+    global _ROUTE_CACHE_DF1, _ROUTE_CACHE_DF2
+    _ROUTE_CACHE_DF1 = pd.read_pickle(df1_pickle_path)
+    _ROUTE_CACHE_DF2 = pd.read_pickle(df2_pickle_path)
+
+
+def _filter_df_by_flight_nums(df, flight_nums):
+    if df is None or df.empty:
+        return pd.DataFrame()
+    
+    out = df
+    flight_nums_filter = list(flight_nums) if flight_nums else []
+    for i, flight_num in enumerate(flight_nums_filter):
+        if flight_num is None or flight_num == "":
+            continue
+        col = f"seg{i + 1}_flight_number"
+        if col not in out.columns:
+            return out.iloc[0:0].copy()
+        out = out[out[col].astype("string") == str(flight_num)]
+        if out.empty:
+            return out
+    
+    return out
+
 def process_flight_group(args):
-    """处理单个航班号的进程函数(独立数据库连接)"""
-    process_id, db_config, each_group, from_city, to_city, date_begin_s, date_end_s, is_hot, is_train, plot_flag, output_dir = args
+    """处理单个航班号的进程函数(基于主进程缓存的数据做 pandas 过滤与处理)"""
+    process_id, each_group, is_train, plot_flag, output_dir = args
     flight_nums = each_group.get("flight_numbers")
-    details = each_group.get("details")
+    # details = each_group.get("details")
 
     print(f"[进程{process_id}] 开始处理航班号: {flight_nums}")
 
-    # 为每个进程创建独立的数据库连接
     try:
-        client, db = mongo_con_parse(db_config)
-        print(f"[进程{process_id}] ✅ 数据库连接创建成功")
-    except Exception as e:
-        print(f"[进程{process_id}] ❌ 数据库连接创建失败: {e}")
-        return pd.DataFrame()
+        df1 = _filter_df_by_flight_nums(_ROUTE_CACHE_DF1, flight_nums)
+        df2 = _filter_df_by_flight_nums(_ROUTE_CACHE_DF2, flight_nums)
+
+        sort_cols = [c for c in ["search_dep_time", "baggage", "crawl_date"] if c in df1.columns]
+        if sort_cols:
+            df1 = df1.sort_values(sort_cols, kind="mergesort").reset_index(drop=True)
+            df2 = df2.sort_values(sort_cols, kind="mergesort").reset_index(drop=True)
 
-    try:
-        # 查询远期表
-        if is_hot == 1:
-            df1 = query_flight_range_status(db, CLEAN_VJ_HOT_FAR_INFO_TAB, from_city, to_city,
-                                            date_begin_s, date_end_s, flight_nums)
-        else:
-            df1 = query_flight_range_status(db, CLEAN_VJ_NOTHOT_FAR_INFO_TAB, from_city, to_city,
-                                            date_begin_s, date_end_s, flight_nums)
-        
-        # 保证远期表里有数据
         if df1.empty:
             print(f"[进程{process_id}] 航班号:{flight_nums} 远期表无数据, 跳过")
             return pd.DataFrame()
         
-        # 查询近期表
-        if is_hot == 1:
-            df2 = query_flight_range_status(db, CLEAN_VJ_HOT_NEAR_INFO_TAB, from_city, to_city,
-                                            date_begin_s, date_end_s, flight_nums)
-        else:
-            df2 = query_flight_range_status(db, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB, from_city, to_city,
-                                            date_begin_s, date_end_s, flight_nums)
-            
-        # 保证近期表里有数据
         if df2.empty:
             print(f"[进程{process_id}] 航班号:{flight_nums} 近期表无数据, 跳过")
             return pd.DataFrame()
@@ -781,12 +812,7 @@ def process_flight_group(args):
         print(f"[进程{process_id}] ❌ 处理航班号:{flight_nums} 时发生异常: {e}")
         return pd.DataFrame()
     finally:
-        # 确保关闭数据库连接
-        try:
-            client.close()
-            print(f"[进程{process_id}] ✅ 数据库连接已关闭")
-        except:
-            pass
+        pass
 
 
 def load_train_data(db_config, flight_route_list, table_name, date_begin, date_end, output_dir='.', is_hot=1, is_train=True, plot_flag=False,
@@ -805,25 +831,70 @@ def load_train_data(db_config, flight_route_list, table_name, date_begin, date_e
         print(f"开始处理航线: {route}")
 
         # 在主进程中查询航班号分组(避免多进程重复查询)
-        main_client, main_db = mongo_con_parse(db_config)
+        main_client, main_db = mongo_con_parse(db_config, reuse_client=True)
         all_groups = query_groups_of_city_code(main_db, from_city, to_city, table_name)
-        main_client.close()
-
+        
         all_groups_len = len(all_groups)
         print(f"该航线共有{all_groups_len}个航班号")
+
+        if all_groups_len == 0:
+            continue
+
+        # 查询远期表
+        if is_hot == 1:
+            df1 = query_flight_range_status(main_db, CLEAN_VJ_HOT_FAR_INFO_TAB, from_city, to_city,
+                                            date_begin_s, date_end_s, None)
+        else:
+            df1 = query_flight_range_status(main_db, CLEAN_VJ_NOTHOT_FAR_INFO_TAB, from_city, to_city,
+                                            date_begin_s, date_end_s, None)
         
-        if use_multiprocess and all_groups_len > 1:
-            print(f"启用多进程处理,最大进程数: {max_workers}")
-            # 多进程处理
-            process_args = []
-            process_id = 0
-            for each_group in all_groups:
-                process_id += 1
-                args = (process_id, db_config, each_group, from_city, to_city, date_begin_s, date_end_s, is_hot, is_train, plot_flag, output_dir)
-                process_args.append(args)
+        # 保证远期表里有数据
+        if df1.empty:
+            print(f"[主进程] 航线:{route} 远期表无数据, 跳过")
+            # main_client.close()
+            return pd.DataFrame()
+        
+        # 查询近期表
+        if is_hot == 1:
+            df2 = query_flight_range_status(main_db, CLEAN_VJ_HOT_NEAR_INFO_TAB, from_city, to_city,
+                                            date_begin_s, date_end_s, None)
+        else:
+            df2 = query_flight_range_status(main_db, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB, from_city, to_city,
+                                            date_begin_s, date_end_s, None)
             
-            with ProcessPoolExecutor(max_workers=max_workers) as executor:
-                future_to_group = {executor.submit(process_flight_group, args): each_group for args, each_group in zip(process_args, all_groups)}
+        # 保证近期表里有数据
+        if df2.empty:
+            print(f"[主进程] 航线:{route} 近期表无数据, 跳过")
+            # main_client.close()
+            return pd.DataFrame()
+        
+        # main_client.close()
+        
+        os.makedirs(output_dir, exist_ok=True)
+        safe_route = route.replace("-", "_")
+        df1_fd, df1_cache_path = tempfile.mkstemp(prefix=f"route_{safe_route}_{timestamp_str}_far_", suffix=".pkl", dir=output_dir)
+        df2_fd, df2_cache_path = tempfile.mkstemp(prefix=f"route_{safe_route}_{timestamp_str}_near_", suffix=".pkl", dir=output_dir)
+        os.close(df1_fd)
+        os.close(df2_fd)
+        df1.to_pickle(df1_cache_path)
+        df2.to_pickle(df2_cache_path)
+
+        try:
+            if use_multiprocess and all_groups_len > 1:
+                print(f"启用多进程处理,最大进程数: {max_workers}")
+                process_args = []
+                process_id = 0
+                for each_group in all_groups:
+                    process_id += 1
+                    args = (process_id, each_group, is_train, plot_flag, output_dir)
+                    process_args.append(args)
+                
+                with ProcessPoolExecutor(
+                    max_workers=max_workers,
+                    initializer=_init_route_cache_worker,
+                    initargs=(df1_cache_path, df2_cache_path),
+                ) as executor:
+                    future_to_group = {executor.submit(process_flight_group, args): each_group for args, each_group in zip(process_args, all_groups)}
                 
                 for future in as_completed(future_to_group):
                     each_group = future_to_group[future]
@@ -838,23 +909,37 @@ def load_train_data(db_config, flight_route_list, table_name, date_begin, date_e
                     except Exception as e:
                         print(f"❌ 航班号:{flight_nums} 处理异常: {e}")
 
-        else:
-            # 单进程处理(进程编号为0)
-            print("使用单进程处理")
-            process_id = 0
-            for each_group in all_groups:
-                args = (process_id, db_config, each_group, from_city, to_city, date_begin_s, date_end_s, is_hot, is_train, plot_flag, output_dir)
-                flight_nums = each_group.get("flight_numbers", "未知")
-                try:
-                    df_mid = process_flight_group(args)
-                    if not df_mid.empty:
-                        list_all.append(df_mid)
-                        print(f"✅ 航班号:{flight_nums} 处理完成")
-                    else:
-                        print(f"⚠️ 航班号:{flight_nums} 处理结果为空")
-                except Exception as e:
-                    print(f"❌ 航班号:{flight_nums} 处理异常: {e}")
-                
+            else:
+                # 单进程处理(进程编号为0)
+                print("使用单进程处理")
+                global _ROUTE_CACHE_DF1, _ROUTE_CACHE_DF2
+                _ROUTE_CACHE_DF1 = df1
+                _ROUTE_CACHE_DF2 = df2
+
+                process_id = 0
+                for each_group in all_groups:
+                    args = (process_id, each_group, is_train, plot_flag, output_dir)
+                    flight_nums = each_group.get("flight_numbers", "未知")
+                    try:
+                        df_mid = process_flight_group(args)
+                        if not df_mid.empty:
+                            list_all.append(df_mid)
+                            print(f"✅ 航班号:{flight_nums} 处理完成")
+                        else:
+                            print(f"⚠️ 航班号:{flight_nums} 处理结果为空")
+                    except Exception as e:
+                        print(f"❌ 航班号:{flight_nums} 处理异常: {e}")
+
+        finally:
+            try:
+                os.remove(df1_cache_path)
+            except Exception:
+                pass
+            try:
+                os.remove(df2_cache_path)
+            except Exception:
+                pass
+            
         print(f"结束处理航线: {from_city}-{to_city}")
 
     if list_all: