data_preprocess.py 56 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222
  1. import pandas as pd
  2. import numpy as np
  3. import bisect
  4. import gc
  5. import os
  6. from datetime import datetime, timedelta
  7. from sklearn.preprocessing import StandardScaler
  8. from config import city_to_country, vj_city_code_map, vi_flight_number_map, build_country_holidays
  9. from utils import insert_df_col
  10. COUNTRY_HOLIDAYS = build_country_holidays(city_to_country)
  11. def preprocess_data_cycle(df_input, interval_hours=8, feature_length=240, target_length=24, is_training=True):
  12. # df_input_part = df_input[(df_input['hours_until_departure'] >= current_n_hours) & (df_input['hours_until_departure'] < current_n_hours)].copy()
  13. df_input = preprocess_data_first_half(df_input)
  14. # 创建一个空列表来存储所有处理后的数据部分
  15. list_df_parts = []
  16. crop_lower_limit_list = [4] # [4, 28, 52, 76, 100]
  17. for crop_lower_limit in crop_lower_limit_list:
  18. target_n_hours = crop_lower_limit + target_length
  19. feature_n_hours = target_n_hours + interval_hours
  20. crop_upper_limit = feature_n_hours + feature_length
  21. df_input_part = preprocess_data(df_input, is_training=is_training, crop_upper_limit=crop_upper_limit, feature_n_hours=feature_n_hours,
  22. target_n_hours=target_n_hours, crop_lower_limit=crop_lower_limit)
  23. # 将处理后的部分添加到列表中
  24. list_df_parts.append(df_input_part)
  25. if not is_training:
  26. break
  27. # 合并所有处理后的数据部分
  28. if list_df_parts:
  29. df_combined = pd.concat(list_df_parts, ignore_index=True)
  30. return df_combined
  31. else:
  32. return pd.DataFrame() # 如果没有数据,返回空DataFrame
  33. def preprocess_data_first_half(df_input):
  34. '''前半部分'''
  35. print(">>> 开始数据预处理")
  36. # 生成 城市对
  37. df_input['city_pair'] = (
  38. df_input['from_city_code'].astype(str) + "-" + df_input['to_city_code'].astype(str)
  39. )
  40. # 城市码映射成数字
  41. df_input['from_city_num'] = df_input['from_city_code'].map(vj_city_code_map)
  42. df_input['to_city_num'] = df_input['to_city_code'].map(vj_city_code_map)
  43. missing_from = (
  44. df_input.loc[df_input['from_city_num'].isna(), 'from_city_code']
  45. .unique()
  46. )
  47. missing_to = (
  48. df_input.loc[df_input['to_city_num'].isna(), 'to_city_code']
  49. .unique()
  50. )
  51. if missing_from:
  52. print("未映射的 from_city:", missing_from)
  53. if missing_to:
  54. print("未映射的 to_city:", missing_to)
  55. # 把 city_pair、from_city_code、from_city_num, to_city_code, to_city_num 放到前几列
  56. cols = df_input.columns.tolist()
  57. # 删除已存在的几列(保证顺序正确)
  58. for c in ['city_pair', 'from_city_code', 'from_city_num', 'to_city_code', 'to_city_num']:
  59. cols.remove(c)
  60. # 这几列插入到最前面
  61. df_input = df_input[['city_pair', 'from_city_code', 'from_city_num', 'to_city_code', 'to_city_num'] + cols]
  62. pass
  63. # 转格式
  64. df_input['search_dep_time'] = pd.to_datetime(
  65. df_input['search_dep_time'],
  66. format='%Y%m%d',
  67. errors='coerce'
  68. ).dt.strftime('%Y-%m-%d')
  69. # 重命名起飞日期
  70. df_input.rename(columns={'search_dep_time': 'flight_day'}, inplace=True)
  71. # 重命名航班号
  72. df_input.rename(
  73. columns={
  74. 'seg1_flight_number': 'flight_number_1',
  75. 'seg2_flight_number': 'flight_number_2'
  76. },
  77. inplace=True
  78. )
  79. # 分开填充
  80. df_input['flight_number_1'] = df_input['flight_number_1'].fillna('VJ')
  81. df_input['flight_number_2'] = df_input['flight_number_2'].fillna('VJ')
  82. # 航班号转数字
  83. df_input['flight_1_num'] = df_input['flight_number_1'].map(vi_flight_number_map)
  84. df_input['flight_2_num'] = df_input['flight_number_2'].map(vi_flight_number_map)
  85. missing_flight_1 = (
  86. df_input.loc[df_input['flight_1_num'].isna(), 'flight_number_1']
  87. .unique()
  88. )
  89. missing_flight_2 = (
  90. df_input.loc[df_input['flight_2_num'].isna(), 'flight_number_2']
  91. .unique()
  92. )
  93. if missing_flight_1:
  94. print("未映射的 flight_1:", missing_flight_1)
  95. if missing_flight_2:
  96. print("未映射的 flight_2:", missing_flight_2)
  97. # flight_1_num 放在 seg1_dep_air_port 之前
  98. insert_df_col(df_input, 'flight_1_num', 'seg1_dep_air_port')
  99. # flight_2_num 放在 seg2_dep_air_port 之前
  100. insert_df_col(df_input, 'flight_2_num', 'seg2_dep_air_port')
  101. df_input['baggage_level'] = (df_input['baggage'] == 30).astype(int) # 30--> 1 20--> 0
  102. # baggage_level 放在 flight_number_2 之前
  103. insert_df_col(df_input, 'baggage_level', 'flight_number_2')
  104. df_input['Adult_Total_Price'] = df_input['adult_total_price']
  105. # Adult_Total_Price 放在 seats_remaining 之前 保存缩放前的原始值
  106. insert_df_col(df_input, 'Adult_Total_Price', 'seats_remaining')
  107. df_input['Hours_Until_Departure'] = df_input['hours_until_departure']
  108. # Hours_Until_Departure 放在 days_to_departure 之前 保存缩放前的原始值
  109. insert_df_col(df_input, 'Hours_Until_Departure', 'days_to_departure')
  110. pass
  111. # gid:基于指定字段的分组标记(整数)
  112. df_input['gid'] = (
  113. df_input
  114. .groupby(
  115. ['city_pair', 'flight_day', 'flight_number_1', 'flight_number_2'], # 'baggage' 先不进分组
  116. sort=False
  117. )
  118. .ngroup()
  119. )
  120. return df_input
  121. def preprocess_data(df_input, is_training=True, crop_upper_limit=480, feature_n_hours=36, target_n_hours=28, crop_lower_limit=4):
  122. print(f"裁剪范围: [{crop_lower_limit}, {crop_upper_limit}], 间隔窗口: [{target_n_hours}, {feature_n_hours}]")
  123. # 做一下时间段裁剪, 保留起飞前480小时之内且大于等于4小时的
  124. df_input = df_input[(df_input['hours_until_departure'] < crop_upper_limit) &
  125. (df_input['hours_until_departure'] >= crop_lower_limit)].reset_index(drop=True)
  126. # 在 gid 与 baggage 内按时间降序
  127. df_input = df_input.sort_values(
  128. by=['gid', 'baggage', 'hours_until_departure'],
  129. ascending=[True, True, False]
  130. ).reset_index(drop=True)
  131. # 价格幅度阈值
  132. VALID_DROP_MIN = 5
  133. # 价格变化掩码
  134. g = df_input.groupby(['gid', 'baggage'])
  135. diff = g['adult_total_price'].transform('diff')
  136. # change_mask = diff.abs() >= VALID_DROP_MIN # 变化太小的不计入
  137. decrease_mask = diff <= -VALID_DROP_MIN # 降价(变化太小的不计入)
  138. increase_mask = diff >= VALID_DROP_MIN # 升价(变化太小的不计入)
  139. df_input['_price_event_dir'] = np.where(increase_mask, 1, np.where(decrease_mask, -1, 0))
  140. # 计算连续升价/降价次数
  141. def _calc_price_streaks(df_group):
  142. dirs = df_group['_price_event_dir'].to_numpy()
  143. n = len(dirs)
  144. inc = np.full(n, np.nan)
  145. dec = np.full(n, np.nan)
  146. last_dir = 0
  147. inc_cnt = 0
  148. dec_cnt = 0
  149. for i, d in enumerate(dirs):
  150. if d == 1:
  151. inc_cnt = inc_cnt + 1 if last_dir == 1 else 1
  152. dec_cnt = 0
  153. last_dir = 1
  154. inc[i] = inc_cnt
  155. dec[i] = dec_cnt
  156. elif d == -1:
  157. dec_cnt = dec_cnt + 1 if last_dir == -1 else 1
  158. inc_cnt = 0
  159. last_dir = -1
  160. inc[i] = inc_cnt
  161. dec[i] = dec_cnt
  162. inc_s = pd.Series(inc, index=df_group.index).ffill().fillna(0).astype(int)
  163. dec_s = pd.Series(dec, index=df_group.index).ffill().fillna(0).astype(int)
  164. return pd.DataFrame(
  165. {
  166. 'price_increase_times_consecutive': inc_s,
  167. 'price_decrease_times_consecutive': dec_s,
  168. },
  169. index=df_group.index,
  170. )
  171. streak_df = df_input.groupby(['gid', 'baggage'], sort=False, group_keys=False).apply(_calc_price_streaks)
  172. df_input = df_input.join(streak_df)
  173. df_input.drop(columns=['_price_event_dir'], inplace=True)
  174. # 价格变化次数
  175. # df_input['price_change_times_total'] = (
  176. # change_mask.groupby([df_input['gid'], df_input['baggage']]).cumsum()
  177. # )
  178. # 价格下降次数
  179. df_input['price_decrease_times_total'] = (
  180. decrease_mask.groupby([df_input['gid'], df_input['baggage']]).cumsum()
  181. )
  182. # 价格上升次数
  183. df_input['price_increase_times_total'] = (
  184. increase_mask.groupby([df_input['gid'], df_input['baggage']]).cumsum()
  185. )
  186. # 上次发生变价的小时数
  187. # last_change_hour = (
  188. # df_input['hours_until_departure']
  189. # .where(change_mask)
  190. # .groupby([df_input['gid'], df_input['baggage']])
  191. # .ffill() # 前向填充
  192. # )
  193. # 上次发生降价的小时数
  194. last_decrease_hour = (
  195. df_input['hours_until_departure']
  196. .where(decrease_mask)
  197. .groupby([df_input['gid'], df_input['baggage']])
  198. .ffill() # 前向填充
  199. )
  200. # 上次发生升价的小时数
  201. last_increase_hour = (
  202. df_input['hours_until_departure']
  203. .where(increase_mask)
  204. .groupby([df_input['gid'], df_input['baggage']])
  205. .ffill() # 前向填充
  206. )
  207. # 当前距离上一次变价过去多少小时
  208. # df_input['price_last_change_hours'] = (
  209. # last_change_hour - df_input['hours_until_departure']
  210. # ).fillna(0)
  211. # 当前距离上一次降价过去多少小时
  212. df_input['price_last_decrease_hours'] = (
  213. last_decrease_hour - df_input['hours_until_departure']
  214. ).fillna(0)
  215. # 当前距离上一次升价过去多少小时
  216. df_input['price_last_increase_hours'] = (
  217. last_increase_hour - df_input['hours_until_departure']
  218. ).fillna(0)
  219. pass
  220. # 想插入到 seats_remaining 前面的新列
  221. new_cols = [
  222. # 'price_change_times_total',
  223. # 'price_last_change_hours',
  224. 'price_decrease_times_total',
  225. 'price_decrease_times_consecutive',
  226. 'price_last_decrease_hours',
  227. 'price_increase_times_total',
  228. 'price_increase_times_consecutive',
  229. 'price_last_increase_hours',
  230. ]
  231. # 当前所有列
  232. cols = df_input.columns.tolist()
  233. # 找到 seats_remaining 的位置
  234. idx = cols.index('seats_remaining')
  235. # 重新拼列顺序
  236. new_order = cols[:idx] + new_cols + cols[idx:]
  237. # 去重(防止列已经在原位置)
  238. new_order = list(dict.fromkeys(new_order))
  239. # 重新排列 DataFrame
  240. df_input = df_input[new_order]
  241. pass
  242. print(">>> 计算价格区间特征")
  243. # 1. 基于绝对价格水平的价格区间划分
  244. # 先计算每个(gid, baggage)的价格统计特征
  245. # g = df_input.groupby(['gid', 'baggage'])
  246. price_stats = df_input.groupby(['gid', 'baggage'])['adult_total_price'].agg(
  247. min_price='min',
  248. max_price='max',
  249. mean_price='mean',
  250. std_price='std'
  251. ).reset_index()
  252. # 合并统计特征到原数据
  253. df_input = df_input.merge(price_stats, on=['gid', 'baggage'], how='left')
  254. # 2. 基于绝对价格的价格区间划分 (可以删除,因为后面有更精细的基于频率加权的分类)
  255. # # 高价区间:超过均值+1倍标准差
  256. # df_input['price_absolute_high'] = (df_input['adult_total_price'] >
  257. # (df_input['mean_price'] + df_input['std_price'])).astype(int)
  258. # # 中高价区间:均值到均值+1倍标准差
  259. # df_input['price_absolute_mid_high'] = ((df_input['adult_total_price'] > df_input['mean_price']) &
  260. # (df_input['adult_total_price'] <= (df_input['mean_price'] + df_input['std_price']))).astype(int)
  261. # # 中低价区间:均值-1倍标准差到均值
  262. # df_input['price_absolute_mid_low'] = ((df_input['adult_total_price'] > (df_input['mean_price'] - df_input['std_price'])) &
  263. # (df_input['adult_total_price'] <= df_input['mean_price'])).astype(int)
  264. # # 低价区间:低于均值-1倍标准差
  265. # df_input['price_absolute_low'] = (df_input['adult_total_price'] <= (df_input['mean_price'] - df_input['std_price'])).astype(int)
  266. # 3. 基于频率加权的价格百分位数(改进版)
  267. # 计算每个价格出现的频率
  268. price_freq = df_input.groupby(['gid', 'baggage', 'adult_total_price']).size().reset_index(name='price_frequency')
  269. df_input = df_input.merge(price_freq, on=['gid', 'baggage', 'adult_total_price'], how='left')
  270. # 计算频率加权的百分位数
  271. def weighted_percentile(group):
  272. if len(group) == 0:
  273. return pd.Series([np.nan] * 4, index=['price_weighted_percentile_25',
  274. 'price_weighted_percentile_50',
  275. 'price_weighted_percentile_75',
  276. 'price_weighted_percentile_90'])
  277. # 按价格排序,计算累积频率
  278. group = group.sort_values('adult_total_price')
  279. group['cum_freq'] = group['price_frequency'].cumsum()
  280. total_freq = group['price_frequency'].sum()
  281. # 计算加权百分位数
  282. percentiles = []
  283. for p in [0.25, 0.5, 0.75, 0.9]:
  284. threshold = total_freq * p
  285. # 找到第一个累积频率超过阈值的价格
  286. mask = group['cum_freq'] >= threshold
  287. if mask.any():
  288. percentile_value = group.loc[mask.idxmax(), 'adult_total_price']
  289. else:
  290. percentile_value = group['adult_total_price'].max()
  291. percentiles.append(percentile_value)
  292. return pd.Series(percentiles, index=['price_weighted_percentile_25',
  293. 'price_weighted_percentile_50',
  294. 'price_weighted_percentile_75',
  295. 'price_weighted_percentile_90'])
  296. # 按gid和baggage分组计算加权百分位数
  297. weighted_percentiles = df_input.groupby(['gid', 'baggage']).apply(weighted_percentile).reset_index()
  298. df_input = df_input.merge(weighted_percentiles, on=['gid', 'baggage'], how='left')
  299. # 4. 结合绝对价格和频率的综合判断(改进版)
  300. freq_median = df_input.groupby(['gid', 'baggage'])['price_frequency'].transform('median')
  301. # 计算价格相对于90%百分位数的倍数,用于区分不同级别的高价
  302. df_input['price_relative_to_90p'] = df_input['adult_total_price'] / df_input['price_weighted_percentile_90']
  303. # 添加价格容忍度:避免相近价格被分到不同区间
  304. # 计算价格差异容忍度(使用各百分位数的1%作为容忍度阈值)
  305. # tolerance_90p = df_input['price_weighted_percentile_90'] * 0.01
  306. tolerance_75p = df_input['price_weighted_percentile_75'] * 0.01
  307. tolerance_50p = df_input['price_weighted_percentile_50'] * 0.01
  308. tolerance_25p = df_input['price_weighted_percentile_25'] * 0.01
  309. # 重新设计价格区间分类(确保无重叠):
  310. # 首先定义各个区间的mask
  311. # 4.1 异常高价:价格远高于90%百分位数(超过1.5倍)且频率极低(低于中位数的1/3)
  312. price_abnormal_high_mask = ((df_input['price_relative_to_90p'] > 1.5) &
  313. (df_input['price_frequency'] < freq_median * 0.33))
  314. # 4.2 真正高位:严格满足条件(价格 > 90%分位数 且 频率 < 中位数)
  315. price_real_high_mask = ((df_input['adult_total_price'] > df_input['price_weighted_percentile_90']) &
  316. (df_input['price_frequency'] < freq_median) &
  317. ~price_abnormal_high_mask)
  318. # 4.3 正常高位:使用容忍度(价格接近75%分位数)
  319. price_normal_high_mask = ((df_input['adult_total_price'] > df_input['price_weighted_percentile_75'] - tolerance_75p) &
  320. ~price_real_high_mask & ~price_abnormal_high_mask)
  321. # 4.4 中高价:使用容忍度(价格在50%-75%分位数之间)
  322. price_mid_high_mask = ((df_input['adult_total_price'] > df_input['price_weighted_percentile_50'] - tolerance_50p) &
  323. (df_input['adult_total_price'] <= df_input['price_weighted_percentile_75'] + tolerance_75p) &
  324. ~price_normal_high_mask & ~price_real_high_mask & ~price_abnormal_high_mask)
  325. # 4.5 中低价:使用容忍度(价格在25%-50%分位数之间)
  326. price_mid_low_mask = ((df_input['adult_total_price'] > df_input['price_weighted_percentile_25'] - tolerance_25p) &
  327. (df_input['adult_total_price'] <= df_input['price_weighted_percentile_50'] + tolerance_50p) &
  328. ~price_mid_high_mask & ~price_normal_high_mask & ~price_real_high_mask & ~price_abnormal_high_mask)
  329. # 4.6 低价:严格满足条件(价格 ≤ 25%分位数)
  330. price_low_mask = ((df_input['adult_total_price'] <= df_input['price_weighted_percentile_25']) &
  331. ~price_mid_low_mask & ~price_mid_high_mask & ~price_normal_high_mask & ~price_real_high_mask & ~price_abnormal_high_mask)
  332. # 使用np.select确保互斥性
  333. price_zone_masks = [
  334. price_abnormal_high_mask, # 异常高价区(5级)
  335. price_real_high_mask, # 真正高价区(4级)
  336. price_normal_high_mask, # 正常高价区(3级)
  337. price_mid_high_mask, # 中高价区(2级)
  338. price_mid_low_mask, # 中低价区(1级)
  339. price_low_mask, # 低价区(0级)
  340. ]
  341. price_zone_values = [5, 4, 3, 2, 1, 0] # 5:异常高价, 4:真正高价, 3:正常高价, 2:中高价, 1:中低价, 0:低价
  342. # 使用np.select确保每个价格只被分到一个区间
  343. price_zone_result = np.select(price_zone_masks, price_zone_values, default=2) # 默认中高价
  344. # 4.8 价格区间综合标记
  345. df_input['price_zone_comprehensive'] = price_zone_result
  346. # 5. 价格异常度检测
  347. # 价格相对于均值的标准化偏差
  348. df_input['price_z_score'] = (df_input['adult_total_price'] - df_input['mean_price']) / df_input['std_price']
  349. # 价格异常度:基于Z-score的绝对值
  350. df_input['price_anomaly_score'] = np.abs(df_input['price_z_score'])
  351. # 6. 价格稳定性特征
  352. # 计算价格波动系数(标准差/均值)
  353. df_input['price_coefficient_variation'] = df_input['std_price'] / df_input['mean_price']
  354. # 7. 价格趋势特征
  355. # 计算当前价格相对于历史价格的位置
  356. df_input['price_relative_position'] = (df_input['adult_total_price'] - df_input['min_price']) / (df_input['max_price'] - df_input['min_price'])
  357. df_input['price_relative_position'] = df_input['price_relative_position'].fillna(0.5) # 兜底
  358. # 删除中间计算列
  359. df_input.drop(columns=['price_frequency', 'price_z_score', 'price_relative_to_90p'], inplace=True, errors='ignore')
  360. del price_freq
  361. del price_stats
  362. del weighted_percentiles
  363. del freq_median
  364. print(">>> 改进版价格区间特征计算完成")
  365. # 生成第一机场对
  366. df_input['airport_pair_1'] = (
  367. df_input['seg1_dep_air_port'].astype(str) + "-" + df_input['seg1_arr_air_port'].astype(str)
  368. )
  369. # 删除原始第一机场码
  370. df_input.drop(columns=['seg1_dep_air_port', 'seg1_arr_air_port'], inplace=True)
  371. # 第一机场对 放到 seg1_dep_time 列的前面
  372. insert_df_col(df_input, 'airport_pair_1', 'seg1_dep_time')
  373. # 生成第二机场对(带缺失兜底)
  374. df_input['airport_pair_2'] = np.where(
  375. df_input['seg2_dep_air_port'].isna() | df_input['seg2_arr_air_port'].isna(),
  376. 'NA',
  377. df_input['seg2_dep_air_port'].astype(str) + "-" +
  378. df_input['seg2_arr_air_port'].astype(str)
  379. )
  380. # 删除原始第二机场码
  381. df_input.drop(columns=['seg2_dep_air_port', 'seg2_arr_air_port'], inplace=True)
  382. # 第二机场对 放到 seg2_dep_time 列的前面
  383. insert_df_col(df_input, 'airport_pair_2', 'seg2_dep_time')
  384. # 是否转乘
  385. df_input['is_transfer'] = np.where(df_input['flight_number_2'] == 'VJ', 0, 1)
  386. # 是否转乘 放到 flight_number_2 列的前面
  387. insert_df_col(df_input, 'is_transfer', 'flight_number_2')
  388. # 重命名起飞时刻与到达时刻
  389. df_input.rename(
  390. columns={
  391. 'seg1_dep_time': 'dep_time_1',
  392. 'seg1_arr_time': 'arr_time_1',
  393. 'seg2_dep_time': 'dep_time_2',
  394. 'seg2_arr_time': 'arr_time_2',
  395. },
  396. inplace=True
  397. )
  398. # 第一段飞行时长
  399. df_input['fly_duration_1'] = (
  400. (df_input['arr_time_1'] - df_input['dep_time_1'])
  401. .dt.total_seconds() / 3600
  402. ).round(2)
  403. # 第二段飞行时长(无转乘为 0)
  404. df_input['fly_duration_2'] = (
  405. (df_input['arr_time_2'] - df_input['dep_time_2'])
  406. .dt.total_seconds() / 3600
  407. ).fillna(0).round(2)
  408. # 总飞行时长
  409. df_input['fly_duration'] = (
  410. df_input['fly_duration_1'] + df_input['fly_duration_2']
  411. ).round(2)
  412. # 中转停留时长(无转乘为 0)
  413. df_input['stop_duration'] = (
  414. (df_input['dep_time_2'] - df_input['arr_time_1'])
  415. .dt.total_seconds() / 3600
  416. ).fillna(0).round(2)
  417. # 裁剪,防止负数
  418. # for c in ['fly_duration_1', 'fly_duration_2', 'fly_duration', 'stop_duration']:
  419. # df_input[c] = df_input[c].clip(lower=0)
  420. # 和 is_transfer 逻辑保持一致
  421. # df_input.loc[df_input['is_transfer'] == 0, ['fly_duration_2', 'stop_duration']] = 0
  422. # 一次性插到 is_filled 前面
  423. insert_before = 'is_filled'
  424. new_cols = [
  425. 'fly_duration_1',
  426. 'fly_duration_2',
  427. 'fly_duration',
  428. 'stop_duration'
  429. ]
  430. cols = df_input.columns.tolist()
  431. idx = cols.index(insert_before)
  432. # 删除旧位置
  433. cols = [c for c in cols if c not in new_cols]
  434. # 插入新位置(顺序保持)
  435. cols[idx:idx] = new_cols # python独有空切片插入法
  436. df_input = df_input[cols]
  437. # 一次生成多个字段
  438. dep_t1 = df_input['dep_time_1']
  439. # 几点起飞(0–23)
  440. df_input['flight_by_hour'] = dep_t1.dt.hour
  441. # 起飞日期几号(1–31)
  442. df_input['flight_by_day'] = dep_t1.dt.day
  443. # 起飞日期几月(1–12)
  444. df_input['flight_day_of_month'] = dep_t1.dt.month
  445. # 起飞日期周几(0=周一, 6=周日)
  446. df_input['flight_day_of_week'] = dep_t1.dt.weekday
  447. # 起飞日期季度(1–4)
  448. df_input['flight_day_of_quarter'] = dep_t1.dt.quarter
  449. # 是否周末(周六 / 周日)
  450. df_input['flight_day_is_weekend'] = dep_t1.dt.weekday.isin([5, 6]).astype(int)
  451. # 找到对应的国家码
  452. df_input['dep_country'] = df_input['from_city_code'].map(city_to_country)
  453. df_input['arr_country'] = df_input['to_city_code'].map(city_to_country)
  454. # 整体出发时间 就是 dep_time_1
  455. df_input['global_dep_time'] = df_input['dep_time_1']
  456. # 整体到达时间:有转乘用 arr_time_2,否则用 arr_time_1
  457. df_input['global_arr_time'] = df_input['arr_time_2'].fillna(df_input['arr_time_1'])
  458. # 出发日期在出发国家是否节假日
  459. df_input['dep_country_is_holiday'] = df_input.apply(
  460. lambda r: r['global_dep_time'].date()
  461. in COUNTRY_HOLIDAYS.get(r['dep_country'], set()),
  462. axis=1
  463. ).astype(int)
  464. # 到达日期在到达国家是否节假日
  465. df_input['arr_country_is_holiday'] = df_input.apply(
  466. lambda r: r['global_arr_time'].date()
  467. in COUNTRY_HOLIDAYS.get(r['arr_country'], set()),
  468. axis=1
  469. ).astype(int)
  470. # 在任一侧是否节假日
  471. df_input['any_country_is_holiday'] = (
  472. df_input[['dep_country_is_holiday', 'arr_country_is_holiday']]
  473. .max(axis=1)
  474. )
  475. # 是否跨国航线
  476. df_input['is_cross_country'] = (
  477. df_input['dep_country'] != df_input['arr_country']
  478. ).astype(int)
  479. def days_to_next_holiday(country, cur_date):
  480. if pd.isna(country) or pd.isna(cur_date):
  481. return np.nan
  482. holidays = COUNTRY_HOLIDAYS.get(country)
  483. if not holidays:
  484. return np.nan
  485. # 找未来(含当天)的节假日,并排序
  486. future_holidays = sorted([d for d in holidays if d >= cur_date])
  487. if not future_holidays:
  488. return np.nan
  489. next_holiday = future_holidays[0] # 第一个未来节假日
  490. delta_days = (next_holiday - cur_date).days
  491. return delta_days
  492. df_input['days_to_holiday'] = df_input.apply(
  493. lambda r: days_to_next_holiday(
  494. r['dep_country'],
  495. r['update_hour'].date()
  496. ),
  497. axis=1
  498. )
  499. # 没有未来节假日的统一兜底
  500. # df_input['days_to_holiday'] = df_input['days_to_holiday'].fillna(999)
  501. # days_to_holiday 插在 update_hour 前面
  502. insert_df_col(df_input, 'days_to_holiday', 'update_hour')
  503. # 训练模式
  504. if is_training:
  505. print(">>> 训练模式:计算 target 相关列")
  506. print(f"\n>>> 开始处理 对应区间: n_hours = {target_n_hours}")
  507. target_lower_limit = crop_lower_limit
  508. target_upper_limit = target_n_hours
  509. mask_targets = (df_input['hours_until_departure'] >= target_lower_limit) & (df_input['hours_until_departure'] < target_upper_limit) & (df_input['baggage'] == 30)
  510. df_targets = df_input.loc[mask_targets].copy()
  511. targets_amout = df_targets.shape[0]
  512. print(f"当前 目标区间数据量: {targets_amout}, 区间: [{target_lower_limit}, {target_upper_limit})")
  513. if targets_amout == 0:
  514. print(f">>> n_hours = {target_n_hours} 无有效数据,跳过")
  515. return pd.DataFrame()
  516. print(">>> 计算 price_at_n_hours")
  517. df_input_object = df_input[(df_input['hours_until_departure'] >= feature_n_hours) & (df_input['baggage'] == 30)].copy()
  518. df_last = df_input_object.groupby('gid', observed=True).last().reset_index() # 一般落在起飞前36\32\30小时
  519. # 提取并重命名 price 列
  520. df_last_price_at_n_hours = df_last[['gid', 'adult_total_price']].rename(columns={'adult_total_price': 'price_at_n_hours'})
  521. print(">>> price_at_n_hours计算完成,示例:")
  522. print(df_last_price_at_n_hours.head(5))
  523. # 新的计算降价方式
  524. # 先排序
  525. df_targets = df_targets.sort_values(
  526. ['gid', 'hours_until_departure'],
  527. ascending=[True, False]
  528. )
  529. # 在 gid 内计算价格变化
  530. g = df_targets.groupby('gid', group_keys=False)
  531. df_targets['price_diff'] = g['adult_total_price'].diff()
  532. # VALID_DROP_MIN = 5
  533. # LOWER_HOUR = 4
  534. # UPPER_HOUR = 28
  535. valid_drop_mask = (
  536. (df_targets['price_diff'] <= -VALID_DROP_MIN)
  537. # (df_targets['hours_until_departure'] >= LOWER_HOUR) &
  538. # (df_targets['hours_until_departure'] <= UPPER_HOUR)
  539. )
  540. # 有效的降价
  541. df_valid_drops = df_targets.loc[valid_drop_mask]
  542. # 找「第一次」降价(每个 gid)
  543. df_first_price_drop = (
  544. df_valid_drops
  545. .groupby('gid', as_index=False)
  546. .first()
  547. )
  548. # 简化列
  549. df_first_price_drop = df_first_price_drop[
  550. ['gid', 'hours_until_departure', 'adult_total_price', 'price_diff']
  551. ].rename(columns={
  552. 'hours_until_departure': 'time_to_price_drop',
  553. 'adult_total_price': 'price_at_d_hours',
  554. 'price_diff': 'amount_of_price_drop',
  555. })
  556. # 把降价幅度转成正数(更直观)
  557. df_first_price_drop['amount_of_price_drop'] = (-df_first_price_drop['amount_of_price_drop']).round(2)
  558. pass
  559. # # 计算降价信息
  560. # print(">>> 计算降价信息")
  561. # df_targets = df_targets.merge(df_last_price_at_n_hours, on='gid', how='left')
  562. # df_targets['price_drop_amount'] = df_targets['price_at_n_hours'] - df_targets['adult_total_price']
  563. # df_targets['price_dropped'] = (
  564. # (df_targets['adult_total_price'] < df_targets['price_at_n_hours']) &
  565. # (df_targets['price_drop_amount'] >= 5) # 降幅不能太小
  566. # )
  567. # df_price_drops = df_targets[df_targets['price_dropped']].copy()
  568. # price_drops_len = df_price_drops.shape[0]
  569. # if price_drops_len == 0:
  570. # print(f">>> n_hours = {current_n_hours} 无降价信息")
  571. # # 创建包含指定列的空 DataFrame
  572. # df_price_drop_info = pd.DataFrame({
  573. # 'gid': pd.Series(dtype='int64'),
  574. # 'first_drop_hours_until_departure': pd.Series(dtype='int64'),
  575. # 'price_at_first_drop_hours': pd.Series(dtype='float64')
  576. # })
  577. # else:
  578. # df_price_drop_info = df_price_drops.groupby('gid', observed=True).first().reset_index() # 第一次发生的降价
  579. # df_price_drop_info = df_price_drop_info[['gid', 'hours_until_departure', 'adult_total_price']].rename(columns={
  580. # 'hours_until_departure': 'first_drop_hours_until_departure',
  581. # 'adult_total_price': 'price_at_first_drop_hours'
  582. # })
  583. # print(">>> 降价信息计算完成,示例:")
  584. # print(df_price_drop_info.head(5))
  585. # # 合并信息
  586. # df_gid_info = df_last_price_at_n_hours.merge(df_price_drop_info, on='gid', how='left')
  587. # df_gid_info['will_price_drop'] = df_gid_info['price_at_first_drop_hours'].notnull().astype(int)
  588. # df_gid_info['amount_of_price_drop'] = df_gid_info['price_at_n_hours'] - df_gid_info['price_at_first_drop_hours']
  589. # df_gid_info['amount_of_price_drop'] = df_gid_info['amount_of_price_drop'].fillna(0) # 区别
  590. # df_gid_info['time_to_price_drop'] = current_n_hours - df_gid_info['first_drop_hours_until_departure']
  591. # df_gid_info['time_to_price_drop'] = df_gid_info['time_to_price_drop'].fillna(0) # 区别
  592. # del df_input_object
  593. # del df_last
  594. # del df_last_price_at_n_hours
  595. # del df_price_drops
  596. # del df_price_drop_info
  597. df_gid_info = df_last_price_at_n_hours.merge(df_first_price_drop, on='gid', how='left')
  598. df_gid_info['will_price_drop'] = df_gid_info['time_to_price_drop'].notnull().astype(int)
  599. df_gid_info['amount_of_price_drop'] = df_gid_info['amount_of_price_drop'].fillna(0)
  600. df_gid_info['time_to_price_drop'] = df_gid_info['time_to_price_drop'].fillna(0)
  601. pass
  602. del df_input_object
  603. del df_last
  604. del df_last_price_at_n_hours
  605. del df_first_price_drop
  606. del df_valid_drops
  607. del df_targets
  608. gc.collect()
  609. # 将目标变量合并到输入数据中
  610. print(">>> 将目标变量信息合并到 df_input")
  611. df_input = df_input.merge(df_gid_info[['gid', 'will_price_drop', 'amount_of_price_drop', 'time_to_price_drop']], on='gid', how='left')
  612. # 使用 0 填充 NaN 值
  613. df_input[['will_price_drop', 'amount_of_price_drop', 'time_to_price_drop']] = df_input[
  614. ['will_price_drop', 'amount_of_price_drop', 'time_to_price_drop']].fillna(0)
  615. df_input = df_input.rename(columns={
  616. 'will_price_drop': 'target_will_price_drop',
  617. 'amount_of_price_drop': 'target_amount_of_drop',
  618. 'time_to_price_drop': 'target_time_to_drop'
  619. })
  620. # 计算每个 gid 分组在 df_targets 中的 adult_total_price 最小值
  621. # print(">>> 计算每个 gid 分组的 adult_total_price 最小值...")
  622. # df_min_price_by_gid = df_targets.groupby('gid')['adult_total_price'].min().reset_index()
  623. # df_min_price_by_gid = df_min_price_by_gid.rename(columns={'adult_total_price': 'min_price'})
  624. # gid_count = df_min_price_by_gid.shape[0]
  625. # print(f">>> 计算完成,共 {gid_count} 个 gid 分组")
  626. # # 将最小价格 merge 到 df_inputs 中
  627. # print(">>> 将最小价格 merge 到输入数据中...")
  628. # df_input = df_input.merge(df_min_price_by_gid, on='gid', how='left')
  629. print(">>> 合并后 df_input 样例:")
  630. print(df_input[['gid', 'hours_until_departure', 'adult_total_price', 'target_will_price_drop', 'target_amount_of_drop', 'target_time_to_drop']].head(5))
  631. # 预测模式
  632. else:
  633. print(">>> 预测模式:补齐 target 相关列(全部置 0)")
  634. df_input['target_will_price_drop'] = 0
  635. df_input['target_amount_of_drop'] = 0.0
  636. df_input['target_time_to_drop'] = 0
  637. # 按顺序排列
  638. order_columns = [
  639. "city_pair", "from_city_code", "from_city_num", "to_city_code", "to_city_num", "flight_day",
  640. "seats_remaining", "baggage", "baggage_level",
  641. "price_decrease_times_total", "price_decrease_times_consecutive", "price_last_decrease_hours",
  642. "price_increase_times_total", "price_increase_times_consecutive", "price_last_increase_hours",
  643. "adult_total_price", "Adult_Total_Price", "target_will_price_drop", "target_amount_of_drop", "target_time_to_drop",
  644. "days_to_departure", "days_to_holiday", "hours_until_departure", "Hours_Until_Departure", "update_hour", "crawl_date", "gid",
  645. "flight_number_1", "flight_1_num", "airport_pair_1", "dep_time_1", "arr_time_1", "fly_duration_1",
  646. "flight_by_hour", "flight_by_day", "flight_day_of_month", "flight_day_of_week", "flight_day_of_quarter", "flight_day_is_weekend", "is_transfer",
  647. "flight_number_2", "flight_2_num", "airport_pair_2", "dep_time_2", "arr_time_2", "fly_duration_2", "fly_duration", "stop_duration",
  648. "global_dep_time", "dep_country", "dep_country_is_holiday", "is_cross_country",
  649. "global_arr_time", "arr_country", "arr_country_is_holiday", "any_country_is_holiday",
  650. "price_weighted_percentile_25", "price_weighted_percentile_50", "price_weighted_percentile_75", "price_weighted_percentile_90",
  651. "price_zone_comprehensive", "price_relative_position",
  652. ]
  653. df_input = df_input[order_columns]
  654. return df_input
  655. def standardization(df, feature_scaler, target_scaler=None, is_training=True, is_val=False, feature_length=240):
  656. print(">>> 开始标准化处理")
  657. # 准备走标准化的特征
  658. scaler_features = ['adult_total_price', 'fly_duration', 'stop_duration',
  659. 'price_weighted_percentile_25', 'price_weighted_percentile_50',
  660. 'price_weighted_percentile_75', 'price_weighted_percentile_90']
  661. if is_training:
  662. print(">>> 特征数据标准化开始")
  663. if feature_scaler is None:
  664. feature_scaler = StandardScaler()
  665. if not is_val:
  666. feature_scaler.fit(df[scaler_features])
  667. df[scaler_features] = feature_scaler.transform(df[scaler_features])
  668. print(">>> 特征数据标准化完成")
  669. else:
  670. df[scaler_features] = feature_scaler.transform(df[scaler_features])
  671. print(">>> 预测模式下特征标准化处理完成")
  672. # 准备走归一化的特征
  673. # 事先定义好每个特征的合理范围
  674. fixed_ranges = {
  675. 'hours_until_departure': (0, 480), # 0-20天
  676. 'from_city_num': (0, 38),
  677. 'to_city_num': (0, 38),
  678. 'flight_1_num': (0, 341),
  679. 'flight_2_num': (0, 341),
  680. 'seats_remaining': (1, 5),
  681. # 'price_change_times_total': (0, 30), # 假设价格变更次数不会超过30次
  682. # 'price_last_change_hours': (0, 480),
  683. 'price_decrease_times_total': (0, 20), # 假设价格下降次数不会超过20次
  684. 'price_decrease_times_consecutive': (0, 10), # 假设价格连续下降次数不会超过10次
  685. 'price_last_decrease_hours': (0, feature_length), #(0-240小时)
  686. 'price_increase_times_total': (0, 20), # 假设价格上升次数不会超过20次
  687. 'price_increase_times_consecutive': (0, 10), # 假设价格连续上升次数不会超过10次
  688. 'price_last_increase_hours': (0, feature_length), #(0-240小时)
  689. 'price_zone_comprehensive': (0, 5),
  690. 'days_to_departure': (0, 30),
  691. 'days_to_holiday': (0, 120), # 最长的越南节假日间隔120天
  692. 'flight_by_hour': (0, 23),
  693. 'flight_by_day': (1, 31),
  694. 'flight_day_of_month': (1, 12),
  695. 'flight_day_of_week': (0, 6),
  696. 'flight_day_of_quarter': (1, 4),
  697. }
  698. normal_features = list(fixed_ranges.keys())
  699. print(">>> 归一化特征列: ", normal_features)
  700. print(">>> 基于固定范围的特征数据归一化开始")
  701. for col in normal_features:
  702. if col in df.columns:
  703. # 核心归一化公式: (x - min) / (max - min)
  704. col_min, col_max = fixed_ranges[col]
  705. df[col] = (df[col] - col_min) / (col_max - col_min)
  706. # 添加裁剪,将超出范围的值强制限制在[0,1]区间
  707. df[col] = df[col].clip(0, 1)
  708. print(">>> 基于固定范围的特征数据归一化完成")
  709. return df, feature_scaler, target_scaler
  710. def preprocess_data_simple(df_input, is_train=False):
  711. df_input = preprocess_data_first_half(df_input)
  712. # 在 gid 与 baggage 内按时间降序
  713. df_input = df_input.sort_values(
  714. by=['gid', 'baggage', 'hours_until_departure'],
  715. ascending=[True, True, False]
  716. ).reset_index(drop=True)
  717. df_input = df_input[df_input['hours_until_departure'] <= 480]
  718. df_input = df_input[df_input['baggage'] == 30]
  719. # 保留真实的而不是补齐的数据
  720. if not is_train:
  721. df_input = df_input[df_input['is_filled'] == 0]
  722. # 计算价格变化量
  723. df_input['price_change_amount'] = (
  724. df_input.groupby(['gid', 'baggage'], group_keys=False)['adult_total_price']
  725. .apply(lambda s: s.diff().replace(0, np.nan).ffill().fillna(0)).round(2)
  726. )
  727. # 计算价格变化百分比(相对于上一时间点的变化率)
  728. df_input['price_change_percent'] = (
  729. df_input.groupby(['gid', 'baggage'], group_keys=False)['adult_total_price']
  730. .apply(lambda s: s.pct_change().replace(0, np.nan).ffill().fillna(0)).round(4)
  731. )
  732. # 第一步:标记价格变化段
  733. df_input['price_change_segment'] = (
  734. df_input.groupby(['gid', 'baggage'], group_keys=False)['price_change_amount']
  735. .apply(lambda s: (s != s.shift()).cumsum())
  736. )
  737. # 第二步:计算每个变化段内的持续时间
  738. df_input['price_duration_hours'] = (
  739. df_input.groupby(['gid', 'baggage', 'price_change_segment'], group_keys=False)
  740. .cumcount()
  741. .add(1)
  742. )
  743. # 可选:删除临时列
  744. df_input = df_input.drop(columns=['price_change_segment'])
  745. adult_price = df_input.pop('Adult_Total_Price')
  746. hours_until = df_input.pop('Hours_Until_Departure')
  747. df_input['Adult_Total_Price'] = adult_price
  748. df_input['Hours_Until_Departure'] = hours_until
  749. df_input['Baggage'] = df_input['baggage']
  750. # 训练过程
  751. if is_train:
  752. df_target = df_input[(df_input['hours_until_departure'] >= 18) & (df_input['hours_until_departure'] <= 54)].copy()
  753. df_target = df_target.sort_values(
  754. by=['gid', 'hours_until_departure'],
  755. ascending=[True, False]
  756. ).reset_index(drop=True)
  757. # 对于先升后降的分析
  758. prev_pct = df_target.groupby('gid', group_keys=False)['price_change_percent'].shift(1)
  759. prev_amo = df_target.groupby('gid', group_keys=False)['price_change_amount'].shift(1)
  760. prev_dur = df_target.groupby('gid', group_keys=False)['price_duration_hours'].shift(1)
  761. drop_mask = (prev_pct > 0) & (df_target['price_change_percent'] < 0)
  762. df_drop_nodes = df_target.loc[drop_mask, ['gid', 'hours_until_departure']].copy()
  763. df_drop_nodes.rename(columns={'hours_until_departure': 'drop_hours_until_departure'}, inplace=True)
  764. df_drop_nodes['drop_price_change_percent'] = df_target.loc[drop_mask, 'price_change_percent'].astype(float).round(4).to_numpy()
  765. df_drop_nodes['drop_price_change_amount'] = df_target.loc[drop_mask, 'price_change_amount'].astype(float).round(2).to_numpy()
  766. df_drop_nodes['high_price_duration_hours'] = prev_dur.loc[drop_mask].astype(float).to_numpy()
  767. df_drop_nodes['high_price_change_percent'] = prev_pct.loc[drop_mask].astype(float).round(4).to_numpy()
  768. df_drop_nodes['high_price_change_amount'] = prev_amo.loc[drop_mask].astype(float).round(2).to_numpy()
  769. df_drop_nodes = df_drop_nodes.reset_index(drop=True)
  770. flight_info_cols = [
  771. 'city_pair',
  772. 'flight_number_1', 'seg1_dep_air_port', 'seg1_dep_time', 'seg1_arr_air_port', 'seg1_arr_time',
  773. 'flight_number_2', 'seg2_dep_air_port', 'seg2_dep_time', 'seg2_arr_air_port', 'seg2_arr_time',
  774. 'currency', 'baggage', 'flight_day',
  775. ]
  776. flight_info_cols = [c for c in flight_info_cols if c in df_target.columns]
  777. df_gid_info = df_target[['gid'] + flight_info_cols].drop_duplicates(subset=['gid']).reset_index(drop=True)
  778. df_drop_nodes = df_drop_nodes.merge(df_gid_info, on='gid', how='left')
  779. drop_info_cols = ['drop_hours_until_departure', 'drop_price_change_percent', 'drop_price_change_amount',
  780. 'high_price_duration_hours', 'high_price_change_percent', 'high_price_change_amount'
  781. ]
  782. # 按顺序排列 去掉gid
  783. df_drop_nodes = df_drop_nodes[flight_info_cols + drop_info_cols]
  784. # 对于没有先升后降的gid进行分析
  785. gids_with_drop = df_target.loc[drop_mask, 'gid'].unique()
  786. df_no_drop = df_target[~df_target['gid'].isin(gids_with_drop)].copy()
  787. keep_info_cols = [
  788. 'keep_hours_until_departure', 'keep_price_change_percent', 'keep_price_change_amount', 'keep_price_duration_hours'
  789. ]
  790. if df_no_drop.empty:
  791. df_keep_nodes = pd.DataFrame(columns=flight_info_cols + keep_info_cols)
  792. else:
  793. df_no_drop = df_no_drop.sort_values(
  794. by=['gid', 'hours_until_departure'],
  795. ascending=[True, False]
  796. ).reset_index(drop=True)
  797. df_no_drop['keep_segment'] = df_no_drop.groupby('gid')['price_change_percent'].transform(
  798. lambda s: (s != s.shift()).cumsum()
  799. )
  800. df_keep_row = (
  801. df_no_drop.groupby(['gid', 'keep_segment'], as_index=False)
  802. .tail(1)
  803. .reset_index(drop=True)
  804. )
  805. df_keep_nodes = df_keep_row[
  806. ['gid', 'hours_until_departure', 'price_change_percent', 'price_change_amount', 'price_duration_hours']
  807. ].copy()
  808. df_keep_nodes.rename(
  809. columns={
  810. 'hours_until_departure': 'keep_hours_until_departure',
  811. 'price_change_percent': 'keep_price_change_percent',
  812. 'price_change_amount': 'keep_price_change_amount',
  813. 'price_duration_hours': 'keep_price_duration_hours',
  814. },
  815. inplace=True,
  816. )
  817. df_keep_nodes = df_keep_nodes.merge(df_gid_info, on='gid', how='left')
  818. df_keep_nodes = df_keep_nodes[flight_info_cols + keep_info_cols]
  819. del df_keep_row
  820. del df_gid_info
  821. del df_target
  822. del df_no_drop
  823. return df_input, df_drop_nodes, df_keep_nodes
  824. return df_input, None, None
  825. def predict_data_simple(df_input, group_route_str, output_dir, predict_dir=".", pred_time_str=""):
  826. if df_input is None or df_input.empty:
  827. return pd.DataFrame()
  828. df_sorted = df_input.sort_values(
  829. by=['gid', 'hours_until_departure'],
  830. ascending=[True, False],
  831. ).reset_index(drop=True)
  832. df_sorted = df_sorted[
  833. df_sorted['hours_until_departure'].between(18, 54)
  834. ].reset_index(drop=True)
  835. # 每个 gid 取 hours_until_departure 最小的一条
  836. df_min_hours = (
  837. df_sorted.drop_duplicates(subset=['gid'], keep='last')
  838. .reset_index(drop=True)
  839. )
  840. # 确保 hours_until_departure 在 [18, 54] 的 范围内
  841. # df_min_hours = df_min_hours[
  842. # df_min_hours['hours_until_departure'].between(18, 54)
  843. # ].reset_index(drop=True)
  844. drop_info_csv_path = os.path.join(output_dir, f'{group_route_str}_drop_info.csv')
  845. if os.path.exists(drop_info_csv_path):
  846. df_drop_nodes = pd.read_csv(drop_info_csv_path)
  847. else:
  848. df_drop_nodes = pd.DataFrame()
  849. keep_info_csv_path = os.path.join(output_dir, f'{group_route_str}_keep_info.csv')
  850. if os.path.exists(keep_info_csv_path):
  851. df_keep_nodes = pd.read_csv(keep_info_csv_path)
  852. else:
  853. df_keep_nodes = pd.DataFrame()
  854. df_min_hours['simple_will_price_drop'] = -1 # -1 表示未知
  855. df_min_hours['simple_drop_in_hours'] = 0
  856. df_min_hours['simple_drop_in_hours_prob'] = 0.0
  857. df_min_hours['simple_drop_in_hours_dist'] = ''
  858. # 这个阈值取多少?
  859. pct_threshold = 0.01
  860. # pct_threshold = 2
  861. pct_threshold_1 = 0.001
  862. pct_threshold_c = 0.001
  863. for idx, row in df_min_hours.iterrows():
  864. city_pair = row['city_pair']
  865. flight_number_1 = row['flight_number_1']
  866. flight_number_2 = row['flight_number_2']
  867. price_change_percent = row['price_change_percent']
  868. price_duration_hours = row['price_duration_hours']
  869. hours_until_departure = row['hours_until_departure']
  870. # 针对历史上发生的 高价->低价
  871. if not df_drop_nodes.empty:
  872. # 对准航班号, 不同起飞日期
  873. if flight_number_2 and flight_number_2 != 'VJ':
  874. df_drop_nodes_part = df_drop_nodes[
  875. (df_drop_nodes['city_pair'] == city_pair) &
  876. (df_drop_nodes['flight_number_1'] == flight_number_1) &
  877. (df_drop_nodes['flight_number_2'] == flight_number_2)
  878. ]
  879. else:
  880. df_drop_nodes_part = df_drop_nodes[
  881. (df_drop_nodes['city_pair'] == city_pair) &
  882. (df_drop_nodes['flight_number_1'] == flight_number_1)
  883. ]
  884. # 降价前 增幅阈值的匹配 与 高价历史持续时间 得出降价时间的概率
  885. if not df_drop_nodes_part.empty and pd.notna(price_change_percent):
  886. # 增幅太小的去掉
  887. df_drop_nodes_part = df_drop_nodes_part[df_drop_nodes_part['high_price_change_percent'] >= 0.1]
  888. # pct_vals = df_drop_nodes_part['high_price_change_percent'].replace([np.inf, -np.inf], np.nan).dropna()
  889. # # 保留百分位 10% ~ 90% 之间的 数据
  890. # if not pct_vals.empty:
  891. # q10 = float(pct_vals.quantile(0.10))
  892. # q90 = float(pct_vals.quantile(0.90))
  893. # df_drop_nodes_part = df_drop_nodes_part[
  894. # df_drop_nodes_part['high_price_change_percent'].between(q10, q90)
  895. # ]
  896. # if df_drop_nodes_part.empty:
  897. # continue
  898. pct_diff = (df_drop_nodes_part['high_price_change_percent'] - float(price_change_percent)).abs()
  899. df_match = df_drop_nodes_part.loc[pct_diff <= pct_threshold, ['high_price_duration_hours', 'high_price_change_percent']].copy()
  900. if not df_match.empty and pd.notna(price_duration_hours):
  901. remaining_hours = (df_match['high_price_duration_hours'] - float(price_duration_hours)).clip(lower=0)
  902. remaining_hours = remaining_hours.round().astype(int)
  903. counts = remaining_hours.value_counts().sort_index()
  904. probs = (counts / counts.sum()).round(4)
  905. top_hours = int(probs.idxmax())
  906. top_prob = float(probs.max())
  907. dist_items = list(zip(probs.index.tolist(), probs.tolist()))
  908. dist_items = dist_items[:10]
  909. dist_str = ' | '.join([f"{int(h)}:{float(p)}" for h, p in dist_items])
  910. df_min_hours.loc[idx, 'simple_will_price_drop'] = 1
  911. df_min_hours.loc[idx, 'simple_drop_in_hours'] = top_hours
  912. df_min_hours.loc[idx, 'simple_drop_in_hours_prob'] = top_prob
  913. df_min_hours.loc[idx, 'simple_drop_in_hours_dist'] = dist_str
  914. continue # 已经判定降价 后面不再做
  915. # 针对历史上发生 一直低价、一直高价、低价->高价、连续低价 等
  916. if not df_keep_nodes.empty:
  917. # 对准航班号, 不同起飞日期
  918. if flight_number_2 and flight_number_2 != 'VJ':
  919. df_keep_nodes_part = df_keep_nodes[
  920. (df_keep_nodes['city_pair'] == city_pair) &
  921. (df_keep_nodes['flight_number_1'] == flight_number_1) &
  922. (df_keep_nodes['flight_number_2'] == flight_number_2)
  923. ]
  924. else:
  925. df_keep_nodes_part = df_keep_nodes[
  926. (df_keep_nodes['city_pair'] == city_pair) &
  927. (df_keep_nodes['flight_number_1'] == flight_number_1)
  928. ]
  929. if not df_keep_nodes_part.empty and pd.notna(price_change_percent):
  930. # pct_vals_1 = df_keep_nodes_part['keep_price_change_percent'].replace([np.inf, -np.inf], np.nan).dropna()
  931. # # 保留百分位 10% ~ 90% 之间的 数据
  932. # if not pct_vals_1.empty:
  933. # q10_1 = float(pct_vals_1.quantile(0.10))
  934. # q90_1 = float(pct_vals_1.quantile(0.90))
  935. # df_keep_nodes_part = df_keep_nodes_part[
  936. # df_keep_nodes_part['keep_price_change_percent'].between(q10_1, q90_1)
  937. # ]
  938. # if df_keep_nodes_part.empty:
  939. # continue
  940. # 特殊判定场景
  941. if price_change_percent < 0:
  942. df_tmp = df_keep_nodes_part.copy()
  943. # 确保组内顺序正确(如果前面已经排过,这行可省略)
  944. df_tmp = df_tmp.sort_values(
  945. by=["flight_day", "keep_hours_until_departure"],
  946. ascending=[True, False]
  947. )
  948. # 是否为负值
  949. df_tmp["is_negative"] = df_tmp["keep_price_change_percent"] < 0
  950. if df_tmp["is_negative"].any():
  951. # 标记“负值段”的开始
  952. # 当 is_negative 为 True 且 前一行不是负值时,认为是一个新段
  953. df_tmp["neg_block_id"] = (
  954. df_tmp["is_negative"]
  955. & ~df_tmp.groupby("flight_day")["is_negative"].shift(fill_value=False)
  956. ).groupby(df_tmp["flight_day"]).cumsum()
  957. # 在每个负值段内计数(第几个负值)
  958. df_tmp["neg_rank_in_block"] = (
  959. df_tmp.groupby(["flight_day", "neg_block_id"])
  960. .cumcount() + 1
  961. )
  962. # 每个连续负值段的长度
  963. df_tmp["neg_block_size"] = (
  964. df_tmp.groupby(["flight_day", "neg_block_id"])["is_negative"]
  965. .transform("sum")
  966. )
  967. # 只保留:
  968. # 1) 是负值
  969. # 2) 且不是该连续负值段的最后一个
  970. df_continuous_price_drop = df_tmp[
  971. (df_tmp["is_negative"]) &
  972. (df_tmp["neg_rank_in_block"] < df_tmp["neg_block_size"])
  973. ].drop(
  974. columns=[
  975. "is_negative",
  976. "neg_block_id",
  977. "neg_rank_in_block",
  978. "neg_block_size",
  979. ]
  980. )
  981. pct_diff_c = (df_continuous_price_drop['keep_price_change_percent'] - float(price_change_percent)).abs()
  982. df_match_c = df_continuous_price_drop.loc[pct_diff_c <= pct_threshold_c, ['flight_day', 'keep_hours_until_departure', 'keep_price_duration_hours', 'keep_price_change_percent']].copy()
  983. # 符合连续降价条件
  984. if not df_match_c.empty and pd.notna(price_duration_hours):
  985. vals_c = df_match_c['keep_price_duration_hours'].replace([np.inf, -np.inf], np.nan).dropna()
  986. if not vals_c.empty:
  987. min_val = vals_c.min()
  988. if min_val <= float(price_duration_hours):
  989. df_min_hours.loc[idx, 'simple_will_price_drop'] = 1
  990. df_min_hours.loc[idx, 'simple_drop_in_hours'] = 0
  991. df_min_hours.loc[idx, 'simple_drop_in_hours_prob'] = 0.5
  992. df_min_hours.loc[idx, 'simple_drop_in_hours_dist'] = ''
  993. continue
  994. # 一般判定场景
  995. pct_diff_1 = (df_keep_nodes_part['keep_price_change_percent'] - float(price_change_percent)).abs()
  996. df_match_1 = df_keep_nodes_part.loc[pct_diff_1 <= pct_threshold_1, ['flight_day', 'keep_hours_until_departure', 'keep_price_duration_hours', 'keep_price_change_percent']].copy()
  997. if not df_match_1.empty and pd.notna(price_duration_hours):
  998. df_match_1['hours_delta'] = hours_until_departure - df_match_1['keep_hours_until_departure']
  999. df_match_1['modify_keep_price_duration_hours'] = df_match_1['keep_price_duration_hours'] - df_match_1['hours_delta']
  1000. df_match_1 = df_match_1[df_match_1['modify_keep_price_duration_hours'] > 0]
  1001. # 比较 price_duration_hours 在 modify_keep_price_duration_hours 的百分位
  1002. vals = df_match_1['modify_keep_price_duration_hours'].replace([np.inf, -np.inf], np.nan).dropna()
  1003. if not vals.empty:
  1004. q10_11 = float(vals.quantile(0.10))
  1005. # q90_11 = float(vals.quantile(0.90))
  1006. if q10_11 <= float(price_duration_hours):
  1007. df_min_hours.loc[idx, 'simple_will_price_drop'] = 0
  1008. df_min_hours.loc[idx, 'simple_drop_in_hours'] = 0
  1009. df_min_hours.loc[idx, 'simple_drop_in_hours_prob'] = 0.0
  1010. df_min_hours.loc[idx, 'simple_drop_in_hours_dist'] = ''
  1011. df_min_hours = df_min_hours.rename(columns={'seg1_dep_time': 'from_time'})
  1012. _pred_dt = pd.to_datetime(str(pred_time_str), format="%Y%m%d%H%M", errors="coerce")
  1013. df_min_hours["update_hour"] = _pred_dt
  1014. _dep_hour = pd.to_datetime(df_min_hours["from_time"], errors="coerce").dt.floor("h")
  1015. df_min_hours["valid_begin_hour"] = _dep_hour - pd.to_timedelta(54, unit="h")
  1016. df_min_hours["valid_end_hour"] = _dep_hour - pd.to_timedelta(18, unit="h")
  1017. order_cols = ['city_pair', 'flight_day', 'flight_number_1', 'flight_number_2', 'from_time', 'baggage', 'currency',
  1018. 'adult_total_price', 'hours_until_departure', 'price_change_percent', 'price_duration_hours',
  1019. 'update_hour', 'crawl_date',
  1020. 'valid_begin_hour', 'valid_end_hour',
  1021. 'simple_will_price_drop', 'simple_drop_in_hours', 'simple_drop_in_hours_prob', 'simple_drop_in_hours_dist'
  1022. ]
  1023. df_predict = df_min_hours[order_cols]
  1024. df_predict = df_predict.rename(columns={
  1025. 'simple_will_price_drop': 'will_price_drop',
  1026. 'simple_drop_in_hours': 'drop_in_hours',
  1027. 'simple_drop_in_hours_prob': 'drop_in_hours_prob',
  1028. 'simple_drop_in_hours_dist': 'drop_in_hours_dist',
  1029. }
  1030. )
  1031. csv_path1 = os.path.join(predict_dir, f'future_predictions_{pred_time_str}.csv')
  1032. df_predict.to_csv(csv_path1, mode='a', index=False, header=not os.path.exists(csv_path1), encoding='utf-8-sig')
  1033. print("预测结果已追加")
  1034. return df_predict