data_loader.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745
  1. import gc
  2. import time
  3. from datetime import datetime
  4. import pymongo
  5. from pymongo.errors import PyMongoError, ServerSelectionTimeoutError
  6. import pandas as pd
  7. import os
  8. import random
  9. import numpy as np
  10. import matplotlib.pyplot as plt
  11. from matplotlib import font_manager
  12. import matplotlib.dates as mdates
  13. from config import mongodb_config, vj_flight_route_list, vj_flight_route_list_hot, vj_flight_route_list_nothot, \
  14. CLEAN_VJ_HOT_NEAR_INFO_TAB, CLEAN_VJ_HOT_FAR_INFO_TAB, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB, CLEAN_VJ_NOTHOT_FAR_INFO_TAB
  15. font_path = "./simhei.ttf"
  16. font_prop = font_manager.FontProperties(fname=font_path)
  17. def mongo_con_parse(config=None):
  18. if config is None:
  19. config = mongodb_config.copy()
  20. try:
  21. if config.get("URI", ""):
  22. motor_uri = config["URI"]
  23. client = pymongo.MongoClient(motor_uri, maxPoolSize=100)
  24. db = client[config['db']]
  25. print("motor_uri: ", motor_uri)
  26. else:
  27. client = pymongo.MongoClient(
  28. config['host'],
  29. config['port'],
  30. serverSelectionTimeoutMS=6000, # 6秒
  31. connectTimeoutMS=6000, # 6秒
  32. socketTimeoutMS=6000, # 6秒,
  33. retryReads=True, # 开启重试
  34. maxPoolSize=50
  35. )
  36. db = client[config['db']]
  37. if config.get('user'):
  38. db.authenticate(config['user'], config['pwd'])
  39. print(f"✅ MongoDB 连接对象创建成功")
  40. except Exception as e:
  41. print(f"❌ 创建 MongoDB 连接对象时发生错误: {e}")
  42. raise
  43. return client, db
  44. def test_mongo_connection(db):
  45. try:
  46. # 获取客户端对象
  47. client = db.client
  48. # 方法1:使用 server_info() 测试连接
  49. info = client.server_info()
  50. print(f"✅ MongoDB 连接测试成功!")
  51. print(f" 服务器版本: {info.get('version')}")
  52. print(f" 数据库: {db.name}")
  53. return True
  54. except Exception as e:
  55. print(f"❌ 数据库连接测试失败: {e}")
  56. return False
  57. def query_flight_range_status(db, table_name, from_city, to_city, dep_date_begin, dep_date_end, flight_nums,
  58. limit=0, max_retries=3, base_sleep=1.0):
  59. """
  60. 从指定表(4类)查询数据(指定起飞天的范围) (失败自动重试)
  61. """
  62. for attempt in range(1, max_retries + 1):
  63. try:
  64. print(f"🔁 第 {attempt}/{max_retries} 次尝试查询")
  65. # 构建查询条件
  66. query_condition = {
  67. "from_city_code": from_city,
  68. "to_city_code": to_city,
  69. "search_dep_time": {
  70. "$gte": dep_date_begin,
  71. "$lte": dep_date_end,
  72. },
  73. "segments.baggage": {"$in": ["1-20", "1-30"]} # 只查20公斤和30公斤行李的
  74. }
  75. # 动态添加航班号条件
  76. for i, flight_num in enumerate(flight_nums):
  77. query_condition[f"segments.{i}.flight_number"] = flight_num
  78. print(f" 查询条件: {query_condition}")
  79. # 定义要查询的字段
  80. projection = {
  81. # "_id": 1,
  82. "from_city_code": 1,
  83. "search_dep_time": 1,
  84. "to_city_code": 1,
  85. "currency": 1,
  86. "adult_price": 1,
  87. "adult_tax": 1,
  88. "adult_total_price": 1,
  89. "seats_remaining": 1,
  90. "segments": 1,
  91. "source_website": 1,
  92. "crawl_date": 1
  93. }
  94. # 执行查询
  95. cursor = db.get_collection(table_name).find(
  96. query_condition,
  97. projection=projection # 添加投影参数
  98. ).sort(
  99. [
  100. ("search_dep_time", 1), # 多级排序要用列表+元组的格式
  101. ("segments.0.baggage", 1),
  102. ("crawl_date", 1)
  103. ]
  104. )
  105. if limit > 0:
  106. cursor = cursor.limit(limit)
  107. # 将结果转换为列表
  108. results = list(cursor)
  109. print(f"✅ 查询成功,找到 {len(results)} 条记录")
  110. if results:
  111. df = pd.DataFrame(results)
  112. # 处理特殊的 ObjectId 类型
  113. if '_id' in df.columns:
  114. df = df.drop(columns=['_id'])
  115. print(f"📊 已转换为 DataFrame,形状: {df.shape}")
  116. # 1️⃣ 展开 segments
  117. print(f"📊 开始扩展segments 稍等...")
  118. t1 = time.time()
  119. df = expand_segments_columns(df)
  120. t2 = time.time()
  121. rt = round(t2 - t1, 3)
  122. print(f"用时: {rt} 秒")
  123. print(f"📊 已将segments扩展成字段,形状: {df.shape}")
  124. # 不用排序,因为mongo语句已经排好
  125. return df
  126. else:
  127. print("⚠️ 查询结果为空")
  128. return pd.DataFrame()
  129. except (ServerSelectionTimeoutError, PyMongoError) as e:
  130. print(f"⚠️ Mongo 查询失败: {e}")
  131. if attempt == max_retries:
  132. print("❌ 达到最大重试次数,放弃")
  133. return pd.DataFrame()
  134. # 指数退避 + 随机抖动
  135. sleep_time = base_sleep * (2 ** (attempt - 1)) + random.random()
  136. print(f"⏳ {sleep_time:.2f}s 后重试...")
  137. time.sleep(sleep_time)
  138. def expand_segments_columns(df):
  139. """展开 segments"""
  140. df = df.copy()
  141. # 定义要展开的列
  142. seg1_cols = ['flight_number', 'dep_air_port', 'dep_time', 'arr_air_port', 'arr_time', 'cabin', 'baggage']
  143. seg2_cols = ['flight_number', 'dep_air_port', 'dep_time', 'arr_air_port', 'arr_time']
  144. # 定义 apply 函数一次返回字典
  145. def extract_segments(row):
  146. segments = row.get('segments')
  147. result = {}
  148. # 默认缺失使用 pd.NA(对字符串友好)
  149. missing = pd.NA
  150. if isinstance(segments, list):
  151. # 第一段
  152. if len(segments) >= 1 and isinstance(segments[0], dict):
  153. for col in seg1_cols:
  154. result[f'seg1_{col}'] = segments[0].get(col)
  155. else:
  156. for col in seg1_cols:
  157. result[f'seg1_{col}'] = missing
  158. # 第二段
  159. if len(segments) >= 2 and isinstance(segments[1], dict):
  160. for col in seg2_cols:
  161. result[f'seg2_{col}'] = segments[1].get(col)
  162. else:
  163. for col in seg2_cols:
  164. result[f'seg2_{col}'] = missing
  165. else:
  166. # segments 不是 list,全都置空
  167. for col in seg1_cols:
  168. result[f'seg1_{col}'] = missing
  169. for col in seg2_cols:
  170. result[f'seg2_{col}'] = missing
  171. return pd.Series(result)
  172. # 一次 apply
  173. df_segments = df.apply(extract_segments, axis=1)
  174. # 拼回原 df
  175. df = pd.concat([df.drop(columns=['segments'], errors='ignore'), df_segments], axis=1)
  176. # 统一转换时间字段为 datetime
  177. time_cols = [
  178. 'seg1_dep_time', 'seg1_arr_time',
  179. 'seg2_dep_time', 'seg2_arr_time'
  180. ]
  181. for col in time_cols:
  182. if col in df.columns:
  183. df[col] = pd.to_datetime(
  184. df[col],
  185. format='%Y%m%d%H%M%S',
  186. errors='coerce'
  187. )
  188. # 站点来源 -> 是否近期
  189. df['source_website'] = np.where(
  190. df['source_website'].str.contains('7_30'),
  191. 0, # 远期 -> 0
  192. np.where(df['source_website'].str.contains('0_7'),
  193. 1, # 近期 -> 1
  194. df['source_website']) # 其他情况保持原值
  195. )
  196. # 行李配额字符 -> 数字
  197. conditions = [
  198. df['seg1_baggage'] == '-;-;-;-',
  199. df['seg1_baggage'] == '1-20',
  200. df['seg1_baggage'] == '1-30',
  201. df['seg1_baggage'] == '1-40',
  202. ]
  203. choices = [0, 20, 30, 40]
  204. df['seg1_baggage'] = np.select(conditions, choices, default=df['seg1_baggage'])
  205. # 重命名字段
  206. df = df.rename(columns={
  207. 'seg1_cabin': 'cabin',
  208. 'seg1_baggage': 'baggage',
  209. 'source_website': 'is_near',
  210. })
  211. return df
  212. def fill_hourly_crawl_date(df, head_fill=0, rear_fill=0):
  213. """补齐成小时粒度数据"""
  214. df = df.copy()
  215. # 1. 转 datetime
  216. df['crawl_date'] = pd.to_datetime(df['crawl_date'])
  217. # 添加一个用于分组的小时字段
  218. df['update_hour'] = df['crawl_date'].dt.floor('h')
  219. # 2. 排序规则:同一小时内,按原始时间戳排序
  220. # 假设你想保留最早的一条
  221. df = df.sort_values(['update_hour', 'crawl_date'])
  222. # 3. 按小时去重,保留该小时内最早的一条
  223. df = df.drop_duplicates(subset=['update_hour'], keep='first')
  224. # 删除原始时间戳列
  225. # df = df.drop(columns=['crawl_date'])
  226. # df = df.drop(columns=['_id'])
  227. # 4. 标记原始数据
  228. df['is_filled'] = 0
  229. # 5. 排序 + 设索引
  230. df = df.sort_values('update_hour').set_index('update_hour')
  231. # 6. 构造完整小时轴
  232. start_of_day = df.index.min() # 默认 第一天 最早 开始
  233. if head_fill == 1:
  234. start_of_day = df.index.min().normalize() # 强制 第一天 00:00 开始
  235. end_of_day = df.index.max() # 默认 最后一天 最晚 结束
  236. if rear_fill == 1:
  237. end_of_day = df.index.max().normalize() + pd.Timedelta(hours=23) # 强制 最后一天 23:00 结束
  238. elif rear_fill == 2:
  239. if 'seg1_dep_time' in df.columns:
  240. last_dep_time = df['seg1_dep_time'].iloc[-1]
  241. if pd.notna(last_dep_time):
  242. # 对齐到整点小时(向下取整)
  243. end_of_day = last_dep_time.floor('h')
  244. full_index = pd.date_range(
  245. start=start_of_day,
  246. end=end_of_day,
  247. freq='1h'
  248. )
  249. # 7. 按小时补齐
  250. df = df.reindex(full_index)
  251. # 先恢复 dtype(关键!)
  252. df = df.infer_objects(copy=False)
  253. # 8. 新增出来的行标记为 1
  254. df['is_filled'] = df['is_filled'].fillna(1)
  255. # 9. 前向填充
  256. df = df.ffill()
  257. # 10. 还原整型字段
  258. int_cols = [
  259. 'seats_remaining',
  260. 'is_near',
  261. 'baggage',
  262. 'is_filled',
  263. ]
  264. for col in int_cols:
  265. if col in df.columns:
  266. df[col] = df[col].astype('int64')
  267. # 10.5 价格字段统一保留两位小数
  268. price_cols = [
  269. 'adult_price',
  270. 'adult_tax',
  271. 'adult_total_price'
  272. ]
  273. for col in price_cols:
  274. if col in df.columns:
  275. df[col] = df[col].astype('float64').round(2)
  276. # 10.6 新增:距离起飞还有多少小时
  277. if 'seg1_dep_time' in df.columns:
  278. # 创建临时字段(整点)
  279. df['seg1_dep_hour'] = df['seg1_dep_time'].dt.floor('h')
  280. # 计算小时差 df.index 此时就是 update_hour
  281. df['hours_until_departure'] = (
  282. (df['seg1_dep_hour'] - df.index) / pd.Timedelta(hours=1)
  283. ).astype('int64')
  284. # 删除临时字段
  285. df = df.drop(columns=['seg1_dep_hour'])
  286. # 11. 写回 update_hour
  287. df['update_hour'] = df.index
  288. # 12. 恢复普通索引
  289. df = df.reset_index(drop=True)
  290. return df
  291. def query_groups_of_city_code(db, from_city, to_city, table_name, min_days=15, max_retries=3, base_sleep=1.0):
  292. """
  293. 从一组城市对中查找所有分组(航班号与起飞时间)的组合
  294. 按:第一段航班号 → 第二段航班号 → 起飞时间 排序
  295. (失败自动重试)
  296. """
  297. print(f"{from_city}-{to_city} 查找所有分组")
  298. pipeline = [
  299. # 1️⃣ 先筛选城市对
  300. {
  301. "$match": {
  302. "from_city_code": from_city,
  303. "to_city_code": to_city
  304. }
  305. },
  306. # 2️⃣ 投影字段 + 拆第一、第二段航班号用于排序
  307. {
  308. "$project": {
  309. "flight_numbers": "$segments.flight_number",
  310. "search_dep_time": 1,
  311. "fn1": {"$arrayElemAt": ["$segments.flight_number", 0]},
  312. "fn2": {"$arrayElemAt": ["$segments.flight_number", 1]}
  313. }
  314. },
  315. # 3️⃣ 第一级分组:组合 + 每一天
  316. {
  317. "$group": {
  318. "_id": {
  319. "flight_numbers": "$flight_numbers",
  320. "search_dep_time": "$search_dep_time",
  321. "fn1": "$fn1",
  322. "fn2": "$fn2"
  323. },
  324. "count": {"$sum": 1}
  325. }
  326. },
  327. # 关键修复点:这里先按【时间】排好序!
  328. {
  329. "$sort": {
  330. "_id.fn1": 1,
  331. "_id.fn2": 1,
  332. "_id.search_dep_time": 1 # 确保 push 进去时是按天递增
  333. }
  334. },
  335. # 4️⃣ 第二级分组:只按【航班组合】聚合 → 统计“有多少天”
  336. {
  337. "$group": {
  338. "_id": {
  339. "flight_numbers": "$_id.flight_numbers",
  340. "fn1": "$_id.fn1",
  341. "fn2": "$_id.fn2"
  342. },
  343. "days": {"$sum": 1}, # 不同起飞天数
  344. "details": {
  345. "$push": {
  346. "search_dep_time": "$_id.search_dep_time",
  347. "count": "$count"
  348. }
  349. }
  350. }
  351. },
  352. # 5️⃣ 关键:按“天数阈值”过滤
  353. {
  354. "$match": {
  355. "days": {"$gte": min_days}
  356. }
  357. },
  358. # 6️⃣ ✅ 按“第一段 → 第二段”排序
  359. {
  360. "$sort": {
  361. "_id.fn1": 1,
  362. "_id.fn2": 1,
  363. }
  364. }
  365. ]
  366. for attempt in range(1, max_retries + 1):
  367. try:
  368. print(f" 第 {attempt}/{max_retries} 次尝试查询")
  369. # 执行聚合查询
  370. collection = db[table_name]
  371. results = list(collection.aggregate(pipeline))
  372. # 格式化结果,将 _id 中的字段提取到外层
  373. formatted_results = []
  374. for item in results:
  375. formatted_item = {
  376. "flight_numbers": item["_id"]["flight_numbers"],
  377. "days": item["days"], # 这个组合一共有多少天
  378. "details": item["details"] # 每一天的 count 明细
  379. }
  380. formatted_results.append(formatted_item)
  381. return formatted_results
  382. except (ServerSelectionTimeoutError, PyMongoError) as e:
  383. print(f"⚠️ Mongo 查询失败: {e}")
  384. if attempt == max_retries:
  385. print("❌ 达到最大重试次数,放弃")
  386. return []
  387. # 指数退避 + 随机抖动
  388. sleep_time = base_sleep * (2 ** (attempt - 1)) + random.random()
  389. print(f"⏳ {sleep_time:.2f}s 后重试...")
  390. time.sleep(sleep_time)
  391. def plot_c12_trend(df, output_dir="."):
  392. """
  393. 根据传入的 dataframe 绘制 adult_total_price 随 update_hour 的趋势图,
  394. 并按照 baggage 分类进行分组绘制。
  395. """
  396. # output_dir_photo = output_dir
  397. # 颜色与线型配置(按顺序循环使用)
  398. colors = ['green', 'blue', 'red', 'brown']
  399. linestyles = ['--', '--', '--', '--']
  400. # 确保时间字段为 datetime 类型
  401. if not hasattr(df['update_hour'], 'dt'):
  402. df['update_hour'] = pd.to_datetime(df['update_hour'])
  403. from_city = df['from_city_code'].mode().iloc[0]
  404. to_city = df['to_city_code'].mode().iloc[0]
  405. flight_number_1 = df['seg1_flight_number'].mode().iloc[0]
  406. flight_number_2 = df['seg2_flight_number'].mode().get(0, "")
  407. dep_time = df['seg1_dep_time'].mode().iloc[0]
  408. route = f"{from_city}-{to_city}"
  409. flight_number = f"{flight_number_1},{flight_number_2}" if flight_number_2 else f"{flight_number_1}"
  410. output_dir_photo = os.path.join(output_dir, route)
  411. os.makedirs(output_dir_photo, exist_ok=True)
  412. # 创建图表对象
  413. fig = plt.figure(figsize=(14, 8))
  414. # 按 baggage 分类绘制
  415. for i, (baggage_value, group) in enumerate(df.groupby('baggage')):
  416. # 按时间排序
  417. g = group.sort_values('update_hour').reset_index(drop=True)
  418. # 找价格变化点:与前一行不同的价格即为变化点
  419. # keep first row + change rows + last row
  420. change_points = g.loc[
  421. (g['adult_total_price'] != g['adult_total_price'].shift(1)) |
  422. (g.index == 0) |
  423. (g.index == len(g) - 1) # 终点
  424. ].drop_duplicates(subset=['update_hour'])
  425. # 绘制点和线条
  426. plt.plot(
  427. change_points['update_hour'],
  428. change_points['adult_total_price'],
  429. marker='o',
  430. color=colors[i % len(colors)],
  431. linestyle=linestyles[i % len(linestyles)],
  432. linewidth=2, markersize=6,
  433. markerfacecolor='white', markeredgewidth=2,
  434. label=f"Baggage {baggage_value}"
  435. )
  436. # 添加注释 (小时数, 价格)
  437. for _, row in change_points.iterrows():
  438. text = f"({row['hours_until_departure']}, {row['adult_total_price']})"
  439. plt.annotate(
  440. text,
  441. xy=(row['update_hour'], row['adult_total_price']),
  442. xytext=(0, 0), # 向右偏移
  443. textcoords="offset points",
  444. ha='left',
  445. va='center',
  446. fontsize=5, # 字体稍小
  447. color='gray',
  448. alpha=0.8,
  449. rotation=25,
  450. )
  451. # 自动优化日期显示
  452. plt.gcf().autofmt_xdate()
  453. plt.xlabel('时刻', fontsize=12, fontproperties=font_prop)
  454. plt.ylabel('价格', fontsize=12, fontproperties=font_prop)
  455. plt.title(f'价格变化趋势 - 航线: {route} 航班号: {flight_number}\n起飞时间: {dep_time}',
  456. fontsize=14, fontweight='bold', fontproperties=font_prop)
  457. # 设置 x 轴刻度为每天
  458. ax = plt.gca()
  459. ax.xaxis.set_major_locator(mdates.DayLocator(interval=1)) # 每天一个主刻度
  460. ax.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d')) # 显示月-日
  461. ax.xaxis.set_minor_locator(mdates.HourLocator(byhour=[12])) # 指定在12:00显示副刻度
  462. ax.xaxis.set_minor_formatter(mdates.DateFormatter('')) # 输出空字符串
  463. # ax.tick_params(axis='x', which='minor', labelsize=8, rotation=30)
  464. # 添加图例
  465. plt.legend(bbox_to_anchor=(1.05, 1), loc='best', prop=font_prop)
  466. plt.grid(True, alpha=0.3)
  467. plt.tight_layout()
  468. safe_flight = flight_number.replace(",", "_")
  469. safe_dep_time = dep_time.strftime("%Y-%m-%d %H%M%S")
  470. save_file = f"{route} {safe_flight} {safe_dep_time}.png"
  471. output_path = os.path.join(output_dir_photo, save_file)
  472. # 保存图片(在显示之前)
  473. plt.savefig(output_path, dpi=300, bbox_inches='tight', facecolor='white')
  474. # 关闭图形释放内存
  475. plt.close(fig)
  476. def load_train_data(db, flight_route_list, table_name, date_begin, date_end, output_dir='.', is_hot=1):
  477. """加载训练数据"""
  478. timestamp_str = datetime.now().strftime("%Y%m%d%H%M%S")
  479. date_begin_s = datetime.strptime(date_begin, "%Y-%m-%d").strftime("%Y%m%d") # 查询时的格式
  480. date_end_s = datetime.strptime(date_end, "%Y-%m-%d").strftime("%Y%m%d")
  481. list_all = []
  482. # 每一航线对
  483. for flight_route in flight_route_list:
  484. from_city = flight_route.split('-')[0]
  485. to_city = flight_route.split('-')[1]
  486. route = f"{from_city}-{to_city}"
  487. print(f"开始处理航线: {route}")
  488. all_groups = query_groups_of_city_code(db, from_city, to_city, table_name)
  489. # 每一组航班号
  490. for each_group in all_groups:
  491. flight_nums = each_group.get("flight_numbers")
  492. print(f"开始处理航班号: {flight_nums}")
  493. details = each_group.get("details")
  494. # 查远期表
  495. if is_hot == 1:
  496. df1 = query_flight_range_status(db, CLEAN_VJ_HOT_FAR_INFO_TAB, from_city, to_city,
  497. date_begin_s, date_end_s, flight_nums)
  498. else:
  499. df1 = query_flight_range_status(db, CLEAN_VJ_NOTHOT_FAR_INFO_TAB, from_city, to_city,
  500. date_begin_s, date_end_s, flight_nums)
  501. # 保证远期表里有数据
  502. if df1.empty:
  503. print(f"航班号:{flight_nums} 远期表无数据, 跳过")
  504. continue
  505. # 查近期表
  506. if is_hot == 1:
  507. df2 = query_flight_range_status(db, CLEAN_VJ_HOT_NEAR_INFO_TAB, from_city, to_city,
  508. date_begin_s, date_end_s, flight_nums)
  509. else:
  510. df2 = query_flight_range_status(db, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB, from_city, to_city,
  511. date_begin_s, date_end_s, flight_nums)
  512. # 保证近期表里有数据
  513. if df2.empty:
  514. print(f"航班号:{flight_nums} 近期表无数据, 跳过")
  515. continue
  516. # 起飞天数、行李配额以近期表的为主
  517. if df2.empty:
  518. common_dep_dates = []
  519. common_baggages = []
  520. else:
  521. common_dep_dates = df2['search_dep_time'].unique()
  522. common_baggages = df2['baggage'].unique()
  523. list_mid = []
  524. for dep_date in common_dep_dates:
  525. # 起飞日期筛选
  526. df_d1 = df1[df1["search_dep_time"] == dep_date].copy()
  527. if not df_d1.empty:
  528. for col in ["seg1_dep_time", "seg1_arr_time", "seg2_dep_time", "seg2_arr_time"]:
  529. mode_series_1 = df_d1[col].mode()
  530. if mode_series_1.empty:
  531. # 如果整个列都是 NaT,则众数为空,直接赋 NaT
  532. zong_1 = pd.NaT
  533. else:
  534. zong_1 = mode_series_1.iloc[0]
  535. df_d1[col] = zong_1
  536. df_d2 = df2[df2["search_dep_time"] == dep_date].copy()
  537. if not df_d2.empty:
  538. for col in ["seg1_dep_time", "seg1_arr_time", "seg2_dep_time", "seg2_arr_time"]:
  539. mode_series_2 = df_d2[col].mode()
  540. if mode_series_2.empty:
  541. # 如果整个列都是 NaT,则众数为空,直接赋 NaT
  542. zong_2 = pd.NaT
  543. else:
  544. zong_2 = mode_series_2.iloc[0]
  545. df_d2[col] = zong_2
  546. list_12 = []
  547. for baggage in common_baggages:
  548. # 行李配额筛选
  549. df_b1 = df_d1[df_d1["baggage"] == baggage].copy()
  550. df_b2 = df_d2[df_d2["baggage"] == baggage].copy()
  551. # 合并前检查是否都有数据
  552. if df_b1.empty and df_b2.empty:
  553. print(f"⚠️ dep_date:{dep_date}, baggage:{baggage} 远期表和近期表都为空,跳过")
  554. continue
  555. cols = ["seg1_flight_number", "seg1_dep_air_port", "seg1_arr_air_port",
  556. "seg2_flight_number", "seg2_dep_air_port", "seg2_arr_air_port"]
  557. # df_b1 = df_b1.copy()
  558. # df_b2 = df_b2.copy()
  559. df_b1[cols] = df_b1[cols].astype("string")
  560. df_b2[cols] = df_b2[cols].astype("string")
  561. df_b12 = pd.concat([df_b1, df_b2]).reset_index(drop=True)
  562. print(f"📊 dep_date:{dep_date}, baggage:{baggage} 已将远期表和近期表合并,形状: {df_b12.shape}")
  563. df_b12 = fill_hourly_crawl_date(df_b12, rear_fill=2)
  564. print(f"📊 dep_date:{dep_date}, baggage:{baggage} 已合并且补齐为完整小时序列,形状: {df_b12.shape}")
  565. # print(df_b12.dtypes)
  566. list_12.append(df_b12)
  567. del df_b12
  568. del df_b2
  569. del df_b1
  570. if list_12:
  571. df_c12 = pd.concat(list_12, ignore_index=True)
  572. print(f"✅ dep_date:{dep_date}, 所有 baggage 数据合并完成,总形状: {df_c12.shape}")
  573. # plot_c12_trend(df_c12, output_dir)
  574. # print(f"✅ dep_date:{dep_date}, 所有 baggage 数据绘图完成")
  575. else:
  576. df_c12 = pd.DataFrame()
  577. print(f"⚠️ dep_date:{dep_date}, 所有 baggage 数据合并为空")
  578. del list_12
  579. list_mid.append(df_c12)
  580. del df_c12
  581. del df_d1
  582. del df_d2
  583. print(f"结束处理起飞日期: {dep_date}")
  584. if list_mid:
  585. df_mid = pd.concat(list_mid, ignore_index=True)
  586. print(f"✅ 航班号:{flight_nums} 所有 起飞日期 数据合并完成,总形状: {df_mid.shape}")
  587. else:
  588. df_mid = pd.DataFrame()
  589. print(f"⚠️ 航班号:{flight_nums} 所有 起飞日期 数据合并为空")
  590. del list_mid
  591. list_all.append(df_mid)
  592. del df1
  593. del df2
  594. # output_path = os.path.join(output_dir, f"./{route}_{timestamp_str}.csv")
  595. # df_mid.to_csv(output_path, index=False, encoding="utf-8-sig", mode="a", header=not os.path.exists(output_path))
  596. del df_mid
  597. gc.collect()
  598. print(f"结束处理航班号: {flight_nums}")
  599. print(f"结束处理航线: {from_city}-{to_city}")
  600. df_all = pd.concat(list_all, ignore_index=True)
  601. print(f"本批次数据加载完毕, 总形状: {df_all.shape}")
  602. del list_all
  603. gc.collect()
  604. return df_all
  605. def chunk_list(lst, group_size):
  606. return [lst[i:i + group_size] for i in range(0, len(lst), group_size)]
  607. if __name__ == "__main__":
  608. # test_mongo_connection(db)
  609. output_dir = f"./output"
  610. os.makedirs(output_dir, exist_ok=True)
  611. # 加载热门航线数据
  612. date_begin = "2025-11-20"
  613. date_end = datetime.today().strftime("%Y-%m-%d")
  614. flight_route_list = vj_flight_route_list_hot[0:] # 热门 vj_flight_route_list_hot 冷门 vj_flight_route_list_nothot
  615. table_name = CLEAN_VJ_HOT_NEAR_INFO_TAB # 热门 CLEAN_VJ_HOT_NEAR_INFO_TAB 冷门 CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
  616. is_hot = 1 # 1 热门 0 冷门
  617. group_size = 1
  618. chunks = chunk_list(flight_route_list, group_size)
  619. for idx, group_route_list in enumerate(chunks, 1):
  620. # 使用默认配置
  621. client, db = mongo_con_parse()
  622. print(f"第 {idx} 组 :", group_route_list)
  623. start_time = time.time()
  624. load_train_data(db, group_route_list, table_name, date_begin, date_end, output_dir, is_hot)
  625. end_time = time.time()
  626. run_time = round(end_time - start_time, 3)
  627. print(f"用时: {run_time} 秒")
  628. client.close()
  629. time.sleep(3)
  630. print("整体结束")