data_preprocess.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. import pandas as pd
  2. import numpy as np
  3. import bisect
  4. from datetime import datetime, timedelta
  5. from sklearn.preprocessing import StandardScaler
  6. from config import city_to_country, build_country_holidays
  7. COUNTRY_HOLIDAYS = build_country_holidays(city_to_country)
  8. def preprocess_data(df_train, features, categorical_features, is_training=True):
  9. print(">>> 开始数据预处理")
  10. # 生成 城市对
  11. df_train['city_pair'] = (
  12. df_train['from_city_code'].astype(str) + "-" + df_train['to_city_code'].astype(str)
  13. )
  14. # 把 city_pair、from_city_code、to_city_code 放到前三列
  15. cols = df_train.columns.tolist()
  16. # 删除已存在的三列(保证顺序正确)
  17. for c in ['city_pair', 'from_city_code', 'to_city_code']:
  18. cols.remove(c)
  19. # 这三列插入到最前面
  20. df_train = df_train[['city_pair', 'from_city_code', 'to_city_code'] + cols]
  21. # 转格式
  22. df_train['search_dep_time'] = pd.to_datetime(
  23. df_train['search_dep_time'],
  24. format='%Y%m%d',
  25. errors='coerce'
  26. ).dt.strftime('%Y-%m-%d')
  27. # 重命名起飞日期
  28. df_train.rename(columns={'search_dep_time': 'flight_day'}, inplace=True)
  29. # 重命名航班号
  30. df_train.rename(
  31. columns={
  32. 'seg1_flight_number': 'flight_number_1',
  33. 'seg2_flight_number': 'flight_number_2'
  34. },
  35. inplace=True
  36. )
  37. # 分开填充
  38. df_train['flight_number_1'] = df_train['flight_number_1'].fillna('VJ')
  39. df_train['flight_number_2'] = df_train['flight_number_2'].fillna('VJ')
  40. # 生成第一机场对
  41. df_train['airport_pair_1'] = (
  42. df_train['seg1_dep_air_port'].astype(str) + "-" + df_train['seg1_arr_air_port'].astype(str)
  43. )
  44. # 删除原始第一机场码
  45. df_train.drop(columns=['seg1_dep_air_port', 'seg1_arr_air_port'], inplace=True)
  46. # 第一机场对 放到 seg1_dep_time 列的前面
  47. insert_idx = df_train.columns.get_loc('seg1_dep_time')
  48. airport_pair_1 = df_train.pop('airport_pair_1')
  49. df_train.insert(insert_idx, 'airport_pair_1', airport_pair_1)
  50. # 生成第二机场对(带缺失兜底)
  51. df_train['airport_pair_2'] = np.where(
  52. df_train['seg2_dep_air_port'].isna() | df_train['seg2_arr_air_port'].isna(),
  53. 'NA',
  54. df_train['seg2_dep_air_port'].astype(str) + "-" +
  55. df_train['seg2_arr_air_port'].astype(str)
  56. )
  57. # 删除原始第二机场码
  58. df_train.drop(columns=['seg2_dep_air_port', 'seg2_arr_air_port'], inplace=True)
  59. # 第二机场对 放到 seg2_dep_time 列的前面
  60. insert_idx = df_train.columns.get_loc('seg2_dep_time')
  61. airport_pair_2 = df_train.pop('airport_pair_2')
  62. df_train.insert(insert_idx, 'airport_pair_2', airport_pair_2)
  63. # 是否转乘
  64. df_train['is_transfer'] = np.where(df_train['flight_number_2'] == 'VJ', 0, 1)
  65. insert_idx = df_train.columns.get_loc('flight_number_2')
  66. is_transfer = df_train.pop('is_transfer')
  67. df_train.insert(insert_idx, 'is_transfer', is_transfer)
  68. # 重命名起飞时刻与到达时刻
  69. df_train.rename(
  70. columns={
  71. 'seg1_dep_time': 'dep_time_1',
  72. 'seg1_arr_time': 'arr_time_1',
  73. 'seg2_dep_time': 'dep_time_2',
  74. 'seg2_arr_time': 'arr_time_2',
  75. },
  76. inplace=True
  77. )
  78. # 第一段飞行时长
  79. df_train['fly_duration_1'] = (
  80. (df_train['arr_time_1'] - df_train['dep_time_1'])
  81. .dt.total_seconds() / 3600
  82. ).round(2)
  83. # 第二段飞行时长(无转乘为 0)
  84. df_train['fly_duration_2'] = (
  85. (df_train['arr_time_2'] - df_train['dep_time_2'])
  86. .dt.total_seconds() / 3600
  87. ).fillna(0).round(2)
  88. # 总飞行时长
  89. df_train['fly_duration'] = (
  90. df_train['fly_duration_1'] + df_train['fly_duration_2']
  91. ).round(2)
  92. # 中转停留时长(无转乘为 0)
  93. df_train['stop_duration'] = (
  94. (df_train['dep_time_2'] - df_train['arr_time_1'])
  95. .dt.total_seconds() / 3600
  96. ).fillna(0).round(2)
  97. # 裁剪,防止负数
  98. # for c in ['fly_duration_1', 'fly_duration_2', 'fly_duration', 'stop_duration']:
  99. # df_train[c] = df_train[c].clip(lower=0)
  100. # 和 is_transfer 逻辑保持一致
  101. # df_train.loc[df_train['is_transfer'] == 0, ['fly_duration_2', 'stop_duration']] = 0
  102. # 一次性插到 is_filled 前面
  103. insert_before = 'is_filled'
  104. new_cols = [
  105. 'fly_duration_1',
  106. 'fly_duration_2',
  107. 'fly_duration',
  108. 'stop_duration'
  109. ]
  110. cols = df_train.columns.tolist()
  111. idx = cols.index(insert_before)
  112. # 删除旧位置
  113. cols = [c for c in cols if c not in new_cols]
  114. # 插入新位置(顺序保持)
  115. cols[idx:idx] = new_cols # python独有空切片插入法
  116. df_train = df_train[cols]
  117. # 一次生成多个字段
  118. dep_t1 = df_train['dep_time_1']
  119. # 几点起飞(0–23)
  120. df_train['flight_by_hour'] = dep_t1.dt.hour
  121. # 起飞日期几号(1–31)
  122. df_train['flight_by_day'] = dep_t1.dt.day
  123. # 起飞日期几月(1–12)
  124. df_train['flight_day_of_month'] = dep_t1.dt.month
  125. # 起飞日期周几(0=周一, 6=周日)
  126. df_train['flight_day_of_week'] = dep_t1.dt.weekday
  127. # 起飞日期季度(1–4)
  128. df_train['flight_day_of_quarter'] = dep_t1.dt.quarter
  129. # 是否周末(周六 / 周日)
  130. df_train['flight_day_is_weekend'] = dep_t1.dt.weekday.isin([5, 6]).astype(int)
  131. # 找到对应的国家码
  132. df_train['dep_country'] = df_train['from_city_code'].map(city_to_country)
  133. df_train['arr_country'] = df_train['to_city_code'].map(city_to_country)
  134. # 整体出发时间 就是 dep_time_1
  135. df_train['global_dep_time'] = df_train['dep_time_1']
  136. # 整体到达时间:有转乘用 arr_time_2,否则用 arr_time_1
  137. df_train['global_arr_time'] = df_train['arr_time_2'].fillna(df_train['arr_time_1'])
  138. # 出发日期在出发国家是否节假日
  139. df_train['dep_country_is_holiday'] = df_train.apply(
  140. lambda r: r['global_dep_time'].date()
  141. in COUNTRY_HOLIDAYS.get(r['dep_country'], set()),
  142. axis=1
  143. ).astype(int)
  144. # 到达日期在到达国家是否节假日
  145. df_train['arr_country_is_holiday'] = df_train.apply(
  146. lambda r: r['global_arr_time'].date()
  147. in COUNTRY_HOLIDAYS.get(r['arr_country'], set()),
  148. axis=1
  149. ).astype(int)
  150. # 在任一侧是否节假日
  151. df_train['flight_day_is_holiday'] = (
  152. df_train[['dep_country_is_holiday', 'arr_country_is_holiday']]
  153. .max(axis=1)
  154. )
  155. # 是否跨国航线
  156. df_train['is_cross_country'] = (
  157. df_train['dep_country'] != df_train['arr_country']
  158. ).astype(int)
  159. pass