data_loader.py 47 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232
  1. import gc
  2. import time
  3. from datetime import datetime, timedelta
  4. import pymongo
  5. from pymongo.errors import PyMongoError, ServerSelectionTimeoutError
  6. import pandas as pd
  7. import os
  8. import random
  9. import tempfile
  10. from concurrent.futures import ProcessPoolExecutor, as_completed
  11. import numpy as np
  12. import matplotlib.pyplot as plt
  13. from matplotlib import font_manager
  14. import matplotlib.dates as mdates
  15. from config import mongodb_config, vj_flight_route_list, vj_flight_route_list_hot, vj_flight_route_list_nothot, \
  16. CLEAN_VJ_HOT_NEAR_INFO_TAB, CLEAN_VJ_HOT_FAR_INFO_TAB, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB, CLEAN_VJ_NOTHOT_FAR_INFO_TAB
  17. font_path = "./simhei.ttf"
  18. font_prop = font_manager.FontProperties(fname=font_path)
  19. _MONGO_SHARED_CLIENT = None
  20. _MONGO_SHARED_DB = None
  21. _MONGO_SHARED_CFG_KEY = None
  22. def mongo_con_parse(config=None, reuse_client=False):
  23. if config is None:
  24. config = mongodb_config.copy()
  25. global _MONGO_SHARED_CLIENT, _MONGO_SHARED_DB, _MONGO_SHARED_CFG_KEY
  26. cfg_key = (
  27. config.get("URI", ""),
  28. config.get("host", ""),
  29. config.get("port", ""),
  30. config.get("db", ""),
  31. config.get("user", ""),
  32. )
  33. if reuse_client and _MONGO_SHARED_CLIENT is not None and _MONGO_SHARED_DB is not None and _MONGO_SHARED_CFG_KEY == cfg_key:
  34. return _MONGO_SHARED_CLIENT, _MONGO_SHARED_DB
  35. try:
  36. if config.get("URI", ""):
  37. motor_uri = config["URI"]
  38. client = pymongo.MongoClient(motor_uri, maxPoolSize=100)
  39. db = client[config['db']]
  40. else:
  41. client = pymongo.MongoClient(
  42. config['host'],
  43. config['port'],
  44. serverSelectionTimeoutMS=30000,
  45. connectTimeoutMS=30000,
  46. socketTimeoutMS=30000,
  47. retryReads=True,
  48. maxPoolSize=50
  49. )
  50. db = client[config['db']]
  51. if config.get('user'):
  52. db.authenticate(config['user'], config['pwd'])
  53. print(f"✅ MongoDB 连接对象创建成功")
  54. except Exception as e:
  55. print(f"❌ 创建 MongoDB 连接对象时发生错误: {e}")
  56. raise
  57. if reuse_client:
  58. _MONGO_SHARED_CLIENT = client
  59. _MONGO_SHARED_DB = db
  60. _MONGO_SHARED_CFG_KEY = cfg_key
  61. return client, db
  62. def test_mongo_connection(db):
  63. try:
  64. # 获取客户端对象
  65. client = db.client
  66. # 方法1:使用 server_info() 测试连接
  67. info = client.server_info()
  68. print(f"✅ MongoDB 连接测试成功!")
  69. print(f" 服务器版本: {info.get('version')}")
  70. print(f" 数据库: {db.name}")
  71. return True
  72. except Exception as e:
  73. print(f"❌ 数据库连接测试失败: {e}")
  74. return False
  75. def query_flight_range_status(db, table_name, from_city, to_city, dep_date_begin, dep_date_end, flight_nums,
  76. limit=0, max_retries=3, base_sleep=1.0, thread_id=0):
  77. """
  78. 从指定表(4类)查询数据(指定起飞天的范围) (失败自动重试)
  79. """
  80. for attempt in range(1, max_retries + 1):
  81. try:
  82. print(f"🔁 第 {attempt}/{max_retries} 次尝试查询")
  83. query_condition = {
  84. "from_city_code": from_city,
  85. "to_city_code": to_city,
  86. "search_dep_time": {
  87. "$gte": dep_date_begin,
  88. "$lte": dep_date_end,
  89. },
  90. }
  91. baggage_filter = 0
  92. # flight_nums_filter = list(flight_nums) if flight_nums else []
  93. print(f" 查询条件(走索引): {query_condition}")
  94. projection = {
  95. "from_city_code": 1,
  96. "search_dep_time": 1,
  97. "to_city_code": 1,
  98. "currency": 1,
  99. "adult_price": 1,
  100. "adult_tax": 1,
  101. "adult_total_price": 1,
  102. "seats_remaining": 1,
  103. "segments": 1,
  104. "source_website": 1,
  105. "crawl_date": 1
  106. }
  107. cursor = (
  108. db.get_collection(table_name)
  109. .find(query_condition, projection=projection)
  110. .batch_size(5000)
  111. .hint('from_city_code_1_to_city_code_1_search_dep_time_1')
  112. )
  113. # 将结果转换为列表
  114. results = list(cursor)
  115. print(f"✅ 查询成功,找到 {len(results)} 条记录")
  116. if results:
  117. df = pd.DataFrame(results)
  118. # 处理特殊的 ObjectId 类型
  119. if '_id' in df.columns:
  120. df = df.drop(columns=['_id'])
  121. print(f"📊 已转换为 DataFrame,形状: {df.shape}")
  122. # 1️⃣ 展开 segments
  123. print(f"📊 开始扩展segments 稍等...")
  124. t1 = time.time()
  125. df = expand_segments_columns_optimized(df) # 改为调用优化版
  126. t2 = time.time()
  127. rt = round(t2 - t1, 3)
  128. print(f"用时: {rt} 秒")
  129. print(f"📊 已将segments扩展成字段,形状: {df.shape}")
  130. if "baggage" in df.columns:
  131. df = df[df["baggage"] == baggage_filter]
  132. # for i, flight_num in enumerate(flight_nums_filter):
  133. # if flight_num is None or flight_num == "":
  134. # continue
  135. # col = f"seg{i + 1}_flight_number"
  136. # if col not in df.columns:
  137. # return pd.DataFrame()
  138. # df = df[df[col].astype("string") == str(flight_num)]
  139. # sort_cols = [c for c in ["search_dep_time", "baggage", "crawl_date"] if c in df.columns]
  140. # if sort_cols:
  141. # df = df.sort_values(sort_cols, kind="mergesort").reset_index(drop=True)
  142. if limit > 0:
  143. df = df.head(limit).reset_index(drop=True)
  144. return df
  145. else:
  146. print("⚠️ 查询结果为空")
  147. return pd.DataFrame()
  148. except (ServerSelectionTimeoutError, PyMongoError) as e:
  149. print(f"⚠️ Mongo 查询失败: {e}")
  150. if attempt == max_retries:
  151. print("❌ 达到最大重试次数,放弃")
  152. return pd.DataFrame()
  153. # 指数退避 + 随机抖动
  154. sleep_time = base_sleep * (2 ** (attempt - 1)) + random.random()
  155. print(f"⏳ {sleep_time:.2f}s 后重试...")
  156. time.sleep(sleep_time)
  157. # def expand_segments_columns(df):
  158. # """展开 segments"""
  159. # df = df.copy()
  160. # # 定义要展开的列
  161. # seg1_cols = ['flight_number', 'dep_air_port', 'dep_time', 'arr_air_port', 'arr_time', 'cabin', 'baggage']
  162. # seg2_cols = ['flight_number', 'dep_air_port', 'dep_time', 'arr_air_port', 'arr_time']
  163. # # 定义 apply 函数一次返回字典
  164. # def extract_segments(row):
  165. # segments = row.get('segments')
  166. # result = {}
  167. # # 默认缺失使用 pd.NA(对字符串友好)
  168. # missing = pd.NA
  169. # if isinstance(segments, list):
  170. # # 第一段
  171. # if len(segments) >= 1 and isinstance(segments[0], dict):
  172. # for col in seg1_cols:
  173. # result[f'seg1_{col}'] = segments[0].get(col)
  174. # else:
  175. # for col in seg1_cols:
  176. # result[f'seg1_{col}'] = missing
  177. # # 第二段
  178. # if len(segments) >= 2 and isinstance(segments[1], dict):
  179. # for col in seg2_cols:
  180. # result[f'seg2_{col}'] = segments[1].get(col)
  181. # else:
  182. # for col in seg2_cols:
  183. # result[f'seg2_{col}'] = missing
  184. # else:
  185. # # segments 不是 list,全都置空
  186. # for col in seg1_cols:
  187. # result[f'seg1_{col}'] = missing
  188. # for col in seg2_cols:
  189. # result[f'seg2_{col}'] = missing
  190. # return pd.Series(result)
  191. # # 一次 apply
  192. # df_segments = df.apply(extract_segments, axis=1)
  193. # # 拼回原 df
  194. # df = pd.concat([df.drop(columns=['segments'], errors='ignore'), df_segments], axis=1)
  195. # # 统一转换时间字段为 datetime
  196. # time_cols = [
  197. # 'seg1_dep_time', 'seg1_arr_time',
  198. # 'seg2_dep_time', 'seg2_arr_time'
  199. # ]
  200. # for col in time_cols:
  201. # if col in df.columns:
  202. # df[col] = pd.to_datetime(
  203. # df[col],
  204. # format='%Y%m%d%H%M%S',
  205. # errors='coerce'
  206. # )
  207. # # 站点来源 -> 是否近期
  208. # df['source_website'] = np.where(
  209. # df['source_website'].str.contains('7_30'),
  210. # 0, # 远期 -> 0
  211. # np.where(df['source_website'].str.contains('0_7'),
  212. # 1, # 近期 -> 1
  213. # df['source_website']) # 其他情况保持原值
  214. # )
  215. # # 行李配额字符 -> 数字
  216. # conditions = [
  217. # df['seg1_baggage'] == '-;-;-;-',
  218. # df['seg1_baggage'] == '1-20',
  219. # df['seg1_baggage'] == '1-30',
  220. # df['seg1_baggage'] == '1-40',
  221. # ]
  222. # choices = [0, 20, 30, 40]
  223. # df['seg1_baggage'] = np.select(conditions, choices, default=df['seg1_baggage'])
  224. # # 重命名字段
  225. # df = df.rename(columns={
  226. # 'seg1_cabin': 'cabin',
  227. # 'seg1_baggage': 'baggage',
  228. # 'source_website': 'is_near',
  229. # })
  230. # return df
  231. def expand_segments_columns_optimized(df):
  232. """优化版的展开segments函数(避免逐行apply)"""
  233. if df.empty:
  234. return df
  235. df = df.copy()
  236. # 直接操作segments列表,避免逐行apply
  237. if 'segments' in df.columns:
  238. # 提取第一段信息
  239. seg1_cols = ['flight_number', 'dep_air_port', 'dep_time', 'arr_air_port', 'arr_time', 'cabin', 'baggage']
  240. # 提取第二段信息
  241. seg2_cols = ['flight_number', 'dep_air_port', 'dep_time', 'arr_air_port', 'arr_time']
  242. # 使用列表推导式替代apply,大幅提升性能
  243. seg1_data = []
  244. seg2_data = []
  245. for segments in df['segments']:
  246. seg1_dict = {}
  247. seg2_dict = {}
  248. if isinstance(segments, list) and len(segments) >= 1 and isinstance(segments[0], dict):
  249. for col in seg1_cols:
  250. seg1_dict[f'seg1_{col}'] = segments[0].get(col)
  251. else:
  252. for col in seg1_cols:
  253. seg1_dict[f'seg1_{col}'] = pd.NA
  254. if isinstance(segments, list) and len(segments) >= 2 and isinstance(segments[1], dict):
  255. for col in seg2_cols:
  256. seg2_dict[f'seg2_{col}'] = segments[1].get(col)
  257. else:
  258. for col in seg2_cols:
  259. seg2_dict[f'seg2_{col}'] = pd.NA
  260. seg1_data.append(seg1_dict)
  261. seg2_data.append(seg2_dict)
  262. # 创建DataFrame
  263. df_seg1 = pd.DataFrame(seg1_data, index=df.index)
  264. df_seg2 = pd.DataFrame(seg2_data, index=df.index)
  265. # 合并到原DataFrame
  266. df = pd.concat([df.drop(columns=['segments'], errors='ignore'), df_seg1, df_seg2], axis=1)
  267. # 后续处理保持不变
  268. time_cols = ['seg1_dep_time', 'seg1_arr_time', 'seg2_dep_time', 'seg2_arr_time']
  269. for col in time_cols:
  270. if col in df.columns:
  271. df[col] = pd.to_datetime(df[col], format='%Y%m%d%H%M%S', errors='coerce')
  272. df['source_website'] = np.where(
  273. df['source_website'].str.contains('7_30'), 0,
  274. np.where(df['source_website'].str.contains('0_7'), 1, df['source_website'])
  275. )
  276. conditions = [
  277. df['seg1_baggage'] == '-;-;-;-',
  278. df['seg1_baggage'] == '1-20',
  279. df['seg1_baggage'] == '1-30',
  280. df['seg1_baggage'] == '1-40',
  281. ]
  282. choices = [0, 20, 30, 40]
  283. df['seg1_baggage'] = np.select(conditions, choices, default=df['seg1_baggage'])
  284. df = df.rename(columns={
  285. 'seg1_cabin': 'cabin',
  286. 'seg1_baggage': 'baggage',
  287. 'source_website': 'is_near',
  288. })
  289. return df
  290. def fill_hourly_crawl_date(df, head_fill=0, rear_fill=0):
  291. """补齐成小时粒度数据"""
  292. df = df.copy()
  293. # 1. 转 datetime
  294. df['crawl_date'] = pd.to_datetime(df['crawl_date'])
  295. # 添加一个用于分组的小时字段
  296. df['update_hour'] = df['crawl_date'].dt.floor('h')
  297. # 2. 排序规则:同一小时内,按原始时间戳排序
  298. # 假设你想保留最早的一条
  299. df = df.sort_values(['update_hour', 'crawl_date'])
  300. # 3. 按小时去重,保留该小时内最早(最晚)的一条
  301. df = df.drop_duplicates(subset=['update_hour'], keep='last') # keep='first' keep='last'
  302. # 删除原始时间戳列
  303. # df = df.drop(columns=['crawl_date'])
  304. # df = df.drop(columns=['_id'])
  305. # 4. 标记原始数据
  306. df['is_filled'] = 0
  307. # 5. 排序 + 设索引
  308. df = df.sort_values('update_hour').set_index('update_hour')
  309. # 6. 构造完整小时轴
  310. start_of_day = df.index.min() # 默认 第一天 最早 开始
  311. if head_fill == 1:
  312. start_of_day = df.index.min().normalize() # 强制 第一天 00:00 开始
  313. end_of_day = df.index.max() # 默认 最后一天 最晚 结束
  314. if rear_fill == 1:
  315. end_of_day = df.index.max().normalize() + pd.Timedelta(hours=23) # 强制 最后一天 23:00 结束
  316. elif rear_fill == 2:
  317. if 'seg1_dep_time' in df.columns:
  318. last_dep_time = df['seg1_dep_time'].iloc[-1]
  319. if pd.notna(last_dep_time):
  320. # 对齐到整点小时(向下取整)
  321. end_of_day = last_dep_time.floor('h')
  322. full_index = pd.date_range(
  323. start=start_of_day,
  324. end=end_of_day,
  325. freq='1h'
  326. )
  327. # 7. 按小时补齐
  328. df = df.reindex(full_index)
  329. # 先恢复 dtype(关键!)
  330. df = df.infer_objects(copy=False)
  331. # 8. 新增出来的行标记为 1
  332. df['is_filled'] = df['is_filled'].fillna(1)
  333. # 9. 前向填充
  334. df = df.ffill()
  335. # 10. 还原整型字段
  336. int_cols = [
  337. 'seats_remaining',
  338. 'is_near',
  339. 'baggage',
  340. 'is_filled',
  341. ]
  342. for col in int_cols:
  343. if col in df.columns:
  344. df[col] = df[col].astype('int64')
  345. # 10.5 价格字段统一保留两位小数
  346. price_cols = [
  347. 'adult_price',
  348. 'adult_tax',
  349. 'adult_total_price'
  350. ]
  351. for col in price_cols:
  352. if col in df.columns:
  353. df[col] = df[col].astype('float64').round(2)
  354. # 10.6 新增:距离起飞还有多少小时
  355. if 'seg1_dep_time' in df.columns:
  356. # 创建临时字段(整点)
  357. df['seg1_dep_hour'] = df['seg1_dep_time'].dt.floor('h')
  358. # 计算小时差 df.index 此时就是 update_hour
  359. df['hours_until_departure'] = (
  360. (df['seg1_dep_hour'] - df.index) / pd.Timedelta(hours=1)
  361. ).astype('int64')
  362. # 新增:距离起飞还有多少天
  363. df['days_to_departure'] = (df['hours_until_departure'] // 24).astype('int64')
  364. # 删除临时字段
  365. df = df.drop(columns=['seg1_dep_hour'])
  366. # 11. 写回 update_hour
  367. df['update_hour'] = df.index
  368. # 12. 恢复普通索引
  369. df = df.reset_index(drop=True)
  370. return df
  371. def query_groups_of_city_code(db, from_city, to_city, table_name, min_days=10, max_retries=3, base_sleep=1.0):
  372. """
  373. 从一组城市对中查找所有分组(航班号与起飞时间)的组合
  374. 按:第一段航班号 → 第二段航班号 → 起飞时间 排序
  375. (失败自动重试) 保证1个月内至少有10天起飞的航线
  376. 说明:为降低 Mongo 的聚合负担,这里只做轻量 find + 投影,把“按天统计/按航班组合汇总”的逻辑放到 pandas 侧处理。
  377. """
  378. print(f"{from_city}-{to_city} 查找所有分组")
  379. date_begin = (datetime.today() - timedelta(days=31)).strftime("%Y%m%d")
  380. date_end = datetime.today().strftime("%Y%m%d")
  381. query = {
  382. "from_city_code": from_city,
  383. "to_city_code": to_city,
  384. "search_dep_time": {"$gte": date_begin, "$lte": date_end},
  385. }
  386. projection = {
  387. "_id": 0,
  388. "search_dep_time": 1,
  389. "segments.flight_number": 1,
  390. }
  391. def _extract_flight_numbers(segments):
  392. if not isinstance(segments, list):
  393. return []
  394. out = []
  395. for seg in segments:
  396. if not isinstance(seg, dict):
  397. continue
  398. fn = seg.get("flight_number")
  399. if fn:
  400. out.append(fn)
  401. return out
  402. for attempt in range(1, max_retries + 1):
  403. try:
  404. print(f" 第 {attempt}/{max_retries} 次尝试查询")
  405. collection = db[table_name]
  406. cursor = collection.find(query, projection=projection).batch_size(5000).hint('from_city_code_1_to_city_code_1_search_dep_time_1')
  407. docs = list(cursor)
  408. if not docs:
  409. return []
  410. df = pd.DataFrame.from_records(docs)
  411. if df.empty or "segments" not in df.columns or "search_dep_time" not in df.columns:
  412. return []
  413. df["flight_numbers"] = df["segments"].apply(_extract_flight_numbers)
  414. df["fn1"] = df["flight_numbers"].str[0].fillna("")
  415. df["fn2"] = df["flight_numbers"].str[1].fillna("")
  416. df["flight_numbers_key"] = df["flight_numbers"].apply(lambda xs: ",".join(xs) if xs else "")
  417. day_counts = (
  418. df.groupby(["flight_numbers_key", "fn1", "fn2", "search_dep_time"], dropna=False)
  419. .size()
  420. .reset_index(name="count")
  421. .sort_values(["fn1", "fn2", "search_dep_time"], kind="mergesort")
  422. .reset_index(drop=True)
  423. )
  424. keys = ["flight_numbers_key", "fn1", "fn2"]
  425. df_days = day_counts.groupby(keys, sort=False).size().reset_index(name="days")
  426. df_details = (
  427. day_counts.groupby(keys, sort=False)
  428. .apply(lambda g: g[["search_dep_time", "count"]].to_dict("records"))
  429. .reset_index(name="details")
  430. )
  431. df_result = df_days.merge(df_details, on=keys, how="inner")
  432. df_result = df_result[df_result["days"] >= min_days].sort_values(["fn1", "fn2"], kind="mergesort")
  433. formatted_results = []
  434. for _, row in df_result.iterrows():
  435. flight_numbers = row["flight_numbers_key"].split(",") if row["flight_numbers_key"] else []
  436. formatted_results.append(
  437. {
  438. "flight_numbers": flight_numbers,
  439. "days": int(row["days"]),
  440. "details": row["details"],
  441. }
  442. )
  443. del df_result
  444. del df_details
  445. del df_days
  446. del df
  447. # gc.collect()
  448. return formatted_results
  449. except (ServerSelectionTimeoutError, PyMongoError) as e:
  450. print(f"⚠️ Mongo 查询失败: {e}")
  451. if attempt == max_retries:
  452. print("❌ 达到最大重试次数,放弃")
  453. return []
  454. # 指数退避 + 随机抖动
  455. sleep_time = base_sleep * (2 ** (attempt - 1)) + random.random()
  456. print(f"⏳ {sleep_time:.2f}s 后重试...")
  457. time.sleep(sleep_time)
  458. def plot_c12_trend(df, output_dir="."):
  459. """
  460. 根据传入的 dataframe 绘制 adult_total_price 随 update_hour 的趋势图,
  461. 并按照 baggage 分类进行分组绘制。
  462. """
  463. # output_dir_photo = output_dir
  464. # 颜色与线型配置(按顺序循环使用)
  465. colors = ['blue', 'red', 'brown']
  466. linestyles = ['--', '--', '--']
  467. # 确保时间字段为 datetime 类型
  468. if not hasattr(df['update_hour'], 'dt'):
  469. df['update_hour'] = pd.to_datetime(df['update_hour'])
  470. from_city = df['from_city_code'].mode().iloc[0]
  471. to_city = df['to_city_code'].mode().iloc[0]
  472. flight_number_1 = df['seg1_flight_number'].mode().iloc[0]
  473. flight_number_2 = df['seg2_flight_number'].mode().get(0, "")
  474. dep_time = df['seg1_dep_time'].mode().iloc[0]
  475. route = f"{from_city}-{to_city}"
  476. flight_number = f"{flight_number_1},{flight_number_2}" if flight_number_2 else f"{flight_number_1}"
  477. output_dir_photo = os.path.join(output_dir, route)
  478. os.makedirs(output_dir_photo, exist_ok=True)
  479. # 创建图表对象
  480. fig = plt.figure(figsize=(14, 8))
  481. # 按 baggage 分类绘制
  482. for i, (baggage_value, group) in enumerate(df.groupby('baggage')):
  483. # 按时间排序
  484. g = group.sort_values('update_hour').reset_index(drop=True)
  485. # 找价格变化点:与前一行不同的价格即为变化点
  486. # keep first row + change rows + last row
  487. change_points = g.loc[
  488. (g['adult_total_price'] != g['adult_total_price'].shift(1)) |
  489. (g.index == 0) |
  490. (g.index == len(g) - 1) # 终点
  491. ].drop_duplicates(subset=['update_hour'])
  492. # 绘制点和线条
  493. plt.plot(
  494. change_points['update_hour'],
  495. change_points['adult_total_price'],
  496. marker='o',
  497. color=colors[i % len(colors)],
  498. linestyle=linestyles[i % len(linestyles)],
  499. linewidth=2, markersize=6,
  500. markerfacecolor='white', markeredgewidth=2,
  501. label=f"Baggage {baggage_value}"
  502. )
  503. # 添加注释 (小时数, 价格)
  504. for _, row in change_points.iterrows():
  505. text = f"({row['hours_until_departure']}, {row['adult_total_price']})"
  506. plt.annotate(
  507. text,
  508. xy=(row['update_hour'], row['adult_total_price']),
  509. xytext=(0, 0), # 向右偏移
  510. textcoords="offset points",
  511. ha='left',
  512. va='center',
  513. fontsize=5, # 字体稍小
  514. color='gray',
  515. alpha=0.8,
  516. rotation=25,
  517. )
  518. # 自动优化日期显示
  519. plt.gcf().autofmt_xdate()
  520. plt.xlabel('时刻', fontsize=12, fontproperties=font_prop)
  521. plt.ylabel('价格', fontsize=12, fontproperties=font_prop)
  522. plt.title(f'价格变化趋势 - 航线: {route} 航班号: {flight_number}\n起飞时间: {dep_time}',
  523. fontsize=14, fontweight='bold', fontproperties=font_prop)
  524. # 设置 x 轴刻度为每天
  525. ax = plt.gca()
  526. ax.xaxis.set_major_locator(mdates.DayLocator(interval=1)) # 每天一个主刻度
  527. ax.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d')) # 显示月-日
  528. ax.xaxis.set_minor_locator(mdates.HourLocator(byhour=[12])) # 指定在12:00显示副刻度
  529. ax.xaxis.set_minor_formatter(mdates.DateFormatter('')) # 输出空字符串
  530. # ax.tick_params(axis='x', which='minor', labelsize=8, rotation=30)
  531. # 添加图例
  532. plt.legend(bbox_to_anchor=(1.05, 1), loc='best', prop=font_prop)
  533. plt.grid(True, alpha=0.3)
  534. plt.tight_layout()
  535. safe_flight = flight_number.replace(",", "_")
  536. safe_dep_time = dep_time.strftime("%Y-%m-%d %H%M%S")
  537. save_file = f"{route} {safe_flight} {safe_dep_time}.png"
  538. output_path = os.path.join(output_dir_photo, save_file)
  539. # 保存图片(在显示之前)
  540. plt.savefig(output_path, dpi=300, bbox_inches='tight', facecolor='white')
  541. # 关闭图形释放内存
  542. plt.close(fig)
  543. _ROUTE_CACHE_DF1 = None
  544. _ROUTE_CACHE_DF2 = None
  545. def _init_route_cache_worker(df1_pickle_path, df2_pickle_path):
  546. global _ROUTE_CACHE_DF1, _ROUTE_CACHE_DF2
  547. _ROUTE_CACHE_DF1 = pd.read_pickle(df1_pickle_path)
  548. _ROUTE_CACHE_DF2 = pd.read_pickle(df2_pickle_path)
  549. def _filter_df_by_flight_nums(df, flight_nums):
  550. if df is None or df.empty:
  551. return pd.DataFrame()
  552. out = df
  553. flight_nums_filter = list(flight_nums) if flight_nums else []
  554. for i, flight_num in enumerate(flight_nums_filter):
  555. if flight_num is None or flight_num == "":
  556. continue
  557. col = f"seg{i + 1}_flight_number"
  558. if col not in out.columns:
  559. return out.iloc[0:0].copy()
  560. out = out[out[col].astype("string") == str(flight_num)]
  561. if out.empty:
  562. return out
  563. return out
  564. def process_flight_group(args):
  565. """处理单个航班号的进程函数(基于主进程缓存的数据做 pandas 过滤与处理)"""
  566. process_id, each_group, is_train, plot_flag, output_dir = args
  567. flight_nums = each_group.get("flight_numbers")
  568. # details = each_group.get("details")
  569. print(f"[进程{process_id}] 开始处理航班号: {flight_nums}")
  570. try:
  571. df1 = _filter_df_by_flight_nums(_ROUTE_CACHE_DF1, flight_nums)
  572. df2 = _filter_df_by_flight_nums(_ROUTE_CACHE_DF2, flight_nums)
  573. sort_cols = [c for c in ["search_dep_time", "baggage", "crawl_date"] if c in df1.columns]
  574. if sort_cols:
  575. df1 = df1.sort_values(sort_cols, kind="mergesort").reset_index(drop=True)
  576. df2 = df2.sort_values(sort_cols, kind="mergesort").reset_index(drop=True)
  577. if df1.empty:
  578. print(f"[进程{process_id}] 航班号:{flight_nums} 远期表无数据, 跳过")
  579. return pd.DataFrame()
  580. if df2.empty:
  581. print(f"[进程{process_id}] 航班号:{flight_nums} 近期表无数据, 跳过")
  582. return pd.DataFrame()
  583. # 起飞天数、行李配额以近期表的为主
  584. if df2.empty:
  585. common_dep_dates = []
  586. common_baggages = []
  587. else:
  588. common_dep_dates = df2['search_dep_time'].unique()
  589. common_baggages = df2['baggage'].unique()
  590. # 如果是预测,起飞天数以远期表为主
  591. if not is_train:
  592. common_dep_dates = df1['search_dep_time'].unique()
  593. list_mid = []
  594. for dep_date in common_dep_dates:
  595. # 起飞日期筛选
  596. df_d1 = df1[df1["search_dep_time"] == dep_date].copy()
  597. if not df_d1.empty:
  598. for col in ["seg1_dep_time", "seg1_arr_time", "seg2_dep_time", "seg2_arr_time"]:
  599. mode_series_1 = df_d1[col].mode()
  600. if mode_series_1.empty:
  601. zong_1 = pd.NaT
  602. else:
  603. zong_1 = mode_series_1.iloc[0]
  604. df_d1[col] = zong_1
  605. df_d2 = df2[df2["search_dep_time"] == dep_date].copy()
  606. if not df_d2.empty:
  607. for col in ["seg1_dep_time", "seg1_arr_time", "seg2_dep_time", "seg2_arr_time"]:
  608. mode_series_2 = df_d2[col].mode()
  609. if mode_series_2.empty:
  610. zong_2 = pd.NaT
  611. else:
  612. zong_2 = mode_series_2.iloc[0]
  613. df_d2[col] = zong_2
  614. list_12 = []
  615. for baggage in common_baggages:
  616. # 行李配额筛选
  617. df_b1 = df_d1[df_d1["baggage"] == baggage].copy()
  618. df_b2 = df_d2[df_d2["baggage"] == baggage].copy()
  619. # 合并前检查是否都有数据
  620. if df_b1.empty and df_b2.empty:
  621. print(f"[进程{process_id}] ⚠️ dep_date:{dep_date}, baggage:{baggage} 远期表和近期表都为空,跳过")
  622. continue
  623. cols = ["seg1_flight_number", "seg1_dep_air_port", "seg1_arr_air_port",
  624. "seg2_flight_number", "seg2_dep_air_port", "seg2_arr_air_port"]
  625. df_b1[cols] = df_b1[cols].astype("string")
  626. df_b2[cols] = df_b2[cols].astype("string")
  627. df_b12 = pd.concat([df_b1, df_b2]).reset_index(drop=True)
  628. # print(f"📊 dep_date:{dep_date}, baggage:{baggage} 已将远期表和近期表合并,形状: {df_b12.shape}")
  629. df_b12 = fill_hourly_crawl_date(df_b12, rear_fill=2)
  630. # print(f"📊 dep_date:{dep_date}, baggage:{baggage} 已合并且补齐为完整小时序列,形状: {df_b12.shape}")
  631. list_12.append(df_b12)
  632. del df_b12
  633. del df_b2
  634. del df_b1
  635. if list_12:
  636. df_c12 = pd.concat(list_12, ignore_index=True)
  637. if plot_flag:
  638. print(f"[进程{process_id}] ✅ dep_date:{dep_date}, 所有 baggage 数据合并完成,总形状: {df_c12.shape}")
  639. plot_c12_trend(df_c12, output_dir)
  640. print(f"[进程{process_id}] ✅ dep_date:{dep_date}, 所有 baggage 数据绘图完成")
  641. else:
  642. df_c12 = pd.DataFrame()
  643. if plot_flag:
  644. print(f"[进程{process_id}] ⚠️ dep_date:{dep_date}, 所有 baggage 数据合并为空")
  645. del list_12
  646. list_mid.append(df_c12)
  647. del df_c12
  648. del df_d1
  649. del df_d2
  650. # print(f"结束处理起飞日期: {dep_date}")
  651. if list_mid:
  652. df_mid = pd.concat(list_mid, ignore_index=True)
  653. print(f"[进程{process_id}] ✅ 航班号:{flight_nums} 所有 起飞日期 数据合并完成,总形状: {df_mid.shape}")
  654. else:
  655. df_mid = pd.DataFrame()
  656. print(f"[进程{process_id}] ⚠️ 航班号:{flight_nums} 所有 起飞日期 数据合并为空")
  657. del list_mid
  658. del df1
  659. del df2
  660. gc.collect()
  661. print(f"[进程{process_id}] 结束处理航班号: {flight_nums}")
  662. return df_mid
  663. except Exception as e:
  664. print(f"[进程{process_id}] ❌ 处理航班号:{flight_nums} 时发生异常: {e}")
  665. return pd.DataFrame()
  666. finally:
  667. pass
  668. def load_train_data(db_config, flight_route_list, table_name, date_begin, date_end, output_dir='.', is_hot=1, is_train=True, plot_flag=False,
  669. use_multiprocess=False, max_workers=None):
  670. """加载训练数据(支持多进程)"""
  671. timestamp_str = datetime.now().strftime("%Y%m%d%H%M%S")
  672. date_begin_s = datetime.strptime(date_begin, "%Y-%m-%d").strftime("%Y%m%d") # 查询时的格式
  673. date_end_s = datetime.strptime(date_end, "%Y-%m-%d").strftime("%Y%m%d")
  674. list_all = []
  675. # 每一航线对
  676. for flight_route in flight_route_list:
  677. from_city = flight_route.split('-')[0]
  678. to_city = flight_route.split('-')[1]
  679. route = f"{from_city}-{to_city}"
  680. print(f"开始处理航线: {route}")
  681. # 在主进程中查询航班号分组(避免多进程重复查询)
  682. main_client, main_db = mongo_con_parse(db_config, reuse_client=True)
  683. all_groups = query_groups_of_city_code(main_db, from_city, to_city, table_name)
  684. all_groups_len = len(all_groups)
  685. print(f"该航线共有{all_groups_len}个航班号")
  686. if all_groups_len == 0:
  687. continue
  688. # 查询远期表
  689. if is_hot == 1:
  690. df1 = query_flight_range_status(main_db, CLEAN_VJ_HOT_FAR_INFO_TAB, from_city, to_city,
  691. date_begin_s, date_end_s, None)
  692. else:
  693. df1 = query_flight_range_status(main_db, CLEAN_VJ_NOTHOT_FAR_INFO_TAB, from_city, to_city,
  694. date_begin_s, date_end_s, None)
  695. # 保证远期表里有数据
  696. if df1.empty:
  697. print(f"[主进程] 航线:{route} 远期表无数据, 跳过")
  698. # main_client.close()
  699. return pd.DataFrame()
  700. # 查询近期表
  701. if is_hot == 1:
  702. df2 = query_flight_range_status(main_db, CLEAN_VJ_HOT_NEAR_INFO_TAB, from_city, to_city,
  703. date_begin_s, date_end_s, None)
  704. else:
  705. df2 = query_flight_range_status(main_db, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB, from_city, to_city,
  706. date_begin_s, date_end_s, None)
  707. # 保证近期表里有数据
  708. if df2.empty:
  709. print(f"[主进程] 航线:{route} 近期表无数据, 跳过")
  710. # main_client.close()
  711. return pd.DataFrame()
  712. # main_client.close()
  713. os.makedirs(output_dir, exist_ok=True)
  714. safe_route = route.replace("-", "_")
  715. df1_fd, df1_cache_path = tempfile.mkstemp(prefix=f"route_{safe_route}_{timestamp_str}_far_", suffix=".pkl", dir=output_dir)
  716. df2_fd, df2_cache_path = tempfile.mkstemp(prefix=f"route_{safe_route}_{timestamp_str}_near_", suffix=".pkl", dir=output_dir)
  717. os.close(df1_fd)
  718. os.close(df2_fd)
  719. df1.to_pickle(df1_cache_path)
  720. df2.to_pickle(df2_cache_path)
  721. try:
  722. if use_multiprocess and all_groups_len > 1:
  723. print(f"启用多进程处理,最大进程数: {max_workers}")
  724. process_args = []
  725. process_id = 0
  726. for each_group in all_groups:
  727. process_id += 1
  728. args = (process_id, each_group, is_train, plot_flag, output_dir)
  729. process_args.append(args)
  730. with ProcessPoolExecutor(
  731. max_workers=max_workers,
  732. initializer=_init_route_cache_worker,
  733. initargs=(df1_cache_path, df2_cache_path),
  734. ) as executor:
  735. future_to_group = {executor.submit(process_flight_group, args): each_group for args, each_group in zip(process_args, all_groups)}
  736. for future in as_completed(future_to_group):
  737. each_group = future_to_group[future]
  738. flight_nums = each_group.get("flight_numbers", "未知")
  739. try:
  740. df_mid = future.result()
  741. if not df_mid.empty:
  742. list_all.append(df_mid)
  743. print(f"✅ 航班号:{flight_nums} 处理完成")
  744. else:
  745. print(f"⚠️ 航班号:{flight_nums} 处理结果为空")
  746. except Exception as e:
  747. print(f"❌ 航班号:{flight_nums} 处理异常: {e}")
  748. else:
  749. # 单进程处理(进程编号为0)
  750. print("使用单进程处理")
  751. global _ROUTE_CACHE_DF1, _ROUTE_CACHE_DF2
  752. _ROUTE_CACHE_DF1 = df1
  753. _ROUTE_CACHE_DF2 = df2
  754. process_id = 0
  755. for each_group in all_groups:
  756. args = (process_id, each_group, is_train, plot_flag, output_dir)
  757. flight_nums = each_group.get("flight_numbers", "未知")
  758. try:
  759. df_mid = process_flight_group(args)
  760. if not df_mid.empty:
  761. list_all.append(df_mid)
  762. print(f"✅ 航班号:{flight_nums} 处理完成")
  763. else:
  764. print(f"⚠️ 航班号:{flight_nums} 处理结果为空")
  765. except Exception as e:
  766. print(f"❌ 航班号:{flight_nums} 处理异常: {e}")
  767. finally:
  768. try:
  769. os.remove(df1_cache_path)
  770. except Exception:
  771. pass
  772. try:
  773. os.remove(df2_cache_path)
  774. except Exception:
  775. pass
  776. print(f"结束处理航线: {from_city}-{to_city}")
  777. if list_all:
  778. df_all = pd.concat(list_all, ignore_index=True)
  779. else:
  780. df_all = pd.DataFrame()
  781. print(f"本批次数据加载完毕, 总形状: {df_all.shape}")
  782. del list_all
  783. gc.collect()
  784. return df_all
  785. def query_all_flight_number(db, table_name):
  786. print(f"{table_name} 查找所有航班号")
  787. pipeline = [
  788. {
  789. "$project": {
  790. "flight_numbers": "$segments.flight_number"
  791. }
  792. },
  793. {
  794. "$group": {
  795. "_id": "$flight_numbers",
  796. "count": { "$sum": 1 }
  797. }
  798. },
  799. ]
  800. # 执行聚合查询
  801. collection = db[table_name]
  802. results = list(collection.aggregate(pipeline))
  803. list_flight_number = []
  804. for item in results:
  805. item_li = item.get("_id", [])
  806. list_flight_number.extend(item_li)
  807. list_flight_number = list(set(list_flight_number))
  808. return list_flight_number
  809. def validate_one_line(db, table_name, city_pair, flight_day, flight_number_1, flight_number_2, baggage, valid_begin_hour,
  810. limit=0, max_retries=3, base_sleep=1.0):
  811. """验证预测结果的一行"""
  812. # if city_pair in vj_flight_route_list_hot:
  813. # table_name = CLEAN_VJ_HOT_NEAR_INFO_TAB
  814. # elif city_pair in vj_flight_route_list_nothot:
  815. # table_name = CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
  816. # else:
  817. # print(f"城市对{city_pair}不在热门航线与冷门航线, 返回")
  818. # return pd.DataFrame()
  819. city_pair_split = city_pair.split('-')
  820. from_city_code = city_pair_split[0]
  821. to_city_code = city_pair_split[1]
  822. flight_day_str = datetime.strptime(flight_day, "%Y-%m-%d").strftime("%Y%m%d")
  823. baggage_str = f"1-{baggage}"
  824. if baggage == 0:
  825. baggage_str = "-;-;-;-"
  826. for attempt in range(1, max_retries + 1):
  827. try:
  828. print(f"🔁 第 {attempt}/{max_retries} 次尝试查询")
  829. # 构建查询条件
  830. query_condition = {
  831. "from_city_code": from_city_code,
  832. "to_city_code": to_city_code,
  833. "search_dep_time": flight_day_str,
  834. "segments.baggage": baggage_str,
  835. "crawl_date": {"$gte": valid_begin_hour},
  836. "segments.0.flight_number": flight_number_1,
  837. }
  838. # 如果有第二段
  839. if flight_number_2 != "VJ":
  840. query_condition["segments.1.flight_number"] = flight_number_2
  841. print(f" 查询条件: {query_condition}")
  842. # 定义要查询的字段
  843. projection = {
  844. # "_id": 1,
  845. "from_city_code": 1,
  846. "search_dep_time": 1,
  847. "to_city_code": 1,
  848. "currency": 1,
  849. "adult_price": 1,
  850. "adult_tax": 1,
  851. "adult_total_price": 1,
  852. "seats_remaining": 1,
  853. "segments": 1,
  854. "source_website": 1,
  855. "crawl_date": 1
  856. }
  857. # 执行查询
  858. cursor = db.get_collection(table_name).find(
  859. query_condition,
  860. projection=projection # 添加投影参数
  861. ).sort(
  862. [
  863. ("crawl_date", 1)
  864. ]
  865. )
  866. if limit > 0:
  867. cursor = cursor.limit(limit)
  868. # 将结果转换为列表
  869. results = list(cursor)
  870. print(f"✅ 查询成功,找到 {len(results)} 条记录")
  871. if results:
  872. df = pd.DataFrame(results)
  873. # 处理特殊的 ObjectId 类型
  874. if '_id' in df.columns:
  875. df = df.drop(columns=['_id'])
  876. print(f"📊 已转换为 DataFrame,形状: {df.shape}")
  877. # 1️⃣ 展开 segments
  878. print(f"📊 开始扩展segments 稍等...")
  879. t1 = time.time()
  880. df = expand_segments_columns_optimized(df)
  881. t2 = time.time()
  882. rt = round(t2 - t1, 3)
  883. print(f"用时: {rt} 秒")
  884. print(f"📊 已将segments扩展成字段,形状: {df.shape}")
  885. # 不用排序,因为mongo语句已经排好
  886. return df
  887. else:
  888. print("⚠️ 查询结果为空")
  889. return pd.DataFrame()
  890. except (ServerSelectionTimeoutError, PyMongoError) as e:
  891. print(f"⚠️ Mongo 查询失败: {e}")
  892. if attempt == max_retries:
  893. print("❌ 达到最大重试次数,放弃")
  894. return pd.DataFrame()
  895. # 指数退避 + 随机抖动
  896. sleep_time = base_sleep * (2 ** (attempt - 1)) + random.random()
  897. print(f"⏳ {sleep_time:.2f}s 后重试...")
  898. time.sleep(sleep_time)
  899. def validate_keep_one_line(db, table_name, city_pair, flight_day, flight_number_1, flight_number_2, baggage, update_hour_str, del_batch_std_str,
  900. limit=0, max_retries=3, base_sleep=1.0):
  901. """验证keep_info的一行"""
  902. city_pair_split = city_pair.split('-')
  903. from_city_code = city_pair_split[0]
  904. to_city_code = city_pair_split[1]
  905. flight_day_str = datetime.strptime(flight_day, "%Y-%m-%d").strftime("%Y%m%d")
  906. baggage_str = f"1-{baggage}"
  907. if baggage == 0:
  908. baggage_str = "-;-;-;-"
  909. for attempt in range(1, max_retries + 1):
  910. try:
  911. print(f"🔁 第 {attempt}/{max_retries} 次尝试查询")
  912. # 构建查询条件
  913. query_condition = {
  914. "from_city_code": from_city_code,
  915. "to_city_code": to_city_code,
  916. "search_dep_time": flight_day_str,
  917. "segments.baggage": baggage_str,
  918. "crawl_date": {"$gte": update_hour_str, "$lt": del_batch_std_str},
  919. "segments.0.flight_number": flight_number_1,
  920. }
  921. # 如果有第二段
  922. if flight_number_2 != "VJ":
  923. query_condition["segments.1.flight_number"] = flight_number_2
  924. print(f" 查询条件: {query_condition}")
  925. # 定义要查询的字段
  926. projection = {
  927. # "_id": 1,
  928. "from_city_code": 1,
  929. "search_dep_time": 1,
  930. "to_city_code": 1,
  931. "currency": 1,
  932. "adult_price": 1,
  933. "adult_tax": 1,
  934. "adult_total_price": 1,
  935. "seats_remaining": 1,
  936. "segments": 1,
  937. "source_website": 1,
  938. "crawl_date": 1
  939. }
  940. # 执行查询
  941. cursor = db.get_collection(table_name).find(
  942. query_condition,
  943. projection=projection # 添加投影参数
  944. ).sort(
  945. [
  946. ("crawl_date", 1)
  947. ]
  948. )
  949. if limit > 0:
  950. cursor = cursor.limit(limit)
  951. # 将结果转换为列表
  952. results = list(cursor)
  953. print(f"✅ 查询成功,找到 {len(results)} 条记录")
  954. if results:
  955. df = pd.DataFrame(results)
  956. # 处理特殊的 ObjectId 类型
  957. if '_id' in df.columns:
  958. df = df.drop(columns=['_id'])
  959. print(f"📊 已转换为 DataFrame,形状: {df.shape}")
  960. # 1️⃣ 展开 segments
  961. print(f"📊 开始扩展segments 稍等...")
  962. t1 = time.time()
  963. df = expand_segments_columns_optimized(df)
  964. t2 = time.time()
  965. rt = round(t2 - t1, 3)
  966. print(f"用时: {rt} 秒")
  967. print(f"📊 已将segments扩展成字段,形状: {df.shape}")
  968. # 不用排序,因为mongo语句已经排好
  969. return df
  970. else:
  971. print("⚠️ 查询结果为空")
  972. return pd.DataFrame()
  973. except (ServerSelectionTimeoutError, PyMongoError) as e:
  974. print(f"⚠️ Mongo 查询失败: {e}")
  975. if attempt == max_retries:
  976. print("❌ 达到最大重试次数,放弃")
  977. return pd.DataFrame()
  978. # 指数退避 + 随机抖动
  979. sleep_time = base_sleep * (2 ** (attempt - 1)) + random.random()
  980. print(f"⏳ {sleep_time:.2f}s 后重试...")
  981. time.sleep(sleep_time)
  982. if __name__ == "__main__":
  983. # test_mongo_connection(db)
  984. from utils import chunk_list_with_index
  985. cpu_cores = os.cpu_count() # 你的系统是72
  986. max_workers = min(8, cpu_cores) # 最大不超过8个进程
  987. output_dir = f"./output"
  988. os.makedirs(output_dir, exist_ok=True)
  989. # 加载热门航线数据
  990. date_begin = "2026-01-01"
  991. date_end = datetime.today().strftime("%Y-%m-%d")
  992. flight_route_list = vj_flight_route_list_hot[:] # 热门 vj_flight_route_list_hot 冷门 vj_flight_route_list_nothot
  993. # flight_route_list = ["SGN-NGO"] # 测试段
  994. table_name = CLEAN_VJ_HOT_NEAR_INFO_TAB # 热门 CLEAN_VJ_HOT_NEAR_INFO_TAB 冷门 CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
  995. is_hot = 1 # 1 热门 0 冷门
  996. group_size = 1
  997. chunks = chunk_list_with_index(flight_route_list, group_size)
  998. for idx, (_, group_route_list) in enumerate(chunks, 1):
  999. # 使用默认配置
  1000. # client, db = mongo_con_parse()
  1001. print(f"第 {idx} 组 :", group_route_list)
  1002. start_time = time.time()
  1003. load_train_data(mongodb_config, group_route_list, table_name, date_begin, date_end, output_dir, is_hot, plot_flag=True,
  1004. use_multiprocess=True, max_workers=max_workers)
  1005. end_time = time.time()
  1006. run_time = round(end_time - start_time, 3)
  1007. print(f"用时: {run_time} 秒")
  1008. # client.close()
  1009. time.sleep(3)
  1010. print("整体结束")
  1011. # client, db = mongo_con_parse()
  1012. # list_flight_number_1 = query_all_flight_number(db, CLEAN_VJ_HOT_NEAR_INFO_TAB)
  1013. # list_flight_number_2 = query_all_flight_number(db, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB)
  1014. # list_flight_number_all = list_flight_number_1 + list_flight_number_2
  1015. # list_flight_number_all = list(set(list_flight_number_all))
  1016. # list_flight_number_all.sort()
  1017. # print(list_flight_number_all)
  1018. # print(len(list_flight_number_all))
  1019. # flight_map = {v: i for i, v in enumerate(list_flight_number_all, start=1)}
  1020. # print(flight_map)