import os import time import random from datetime import datetime, timedelta import gc from concurrent.futures import ProcessPoolExecutor, as_completed import pymongo from pymongo.errors import PyMongoError, ServerSelectionTimeoutError import pandas as pd import matplotlib.pyplot as plt from matplotlib import font_manager import matplotlib.dates as mdates from uo_atlas_import import mongo_con_parse from config import mongo_config, mongo_table_uo, uo_city_pairs_old, uo_city_pairs_new font_path = "./simhei.ttf" font_prop = font_manager.FontProperties(fname=font_path) def query_groups_of_city_pair(db, city_pair, table_name, min_days=10, max_retries=3, base_sleep=1.0): """根据city_pair查询航线, 筛选1个月内至少有10天起飞的航线""" print(f"{city_pair} 查找所有分组") date_begin = (datetime.today() - timedelta(days=30)).strftime("%Y-%m-%d") date_end = datetime.today().strftime("%Y-%m-%d") # 聚合查询管道 pipeline = [ { "$match": { "citypair": city_pair, "from_date": { "$gte": date_begin, "$lte": date_end } } }, { "$group": { "_id": { "flight_numbers": "$flight_numbers", "from_date": "$from_date" } } }, { "$group": { "_id": "$_id.flight_numbers", "days": {"$sum": 1}, "details": {"$push": "$_id.from_date"} } }, { "$match": { "days": {"$gte": min_days} } }, { "$addFields": { "details": {"$sortArray": {"input": "$details", "sortBy": 1}} } }, { "$sort": {"_id": 1} } ] for attempt in range(1, max_retries + 1): try: print(f" 第 {attempt}/{max_retries} 次尝试查询") # 执行聚合查询 collection = db[table_name] results = list(collection.aggregate(pipeline)) # 格式化结果,使字段名更清晰 formatted_results = [ { "flight_numbers": r["_id"], "days": r["days"], "flight_dates": r["details"] } for r in results ] return formatted_results except (ServerSelectionTimeoutError, PyMongoError) as e: print(f"⚠️ Mongo 查询失败: {e}") if attempt == max_retries: print("❌ 达到最大重试次数,放弃") return [] # 指数退避 + 随机抖动 sleep_time = base_sleep * (2 ** (attempt - 1)) + random.random() print(f"⏳ {sleep_time:.2f}s 后重试...") time.sleep(sleep_time) def query_flight_range_status(db, table_name, city_pair, flight_numbers, from_date_begin, from_date_end, limit=0, max_retries=3, base_sleep=1.0): for attempt in range(1, max_retries + 1): try: print(f"🔁 第 {attempt}/{max_retries} 次尝试查询") # 构建查询条件 projection = { # "_id": 0 # 一般会关掉 "citypair": 1, "flight_numbers": 1, "from_date": 1, "from_time": 1, "create_time": 1, "baggage_weight": 1, "cabins": 1, "ticket_amount": 1, "currency": 1, "price_base": 1, "price_tax": 1, "price_total": 1 } pipeline = [ { "$match": { "citypair": city_pair, "flight_numbers": flight_numbers, "baggage_weight": {"$in": [0, 20]}, "from_date": { "$gte": from_date_begin, "$lte": from_date_end } } }, { "$project": projection # 就是这里 }, { "$sort": { "from_date": 1, "baggage_weight": 1, "create_time": 1 } } ] # print(f" 查询条件: {pipeline}") # 执行查询 collection = db[table_name] results = list(collection.aggregate(pipeline)) print(f"✅ 查询成功,找到 {len(results)} 条记录") if results: df = pd.DataFrame(results) if '_id' in df.columns: df = df.drop(columns=['_id']) if 'from_time' in df.columns and 'from_date' in df.columns: from_time_raw = df['from_time'] from_time_str = from_time_raw.fillna('').astype(str).str.strip() non_empty = from_time_str[from_time_str.ne('')] # 找到原始 from_time 非空的记录 extracted_time = non_empty.str.extract(r'(\d{2}:\d{2}:\d{2})$')[0].dropna() if not extracted_time.empty: more_time = extracted_time.value_counts().idxmax() # 按众数分配给其它行 构造from_time missing_mask = from_time_raw.isna() | from_time_str.eq('') if missing_mask.any(): df.loc[missing_mask, 'from_time'] = df.loc[missing_mask, 'from_date'].astype(str).str.strip() + ' ' + more_time else: # 无法得到起飞日期的抛弃 print(f"⚠️ 无法提取有效起飞时间,抛弃该条记录") return pd.DataFrame() print(f"📊 已转换为 DataFrame,形状: {df.shape}") return df else: print("⚠️ 查询结果为空") return pd.DataFrame() except (ServerSelectionTimeoutError, PyMongoError) as e: print(f"⚠️ Mongo 查询失败: {e}") if attempt == max_retries: print("❌ 达到最大重试次数,放弃") return pd.DataFrame() # 指数退避 + 随机抖动 sleep_time = base_sleep * (2 ** (attempt - 1)) + random.random() print(f"⏳ {sleep_time:.2f}s 后重试...") time.sleep(sleep_time) def plot_c1_trend(df, output_dir="."): """ 根据传入的 dataframe 绘制 price_total 随 update_hour 的趋势图, 并按照 baggage 分类进行分组绘制。 """ # 颜色与线型配置(按顺序循环使用) colors = ['green', 'blue', 'red', 'brown'] linestyles = ['--', '--', '--', '--'] # 确保时间字段为 datetime 类型 if not hasattr(df['update_hour'], 'dt'): df['update_hour'] = pd.to_datetime(df['update_hour']) city_pair = df['citypair'].mode().iloc[0] flight_numbers = df['flight_numbers'].mode().iloc[0] from_time = df['from_time'].mode().iloc[0] output_dir_temp = os.path.join(output_dir, city_pair) os.makedirs(output_dir_temp, exist_ok=True) # 创建图表对象 fig = plt.figure(figsize=(14, 8)) # 按 baggage_weight 分类绘制 for i, (baggage_value, group) in enumerate(df.groupby('baggage_weight')): # 按时间排序 df_g = group.sort_values('update_hour').reset_index(drop=True) # 找价格变化点:与前一行不同的价格即为变化点 # keep first row + change rows + last row change_points = df_g.loc[ (df_g['price_total'] != df_g['price_total'].shift(1)) | (df_g.index == 0) | (df_g.index == len(df_g) - 1) # 终点 ].drop_duplicates(subset=['update_hour']) # 绘制阶梯线(平缓-突变)+ 变化点 plt.step( change_points['update_hour'], change_points['price_total'], where='post', color=colors[i % len(colors)], linestyle=linestyles[i % len(linestyles)], linewidth=2, label=f"Baggage {baggage_value}" ) plt.scatter( change_points['update_hour'], change_points['price_total'], s=30, facecolors='white', edgecolors=colors[i % len(colors)], linewidths=2, zorder=3, ) # 添加注释 (小时数, 价格) for _, row in change_points.iterrows(): text = f"({row['hours_until_departure']}, {row['price_total']})" plt.annotate( text, xy=(row['update_hour'], row['price_total']), xytext=(0, 0), # 向右偏移 textcoords="offset points", ha='left', va='center', fontsize=5, # 字体稍小 color='gray', alpha=0.8, rotation=25, ) del change_points del df_g # 自动优化日期显示 plt.gcf().autofmt_xdate() plt.xlabel('时刻', fontsize=12, fontproperties=font_prop) plt.ylabel('价格', fontsize=12, fontproperties=font_prop) plt.title(f'价格变化趋势 - 航线: {city_pair} 航班号: {flight_numbers}\n起飞时间: {from_time}', fontsize=14, fontweight='bold', fontproperties=font_prop) # 设置 x 轴刻度为每天 ax = plt.gca() ax.xaxis.set_major_locator(mdates.DayLocator(interval=1)) # 每天一个主刻度 ax.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d')) # 显示月-日 ax.xaxis.set_minor_locator(mdates.HourLocator(byhour=[12])) # 指定在12:00显示副刻度 ax.xaxis.set_minor_formatter(mdates.DateFormatter('')) # 输出空字符串 # 添加图例 plt.legend(bbox_to_anchor=(1.05, 1), loc='best', prop=font_prop) plt.grid(True, alpha=0.3) plt.tight_layout() safe_flight = flight_numbers.replace(",", "_") safe_dep_time = from_time.strftime("%Y-%m-%d %H%M%S") save_file = f"{city_pair} {safe_flight} {safe_dep_time}.png" output_path = os.path.join(output_dir_temp, save_file) # 保存图片(在显示之前) plt.savefig(output_path, dpi=300, bbox_inches='tight', facecolor='white') # 关闭图形释放内存 plt.close(fig) def fill_hourly_create_time(df, head_fill=0, rear_fill=0): """补齐成小时粒度数据""" df = df.copy() # 1. 转 datetime df['create_time'] = pd.to_datetime(df['create_time']) df['from_time'] = pd.to_datetime(df['from_time']) # 添加一个用于分组的小时字段 df['update_hour'] = df['create_time'].dt.floor('h') # 2. 排序规则:同一小时内,按原始时间戳排序 # 假设你想保留最早的一条 df = df.sort_values(['update_hour', 'create_time']) # 3. 按小时去重,保留该小时内最早(最晚)的一条 df = df.drop_duplicates(subset=['update_hour'], keep='last') # keep='first' keep='last' # 4. 标记原始数据 df['is_filled'] = 0 # 5. 排序 + 设索引 df = df.sort_values('update_hour').set_index('update_hour') # 6. 构造完整小时轴 start_of_hour = df.index.min() # 默认 第一天 最早 开始 if head_fill == 1: start_of_hour = df.index.min().normalize() # 强制 第一天 00:00 开始 end_of_hour = df.index.max() # 默认 最后一天 最晚 结束 if rear_fill == 1: end_of_hour = df.index.max().normalize() + pd.Timedelta(hours=23) # 强制 最后一天 23:00 结束 elif rear_fill == 2: if 'from_time' in df.columns: last_dep_time = df['from_time'].iloc[-1] if pd.notna(last_dep_time): # 对齐到整点小时(向下取整) end_of_hour = last_dep_time.floor('h') full_index = pd.date_range( start=start_of_hour, end=end_of_hour, freq='1h' ) # 7. 按小时补齐 df = df.reindex(full_index) # 先恢复 dtype(关键!) df = df.infer_objects(copy=False) # 8. 新增出来的行标记为 1 df['is_filled'] = df['is_filled'].fillna(1) # 9. 前向填充 df = df.ffill() # 10. 还原整型字段 int_cols = [ 'ticket_amount', 'baggage_weight', 'is_filled', ] for col in int_cols: if col in df.columns: df[col] = df[col].astype('int64') # 10.5 价格字段统一保留两位小数 price_cols = [ 'price_base', 'price_tax', 'price_total' ] for col in price_cols: if col in df.columns: df[col] = df[col].astype('float64').round(2) # 10.6 新增:距离起飞还有多少小时 if 'from_time' in df.columns: # 创建临时字段(整点) df['from_hour'] = df['from_time'].dt.floor('h') # 计算小时差 df.index 此时就是 update_hour df['hours_until_departure'] = ( (df['from_hour'] - df.index) / pd.Timedelta(hours=1) ).astype('int64') # 新增:距离起飞还有多少天 df['days_to_departure'] = (df['hours_until_departure'] // 24).astype('int64') # 删除临时字段 df = df.drop(columns=['from_hour']) # 11. 写回 update_hour df['update_hour'] = df.index # 12. 恢复普通索引 df = df.reset_index(drop=True) return df def process_flight_numbers(args): process_id, db_config, city_pair, flight_numbers, from_date_begin, from_date_end, is_train, plot_flag, output_dir = args print(f"[进程{process_id}] 开始处理航班号: {flight_numbers}") # 为每个进程创建独立的数据库连接 try: client, db = mongo_con_parse(db_config) print(f"[进程{process_id}] ✅ 数据库连接创建成功") except Exception as e: print(f"[进程{process_id}] ❌ 数据库连接创建失败: {e}") return pd.DataFrame() try: # 查询 df1 = query_flight_range_status(db, mongo_table_uo, city_pair, flight_numbers, from_date_begin, from_date_end) if df1.empty: return pd.DataFrame() common_dep_dates = df1['from_date'].unique() common_baggages = df1['baggage_weight'].unique() list_mid = [] for dep_date in common_dep_dates: # 起飞日期筛选 df_d1 = df1[df1["from_date"] == dep_date].copy() list_f1 = [] for baggage in common_baggages: # 行李配额筛选 df_b1 = df_d1[df_d1["baggage_weight"] == baggage].copy() if df_b1.empty: print(f"[进程{process_id}] ⚠️ dep_date:{dep_date}, baggage:{baggage} 为空,跳过") continue df_f1 = fill_hourly_create_time(df_b1, rear_fill=2) list_f1.append(df_f1) del df_f1 del df_b1 if list_f1: df_c1 = pd.concat(list_f1, ignore_index=True) if plot_flag: print(f"[进程{process_id}] ✅ dep_date:{dep_date}, 所有 baggage 数据合并完成,总形状: {df_c1.shape}") plot_c1_trend(df_c1, output_dir) print(f"[进程{process_id}] ✅ dep_date:{dep_date}, 所有 baggage 数据绘图完成") else: df_c1 = pd.DataFrame() if plot_flag: print(f"[进程{process_id}] ⚠️ dep_date:{dep_date}, 所有 baggage 数据合并为空") del list_f1 list_mid.append(df_c1) del df_c1 del df_d1 if list_mid: df_mid = pd.concat(list_mid, ignore_index=True) print(f"[进程{process_id}] ✅ 航班号:{flight_numbers} 所有 起飞日期 数据合并完成,总形状: {df_mid.shape}") else: df_mid = pd.DataFrame() print(f"[进程{process_id}] ⚠️ 航班号:{flight_numbers} 所有 起飞日期 数据合并为空") del list_mid del df1 gc.collect() print(f"[进程{process_id}] 结束处理航班号: {flight_numbers}") return df_mid except Exception as e: print(f"[进程{process_id}] ❌ 处理航班号:{flight_numbers} 时发生异常: {e}") return pd.DataFrame() finally: # 确保关闭数据库连接 try: client.close() print(f"[进程{process_id}] ✅ 数据库连接已关闭") except: pass def load_data(db_config, city_pair, from_date_begin, from_date_end, is_train=True, plot_flag=False, output_dir='.', use_multiprocess=False, max_workers=None): list_all = [] print(f"开始处理航线: {city_pair}") main_client, main_db = mongo_con_parse(db_config) all_groups = query_groups_of_city_pair(main_db, city_pair, mongo_table_uo) main_client.close() all_groups_len = len(all_groups) print(f"该航线共有{all_groups_len}组航班号") if use_multiprocess and all_groups_len > 1: print(f"启用多进程处理,最大进程数: {max_workers}") # 多进程处理 process_args = [] process_id = 0 for each_group in all_groups: flight_numbers = each_group.get("flight_numbers", "未知") process_id += 1 args = (process_id, db_config, city_pair, flight_numbers, from_date_begin, from_date_end, is_train, plot_flag, output_dir) process_args.append(args) with ProcessPoolExecutor(max_workers=max_workers) as executor: future_to_group = {executor.submit(process_flight_numbers, args): each_group for args, each_group in zip(process_args, all_groups)} for future in as_completed(future_to_group): each_group = future_to_group[future] flight_numbers = each_group.get("flight_numbers", "未知") try: df_mid = future.result() if not df_mid.empty: list_all.append(df_mid) print(f"✅ 航班号:{flight_numbers} 处理完成") else: print(f"⚠️ 航班号:{flight_numbers} 处理结果为空") except Exception as e: print(f"❌ 航班号:{flight_numbers} 处理异常: {e}") pass else: print("使用单进程处理") process_id = 0 for each_group in all_groups: flight_numbers = each_group.get("flight_numbers", "未知") args = (process_id, db_config, city_pair, flight_numbers, from_date_begin, from_date_end, is_train, plot_flag, output_dir) try: df_mid = process_flight_numbers(args) if not df_mid.empty: list_all.append(df_mid) print(f"✅ 航班号:{flight_numbers} 处理完成") else: print(f"⚠️ 航班号:{flight_numbers} 处理结果为空") except Exception as e: print(f"❌ 航班号:{flight_numbers} 处理异常: {e}") print(f"结束处理航线: {city_pair}") if list_all: df_all = pd.concat(list_all, ignore_index=True) else: df_all = pd.DataFrame() print(f"本批次数据加载完毕, 总形状: {df_all.shape}") del list_all gc.collect() return df_all if __name__ == "__main__": cpu_cores = os.cpu_count() # 你的系统是72 max_workers = min(8, cpu_cores) # 最大不超过8个进程 output_dir = f"./photo" os.makedirs(output_dir, exist_ok=True) from_date_begin = "2026-03-17" from_date_end = "2026-03-26" uo_city_pairs = uo_city_pairs_new.copy() uo_city_pair_list = [f"{pair[:3]}-{pair[3:]}" for pair in uo_city_pairs] for idx, uo_city_pair in enumerate(uo_city_pair_list, start=1): # 使用默认配置 # client, db = mongo_con_parse() print(f"第 {idx} 组 :", uo_city_pair) start_time = time.time() load_data(mongo_config, uo_city_pair, from_date_begin, from_date_end, plot_flag=True, output_dir=output_dir, use_multiprocess=True, max_workers=max_workers) end_time = time.time() run_time = round(end_time - start_time, 3) print(f"用时: {run_time} 秒")