data_loader.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567
  1. import os
  2. import time
  3. import random
  4. from datetime import datetime, timedelta
  5. import gc
  6. from concurrent.futures import ProcessPoolExecutor, as_completed
  7. import pymongo
  8. from pymongo.errors import PyMongoError, ServerSelectionTimeoutError
  9. import pandas as pd
  10. import matplotlib.pyplot as plt
  11. from matplotlib import font_manager
  12. import matplotlib.dates as mdates
  13. from uo_atlas_import import mongo_con_parse
  14. from config import mongo_config, mongo_table_uo, uo_city_pairs_old, uo_city_pairs_new
  15. font_path = "./simhei.ttf"
  16. font_prop = font_manager.FontProperties(fname=font_path)
  17. def query_groups_of_city_pair(db, city_pair, table_name, min_days=10, max_retries=3, base_sleep=1.0):
  18. """根据city_pair查询航线, 筛选1个月内至少有10天起飞的航线"""
  19. print(f"{city_pair} 查找所有分组")
  20. date_begin = (datetime.today() - timedelta(days=30)).strftime("%Y-%m-%d")
  21. date_end = datetime.today().strftime("%Y-%m-%d")
  22. # 聚合查询管道
  23. pipeline = [
  24. {
  25. "$match": {
  26. "citypair": city_pair,
  27. "from_date": {
  28. "$gte": date_begin,
  29. "$lte": date_end
  30. }
  31. }
  32. },
  33. {
  34. "$group": {
  35. "_id": {
  36. "flight_numbers": "$flight_numbers",
  37. "from_date": "$from_date"
  38. }
  39. }
  40. },
  41. {
  42. "$group": {
  43. "_id": "$_id.flight_numbers",
  44. "days": {"$sum": 1},
  45. "details": {"$push": "$_id.from_date"}
  46. }
  47. },
  48. {
  49. "$match": {
  50. "days": {"$gte": min_days}
  51. }
  52. },
  53. {
  54. "$addFields": {
  55. "details": {"$sortArray": {"input": "$details", "sortBy": 1}}
  56. }
  57. },
  58. {
  59. "$sort": {"_id": 1}
  60. }
  61. ]
  62. for attempt in range(1, max_retries + 1):
  63. try:
  64. print(f" 第 {attempt}/{max_retries} 次尝试查询")
  65. # 执行聚合查询
  66. collection = db[table_name]
  67. results = list(collection.aggregate(pipeline))
  68. # 格式化结果,使字段名更清晰
  69. formatted_results = [
  70. {
  71. "flight_numbers": r["_id"],
  72. "days": r["days"],
  73. "flight_dates": r["details"]
  74. }
  75. for r in results
  76. ]
  77. return formatted_results
  78. except (ServerSelectionTimeoutError, PyMongoError) as e:
  79. print(f"⚠️ Mongo 查询失败: {e}")
  80. if attempt == max_retries:
  81. print("❌ 达到最大重试次数,放弃")
  82. return []
  83. # 指数退避 + 随机抖动
  84. sleep_time = base_sleep * (2 ** (attempt - 1)) + random.random()
  85. print(f"⏳ {sleep_time:.2f}s 后重试...")
  86. time.sleep(sleep_time)
  87. def query_flight_range_status(db, table_name, city_pair, flight_numbers, from_date_begin, from_date_end,
  88. limit=0, max_retries=3, base_sleep=1.0):
  89. for attempt in range(1, max_retries + 1):
  90. try:
  91. print(f"🔁 第 {attempt}/{max_retries} 次尝试查询")
  92. # 构建查询条件
  93. projection = {
  94. # "_id": 0 # 一般会关掉
  95. "citypair": 1,
  96. "flight_numbers": 1,
  97. "from_date": 1,
  98. "from_time": 1,
  99. "create_time": 1,
  100. "baggage_weight": 1,
  101. "cabins": 1,
  102. "ticket_amount": 1,
  103. "currency": 1,
  104. "price_base": 1,
  105. "price_tax": 1,
  106. "price_total": 1
  107. }
  108. pipeline = [
  109. {
  110. "$match": {
  111. "citypair": city_pair,
  112. "flight_numbers": flight_numbers,
  113. "baggage_weight": {"$in": [0, 20]},
  114. "from_date": {
  115. "$gte": from_date_begin,
  116. "$lte": from_date_end
  117. }
  118. }
  119. },
  120. {
  121. "$project": projection # 就是这里
  122. },
  123. {
  124. "$sort": {
  125. "from_date": 1,
  126. "baggage_weight": 1,
  127. "create_time": 1
  128. }
  129. }
  130. ]
  131. # print(f" 查询条件: {pipeline}")
  132. # 执行查询
  133. collection = db[table_name]
  134. results = list(collection.aggregate(pipeline))
  135. print(f"✅ 查询成功,找到 {len(results)} 条记录")
  136. if results:
  137. df = pd.DataFrame(results)
  138. if '_id' in df.columns:
  139. df = df.drop(columns=['_id'])
  140. if 'from_time' in df.columns and 'from_date' in df.columns:
  141. from_time_raw = df['from_time']
  142. from_time_str = from_time_raw.fillna('').astype(str).str.strip()
  143. non_empty = from_time_str[from_time_str.ne('')] # 找到原始 from_time 非空的记录
  144. extracted_time = non_empty.str.extract(r'(\d{2}:\d{2}:\d{2})$')[0].dropna()
  145. if not extracted_time.empty:
  146. more_time = extracted_time.value_counts().idxmax() # 按众数分配给其它行 构造from_time
  147. missing_mask = from_time_raw.isna() | from_time_str.eq('')
  148. if missing_mask.any():
  149. df.loc[missing_mask, 'from_time'] = df.loc[missing_mask, 'from_date'].astype(str).str.strip() + ' ' + more_time
  150. else:
  151. # 无法得到起飞日期的抛弃
  152. print(f"⚠️ 无法提取有效起飞时间,抛弃该条记录")
  153. return pd.DataFrame()
  154. print(f"📊 已转换为 DataFrame,形状: {df.shape}")
  155. return df
  156. else:
  157. print("⚠️ 查询结果为空")
  158. return pd.DataFrame()
  159. except (ServerSelectionTimeoutError, PyMongoError) as e:
  160. print(f"⚠️ Mongo 查询失败: {e}")
  161. if attempt == max_retries:
  162. print("❌ 达到最大重试次数,放弃")
  163. return pd.DataFrame()
  164. # 指数退避 + 随机抖动
  165. sleep_time = base_sleep * (2 ** (attempt - 1)) + random.random()
  166. print(f"⏳ {sleep_time:.2f}s 后重试...")
  167. time.sleep(sleep_time)
  168. def plot_c1_trend(df, output_dir="."):
  169. """
  170. 根据传入的 dataframe 绘制 price_total 随 update_hour 的趋势图,
  171. 并按照 baggage 分类进行分组绘制。
  172. """
  173. # 颜色与线型配置(按顺序循环使用)
  174. colors = ['green', 'blue', 'red', 'brown']
  175. linestyles = ['--', '--', '--', '--']
  176. # 确保时间字段为 datetime 类型
  177. if not hasattr(df['update_hour'], 'dt'):
  178. df['update_hour'] = pd.to_datetime(df['update_hour'])
  179. city_pair = df['citypair'].mode().iloc[0]
  180. flight_numbers = df['flight_numbers'].mode().iloc[0]
  181. from_time = df['from_time'].mode().iloc[0]
  182. output_dir_temp = os.path.join(output_dir, city_pair)
  183. os.makedirs(output_dir_temp, exist_ok=True)
  184. # 创建图表对象
  185. fig = plt.figure(figsize=(14, 8))
  186. # 按 baggage_weight 分类绘制
  187. for i, (baggage_value, group) in enumerate(df.groupby('baggage_weight')):
  188. # 按时间排序
  189. df_g = group.sort_values('update_hour').reset_index(drop=True)
  190. # 找价格变化点:与前一行不同的价格即为变化点
  191. # keep first row + change rows + last row
  192. change_points = df_g.loc[
  193. (df_g['price_total'] != df_g['price_total'].shift(1)) |
  194. (df_g.index == 0) |
  195. (df_g.index == len(df_g) - 1) # 终点
  196. ].drop_duplicates(subset=['update_hour'])
  197. # 绘制阶梯线(平缓-突变)+ 变化点
  198. plt.step(
  199. change_points['update_hour'],
  200. change_points['price_total'],
  201. where='post',
  202. color=colors[i % len(colors)],
  203. linestyle=linestyles[i % len(linestyles)],
  204. linewidth=2,
  205. label=f"Baggage {baggage_value}"
  206. )
  207. plt.scatter(
  208. change_points['update_hour'],
  209. change_points['price_total'],
  210. s=30,
  211. facecolors='white',
  212. edgecolors=colors[i % len(colors)],
  213. linewidths=2,
  214. zorder=3,
  215. )
  216. # 添加注释 (小时数, 价格)
  217. for _, row in change_points.iterrows():
  218. text = f"({row['hours_until_departure']}, {row['price_total']})"
  219. plt.annotate(
  220. text,
  221. xy=(row['update_hour'], row['price_total']),
  222. xytext=(0, 0), # 向右偏移
  223. textcoords="offset points",
  224. ha='left',
  225. va='center',
  226. fontsize=5, # 字体稍小
  227. color='gray',
  228. alpha=0.8,
  229. rotation=25,
  230. )
  231. del change_points
  232. del df_g
  233. # 自动优化日期显示
  234. plt.gcf().autofmt_xdate()
  235. plt.xlabel('时刻', fontsize=12, fontproperties=font_prop)
  236. plt.ylabel('价格', fontsize=12, fontproperties=font_prop)
  237. plt.title(f'价格变化趋势 - 航线: {city_pair} 航班号: {flight_numbers}\n起飞时间: {from_time}',
  238. fontsize=14, fontweight='bold', fontproperties=font_prop)
  239. # 设置 x 轴刻度为每天
  240. ax = plt.gca()
  241. ax.xaxis.set_major_locator(mdates.DayLocator(interval=1)) # 每天一个主刻度
  242. ax.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d')) # 显示月-日
  243. ax.xaxis.set_minor_locator(mdates.HourLocator(byhour=[12])) # 指定在12:00显示副刻度
  244. ax.xaxis.set_minor_formatter(mdates.DateFormatter('')) # 输出空字符串
  245. # 添加图例
  246. plt.legend(bbox_to_anchor=(1.05, 1), loc='best', prop=font_prop)
  247. plt.grid(True, alpha=0.3)
  248. plt.tight_layout()
  249. safe_flight = flight_numbers.replace(",", "_")
  250. safe_dep_time = from_time.strftime("%Y-%m-%d %H%M%S")
  251. save_file = f"{city_pair} {safe_flight} {safe_dep_time}.png"
  252. output_path = os.path.join(output_dir_temp, save_file)
  253. # 保存图片(在显示之前)
  254. plt.savefig(output_path, dpi=300, bbox_inches='tight', facecolor='white')
  255. # 关闭图形释放内存
  256. plt.close(fig)
  257. def fill_hourly_create_time(df, head_fill=0, rear_fill=0):
  258. """补齐成小时粒度数据"""
  259. df = df.copy()
  260. # 1. 转 datetime
  261. df['create_time'] = pd.to_datetime(df['create_time'])
  262. df['from_time'] = pd.to_datetime(df['from_time'])
  263. # 添加一个用于分组的小时字段
  264. df['update_hour'] = df['create_time'].dt.floor('h')
  265. # 2. 排序规则:同一小时内,按原始时间戳排序
  266. # 假设你想保留最早的一条
  267. df = df.sort_values(['update_hour', 'create_time'])
  268. # 3. 按小时去重,保留该小时内最早(最晚)的一条
  269. df = df.drop_duplicates(subset=['update_hour'], keep='last') # keep='first' keep='last'
  270. # 4. 标记原始数据
  271. df['is_filled'] = 0
  272. # 5. 排序 + 设索引
  273. df = df.sort_values('update_hour').set_index('update_hour')
  274. # 6. 构造完整小时轴
  275. start_of_hour = df.index.min() # 默认 第一天 最早 开始
  276. if head_fill == 1:
  277. start_of_hour = df.index.min().normalize() # 强制 第一天 00:00 开始
  278. end_of_hour = df.index.max() # 默认 最后一天 最晚 结束
  279. if rear_fill == 1:
  280. end_of_hour = df.index.max().normalize() + pd.Timedelta(hours=23) # 强制 最后一天 23:00 结束
  281. elif rear_fill == 2:
  282. if 'from_time' in df.columns:
  283. last_dep_time = df['from_time'].iloc[-1]
  284. if pd.notna(last_dep_time):
  285. # 对齐到整点小时(向下取整)
  286. end_of_hour = last_dep_time.floor('h')
  287. full_index = pd.date_range(
  288. start=start_of_hour,
  289. end=end_of_hour,
  290. freq='1h'
  291. )
  292. # 7. 按小时补齐
  293. df = df.reindex(full_index)
  294. # 先恢复 dtype(关键!)
  295. df = df.infer_objects(copy=False)
  296. # 8. 新增出来的行标记为 1
  297. df['is_filled'] = df['is_filled'].fillna(1)
  298. # 9. 前向填充
  299. df = df.ffill()
  300. # 10. 还原整型字段
  301. int_cols = [
  302. 'ticket_amount',
  303. 'baggage_weight',
  304. 'is_filled',
  305. ]
  306. for col in int_cols:
  307. if col in df.columns:
  308. df[col] = df[col].astype('int64')
  309. # 10.5 价格字段统一保留两位小数
  310. price_cols = [
  311. 'price_base',
  312. 'price_tax',
  313. 'price_total'
  314. ]
  315. for col in price_cols:
  316. if col in df.columns:
  317. df[col] = df[col].astype('float64').round(2)
  318. # 10.6 新增:距离起飞还有多少小时
  319. if 'from_time' in df.columns:
  320. # 创建临时字段(整点)
  321. df['from_hour'] = df['from_time'].dt.floor('h')
  322. # 计算小时差 df.index 此时就是 update_hour
  323. df['hours_until_departure'] = (
  324. (df['from_hour'] - df.index) / pd.Timedelta(hours=1)
  325. ).astype('int64')
  326. # 新增:距离起飞还有多少天
  327. df['days_to_departure'] = (df['hours_until_departure'] // 24).astype('int64')
  328. # 删除临时字段
  329. df = df.drop(columns=['from_hour'])
  330. # 11. 写回 update_hour
  331. df['update_hour'] = df.index
  332. # 12. 恢复普通索引
  333. df = df.reset_index(drop=True)
  334. return df
  335. def process_flight_numbers(args):
  336. process_id, db_config, city_pair, flight_numbers, from_date_begin, from_date_end, is_train, plot_flag, output_dir = args
  337. print(f"[进程{process_id}] 开始处理航班号: {flight_numbers}")
  338. # 为每个进程创建独立的数据库连接
  339. try:
  340. client, db = mongo_con_parse(db_config)
  341. print(f"[进程{process_id}] ✅ 数据库连接创建成功")
  342. except Exception as e:
  343. print(f"[进程{process_id}] ❌ 数据库连接创建失败: {e}")
  344. return pd.DataFrame()
  345. try:
  346. # 查询
  347. df1 = query_flight_range_status(db, mongo_table_uo, city_pair, flight_numbers, from_date_begin, from_date_end)
  348. if df1.empty:
  349. return pd.DataFrame()
  350. common_dep_dates = df1['from_date'].unique()
  351. common_baggages = df1['baggage_weight'].unique()
  352. list_mid = []
  353. for dep_date in common_dep_dates:
  354. # 起飞日期筛选
  355. df_d1 = df1[df1["from_date"] == dep_date].copy()
  356. list_f1 = []
  357. for baggage in common_baggages:
  358. # 行李配额筛选
  359. df_b1 = df_d1[df_d1["baggage_weight"] == baggage].copy()
  360. if df_b1.empty:
  361. print(f"[进程{process_id}] ⚠️ dep_date:{dep_date}, baggage:{baggage} 为空,跳过")
  362. continue
  363. df_f1 = fill_hourly_create_time(df_b1, rear_fill=2)
  364. list_f1.append(df_f1)
  365. del df_f1
  366. del df_b1
  367. if list_f1:
  368. df_c1 = pd.concat(list_f1, ignore_index=True)
  369. if plot_flag:
  370. print(f"[进程{process_id}] ✅ dep_date:{dep_date}, 所有 baggage 数据合并完成,总形状: {df_c1.shape}")
  371. plot_c1_trend(df_c1, output_dir)
  372. print(f"[进程{process_id}] ✅ dep_date:{dep_date}, 所有 baggage 数据绘图完成")
  373. else:
  374. df_c1 = pd.DataFrame()
  375. if plot_flag:
  376. print(f"[进程{process_id}] ⚠️ dep_date:{dep_date}, 所有 baggage 数据合并为空")
  377. del list_f1
  378. list_mid.append(df_c1)
  379. del df_c1
  380. del df_d1
  381. if list_mid:
  382. df_mid = pd.concat(list_mid, ignore_index=True)
  383. print(f"[进程{process_id}] ✅ 航班号:{flight_numbers} 所有 起飞日期 数据合并完成,总形状: {df_mid.shape}")
  384. else:
  385. df_mid = pd.DataFrame()
  386. print(f"[进程{process_id}] ⚠️ 航班号:{flight_numbers} 所有 起飞日期 数据合并为空")
  387. del list_mid
  388. del df1
  389. gc.collect()
  390. print(f"[进程{process_id}] 结束处理航班号: {flight_numbers}")
  391. return df_mid
  392. except Exception as e:
  393. print(f"[进程{process_id}] ❌ 处理航班号:{flight_numbers} 时发生异常: {e}")
  394. return pd.DataFrame()
  395. finally:
  396. # 确保关闭数据库连接
  397. try:
  398. client.close()
  399. print(f"[进程{process_id}] ✅ 数据库连接已关闭")
  400. except:
  401. pass
  402. def load_data(db_config, city_pair, from_date_begin, from_date_end, is_train=True, plot_flag=False, output_dir='.',
  403. use_multiprocess=False, max_workers=None):
  404. list_all = []
  405. print(f"开始处理航线: {city_pair}")
  406. main_client, main_db = mongo_con_parse(db_config)
  407. all_groups = query_groups_of_city_pair(main_db, city_pair, mongo_table_uo)
  408. main_client.close()
  409. all_groups_len = len(all_groups)
  410. print(f"该航线共有{all_groups_len}组航班号")
  411. if use_multiprocess and all_groups_len > 1:
  412. print(f"启用多进程处理,最大进程数: {max_workers}")
  413. # 多进程处理
  414. process_args = []
  415. process_id = 0
  416. for each_group in all_groups:
  417. flight_numbers = each_group.get("flight_numbers", "未知")
  418. process_id += 1
  419. args = (process_id, db_config, city_pair, flight_numbers, from_date_begin, from_date_end, is_train, plot_flag, output_dir)
  420. process_args.append(args)
  421. with ProcessPoolExecutor(max_workers=max_workers) as executor:
  422. future_to_group = {executor.submit(process_flight_numbers, args): each_group for args, each_group in zip(process_args, all_groups)}
  423. for future in as_completed(future_to_group):
  424. each_group = future_to_group[future]
  425. flight_numbers = each_group.get("flight_numbers", "未知")
  426. try:
  427. df_mid = future.result()
  428. if not df_mid.empty:
  429. list_all.append(df_mid)
  430. print(f"✅ 航班号:{flight_numbers} 处理完成")
  431. else:
  432. print(f"⚠️ 航班号:{flight_numbers} 处理结果为空")
  433. except Exception as e:
  434. print(f"❌ 航班号:{flight_numbers} 处理异常: {e}")
  435. pass
  436. else:
  437. print("使用单进程处理")
  438. process_id = 0
  439. for each_group in all_groups:
  440. flight_numbers = each_group.get("flight_numbers", "未知")
  441. args = (process_id, db_config, city_pair, flight_numbers, from_date_begin, from_date_end, is_train, plot_flag, output_dir)
  442. try:
  443. df_mid = process_flight_numbers(args)
  444. if not df_mid.empty:
  445. list_all.append(df_mid)
  446. print(f"✅ 航班号:{flight_numbers} 处理完成")
  447. else:
  448. print(f"⚠️ 航班号:{flight_numbers} 处理结果为空")
  449. except Exception as e:
  450. print(f"❌ 航班号:{flight_numbers} 处理异常: {e}")
  451. print(f"结束处理航线: {city_pair}")
  452. if list_all:
  453. df_all = pd.concat(list_all, ignore_index=True)
  454. else:
  455. df_all = pd.DataFrame()
  456. print(f"本批次数据加载完毕, 总形状: {df_all.shape}")
  457. del list_all
  458. gc.collect()
  459. return df_all
  460. if __name__ == "__main__":
  461. cpu_cores = os.cpu_count() # 你的系统是72
  462. max_workers = min(8, cpu_cores) # 最大不超过8个进程
  463. output_dir = f"./photo"
  464. os.makedirs(output_dir, exist_ok=True)
  465. from_date_begin = "2026-03-17"
  466. from_date_end = "2026-03-26"
  467. uo_city_pairs = uo_city_pairs_new.copy()
  468. uo_city_pair_list = [f"{pair[:3]}-{pair[3:]}" for pair in uo_city_pairs]
  469. for idx, uo_city_pair in enumerate(uo_city_pair_list, start=1):
  470. # 使用默认配置
  471. # client, db = mongo_con_parse()
  472. print(f"第 {idx} 组 :", uo_city_pair)
  473. start_time = time.time()
  474. load_data(mongo_config, uo_city_pair, from_date_begin, from_date_end,
  475. plot_flag=True, output_dir=output_dir, use_multiprocess=True, max_workers=max_workers)
  476. end_time = time.time()
  477. run_time = round(end_time - start_time, 3)
  478. print(f"用时: {run_time} 秒")