data_loader.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047
  1. import gc
  2. import time
  3. from datetime import datetime, timedelta
  4. import pymongo
  5. from pymongo.errors import PyMongoError, ServerSelectionTimeoutError
  6. import pandas as pd
  7. import os
  8. import random
  9. import threading
  10. from concurrent.futures import ThreadPoolExecutor, as_completed
  11. import numpy as np
  12. import matplotlib.pyplot as plt
  13. from matplotlib import font_manager
  14. import matplotlib.dates as mdates
  15. from config import mongodb_config, vj_flight_route_list, vj_flight_route_list_hot, vj_flight_route_list_nothot, \
  16. CLEAN_VJ_HOT_NEAR_INFO_TAB, CLEAN_VJ_HOT_FAR_INFO_TAB, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB, CLEAN_VJ_NOTHOT_FAR_INFO_TAB
  17. font_path = "./simhei.ttf"
  18. font_prop = font_manager.FontProperties(fname=font_path)
  19. def mongo_con_parse(config=None):
  20. if config is None:
  21. config = mongodb_config.copy()
  22. try:
  23. if config.get("URI", ""):
  24. motor_uri = config["URI"]
  25. client = pymongo.MongoClient(motor_uri, maxPoolSize=100)
  26. db = client[config['db']]
  27. print("motor_uri: ", motor_uri)
  28. else:
  29. client = pymongo.MongoClient(
  30. config['host'],
  31. config['port'],
  32. serverSelectionTimeoutMS=15000, # 6秒
  33. connectTimeoutMS=15000, # 6秒
  34. socketTimeoutMS=15000, # 6秒,
  35. retryReads=True, # 开启重试
  36. maxPoolSize=50
  37. )
  38. db = client[config['db']]
  39. if config.get('user'):
  40. db.authenticate(config['user'], config['pwd'])
  41. print(f"✅ MongoDB 连接对象创建成功")
  42. except Exception as e:
  43. print(f"❌ 创建 MongoDB 连接对象时发生错误: {e}")
  44. raise
  45. return client, db
  46. def test_mongo_connection(db):
  47. try:
  48. # 获取客户端对象
  49. client = db.client
  50. # 方法1:使用 server_info() 测试连接
  51. info = client.server_info()
  52. print(f"✅ MongoDB 连接测试成功!")
  53. print(f" 服务器版本: {info.get('version')}")
  54. print(f" 数据库: {db.name}")
  55. return True
  56. except Exception as e:
  57. print(f"❌ 数据库连接测试失败: {e}")
  58. return False
  59. def query_flight_range_status(db, table_name, from_city, to_city, dep_date_begin, dep_date_end, flight_nums,
  60. limit=0, max_retries=3, base_sleep=1.0, thread_id=0):
  61. """
  62. 从指定表(4类)查询数据(指定起飞天的范围) (失败自动重试)
  63. """
  64. for attempt in range(1, max_retries + 1):
  65. try:
  66. print(f"🔁 第 {attempt}/{max_retries} 次尝试查询")
  67. # 构建查询条件
  68. query_condition = {
  69. "from_city_code": from_city,
  70. "to_city_code": to_city,
  71. "search_dep_time": {
  72. "$gte": dep_date_begin,
  73. "$lte": dep_date_end,
  74. },
  75. "segments.baggage": {"$in": ["1-20", "1-30"]} # 只查20公斤和30公斤行李的
  76. }
  77. # 动态添加航班号条件
  78. for i, flight_num in enumerate(flight_nums):
  79. query_condition[f"segments.{i}.flight_number"] = flight_num
  80. print(f" 查询条件: {query_condition}")
  81. # 定义要查询的字段
  82. projection = {
  83. # "_id": 1,
  84. "from_city_code": 1,
  85. "search_dep_time": 1,
  86. "to_city_code": 1,
  87. "currency": 1,
  88. "adult_price": 1,
  89. "adult_tax": 1,
  90. "adult_total_price": 1,
  91. "seats_remaining": 1,
  92. "segments": 1,
  93. "source_website": 1,
  94. "crawl_date": 1
  95. }
  96. # 执行查询
  97. cursor = db.get_collection(table_name).find(
  98. query_condition,
  99. projection=projection # 添加投影参数
  100. ).sort(
  101. [
  102. ("search_dep_time", 1), # 多级排序要用列表+元组的格式
  103. ("segments.0.baggage", 1),
  104. ("crawl_date", 1)
  105. ]
  106. )
  107. if limit > 0:
  108. cursor = cursor.limit(limit)
  109. # 将结果转换为列表
  110. results = list(cursor)
  111. print(f"✅ 查询成功,找到 {len(results)} 条记录")
  112. if results:
  113. df = pd.DataFrame(results)
  114. # 处理特殊的 ObjectId 类型
  115. if '_id' in df.columns:
  116. df = df.drop(columns=['_id'])
  117. print(f"📊 已转换为 DataFrame,形状: {df.shape}")
  118. # 1️⃣ 展开 segments
  119. print(f"📊 开始扩展segments 稍等...")
  120. t1 = time.time()
  121. df = expand_segments_columns_optimized(df) # 改为调用优化版
  122. t2 = time.time()
  123. rt = round(t2 - t1, 3)
  124. print(f"用时: {rt} 秒")
  125. print(f"📊 已将segments扩展成字段,形状: {df.shape}")
  126. # 不用排序,因为mongo语句已经排好
  127. return df
  128. else:
  129. print("⚠️ 查询结果为空")
  130. return pd.DataFrame()
  131. except (ServerSelectionTimeoutError, PyMongoError) as e:
  132. print(f"⚠️ Mongo 查询失败: {e}")
  133. if attempt == max_retries:
  134. print("❌ 达到最大重试次数,放弃")
  135. return pd.DataFrame()
  136. # 指数退避 + 随机抖动
  137. sleep_time = base_sleep * (2 ** (attempt - 1)) + random.random()
  138. print(f"⏳ {sleep_time:.2f}s 后重试...")
  139. time.sleep(sleep_time)
  140. # def expand_segments_columns(df):
  141. # """展开 segments"""
  142. # df = df.copy()
  143. # # 定义要展开的列
  144. # seg1_cols = ['flight_number', 'dep_air_port', 'dep_time', 'arr_air_port', 'arr_time', 'cabin', 'baggage']
  145. # seg2_cols = ['flight_number', 'dep_air_port', 'dep_time', 'arr_air_port', 'arr_time']
  146. # # 定义 apply 函数一次返回字典
  147. # def extract_segments(row):
  148. # segments = row.get('segments')
  149. # result = {}
  150. # # 默认缺失使用 pd.NA(对字符串友好)
  151. # missing = pd.NA
  152. # if isinstance(segments, list):
  153. # # 第一段
  154. # if len(segments) >= 1 and isinstance(segments[0], dict):
  155. # for col in seg1_cols:
  156. # result[f'seg1_{col}'] = segments[0].get(col)
  157. # else:
  158. # for col in seg1_cols:
  159. # result[f'seg1_{col}'] = missing
  160. # # 第二段
  161. # if len(segments) >= 2 and isinstance(segments[1], dict):
  162. # for col in seg2_cols:
  163. # result[f'seg2_{col}'] = segments[1].get(col)
  164. # else:
  165. # for col in seg2_cols:
  166. # result[f'seg2_{col}'] = missing
  167. # else:
  168. # # segments 不是 list,全都置空
  169. # for col in seg1_cols:
  170. # result[f'seg1_{col}'] = missing
  171. # for col in seg2_cols:
  172. # result[f'seg2_{col}'] = missing
  173. # return pd.Series(result)
  174. # # 一次 apply
  175. # df_segments = df.apply(extract_segments, axis=1)
  176. # # 拼回原 df
  177. # df = pd.concat([df.drop(columns=['segments'], errors='ignore'), df_segments], axis=1)
  178. # # 统一转换时间字段为 datetime
  179. # time_cols = [
  180. # 'seg1_dep_time', 'seg1_arr_time',
  181. # 'seg2_dep_time', 'seg2_arr_time'
  182. # ]
  183. # for col in time_cols:
  184. # if col in df.columns:
  185. # df[col] = pd.to_datetime(
  186. # df[col],
  187. # format='%Y%m%d%H%M%S',
  188. # errors='coerce'
  189. # )
  190. # # 站点来源 -> 是否近期
  191. # df['source_website'] = np.where(
  192. # df['source_website'].str.contains('7_30'),
  193. # 0, # 远期 -> 0
  194. # np.where(df['source_website'].str.contains('0_7'),
  195. # 1, # 近期 -> 1
  196. # df['source_website']) # 其他情况保持原值
  197. # )
  198. # # 行李配额字符 -> 数字
  199. # conditions = [
  200. # df['seg1_baggage'] == '-;-;-;-',
  201. # df['seg1_baggage'] == '1-20',
  202. # df['seg1_baggage'] == '1-30',
  203. # df['seg1_baggage'] == '1-40',
  204. # ]
  205. # choices = [0, 20, 30, 40]
  206. # df['seg1_baggage'] = np.select(conditions, choices, default=df['seg1_baggage'])
  207. # # 重命名字段
  208. # df = df.rename(columns={
  209. # 'seg1_cabin': 'cabin',
  210. # 'seg1_baggage': 'baggage',
  211. # 'source_website': 'is_near',
  212. # })
  213. # return df
  214. def expand_segments_columns_optimized(df):
  215. """优化版的展开segments函数(避免逐行apply)"""
  216. if df.empty:
  217. return df
  218. df = df.copy()
  219. # 直接操作segments列表,避免逐行apply
  220. if 'segments' in df.columns:
  221. # 提取第一段信息
  222. seg1_cols = ['flight_number', 'dep_air_port', 'dep_time', 'arr_air_port', 'arr_time', 'cabin', 'baggage']
  223. # 提取第二段信息
  224. seg2_cols = ['flight_number', 'dep_air_port', 'dep_time', 'arr_air_port', 'arr_time']
  225. # 使用列表推导式替代apply,大幅提升性能
  226. seg1_data = []
  227. seg2_data = []
  228. for segments in df['segments']:
  229. seg1_dict = {}
  230. seg2_dict = {}
  231. if isinstance(segments, list) and len(segments) >= 1 and isinstance(segments[0], dict):
  232. for col in seg1_cols:
  233. seg1_dict[f'seg1_{col}'] = segments[0].get(col)
  234. else:
  235. for col in seg1_cols:
  236. seg1_dict[f'seg1_{col}'] = pd.NA
  237. if isinstance(segments, list) and len(segments) >= 2 and isinstance(segments[1], dict):
  238. for col in seg2_cols:
  239. seg2_dict[f'seg2_{col}'] = segments[1].get(col)
  240. else:
  241. for col in seg2_cols:
  242. seg2_dict[f'seg2_{col}'] = pd.NA
  243. seg1_data.append(seg1_dict)
  244. seg2_data.append(seg2_dict)
  245. # 创建DataFrame
  246. df_seg1 = pd.DataFrame(seg1_data, index=df.index)
  247. df_seg2 = pd.DataFrame(seg2_data, index=df.index)
  248. # 合并到原DataFrame
  249. df = pd.concat([df.drop(columns=['segments'], errors='ignore'), df_seg1, df_seg2], axis=1)
  250. # 后续处理保持不变
  251. time_cols = ['seg1_dep_time', 'seg1_arr_time', 'seg2_dep_time', 'seg2_arr_time']
  252. for col in time_cols:
  253. if col in df.columns:
  254. df[col] = pd.to_datetime(df[col], format='%Y%m%d%H%M%S', errors='coerce')
  255. df['source_website'] = np.where(
  256. df['source_website'].str.contains('7_30'), 0,
  257. np.where(df['source_website'].str.contains('0_7'), 1, df['source_website'])
  258. )
  259. conditions = [
  260. df['seg1_baggage'] == '-;-;-;-',
  261. df['seg1_baggage'] == '1-20',
  262. df['seg1_baggage'] == '1-30',
  263. df['seg1_baggage'] == '1-40',
  264. ]
  265. choices = [0, 20, 30, 40]
  266. df['seg1_baggage'] = np.select(conditions, choices, default=df['seg1_baggage'])
  267. df = df.rename(columns={
  268. 'seg1_cabin': 'cabin',
  269. 'seg1_baggage': 'baggage',
  270. 'source_website': 'is_near',
  271. })
  272. return df
  273. def fill_hourly_crawl_date(df, head_fill=0, rear_fill=0):
  274. """补齐成小时粒度数据"""
  275. df = df.copy()
  276. # 1. 转 datetime
  277. df['crawl_date'] = pd.to_datetime(df['crawl_date'])
  278. # 添加一个用于分组的小时字段
  279. df['update_hour'] = df['crawl_date'].dt.floor('h')
  280. # 2. 排序规则:同一小时内,按原始时间戳排序
  281. # 假设你想保留最早的一条
  282. df = df.sort_values(['update_hour', 'crawl_date'])
  283. # 3. 按小时去重,保留该小时内最早的一条
  284. df = df.drop_duplicates(subset=['update_hour'], keep='first')
  285. # 删除原始时间戳列
  286. # df = df.drop(columns=['crawl_date'])
  287. # df = df.drop(columns=['_id'])
  288. # 4. 标记原始数据
  289. df['is_filled'] = 0
  290. # 5. 排序 + 设索引
  291. df = df.sort_values('update_hour').set_index('update_hour')
  292. # 6. 构造完整小时轴
  293. start_of_day = df.index.min() # 默认 第一天 最早 开始
  294. if head_fill == 1:
  295. start_of_day = df.index.min().normalize() # 强制 第一天 00:00 开始
  296. end_of_day = df.index.max() # 默认 最后一天 最晚 结束
  297. if rear_fill == 1:
  298. end_of_day = df.index.max().normalize() + pd.Timedelta(hours=23) # 强制 最后一天 23:00 结束
  299. elif rear_fill == 2:
  300. if 'seg1_dep_time' in df.columns:
  301. last_dep_time = df['seg1_dep_time'].iloc[-1]
  302. if pd.notna(last_dep_time):
  303. # 对齐到整点小时(向下取整)
  304. end_of_day = last_dep_time.floor('h')
  305. full_index = pd.date_range(
  306. start=start_of_day,
  307. end=end_of_day,
  308. freq='1h'
  309. )
  310. # 7. 按小时补齐
  311. df = df.reindex(full_index)
  312. # 先恢复 dtype(关键!)
  313. df = df.infer_objects(copy=False)
  314. # 8. 新增出来的行标记为 1
  315. df['is_filled'] = df['is_filled'].fillna(1)
  316. # 9. 前向填充
  317. df = df.ffill()
  318. # 10. 还原整型字段
  319. int_cols = [
  320. 'seats_remaining',
  321. 'is_near',
  322. 'baggage',
  323. 'is_filled',
  324. ]
  325. for col in int_cols:
  326. if col in df.columns:
  327. df[col] = df[col].astype('int64')
  328. # 10.5 价格字段统一保留两位小数
  329. price_cols = [
  330. 'adult_price',
  331. 'adult_tax',
  332. 'adult_total_price'
  333. ]
  334. for col in price_cols:
  335. if col in df.columns:
  336. df[col] = df[col].astype('float64').round(2)
  337. # 10.6 新增:距离起飞还有多少小时
  338. if 'seg1_dep_time' in df.columns:
  339. # 创建临时字段(整点)
  340. df['seg1_dep_hour'] = df['seg1_dep_time'].dt.floor('h')
  341. # 计算小时差 df.index 此时就是 update_hour
  342. df['hours_until_departure'] = (
  343. (df['seg1_dep_hour'] - df.index) / pd.Timedelta(hours=1)
  344. ).astype('int64')
  345. # 新增:距离起飞还有多少天
  346. df['days_to_departure'] = (df['hours_until_departure'] // 24).astype('int64')
  347. # 删除临时字段
  348. df = df.drop(columns=['seg1_dep_hour'])
  349. # 11. 写回 update_hour
  350. df['update_hour'] = df.index
  351. # 12. 恢复普通索引
  352. df = df.reset_index(drop=True)
  353. return df
  354. def query_groups_of_city_code(db, from_city, to_city, table_name, min_days=20, max_retries=3, base_sleep=1.0):
  355. """
  356. 从一组城市对中查找所有分组(航班号与起飞时间)的组合
  357. 按:第一段航班号 → 第二段航班号 → 起飞时间 排序
  358. (失败自动重试) 保证2个月内至少有20天起飞的航线
  359. """
  360. print(f"{from_city}-{to_city} 查找所有分组")
  361. date_begin = (datetime.today() - timedelta(days=60)).strftime("%Y%m%d")
  362. date_end = datetime.today().strftime("%Y%m%d")
  363. pipeline = [
  364. # 1️⃣ 先筛选城市对
  365. {
  366. "$match": {
  367. "from_city_code": from_city,
  368. "to_city_code": to_city,
  369. "search_dep_time": {
  370. "$gte": date_begin,
  371. "$lte": date_end
  372. }
  373. }
  374. },
  375. # 2️⃣ 投影字段 + 拆第一、第二段航班号用于排序
  376. {
  377. "$project": {
  378. "flight_numbers": "$segments.flight_number",
  379. "search_dep_time": 1,
  380. "fn1": {"$arrayElemAt": ["$segments.flight_number", 0]},
  381. "fn2": {"$arrayElemAt": ["$segments.flight_number", 1]}
  382. }
  383. },
  384. # 3️⃣ 第一级分组:组合 + 每一天
  385. {
  386. "$group": {
  387. "_id": {
  388. "flight_numbers": "$flight_numbers",
  389. "search_dep_time": "$search_dep_time",
  390. "fn1": "$fn1",
  391. "fn2": "$fn2"
  392. },
  393. "count": {"$sum": 1}
  394. }
  395. },
  396. # 关键修复点:这里先按【时间】排好序!
  397. {
  398. "$sort": {
  399. "_id.fn1": 1,
  400. "_id.fn2": 1,
  401. "_id.search_dep_time": 1 # 确保 push 进去时是按天递增
  402. }
  403. },
  404. # 4️⃣ 第二级分组:只按【航班组合】聚合 → 统计“有多少天”
  405. {
  406. "$group": {
  407. "_id": {
  408. "flight_numbers": "$_id.flight_numbers",
  409. "fn1": "$_id.fn1",
  410. "fn2": "$_id.fn2"
  411. },
  412. "days": {"$sum": 1}, # 不同起飞天数
  413. "details": {
  414. "$push": {
  415. "search_dep_time": "$_id.search_dep_time",
  416. "count": "$count"
  417. }
  418. }
  419. }
  420. },
  421. # 5️⃣ 关键:按“天数阈值”过滤
  422. {
  423. "$match": {
  424. "days": {"$gte": min_days}
  425. }
  426. },
  427. # 6️⃣ ✅ 按“第一段 → 第二段”排序
  428. {
  429. "$sort": {
  430. "_id.fn1": 1,
  431. "_id.fn2": 1,
  432. }
  433. }
  434. ]
  435. for attempt in range(1, max_retries + 1):
  436. try:
  437. print(f" 第 {attempt}/{max_retries} 次尝试查询")
  438. # 执行聚合查询
  439. collection = db[table_name]
  440. results = list(collection.aggregate(pipeline))
  441. # 格式化结果,将 _id 中的字段提取到外层
  442. formatted_results = []
  443. for item in results:
  444. formatted_item = {
  445. "flight_numbers": item["_id"]["flight_numbers"],
  446. "days": item["days"], # 这个组合一共有多少天
  447. "details": item["details"] # 每一天的 count 明细
  448. }
  449. formatted_results.append(formatted_item)
  450. return formatted_results
  451. except (ServerSelectionTimeoutError, PyMongoError) as e:
  452. print(f"⚠️ Mongo 查询失败: {e}")
  453. if attempt == max_retries:
  454. print("❌ 达到最大重试次数,放弃")
  455. return []
  456. # 指数退避 + 随机抖动
  457. sleep_time = base_sleep * (2 ** (attempt - 1)) + random.random()
  458. print(f"⏳ {sleep_time:.2f}s 后重试...")
  459. time.sleep(sleep_time)
  460. def plot_c12_trend(df, output_dir="."):
  461. """
  462. 根据传入的 dataframe 绘制 adult_total_price 随 update_hour 的趋势图,
  463. 并按照 baggage 分类进行分组绘制。
  464. """
  465. # output_dir_photo = output_dir
  466. # 颜色与线型配置(按顺序循环使用)
  467. colors = ['green', 'blue', 'red', 'brown']
  468. linestyles = ['--', '--', '--', '--']
  469. # 确保时间字段为 datetime 类型
  470. if not hasattr(df['update_hour'], 'dt'):
  471. df['update_hour'] = pd.to_datetime(df['update_hour'])
  472. from_city = df['from_city_code'].mode().iloc[0]
  473. to_city = df['to_city_code'].mode().iloc[0]
  474. flight_number_1 = df['seg1_flight_number'].mode().iloc[0]
  475. flight_number_2 = df['seg2_flight_number'].mode().get(0, "")
  476. dep_time = df['seg1_dep_time'].mode().iloc[0]
  477. route = f"{from_city}-{to_city}"
  478. flight_number = f"{flight_number_1},{flight_number_2}" if flight_number_2 else f"{flight_number_1}"
  479. output_dir_photo = os.path.join(output_dir, route)
  480. os.makedirs(output_dir_photo, exist_ok=True)
  481. # 创建图表对象
  482. fig = plt.figure(figsize=(14, 8))
  483. # 按 baggage 分类绘制
  484. for i, (baggage_value, group) in enumerate(df.groupby('baggage')):
  485. # 按时间排序
  486. g = group.sort_values('update_hour').reset_index(drop=True)
  487. # 找价格变化点:与前一行不同的价格即为变化点
  488. # keep first row + change rows + last row
  489. change_points = g.loc[
  490. (g['adult_total_price'] != g['adult_total_price'].shift(1)) |
  491. (g.index == 0) |
  492. (g.index == len(g) - 1) # 终点
  493. ].drop_duplicates(subset=['update_hour'])
  494. # 绘制点和线条
  495. plt.plot(
  496. change_points['update_hour'],
  497. change_points['adult_total_price'],
  498. marker='o',
  499. color=colors[i % len(colors)],
  500. linestyle=linestyles[i % len(linestyles)],
  501. linewidth=2, markersize=6,
  502. markerfacecolor='white', markeredgewidth=2,
  503. label=f"Baggage {baggage_value}"
  504. )
  505. # 添加注释 (小时数, 价格)
  506. for _, row in change_points.iterrows():
  507. text = f"({row['hours_until_departure']}, {row['adult_total_price']})"
  508. plt.annotate(
  509. text,
  510. xy=(row['update_hour'], row['adult_total_price']),
  511. xytext=(0, 0), # 向右偏移
  512. textcoords="offset points",
  513. ha='left',
  514. va='center',
  515. fontsize=5, # 字体稍小
  516. color='gray',
  517. alpha=0.8,
  518. rotation=25,
  519. )
  520. # 自动优化日期显示
  521. plt.gcf().autofmt_xdate()
  522. plt.xlabel('时刻', fontsize=12, fontproperties=font_prop)
  523. plt.ylabel('价格', fontsize=12, fontproperties=font_prop)
  524. plt.title(f'价格变化趋势 - 航线: {route} 航班号: {flight_number}\n起飞时间: {dep_time}',
  525. fontsize=14, fontweight='bold', fontproperties=font_prop)
  526. # 设置 x 轴刻度为每天
  527. ax = plt.gca()
  528. ax.xaxis.set_major_locator(mdates.DayLocator(interval=1)) # 每天一个主刻度
  529. ax.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d')) # 显示月-日
  530. ax.xaxis.set_minor_locator(mdates.HourLocator(byhour=[12])) # 指定在12:00显示副刻度
  531. ax.xaxis.set_minor_formatter(mdates.DateFormatter('')) # 输出空字符串
  532. # ax.tick_params(axis='x', which='minor', labelsize=8, rotation=30)
  533. # 添加图例
  534. plt.legend(bbox_to_anchor=(1.05, 1), loc='best', prop=font_prop)
  535. plt.grid(True, alpha=0.3)
  536. plt.tight_layout()
  537. safe_flight = flight_number.replace(",", "_")
  538. safe_dep_time = dep_time.strftime("%Y-%m-%d %H%M%S")
  539. save_file = f"{route} {safe_flight} {safe_dep_time}.png"
  540. output_path = os.path.join(output_dir_photo, save_file)
  541. # 保存图片(在显示之前)
  542. plt.savefig(output_path, dpi=300, bbox_inches='tight', facecolor='white')
  543. # 关闭图形释放内存
  544. plt.close(fig)
  545. def process_flight_group(args):
  546. """处理单个航班号的线程函数(独立数据库连接)"""
  547. thread_id, db_config, each_group, from_city, to_city, date_begin_s, date_end_s, is_hot, plot_flag, output_dir = args
  548. flight_nums = each_group.get("flight_numbers")
  549. details = each_group.get("details")
  550. print(f"[线程{thread_id}] 开始处理航班号: {flight_nums}")
  551. # 为每个线程创建独立的数据库连接
  552. try:
  553. client, db = mongo_con_parse(db_config)
  554. print(f"[线程{thread_id}] ✅ 数据库连接创建成功")
  555. except Exception as e:
  556. print(f"[线程{thread_id}] ❌ 数据库连接创建失败: {e}")
  557. return pd.DataFrame()
  558. try:
  559. # 查询远期表
  560. if is_hot == 1:
  561. df1 = query_flight_range_status(db, CLEAN_VJ_HOT_FAR_INFO_TAB, from_city, to_city,
  562. date_begin_s, date_end_s, flight_nums)
  563. else:
  564. df1 = query_flight_range_status(db, CLEAN_VJ_NOTHOT_FAR_INFO_TAB, from_city, to_city,
  565. date_begin_s, date_end_s, flight_nums)
  566. # 保证远期表里有数据
  567. if df1.empty:
  568. print(f"[线程{thread_id}] 航班号:{flight_nums} 远期表无数据, 跳过")
  569. return pd.DataFrame()
  570. # 查询近期表
  571. if is_hot == 1:
  572. df2 = query_flight_range_status(db, CLEAN_VJ_HOT_NEAR_INFO_TAB, from_city, to_city,
  573. date_begin_s, date_end_s, flight_nums)
  574. else:
  575. df2 = query_flight_range_status(db, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB, from_city, to_city,
  576. date_begin_s, date_end_s, flight_nums)
  577. # 保证近期表里有数据
  578. if df2.empty:
  579. print(f"[线程{thread_id}] 航班号:{flight_nums} 近期表无数据, 跳过")
  580. return pd.DataFrame()
  581. # 起飞天数、行李配额以近期表的为主
  582. if df2.empty:
  583. common_dep_dates = []
  584. common_baggages = []
  585. else:
  586. common_dep_dates = df2['search_dep_time'].unique()
  587. common_baggages = df2['baggage'].unique()
  588. list_mid = []
  589. for dep_date in common_dep_dates:
  590. # 起飞日期筛选
  591. df_d1 = df1[df1["search_dep_time"] == dep_date].copy()
  592. if not df_d1.empty:
  593. for col in ["seg1_dep_time", "seg1_arr_time", "seg2_dep_time", "seg2_arr_time"]:
  594. mode_series_1 = df_d1[col].mode()
  595. if mode_series_1.empty:
  596. zong_1 = pd.NaT
  597. else:
  598. zong_1 = mode_series_1.iloc[0]
  599. df_d1[col] = zong_1
  600. df_d2 = df2[df2["search_dep_time"] == dep_date].copy()
  601. if not df_d2.empty:
  602. for col in ["seg1_dep_time", "seg1_arr_time", "seg2_dep_time", "seg2_arr_time"]:
  603. mode_series_2 = df_d2[col].mode()
  604. if mode_series_2.empty:
  605. zong_2 = pd.NaT
  606. else:
  607. zong_2 = mode_series_2.iloc[0]
  608. df_d2[col] = zong_2
  609. list_12 = []
  610. for baggage in common_baggages:
  611. # 行李配额筛选
  612. df_b1 = df_d1[df_d1["baggage"] == baggage].copy()
  613. df_b2 = df_d2[df_d2["baggage"] == baggage].copy()
  614. # 合并前检查是否都有数据
  615. if df_b1.empty and df_b2.empty:
  616. print(f"[线程{thread_id}] ⚠️ dep_date:{dep_date}, baggage:{baggage} 远期表和近期表都为空,跳过")
  617. continue
  618. cols = ["seg1_flight_number", "seg1_dep_air_port", "seg1_arr_air_port",
  619. "seg2_flight_number", "seg2_dep_air_port", "seg2_arr_air_port"]
  620. df_b1[cols] = df_b1[cols].astype("string")
  621. df_b2[cols] = df_b2[cols].astype("string")
  622. df_b12 = pd.concat([df_b1, df_b2]).reset_index(drop=True)
  623. # print(f"📊 dep_date:{dep_date}, baggage:{baggage} 已将远期表和近期表合并,形状: {df_b12.shape}")
  624. df_b12 = fill_hourly_crawl_date(df_b12, rear_fill=2)
  625. # print(f"📊 dep_date:{dep_date}, baggage:{baggage} 已合并且补齐为完整小时序列,形状: {df_b12.shape}")
  626. list_12.append(df_b12)
  627. del df_b12
  628. del df_b2
  629. del df_b1
  630. if list_12:
  631. df_c12 = pd.concat(list_12, ignore_index=True)
  632. if plot_flag:
  633. print(f"[线程{thread_id}] ✅ dep_date:{dep_date}, 所有 baggage 数据合并完成,总形状: {df_c12.shape}")
  634. plot_c12_trend(df_c12, output_dir)
  635. print(f"[线程{thread_id}] ✅ dep_date:{dep_date}, 所有 baggage 数据绘图完成")
  636. else:
  637. df_c12 = pd.DataFrame()
  638. if plot_flag:
  639. print(f"[线程{thread_id}] ⚠️ dep_date:{dep_date}, 所有 baggage 数据合并为空")
  640. del list_12
  641. list_mid.append(df_c12)
  642. del df_c12
  643. del df_d1
  644. del df_d2
  645. # print(f"结束处理起飞日期: {dep_date}")
  646. if list_mid:
  647. df_mid = pd.concat(list_mid, ignore_index=True)
  648. print(f"[线程{thread_id}] ✅ 航班号:{flight_nums} 所有 起飞日期 数据合并完成,总形状: {df_mid.shape}")
  649. else:
  650. df_mid = pd.DataFrame()
  651. print(f"[线程{thread_id}] ⚠️ 航班号:{flight_nums} 所有 起飞日期 数据合并为空")
  652. del list_mid
  653. del df1
  654. del df2
  655. gc.collect()
  656. print(f"[线程{thread_id}] 结束处理航班号: {flight_nums}")
  657. return df_mid
  658. except Exception as e:
  659. print(f"[线程{thread_id}] ❌ 处理航班号:{flight_nums} 时发生异常: {e}")
  660. return pd.DataFrame()
  661. finally:
  662. # 确保关闭数据库连接
  663. try:
  664. client.close()
  665. print(f"[线程{thread_id}] ✅ 数据库连接已关闭")
  666. except:
  667. pass
  668. def load_train_data(db_config, flight_route_list, table_name, date_begin, date_end, output_dir='.', is_hot=1, plot_flag=False,
  669. use_multithread=False, max_workers=None):
  670. """加载训练数据(支持多线程)"""
  671. timestamp_str = datetime.now().strftime("%Y%m%d%H%M%S")
  672. date_begin_s = datetime.strptime(date_begin, "%Y-%m-%d").strftime("%Y%m%d") # 查询时的格式
  673. date_end_s = datetime.strptime(date_end, "%Y-%m-%d").strftime("%Y%m%d")
  674. list_all = []
  675. # 每一航线对
  676. for flight_route in flight_route_list:
  677. from_city = flight_route.split('-')[0]
  678. to_city = flight_route.split('-')[1]
  679. route = f"{from_city}-{to_city}"
  680. print(f"开始处理航线: {route}")
  681. # 在主线程中查询航班号分组(避免多线程重复查询)
  682. main_client, main_db = mongo_con_parse(db_config)
  683. all_groups = query_groups_of_city_code(main_db, from_city, to_city, table_name)
  684. main_client.close()
  685. all_groups_len = len(all_groups)
  686. print(f"该航线共有{all_groups_len}个航班号")
  687. if use_multithread and all_groups_len > 1:
  688. print(f"启用多线程处理,最大线程数: {max_workers}")
  689. # 多线程处理
  690. thread_args = []
  691. thread_id = 0
  692. for each_group in all_groups:
  693. thread_id += 1
  694. args = (thread_id, db_config, each_group, from_city, to_city, date_begin_s, date_end_s, is_hot, plot_flag, output_dir)
  695. thread_args.append(args)
  696. with ThreadPoolExecutor(max_workers=max_workers) as executor:
  697. future_to_group = {executor.submit(process_flight_group, args): each_group for args, each_group in zip(thread_args, all_groups)}
  698. for future in as_completed(future_to_group):
  699. each_group = future_to_group[future]
  700. flight_nums = each_group.get("flight_numbers", "未知")
  701. try:
  702. df_mid = future.result()
  703. if not df_mid.empty:
  704. list_all.append(df_mid)
  705. print(f"✅ 航班号:{flight_nums} 处理完成")
  706. else:
  707. print(f"⚠️ 航班号:{flight_nums} 处理结果为空")
  708. except Exception as e:
  709. print(f"❌ 航班号:{flight_nums} 处理异常: {e}")
  710. else:
  711. # 单线程处理(线程编号为0)
  712. print("使用单线程处理")
  713. thread_id = 0
  714. for each_group in all_groups:
  715. args = (thread_id, db_config, each_group, from_city, to_city, date_begin_s, date_end_s, is_hot, plot_flag, output_dir)
  716. flight_nums = each_group.get("flight_numbers", "未知")
  717. try:
  718. df_mid = process_flight_group(args)
  719. if not df_mid.empty:
  720. list_all.append(df_mid)
  721. print(f"✅ 航班号:{flight_nums} 处理完成")
  722. else:
  723. print(f"⚠️ 航班号:{flight_nums} 处理结果为空")
  724. except Exception as e:
  725. print(f"❌ 航班号:{flight_nums} 处理异常: {e}")
  726. print(f"结束处理航线: {from_city}-{to_city}")
  727. if list_all:
  728. df_all = pd.concat(list_all, ignore_index=True)
  729. else:
  730. df_all = pd.DataFrame()
  731. print(f"本批次数据加载完毕, 总形状: {df_all.shape}")
  732. del list_all
  733. gc.collect()
  734. return df_all
  735. def query_all_flight_number(db, table_name):
  736. print(f"{table_name} 查找所有航班号")
  737. pipeline = [
  738. {
  739. "$project": {
  740. "flight_numbers": "$segments.flight_number"
  741. }
  742. },
  743. {
  744. "$group": {
  745. "_id": "$flight_numbers",
  746. "count": { "$sum": 1 }
  747. }
  748. },
  749. ]
  750. # 执行聚合查询
  751. collection = db[table_name]
  752. results = list(collection.aggregate(pipeline))
  753. list_flight_number = []
  754. for item in results:
  755. item_li = item.get("_id", [])
  756. list_flight_number.extend(item_li)
  757. list_flight_number = list(set(list_flight_number))
  758. return list_flight_number
  759. def validate_one_line(db, city_pair, flight_day, flight_number_1, flight_number_2, baggage, valid_begin_hour,
  760. limit=0, max_retries=3, base_sleep=1.0):
  761. """验证预测结果的一行"""
  762. if city_pair in vj_flight_route_list_hot:
  763. table_name = CLEAN_VJ_HOT_NEAR_INFO_TAB
  764. elif city_pair in vj_flight_route_list_nothot:
  765. table_name = CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
  766. else:
  767. print(f"城市对{city_pair}不在热门航线与冷门航线, 返回")
  768. return pd.DataFrame()
  769. city_pair_split = city_pair.split('-')
  770. from_city_code = city_pair_split[0]
  771. to_city_code = city_pair_split[1]
  772. flight_day_str = datetime.strptime(flight_day, "%Y-%m-%d").strftime("%Y%m%d")
  773. baggage_str = f"1-{baggage}"
  774. for attempt in range(1, max_retries + 1):
  775. try:
  776. print(f"🔁 第 {attempt}/{max_retries} 次尝试查询")
  777. # 构建查询条件
  778. query_condition = {
  779. "from_city_code": from_city_code,
  780. "to_city_code": to_city_code,
  781. "search_dep_time": flight_day_str,
  782. "segments.baggage": baggage_str,
  783. "crawl_date": {"$gte": valid_begin_hour},
  784. "segments.0.flight_number": flight_number_1,
  785. }
  786. # 如果有第二段
  787. if flight_number_2 != "VJ":
  788. query_condition["segments.1.flight_number"] = flight_number_2
  789. print(f" 查询条件: {query_condition}")
  790. # 定义要查询的字段
  791. projection = {
  792. # "_id": 1,
  793. "from_city_code": 1,
  794. "search_dep_time": 1,
  795. "to_city_code": 1,
  796. "currency": 1,
  797. "adult_price": 1,
  798. "adult_tax": 1,
  799. "adult_total_price": 1,
  800. "seats_remaining": 1,
  801. "segments": 1,
  802. "source_website": 1,
  803. "crawl_date": 1
  804. }
  805. # 执行查询
  806. cursor = db.get_collection(table_name).find(
  807. query_condition,
  808. projection=projection # 添加投影参数
  809. ).sort(
  810. [
  811. ("crawl_date", 1)
  812. ]
  813. )
  814. if limit > 0:
  815. cursor = cursor.limit(limit)
  816. # 将结果转换为列表
  817. results = list(cursor)
  818. print(f"✅ 查询成功,找到 {len(results)} 条记录")
  819. if results:
  820. df = pd.DataFrame(results)
  821. # 处理特殊的 ObjectId 类型
  822. if '_id' in df.columns:
  823. df = df.drop(columns=['_id'])
  824. print(f"📊 已转换为 DataFrame,形状: {df.shape}")
  825. # 1️⃣ 展开 segments
  826. print(f"📊 开始扩展segments 稍等...")
  827. t1 = time.time()
  828. df = expand_segments_columns_optimized(df)
  829. t2 = time.time()
  830. rt = round(t2 - t1, 3)
  831. print(f"用时: {rt} 秒")
  832. print(f"📊 已将segments扩展成字段,形状: {df.shape}")
  833. # 不用排序,因为mongo语句已经排好
  834. return df
  835. else:
  836. print("⚠️ 查询结果为空")
  837. return pd.DataFrame()
  838. except (ServerSelectionTimeoutError, PyMongoError) as e:
  839. print(f"⚠️ Mongo 查询失败: {e}")
  840. if attempt == max_retries:
  841. print("❌ 达到最大重试次数,放弃")
  842. return pd.DataFrame()
  843. # 指数退避 + 随机抖动
  844. sleep_time = base_sleep * (2 ** (attempt - 1)) + random.random()
  845. print(f"⏳ {sleep_time:.2f}s 后重试...")
  846. time.sleep(sleep_time)
  847. if __name__ == "__main__":
  848. # test_mongo_connection(db)
  849. from utils import chunk_list_with_index
  850. cpu_cores = os.cpu_count() # 你的系统是72
  851. max_workers = min(16, cpu_cores) # 最大不超过16个线程
  852. output_dir = f"./output"
  853. os.makedirs(output_dir, exist_ok=True)
  854. # 加载热门航线数据
  855. date_begin = "2025-12-07"
  856. date_end = datetime.today().strftime("%Y-%m-%d")
  857. flight_route_list = vj_flight_route_list_hot[0:] # 热门 vj_flight_route_list_hot 冷门 vj_flight_route_list_nothot
  858. table_name = CLEAN_VJ_HOT_NEAR_INFO_TAB # 热门 CLEAN_VJ_HOT_NEAR_INFO_TAB 冷门 CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
  859. is_hot = 1 # 1 热门 0 冷门
  860. group_size = 1
  861. chunks = chunk_list_with_index(flight_route_list, group_size)
  862. for idx, (_, group_route_list) in enumerate(chunks, 1):
  863. # 使用默认配置
  864. # client, db = mongo_con_parse()
  865. print(f"第 {idx} 组 :", group_route_list)
  866. start_time = time.time()
  867. load_train_data(mongodb_config, group_route_list, table_name, date_begin, date_end, output_dir, is_hot, plot_flag=False,
  868. use_multithread=False, max_workers=max_workers)
  869. end_time = time.time()
  870. run_time = round(end_time - start_time, 3)
  871. print(f"用时: {run_time} 秒")
  872. # client.close()
  873. time.sleep(3)
  874. print("整体结束")
  875. # client, db = mongo_con_parse()
  876. # list_flight_number_1 = query_all_flight_number(db, CLEAN_VJ_HOT_NEAR_INFO_TAB)
  877. # list_flight_number_2 = query_all_flight_number(db, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB)
  878. # list_flight_number_all = list_flight_number_1 + list_flight_number_2
  879. # list_flight_number_all = list(set(list_flight_number_all))
  880. # list_flight_number_all.sort()
  881. # print(list_flight_number_all)
  882. # print(len(list_flight_number_all))
  883. # flight_map = {v: i for i, v in enumerate(list_flight_number_all, start=1)}
  884. # print(flight_map)