data_loader.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676
  1. import os
  2. import time
  3. import random
  4. import atexit
  5. from datetime import datetime, timedelta
  6. import gc
  7. from concurrent.futures import ProcessPoolExecutor, as_completed
  8. import pymongo
  9. from pymongo.errors import PyMongoError, ServerSelectionTimeoutError
  10. import pandas as pd
  11. import matplotlib.pyplot as plt
  12. from matplotlib import font_manager
  13. import matplotlib.dates as mdates
  14. from uo_atlas_import import mongo_con_parse
  15. from config import mongo_config, mongo_table_uo, uo_city_pairs_old, uo_city_pairs_new
  16. font_path = "./simhei.ttf"
  17. font_prop = font_manager.FontProperties(fname=font_path)
  18. _worker_client = None
  19. _worker_db = None
  20. def _close_worker_mongo():
  21. global _worker_client, _worker_db
  22. if _worker_client is not None:
  23. try:
  24. _worker_client.close()
  25. except Exception:
  26. pass
  27. _worker_client = None
  28. _worker_db = None
  29. def init_worker_mongo(db_config):
  30. global _worker_client, _worker_db
  31. try:
  32. _worker_client, _worker_db = mongo_con_parse(db_config)
  33. atexit.register(_close_worker_mongo)
  34. print("[worker] ✅ 数据库连接创建成功")
  35. except Exception as e:
  36. _worker_client = None
  37. _worker_db = None
  38. print(f"[worker] ❌ 数据库连接创建失败: {e}")
  39. def query_groups_of_city_pair(db, city_pair, table_name, min_days=10, max_retries=3, base_sleep=1.0):
  40. """根据city_pair查询航线, 筛选1个月内至少有10天起飞的航线"""
  41. print(f"{city_pair} 查找所有分组")
  42. date_begin = (datetime.today() - timedelta(days=30)).strftime("%Y-%m-%d")
  43. date_end = datetime.today().strftime("%Y-%m-%d")
  44. for attempt in range(1, max_retries + 1):
  45. try:
  46. print(f" 第 {attempt}/{max_retries} 次尝试查询")
  47. collection = db[table_name]
  48. query = {
  49. "citypair": city_pair,
  50. "from_date": {
  51. "$gte": date_begin,
  52. "$lte": date_end
  53. }
  54. }
  55. projection = {
  56. "_id": 0,
  57. "flight_numbers": 1,
  58. "from_date": 1
  59. }
  60. raw_rows = list(collection.find(query, projection))
  61. if not raw_rows:
  62. return []
  63. df = pd.DataFrame(raw_rows)
  64. if df.empty or 'flight_numbers' not in df.columns or 'from_date' not in df.columns:
  65. return []
  66. df = df.dropna(subset=['flight_numbers', 'from_date'])
  67. if df.empty:
  68. return []
  69. df = df.drop_duplicates(subset=['flight_numbers', 'from_date'])
  70. df_grouped = (
  71. df.groupby('flight_numbers', as_index=False)
  72. .agg(days=('from_date', 'size'), flight_dates=('from_date', lambda s: sorted(s.tolist())))
  73. )
  74. df_grouped = df_grouped[df_grouped['days'] >= min_days].sort_values('flight_numbers').reset_index(drop=True)
  75. if df_grouped.empty:
  76. return []
  77. formatted_results = df_grouped[['flight_numbers', 'days', 'flight_dates']].to_dict(orient='records')
  78. return formatted_results
  79. except (ServerSelectionTimeoutError, PyMongoError) as e:
  80. print(f"⚠️ Mongo 查询失败: {e}")
  81. if attempt == max_retries:
  82. print("❌ 达到最大重试次数,放弃")
  83. return []
  84. # 指数退避 + 随机抖动
  85. sleep_time = base_sleep * (2 ** (attempt - 1)) + random.random()
  86. print(f"⏳ {sleep_time:.2f}s 后重试...")
  87. time.sleep(sleep_time)
  88. def query_flight_range_status(db, table_name, city_pair, flight_numbers, from_date_begin, from_date_end,
  89. limit=0, max_retries=3, base_sleep=1.0):
  90. for attempt in range(1, max_retries + 1):
  91. try:
  92. print(f"🔁 第 {attempt}/{max_retries} 次尝试查询")
  93. # 构建查询条件
  94. projection = {
  95. # "_id": 0 # 一般会关掉
  96. "citypair": 1,
  97. "flight_numbers": 1,
  98. "from_date": 1,
  99. "from_time": 1,
  100. "create_time": 1,
  101. "baggage_weight": 1,
  102. "cabins": 1,
  103. "ticket_amount": 1,
  104. "currency": 1,
  105. "price_base": 1,
  106. "price_tax": 1,
  107. "price_total": 1
  108. }
  109. pipeline = [
  110. {
  111. "$match": {
  112. "citypair": city_pair,
  113. "flight_numbers": flight_numbers,
  114. "baggage_weight": {"$in": [0, 20]},
  115. "from_date": {
  116. "$gte": from_date_begin,
  117. "$lte": from_date_end
  118. }
  119. }
  120. },
  121. {
  122. "$project": projection # 就是这里
  123. },
  124. {
  125. "$sort": {
  126. "from_date": 1,
  127. "baggage_weight": 1,
  128. "create_time": 1
  129. }
  130. }
  131. ]
  132. # print(f" 查询条件: {pipeline}")
  133. # 执行查询
  134. collection = db[table_name]
  135. results = list(collection.aggregate(pipeline))
  136. print(f"✅ 查询成功,找到 {len(results)} 条记录")
  137. if results:
  138. df = pd.DataFrame(results)
  139. if '_id' in df.columns:
  140. df = df.drop(columns=['_id'])
  141. if 'from_time' in df.columns and 'from_date' in df.columns:
  142. key_cols = ['citypair', 'flight_numbers', 'from_date']
  143. group_cols = [col for col in key_cols if col in df.columns]
  144. from_time_raw = df['from_time']
  145. from_time_str = from_time_raw.fillna('').astype(str).str.strip()
  146. extracted_time = from_time_str.str.extract(r'(\d{2}:\d{2}:\d{2})$')[0]
  147. valid_time_mask = from_time_str.ne('') & extracted_time.notna()
  148. if valid_time_mask.any():
  149. missing_mask = from_time_raw.isna() | from_time_str.eq('')
  150. if group_cols:
  151. mode_by_group = (
  152. df.loc[valid_time_mask, group_cols]
  153. .assign(_mode_time=extracted_time.loc[valid_time_mask].values)
  154. .groupby(group_cols, dropna=False)['_mode_time']
  155. .agg(lambda s: s.value_counts().idxmax())
  156. .reset_index()
  157. )
  158. df = df.merge(mode_by_group, on=group_cols, how='left')
  159. fill_mask = missing_mask & df['_mode_time'].notna()
  160. if fill_mask.any():
  161. df.loc[fill_mask, 'from_time'] = (
  162. df.loc[fill_mask, 'from_date'].astype(str).str.strip() + ' ' + df.loc[fill_mask, '_mode_time']
  163. )
  164. df = df.drop(columns=['_mode_time'])
  165. remaining_missing_mask = df['from_time'].isna() | df['from_time'].astype(str).str.strip().eq('')
  166. if remaining_missing_mask.any():
  167. more_time = extracted_time.loc[valid_time_mask].value_counts().idxmax()
  168. df.loc[remaining_missing_mask, 'from_time'] = (
  169. df.loc[remaining_missing_mask, 'from_date'].astype(str).str.strip() + ' ' + more_time
  170. )
  171. pass
  172. else:
  173. more_time = extracted_time.loc[valid_time_mask].value_counts().idxmax()
  174. if missing_mask.any():
  175. df.loc[missing_mask, 'from_time'] = df.loc[missing_mask, 'from_date'].astype(str).str.strip() + ' ' + more_time
  176. else:
  177. print(f"⚠️ 无法提取有效起飞时间,抛弃该条记录")
  178. return pd.DataFrame()
  179. print(f"📊 已转换为 DataFrame,形状: {df.shape}")
  180. return df
  181. else:
  182. print("⚠️ 查询结果为空")
  183. return pd.DataFrame()
  184. except (ServerSelectionTimeoutError, PyMongoError) as e:
  185. print(f"⚠️ Mongo 查询失败: {e}")
  186. if attempt == max_retries:
  187. print("❌ 达到最大重试次数,放弃")
  188. return pd.DataFrame()
  189. # 指数退避 + 随机抖动
  190. sleep_time = base_sleep * (2 ** (attempt - 1)) + random.random()
  191. print(f"⏳ {sleep_time:.2f}s 后重试...")
  192. time.sleep(sleep_time)
  193. def plot_c1_trend(df, output_dir="."):
  194. """
  195. 根据传入的 dataframe 绘制 price_total 随 update_hour 的趋势图,
  196. 并按照 baggage 分类进行分组绘制。
  197. """
  198. # 颜色与线型配置(按顺序循环使用)
  199. colors = ['green', 'blue', 'red', 'brown']
  200. linestyles = ['--', '--', '--', '--']
  201. # 确保时间字段为 datetime 类型
  202. if not hasattr(df['update_hour'], 'dt'):
  203. df['update_hour'] = pd.to_datetime(df['update_hour'])
  204. city_pair = df['citypair'].mode().iloc[0]
  205. flight_numbers = df['flight_numbers'].mode().iloc[0]
  206. from_time = df['from_time'].mode().iloc[0]
  207. output_dir_temp = os.path.join(output_dir, city_pair)
  208. os.makedirs(output_dir_temp, exist_ok=True)
  209. # 创建图表对象
  210. fig = plt.figure(figsize=(14, 8))
  211. # 按 baggage_weight 分类绘制
  212. for i, (baggage_value, group) in enumerate(df.groupby('baggage_weight')):
  213. # 按时间排序
  214. df_g = group.sort_values('update_hour').reset_index(drop=True)
  215. # 找价格变化点:与前一行不同的价格即为变化点
  216. # keep first row + change rows + last row
  217. change_points = df_g.loc[
  218. (df_g['price_total'] != df_g['price_total'].shift(1)) |
  219. (df_g.index == 0) |
  220. (df_g.index == len(df_g) - 1) # 终点
  221. ].drop_duplicates(subset=['update_hour'])
  222. # 绘制阶梯线(平缓-突变)+ 变化点
  223. plt.step(
  224. change_points['update_hour'],
  225. change_points['price_total'],
  226. where='post',
  227. color=colors[i % len(colors)],
  228. linestyle=linestyles[i % len(linestyles)],
  229. linewidth=2,
  230. label=f"Baggage {baggage_value}"
  231. )
  232. plt.scatter(
  233. change_points['update_hour'],
  234. change_points['price_total'],
  235. s=30,
  236. facecolors='white',
  237. edgecolors=colors[i % len(colors)],
  238. linewidths=2,
  239. zorder=3,
  240. )
  241. # 添加注释 (小时数, 价格)
  242. for _, row in change_points.iterrows():
  243. text = f"({row['hours_until_departure']}, {row['price_total']})"
  244. plt.annotate(
  245. text,
  246. xy=(row['update_hour'], row['price_total']),
  247. xytext=(0, 0), # 向右偏移
  248. textcoords="offset points",
  249. ha='left',
  250. va='center',
  251. fontsize=5, # 字体稍小
  252. color='gray',
  253. alpha=0.8,
  254. rotation=25,
  255. )
  256. del change_points
  257. del df_g
  258. # 自动优化日期显示
  259. plt.gcf().autofmt_xdate()
  260. plt.xlabel('时刻', fontsize=12, fontproperties=font_prop)
  261. plt.ylabel('价格', fontsize=12, fontproperties=font_prop)
  262. plt.title(f'价格变化趋势 - 航线: {city_pair} 航班号: {flight_numbers}\n起飞时间: {from_time}',
  263. fontsize=14, fontweight='bold', fontproperties=font_prop)
  264. # 设置 x 轴刻度为每天
  265. ax = plt.gca()
  266. ax.xaxis.set_major_locator(mdates.DayLocator(interval=1)) # 每天一个主刻度
  267. ax.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d')) # 显示月-日
  268. ax.xaxis.set_minor_locator(mdates.HourLocator(byhour=[12])) # 指定在12:00显示副刻度
  269. ax.xaxis.set_minor_formatter(mdates.DateFormatter('')) # 输出空字符串
  270. # 添加图例
  271. plt.legend(bbox_to_anchor=(1.05, 1), loc='best', prop=font_prop)
  272. plt.grid(True, alpha=0.3)
  273. plt.tight_layout()
  274. safe_flight = flight_numbers.replace(",", "_")
  275. safe_dep_time = from_time.strftime("%Y-%m-%d %H%M%S")
  276. save_file = f"{city_pair} {safe_flight} {safe_dep_time}.png"
  277. output_path = os.path.join(output_dir_temp, save_file)
  278. # 保存图片(在显示之前)
  279. plt.savefig(output_path, dpi=300, bbox_inches='tight', facecolor='white')
  280. # 关闭图形释放内存
  281. plt.close(fig)
  282. def fill_hourly_create_time(df, head_fill=0, rear_fill=0):
  283. """补齐成小时粒度数据"""
  284. df = df.copy()
  285. # 1. 转 datetime
  286. df['create_time'] = pd.to_datetime(df['create_time'])
  287. df['from_time'] = pd.to_datetime(df['from_time'])
  288. # 添加一个用于分组的小时字段
  289. df['update_hour'] = df['create_time'].dt.floor('h')
  290. # 2. 排序规则:同一小时内,按原始时间戳排序
  291. # 假设你想保留最早的一条
  292. df = df.sort_values(['update_hour', 'create_time'])
  293. # 3. 按小时去重,保留该小时内最早(最晚)的一条
  294. df = df.drop_duplicates(subset=['update_hour'], keep='last') # keep='first' keep='last'
  295. # 4. 标记原始数据
  296. df['is_filled'] = 0
  297. # 5. 排序 + 设索引
  298. df = df.sort_values('update_hour').set_index('update_hour')
  299. # 6. 构造完整小时轴
  300. start_of_hour = df.index.min() # 默认 第一天 最早 开始
  301. if head_fill == 1:
  302. start_of_hour = df.index.min().normalize() # 强制 第一天 00:00 开始
  303. end_of_hour = df.index.max() # 默认 最后一天 最晚 结束
  304. if rear_fill == 1:
  305. end_of_hour = df.index.max().normalize() + pd.Timedelta(hours=23) # 强制 最后一天 23:00 结束
  306. elif rear_fill == 2:
  307. if 'from_time' in df.columns:
  308. last_dep_time = df['from_time'].iloc[-1]
  309. if pd.notna(last_dep_time):
  310. # 对齐到整点小时(向下取整)
  311. end_of_hour = last_dep_time.floor('h')
  312. full_index = pd.date_range(
  313. start=start_of_hour,
  314. end=end_of_hour,
  315. freq='1h'
  316. )
  317. # 7. 按小时补齐
  318. df = df.reindex(full_index)
  319. # 先恢复 dtype(关键!)
  320. df = df.infer_objects(copy=False)
  321. # 8. 新增出来的行标记为 1
  322. df['is_filled'] = df['is_filled'].fillna(1)
  323. # 9. 前向填充
  324. df = df.ffill()
  325. # 10. 还原整型字段
  326. int_cols = [
  327. 'ticket_amount',
  328. 'baggage_weight',
  329. 'is_filled',
  330. ]
  331. for col in int_cols:
  332. if col in df.columns:
  333. df[col] = df[col].astype('int64')
  334. # 10.5 价格字段统一保留两位小数
  335. price_cols = [
  336. 'price_base',
  337. 'price_tax',
  338. 'price_total'
  339. ]
  340. for col in price_cols:
  341. if col in df.columns:
  342. df[col] = df[col].astype('float64').round(2)
  343. # 10.6 新增:距离起飞还有多少小时
  344. if 'from_time' in df.columns:
  345. # 创建临时字段(整点)
  346. df['from_hour'] = df['from_time'].dt.floor('h')
  347. # 计算小时差 df.index 此时就是 update_hour
  348. df['hours_until_departure'] = (
  349. (df['from_hour'] - df.index) / pd.Timedelta(hours=1)
  350. ).astype('int64')
  351. # 新增:距离起飞还有多少天
  352. df['days_to_departure'] = (df['hours_until_departure'] // 24).astype('int64')
  353. # 删除临时字段
  354. df = df.drop(columns=['from_hour'])
  355. # 11. 写回 update_hour
  356. df['update_hour'] = df.index
  357. # 12. 恢复普通索引
  358. df = df.reset_index(drop=True)
  359. return df
  360. def process_flight_numbers(args):
  361. process_id, db_config, city_pair, flight_numbers, from_date_begin, from_date_end, is_train, plot_flag, output_dir = args
  362. print(f"[进程{process_id}] 开始处理航班号: {flight_numbers}")
  363. local_client = None
  364. db = _worker_db
  365. if db is None:
  366. try:
  367. local_client, db = mongo_con_parse(db_config)
  368. print(f"[进程{process_id}] ✅ 数据库连接创建成功")
  369. except Exception as e:
  370. print(f"[进程{process_id}] ❌ 数据库连接创建失败: {e}")
  371. return pd.DataFrame()
  372. try:
  373. df1 = query_flight_range_status(db, mongo_table_uo, city_pair, flight_numbers, from_date_begin, from_date_end)
  374. if df1.empty:
  375. return pd.DataFrame()
  376. common_dep_dates = df1['from_date'].unique()
  377. common_baggages = df1['baggage_weight'].unique()
  378. list_mid = []
  379. for dep_date in common_dep_dates:
  380. # 起飞日期筛选
  381. df_d1 = df1[df1["from_date"] == dep_date].copy()
  382. list_f1 = []
  383. for baggage in common_baggages:
  384. # 行李配额筛选
  385. df_b1 = df_d1[df_d1["baggage_weight"] == baggage].copy()
  386. if df_b1.empty:
  387. print(f"[进程{process_id}] ⚠️ dep_date:{dep_date}, baggage:{baggage} 为空,跳过")
  388. continue
  389. df_f1 = fill_hourly_create_time(df_b1, rear_fill=2)
  390. list_f1.append(df_f1)
  391. del df_f1
  392. del df_b1
  393. if list_f1:
  394. df_c1 = pd.concat(list_f1, ignore_index=True)
  395. if plot_flag:
  396. print(f"[进程{process_id}] ✅ dep_date:{dep_date}, 所有 baggage 数据合并完成,总形状: {df_c1.shape}")
  397. plot_c1_trend(df_c1, output_dir)
  398. print(f"[进程{process_id}] ✅ dep_date:{dep_date}, 所有 baggage 数据绘图完成")
  399. else:
  400. df_c1 = pd.DataFrame()
  401. if plot_flag:
  402. print(f"[进程{process_id}] ⚠️ dep_date:{dep_date}, 所有 baggage 数据合并为空")
  403. del list_f1
  404. list_mid.append(df_c1)
  405. del df_c1
  406. del df_d1
  407. if list_mid:
  408. df_mid = pd.concat(list_mid, ignore_index=True)
  409. print(f"[进程{process_id}] ✅ 航班号:{flight_numbers} 所有 起飞日期 数据合并完成,总形状: {df_mid.shape}")
  410. else:
  411. df_mid = pd.DataFrame()
  412. print(f"[进程{process_id}] ⚠️ 航班号:{flight_numbers} 所有 起飞日期 数据合并为空")
  413. del list_mid
  414. del df1
  415. gc.collect()
  416. print(f"[进程{process_id}] 结束处理航班号: {flight_numbers}")
  417. return df_mid
  418. except Exception as e:
  419. print(f"[进程{process_id}] ❌ 处理航班号:{flight_numbers} 时发生异常: {e}")
  420. return pd.DataFrame()
  421. finally:
  422. if local_client is not None:
  423. try:
  424. local_client.close()
  425. print(f"[进程{process_id}] ✅ 数据库连接已关闭")
  426. except Exception:
  427. pass
  428. def load_data(db_config, city_pair, from_date_begin, from_date_end, is_train=True, plot_flag=False, output_dir='.',
  429. use_multiprocess=False, max_workers=None):
  430. list_all = []
  431. print(f"开始处理航线: {city_pair}")
  432. main_client, main_db = mongo_con_parse(db_config)
  433. all_groups = query_groups_of_city_pair(main_db, city_pair, mongo_table_uo)
  434. main_client.close()
  435. all_groups_len = len(all_groups)
  436. print(f"该航线共有{all_groups_len}组航班号")
  437. if use_multiprocess and all_groups_len > 1:
  438. print(f"启用多进程处理,最大进程数: {max_workers}")
  439. # 多进程处理
  440. process_args = []
  441. process_id = 0
  442. for each_group in all_groups:
  443. flight_numbers = each_group.get("flight_numbers", "未知")
  444. process_id += 1
  445. args = (process_id, db_config, city_pair, flight_numbers, from_date_begin, from_date_end, is_train, plot_flag, output_dir)
  446. process_args.append(args)
  447. with ProcessPoolExecutor(max_workers=max_workers, initializer=init_worker_mongo, initargs=(db_config,)) as executor:
  448. future_to_group = {executor.submit(process_flight_numbers, args): each_group for args, each_group in zip(process_args, all_groups)}
  449. for future in as_completed(future_to_group):
  450. each_group = future_to_group[future]
  451. flight_numbers = each_group.get("flight_numbers", "未知")
  452. try:
  453. df_mid = future.result()
  454. if not df_mid.empty:
  455. list_all.append(df_mid)
  456. print(f"✅ 航班号:{flight_numbers} 处理完成")
  457. else:
  458. print(f"⚠️ 航班号:{flight_numbers} 处理结果为空")
  459. except Exception as e:
  460. print(f"❌ 航班号:{flight_numbers} 处理异常: {e}")
  461. pass
  462. else:
  463. print("使用单进程处理")
  464. process_id = 0
  465. for each_group in all_groups:
  466. flight_numbers = each_group.get("flight_numbers", "未知")
  467. args = (process_id, db_config, city_pair, flight_numbers, from_date_begin, from_date_end, is_train, plot_flag, output_dir)
  468. try:
  469. df_mid = process_flight_numbers(args)
  470. if not df_mid.empty:
  471. list_all.append(df_mid)
  472. print(f"✅ 航班号:{flight_numbers} 处理完成")
  473. else:
  474. print(f"⚠️ 航班号:{flight_numbers} 处理结果为空")
  475. except Exception as e:
  476. print(f"❌ 航班号:{flight_numbers} 处理异常: {e}")
  477. print(f"结束处理航线: {city_pair}")
  478. if list_all:
  479. df_all = pd.concat(list_all, ignore_index=True)
  480. else:
  481. df_all = pd.DataFrame()
  482. print(f"本批次数据加载完毕, 总形状: {df_all.shape}")
  483. del list_all
  484. gc.collect()
  485. return df_all
  486. def validate_keep_one_line(db, table_name, city_pair, flight_numbers, baggage_weight, from_date, entry_price, update_hour_str, del_batch_std_str,
  487. limit=0, max_retries=3, base_sleep=1.0):
  488. """验证keep_info的一行"""
  489. for attempt in range(1, max_retries + 1):
  490. try:
  491. print(f"🔁 第 {attempt}/{max_retries} 次尝试查询")
  492. collection = db[table_name]
  493. projection = {
  494. # "_id": 0 # 一般会关掉
  495. "citypair": 1,
  496. "flight_numbers": 1,
  497. "from_date": 1,
  498. "from_time": 1,
  499. "create_time": 1,
  500. "baggage_weight": 1,
  501. "cabins": 1,
  502. "ticket_amount": 1,
  503. "currency": 1,
  504. "price_base": 1,
  505. "price_tax": 1,
  506. "price_total": 1
  507. }
  508. query = {
  509. "citypair": city_pair,
  510. "flight_numbers": flight_numbers,
  511. "baggage_weight": baggage_weight,
  512. "from_date": from_date,
  513. "create_time": {"$lte": del_batch_std_str}
  514. }
  515. raw_rows = list(collection.find(query, projection))
  516. if not raw_rows:
  517. return pd.DataFrame()
  518. df = pd.DataFrame(raw_rows)
  519. if df.empty or 'flight_numbers' not in df.columns or 'from_date' not in df.columns or 'create_time' not in df.columns:
  520. return pd.DataFrame()
  521. df['_create_time_dt'] = pd.to_datetime(df['create_time'], errors='coerce')
  522. df = df[df['_create_time_dt'].notna()].sort_values('_create_time_dt').reset_index(drop=True)
  523. if df.empty:
  524. return pd.DataFrame()
  525. entry_price_num = pd.to_numeric(entry_price, errors='coerce')
  526. if pd.isna(entry_price_num):
  527. return pd.DataFrame()
  528. df['_price_total_num'] = pd.to_numeric(df['price_total'], errors='coerce')
  529. matched_idx = df.index[df['_price_total_num'] == entry_price_num]
  530. if len(matched_idx) == 0:
  531. return pd.DataFrame()
  532. start_idx = matched_idx[-1]
  533. df = df.iloc[start_idx:].reset_index(drop=True)
  534. df = df.drop(columns=['_create_time_dt', '_price_total_num'])
  535. return df
  536. except (ServerSelectionTimeoutError, PyMongoError) as e:
  537. print(f"⚠️ Mongo 查询失败: {e}")
  538. if attempt == max_retries:
  539. print("❌ 达到最大重试次数,放弃")
  540. return pd.DataFrame()
  541. # 指数退避 + 随机抖动
  542. sleep_time = base_sleep * (2 ** (attempt - 1)) + random.random()
  543. print(f"⏳ {sleep_time:.2f}s 后重试...")
  544. time.sleep(sleep_time)
  545. if __name__ == "__main__":
  546. cpu_cores = os.cpu_count() # 你的系统是72
  547. max_workers = min(8, cpu_cores) # 最大不超过8个进程
  548. output_dir = f"./photo"
  549. os.makedirs(output_dir, exist_ok=True)
  550. from_date_begin = "2026-03-17"
  551. from_date_end = "2026-03-26"
  552. uo_city_pairs = uo_city_pairs_new.copy()
  553. uo_city_pair_list = [f"{pair[:3]}-{pair[3:]}" for pair in uo_city_pairs]
  554. for idx, uo_city_pair in enumerate(uo_city_pair_list, start=1):
  555. # 使用默认配置
  556. # client, db = mongo_con_parse()
  557. print(f"第 {idx} 组 :", uo_city_pair)
  558. start_time = time.time()
  559. load_data(mongo_config, uo_city_pair, from_date_begin, from_date_end,
  560. plot_flag=True, output_dir=output_dir, use_multiprocess=True, max_workers=max_workers)
  561. end_time = time.time()
  562. run_time = round(end_time - start_time, 3)
  563. print(f"用时: {run_time} 秒")