data_loader.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692
  1. import os
  2. import time
  3. import random
  4. import atexit
  5. from datetime import datetime, timedelta
  6. import gc
  7. from concurrent.futures import ProcessPoolExecutor, as_completed
  8. import pymongo
  9. from pymongo.errors import PyMongoError, ServerSelectionTimeoutError
  10. import numpy as np
  11. import pandas as pd
  12. import matplotlib.pyplot as plt
  13. from matplotlib import font_manager
  14. import matplotlib.dates as mdates
  15. from uo_atlas_import import mongo_con_parse
  16. from config import mongo_config, mongo_table_uo, uo_city_pairs_old, uo_city_pairs_new
  17. font_path = "./simhei.ttf"
  18. font_prop = font_manager.FontProperties(fname=font_path)
  19. _worker_client = None
  20. _worker_db = None
  21. def _close_worker_mongo():
  22. global _worker_client, _worker_db
  23. if _worker_client is not None:
  24. try:
  25. _worker_client.close()
  26. except Exception:
  27. pass
  28. _worker_client = None
  29. _worker_db = None
  30. def init_worker_mongo(db_config):
  31. global _worker_client, _worker_db
  32. try:
  33. _worker_client, _worker_db = mongo_con_parse(db_config)
  34. atexit.register(_close_worker_mongo)
  35. print("[worker] ✅ 数据库连接创建成功")
  36. except Exception as e:
  37. _worker_client = None
  38. _worker_db = None
  39. print(f"[worker] ❌ 数据库连接创建失败: {e}")
  40. def query_groups_of_city_pair(db, city_pair, table_name, min_days=10, max_retries=3, base_sleep=1.0):
  41. """根据city_pair查询航线, 筛选1个月内至少有10天起飞的航线"""
  42. print(f"{city_pair} 查找所有分组")
  43. date_begin = (datetime.today() - timedelta(days=30)).strftime("%Y-%m-%d")
  44. date_end = datetime.today().strftime("%Y-%m-%d")
  45. for attempt in range(1, max_retries + 1):
  46. try:
  47. print(f" 第 {attempt}/{max_retries} 次尝试查询")
  48. collection = db[table_name]
  49. query = {
  50. "citypair": city_pair,
  51. "from_date": {
  52. "$gte": date_begin,
  53. "$lte": date_end
  54. }
  55. }
  56. projection = {
  57. "_id": 0,
  58. "flight_numbers": 1,
  59. "from_date": 1
  60. }
  61. raw_rows = list(collection.find(query, projection))
  62. if not raw_rows:
  63. return []
  64. df = pd.DataFrame(raw_rows)
  65. if df.empty or 'flight_numbers' not in df.columns or 'from_date' not in df.columns:
  66. return []
  67. df = df.dropna(subset=['flight_numbers', 'from_date'])
  68. if df.empty:
  69. return []
  70. df = df.drop_duplicates(subset=['flight_numbers', 'from_date'])
  71. df_grouped = (
  72. df.groupby('flight_numbers', as_index=False)
  73. .agg(days=('from_date', 'size'), flight_dates=('from_date', lambda s: sorted(s.tolist())))
  74. )
  75. df_grouped = df_grouped[df_grouped['days'] >= min_days].sort_values('flight_numbers').reset_index(drop=True)
  76. if df_grouped.empty:
  77. return []
  78. formatted_results = df_grouped[['flight_numbers', 'days', 'flight_dates']].to_dict(orient='records')
  79. return formatted_results
  80. except (ServerSelectionTimeoutError, PyMongoError) as e:
  81. print(f"⚠️ Mongo 查询失败: {e}")
  82. if attempt == max_retries:
  83. print("❌ 达到最大重试次数,放弃")
  84. return []
  85. # 指数退避 + 随机抖动
  86. sleep_time = base_sleep * (2 ** (attempt - 1)) + random.random()
  87. print(f"⏳ {sleep_time:.2f}s 后重试...")
  88. time.sleep(sleep_time)
  89. def query_flight_range_status(db, table_name, city_pair, flight_numbers, from_date_begin, from_date_end,
  90. limit=0, max_retries=3, base_sleep=1.0):
  91. for attempt in range(1, max_retries + 1):
  92. try:
  93. print(f"🔁 第 {attempt}/{max_retries} 次尝试查询")
  94. # 构建查询条件
  95. projection = {
  96. # "_id": 0 # 一般会关掉
  97. "citypair": 1,
  98. "flight_numbers": 1,
  99. "from_date": 1,
  100. "from_time": 1,
  101. "create_time": 1,
  102. "baggage_weight": 1,
  103. "cabins": 1,
  104. "ticket_amount": 1,
  105. "currency": 1,
  106. "price_base": 1,
  107. "price_tax": 1,
  108. "price_total": 1
  109. }
  110. pipeline = [
  111. {
  112. "$match": {
  113. "citypair": city_pair,
  114. "flight_numbers": flight_numbers,
  115. "baggage_weight": {"$in": [0, 20]},
  116. "from_date": {
  117. "$gte": from_date_begin,
  118. "$lte": from_date_end
  119. }
  120. }
  121. },
  122. {
  123. "$project": projection # 就是这里
  124. },
  125. {
  126. "$sort": {
  127. "from_date": 1,
  128. "baggage_weight": 1,
  129. "create_time": 1
  130. }
  131. }
  132. ]
  133. # print(f" 查询条件: {pipeline}")
  134. # 执行查询
  135. collection = db[table_name]
  136. results = list(collection.aggregate(pipeline))
  137. print(f"✅ 查询成功,找到 {len(results)} 条记录")
  138. if results:
  139. df = pd.DataFrame(results)
  140. if '_id' in df.columns:
  141. df = df.drop(columns=['_id'])
  142. if 'from_time' in df.columns and 'from_date' in df.columns:
  143. key_cols = ['citypair', 'flight_numbers', 'from_date']
  144. group_cols = [col for col in key_cols if col in df.columns]
  145. from_time_raw = df['from_time']
  146. from_time_str = from_time_raw.fillna('').astype(str).str.strip()
  147. extracted_time = from_time_str.str.extract(r'(\d{2}:\d{2}:\d{2})$')[0]
  148. valid_time_mask = from_time_str.ne('') & extracted_time.notna()
  149. if valid_time_mask.any():
  150. missing_mask = from_time_raw.isna() | from_time_str.eq('')
  151. if group_cols:
  152. mode_by_group = (
  153. df.loc[valid_time_mask, group_cols]
  154. .assign(_mode_time=extracted_time.loc[valid_time_mask].values)
  155. .groupby(group_cols, dropna=False)['_mode_time']
  156. .agg(lambda s: s.value_counts().idxmax())
  157. .reset_index()
  158. )
  159. df = df.merge(mode_by_group, on=group_cols, how='left')
  160. fill_mask = missing_mask & df['_mode_time'].notna()
  161. if fill_mask.any():
  162. df.loc[fill_mask, 'from_time'] = (
  163. df.loc[fill_mask, 'from_date'].astype(str).str.strip() + ' ' + df.loc[fill_mask, '_mode_time']
  164. )
  165. df = df.drop(columns=['_mode_time'])
  166. remaining_missing_mask = df['from_time'].isna() | df['from_time'].astype(str).str.strip().eq('')
  167. if remaining_missing_mask.any():
  168. more_time = extracted_time.loc[valid_time_mask].value_counts().idxmax()
  169. df.loc[remaining_missing_mask, 'from_time'] = (
  170. df.loc[remaining_missing_mask, 'from_date'].astype(str).str.strip() + ' ' + more_time
  171. )
  172. pass
  173. else:
  174. more_time = extracted_time.loc[valid_time_mask].value_counts().idxmax()
  175. if missing_mask.any():
  176. df.loc[missing_mask, 'from_time'] = df.loc[missing_mask, 'from_date'].astype(str).str.strip() + ' ' + more_time
  177. else:
  178. print(f"⚠️ 无法提取有效起飞时间,抛弃该条记录")
  179. return pd.DataFrame()
  180. print(f"📊 已转换为 DataFrame,形状: {df.shape}")
  181. return df
  182. else:
  183. print("⚠️ 查询结果为空")
  184. return pd.DataFrame()
  185. except (ServerSelectionTimeoutError, PyMongoError) as e:
  186. print(f"⚠️ Mongo 查询失败: {e}")
  187. if attempt == max_retries:
  188. print("❌ 达到最大重试次数,放弃")
  189. return pd.DataFrame()
  190. # 指数退避 + 随机抖动
  191. sleep_time = base_sleep * (2 ** (attempt - 1)) + random.random()
  192. print(f"⏳ {sleep_time:.2f}s 后重试...")
  193. time.sleep(sleep_time)
  194. def plot_c1_trend(df, output_dir="."):
  195. """
  196. 根据传入的 dataframe 绘制 price_total 随 update_hour 的趋势图,
  197. 并按照 baggage 分类进行分组绘制。
  198. """
  199. # 颜色与线型配置(按顺序循环使用)
  200. colors = ['green', 'blue', 'red', 'brown']
  201. linestyles = ['--', '--', '--', '--']
  202. # 确保时间字段为 datetime 类型
  203. if not hasattr(df['update_hour'], 'dt'):
  204. df['update_hour'] = pd.to_datetime(df['update_hour'])
  205. city_pair = df['citypair'].mode().iloc[0]
  206. flight_numbers = df['flight_numbers'].mode().iloc[0]
  207. from_time = df['from_time'].mode().iloc[0]
  208. output_dir_temp = os.path.join(output_dir, city_pair)
  209. os.makedirs(output_dir_temp, exist_ok=True)
  210. df = df[df['baggage_weight'] == 0] # 只保留无行李的
  211. # 创建图表对象
  212. fig = plt.figure(figsize=(14, 8))
  213. # 按 baggage_weight 分类绘制
  214. for i, (baggage_value, group) in enumerate(df.groupby('baggage_weight')):
  215. # 按时间排序
  216. df_g = group.sort_values('update_hour').reset_index(drop=True)
  217. # 找价格变化点:与前一行不同的价格即为变化点
  218. # keep first row + change rows + last row
  219. change_points = df_g.loc[
  220. (df_g['price_total'] != df_g['price_total'].shift(1)) |
  221. (df_g.index == 0) |
  222. (df_g.index == len(df_g) - 1) # 终点
  223. ].drop_duplicates(subset=['update_hour'])
  224. # 绘制阶梯线(平缓-突变)+ 变化点
  225. plt.step(
  226. change_points['update_hour'],
  227. change_points['price_total'],
  228. where='post',
  229. color=colors[i % len(colors)],
  230. linestyle=linestyles[i % len(linestyles)],
  231. linewidth=2,
  232. label=f"Baggage {baggage_value}"
  233. )
  234. plt.scatter(
  235. change_points['update_hour'],
  236. change_points['price_total'],
  237. s=30,
  238. facecolors='white',
  239. edgecolors=colors[i % len(colors)],
  240. linewidths=2,
  241. zorder=3,
  242. )
  243. # 添加注释 (小时数, 价格, 舱位)
  244. # 点密集时自动抽样,避免文字严重重叠
  245. n_points = len(change_points)
  246. max_labels = 30
  247. step = max(1, int(np.ceil(n_points / max_labels)))
  248. label_points = change_points.iloc[::step].copy()
  249. # 确保最后一个点始终有注释
  250. if n_points > 0 and label_points.index[-1] != change_points.index[-1]:
  251. label_points = pd.concat([label_points, change_points.tail(1)])
  252. rotation_angle = 45 if n_points > max_labels else 25
  253. label_fontsize = 4 if n_points > max_labels else 5
  254. for _, row in change_points.iterrows():
  255. text = f"({row['hours_until_departure']}, {row['price_total']}, {row['cabins']})"
  256. plt.annotate(
  257. text,
  258. xy=(row['update_hour'], row['price_total']),
  259. xytext=(0, 0), # 向右偏移
  260. textcoords="offset points",
  261. ha='left',
  262. va='center',
  263. fontsize=label_fontsize, # 字体稍小
  264. color='gray',
  265. alpha=0.8,
  266. rotation=rotation_angle,
  267. )
  268. del change_points
  269. del df_g
  270. # 自动优化日期显示
  271. plt.gcf().autofmt_xdate()
  272. plt.xlabel('时刻', fontsize=12, fontproperties=font_prop)
  273. plt.ylabel('价格', fontsize=12, fontproperties=font_prop)
  274. plt.title(f'价格变化趋势 - 航线: {city_pair} 航班号: {flight_numbers}\n起飞时间: {from_time}',
  275. fontsize=14, fontweight='bold', fontproperties=font_prop)
  276. # 设置 x 轴刻度为每天
  277. ax = plt.gca()
  278. ax.xaxis.set_major_locator(mdates.DayLocator(interval=1)) # 每天一个主刻度
  279. ax.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d')) # 显示月-日
  280. ax.xaxis.set_minor_locator(mdates.HourLocator(byhour=[12])) # 指定在12:00显示副刻度
  281. ax.xaxis.set_minor_formatter(mdates.DateFormatter('')) # 输出空字符串
  282. # 添加图例
  283. plt.legend(bbox_to_anchor=(1.05, 1), loc='best', prop=font_prop)
  284. plt.grid(True, alpha=0.3)
  285. plt.tight_layout()
  286. safe_flight = flight_numbers.replace(",", "_")
  287. safe_dep_time = from_time.strftime("%Y-%m-%d %H%M%S")
  288. save_file = f"{city_pair} {safe_flight} {safe_dep_time}.png"
  289. output_path = os.path.join(output_dir_temp, save_file)
  290. # 保存图片(在显示之前)
  291. plt.savefig(output_path, dpi=300, bbox_inches='tight', facecolor='white')
  292. # 关闭图形释放内存
  293. plt.close(fig)
  294. def fill_hourly_create_time(df, head_fill=0, rear_fill=0):
  295. """补齐成小时粒度数据"""
  296. df = df.copy()
  297. # 1. 转 datetime
  298. df['create_time'] = pd.to_datetime(df['create_time'])
  299. df['from_time'] = pd.to_datetime(df['from_time'])
  300. # 添加一个用于分组的小时字段
  301. df['update_hour'] = df['create_time'].dt.floor('h')
  302. # 2. 排序规则:同一小时内,按原始时间戳排序
  303. # 假设你想保留最早的一条
  304. df = df.sort_values(['update_hour', 'create_time'])
  305. # 3. 按小时去重,保留该小时内最早(最晚)的一条
  306. df = df.drop_duplicates(subset=['update_hour'], keep='last') # keep='first' keep='last'
  307. # 4. 标记原始数据
  308. df['is_filled'] = 0
  309. # 5. 排序 + 设索引
  310. df = df.sort_values('update_hour').set_index('update_hour')
  311. # 6. 构造完整小时轴
  312. start_of_hour = df.index.min() # 默认 第一天 最早 开始
  313. if head_fill == 1:
  314. start_of_hour = df.index.min().normalize() # 强制 第一天 00:00 开始
  315. end_of_hour = df.index.max() # 默认 最后一天 最晚 结束
  316. if rear_fill == 1:
  317. end_of_hour = df.index.max().normalize() + pd.Timedelta(hours=23) # 强制 最后一天 23:00 结束
  318. elif rear_fill == 2:
  319. if 'from_time' in df.columns:
  320. last_dep_time = df['from_time'].iloc[-1]
  321. if pd.notna(last_dep_time):
  322. # 对齐到整点小时(向下取整)
  323. end_of_hour = last_dep_time.floor('h')
  324. full_index = pd.date_range(
  325. start=start_of_hour,
  326. end=end_of_hour,
  327. freq='1h'
  328. )
  329. # 7. 按小时补齐
  330. df = df.reindex(full_index)
  331. # 先恢复 dtype(关键!)
  332. df = df.infer_objects(copy=False)
  333. # 8. 新增出来的行标记为 1
  334. df['is_filled'] = df['is_filled'].fillna(1)
  335. # 9. 前向填充
  336. df = df.ffill()
  337. # 10. 还原整型字段
  338. int_cols = [
  339. 'ticket_amount',
  340. 'baggage_weight',
  341. 'is_filled',
  342. ]
  343. for col in int_cols:
  344. if col in df.columns:
  345. df[col] = df[col].astype('int64')
  346. # 10.5 价格字段统一保留两位小数
  347. price_cols = [
  348. 'price_base',
  349. 'price_tax',
  350. 'price_total'
  351. ]
  352. for col in price_cols:
  353. if col in df.columns:
  354. df[col] = df[col].astype('float64').round(2)
  355. # 10.6 新增:距离起飞还有多少小时
  356. if 'from_time' in df.columns:
  357. # 创建临时字段(整点)
  358. df['from_hour'] = df['from_time'].dt.floor('h')
  359. # 计算小时差 df.index 此时就是 update_hour
  360. df['hours_until_departure'] = (
  361. (df['from_hour'] - df.index) / pd.Timedelta(hours=1)
  362. ).astype('int64')
  363. # 新增:距离起飞还有多少天
  364. df['days_to_departure'] = (df['hours_until_departure'] // 24).astype('int64')
  365. # 删除临时字段
  366. df = df.drop(columns=['from_hour'])
  367. # 11. 写回 update_hour
  368. df['update_hour'] = df.index
  369. # 12. 恢复普通索引
  370. df = df.reset_index(drop=True)
  371. return df
  372. def process_flight_numbers(args):
  373. process_id, db_config, city_pair, flight_numbers, from_date_begin, from_date_end, is_train, plot_flag, output_dir = args
  374. print(f"[进程{process_id}] 开始处理航班号: {flight_numbers}")
  375. local_client = None
  376. db = _worker_db
  377. if db is None:
  378. try:
  379. local_client, db = mongo_con_parse(db_config)
  380. print(f"[进程{process_id}] ✅ 数据库连接创建成功")
  381. except Exception as e:
  382. print(f"[进程{process_id}] ❌ 数据库连接创建失败: {e}")
  383. return pd.DataFrame()
  384. try:
  385. df1 = query_flight_range_status(db, mongo_table_uo, city_pair, flight_numbers, from_date_begin, from_date_end)
  386. if df1.empty:
  387. return pd.DataFrame()
  388. common_dep_dates = df1['from_date'].unique()
  389. common_baggages = df1['baggage_weight'].unique()
  390. list_mid = []
  391. for dep_date in common_dep_dates:
  392. # 起飞日期筛选
  393. df_d1 = df1[df1["from_date"] == dep_date].copy()
  394. list_f1 = []
  395. for baggage in common_baggages:
  396. # 行李配额筛选
  397. df_b1 = df_d1[df_d1["baggage_weight"] == baggage].copy()
  398. if df_b1.empty:
  399. print(f"[进程{process_id}] ⚠️ dep_date:{dep_date}, baggage:{baggage} 为空,跳过")
  400. continue
  401. df_f1 = fill_hourly_create_time(df_b1, rear_fill=2)
  402. list_f1.append(df_f1)
  403. del df_f1
  404. del df_b1
  405. if list_f1:
  406. df_c1 = pd.concat(list_f1, ignore_index=True)
  407. if plot_flag:
  408. print(f"[进程{process_id}] ✅ dep_date:{dep_date}, 所有 baggage 数据合并完成,总形状: {df_c1.shape}")
  409. plot_c1_trend(df_c1, output_dir)
  410. print(f"[进程{process_id}] ✅ dep_date:{dep_date}, 所有 baggage 数据绘图完成")
  411. else:
  412. df_c1 = pd.DataFrame()
  413. if plot_flag:
  414. print(f"[进程{process_id}] ⚠️ dep_date:{dep_date}, 所有 baggage 数据合并为空")
  415. del list_f1
  416. list_mid.append(df_c1)
  417. del df_c1
  418. del df_d1
  419. if list_mid:
  420. df_mid = pd.concat(list_mid, ignore_index=True)
  421. print(f"[进程{process_id}] ✅ 航班号:{flight_numbers} 所有 起飞日期 数据合并完成,总形状: {df_mid.shape}")
  422. else:
  423. df_mid = pd.DataFrame()
  424. print(f"[进程{process_id}] ⚠️ 航班号:{flight_numbers} 所有 起飞日期 数据合并为空")
  425. del list_mid
  426. del df1
  427. gc.collect()
  428. print(f"[进程{process_id}] 结束处理航班号: {flight_numbers}")
  429. return df_mid
  430. except Exception as e:
  431. print(f"[进程{process_id}] ❌ 处理航班号:{flight_numbers} 时发生异常: {e}")
  432. return pd.DataFrame()
  433. finally:
  434. if local_client is not None:
  435. try:
  436. local_client.close()
  437. print(f"[进程{process_id}] ✅ 数据库连接已关闭")
  438. except Exception:
  439. pass
  440. def load_data(db_config, city_pair, from_date_begin, from_date_end, is_train=True, plot_flag=False, output_dir='.',
  441. use_multiprocess=False, max_workers=None):
  442. list_all = []
  443. print(f"开始处理航线: {city_pair}")
  444. main_client, main_db = mongo_con_parse(db_config)
  445. all_groups = query_groups_of_city_pair(main_db, city_pair, mongo_table_uo)
  446. main_client.close()
  447. all_groups_len = len(all_groups)
  448. print(f"该航线共有{all_groups_len}组航班号")
  449. if use_multiprocess and all_groups_len > 1:
  450. print(f"启用多进程处理,最大进程数: {max_workers}")
  451. # 多进程处理
  452. process_args = []
  453. process_id = 0
  454. for each_group in all_groups:
  455. flight_numbers = each_group.get("flight_numbers", "未知")
  456. process_id += 1
  457. args = (process_id, db_config, city_pair, flight_numbers, from_date_begin, from_date_end, is_train, plot_flag, output_dir)
  458. process_args.append(args)
  459. with ProcessPoolExecutor(max_workers=max_workers, initializer=init_worker_mongo, initargs=(db_config,)) as executor:
  460. future_to_group = {executor.submit(process_flight_numbers, args): each_group for args, each_group in zip(process_args, all_groups)}
  461. for future in as_completed(future_to_group):
  462. each_group = future_to_group[future]
  463. flight_numbers = each_group.get("flight_numbers", "未知")
  464. try:
  465. df_mid = future.result()
  466. if not df_mid.empty:
  467. list_all.append(df_mid)
  468. print(f"✅ 航班号:{flight_numbers} 处理完成")
  469. else:
  470. print(f"⚠️ 航班号:{flight_numbers} 处理结果为空")
  471. except Exception as e:
  472. print(f"❌ 航班号:{flight_numbers} 处理异常: {e}")
  473. pass
  474. else:
  475. print("使用单进程处理")
  476. process_id = 0
  477. for each_group in all_groups:
  478. flight_numbers = each_group.get("flight_numbers", "未知")
  479. args = (process_id, db_config, city_pair, flight_numbers, from_date_begin, from_date_end, is_train, plot_flag, output_dir)
  480. try:
  481. df_mid = process_flight_numbers(args)
  482. if not df_mid.empty:
  483. list_all.append(df_mid)
  484. print(f"✅ 航班号:{flight_numbers} 处理完成")
  485. else:
  486. print(f"⚠️ 航班号:{flight_numbers} 处理结果为空")
  487. except Exception as e:
  488. print(f"❌ 航班号:{flight_numbers} 处理异常: {e}")
  489. print(f"结束处理航线: {city_pair}")
  490. if list_all:
  491. df_all = pd.concat(list_all, ignore_index=True)
  492. else:
  493. df_all = pd.DataFrame()
  494. print(f"本批次数据加载完毕, 总形状: {df_all.shape}")
  495. del list_all
  496. gc.collect()
  497. return df_all
  498. def validate_keep_one_line(db, table_name, city_pair, flight_numbers, baggage_weight, from_date, entry_price, del_batch_std_str,
  499. limit=0, max_retries=3, base_sleep=1.0):
  500. """验证keep_info的一行"""
  501. for attempt in range(1, max_retries + 1):
  502. try:
  503. print(f"🔁 第 {attempt}/{max_retries} 次尝试查询")
  504. collection = db[table_name]
  505. projection = {
  506. # "_id": 0 # 一般会关掉
  507. "citypair": 1,
  508. "flight_numbers": 1,
  509. "from_date": 1,
  510. "from_time": 1,
  511. "create_time": 1,
  512. "baggage_weight": 1,
  513. "cabins": 1,
  514. "ticket_amount": 1,
  515. "currency": 1,
  516. "price_base": 1,
  517. "price_tax": 1,
  518. "price_total": 1
  519. }
  520. query = {
  521. "citypair": city_pair,
  522. "flight_numbers": flight_numbers,
  523. "baggage_weight": baggage_weight,
  524. "from_date": from_date,
  525. "create_time": {"$lte": del_batch_std_str}
  526. }
  527. raw_rows = list(collection.find(query, projection))
  528. if not raw_rows:
  529. return pd.DataFrame()
  530. df = pd.DataFrame(raw_rows)
  531. if df.empty or 'flight_numbers' not in df.columns or 'from_date' not in df.columns or 'create_time' not in df.columns:
  532. return pd.DataFrame()
  533. df['_create_time_dt'] = pd.to_datetime(df['create_time'], errors='coerce')
  534. df = df[df['_create_time_dt'].notna()].sort_values('_create_time_dt').reset_index(drop=True)
  535. if df.empty:
  536. return pd.DataFrame()
  537. entry_price_num = pd.to_numeric(entry_price, errors='coerce')
  538. if pd.isna(entry_price_num):
  539. return pd.DataFrame()
  540. df['_price_total_num'] = pd.to_numeric(df['price_total'], errors='coerce')
  541. matched_idx = df.index[df['_price_total_num'] == entry_price_num]
  542. if len(matched_idx) == 0:
  543. return pd.DataFrame()
  544. start_idx = matched_idx[-1]
  545. df = df.iloc[start_idx:].reset_index(drop=True)
  546. df = df.drop(columns=['_create_time_dt', '_price_total_num'])
  547. return df
  548. except (ServerSelectionTimeoutError, PyMongoError) as e:
  549. print(f"⚠️ Mongo 查询失败: {e}")
  550. if attempt == max_retries:
  551. print("❌ 达到最大重试次数,放弃")
  552. return pd.DataFrame()
  553. # 指数退避 + 随机抖动
  554. sleep_time = base_sleep * (2 ** (attempt - 1)) + random.random()
  555. print(f"⏳ {sleep_time:.2f}s 后重试...")
  556. time.sleep(sleep_time)
  557. if __name__ == "__main__":
  558. cpu_cores = os.cpu_count() # 你的系统是72
  559. max_workers = min(8, cpu_cores) # 最大不超过8个进程
  560. output_dir = f"./photo"
  561. os.makedirs(output_dir, exist_ok=True)
  562. from_date_begin = "2026-04-21"
  563. from_date_end = "2026-05-10"
  564. uo_city_pairs = uo_city_pairs_new.copy()
  565. uo_city_pair_list = [f"{pair[:3]}-{pair[3:]}" for pair in uo_city_pairs]
  566. for idx, uo_city_pair in enumerate(uo_city_pair_list, start=1):
  567. # 使用默认配置
  568. # client, db = mongo_con_parse()
  569. print(f"第 {idx} 组 :", uo_city_pair)
  570. start_time = time.time()
  571. load_data(mongo_config, uo_city_pair, from_date_begin, from_date_end,
  572. plot_flag=True, output_dir=output_dir, use_multiprocess=True, max_workers=max_workers)
  573. end_time = time.time()
  574. run_time = round(end_time - start_time, 3)
  575. print(f"用时: {run_time} 秒")