| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567 |
- import os
- import time
- import random
- from datetime import datetime, timedelta
- import gc
- from concurrent.futures import ProcessPoolExecutor, as_completed
- import pymongo
- from pymongo.errors import PyMongoError, ServerSelectionTimeoutError
- import pandas as pd
- import matplotlib.pyplot as plt
- from matplotlib import font_manager
- import matplotlib.dates as mdates
- from uo_atlas_import import mongo_con_parse
- from config import mongo_config, mongo_table_uo, uo_city_pairs_old, uo_city_pairs_new
- font_path = "./simhei.ttf"
- font_prop = font_manager.FontProperties(fname=font_path)
- def query_groups_of_city_pair(db, city_pair, table_name, min_days=10, max_retries=3, base_sleep=1.0):
- """根据city_pair查询航线, 筛选1个月内至少有10天起飞的航线"""
- print(f"{city_pair} 查找所有分组")
- date_begin = (datetime.today() - timedelta(days=30)).strftime("%Y-%m-%d")
- date_end = datetime.today().strftime("%Y-%m-%d")
-
- # 聚合查询管道
- pipeline = [
- {
- "$match": {
- "citypair": city_pair,
- "from_date": {
- "$gte": date_begin,
- "$lte": date_end
- }
- }
- },
- {
- "$group": {
- "_id": {
- "flight_numbers": "$flight_numbers",
- "from_date": "$from_date"
- }
- }
- },
- {
- "$group": {
- "_id": "$_id.flight_numbers",
- "days": {"$sum": 1},
- "details": {"$push": "$_id.from_date"}
- }
- },
- {
- "$match": {
- "days": {"$gte": min_days}
- }
- },
- {
- "$addFields": {
- "details": {"$sortArray": {"input": "$details", "sortBy": 1}}
- }
- },
- {
- "$sort": {"_id": 1}
- }
- ]
- for attempt in range(1, max_retries + 1):
- try:
- print(f" 第 {attempt}/{max_retries} 次尝试查询")
- # 执行聚合查询
- collection = db[table_name]
- results = list(collection.aggregate(pipeline))
- # 格式化结果,使字段名更清晰
- formatted_results = [
- {
- "flight_numbers": r["_id"],
- "days": r["days"],
- "flight_dates": r["details"]
- }
- for r in results
- ]
-
- return formatted_results
- except (ServerSelectionTimeoutError, PyMongoError) as e:
- print(f"⚠️ Mongo 查询失败: {e}")
- if attempt == max_retries:
- print("❌ 达到最大重试次数,放弃")
- return []
- # 指数退避 + 随机抖动
- sleep_time = base_sleep * (2 ** (attempt - 1)) + random.random()
- print(f"⏳ {sleep_time:.2f}s 后重试...")
- time.sleep(sleep_time)
- def query_flight_range_status(db, table_name, city_pair, flight_numbers, from_date_begin, from_date_end,
- limit=0, max_retries=3, base_sleep=1.0):
- for attempt in range(1, max_retries + 1):
- try:
- print(f"🔁 第 {attempt}/{max_retries} 次尝试查询")
- # 构建查询条件
- projection = {
- # "_id": 0 # 一般会关掉
- "citypair": 1,
- "flight_numbers": 1,
- "from_date": 1,
- "from_time": 1,
- "create_time": 1,
- "baggage_weight": 1,
- "cabins": 1,
- "ticket_amount": 1,
- "currency": 1,
- "price_base": 1,
- "price_tax": 1,
- "price_total": 1
- }
- pipeline = [
- {
- "$match": {
- "citypair": city_pair,
- "flight_numbers": flight_numbers,
- "baggage_weight": {"$in": [0, 20]},
- "from_date": {
- "$gte": from_date_begin,
- "$lte": from_date_end
- }
- }
- },
- {
- "$project": projection # 就是这里
- },
- {
- "$sort": {
- "from_date": 1,
- "baggage_weight": 1,
- "create_time": 1
- }
- }
- ]
- # print(f" 查询条件: {pipeline}")
- # 执行查询
- collection = db[table_name]
- results = list(collection.aggregate(pipeline))
- print(f"✅ 查询成功,找到 {len(results)} 条记录")
- if results:
- df = pd.DataFrame(results)
- if '_id' in df.columns:
- df = df.drop(columns=['_id'])
-
- if 'from_time' in df.columns and 'from_date' in df.columns:
- from_time_raw = df['from_time']
- from_time_str = from_time_raw.fillna('').astype(str).str.strip()
- non_empty = from_time_str[from_time_str.ne('')] # 找到原始 from_time 非空的记录
- extracted_time = non_empty.str.extract(r'(\d{2}:\d{2}:\d{2})$')[0].dropna()
- if not extracted_time.empty:
- more_time = extracted_time.value_counts().idxmax() # 按众数分配给其它行 构造from_time
- missing_mask = from_time_raw.isna() | from_time_str.eq('')
- if missing_mask.any():
- df.loc[missing_mask, 'from_time'] = df.loc[missing_mask, 'from_date'].astype(str).str.strip() + ' ' + more_time
- else:
- # 无法得到起飞日期的抛弃
- print(f"⚠️ 无法提取有效起飞时间,抛弃该条记录")
- return pd.DataFrame()
- print(f"📊 已转换为 DataFrame,形状: {df.shape}")
- return df
- else:
- print("⚠️ 查询结果为空")
- return pd.DataFrame()
- except (ServerSelectionTimeoutError, PyMongoError) as e:
- print(f"⚠️ Mongo 查询失败: {e}")
- if attempt == max_retries:
- print("❌ 达到最大重试次数,放弃")
- return pd.DataFrame()
-
- # 指数退避 + 随机抖动
- sleep_time = base_sleep * (2 ** (attempt - 1)) + random.random()
- print(f"⏳ {sleep_time:.2f}s 后重试...")
- time.sleep(sleep_time)
- def plot_c1_trend(df, output_dir="."):
- """
- 根据传入的 dataframe 绘制 price_total 随 update_hour 的趋势图,
- 并按照 baggage 分类进行分组绘制。
- """
- # 颜色与线型配置(按顺序循环使用)
- colors = ['green', 'blue', 'red', 'brown']
- linestyles = ['--', '--', '--', '--']
- # 确保时间字段为 datetime 类型
- if not hasattr(df['update_hour'], 'dt'):
- df['update_hour'] = pd.to_datetime(df['update_hour'])
-
- city_pair = df['citypair'].mode().iloc[0]
- flight_numbers = df['flight_numbers'].mode().iloc[0]
- from_time = df['from_time'].mode().iloc[0]
-
- output_dir_temp = os.path.join(output_dir, city_pair)
- os.makedirs(output_dir_temp, exist_ok=True)
- # 创建图表对象
- fig = plt.figure(figsize=(14, 8))
- # 按 baggage_weight 分类绘制
- for i, (baggage_value, group) in enumerate(df.groupby('baggage_weight')):
- # 按时间排序
- df_g = group.sort_values('update_hour').reset_index(drop=True)
- # 找价格变化点:与前一行不同的价格即为变化点
- # keep first row + change rows + last row
- change_points = df_g.loc[
- (df_g['price_total'] != df_g['price_total'].shift(1)) |
- (df_g.index == 0) |
- (df_g.index == len(df_g) - 1) # 终点
- ].drop_duplicates(subset=['update_hour'])
- # 绘制阶梯线(平缓-突变)+ 变化点
- plt.step(
- change_points['update_hour'],
- change_points['price_total'],
- where='post',
- color=colors[i % len(colors)],
- linestyle=linestyles[i % len(linestyles)],
- linewidth=2,
- label=f"Baggage {baggage_value}"
- )
-
- plt.scatter(
- change_points['update_hour'],
- change_points['price_total'],
- s=30,
- facecolors='white',
- edgecolors=colors[i % len(colors)],
- linewidths=2,
- zorder=3,
- )
- # 添加注释 (小时数, 价格)
- for _, row in change_points.iterrows():
- text = f"({row['hours_until_departure']}, {row['price_total']})"
- plt.annotate(
- text,
- xy=(row['update_hour'], row['price_total']),
- xytext=(0, 0), # 向右偏移
- textcoords="offset points",
- ha='left',
- va='center',
- fontsize=5, # 字体稍小
- color='gray',
- alpha=0.8,
- rotation=25,
- )
-
- del change_points
- del df_g
- # 自动优化日期显示
- plt.gcf().autofmt_xdate()
- plt.xlabel('时刻', fontsize=12, fontproperties=font_prop)
- plt.ylabel('价格', fontsize=12, fontproperties=font_prop)
- plt.title(f'价格变化趋势 - 航线: {city_pair} 航班号: {flight_numbers}\n起飞时间: {from_time}',
- fontsize=14, fontweight='bold', fontproperties=font_prop)
- # 设置 x 轴刻度为每天
- ax = plt.gca()
- ax.xaxis.set_major_locator(mdates.DayLocator(interval=1)) # 每天一个主刻度
- ax.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d')) # 显示月-日
- ax.xaxis.set_minor_locator(mdates.HourLocator(byhour=[12])) # 指定在12:00显示副刻度
- ax.xaxis.set_minor_formatter(mdates.DateFormatter('')) # 输出空字符串
- # 添加图例
- plt.legend(bbox_to_anchor=(1.05, 1), loc='best', prop=font_prop)
- plt.grid(True, alpha=0.3)
- plt.tight_layout()
- safe_flight = flight_numbers.replace(",", "_")
- safe_dep_time = from_time.strftime("%Y-%m-%d %H%M%S")
- save_file = f"{city_pair} {safe_flight} {safe_dep_time}.png"
- output_path = os.path.join(output_dir_temp, save_file)
- # 保存图片(在显示之前)
- plt.savefig(output_path, dpi=300, bbox_inches='tight', facecolor='white')
- # 关闭图形释放内存
- plt.close(fig)
-
- def fill_hourly_create_time(df, head_fill=0, rear_fill=0):
- """补齐成小时粒度数据"""
- df = df.copy()
- # 1. 转 datetime
- df['create_time'] = pd.to_datetime(df['create_time'])
- df['from_time'] = pd.to_datetime(df['from_time'])
- # 添加一个用于分组的小时字段
- df['update_hour'] = df['create_time'].dt.floor('h')
- # 2. 排序规则:同一小时内,按原始时间戳排序
- # 假设你想保留最早的一条
- df = df.sort_values(['update_hour', 'create_time'])
- # 3. 按小时去重,保留该小时内最早(最晚)的一条
- df = df.drop_duplicates(subset=['update_hour'], keep='last') # keep='first' keep='last'
- # 4. 标记原始数据
- df['is_filled'] = 0
- # 5. 排序 + 设索引
- df = df.sort_values('update_hour').set_index('update_hour')
- # 6. 构造完整小时轴
- start_of_hour = df.index.min() # 默认 第一天 最早 开始
- if head_fill == 1:
- start_of_hour = df.index.min().normalize() # 强制 第一天 00:00 开始
- end_of_hour = df.index.max() # 默认 最后一天 最晚 结束
- if rear_fill == 1:
- end_of_hour = df.index.max().normalize() + pd.Timedelta(hours=23) # 强制 最后一天 23:00 结束
- elif rear_fill == 2:
- if 'from_time' in df.columns:
- last_dep_time = df['from_time'].iloc[-1]
- if pd.notna(last_dep_time):
- # 对齐到整点小时(向下取整)
- end_of_hour = last_dep_time.floor('h')
- full_index = pd.date_range(
- start=start_of_hour,
- end=end_of_hour,
- freq='1h'
- )
- # 7. 按小时补齐
- df = df.reindex(full_index)
- # 先恢复 dtype(关键!)
- df = df.infer_objects(copy=False)
- # 8. 新增出来的行标记为 1
- df['is_filled'] = df['is_filled'].fillna(1)
- # 9. 前向填充
- df = df.ffill()
- # 10. 还原整型字段
- int_cols = [
- 'ticket_amount',
- 'baggage_weight',
- 'is_filled',
- ]
- for col in int_cols:
- if col in df.columns:
- df[col] = df[col].astype('int64')
- # 10.5 价格字段统一保留两位小数
- price_cols = [
- 'price_base',
- 'price_tax',
- 'price_total'
- ]
- for col in price_cols:
- if col in df.columns:
- df[col] = df[col].astype('float64').round(2)
-
- # 10.6 新增:距离起飞还有多少小时
- if 'from_time' in df.columns:
- # 创建临时字段(整点)
- df['from_hour'] = df['from_time'].dt.floor('h')
- # 计算小时差 df.index 此时就是 update_hour
- df['hours_until_departure'] = (
- (df['from_hour'] - df.index) / pd.Timedelta(hours=1)
- ).astype('int64')
- # 新增:距离起飞还有多少天
- df['days_to_departure'] = (df['hours_until_departure'] // 24).astype('int64')
- # 删除临时字段
- df = df.drop(columns=['from_hour'])
-
- # 11. 写回 update_hour
- df['update_hour'] = df.index
- # 12. 恢复普通索引
- df = df.reset_index(drop=True)
- return df
- def process_flight_numbers(args):
- process_id, db_config, city_pair, flight_numbers, from_date_begin, from_date_end, is_train, plot_flag, output_dir = args
- print(f"[进程{process_id}] 开始处理航班号: {flight_numbers}")
-
- # 为每个进程创建独立的数据库连接
- try:
- client, db = mongo_con_parse(db_config)
- print(f"[进程{process_id}] ✅ 数据库连接创建成功")
- except Exception as e:
- print(f"[进程{process_id}] ❌ 数据库连接创建失败: {e}")
- return pd.DataFrame()
-
- try:
- # 查询
- df1 = query_flight_range_status(db, mongo_table_uo, city_pair, flight_numbers, from_date_begin, from_date_end)
- if df1.empty:
- return pd.DataFrame()
-
- common_dep_dates = df1['from_date'].unique()
- common_baggages = df1['baggage_weight'].unique()
- list_mid = []
- for dep_date in common_dep_dates:
- # 起飞日期筛选
- df_d1 = df1[df1["from_date"] == dep_date].copy()
- list_f1 = []
- for baggage in common_baggages:
- # 行李配额筛选
- df_b1 = df_d1[df_d1["baggage_weight"] == baggage].copy()
- if df_b1.empty:
- print(f"[进程{process_id}] ⚠️ dep_date:{dep_date}, baggage:{baggage} 为空,跳过")
- continue
- df_f1 = fill_hourly_create_time(df_b1, rear_fill=2)
- list_f1.append(df_f1)
- del df_f1
- del df_b1
- if list_f1:
- df_c1 = pd.concat(list_f1, ignore_index=True)
- if plot_flag:
- print(f"[进程{process_id}] ✅ dep_date:{dep_date}, 所有 baggage 数据合并完成,总形状: {df_c1.shape}")
- plot_c1_trend(df_c1, output_dir)
- print(f"[进程{process_id}] ✅ dep_date:{dep_date}, 所有 baggage 数据绘图完成")
- else:
- df_c1 = pd.DataFrame()
- if plot_flag:
- print(f"[进程{process_id}] ⚠️ dep_date:{dep_date}, 所有 baggage 数据合并为空")
-
- del list_f1
- list_mid.append(df_c1)
- del df_c1
- del df_d1
-
- if list_mid:
- df_mid = pd.concat(list_mid, ignore_index=True)
- print(f"[进程{process_id}] ✅ 航班号:{flight_numbers} 所有 起飞日期 数据合并完成,总形状: {df_mid.shape}")
- else:
- df_mid = pd.DataFrame()
- print(f"[进程{process_id}] ⚠️ 航班号:{flight_numbers} 所有 起飞日期 数据合并为空")
-
- del list_mid
- del df1
- gc.collect()
- print(f"[进程{process_id}] 结束处理航班号: {flight_numbers}")
- return df_mid
-
- except Exception as e:
- print(f"[进程{process_id}] ❌ 处理航班号:{flight_numbers} 时发生异常: {e}")
- return pd.DataFrame()
- finally:
- # 确保关闭数据库连接
- try:
- client.close()
- print(f"[进程{process_id}] ✅ 数据库连接已关闭")
- except:
- pass
- def load_data(db_config, city_pair, from_date_begin, from_date_end, is_train=True, plot_flag=False, output_dir='.',
- use_multiprocess=False, max_workers=None):
- list_all = []
-
- print(f"开始处理航线: {city_pair}")
- main_client, main_db = mongo_con_parse(db_config)
- all_groups = query_groups_of_city_pair(main_db, city_pair, mongo_table_uo)
- main_client.close()
-
- all_groups_len = len(all_groups)
- print(f"该航线共有{all_groups_len}组航班号")
- if use_multiprocess and all_groups_len > 1:
- print(f"启用多进程处理,最大进程数: {max_workers}")
- # 多进程处理
- process_args = []
- process_id = 0
- for each_group in all_groups:
- flight_numbers = each_group.get("flight_numbers", "未知")
- process_id += 1
- args = (process_id, db_config, city_pair, flight_numbers, from_date_begin, from_date_end, is_train, plot_flag, output_dir)
- process_args.append(args)
- with ProcessPoolExecutor(max_workers=max_workers) as executor:
- future_to_group = {executor.submit(process_flight_numbers, args): each_group for args, each_group in zip(process_args, all_groups)}
- for future in as_completed(future_to_group):
- each_group = future_to_group[future]
- flight_numbers = each_group.get("flight_numbers", "未知")
- try:
- df_mid = future.result()
- if not df_mid.empty:
- list_all.append(df_mid)
- print(f"✅ 航班号:{flight_numbers} 处理完成")
- else:
- print(f"⚠️ 航班号:{flight_numbers} 处理结果为空")
- except Exception as e:
- print(f"❌ 航班号:{flight_numbers} 处理异常: {e}")
- pass
- else:
- print("使用单进程处理")
- process_id = 0
- for each_group in all_groups:
- flight_numbers = each_group.get("flight_numbers", "未知")
- args = (process_id, db_config, city_pair, flight_numbers, from_date_begin, from_date_end, is_train, plot_flag, output_dir)
- try:
- df_mid = process_flight_numbers(args)
- if not df_mid.empty:
- list_all.append(df_mid)
- print(f"✅ 航班号:{flight_numbers} 处理完成")
- else:
- print(f"⚠️ 航班号:{flight_numbers} 处理结果为空")
- except Exception as e:
- print(f"❌ 航班号:{flight_numbers} 处理异常: {e}")
- print(f"结束处理航线: {city_pair}")
- if list_all:
- df_all = pd.concat(list_all, ignore_index=True)
- else:
- df_all = pd.DataFrame()
-
- print(f"本批次数据加载完毕, 总形状: {df_all.shape}")
- del list_all
- gc.collect()
- return df_all
- if __name__ == "__main__":
-
- cpu_cores = os.cpu_count() # 你的系统是72
- max_workers = min(8, cpu_cores) # 最大不超过8个进程
- output_dir = f"./photo"
- os.makedirs(output_dir, exist_ok=True)
- from_date_begin = "2026-03-17"
- from_date_end = "2026-03-26"
- uo_city_pairs = uo_city_pairs_new.copy()
- uo_city_pair_list = [f"{pair[:3]}-{pair[3:]}" for pair in uo_city_pairs]
- for idx, uo_city_pair in enumerate(uo_city_pair_list, start=1):
- # 使用默认配置
- # client, db = mongo_con_parse()
- print(f"第 {idx} 组 :", uo_city_pair)
- start_time = time.time()
- load_data(mongo_config, uo_city_pair, from_date_begin, from_date_end,
- plot_flag=True, output_dir=output_dir, use_multiprocess=True, max_workers=max_workers)
- end_time = time.time()
- run_time = round(end_time - start_time, 3)
- print(f"用时: {run_time} 秒")
|