|
|
@@ -6,6 +6,7 @@ from pymongo.errors import PyMongoError, ServerSelectionTimeoutError
|
|
|
import pandas as pd
|
|
|
import os
|
|
|
import random
|
|
|
+import tempfile
|
|
|
from concurrent.futures import ProcessPoolExecutor, as_completed
|
|
|
import numpy as np
|
|
|
import matplotlib.pyplot as plt
|
|
|
@@ -18,24 +19,41 @@ from config import mongodb_config, vj_flight_route_list, vj_flight_route_list_ho
|
|
|
font_path = "./simhei.ttf"
|
|
|
font_prop = font_manager.FontProperties(fname=font_path)
|
|
|
|
|
|
-def mongo_con_parse(config=None):
|
|
|
+_MONGO_SHARED_CLIENT = None
|
|
|
+_MONGO_SHARED_DB = None
|
|
|
+_MONGO_SHARED_CFG_KEY = None
|
|
|
+
|
|
|
+
|
|
|
+def mongo_con_parse(config=None, reuse_client=False):
|
|
|
if config is None:
|
|
|
config = mongodb_config.copy()
|
|
|
|
|
|
+ global _MONGO_SHARED_CLIENT, _MONGO_SHARED_DB, _MONGO_SHARED_CFG_KEY
|
|
|
+
|
|
|
+ cfg_key = (
|
|
|
+ config.get("URI", ""),
|
|
|
+ config.get("host", ""),
|
|
|
+ config.get("port", ""),
|
|
|
+ config.get("db", ""),
|
|
|
+ config.get("user", ""),
|
|
|
+ )
|
|
|
+
|
|
|
+ if reuse_client and _MONGO_SHARED_CLIENT is not None and _MONGO_SHARED_DB is not None and _MONGO_SHARED_CFG_KEY == cfg_key:
|
|
|
+ return _MONGO_SHARED_CLIENT, _MONGO_SHARED_DB
|
|
|
+
|
|
|
try:
|
|
|
if config.get("URI", ""):
|
|
|
motor_uri = config["URI"]
|
|
|
client = pymongo.MongoClient(motor_uri, maxPoolSize=100)
|
|
|
db = client[config['db']]
|
|
|
- print("motor_uri: ", motor_uri)
|
|
|
else:
|
|
|
client = pymongo.MongoClient(
|
|
|
config['host'],
|
|
|
config['port'],
|
|
|
- serverSelectionTimeoutMS=15000, # 6秒
|
|
|
- connectTimeoutMS=15000, # 6秒
|
|
|
- socketTimeoutMS=15000, # 6秒,
|
|
|
- retryReads=True, # 开启重试
|
|
|
+ serverSelectionTimeoutMS=30000,
|
|
|
+ connectTimeoutMS=30000,
|
|
|
+ socketTimeoutMS=30000,
|
|
|
+ retryReads=True,
|
|
|
maxPoolSize=50
|
|
|
)
|
|
|
db = client[config['db']]
|
|
|
@@ -47,6 +65,12 @@ def mongo_con_parse(config=None):
|
|
|
except Exception as e:
|
|
|
print(f"❌ 创建 MongoDB 连接对象时发生错误: {e}")
|
|
|
raise
|
|
|
+
|
|
|
+ if reuse_client:
|
|
|
+ _MONGO_SHARED_CLIENT = client
|
|
|
+ _MONGO_SHARED_DB = db
|
|
|
+ _MONGO_SHARED_CFG_KEY = cfg_key
|
|
|
+
|
|
|
return client, db
|
|
|
|
|
|
|
|
|
@@ -75,7 +99,6 @@ def query_flight_range_status(db, table_name, from_city, to_city, dep_date_begin
|
|
|
for attempt in range(1, max_retries + 1):
|
|
|
try:
|
|
|
print(f"🔁 第 {attempt}/{max_retries} 次尝试查询")
|
|
|
- # 构建查询条件
|
|
|
query_condition = {
|
|
|
"from_city_code": from_city,
|
|
|
"to_city_code": to_city,
|
|
|
@@ -83,17 +106,13 @@ def query_flight_range_status(db, table_name, from_city, to_city, dep_date_begin
|
|
|
"$gte": dep_date_begin,
|
|
|
"$lte": dep_date_end,
|
|
|
},
|
|
|
- # "segments.baggage": {"$in": ["-;-;-;-", "1-30"]} # 无行李,30公斤行李
|
|
|
- "segments.baggage": "-;-;-;-"
|
|
|
}
|
|
|
- # 动态添加航班号条件
|
|
|
- for i, flight_num in enumerate(flight_nums):
|
|
|
- query_condition[f"segments.{i}.flight_number"] = flight_num
|
|
|
-
|
|
|
- print(f" 查询条件: {query_condition}")
|
|
|
- # 定义要查询的字段
|
|
|
+
|
|
|
+ baggage_filter = 0
|
|
|
+ # flight_nums_filter = list(flight_nums) if flight_nums else []
|
|
|
+
|
|
|
+ print(f" 查询条件(走索引): {query_condition}")
|
|
|
projection = {
|
|
|
- # "_id": 1,
|
|
|
"from_city_code": 1,
|
|
|
"search_dep_time": 1,
|
|
|
"to_city_code": 1,
|
|
|
@@ -106,19 +125,13 @@ def query_flight_range_status(db, table_name, from_city, to_city, dep_date_begin
|
|
|
"source_website": 1,
|
|
|
"crawl_date": 1
|
|
|
}
|
|
|
- # 执行查询
|
|
|
- cursor = db.get_collection(table_name).find(
|
|
|
- query_condition,
|
|
|
- projection=projection # 添加投影参数
|
|
|
- ).sort(
|
|
|
- [
|
|
|
- ("search_dep_time", 1), # 多级排序要用列表+元组的格式
|
|
|
- ("segments.0.baggage", 1),
|
|
|
- ("crawl_date", 1)
|
|
|
- ]
|
|
|
+
|
|
|
+ cursor = (
|
|
|
+ db.get_collection(table_name)
|
|
|
+ .find(query_condition, projection=projection)
|
|
|
+ .batch_size(5000)
|
|
|
+ .hint('from_city_code_1_to_city_code_1_search_dep_time_1')
|
|
|
)
|
|
|
- if limit > 0:
|
|
|
- cursor = cursor.limit(limit)
|
|
|
|
|
|
# 将结果转换为列表
|
|
|
results = list(cursor)
|
|
|
@@ -140,7 +153,24 @@ def query_flight_range_status(db, table_name, from_city, to_city, dep_date_begin
|
|
|
print(f"用时: {rt} 秒")
|
|
|
print(f"📊 已将segments扩展成字段,形状: {df.shape}")
|
|
|
|
|
|
- # 不用排序,因为mongo语句已经排好
|
|
|
+ if "baggage" in df.columns:
|
|
|
+ df = df[df["baggage"] == baggage_filter]
|
|
|
+
|
|
|
+ # for i, flight_num in enumerate(flight_nums_filter):
|
|
|
+ # if flight_num is None or flight_num == "":
|
|
|
+ # continue
|
|
|
+ # col = f"seg{i + 1}_flight_number"
|
|
|
+ # if col not in df.columns:
|
|
|
+ # return pd.DataFrame()
|
|
|
+ # df = df[df[col].astype("string") == str(flight_num)]
|
|
|
+
|
|
|
+ # sort_cols = [c for c in ["search_dep_time", "baggage", "crawl_date"] if c in df.columns]
|
|
|
+ # if sort_cols:
|
|
|
+ # df = df.sort_values(sort_cols, kind="mergesort").reset_index(drop=True)
|
|
|
+
|
|
|
+ if limit > 0:
|
|
|
+ df = df.head(limit).reset_index(drop=True)
|
|
|
+
|
|
|
return df
|
|
|
|
|
|
else:
|
|
|
@@ -425,104 +455,95 @@ def fill_hourly_crawl_date(df, head_fill=0, rear_fill=0):
|
|
|
return df
|
|
|
|
|
|
|
|
|
-def query_groups_of_city_code(db, from_city, to_city, table_name, min_days=20, max_retries=3, base_sleep=1.0):
|
|
|
+def query_groups_of_city_code(db, from_city, to_city, table_name, min_days=10, max_retries=3, base_sleep=1.0):
|
|
|
"""
|
|
|
从一组城市对中查找所有分组(航班号与起飞时间)的组合
|
|
|
按:第一段航班号 → 第二段航班号 → 起飞时间 排序
|
|
|
- (失败自动重试) 保证2个月内至少有20天起飞的航线
|
|
|
+ (失败自动重试) 保证1个月内至少有10天起飞的航线
|
|
|
+
|
|
|
+ 说明:为降低 Mongo 的聚合负担,这里只做轻量 find + 投影,把“按天统计/按航班组合汇总”的逻辑放到 pandas 侧处理。
|
|
|
"""
|
|
|
print(f"{from_city}-{to_city} 查找所有分组")
|
|
|
- date_begin = (datetime.today() - timedelta(days=60)).strftime("%Y%m%d")
|
|
|
+ date_begin = (datetime.today() - timedelta(days=31)).strftime("%Y%m%d")
|
|
|
date_end = datetime.today().strftime("%Y%m%d")
|
|
|
- pipeline = [
|
|
|
- # 1️⃣ 先筛选城市对
|
|
|
- {
|
|
|
- "$match": {
|
|
|
- "from_city_code": from_city,
|
|
|
- "to_city_code": to_city,
|
|
|
- "search_dep_time": {
|
|
|
- "$gte": date_begin,
|
|
|
- "$lte": date_end
|
|
|
- }
|
|
|
- }
|
|
|
- },
|
|
|
- # 2️⃣ 投影字段 + 拆第一、第二段航班号用于排序
|
|
|
- {
|
|
|
- "$project": {
|
|
|
- "flight_numbers": "$segments.flight_number",
|
|
|
- "search_dep_time": 1,
|
|
|
- "fn1": {"$arrayElemAt": ["$segments.flight_number", 0]},
|
|
|
- "fn2": {"$arrayElemAt": ["$segments.flight_number", 1]}
|
|
|
- }
|
|
|
- },
|
|
|
- # 3️⃣ 第一级分组:组合 + 每一天
|
|
|
- {
|
|
|
- "$group": {
|
|
|
- "_id": {
|
|
|
- "flight_numbers": "$flight_numbers",
|
|
|
- "search_dep_time": "$search_dep_time",
|
|
|
- "fn1": "$fn1",
|
|
|
- "fn2": "$fn2"
|
|
|
- },
|
|
|
- "count": {"$sum": 1}
|
|
|
- }
|
|
|
- },
|
|
|
- # 关键修复点:这里先按【时间】排好序!
|
|
|
- {
|
|
|
- "$sort": {
|
|
|
- "_id.fn1": 1,
|
|
|
- "_id.fn2": 1,
|
|
|
- "_id.search_dep_time": 1 # 确保 push 进去时是按天递增
|
|
|
- }
|
|
|
- },
|
|
|
- # 4️⃣ 第二级分组:只按【航班组合】聚合 → 统计“有多少天”
|
|
|
- {
|
|
|
- "$group": {
|
|
|
- "_id": {
|
|
|
- "flight_numbers": "$_id.flight_numbers",
|
|
|
- "fn1": "$_id.fn1",
|
|
|
- "fn2": "$_id.fn2"
|
|
|
- },
|
|
|
- "days": {"$sum": 1}, # 不同起飞天数
|
|
|
- "details": {
|
|
|
- "$push": {
|
|
|
- "search_dep_time": "$_id.search_dep_time",
|
|
|
- "count": "$count"
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- },
|
|
|
- # 5️⃣ 关键:按“天数阈值”过滤
|
|
|
- {
|
|
|
- "$match": {
|
|
|
- "days": {"$gte": min_days}
|
|
|
- }
|
|
|
- },
|
|
|
- # 6️⃣ ✅ 按“第一段 → 第二段”排序
|
|
|
- {
|
|
|
- "$sort": {
|
|
|
- "_id.fn1": 1,
|
|
|
- "_id.fn2": 1,
|
|
|
- }
|
|
|
- }
|
|
|
- ]
|
|
|
+
|
|
|
+ query = {
|
|
|
+ "from_city_code": from_city,
|
|
|
+ "to_city_code": to_city,
|
|
|
+ "search_dep_time": {"$gte": date_begin, "$lte": date_end},
|
|
|
+ }
|
|
|
+ projection = {
|
|
|
+ "_id": 0,
|
|
|
+ "search_dep_time": 1,
|
|
|
+ "segments.flight_number": 1,
|
|
|
+ }
|
|
|
+
|
|
|
+ def _extract_flight_numbers(segments):
|
|
|
+ if not isinstance(segments, list):
|
|
|
+ return []
|
|
|
+ out = []
|
|
|
+ for seg in segments:
|
|
|
+ if not isinstance(seg, dict):
|
|
|
+ continue
|
|
|
+ fn = seg.get("flight_number")
|
|
|
+ if fn:
|
|
|
+ out.append(fn)
|
|
|
+ return out
|
|
|
+
|
|
|
for attempt in range(1, max_retries + 1):
|
|
|
try:
|
|
|
print(f" 第 {attempt}/{max_retries} 次尝试查询")
|
|
|
|
|
|
- # 执行聚合查询
|
|
|
collection = db[table_name]
|
|
|
- results = list(collection.aggregate(pipeline))
|
|
|
+ cursor = collection.find(query, projection=projection).batch_size(5000).hint('from_city_code_1_to_city_code_1_search_dep_time_1')
|
|
|
+ docs = list(cursor)
|
|
|
+ if not docs:
|
|
|
+ return []
|
|
|
+
|
|
|
+ df = pd.DataFrame.from_records(docs)
|
|
|
+ if df.empty or "segments" not in df.columns or "search_dep_time" not in df.columns:
|
|
|
+ return []
|
|
|
+
|
|
|
+ df["flight_numbers"] = df["segments"].apply(_extract_flight_numbers)
|
|
|
+ df["fn1"] = df["flight_numbers"].str[0].fillna("")
|
|
|
+ df["fn2"] = df["flight_numbers"].str[1].fillna("")
|
|
|
+ df["flight_numbers_key"] = df["flight_numbers"].apply(lambda xs: ",".join(xs) if xs else "")
|
|
|
+
|
|
|
+ day_counts = (
|
|
|
+ df.groupby(["flight_numbers_key", "fn1", "fn2", "search_dep_time"], dropna=False)
|
|
|
+ .size()
|
|
|
+ .reset_index(name="count")
|
|
|
+ .sort_values(["fn1", "fn2", "search_dep_time"], kind="mergesort")
|
|
|
+ .reset_index(drop=True)
|
|
|
+ )
|
|
|
+
|
|
|
+ keys = ["flight_numbers_key", "fn1", "fn2"]
|
|
|
+ df_days = day_counts.groupby(keys, sort=False).size().reset_index(name="days")
|
|
|
+ df_details = (
|
|
|
+ day_counts.groupby(keys, sort=False)
|
|
|
+ .apply(lambda g: g[["search_dep_time", "count"]].to_dict("records"))
|
|
|
+ .reset_index(name="details")
|
|
|
+ )
|
|
|
+
|
|
|
+ df_result = df_days.merge(df_details, on=keys, how="inner")
|
|
|
+ df_result = df_result[df_result["days"] >= min_days].sort_values(["fn1", "fn2"], kind="mergesort")
|
|
|
|
|
|
- # 格式化结果,将 _id 中的字段提取到外层
|
|
|
formatted_results = []
|
|
|
- for item in results:
|
|
|
- formatted_item = {
|
|
|
- "flight_numbers": item["_id"]["flight_numbers"],
|
|
|
- "days": item["days"], # 这个组合一共有多少天
|
|
|
- "details": item["details"] # 每一天的 count 明细
|
|
|
- }
|
|
|
- formatted_results.append(formatted_item)
|
|
|
+ for _, row in df_result.iterrows():
|
|
|
+ flight_numbers = row["flight_numbers_key"].split(",") if row["flight_numbers_key"] else []
|
|
|
+ formatted_results.append(
|
|
|
+ {
|
|
|
+ "flight_numbers": flight_numbers,
|
|
|
+ "days": int(row["days"]),
|
|
|
+ "details": row["details"],
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
+ del df_result
|
|
|
+ del df_details
|
|
|
+ del df_days
|
|
|
+ del df
|
|
|
+ # gc.collect()
|
|
|
|
|
|
return formatted_results
|
|
|
|
|
|
@@ -640,45 +661,55 @@ def plot_c12_trend(df, output_dir="."):
|
|
|
plt.close(fig)
|
|
|
|
|
|
|
|
|
+_ROUTE_CACHE_DF1 = None
|
|
|
+_ROUTE_CACHE_DF2 = None
|
|
|
+
|
|
|
+
|
|
|
+def _init_route_cache_worker(df1_pickle_path, df2_pickle_path):
|
|
|
+ global _ROUTE_CACHE_DF1, _ROUTE_CACHE_DF2
|
|
|
+ _ROUTE_CACHE_DF1 = pd.read_pickle(df1_pickle_path)
|
|
|
+ _ROUTE_CACHE_DF2 = pd.read_pickle(df2_pickle_path)
|
|
|
+
|
|
|
+
|
|
|
+def _filter_df_by_flight_nums(df, flight_nums):
|
|
|
+ if df is None or df.empty:
|
|
|
+ return pd.DataFrame()
|
|
|
+
|
|
|
+ out = df
|
|
|
+ flight_nums_filter = list(flight_nums) if flight_nums else []
|
|
|
+ for i, flight_num in enumerate(flight_nums_filter):
|
|
|
+ if flight_num is None or flight_num == "":
|
|
|
+ continue
|
|
|
+ col = f"seg{i + 1}_flight_number"
|
|
|
+ if col not in out.columns:
|
|
|
+ return out.iloc[0:0].copy()
|
|
|
+ out = out[out[col].astype("string") == str(flight_num)]
|
|
|
+ if out.empty:
|
|
|
+ return out
|
|
|
+
|
|
|
+ return out
|
|
|
+
|
|
|
def process_flight_group(args):
|
|
|
- """处理单个航班号的进程函数(独立数据库连接)"""
|
|
|
- process_id, db_config, each_group, from_city, to_city, date_begin_s, date_end_s, is_hot, is_train, plot_flag, output_dir = args
|
|
|
+ """处理单个航班号的进程函数(基于主进程缓存的数据做 pandas 过滤与处理)"""
|
|
|
+ process_id, each_group, is_train, plot_flag, output_dir = args
|
|
|
flight_nums = each_group.get("flight_numbers")
|
|
|
- details = each_group.get("details")
|
|
|
+ # details = each_group.get("details")
|
|
|
|
|
|
print(f"[进程{process_id}] 开始处理航班号: {flight_nums}")
|
|
|
|
|
|
- # 为每个进程创建独立的数据库连接
|
|
|
try:
|
|
|
- client, db = mongo_con_parse(db_config)
|
|
|
- print(f"[进程{process_id}] ✅ 数据库连接创建成功")
|
|
|
- except Exception as e:
|
|
|
- print(f"[进程{process_id}] ❌ 数据库连接创建失败: {e}")
|
|
|
- return pd.DataFrame()
|
|
|
+ df1 = _filter_df_by_flight_nums(_ROUTE_CACHE_DF1, flight_nums)
|
|
|
+ df2 = _filter_df_by_flight_nums(_ROUTE_CACHE_DF2, flight_nums)
|
|
|
+
|
|
|
+ sort_cols = [c for c in ["search_dep_time", "baggage", "crawl_date"] if c in df1.columns]
|
|
|
+ if sort_cols:
|
|
|
+ df1 = df1.sort_values(sort_cols, kind="mergesort").reset_index(drop=True)
|
|
|
+ df2 = df2.sort_values(sort_cols, kind="mergesort").reset_index(drop=True)
|
|
|
|
|
|
- try:
|
|
|
- # 查询远期表
|
|
|
- if is_hot == 1:
|
|
|
- df1 = query_flight_range_status(db, CLEAN_VJ_HOT_FAR_INFO_TAB, from_city, to_city,
|
|
|
- date_begin_s, date_end_s, flight_nums)
|
|
|
- else:
|
|
|
- df1 = query_flight_range_status(db, CLEAN_VJ_NOTHOT_FAR_INFO_TAB, from_city, to_city,
|
|
|
- date_begin_s, date_end_s, flight_nums)
|
|
|
-
|
|
|
- # 保证远期表里有数据
|
|
|
if df1.empty:
|
|
|
print(f"[进程{process_id}] 航班号:{flight_nums} 远期表无数据, 跳过")
|
|
|
return pd.DataFrame()
|
|
|
|
|
|
- # 查询近期表
|
|
|
- if is_hot == 1:
|
|
|
- df2 = query_flight_range_status(db, CLEAN_VJ_HOT_NEAR_INFO_TAB, from_city, to_city,
|
|
|
- date_begin_s, date_end_s, flight_nums)
|
|
|
- else:
|
|
|
- df2 = query_flight_range_status(db, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB, from_city, to_city,
|
|
|
- date_begin_s, date_end_s, flight_nums)
|
|
|
-
|
|
|
- # 保证近期表里有数据
|
|
|
if df2.empty:
|
|
|
print(f"[进程{process_id}] 航班号:{flight_nums} 近期表无数据, 跳过")
|
|
|
return pd.DataFrame()
|
|
|
@@ -781,12 +812,7 @@ def process_flight_group(args):
|
|
|
print(f"[进程{process_id}] ❌ 处理航班号:{flight_nums} 时发生异常: {e}")
|
|
|
return pd.DataFrame()
|
|
|
finally:
|
|
|
- # 确保关闭数据库连接
|
|
|
- try:
|
|
|
- client.close()
|
|
|
- print(f"[进程{process_id}] ✅ 数据库连接已关闭")
|
|
|
- except:
|
|
|
- pass
|
|
|
+ pass
|
|
|
|
|
|
|
|
|
def load_train_data(db_config, flight_route_list, table_name, date_begin, date_end, output_dir='.', is_hot=1, is_train=True, plot_flag=False,
|
|
|
@@ -805,25 +831,70 @@ def load_train_data(db_config, flight_route_list, table_name, date_begin, date_e
|
|
|
print(f"开始处理航线: {route}")
|
|
|
|
|
|
# 在主进程中查询航班号分组(避免多进程重复查询)
|
|
|
- main_client, main_db = mongo_con_parse(db_config)
|
|
|
+ main_client, main_db = mongo_con_parse(db_config, reuse_client=True)
|
|
|
all_groups = query_groups_of_city_code(main_db, from_city, to_city, table_name)
|
|
|
- main_client.close()
|
|
|
-
|
|
|
+
|
|
|
all_groups_len = len(all_groups)
|
|
|
print(f"该航线共有{all_groups_len}个航班号")
|
|
|
+
|
|
|
+ if all_groups_len == 0:
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 查询远期表
|
|
|
+ if is_hot == 1:
|
|
|
+ df1 = query_flight_range_status(main_db, CLEAN_VJ_HOT_FAR_INFO_TAB, from_city, to_city,
|
|
|
+ date_begin_s, date_end_s, None)
|
|
|
+ else:
|
|
|
+ df1 = query_flight_range_status(main_db, CLEAN_VJ_NOTHOT_FAR_INFO_TAB, from_city, to_city,
|
|
|
+ date_begin_s, date_end_s, None)
|
|
|
|
|
|
- if use_multiprocess and all_groups_len > 1:
|
|
|
- print(f"启用多进程处理,最大进程数: {max_workers}")
|
|
|
- # 多进程处理
|
|
|
- process_args = []
|
|
|
- process_id = 0
|
|
|
- for each_group in all_groups:
|
|
|
- process_id += 1
|
|
|
- args = (process_id, db_config, each_group, from_city, to_city, date_begin_s, date_end_s, is_hot, is_train, plot_flag, output_dir)
|
|
|
- process_args.append(args)
|
|
|
+ # 保证远期表里有数据
|
|
|
+ if df1.empty:
|
|
|
+ print(f"[主进程] 航线:{route} 远期表无数据, 跳过")
|
|
|
+ # main_client.close()
|
|
|
+ return pd.DataFrame()
|
|
|
+
|
|
|
+ # 查询近期表
|
|
|
+ if is_hot == 1:
|
|
|
+ df2 = query_flight_range_status(main_db, CLEAN_VJ_HOT_NEAR_INFO_TAB, from_city, to_city,
|
|
|
+ date_begin_s, date_end_s, None)
|
|
|
+ else:
|
|
|
+ df2 = query_flight_range_status(main_db, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB, from_city, to_city,
|
|
|
+ date_begin_s, date_end_s, None)
|
|
|
|
|
|
- with ProcessPoolExecutor(max_workers=max_workers) as executor:
|
|
|
- future_to_group = {executor.submit(process_flight_group, args): each_group for args, each_group in zip(process_args, all_groups)}
|
|
|
+ # 保证近期表里有数据
|
|
|
+ if df2.empty:
|
|
|
+ print(f"[主进程] 航线:{route} 近期表无数据, 跳过")
|
|
|
+ # main_client.close()
|
|
|
+ return pd.DataFrame()
|
|
|
+
|
|
|
+ # main_client.close()
|
|
|
+
|
|
|
+ os.makedirs(output_dir, exist_ok=True)
|
|
|
+ safe_route = route.replace("-", "_")
|
|
|
+ df1_fd, df1_cache_path = tempfile.mkstemp(prefix=f"route_{safe_route}_{timestamp_str}_far_", suffix=".pkl", dir=output_dir)
|
|
|
+ df2_fd, df2_cache_path = tempfile.mkstemp(prefix=f"route_{safe_route}_{timestamp_str}_near_", suffix=".pkl", dir=output_dir)
|
|
|
+ os.close(df1_fd)
|
|
|
+ os.close(df2_fd)
|
|
|
+ df1.to_pickle(df1_cache_path)
|
|
|
+ df2.to_pickle(df2_cache_path)
|
|
|
+
|
|
|
+ try:
|
|
|
+ if use_multiprocess and all_groups_len > 1:
|
|
|
+ print(f"启用多进程处理,最大进程数: {max_workers}")
|
|
|
+ process_args = []
|
|
|
+ process_id = 0
|
|
|
+ for each_group in all_groups:
|
|
|
+ process_id += 1
|
|
|
+ args = (process_id, each_group, is_train, plot_flag, output_dir)
|
|
|
+ process_args.append(args)
|
|
|
+
|
|
|
+ with ProcessPoolExecutor(
|
|
|
+ max_workers=max_workers,
|
|
|
+ initializer=_init_route_cache_worker,
|
|
|
+ initargs=(df1_cache_path, df2_cache_path),
|
|
|
+ ) as executor:
|
|
|
+ future_to_group = {executor.submit(process_flight_group, args): each_group for args, each_group in zip(process_args, all_groups)}
|
|
|
|
|
|
for future in as_completed(future_to_group):
|
|
|
each_group = future_to_group[future]
|
|
|
@@ -838,23 +909,37 @@ def load_train_data(db_config, flight_route_list, table_name, date_begin, date_e
|
|
|
except Exception as e:
|
|
|
print(f"❌ 航班号:{flight_nums} 处理异常: {e}")
|
|
|
|
|
|
- else:
|
|
|
- # 单进程处理(进程编号为0)
|
|
|
- print("使用单进程处理")
|
|
|
- process_id = 0
|
|
|
- for each_group in all_groups:
|
|
|
- args = (process_id, db_config, each_group, from_city, to_city, date_begin_s, date_end_s, is_hot, is_train, plot_flag, output_dir)
|
|
|
- flight_nums = each_group.get("flight_numbers", "未知")
|
|
|
- try:
|
|
|
- df_mid = process_flight_group(args)
|
|
|
- if not df_mid.empty:
|
|
|
- list_all.append(df_mid)
|
|
|
- print(f"✅ 航班号:{flight_nums} 处理完成")
|
|
|
- else:
|
|
|
- print(f"⚠️ 航班号:{flight_nums} 处理结果为空")
|
|
|
- except Exception as e:
|
|
|
- print(f"❌ 航班号:{flight_nums} 处理异常: {e}")
|
|
|
-
|
|
|
+ else:
|
|
|
+ # 单进程处理(进程编号为0)
|
|
|
+ print("使用单进程处理")
|
|
|
+ global _ROUTE_CACHE_DF1, _ROUTE_CACHE_DF2
|
|
|
+ _ROUTE_CACHE_DF1 = df1
|
|
|
+ _ROUTE_CACHE_DF2 = df2
|
|
|
+
|
|
|
+ process_id = 0
|
|
|
+ for each_group in all_groups:
|
|
|
+ args = (process_id, each_group, is_train, plot_flag, output_dir)
|
|
|
+ flight_nums = each_group.get("flight_numbers", "未知")
|
|
|
+ try:
|
|
|
+ df_mid = process_flight_group(args)
|
|
|
+ if not df_mid.empty:
|
|
|
+ list_all.append(df_mid)
|
|
|
+ print(f"✅ 航班号:{flight_nums} 处理完成")
|
|
|
+ else:
|
|
|
+ print(f"⚠️ 航班号:{flight_nums} 处理结果为空")
|
|
|
+ except Exception as e:
|
|
|
+ print(f"❌ 航班号:{flight_nums} 处理异常: {e}")
|
|
|
+
|
|
|
+ finally:
|
|
|
+ try:
|
|
|
+ os.remove(df1_cache_path)
|
|
|
+ except Exception:
|
|
|
+ pass
|
|
|
+ try:
|
|
|
+ os.remove(df2_cache_path)
|
|
|
+ except Exception:
|
|
|
+ pass
|
|
|
+
|
|
|
print(f"结束处理航线: {from_city}-{to_city}")
|
|
|
|
|
|
if list_all:
|