import gc import time from datetime import datetime, timedelta import pymongo from pymongo.errors import PyMongoError, ServerSelectionTimeoutError import pandas as pd import os import random from concurrent.futures import ProcessPoolExecutor, as_completed import numpy as np import matplotlib.pyplot as plt from matplotlib import font_manager import matplotlib.dates as mdates from config import mongodb_config, vj_flight_route_list, vj_flight_route_list_hot, vj_flight_route_list_nothot, \ CLEAN_VJ_HOT_NEAR_INFO_TAB, CLEAN_VJ_HOT_FAR_INFO_TAB, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB, CLEAN_VJ_NOTHOT_FAR_INFO_TAB font_path = "./simhei.ttf" font_prop = font_manager.FontProperties(fname=font_path) def mongo_con_parse(config=None): if config is None: config = mongodb_config.copy() try: if config.get("URI", ""): motor_uri = config["URI"] client = pymongo.MongoClient(motor_uri, maxPoolSize=100) db = client[config['db']] print("motor_uri: ", motor_uri) else: client = pymongo.MongoClient( config['host'], config['port'], serverSelectionTimeoutMS=15000, # 6秒 connectTimeoutMS=15000, # 6秒 socketTimeoutMS=15000, # 6秒, retryReads=True, # 开启重试 maxPoolSize=50 ) db = client[config['db']] if config.get('user'): db.authenticate(config['user'], config['pwd']) print(f"✅ MongoDB 连接对象创建成功") except Exception as e: print(f"❌ 创建 MongoDB 连接对象时发生错误: {e}") raise return client, db def test_mongo_connection(db): try: # 获取客户端对象 client = db.client # 方法1:使用 server_info() 测试连接 info = client.server_info() print(f"✅ MongoDB 连接测试成功!") print(f" 服务器版本: {info.get('version')}") print(f" 数据库: {db.name}") return True except Exception as e: print(f"❌ 数据库连接测试失败: {e}") return False def query_flight_range_status(db, table_name, from_city, to_city, dep_date_begin, dep_date_end, flight_nums, limit=0, max_retries=3, base_sleep=1.0, thread_id=0): """ 从指定表(4类)查询数据(指定起飞天的范围) (失败自动重试) """ for attempt in range(1, max_retries + 1): try: print(f"🔁 第 {attempt}/{max_retries} 次尝试查询") # 构建查询条件 query_condition = { "from_city_code": from_city, "to_city_code": to_city, "search_dep_time": { "$gte": dep_date_begin, "$lte": dep_date_end, }, "segments.baggage": {"$in": ["1-20", "1-30"]} # 只查20公斤和30公斤行李的 } # 动态添加航班号条件 for i, flight_num in enumerate(flight_nums): query_condition[f"segments.{i}.flight_number"] = flight_num print(f" 查询条件: {query_condition}") # 定义要查询的字段 projection = { # "_id": 1, "from_city_code": 1, "search_dep_time": 1, "to_city_code": 1, "currency": 1, "adult_price": 1, "adult_tax": 1, "adult_total_price": 1, "seats_remaining": 1, "segments": 1, "source_website": 1, "crawl_date": 1 } # 执行查询 cursor = db.get_collection(table_name).find( query_condition, projection=projection # 添加投影参数 ).sort( [ ("search_dep_time", 1), # 多级排序要用列表+元组的格式 ("segments.0.baggage", 1), ("crawl_date", 1) ] ) if limit > 0: cursor = cursor.limit(limit) # 将结果转换为列表 results = list(cursor) print(f"✅ 查询成功,找到 {len(results)} 条记录") if results: df = pd.DataFrame(results) # 处理特殊的 ObjectId 类型 if '_id' in df.columns: df = df.drop(columns=['_id']) print(f"📊 已转换为 DataFrame,形状: {df.shape}") # 1️⃣ 展开 segments print(f"📊 开始扩展segments 稍等...") t1 = time.time() df = expand_segments_columns_optimized(df) # 改为调用优化版 t2 = time.time() rt = round(t2 - t1, 3) print(f"用时: {rt} 秒") print(f"📊 已将segments扩展成字段,形状: {df.shape}") # 不用排序,因为mongo语句已经排好 return df else: print("⚠️ 查询结果为空") return pd.DataFrame() except (ServerSelectionTimeoutError, PyMongoError) as e: print(f"⚠️ Mongo 查询失败: {e}") if attempt == max_retries: print("❌ 达到最大重试次数,放弃") return pd.DataFrame() # 指数退避 + 随机抖动 sleep_time = base_sleep * (2 ** (attempt - 1)) + random.random() print(f"⏳ {sleep_time:.2f}s 后重试...") time.sleep(sleep_time) # def expand_segments_columns(df): # """展开 segments""" # df = df.copy() # # 定义要展开的列 # seg1_cols = ['flight_number', 'dep_air_port', 'dep_time', 'arr_air_port', 'arr_time', 'cabin', 'baggage'] # seg2_cols = ['flight_number', 'dep_air_port', 'dep_time', 'arr_air_port', 'arr_time'] # # 定义 apply 函数一次返回字典 # def extract_segments(row): # segments = row.get('segments') # result = {} # # 默认缺失使用 pd.NA(对字符串友好) # missing = pd.NA # if isinstance(segments, list): # # 第一段 # if len(segments) >= 1 and isinstance(segments[0], dict): # for col in seg1_cols: # result[f'seg1_{col}'] = segments[0].get(col) # else: # for col in seg1_cols: # result[f'seg1_{col}'] = missing # # 第二段 # if len(segments) >= 2 and isinstance(segments[1], dict): # for col in seg2_cols: # result[f'seg2_{col}'] = segments[1].get(col) # else: # for col in seg2_cols: # result[f'seg2_{col}'] = missing # else: # # segments 不是 list,全都置空 # for col in seg1_cols: # result[f'seg1_{col}'] = missing # for col in seg2_cols: # result[f'seg2_{col}'] = missing # return pd.Series(result) # # 一次 apply # df_segments = df.apply(extract_segments, axis=1) # # 拼回原 df # df = pd.concat([df.drop(columns=['segments'], errors='ignore'), df_segments], axis=1) # # 统一转换时间字段为 datetime # time_cols = [ # 'seg1_dep_time', 'seg1_arr_time', # 'seg2_dep_time', 'seg2_arr_time' # ] # for col in time_cols: # if col in df.columns: # df[col] = pd.to_datetime( # df[col], # format='%Y%m%d%H%M%S', # errors='coerce' # ) # # 站点来源 -> 是否近期 # df['source_website'] = np.where( # df['source_website'].str.contains('7_30'), # 0, # 远期 -> 0 # np.where(df['source_website'].str.contains('0_7'), # 1, # 近期 -> 1 # df['source_website']) # 其他情况保持原值 # ) # # 行李配额字符 -> 数字 # conditions = [ # df['seg1_baggage'] == '-;-;-;-', # df['seg1_baggage'] == '1-20', # df['seg1_baggage'] == '1-30', # df['seg1_baggage'] == '1-40', # ] # choices = [0, 20, 30, 40] # df['seg1_baggage'] = np.select(conditions, choices, default=df['seg1_baggage']) # # 重命名字段 # df = df.rename(columns={ # 'seg1_cabin': 'cabin', # 'seg1_baggage': 'baggage', # 'source_website': 'is_near', # }) # return df def expand_segments_columns_optimized(df): """优化版的展开segments函数(避免逐行apply)""" if df.empty: return df df = df.copy() # 直接操作segments列表,避免逐行apply if 'segments' in df.columns: # 提取第一段信息 seg1_cols = ['flight_number', 'dep_air_port', 'dep_time', 'arr_air_port', 'arr_time', 'cabin', 'baggage'] # 提取第二段信息 seg2_cols = ['flight_number', 'dep_air_port', 'dep_time', 'arr_air_port', 'arr_time'] # 使用列表推导式替代apply,大幅提升性能 seg1_data = [] seg2_data = [] for segments in df['segments']: seg1_dict = {} seg2_dict = {} if isinstance(segments, list) and len(segments) >= 1 and isinstance(segments[0], dict): for col in seg1_cols: seg1_dict[f'seg1_{col}'] = segments[0].get(col) else: for col in seg1_cols: seg1_dict[f'seg1_{col}'] = pd.NA if isinstance(segments, list) and len(segments) >= 2 and isinstance(segments[1], dict): for col in seg2_cols: seg2_dict[f'seg2_{col}'] = segments[1].get(col) else: for col in seg2_cols: seg2_dict[f'seg2_{col}'] = pd.NA seg1_data.append(seg1_dict) seg2_data.append(seg2_dict) # 创建DataFrame df_seg1 = pd.DataFrame(seg1_data, index=df.index) df_seg2 = pd.DataFrame(seg2_data, index=df.index) # 合并到原DataFrame df = pd.concat([df.drop(columns=['segments'], errors='ignore'), df_seg1, df_seg2], axis=1) # 后续处理保持不变 time_cols = ['seg1_dep_time', 'seg1_arr_time', 'seg2_dep_time', 'seg2_arr_time'] for col in time_cols: if col in df.columns: df[col] = pd.to_datetime(df[col], format='%Y%m%d%H%M%S', errors='coerce') df['source_website'] = np.where( df['source_website'].str.contains('7_30'), 0, np.where(df['source_website'].str.contains('0_7'), 1, df['source_website']) ) conditions = [ df['seg1_baggage'] == '-;-;-;-', df['seg1_baggage'] == '1-20', df['seg1_baggage'] == '1-30', df['seg1_baggage'] == '1-40', ] choices = [0, 20, 30, 40] df['seg1_baggage'] = np.select(conditions, choices, default=df['seg1_baggage']) df = df.rename(columns={ 'seg1_cabin': 'cabin', 'seg1_baggage': 'baggage', 'source_website': 'is_near', }) return df def fill_hourly_crawl_date(df, head_fill=0, rear_fill=0): """补齐成小时粒度数据""" df = df.copy() # 1. 转 datetime df['crawl_date'] = pd.to_datetime(df['crawl_date']) # 添加一个用于分组的小时字段 df['update_hour'] = df['crawl_date'].dt.floor('h') # 2. 排序规则:同一小时内,按原始时间戳排序 # 假设你想保留最早的一条 df = df.sort_values(['update_hour', 'crawl_date']) # 3. 按小时去重,保留该小时内最早的一条 df = df.drop_duplicates(subset=['update_hour'], keep='first') # 删除原始时间戳列 # df = df.drop(columns=['crawl_date']) # df = df.drop(columns=['_id']) # 4. 标记原始数据 df['is_filled'] = 0 # 5. 排序 + 设索引 df = df.sort_values('update_hour').set_index('update_hour') # 6. 构造完整小时轴 start_of_day = df.index.min() # 默认 第一天 最早 开始 if head_fill == 1: start_of_day = df.index.min().normalize() # 强制 第一天 00:00 开始 end_of_day = df.index.max() # 默认 最后一天 最晚 结束 if rear_fill == 1: end_of_day = df.index.max().normalize() + pd.Timedelta(hours=23) # 强制 最后一天 23:00 结束 elif rear_fill == 2: if 'seg1_dep_time' in df.columns: last_dep_time = df['seg1_dep_time'].iloc[-1] if pd.notna(last_dep_time): # 对齐到整点小时(向下取整) end_of_day = last_dep_time.floor('h') full_index = pd.date_range( start=start_of_day, end=end_of_day, freq='1h' ) # 7. 按小时补齐 df = df.reindex(full_index) # 先恢复 dtype(关键!) df = df.infer_objects(copy=False) # 8. 新增出来的行标记为 1 df['is_filled'] = df['is_filled'].fillna(1) # 9. 前向填充 df = df.ffill() # 10. 还原整型字段 int_cols = [ 'seats_remaining', 'is_near', 'baggage', 'is_filled', ] for col in int_cols: if col in df.columns: df[col] = df[col].astype('int64') # 10.5 价格字段统一保留两位小数 price_cols = [ 'adult_price', 'adult_tax', 'adult_total_price' ] for col in price_cols: if col in df.columns: df[col] = df[col].astype('float64').round(2) # 10.6 新增:距离起飞还有多少小时 if 'seg1_dep_time' in df.columns: # 创建临时字段(整点) df['seg1_dep_hour'] = df['seg1_dep_time'].dt.floor('h') # 计算小时差 df.index 此时就是 update_hour df['hours_until_departure'] = ( (df['seg1_dep_hour'] - df.index) / pd.Timedelta(hours=1) ).astype('int64') # 新增:距离起飞还有多少天 df['days_to_departure'] = (df['hours_until_departure'] // 24).astype('int64') # 删除临时字段 df = df.drop(columns=['seg1_dep_hour']) # 11. 写回 update_hour df['update_hour'] = df.index # 12. 恢复普通索引 df = df.reset_index(drop=True) return df def query_groups_of_city_code(db, from_city, to_city, table_name, min_days=20, max_retries=3, base_sleep=1.0): """ 从一组城市对中查找所有分组(航班号与起飞时间)的组合 按:第一段航班号 → 第二段航班号 → 起飞时间 排序 (失败自动重试) 保证2个月内至少有20天起飞的航线 """ print(f"{from_city}-{to_city} 查找所有分组") date_begin = (datetime.today() - timedelta(days=60)).strftime("%Y%m%d") date_end = datetime.today().strftime("%Y%m%d") pipeline = [ # 1️⃣ 先筛选城市对 { "$match": { "from_city_code": from_city, "to_city_code": to_city, "search_dep_time": { "$gte": date_begin, "$lte": date_end } } }, # 2️⃣ 投影字段 + 拆第一、第二段航班号用于排序 { "$project": { "flight_numbers": "$segments.flight_number", "search_dep_time": 1, "fn1": {"$arrayElemAt": ["$segments.flight_number", 0]}, "fn2": {"$arrayElemAt": ["$segments.flight_number", 1]} } }, # 3️⃣ 第一级分组:组合 + 每一天 { "$group": { "_id": { "flight_numbers": "$flight_numbers", "search_dep_time": "$search_dep_time", "fn1": "$fn1", "fn2": "$fn2" }, "count": {"$sum": 1} } }, # 关键修复点:这里先按【时间】排好序! { "$sort": { "_id.fn1": 1, "_id.fn2": 1, "_id.search_dep_time": 1 # 确保 push 进去时是按天递增 } }, # 4️⃣ 第二级分组:只按【航班组合】聚合 → 统计“有多少天” { "$group": { "_id": { "flight_numbers": "$_id.flight_numbers", "fn1": "$_id.fn1", "fn2": "$_id.fn2" }, "days": {"$sum": 1}, # 不同起飞天数 "details": { "$push": { "search_dep_time": "$_id.search_dep_time", "count": "$count" } } } }, # 5️⃣ 关键:按“天数阈值”过滤 { "$match": { "days": {"$gte": min_days} } }, # 6️⃣ ✅ 按“第一段 → 第二段”排序 { "$sort": { "_id.fn1": 1, "_id.fn2": 1, } } ] for attempt in range(1, max_retries + 1): try: print(f" 第 {attempt}/{max_retries} 次尝试查询") # 执行聚合查询 collection = db[table_name] results = list(collection.aggregate(pipeline)) # 格式化结果,将 _id 中的字段提取到外层 formatted_results = [] for item in results: formatted_item = { "flight_numbers": item["_id"]["flight_numbers"], "days": item["days"], # 这个组合一共有多少天 "details": item["details"] # 每一天的 count 明细 } formatted_results.append(formatted_item) return formatted_results except (ServerSelectionTimeoutError, PyMongoError) as e: print(f"⚠️ Mongo 查询失败: {e}") if attempt == max_retries: print("❌ 达到最大重试次数,放弃") return [] # 指数退避 + 随机抖动 sleep_time = base_sleep * (2 ** (attempt - 1)) + random.random() print(f"⏳ {sleep_time:.2f}s 后重试...") time.sleep(sleep_time) def plot_c12_trend(df, output_dir="."): """ 根据传入的 dataframe 绘制 adult_total_price 随 update_hour 的趋势图, 并按照 baggage 分类进行分组绘制。 """ # output_dir_photo = output_dir # 颜色与线型配置(按顺序循环使用) colors = ['green', 'blue', 'red', 'brown'] linestyles = ['--', '--', '--', '--'] # 确保时间字段为 datetime 类型 if not hasattr(df['update_hour'], 'dt'): df['update_hour'] = pd.to_datetime(df['update_hour']) from_city = df['from_city_code'].mode().iloc[0] to_city = df['to_city_code'].mode().iloc[0] flight_number_1 = df['seg1_flight_number'].mode().iloc[0] flight_number_2 = df['seg2_flight_number'].mode().get(0, "") dep_time = df['seg1_dep_time'].mode().iloc[0] route = f"{from_city}-{to_city}" flight_number = f"{flight_number_1},{flight_number_2}" if flight_number_2 else f"{flight_number_1}" output_dir_photo = os.path.join(output_dir, route) os.makedirs(output_dir_photo, exist_ok=True) # 创建图表对象 fig = plt.figure(figsize=(14, 8)) # 按 baggage 分类绘制 for i, (baggage_value, group) in enumerate(df.groupby('baggage')): # 按时间排序 g = group.sort_values('update_hour').reset_index(drop=True) # 找价格变化点:与前一行不同的价格即为变化点 # keep first row + change rows + last row change_points = g.loc[ (g['adult_total_price'] != g['adult_total_price'].shift(1)) | (g.index == 0) | (g.index == len(g) - 1) # 终点 ].drop_duplicates(subset=['update_hour']) # 绘制点和线条 plt.plot( change_points['update_hour'], change_points['adult_total_price'], marker='o', color=colors[i % len(colors)], linestyle=linestyles[i % len(linestyles)], linewidth=2, markersize=6, markerfacecolor='white', markeredgewidth=2, label=f"Baggage {baggage_value}" ) # 添加注释 (小时数, 价格) for _, row in change_points.iterrows(): text = f"({row['hours_until_departure']}, {row['adult_total_price']})" plt.annotate( text, xy=(row['update_hour'], row['adult_total_price']), xytext=(0, 0), # 向右偏移 textcoords="offset points", ha='left', va='center', fontsize=5, # 字体稍小 color='gray', alpha=0.8, rotation=25, ) # 自动优化日期显示 plt.gcf().autofmt_xdate() plt.xlabel('时刻', fontsize=12, fontproperties=font_prop) plt.ylabel('价格', fontsize=12, fontproperties=font_prop) plt.title(f'价格变化趋势 - 航线: {route} 航班号: {flight_number}\n起飞时间: {dep_time}', fontsize=14, fontweight='bold', fontproperties=font_prop) # 设置 x 轴刻度为每天 ax = plt.gca() ax.xaxis.set_major_locator(mdates.DayLocator(interval=1)) # 每天一个主刻度 ax.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d')) # 显示月-日 ax.xaxis.set_minor_locator(mdates.HourLocator(byhour=[12])) # 指定在12:00显示副刻度 ax.xaxis.set_minor_formatter(mdates.DateFormatter('')) # 输出空字符串 # ax.tick_params(axis='x', which='minor', labelsize=8, rotation=30) # 添加图例 plt.legend(bbox_to_anchor=(1.05, 1), loc='best', prop=font_prop) plt.grid(True, alpha=0.3) plt.tight_layout() safe_flight = flight_number.replace(",", "_") safe_dep_time = dep_time.strftime("%Y-%m-%d %H%M%S") save_file = f"{route} {safe_flight} {safe_dep_time}.png" output_path = os.path.join(output_dir_photo, save_file) # 保存图片(在显示之前) plt.savefig(output_path, dpi=300, bbox_inches='tight', facecolor='white') # 关闭图形释放内存 plt.close(fig) def process_flight_group(args): """处理单个航班号的进程函数(独立数据库连接)""" process_id, db_config, each_group, from_city, to_city, date_begin_s, date_end_s, is_hot, plot_flag, output_dir = args flight_nums = each_group.get("flight_numbers") details = each_group.get("details") print(f"[进程{process_id}] 开始处理航班号: {flight_nums}") # 为每个进程创建独立的数据库连接 try: client, db = mongo_con_parse(db_config) print(f"[进程{process_id}] ✅ 数据库连接创建成功") except Exception as e: print(f"[进程{process_id}] ❌ 数据库连接创建失败: {e}") return pd.DataFrame() try: # 查询远期表 if is_hot == 1: df1 = query_flight_range_status(db, CLEAN_VJ_HOT_FAR_INFO_TAB, from_city, to_city, date_begin_s, date_end_s, flight_nums) else: df1 = query_flight_range_status(db, CLEAN_VJ_NOTHOT_FAR_INFO_TAB, from_city, to_city, date_begin_s, date_end_s, flight_nums) # 保证远期表里有数据 if df1.empty: print(f"[进程{process_id}] 航班号:{flight_nums} 远期表无数据, 跳过") return pd.DataFrame() # 查询近期表 if is_hot == 1: df2 = query_flight_range_status(db, CLEAN_VJ_HOT_NEAR_INFO_TAB, from_city, to_city, date_begin_s, date_end_s, flight_nums) else: df2 = query_flight_range_status(db, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB, from_city, to_city, date_begin_s, date_end_s, flight_nums) # 保证近期表里有数据 if df2.empty: print(f"[进程{process_id}] 航班号:{flight_nums} 近期表无数据, 跳过") return pd.DataFrame() # 起飞天数、行李配额以近期表的为主 if df2.empty: common_dep_dates = [] common_baggages = [] else: common_dep_dates = df2['search_dep_time'].unique() common_baggages = df2['baggage'].unique() list_mid = [] for dep_date in common_dep_dates: # 起飞日期筛选 df_d1 = df1[df1["search_dep_time"] == dep_date].copy() if not df_d1.empty: for col in ["seg1_dep_time", "seg1_arr_time", "seg2_dep_time", "seg2_arr_time"]: mode_series_1 = df_d1[col].mode() if mode_series_1.empty: zong_1 = pd.NaT else: zong_1 = mode_series_1.iloc[0] df_d1[col] = zong_1 df_d2 = df2[df2["search_dep_time"] == dep_date].copy() if not df_d2.empty: for col in ["seg1_dep_time", "seg1_arr_time", "seg2_dep_time", "seg2_arr_time"]: mode_series_2 = df_d2[col].mode() if mode_series_2.empty: zong_2 = pd.NaT else: zong_2 = mode_series_2.iloc[0] df_d2[col] = zong_2 list_12 = [] for baggage in common_baggages: # 行李配额筛选 df_b1 = df_d1[df_d1["baggage"] == baggage].copy() df_b2 = df_d2[df_d2["baggage"] == baggage].copy() # 合并前检查是否都有数据 if df_b1.empty and df_b2.empty: print(f"[进程{process_id}] ⚠️ dep_date:{dep_date}, baggage:{baggage} 远期表和近期表都为空,跳过") continue cols = ["seg1_flight_number", "seg1_dep_air_port", "seg1_arr_air_port", "seg2_flight_number", "seg2_dep_air_port", "seg2_arr_air_port"] df_b1[cols] = df_b1[cols].astype("string") df_b2[cols] = df_b2[cols].astype("string") df_b12 = pd.concat([df_b1, df_b2]).reset_index(drop=True) # print(f"📊 dep_date:{dep_date}, baggage:{baggage} 已将远期表和近期表合并,形状: {df_b12.shape}") df_b12 = fill_hourly_crawl_date(df_b12, rear_fill=2) # print(f"📊 dep_date:{dep_date}, baggage:{baggage} 已合并且补齐为完整小时序列,形状: {df_b12.shape}") list_12.append(df_b12) del df_b12 del df_b2 del df_b1 if list_12: df_c12 = pd.concat(list_12, ignore_index=True) if plot_flag: print(f"[进程{process_id}] ✅ dep_date:{dep_date}, 所有 baggage 数据合并完成,总形状: {df_c12.shape}") plot_c12_trend(df_c12, output_dir) print(f"[进程{process_id}] ✅ dep_date:{dep_date}, 所有 baggage 数据绘图完成") else: df_c12 = pd.DataFrame() if plot_flag: print(f"[进程{process_id}] ⚠️ dep_date:{dep_date}, 所有 baggage 数据合并为空") del list_12 list_mid.append(df_c12) del df_c12 del df_d1 del df_d2 # print(f"结束处理起飞日期: {dep_date}") if list_mid: df_mid = pd.concat(list_mid, ignore_index=True) print(f"[进程{process_id}] ✅ 航班号:{flight_nums} 所有 起飞日期 数据合并完成,总形状: {df_mid.shape}") else: df_mid = pd.DataFrame() print(f"[进程{process_id}] ⚠️ 航班号:{flight_nums} 所有 起飞日期 数据合并为空") del list_mid del df1 del df2 gc.collect() print(f"[进程{process_id}] 结束处理航班号: {flight_nums}") return df_mid except Exception as e: print(f"[进程{process_id}] ❌ 处理航班号:{flight_nums} 时发生异常: {e}") return pd.DataFrame() finally: # 确保关闭数据库连接 try: client.close() print(f"[进程{process_id}] ✅ 数据库连接已关闭") except: pass def load_train_data(db_config, flight_route_list, table_name, date_begin, date_end, output_dir='.', is_hot=1, plot_flag=False, use_multiprocess=False, max_workers=None): """加载训练数据(支持多进程)""" timestamp_str = datetime.now().strftime("%Y%m%d%H%M%S") date_begin_s = datetime.strptime(date_begin, "%Y-%m-%d").strftime("%Y%m%d") # 查询时的格式 date_end_s = datetime.strptime(date_end, "%Y-%m-%d").strftime("%Y%m%d") list_all = [] # 每一航线对 for flight_route in flight_route_list: from_city = flight_route.split('-')[0] to_city = flight_route.split('-')[1] route = f"{from_city}-{to_city}" print(f"开始处理航线: {route}") # 在主进程中查询航班号分组(避免多进程重复查询) main_client, main_db = mongo_con_parse(db_config) all_groups = query_groups_of_city_code(main_db, from_city, to_city, table_name) main_client.close() all_groups_len = len(all_groups) print(f"该航线共有{all_groups_len}个航班号") if use_multiprocess and all_groups_len > 1: print(f"启用多线程处理,最大线程数: {max_workers}") # 多进程处理 process_args = [] process_id = 0 for each_group in all_groups: process_id += 1 args = (process_id, db_config, each_group, from_city, to_city, date_begin_s, date_end_s, is_hot, plot_flag, output_dir) process_args.append(args) with ProcessPoolExecutor(max_workers=max_workers) as executor: future_to_group = {executor.submit(process_flight_group, args): each_group for args, each_group in zip(process_args, all_groups)} for future in as_completed(future_to_group): each_group = future_to_group[future] flight_nums = each_group.get("flight_numbers", "未知") try: df_mid = future.result() if not df_mid.empty: list_all.append(df_mid) print(f"✅ 航班号:{flight_nums} 处理完成") else: print(f"⚠️ 航班号:{flight_nums} 处理结果为空") except Exception as e: print(f"❌ 航班号:{flight_nums} 处理异常: {e}") else: # 单进程处理(进程编号为0) print("使用单进程处理") process_id = 0 for each_group in all_groups: args = (process_id, db_config, each_group, from_city, to_city, date_begin_s, date_end_s, is_hot, plot_flag, output_dir) flight_nums = each_group.get("flight_numbers", "未知") try: df_mid = process_flight_group(args) if not df_mid.empty: list_all.append(df_mid) print(f"✅ 航班号:{flight_nums} 处理完成") else: print(f"⚠️ 航班号:{flight_nums} 处理结果为空") except Exception as e: print(f"❌ 航班号:{flight_nums} 处理异常: {e}") print(f"结束处理航线: {from_city}-{to_city}") if list_all: df_all = pd.concat(list_all, ignore_index=True) else: df_all = pd.DataFrame() print(f"本批次数据加载完毕, 总形状: {df_all.shape}") del list_all gc.collect() return df_all def query_all_flight_number(db, table_name): print(f"{table_name} 查找所有航班号") pipeline = [ { "$project": { "flight_numbers": "$segments.flight_number" } }, { "$group": { "_id": "$flight_numbers", "count": { "$sum": 1 } } }, ] # 执行聚合查询 collection = db[table_name] results = list(collection.aggregate(pipeline)) list_flight_number = [] for item in results: item_li = item.get("_id", []) list_flight_number.extend(item_li) list_flight_number = list(set(list_flight_number)) return list_flight_number def validate_one_line(db, city_pair, flight_day, flight_number_1, flight_number_2, baggage, valid_begin_hour, limit=0, max_retries=3, base_sleep=1.0): """验证预测结果的一行""" if city_pair in vj_flight_route_list_hot: table_name = CLEAN_VJ_HOT_NEAR_INFO_TAB elif city_pair in vj_flight_route_list_nothot: table_name = CLEAN_VJ_NOTHOT_NEAR_INFO_TAB else: print(f"城市对{city_pair}不在热门航线与冷门航线, 返回") return pd.DataFrame() city_pair_split = city_pair.split('-') from_city_code = city_pair_split[0] to_city_code = city_pair_split[1] flight_day_str = datetime.strptime(flight_day, "%Y-%m-%d").strftime("%Y%m%d") baggage_str = f"1-{baggage}" for attempt in range(1, max_retries + 1): try: print(f"🔁 第 {attempt}/{max_retries} 次尝试查询") # 构建查询条件 query_condition = { "from_city_code": from_city_code, "to_city_code": to_city_code, "search_dep_time": flight_day_str, "segments.baggage": baggage_str, "crawl_date": {"$gte": valid_begin_hour}, "segments.0.flight_number": flight_number_1, } # 如果有第二段 if flight_number_2 != "VJ": query_condition["segments.1.flight_number"] = flight_number_2 print(f" 查询条件: {query_condition}") # 定义要查询的字段 projection = { # "_id": 1, "from_city_code": 1, "search_dep_time": 1, "to_city_code": 1, "currency": 1, "adult_price": 1, "adult_tax": 1, "adult_total_price": 1, "seats_remaining": 1, "segments": 1, "source_website": 1, "crawl_date": 1 } # 执行查询 cursor = db.get_collection(table_name).find( query_condition, projection=projection # 添加投影参数 ).sort( [ ("crawl_date", 1) ] ) if limit > 0: cursor = cursor.limit(limit) # 将结果转换为列表 results = list(cursor) print(f"✅ 查询成功,找到 {len(results)} 条记录") if results: df = pd.DataFrame(results) # 处理特殊的 ObjectId 类型 if '_id' in df.columns: df = df.drop(columns=['_id']) print(f"📊 已转换为 DataFrame,形状: {df.shape}") # 1️⃣ 展开 segments print(f"📊 开始扩展segments 稍等...") t1 = time.time() df = expand_segments_columns_optimized(df) t2 = time.time() rt = round(t2 - t1, 3) print(f"用时: {rt} 秒") print(f"📊 已将segments扩展成字段,形状: {df.shape}") # 不用排序,因为mongo语句已经排好 return df else: print("⚠️ 查询结果为空") return pd.DataFrame() except (ServerSelectionTimeoutError, PyMongoError) as e: print(f"⚠️ Mongo 查询失败: {e}") if attempt == max_retries: print("❌ 达到最大重试次数,放弃") return pd.DataFrame() # 指数退避 + 随机抖动 sleep_time = base_sleep * (2 ** (attempt - 1)) + random.random() print(f"⏳ {sleep_time:.2f}s 后重试...") time.sleep(sleep_time) if __name__ == "__main__": # test_mongo_connection(db) from utils import chunk_list_with_index cpu_cores = os.cpu_count() # 你的系统是72 max_workers = min(8, cpu_cores) # 最大不超过8个进程 output_dir = f"./output" os.makedirs(output_dir, exist_ok=True) # 加载热门航线数据 date_begin = "2026-01-08" date_end = datetime.today().strftime("%Y-%m-%d") flight_route_list = vj_flight_route_list_hot[:] # 热门 vj_flight_route_list_hot 冷门 vj_flight_route_list_nothot table_name = CLEAN_VJ_HOT_NEAR_INFO_TAB # 热门 CLEAN_VJ_HOT_NEAR_INFO_TAB 冷门 CLEAN_VJ_NOTHOT_NEAR_INFO_TAB is_hot = 1 # 1 热门 0 冷门 group_size = 1 chunks = chunk_list_with_index(flight_route_list, group_size) for idx, (_, group_route_list) in enumerate(chunks, 1): # 使用默认配置 # client, db = mongo_con_parse() print(f"第 {idx} 组 :", group_route_list) start_time = time.time() load_train_data(mongodb_config, group_route_list, table_name, date_begin, date_end, output_dir, is_hot, plot_flag=False, use_multiprocess=True, max_workers=max_workers) end_time = time.time() run_time = round(end_time - start_time, 3) print(f"用时: {run_time} 秒") # client.close() time.sleep(3) print("整体结束") # client, db = mongo_con_parse() # list_flight_number_1 = query_all_flight_number(db, CLEAN_VJ_HOT_NEAR_INFO_TAB) # list_flight_number_2 = query_all_flight_number(db, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB) # list_flight_number_all = list_flight_number_1 + list_flight_number_2 # list_flight_number_all = list(set(list_flight_number_all)) # list_flight_number_all.sort() # print(list_flight_number_all) # print(len(list_flight_number_all)) # flight_map = {v: i for i, v in enumerate(list_flight_number_all, start=1)} # print(flight_map)