ソースを参照

增加data_loader内容

node04 5 日 前
コミット
a02330c112
1 ファイル変更141 行追加3 行削除
  1. 141 3
      data_loader.py

+ 141 - 3
data_loader.py

@@ -65,7 +65,7 @@ def query_groups_of_city_pair(db, city_pair, table_name, min_days=10, max_retrie
             # 格式化结果,使字段名更清晰
             formatted_results = [
                 {
-                    "flight_number": r["_id"],
+                    "flight_numbers": r["_id"],
                     "days": r["days"],
                     "flight_dates": r["details"]
                 }
@@ -86,14 +86,152 @@ def query_groups_of_city_pair(db, city_pair, table_name, min_days=10, max_retrie
             time.sleep(sleep_time)
 
 
-def load_data(db_config, city_pair, from_date_begin, from_date_end):
+def query_flight_range_status(db, table_name, city_pair, flight_numbers, from_date_begin, from_date_end, 
+                              limit=0, max_retries=3, base_sleep=1.0):
+    for attempt in range(1, max_retries + 1):
+        try:
+            print(f"🔁 第 {attempt}/{max_retries} 次尝试查询")
+            # 构建查询条件
+            projection = {
+                # "_id": 0   # 一般会关掉
+                "citypair": 1,
+                "flight_numbers": 1,
+                "from_date": 1,
+                "from_time": 1,
+                "create_time": 1,
+                "baggage_weight": 1,
+                "cabins": 1,
+                "ticket_amount": 1,
+                "currency": 1,
+                "price_total": 1
+            }            
+            pipeline = [
+                {
+                    "$match": {
+                        "citypair": city_pair,
+                        "flight_numbers": flight_numbers,
+                        "baggage_weight": {"$in": [0, 20]},
+                        "from_date": {
+                            "$gte": from_date_begin,
+                            "$lte": from_date_end
+                        }
+                    }
+                },
+                {
+                    "$project": projection   # 就是这里
+                },
+                {
+                    "$sort": {
+                        "from_date": 1,
+                        "baggage_weight": 1, 
+                        "create_time": 1
+                    }
+                }
+            ]
+            # print(f"   查询条件: {pipeline}")
+            # 执行查询
+            collection = db[table_name]
+            results = list(collection.aggregate(pipeline))
+            print(f"✅ 查询成功,找到 {len(results)} 条记录")
+
+            if results:
+                df = pd.DataFrame(results)
+                if '_id' in df.columns:
+                    df = df.drop(columns=['_id'])
+                
+                if 'from_time' in df.columns and 'from_date' in df.columns:
+                    from_time_raw = df['from_time']
+                    from_time_str = from_time_raw.fillna('').astype(str).str.strip()
+                    non_empty = from_time_str[from_time_str.ne('')]  # 找到原始 from_time 非空的记录
+
+                    extracted_time = non_empty.str.extract(r'(\d{2}:\d{2}:\d{2})$')[0].dropna()
+                    if not extracted_time.empty:
+                        more_time = extracted_time.value_counts().idxmax()   # 按众数分配给其它行 构造from_time
+                        missing_mask = from_time_raw.isna() | from_time_str.eq('')
+                        if missing_mask.any():
+                            df.loc[missing_mask, 'from_time'] = df.loc[missing_mask, 'from_date'].astype(str).str.strip() + ' ' + more_time
+                    else:
+                        # 无法得到起飞日期的抛弃
+                        return pd.DataFrame()
+
+                print(f"📊 已转换为 DataFrame,形状: {df.shape}")
+                return df
+
+            else:
+                print("⚠️  查询结果为空")
+                return pd.DataFrame()
+
+        except (ServerSelectionTimeoutError, PyMongoError) as e:
+            print(f"⚠️ Mongo 查询失败: {e}")
+            if attempt == max_retries:
+                print("❌ 达到最大重试次数,放弃")
+                return pd.DataFrame()
+            
+            # 指数退避 + 随机抖动
+            sleep_time = base_sleep * (2 ** (attempt - 1)) + random.random()
+            print(f"⏳ {sleep_time:.2f}s 后重试...")
+            time.sleep(sleep_time)
+
+
+def fill_hourly_create_time(df):
+    """补齐成小时粒度数据"""
+    pass
+
+def process_flight_numbers(args):
+    process_id, db_config, city_pair, flight_numbers, from_date_begin, from_date_end, is_train, plot_flag, output_dir = args
+    print(f"[进程{process_id}] 开始处理航班号: {flight_numbers}")
+    
+    # 为每个进程创建独立的数据库连接
+    try:
+        client, db = mongo_con_parse(db_config)
+        print(f"[进程{process_id}] ✅ 数据库连接创建成功")
+    except Exception as e:
+        print(f"[进程{process_id}] ❌ 数据库连接创建失败: {e}")
+        return pd.DataFrame()
+    
+    try:
+        # 查询
+        df_1 = query_flight_range_status(db, mongo_table_uo, city_pair, flight_numbers, from_date_begin, from_date_end)
+        
+        df_f1 = fill_hourly_create_time(df_1)
+
+        
+    except Exception as e:
+        print(f"[进程{process_id}] ❌ 处理航班号:{flight_numbers} 时发生异常: {e}")
+        return pd.DataFrame()
+    finally:
+        # 确保关闭数据库连接
+        try:
+            client.close()
+            print(f"[进程{process_id}] ✅ 数据库连接已关闭")
+        except:
+            pass
+
+
+def load_data(db_config, city_pair, from_date_begin, from_date_end, is_train=True, plot_flag=False, output_dir='.', 
+              use_multiprocess=False, max_workers=None):
 
     print(f"开始处理航线: {city_pair}")
     main_client, main_db = mongo_con_parse(db_config)
     all_groups = query_groups_of_city_pair(main_db, city_pair, mongo_table_uo)
     main_client.close()
     
-    pass
+    all_groups_len = len(all_groups)
+    print(f"该航线共有{all_groups_len}组航班号")
+
+    print("使用单进程处理")
+    process_id = 0
+
+    for each_group in all_groups:
+        flight_numbers = each_group.get("flight_numbers", "未知")
+        args = (process_id, db_config, city_pair, flight_numbers, from_date_begin, from_date_end, is_train, plot_flag, output_dir)
+        try:
+            df_mid = process_flight_numbers(args)
+            pass
+
+        except Exception as e:
+            print(f"❌ 航班号:{flight_numbers} 处理异常: {e}")
+
 
 
 if __name__ == "__main__":