| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255 |
- import gc
- import time
- from datetime import datetime, timedelta
- import pymongo
- from pymongo.errors import PyMongoError, ServerSelectionTimeoutError
- import pandas as pd
- import os
- import random
- import tempfile
- from concurrent.futures import ProcessPoolExecutor, as_completed
- import numpy as np
- import matplotlib.pyplot as plt
- from matplotlib import font_manager
- import matplotlib.dates as mdates
- from config import mongodb_config, vj_flight_route_list, vj_flight_route_list_hot, vj_flight_route_list_nothot, \
- CLEAN_VJ_HOT_NEAR_INFO_TAB, CLEAN_VJ_HOT_FAR_INFO_TAB, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB, CLEAN_VJ_NOTHOT_FAR_INFO_TAB
- font_path = "./simhei.ttf"
- font_prop = font_manager.FontProperties(fname=font_path)
- _MONGO_SHARED_CLIENT = None
- _MONGO_SHARED_DB = None
- _MONGO_SHARED_CFG_KEY = None
- def mongo_con_parse(config=None, reuse_client=False):
- if config is None:
- config = mongodb_config.copy()
- global _MONGO_SHARED_CLIENT, _MONGO_SHARED_DB, _MONGO_SHARED_CFG_KEY
- cfg_key = (
- config.get("URI", ""),
- config.get("host", ""),
- config.get("port", ""),
- config.get("db", ""),
- config.get("user", ""),
- )
- if reuse_client and _MONGO_SHARED_CLIENT is not None and _MONGO_SHARED_DB is not None and _MONGO_SHARED_CFG_KEY == cfg_key:
- return _MONGO_SHARED_CLIENT, _MONGO_SHARED_DB
-
- try:
- if config.get("URI", ""):
- motor_uri = config["URI"]
- client = pymongo.MongoClient(motor_uri, maxPoolSize=100)
- db = client[config['db']]
- else:
- client = pymongo.MongoClient(
- config['host'],
- config['port'],
- serverSelectionTimeoutMS=30000,
- connectTimeoutMS=30000,
- socketTimeoutMS=30000,
- retryReads=True,
- maxPoolSize=50
- )
- db = client[config['db']]
- if config.get('user'):
- db.authenticate(config['user'], config['pwd'])
- print(f"✅ MongoDB 连接对象创建成功")
- except Exception as e:
- print(f"❌ 创建 MongoDB 连接对象时发生错误: {e}")
- raise
- if reuse_client:
- _MONGO_SHARED_CLIENT = client
- _MONGO_SHARED_DB = db
- _MONGO_SHARED_CFG_KEY = cfg_key
- return client, db
- def test_mongo_connection(db):
- try:
- # 获取客户端对象
- client = db.client
- # 方法1:使用 server_info() 测试连接
- info = client.server_info()
- print(f"✅ MongoDB 连接测试成功!")
- print(f" 服务器版本: {info.get('version')}")
- print(f" 数据库: {db.name}")
- return True
- except Exception as e:
- print(f"❌ 数据库连接测试失败: {e}")
- return False
- def query_flight_range_status(db, table_name, from_city, to_city, dep_date_begin, dep_date_end, flight_nums,
- limit=0, max_retries=3, base_sleep=1.0, thread_id=0):
- """
- 从指定表(4类)查询数据(指定起飞天的范围) (失败自动重试)
- """
- for attempt in range(1, max_retries + 1):
- try:
- print(f"🔁 第 {attempt}/{max_retries} 次尝试查询")
- query_condition = {
- "from_city_code": from_city,
- "to_city_code": to_city,
- "search_dep_time": {
- "$gte": dep_date_begin,
- "$lte": dep_date_end,
- },
- }
-
- baggage_filter = 0
- # flight_nums_filter = list(flight_nums) if flight_nums else []
- print(f" 查询条件(走索引): {query_condition}")
- projection = {
- "from_city_code": 1,
- "search_dep_time": 1,
- "to_city_code": 1,
- "currency": 1,
- "adult_price": 1,
- "adult_tax": 1,
- "adult_total_price": 1,
- "seats_remaining": 1,
- "segments": 1,
- "source_website": 1,
- "crawl_date": 1
- }
-
- cursor = (
- db.get_collection(table_name)
- .find(query_condition, projection=projection)
- .batch_size(5000)
- .hint('from_city_code_1_to_city_code_1_search_dep_time_1')
- )
- # 将结果转换为列表
- results = list(cursor)
- print(f"✅ 查询成功,找到 {len(results)} 条记录")
- if results:
- df = pd.DataFrame(results)
- # 处理特殊的 ObjectId 类型
- if '_id' in df.columns:
- df = df.drop(columns=['_id'])
- print(f"📊 已转换为 DataFrame,形状: {df.shape}")
- # 1️⃣ 展开 segments
- print(f"📊 开始扩展segments 稍等...")
- t1 = time.time()
- df = expand_segments_columns_optimized(df) # 改为调用优化版
- t2 = time.time()
- rt = round(t2 - t1, 3)
- print(f"用时: {rt} 秒")
- print(f"📊 已将segments扩展成字段,形状: {df.shape}")
- if "baggage" in df.columns:
- df = df[df["baggage"] == baggage_filter]
-
- # for i, flight_num in enumerate(flight_nums_filter):
- # if flight_num is None or flight_num == "":
- # continue
- # col = f"seg{i + 1}_flight_number"
- # if col not in df.columns:
- # return pd.DataFrame()
- # df = df[df[col].astype("string") == str(flight_num)]
-
- # sort_cols = [c for c in ["search_dep_time", "baggage", "crawl_date"] if c in df.columns]
- # if sort_cols:
- # df = df.sort_values(sort_cols, kind="mergesort").reset_index(drop=True)
-
- if limit > 0:
- df = df.head(limit).reset_index(drop=True)
-
- return df
- else:
- print("⚠️ 查询结果为空")
- return pd.DataFrame()
-
- except (ServerSelectionTimeoutError, PyMongoError) as e:
- print(f"⚠️ Mongo 查询失败: {e}")
- if attempt == max_retries:
- print("❌ 达到最大重试次数,放弃")
- return pd.DataFrame()
-
- # 指数退避 + 随机抖动
- sleep_time = base_sleep * (2 ** (attempt - 1)) + random.random()
- print(f"⏳ {sleep_time:.2f}s 后重试...")
- time.sleep(sleep_time)
- # def expand_segments_columns(df):
- # """展开 segments"""
- # df = df.copy()
- # # 定义要展开的列
- # seg1_cols = ['flight_number', 'dep_air_port', 'dep_time', 'arr_air_port', 'arr_time', 'cabin', 'baggage']
- # seg2_cols = ['flight_number', 'dep_air_port', 'dep_time', 'arr_air_port', 'arr_time']
- # # 定义 apply 函数一次返回字典
- # def extract_segments(row):
- # segments = row.get('segments')
- # result = {}
- # # 默认缺失使用 pd.NA(对字符串友好)
- # missing = pd.NA
- # if isinstance(segments, list):
- # # 第一段
- # if len(segments) >= 1 and isinstance(segments[0], dict):
- # for col in seg1_cols:
- # result[f'seg1_{col}'] = segments[0].get(col)
- # else:
- # for col in seg1_cols:
- # result[f'seg1_{col}'] = missing
- # # 第二段
- # if len(segments) >= 2 and isinstance(segments[1], dict):
- # for col in seg2_cols:
- # result[f'seg2_{col}'] = segments[1].get(col)
- # else:
- # for col in seg2_cols:
- # result[f'seg2_{col}'] = missing
- # else:
- # # segments 不是 list,全都置空
- # for col in seg1_cols:
- # result[f'seg1_{col}'] = missing
- # for col in seg2_cols:
- # result[f'seg2_{col}'] = missing
- # return pd.Series(result)
- # # 一次 apply
- # df_segments = df.apply(extract_segments, axis=1)
- # # 拼回原 df
- # df = pd.concat([df.drop(columns=['segments'], errors='ignore'), df_segments], axis=1)
- # # 统一转换时间字段为 datetime
- # time_cols = [
- # 'seg1_dep_time', 'seg1_arr_time',
- # 'seg2_dep_time', 'seg2_arr_time'
- # ]
- # for col in time_cols:
- # if col in df.columns:
- # df[col] = pd.to_datetime(
- # df[col],
- # format='%Y%m%d%H%M%S',
- # errors='coerce'
- # )
- # # 站点来源 -> 是否近期
- # df['source_website'] = np.where(
- # df['source_website'].str.contains('7_30'),
- # 0, # 远期 -> 0
- # np.where(df['source_website'].str.contains('0_7'),
- # 1, # 近期 -> 1
- # df['source_website']) # 其他情况保持原值
- # )
- # # 行李配额字符 -> 数字
- # conditions = [
- # df['seg1_baggage'] == '-;-;-;-',
- # df['seg1_baggage'] == '1-20',
- # df['seg1_baggage'] == '1-30',
- # df['seg1_baggage'] == '1-40',
- # ]
- # choices = [0, 20, 30, 40]
- # df['seg1_baggage'] = np.select(conditions, choices, default=df['seg1_baggage'])
- # # 重命名字段
- # df = df.rename(columns={
- # 'seg1_cabin': 'cabin',
- # 'seg1_baggage': 'baggage',
- # 'source_website': 'is_near',
- # })
- # return df
- def expand_segments_columns_optimized(df):
- """优化版的展开segments函数(避免逐行apply)"""
- if df.empty:
- return df
-
- df = df.copy()
- # 直接操作segments列表,避免逐行apply
- if 'segments' in df.columns:
- # 提取第一段信息
- seg1_cols = ['flight_number', 'dep_air_port', 'dep_time', 'arr_air_port', 'arr_time', 'cabin', 'baggage']
- # 提取第二段信息
- seg2_cols = ['flight_number', 'dep_air_port', 'dep_time', 'arr_air_port', 'arr_time']
-
- # 使用列表推导式替代apply,大幅提升性能
- seg1_data = []
- seg2_data = []
- for segments in df['segments']:
- seg1_dict = {}
- seg2_dict = {}
- if isinstance(segments, list) and len(segments) >= 1 and isinstance(segments[0], dict):
- for col in seg1_cols:
- seg1_dict[f'seg1_{col}'] = segments[0].get(col)
- else:
- for col in seg1_cols:
- seg1_dict[f'seg1_{col}'] = pd.NA
-
- if isinstance(segments, list) and len(segments) >= 2 and isinstance(segments[1], dict):
- for col in seg2_cols:
- seg2_dict[f'seg2_{col}'] = segments[1].get(col)
- else:
- for col in seg2_cols:
- seg2_dict[f'seg2_{col}'] = pd.NA
-
- seg1_data.append(seg1_dict)
- seg2_data.append(seg2_dict)
- # 创建DataFrame
- df_seg1 = pd.DataFrame(seg1_data, index=df.index)
- df_seg2 = pd.DataFrame(seg2_data, index=df.index)
- # 合并到原DataFrame
- df = pd.concat([df.drop(columns=['segments'], errors='ignore'), df_seg1, df_seg2], axis=1)
- # 后续处理保持不变
- time_cols = ['seg1_dep_time', 'seg1_arr_time', 'seg2_dep_time', 'seg2_arr_time']
- for col in time_cols:
- if col in df.columns:
- df[col] = pd.to_datetime(df[col], format='%Y%m%d%H%M%S', errors='coerce')
-
- df['source_website'] = np.where(
- df['source_website'].str.contains('7_30'), 0,
- np.where(df['source_website'].str.contains('0_7'), 1, df['source_website'])
- )
- conditions = [
- df['seg1_baggage'] == '-;-;-;-',
- df['seg1_baggage'] == '1-20',
- df['seg1_baggage'] == '1-30',
- df['seg1_baggage'] == '1-40',
- ]
- choices = [0, 20, 30, 40]
- df['seg1_baggage'] = np.select(conditions, choices, default=df['seg1_baggage'])
- df = df.rename(columns={
- 'seg1_cabin': 'cabin',
- 'seg1_baggage': 'baggage',
- 'source_website': 'is_near',
- })
-
- return df
- def fill_hourly_crawl_date(df, head_fill=0, rear_fill=0):
- """补齐成小时粒度数据"""
- df = df.copy()
- # 1. 转 datetime
- df['crawl_date'] = pd.to_datetime(df['crawl_date'])
- # 添加一个用于分组的小时字段
- df['update_hour'] = df['crawl_date'].dt.floor('h')
- # 2. 排序规则:同一小时内,按原始时间戳排序
- # 假设你想保留最早的一条
- df = df.sort_values(['update_hour', 'crawl_date'])
- # 3. 按小时去重,保留该小时内最早(最晚)的一条
- df = df.drop_duplicates(subset=['update_hour'], keep='last') # keep='first' keep='last'
- # 删除原始时间戳列
- # df = df.drop(columns=['crawl_date'])
- # df = df.drop(columns=['_id'])
- # 4. 标记原始数据
- df['is_filled'] = 0
- # 5. 排序 + 设索引
- df = df.sort_values('update_hour').set_index('update_hour')
- # 6. 构造完整小时轴
- start_of_day = df.index.min() # 默认 第一天 最早 开始
- if head_fill == 1:
- start_of_day = df.index.min().normalize() # 强制 第一天 00:00 开始
- end_of_day = df.index.max() # 默认 最后一天 最晚 结束
- if rear_fill == 1:
- end_of_day = df.index.max().normalize() + pd.Timedelta(hours=23) # 强制 最后一天 23:00 结束
- elif rear_fill == 2:
- if 'seg1_dep_time' in df.columns:
- last_dep_time = df['seg1_dep_time'].iloc[-1]
- if pd.notna(last_dep_time):
- # 对齐到整点小时(向下取整)
- end_of_day = last_dep_time.floor('h')
- full_index = pd.date_range(
- start=start_of_day,
- end=end_of_day,
- freq='1h'
- )
- # 7. 按小时补齐
- df = df.reindex(full_index)
- # 先恢复 dtype(关键!)
- df = df.infer_objects(copy=False)
- # 8. 新增出来的行标记为 1
- df['is_filled'] = df['is_filled'].fillna(1)
- # 9. 前向填充
- df = df.ffill()
- # 10. 还原整型字段
- int_cols = [
- 'seats_remaining',
- 'is_near',
- 'baggage',
- 'is_filled',
- ]
- for col in int_cols:
- if col in df.columns:
- df[col] = df[col].astype('int64')
- # 10.5 价格字段统一保留两位小数
- price_cols = [
- 'adult_price',
- 'adult_tax',
- 'adult_total_price'
- ]
- for col in price_cols:
- if col in df.columns:
- df[col] = df[col].astype('float64').round(2)
- # 10.6 新增:距离起飞还有多少小时
- if 'seg1_dep_time' in df.columns:
- # 创建临时字段(整点)
- df['seg1_dep_hour'] = df['seg1_dep_time'].dt.floor('h')
- # 计算小时差 df.index 此时就是 update_hour
- df['hours_until_departure'] = (
- (df['seg1_dep_hour'] - df.index) / pd.Timedelta(hours=1)
- ).astype('int64')
- # 新增:距离起飞还有多少天
- df['days_to_departure'] = (df['hours_until_departure'] // 24).astype('int64')
- # 删除临时字段
- df = df.drop(columns=['seg1_dep_hour'])
- # 11. 写回 update_hour
- df['update_hour'] = df.index
- # 12. 恢复普通索引
- df = df.reset_index(drop=True)
- return df
- def query_groups_of_city_code(db, from_city, to_city, table_name, min_days=10, max_retries=3, base_sleep=1.0):
- """
- 从一组城市对中查找所有分组(航班号与起飞时间)的组合
- 按:第一段航班号 → 第二段航班号 → 起飞时间 排序
- (失败自动重试) 保证1个月内至少有10天起飞的航线
- 说明:为降低 Mongo 的聚合负担,这里只做轻量 find + 投影,把“按天统计/按航班组合汇总”的逻辑放到 pandas 侧处理。
- """
- print(f"{from_city}-{to_city} 查找所有分组")
- date_begin = (datetime.today() - timedelta(days=31)).strftime("%Y%m%d")
- date_end = datetime.today().strftime("%Y%m%d")
-
- query = {
- "from_city_code": from_city,
- "to_city_code": to_city,
- "search_dep_time": {"$gte": date_begin, "$lte": date_end},
- }
- projection = {
- "_id": 0,
- "search_dep_time": 1,
- "segments.flight_number": 1,
- }
- def _extract_flight_numbers(segments):
- if not isinstance(segments, list):
- return []
- out = []
- for seg in segments:
- if not isinstance(seg, dict):
- continue
- fn = seg.get("flight_number")
- if fn:
- out.append(fn)
- return out
-
- for attempt in range(1, max_retries + 1):
- try:
- print(f" 第 {attempt}/{max_retries} 次尝试查询")
- collection = db[table_name]
- cursor = collection.find(query, projection=projection).batch_size(5000).hint('from_city_code_1_to_city_code_1_search_dep_time_1')
- docs = list(cursor)
- if not docs:
- return []
-
- df = pd.DataFrame.from_records(docs)
- if df.empty or "segments" not in df.columns or "search_dep_time" not in df.columns:
- return []
-
- df["flight_numbers"] = df["segments"].apply(_extract_flight_numbers)
- df["fn1"] = df["flight_numbers"].str[0].fillna("")
- df["fn2"] = df["flight_numbers"].str[1].fillna("")
- df["flight_numbers_key"] = df["flight_numbers"].apply(lambda xs: ",".join(xs) if xs else "")
- day_counts = (
- df.groupby(["flight_numbers_key", "fn1", "fn2", "search_dep_time"], dropna=False)
- .size()
- .reset_index(name="count")
- .sort_values(["fn1", "fn2", "search_dep_time"], kind="mergesort")
- .reset_index(drop=True)
- )
- keys = ["flight_numbers_key", "fn1", "fn2"]
- df_days = day_counts.groupby(keys, sort=False).size().reset_index(name="days")
- df_details = (
- day_counts.groupby(keys, sort=False)
- .apply(lambda g: g[["search_dep_time", "count"]].to_dict("records"))
- .reset_index(name="details")
- )
- df_result = df_days.merge(df_details, on=keys, how="inner")
- df_result = df_result[df_result["days"] >= min_days].sort_values(["fn1", "fn2"], kind="mergesort")
- formatted_results = []
- for _, row in df_result.iterrows():
- flight_numbers = row["flight_numbers_key"].split(",") if row["flight_numbers_key"] else []
- formatted_results.append(
- {
- "flight_numbers": flight_numbers,
- "days": int(row["days"]),
- "details": row["details"],
- }
- )
- del df_result
- del df_details
- del df_days
- del df
- # gc.collect()
- return formatted_results
- except (ServerSelectionTimeoutError, PyMongoError) as e:
- print(f"⚠️ Mongo 查询失败: {e}")
- if attempt == max_retries:
- print("❌ 达到最大重试次数,放弃")
- return []
- # 指数退避 + 随机抖动
- sleep_time = base_sleep * (2 ** (attempt - 1)) + random.random()
- print(f"⏳ {sleep_time:.2f}s 后重试...")
- time.sleep(sleep_time)
- def plot_c12_trend(df, output_dir="."):
- """
- 根据传入的 dataframe 绘制 adult_total_price 随 update_hour 的趋势图,
- 并按照 baggage 分类进行分组绘制。
- """
- # output_dir_photo = output_dir
- # 颜色与线型配置(按顺序循环使用)
- colors = ['blue', 'red', 'brown']
- linestyles = ['--', '--', '--']
- # 确保时间字段为 datetime 类型
- if not hasattr(df['update_hour'], 'dt'):
- df['update_hour'] = pd.to_datetime(df['update_hour'])
- from_city = df['from_city_code'].mode().iloc[0]
- to_city = df['to_city_code'].mode().iloc[0]
- flight_number_1 = df['seg1_flight_number'].mode().iloc[0]
- flight_number_2 = df['seg2_flight_number'].mode().get(0, "")
- dep_time = df['seg1_dep_time'].mode().iloc[0]
- route = f"{from_city}-{to_city}"
- flight_number = f"{flight_number_1},{flight_number_2}" if flight_number_2 else f"{flight_number_1}"
- output_dir_photo = os.path.join(output_dir, route)
- os.makedirs(output_dir_photo, exist_ok=True)
- # 创建图表对象
- fig = plt.figure(figsize=(14, 8))
- # 按 baggage 分类绘制
- for i, (baggage_value, group) in enumerate(df.groupby('baggage')):
- # 按时间排序
- g = group.sort_values('update_hour').reset_index(drop=True)
- # 找价格变化点:与前一行不同的价格即为变化点
- # keep first row + change rows + last row
- change_points = g.loc[
- (g['adult_total_price'] != g['adult_total_price'].shift(1)) |
- (g.index == 0) |
- (g.index == len(g) - 1) # 终点
- ].drop_duplicates(subset=['update_hour'])
- # 绘制阶梯线:价格在相邻区间内保持不变,在下一个时间点发生跳变
- plt.step(
- change_points['update_hour'],
- change_points['adult_total_price'],
- where='post',
- color=colors[i % len(colors)],
- linestyle=linestyles[i % len(linestyles)],
- linewidth=2,
- label=f"Baggage {baggage_value}"
- )
- # 单独绘制变化点,保留原来的圆点视觉效果
- plt.plot(
- change_points['update_hour'],
- change_points['adult_total_price'],
- linestyle='None',
- marker='o',
- color=colors[i % len(colors)],
- markersize=6,
- markerfacecolor='white',
- markeredgewidth=2,
- )
-
- # 添加注释 (小时数, 价格, 余票)
- # 点密集时自动抽样,避免文字严重重叠
- n_points = len(change_points)
- max_labels = 30
- step = max(1, int(np.ceil(n_points / max_labels)))
- label_points = change_points.iloc[::step].copy()
- # 确保最后一个点始终有注释
- if n_points > 0 and label_points.index[-1] != change_points.index[-1]:
- label_points = pd.concat([label_points, change_points.tail(1)])
- rotation_angle = 45 if n_points > max_labels else 25
- label_fontsize = 4 if n_points > max_labels else 5
- for _, row in label_points.iterrows():
- text = f"({row['hours_until_departure']}, {row['adult_total_price']}, {row['seats_remaining']})"
- plt.annotate(
- text,
- xy=(row['update_hour'], row['adult_total_price']),
- xytext=(0, 0), # 向右偏移
- textcoords="offset points",
- ha='left',
- va='center',
- fontsize=label_fontsize,
- color='gray',
- alpha=0.8,
- rotation=rotation_angle,
- )
- # 自动优化日期显示
- plt.gcf().autofmt_xdate()
- plt.xlabel('时刻', fontsize=12, fontproperties=font_prop)
- plt.ylabel('价格', fontsize=12, fontproperties=font_prop)
- plt.title(f'价格变化趋势 - 航线: {route} 航班号: {flight_number}\n起飞时间: {dep_time}',
- fontsize=14, fontweight='bold', fontproperties=font_prop)
- # 设置 x 轴刻度为每天
- ax = plt.gca()
- ax.xaxis.set_major_locator(mdates.DayLocator(interval=1)) # 每天一个主刻度
- ax.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d')) # 显示月-日
- ax.xaxis.set_minor_locator(mdates.HourLocator(byhour=[12])) # 指定在12:00显示副刻度
- ax.xaxis.set_minor_formatter(mdates.DateFormatter('')) # 输出空字符串
- # ax.tick_params(axis='x', which='minor', labelsize=8, rotation=30)
- # 添加图例
- plt.legend(bbox_to_anchor=(1.05, 1), loc='best', prop=font_prop)
- plt.grid(True, alpha=0.3)
- plt.tight_layout()
- safe_flight = flight_number.replace(",", "_")
- safe_dep_time = dep_time.strftime("%Y-%m-%d %H%M%S")
- save_file = f"{route} {safe_flight} {safe_dep_time}.png"
- output_path = os.path.join(output_dir_photo, save_file)
- # 保存图片(在显示之前)
- plt.savefig(output_path, dpi=300, bbox_inches='tight', facecolor='white')
- # 关闭图形释放内存
- plt.close(fig)
- _ROUTE_CACHE_DF1 = None
- _ROUTE_CACHE_DF2 = None
- def _init_route_cache_worker(df1_pickle_path, df2_pickle_path):
- global _ROUTE_CACHE_DF1, _ROUTE_CACHE_DF2
- _ROUTE_CACHE_DF1 = pd.read_pickle(df1_pickle_path)
- _ROUTE_CACHE_DF2 = pd.read_pickle(df2_pickle_path)
- def _filter_df_by_flight_nums(df, flight_nums):
- if df is None or df.empty:
- return pd.DataFrame()
-
- out = df
- flight_nums_filter = list(flight_nums) if flight_nums else []
- for i, flight_num in enumerate(flight_nums_filter):
- if flight_num is None or flight_num == "":
- continue
- col = f"seg{i + 1}_flight_number"
- if col not in out.columns:
- return out.iloc[0:0].copy()
- out = out[out[col].astype("string") == str(flight_num)]
- if out.empty:
- return out
-
- return out
- def process_flight_group(args):
- """处理单个航班号的进程函数(基于主进程缓存的数据做 pandas 过滤与处理)"""
- process_id, each_group, is_train, plot_flag, output_dir = args
- flight_nums = each_group.get("flight_numbers")
- # details = each_group.get("details")
- print(f"[进程{process_id}] 开始处理航班号: {flight_nums}")
- try:
- df1 = _filter_df_by_flight_nums(_ROUTE_CACHE_DF1, flight_nums)
- df2 = _filter_df_by_flight_nums(_ROUTE_CACHE_DF2, flight_nums)
- sort_cols = [c for c in ["search_dep_time", "baggage", "crawl_date"] if c in df1.columns]
- if sort_cols:
- df1 = df1.sort_values(sort_cols, kind="mergesort").reset_index(drop=True)
- df2 = df2.sort_values(sort_cols, kind="mergesort").reset_index(drop=True)
- if df1.empty:
- print(f"[进程{process_id}] 航班号:{flight_nums} 远期表无数据, 跳过")
- return pd.DataFrame()
-
- if df2.empty:
- print(f"[进程{process_id}] 航班号:{flight_nums} 近期表无数据, 跳过")
- return pd.DataFrame()
-
- # 起飞天数、行李配额以近期表的为主
- if df2.empty:
- common_dep_dates = []
- common_baggages = []
- else:
- common_dep_dates = df2['search_dep_time'].unique()
- common_baggages = df2['baggage'].unique()
- # 如果是预测,起飞天数以远期表为主
- if not is_train:
- common_dep_dates = df1['search_dep_time'].unique()
- list_mid = []
- for dep_date in common_dep_dates:
- # 起飞日期筛选
- df_d1 = df1[df1["search_dep_time"] == dep_date].copy()
- if not df_d1.empty:
- for col in ["seg1_dep_time", "seg1_arr_time", "seg2_dep_time", "seg2_arr_time"]:
- mode_series_1 = df_d1[col].mode()
- if mode_series_1.empty:
- zong_1 = pd.NaT
- else:
- zong_1 = mode_series_1.iloc[0]
- df_d1[col] = zong_1
- df_d2 = df2[df2["search_dep_time"] == dep_date].copy()
- if not df_d2.empty:
- for col in ["seg1_dep_time", "seg1_arr_time", "seg2_dep_time", "seg2_arr_time"]:
- mode_series_2 = df_d2[col].mode()
- if mode_series_2.empty:
- zong_2 = pd.NaT
- else:
- zong_2 = mode_series_2.iloc[0]
- df_d2[col] = zong_2
- list_12 = []
- for baggage in common_baggages:
- # 行李配额筛选
- df_b1 = df_d1[df_d1["baggage"] == baggage].copy()
- df_b2 = df_d2[df_d2["baggage"] == baggage].copy()
- # 合并前检查是否都有数据
- if df_b1.empty and df_b2.empty:
- print(f"[进程{process_id}] ⚠️ dep_date:{dep_date}, baggage:{baggage} 远期表和近期表都为空,跳过")
- continue
- cols = ["seg1_flight_number", "seg1_dep_air_port", "seg1_arr_air_port",
- "seg2_flight_number", "seg2_dep_air_port", "seg2_arr_air_port"]
- df_b1[cols] = df_b1[cols].astype("string")
- df_b2[cols] = df_b2[cols].astype("string")
- df_b12 = pd.concat([df_b1, df_b2]).reset_index(drop=True)
- # print(f"📊 dep_date:{dep_date}, baggage:{baggage} 已将远期表和近期表合并,形状: {df_b12.shape}")
- df_b12 = fill_hourly_crawl_date(df_b12, rear_fill=2)
- # print(f"📊 dep_date:{dep_date}, baggage:{baggage} 已合并且补齐为完整小时序列,形状: {df_b12.shape}")
- list_12.append(df_b12)
- del df_b12
- del df_b2
- del df_b1
- if list_12:
- df_c12 = pd.concat(list_12, ignore_index=True)
- if plot_flag:
- print(f"[进程{process_id}] ✅ dep_date:{dep_date}, 所有 baggage 数据合并完成,总形状: {df_c12.shape}")
- plot_c12_trend(df_c12, output_dir)
- print(f"[进程{process_id}] ✅ dep_date:{dep_date}, 所有 baggage 数据绘图完成")
- else:
- df_c12 = pd.DataFrame()
- if plot_flag:
- print(f"[进程{process_id}] ⚠️ dep_date:{dep_date}, 所有 baggage 数据合并为空")
- del list_12
- list_mid.append(df_c12)
- del df_c12
- del df_d1
- del df_d2
- # print(f"结束处理起飞日期: {dep_date}")
- if list_mid:
- df_mid = pd.concat(list_mid, ignore_index=True)
- print(f"[进程{process_id}] ✅ 航班号:{flight_nums} 所有 起飞日期 数据合并完成,总形状: {df_mid.shape}")
- else:
- df_mid = pd.DataFrame()
- print(f"[进程{process_id}] ⚠️ 航班号:{flight_nums} 所有 起飞日期 数据合并为空")
-
- del list_mid
- del df1
- del df2
- gc.collect()
- print(f"[进程{process_id}] 结束处理航班号: {flight_nums}")
- return df_mid
-
- except Exception as e:
- print(f"[进程{process_id}] ❌ 处理航班号:{flight_nums} 时发生异常: {e}")
- return pd.DataFrame()
- finally:
- pass
- def load_train_data(db_config, flight_route_list, table_name, date_begin, date_end, output_dir='.', is_hot=1, is_train=True, plot_flag=False,
- use_multiprocess=False, max_workers=None):
- """加载训练数据(支持多进程)"""
- timestamp_str = datetime.now().strftime("%Y%m%d%H%M%S")
- date_begin_s = datetime.strptime(date_begin, "%Y-%m-%d").strftime("%Y%m%d") # 查询时的格式
- date_end_s = datetime.strptime(date_end, "%Y-%m-%d").strftime("%Y%m%d")
- list_all = []
- # 每一航线对
- for flight_route in flight_route_list:
- from_city = flight_route.split('-')[0]
- to_city = flight_route.split('-')[1]
- route = f"{from_city}-{to_city}"
- print(f"开始处理航线: {route}")
- # 在主进程中查询航班号分组(避免多进程重复查询)
- main_client, main_db = mongo_con_parse(db_config, reuse_client=True)
- all_groups = query_groups_of_city_code(main_db, from_city, to_city, table_name)
-
- all_groups_len = len(all_groups)
- print(f"该航线共有{all_groups_len}个航班号")
- if all_groups_len == 0:
- continue
- # 查询远期表
- if is_hot == 1:
- df1 = query_flight_range_status(main_db, CLEAN_VJ_HOT_FAR_INFO_TAB, from_city, to_city,
- date_begin_s, date_end_s, None)
- else:
- df1 = query_flight_range_status(main_db, CLEAN_VJ_NOTHOT_FAR_INFO_TAB, from_city, to_city,
- date_begin_s, date_end_s, None)
-
- # 保证远期表里有数据
- if df1.empty:
- print(f"[主进程] 航线:{route} 远期表无数据, 跳过")
- # main_client.close()
- return pd.DataFrame()
-
- # 查询近期表
- if is_hot == 1:
- df2 = query_flight_range_status(main_db, CLEAN_VJ_HOT_NEAR_INFO_TAB, from_city, to_city,
- date_begin_s, date_end_s, None)
- else:
- df2 = query_flight_range_status(main_db, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB, from_city, to_city,
- date_begin_s, date_end_s, None)
-
- # 保证近期表里有数据
- if df2.empty:
- print(f"[主进程] 航线:{route} 近期表无数据, 跳过")
- # main_client.close()
- return pd.DataFrame()
-
- # main_client.close()
-
- os.makedirs(output_dir, exist_ok=True)
- safe_route = route.replace("-", "_")
- df1_fd, df1_cache_path = tempfile.mkstemp(prefix=f"route_{safe_route}_{timestamp_str}_far_", suffix=".pkl", dir=output_dir)
- df2_fd, df2_cache_path = tempfile.mkstemp(prefix=f"route_{safe_route}_{timestamp_str}_near_", suffix=".pkl", dir=output_dir)
- os.close(df1_fd)
- os.close(df2_fd)
- df1.to_pickle(df1_cache_path)
- df2.to_pickle(df2_cache_path)
- try:
- if use_multiprocess and all_groups_len > 1:
- print(f"启用多进程处理,最大进程数: {max_workers}")
- process_args = []
- process_id = 0
- for each_group in all_groups:
- process_id += 1
- args = (process_id, each_group, is_train, plot_flag, output_dir)
- process_args.append(args)
-
- with ProcessPoolExecutor(
- max_workers=max_workers,
- initializer=_init_route_cache_worker,
- initargs=(df1_cache_path, df2_cache_path),
- ) as executor:
- future_to_group = {executor.submit(process_flight_group, args): each_group for args, each_group in zip(process_args, all_groups)}
-
- for future in as_completed(future_to_group):
- each_group = future_to_group[future]
- flight_nums = each_group.get("flight_numbers", "未知")
- try:
- df_mid = future.result()
- if not df_mid.empty:
- list_all.append(df_mid)
- print(f"✅ 航班号:{flight_nums} 处理完成")
- else:
- print(f"⚠️ 航班号:{flight_nums} 处理结果为空")
- except Exception as e:
- print(f"❌ 航班号:{flight_nums} 处理异常: {e}")
- else:
- # 单进程处理(进程编号为0)
- print("使用单进程处理")
- global _ROUTE_CACHE_DF1, _ROUTE_CACHE_DF2
- _ROUTE_CACHE_DF1 = df1
- _ROUTE_CACHE_DF2 = df2
- process_id = 0
- for each_group in all_groups:
- args = (process_id, each_group, is_train, plot_flag, output_dir)
- flight_nums = each_group.get("flight_numbers", "未知")
- try:
- df_mid = process_flight_group(args)
- if not df_mid.empty:
- list_all.append(df_mid)
- print(f"✅ 航班号:{flight_nums} 处理完成")
- else:
- print(f"⚠️ 航班号:{flight_nums} 处理结果为空")
- except Exception as e:
- print(f"❌ 航班号:{flight_nums} 处理异常: {e}")
- finally:
- try:
- os.remove(df1_cache_path)
- except Exception:
- pass
- try:
- os.remove(df2_cache_path)
- except Exception:
- pass
-
- print(f"结束处理航线: {from_city}-{to_city}")
- if list_all:
- df_all = pd.concat(list_all, ignore_index=True)
- else:
- df_all = pd.DataFrame()
- print(f"本批次数据加载完毕, 总形状: {df_all.shape}")
- del list_all
- gc.collect()
- return df_all
- def query_all_flight_number(db, table_name):
- print(f"{table_name} 查找所有航班号")
- pipeline = [
- {
- "$project": {
- "flight_numbers": "$segments.flight_number"
- }
- },
- {
- "$group": {
- "_id": "$flight_numbers",
- "count": { "$sum": 1 }
- }
- },
- ]
- # 执行聚合查询
- collection = db[table_name]
- results = list(collection.aggregate(pipeline))
- list_flight_number = []
- for item in results:
- item_li = item.get("_id", [])
- list_flight_number.extend(item_li)
- list_flight_number = list(set(list_flight_number))
-
- return list_flight_number
- def validate_one_line(db, table_name, city_pair, flight_day, flight_number_1, flight_number_2, baggage, valid_begin_hour,
- limit=0, max_retries=3, base_sleep=1.0):
- """验证预测结果的一行"""
-
- # if city_pair in vj_flight_route_list_hot:
- # table_name = CLEAN_VJ_HOT_NEAR_INFO_TAB
- # elif city_pair in vj_flight_route_list_nothot:
- # table_name = CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
- # else:
- # print(f"城市对{city_pair}不在热门航线与冷门航线, 返回")
- # return pd.DataFrame()
-
- city_pair_split = city_pair.split('-')
- from_city_code = city_pair_split[0]
- to_city_code = city_pair_split[1]
- flight_day_str = datetime.strptime(flight_day, "%Y-%m-%d").strftime("%Y%m%d")
- baggage_str = f"1-{baggage}"
- if baggage == 0:
- baggage_str = "-;-;-;-"
- for attempt in range(1, max_retries + 1):
- try:
- print(f"🔁 第 {attempt}/{max_retries} 次尝试查询")
- # 构建查询条件
- query_condition = {
- "from_city_code": from_city_code,
- "to_city_code": to_city_code,
- "search_dep_time": flight_day_str,
- "segments.baggage": baggage_str,
- "crawl_date": {"$gte": valid_begin_hour},
- "segments.0.flight_number": flight_number_1,
- }
- # 如果有第二段
- if flight_number_2 != "VJ":
- query_condition["segments.1.flight_number"] = flight_number_2
- print(f" 查询条件: {query_condition}")
- # 定义要查询的字段
- projection = {
- # "_id": 1,
- "from_city_code": 1,
- "search_dep_time": 1,
- "to_city_code": 1,
- "currency": 1,
- "adult_price": 1,
- "adult_tax": 1,
- "adult_total_price": 1,
- "seats_remaining": 1,
- "segments": 1,
- "source_website": 1,
- "crawl_date": 1
- }
- # 执行查询
- cursor = db.get_collection(table_name).find(
- query_condition,
- projection=projection # 添加投影参数
- ).sort(
- [
- ("crawl_date", 1)
- ]
- )
- if limit > 0:
- cursor = cursor.limit(limit)
- # 将结果转换为列表
- results = list(cursor)
- print(f"✅ 查询成功,找到 {len(results)} 条记录")
- if results:
- df = pd.DataFrame(results)
- # 处理特殊的 ObjectId 类型
- if '_id' in df.columns:
- df = df.drop(columns=['_id'])
- print(f"📊 已转换为 DataFrame,形状: {df.shape}")
- # 1️⃣ 展开 segments
- print(f"📊 开始扩展segments 稍等...")
- t1 = time.time()
- df = expand_segments_columns_optimized(df)
- t2 = time.time()
- rt = round(t2 - t1, 3)
- print(f"用时: {rt} 秒")
- print(f"📊 已将segments扩展成字段,形状: {df.shape}")
- # 不用排序,因为mongo语句已经排好
- return df
- else:
- print("⚠️ 查询结果为空")
- return pd.DataFrame()
- except (ServerSelectionTimeoutError, PyMongoError) as e:
- print(f"⚠️ Mongo 查询失败: {e}")
- if attempt == max_retries:
- print("❌ 达到最大重试次数,放弃")
- return pd.DataFrame()
-
- # 指数退避 + 随机抖动
- sleep_time = base_sleep * (2 ** (attempt - 1)) + random.random()
- print(f"⏳ {sleep_time:.2f}s 后重试...")
- time.sleep(sleep_time)
- def validate_keep_one_line(db, table_name, city_pair, flight_day, flight_number_1, flight_number_2, baggage, update_hour_str, del_batch_std_str,
- limit=0, max_retries=3, base_sleep=1.0):
- """验证keep_info的一行"""
- city_pair_split = city_pair.split('-')
- from_city_code = city_pair_split[0]
- to_city_code = city_pair_split[1]
- flight_day_str = datetime.strptime(flight_day, "%Y-%m-%d").strftime("%Y%m%d")
- baggage_str = f"1-{baggage}"
- if baggage == 0:
- baggage_str = "-;-;-;-"
- for attempt in range(1, max_retries + 1):
- try:
- print(f"🔁 第 {attempt}/{max_retries} 次尝试查询")
- # 构建查询条件
- query_condition = {
- "from_city_code": from_city_code,
- "to_city_code": to_city_code,
- "search_dep_time": flight_day_str,
- "segments.baggage": baggage_str,
- "crawl_date": {"$gte": update_hour_str, "$lt": del_batch_std_str},
- "segments.0.flight_number": flight_number_1,
- }
- # 如果有第二段
- if flight_number_2 != "VJ":
- query_condition["segments.1.flight_number"] = flight_number_2
- print(f" 查询条件: {query_condition}")
- # 定义要查询的字段
- projection = {
- # "_id": 1,
- "from_city_code": 1,
- "search_dep_time": 1,
- "to_city_code": 1,
- "currency": 1,
- "adult_price": 1,
- "adult_tax": 1,
- "adult_total_price": 1,
- "seats_remaining": 1,
- "segments": 1,
- "source_website": 1,
- "crawl_date": 1
- }
- # 执行查询
- cursor = db.get_collection(table_name).find(
- query_condition,
- projection=projection # 添加投影参数
- ).sort(
- [
- ("crawl_date", 1)
- ]
- )
- if limit > 0:
- cursor = cursor.limit(limit)
- # 将结果转换为列表
- results = list(cursor)
- print(f"✅ 查询成功,找到 {len(results)} 条记录")
- if results:
- df = pd.DataFrame(results)
- # 处理特殊的 ObjectId 类型
- if '_id' in df.columns:
- df = df.drop(columns=['_id'])
- print(f"📊 已转换为 DataFrame,形状: {df.shape}")
- # 1️⃣ 展开 segments
- print(f"📊 开始扩展segments 稍等...")
- t1 = time.time()
- df = expand_segments_columns_optimized(df)
- t2 = time.time()
- rt = round(t2 - t1, 3)
- print(f"用时: {rt} 秒")
- print(f"📊 已将segments扩展成字段,形状: {df.shape}")
- # 不用排序,因为mongo语句已经排好
- return df
- else:
- print("⚠️ 查询结果为空")
- return pd.DataFrame()
- except (ServerSelectionTimeoutError, PyMongoError) as e:
- print(f"⚠️ Mongo 查询失败: {e}")
- if attempt == max_retries:
- print("❌ 达到最大重试次数,放弃")
- return pd.DataFrame()
-
- # 指数退避 + 随机抖动
- sleep_time = base_sleep * (2 ** (attempt - 1)) + random.random()
- print(f"⏳ {sleep_time:.2f}s 后重试...")
- time.sleep(sleep_time)
- if __name__ == "__main__":
- # test_mongo_connection(db)
- from utils import chunk_list_with_index
- cpu_cores = os.cpu_count() # 你的系统是72
- max_workers = min(8, cpu_cores) # 最大不超过8个进程
- output_dir = f"./photo_0"
- os.makedirs(output_dir, exist_ok=True)
- # 加载热门航线数据
- date_begin = "2026-04-21"
- date_end = datetime.today().strftime("%Y-%m-%d")
- flight_route_list = vj_flight_route_list_hot[:] # 热门 vj_flight_route_list_hot 冷门 vj_flight_route_list_nothot
- # flight_route_list = ["SGN-NGO"] # 测试段
- table_name = CLEAN_VJ_HOT_NEAR_INFO_TAB # 热门 CLEAN_VJ_HOT_NEAR_INFO_TAB 冷门 CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
- is_hot = 1 # 1 热门 0 冷门
- group_size = 1
- chunks = chunk_list_with_index(flight_route_list, group_size)
- for idx, (_, group_route_list) in enumerate(chunks, 1):
- # 使用默认配置
- # client, db = mongo_con_parse()
- print(f"第 {idx} 组 :", group_route_list)
- start_time = time.time()
- load_train_data(mongodb_config, group_route_list, table_name, date_begin, date_end, output_dir, is_hot, plot_flag=True,
- use_multiprocess=True, max_workers=max_workers)
- end_time = time.time()
- run_time = round(end_time - start_time, 3)
- print(f"用时: {run_time} 秒")
- # client.close()
- time.sleep(3)
- print("整体结束")
- # client, db = mongo_con_parse()
- # list_flight_number_1 = query_all_flight_number(db, CLEAN_VJ_HOT_NEAR_INFO_TAB)
- # list_flight_number_2 = query_all_flight_number(db, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB)
- # list_flight_number_all = list_flight_number_1 + list_flight_number_2
- # list_flight_number_all = list(set(list_flight_number_all))
- # list_flight_number_all.sort()
-
- # print(list_flight_number_all)
- # print(len(list_flight_number_all))
- # flight_map = {v: i for i, v in enumerate(list_flight_number_all, start=1)}
- # print(flight_map)
-
|