data_loader.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594
  1. import os
  2. import time
  3. import random
  4. from datetime import datetime, timedelta
  5. import gc
  6. from concurrent.futures import ProcessPoolExecutor, as_completed
  7. import pymongo
  8. from pymongo.errors import PyMongoError, ServerSelectionTimeoutError
  9. import pandas as pd
  10. import matplotlib.pyplot as plt
  11. from matplotlib import font_manager
  12. import matplotlib.dates as mdates
  13. from uo_atlas_import import mongo_con_parse
  14. from config import mongo_config, mongo_table_uo, uo_city_pairs_old, uo_city_pairs_new
  15. font_path = "./simhei.ttf"
  16. font_prop = font_manager.FontProperties(fname=font_path)
  17. def query_groups_of_city_pair(db, city_pair, table_name, min_days=10, max_retries=3, base_sleep=1.0):
  18. """根据city_pair查询航线, 筛选1个月内至少有10天起飞的航线"""
  19. print(f"{city_pair} 查找所有分组")
  20. date_begin = (datetime.today() - timedelta(days=30)).strftime("%Y-%m-%d")
  21. date_end = datetime.today().strftime("%Y-%m-%d")
  22. # 聚合查询管道
  23. pipeline = [
  24. {
  25. "$match": {
  26. "citypair": city_pair,
  27. "from_date": {
  28. "$gte": date_begin,
  29. "$lte": date_end
  30. }
  31. }
  32. },
  33. {
  34. "$group": {
  35. "_id": {
  36. "flight_numbers": "$flight_numbers",
  37. "from_date": "$from_date"
  38. }
  39. }
  40. },
  41. {
  42. "$group": {
  43. "_id": "$_id.flight_numbers",
  44. "days": {"$sum": 1},
  45. "details": {"$push": "$_id.from_date"}
  46. }
  47. },
  48. {
  49. "$match": {
  50. "days": {"$gte": min_days}
  51. }
  52. },
  53. {
  54. "$addFields": {
  55. "details": {"$sortArray": {"input": "$details", "sortBy": 1}}
  56. }
  57. },
  58. {
  59. "$sort": {"_id": 1}
  60. }
  61. ]
  62. for attempt in range(1, max_retries + 1):
  63. try:
  64. print(f" 第 {attempt}/{max_retries} 次尝试查询")
  65. # 执行聚合查询
  66. collection = db[table_name]
  67. results = list(collection.aggregate(pipeline))
  68. # 格式化结果,使字段名更清晰
  69. formatted_results = [
  70. {
  71. "flight_numbers": r["_id"],
  72. "days": r["days"],
  73. "flight_dates": r["details"]
  74. }
  75. for r in results
  76. ]
  77. return formatted_results
  78. except (ServerSelectionTimeoutError, PyMongoError) as e:
  79. print(f"⚠️ Mongo 查询失败: {e}")
  80. if attempt == max_retries:
  81. print("❌ 达到最大重试次数,放弃")
  82. return []
  83. # 指数退避 + 随机抖动
  84. sleep_time = base_sleep * (2 ** (attempt - 1)) + random.random()
  85. print(f"⏳ {sleep_time:.2f}s 后重试...")
  86. time.sleep(sleep_time)
  87. def query_flight_range_status(db, table_name, city_pair, flight_numbers, from_date_begin, from_date_end,
  88. limit=0, max_retries=3, base_sleep=1.0):
  89. for attempt in range(1, max_retries + 1):
  90. try:
  91. print(f"🔁 第 {attempt}/{max_retries} 次尝试查询")
  92. # 构建查询条件
  93. projection = {
  94. # "_id": 0 # 一般会关掉
  95. "citypair": 1,
  96. "flight_numbers": 1,
  97. "from_date": 1,
  98. "from_time": 1,
  99. "create_time": 1,
  100. "baggage_weight": 1,
  101. "cabins": 1,
  102. "ticket_amount": 1,
  103. "currency": 1,
  104. "price_base": 1,
  105. "price_tax": 1,
  106. "price_total": 1
  107. }
  108. pipeline = [
  109. {
  110. "$match": {
  111. "citypair": city_pair,
  112. "flight_numbers": flight_numbers,
  113. "baggage_weight": {"$in": [0, 20]},
  114. "from_date": {
  115. "$gte": from_date_begin,
  116. "$lte": from_date_end
  117. }
  118. }
  119. },
  120. {
  121. "$project": projection # 就是这里
  122. },
  123. {
  124. "$sort": {
  125. "from_date": 1,
  126. "baggage_weight": 1,
  127. "create_time": 1
  128. }
  129. }
  130. ]
  131. # print(f" 查询条件: {pipeline}")
  132. # 执行查询
  133. collection = db[table_name]
  134. results = list(collection.aggregate(pipeline))
  135. print(f"✅ 查询成功,找到 {len(results)} 条记录")
  136. if results:
  137. df = pd.DataFrame(results)
  138. if '_id' in df.columns:
  139. df = df.drop(columns=['_id'])
  140. if 'from_time' in df.columns and 'from_date' in df.columns:
  141. key_cols = ['citypair', 'flight_numbers', 'from_date']
  142. group_cols = [col for col in key_cols if col in df.columns]
  143. from_time_raw = df['from_time']
  144. from_time_str = from_time_raw.fillna('').astype(str).str.strip()
  145. extracted_time = from_time_str.str.extract(r'(\d{2}:\d{2}:\d{2})$')[0]
  146. valid_time_mask = from_time_str.ne('') & extracted_time.notna()
  147. if valid_time_mask.any():
  148. missing_mask = from_time_raw.isna() | from_time_str.eq('')
  149. if group_cols:
  150. mode_by_group = (
  151. df.loc[valid_time_mask, group_cols]
  152. .assign(_mode_time=extracted_time.loc[valid_time_mask].values)
  153. .groupby(group_cols, dropna=False)['_mode_time']
  154. .agg(lambda s: s.value_counts().idxmax())
  155. .reset_index()
  156. )
  157. df = df.merge(mode_by_group, on=group_cols, how='left')
  158. fill_mask = missing_mask & df['_mode_time'].notna()
  159. if fill_mask.any():
  160. df.loc[fill_mask, 'from_time'] = (
  161. df.loc[fill_mask, 'from_date'].astype(str).str.strip() + ' ' + df.loc[fill_mask, '_mode_time']
  162. )
  163. df = df.drop(columns=['_mode_time'])
  164. remaining_missing_mask = df['from_time'].isna() | df['from_time'].astype(str).str.strip().eq('')
  165. if remaining_missing_mask.any():
  166. more_time = extracted_time.loc[valid_time_mask].value_counts().idxmax()
  167. df.loc[remaining_missing_mask, 'from_time'] = (
  168. df.loc[remaining_missing_mask, 'from_date'].astype(str).str.strip() + ' ' + more_time
  169. )
  170. pass
  171. else:
  172. more_time = extracted_time.loc[valid_time_mask].value_counts().idxmax()
  173. if missing_mask.any():
  174. df.loc[missing_mask, 'from_time'] = df.loc[missing_mask, 'from_date'].astype(str).str.strip() + ' ' + more_time
  175. else:
  176. print(f"⚠️ 无法提取有效起飞时间,抛弃该条记录")
  177. return pd.DataFrame()
  178. print(f"📊 已转换为 DataFrame,形状: {df.shape}")
  179. return df
  180. else:
  181. print("⚠️ 查询结果为空")
  182. return pd.DataFrame()
  183. except (ServerSelectionTimeoutError, PyMongoError) as e:
  184. print(f"⚠️ Mongo 查询失败: {e}")
  185. if attempt == max_retries:
  186. print("❌ 达到最大重试次数,放弃")
  187. return pd.DataFrame()
  188. # 指数退避 + 随机抖动
  189. sleep_time = base_sleep * (2 ** (attempt - 1)) + random.random()
  190. print(f"⏳ {sleep_time:.2f}s 后重试...")
  191. time.sleep(sleep_time)
  192. def plot_c1_trend(df, output_dir="."):
  193. """
  194. 根据传入的 dataframe 绘制 price_total 随 update_hour 的趋势图,
  195. 并按照 baggage 分类进行分组绘制。
  196. """
  197. # 颜色与线型配置(按顺序循环使用)
  198. colors = ['green', 'blue', 'red', 'brown']
  199. linestyles = ['--', '--', '--', '--']
  200. # 确保时间字段为 datetime 类型
  201. if not hasattr(df['update_hour'], 'dt'):
  202. df['update_hour'] = pd.to_datetime(df['update_hour'])
  203. city_pair = df['citypair'].mode().iloc[0]
  204. flight_numbers = df['flight_numbers'].mode().iloc[0]
  205. from_time = df['from_time'].mode().iloc[0]
  206. output_dir_temp = os.path.join(output_dir, city_pair)
  207. os.makedirs(output_dir_temp, exist_ok=True)
  208. # 创建图表对象
  209. fig = plt.figure(figsize=(14, 8))
  210. # 按 baggage_weight 分类绘制
  211. for i, (baggage_value, group) in enumerate(df.groupby('baggage_weight')):
  212. # 按时间排序
  213. df_g = group.sort_values('update_hour').reset_index(drop=True)
  214. # 找价格变化点:与前一行不同的价格即为变化点
  215. # keep first row + change rows + last row
  216. change_points = df_g.loc[
  217. (df_g['price_total'] != df_g['price_total'].shift(1)) |
  218. (df_g.index == 0) |
  219. (df_g.index == len(df_g) - 1) # 终点
  220. ].drop_duplicates(subset=['update_hour'])
  221. # 绘制阶梯线(平缓-突变)+ 变化点
  222. plt.step(
  223. change_points['update_hour'],
  224. change_points['price_total'],
  225. where='post',
  226. color=colors[i % len(colors)],
  227. linestyle=linestyles[i % len(linestyles)],
  228. linewidth=2,
  229. label=f"Baggage {baggage_value}"
  230. )
  231. plt.scatter(
  232. change_points['update_hour'],
  233. change_points['price_total'],
  234. s=30,
  235. facecolors='white',
  236. edgecolors=colors[i % len(colors)],
  237. linewidths=2,
  238. zorder=3,
  239. )
  240. # 添加注释 (小时数, 价格)
  241. for _, row in change_points.iterrows():
  242. text = f"({row['hours_until_departure']}, {row['price_total']})"
  243. plt.annotate(
  244. text,
  245. xy=(row['update_hour'], row['price_total']),
  246. xytext=(0, 0), # 向右偏移
  247. textcoords="offset points",
  248. ha='left',
  249. va='center',
  250. fontsize=5, # 字体稍小
  251. color='gray',
  252. alpha=0.8,
  253. rotation=25,
  254. )
  255. del change_points
  256. del df_g
  257. # 自动优化日期显示
  258. plt.gcf().autofmt_xdate()
  259. plt.xlabel('时刻', fontsize=12, fontproperties=font_prop)
  260. plt.ylabel('价格', fontsize=12, fontproperties=font_prop)
  261. plt.title(f'价格变化趋势 - 航线: {city_pair} 航班号: {flight_numbers}\n起飞时间: {from_time}',
  262. fontsize=14, fontweight='bold', fontproperties=font_prop)
  263. # 设置 x 轴刻度为每天
  264. ax = plt.gca()
  265. ax.xaxis.set_major_locator(mdates.DayLocator(interval=1)) # 每天一个主刻度
  266. ax.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d')) # 显示月-日
  267. ax.xaxis.set_minor_locator(mdates.HourLocator(byhour=[12])) # 指定在12:00显示副刻度
  268. ax.xaxis.set_minor_formatter(mdates.DateFormatter('')) # 输出空字符串
  269. # 添加图例
  270. plt.legend(bbox_to_anchor=(1.05, 1), loc='best', prop=font_prop)
  271. plt.grid(True, alpha=0.3)
  272. plt.tight_layout()
  273. safe_flight = flight_numbers.replace(",", "_")
  274. safe_dep_time = from_time.strftime("%Y-%m-%d %H%M%S")
  275. save_file = f"{city_pair} {safe_flight} {safe_dep_time}.png"
  276. output_path = os.path.join(output_dir_temp, save_file)
  277. # 保存图片(在显示之前)
  278. plt.savefig(output_path, dpi=300, bbox_inches='tight', facecolor='white')
  279. # 关闭图形释放内存
  280. plt.close(fig)
  281. def fill_hourly_create_time(df, head_fill=0, rear_fill=0):
  282. """补齐成小时粒度数据"""
  283. df = df.copy()
  284. # 1. 转 datetime
  285. df['create_time'] = pd.to_datetime(df['create_time'])
  286. df['from_time'] = pd.to_datetime(df['from_time'])
  287. # 添加一个用于分组的小时字段
  288. df['update_hour'] = df['create_time'].dt.floor('h')
  289. # 2. 排序规则:同一小时内,按原始时间戳排序
  290. # 假设你想保留最早的一条
  291. df = df.sort_values(['update_hour', 'create_time'])
  292. # 3. 按小时去重,保留该小时内最早(最晚)的一条
  293. df = df.drop_duplicates(subset=['update_hour'], keep='last') # keep='first' keep='last'
  294. # 4. 标记原始数据
  295. df['is_filled'] = 0
  296. # 5. 排序 + 设索引
  297. df = df.sort_values('update_hour').set_index('update_hour')
  298. # 6. 构造完整小时轴
  299. start_of_hour = df.index.min() # 默认 第一天 最早 开始
  300. if head_fill == 1:
  301. start_of_hour = df.index.min().normalize() # 强制 第一天 00:00 开始
  302. end_of_hour = df.index.max() # 默认 最后一天 最晚 结束
  303. if rear_fill == 1:
  304. end_of_hour = df.index.max().normalize() + pd.Timedelta(hours=23) # 强制 最后一天 23:00 结束
  305. elif rear_fill == 2:
  306. if 'from_time' in df.columns:
  307. last_dep_time = df['from_time'].iloc[-1]
  308. if pd.notna(last_dep_time):
  309. # 对齐到整点小时(向下取整)
  310. end_of_hour = last_dep_time.floor('h')
  311. full_index = pd.date_range(
  312. start=start_of_hour,
  313. end=end_of_hour,
  314. freq='1h'
  315. )
  316. # 7. 按小时补齐
  317. df = df.reindex(full_index)
  318. # 先恢复 dtype(关键!)
  319. df = df.infer_objects(copy=False)
  320. # 8. 新增出来的行标记为 1
  321. df['is_filled'] = df['is_filled'].fillna(1)
  322. # 9. 前向填充
  323. df = df.ffill()
  324. # 10. 还原整型字段
  325. int_cols = [
  326. 'ticket_amount',
  327. 'baggage_weight',
  328. 'is_filled',
  329. ]
  330. for col in int_cols:
  331. if col in df.columns:
  332. df[col] = df[col].astype('int64')
  333. # 10.5 价格字段统一保留两位小数
  334. price_cols = [
  335. 'price_base',
  336. 'price_tax',
  337. 'price_total'
  338. ]
  339. for col in price_cols:
  340. if col in df.columns:
  341. df[col] = df[col].astype('float64').round(2)
  342. # 10.6 新增:距离起飞还有多少小时
  343. if 'from_time' in df.columns:
  344. # 创建临时字段(整点)
  345. df['from_hour'] = df['from_time'].dt.floor('h')
  346. # 计算小时差 df.index 此时就是 update_hour
  347. df['hours_until_departure'] = (
  348. (df['from_hour'] - df.index) / pd.Timedelta(hours=1)
  349. ).astype('int64')
  350. # 新增:距离起飞还有多少天
  351. df['days_to_departure'] = (df['hours_until_departure'] // 24).astype('int64')
  352. # 删除临时字段
  353. df = df.drop(columns=['from_hour'])
  354. # 11. 写回 update_hour
  355. df['update_hour'] = df.index
  356. # 12. 恢复普通索引
  357. df = df.reset_index(drop=True)
  358. return df
  359. def process_flight_numbers(args):
  360. process_id, db_config, city_pair, flight_numbers, from_date_begin, from_date_end, is_train, plot_flag, output_dir = args
  361. print(f"[进程{process_id}] 开始处理航班号: {flight_numbers}")
  362. # 为每个进程创建独立的数据库连接
  363. try:
  364. client, db = mongo_con_parse(db_config)
  365. print(f"[进程{process_id}] ✅ 数据库连接创建成功")
  366. except Exception as e:
  367. print(f"[进程{process_id}] ❌ 数据库连接创建失败: {e}")
  368. return pd.DataFrame()
  369. try:
  370. # 查询
  371. df1 = query_flight_range_status(db, mongo_table_uo, city_pair, flight_numbers, from_date_begin, from_date_end)
  372. if df1.empty:
  373. return pd.DataFrame()
  374. common_dep_dates = df1['from_date'].unique()
  375. common_baggages = df1['baggage_weight'].unique()
  376. list_mid = []
  377. for dep_date in common_dep_dates:
  378. # 起飞日期筛选
  379. df_d1 = df1[df1["from_date"] == dep_date].copy()
  380. list_f1 = []
  381. for baggage in common_baggages:
  382. # 行李配额筛选
  383. df_b1 = df_d1[df_d1["baggage_weight"] == baggage].copy()
  384. if df_b1.empty:
  385. print(f"[进程{process_id}] ⚠️ dep_date:{dep_date}, baggage:{baggage} 为空,跳过")
  386. continue
  387. df_f1 = fill_hourly_create_time(df_b1, rear_fill=2)
  388. list_f1.append(df_f1)
  389. del df_f1
  390. del df_b1
  391. if list_f1:
  392. df_c1 = pd.concat(list_f1, ignore_index=True)
  393. if plot_flag:
  394. print(f"[进程{process_id}] ✅ dep_date:{dep_date}, 所有 baggage 数据合并完成,总形状: {df_c1.shape}")
  395. plot_c1_trend(df_c1, output_dir)
  396. print(f"[进程{process_id}] ✅ dep_date:{dep_date}, 所有 baggage 数据绘图完成")
  397. else:
  398. df_c1 = pd.DataFrame()
  399. if plot_flag:
  400. print(f"[进程{process_id}] ⚠️ dep_date:{dep_date}, 所有 baggage 数据合并为空")
  401. del list_f1
  402. list_mid.append(df_c1)
  403. del df_c1
  404. del df_d1
  405. if list_mid:
  406. df_mid = pd.concat(list_mid, ignore_index=True)
  407. print(f"[进程{process_id}] ✅ 航班号:{flight_numbers} 所有 起飞日期 数据合并完成,总形状: {df_mid.shape}")
  408. else:
  409. df_mid = pd.DataFrame()
  410. print(f"[进程{process_id}] ⚠️ 航班号:{flight_numbers} 所有 起飞日期 数据合并为空")
  411. del list_mid
  412. del df1
  413. gc.collect()
  414. print(f"[进程{process_id}] 结束处理航班号: {flight_numbers}")
  415. return df_mid
  416. except Exception as e:
  417. print(f"[进程{process_id}] ❌ 处理航班号:{flight_numbers} 时发生异常: {e}")
  418. return pd.DataFrame()
  419. finally:
  420. # 确保关闭数据库连接
  421. try:
  422. client.close()
  423. print(f"[进程{process_id}] ✅ 数据库连接已关闭")
  424. except:
  425. pass
  426. def load_data(db_config, city_pair, from_date_begin, from_date_end, is_train=True, plot_flag=False, output_dir='.',
  427. use_multiprocess=False, max_workers=None):
  428. list_all = []
  429. print(f"开始处理航线: {city_pair}")
  430. main_client, main_db = mongo_con_parse(db_config)
  431. all_groups = query_groups_of_city_pair(main_db, city_pair, mongo_table_uo)
  432. main_client.close()
  433. all_groups_len = len(all_groups)
  434. print(f"该航线共有{all_groups_len}组航班号")
  435. if use_multiprocess and all_groups_len > 1:
  436. print(f"启用多进程处理,最大进程数: {max_workers}")
  437. # 多进程处理
  438. process_args = []
  439. process_id = 0
  440. for each_group in all_groups:
  441. flight_numbers = each_group.get("flight_numbers", "未知")
  442. process_id += 1
  443. args = (process_id, db_config, city_pair, flight_numbers, from_date_begin, from_date_end, is_train, plot_flag, output_dir)
  444. process_args.append(args)
  445. with ProcessPoolExecutor(max_workers=max_workers) as executor:
  446. future_to_group = {executor.submit(process_flight_numbers, args): each_group for args, each_group in zip(process_args, all_groups)}
  447. for future in as_completed(future_to_group):
  448. each_group = future_to_group[future]
  449. flight_numbers = each_group.get("flight_numbers", "未知")
  450. try:
  451. df_mid = future.result()
  452. if not df_mid.empty:
  453. list_all.append(df_mid)
  454. print(f"✅ 航班号:{flight_numbers} 处理完成")
  455. else:
  456. print(f"⚠️ 航班号:{flight_numbers} 处理结果为空")
  457. except Exception as e:
  458. print(f"❌ 航班号:{flight_numbers} 处理异常: {e}")
  459. pass
  460. else:
  461. print("使用单进程处理")
  462. process_id = 0
  463. for each_group in all_groups:
  464. flight_numbers = each_group.get("flight_numbers", "未知")
  465. args = (process_id, db_config, city_pair, flight_numbers, from_date_begin, from_date_end, is_train, plot_flag, output_dir)
  466. try:
  467. df_mid = process_flight_numbers(args)
  468. if not df_mid.empty:
  469. list_all.append(df_mid)
  470. print(f"✅ 航班号:{flight_numbers} 处理完成")
  471. else:
  472. print(f"⚠️ 航班号:{flight_numbers} 处理结果为空")
  473. except Exception as e:
  474. print(f"❌ 航班号:{flight_numbers} 处理异常: {e}")
  475. print(f"结束处理航线: {city_pair}")
  476. if list_all:
  477. df_all = pd.concat(list_all, ignore_index=True)
  478. else:
  479. df_all = pd.DataFrame()
  480. print(f"本批次数据加载完毕, 总形状: {df_all.shape}")
  481. del list_all
  482. gc.collect()
  483. return df_all
  484. if __name__ == "__main__":
  485. cpu_cores = os.cpu_count() # 你的系统是72
  486. max_workers = min(8, cpu_cores) # 最大不超过8个进程
  487. output_dir = f"./photo"
  488. os.makedirs(output_dir, exist_ok=True)
  489. from_date_begin = "2026-03-17"
  490. from_date_end = "2026-03-26"
  491. uo_city_pairs = uo_city_pairs_new.copy()
  492. uo_city_pair_list = [f"{pair[:3]}-{pair[3:]}" for pair in uo_city_pairs]
  493. for idx, uo_city_pair in enumerate(uo_city_pair_list, start=1):
  494. # 使用默认配置
  495. # client, db = mongo_con_parse()
  496. print(f"第 {idx} 组 :", uo_city_pair)
  497. start_time = time.time()
  498. load_data(mongo_config, uo_city_pair, from_date_begin, from_date_end,
  499. plot_flag=True, output_dir=output_dir, use_multiprocess=True, max_workers=max_workers)
  500. end_time = time.time()
  501. run_time = round(end_time - start_time, 3)
  502. print(f"用时: {run_time} 秒")