Procházet zdrojové kódy

训练时增加变化前后的舱位

node04 před 1 měsícem
rodič
revize
e7ec8491f1
1 změnil soubory, kde provedl 12 přidání a 7 odebrání
  1. 12 7
      data_process.py

+ 12 - 7
data_process.py

@@ -95,22 +95,25 @@ def preprocess_data_simple(df_input, is_train=False, hourly_time=None):
         prev_amo = df_target.groupby(['gid', 'baggage_weight'], group_keys=False)['price_change_amount'].shift(1)
         prev_dur = df_target.groupby(['gid', 'baggage_weight'], group_keys=False)['price_duration_hours'].shift(1)
         prev_price = df_target.groupby(['gid', 'baggage_weight'], group_keys=False)['price_total'].shift(1)
-        
+        prev_cabin = df_target.groupby(['gid', 'baggage_weight'], group_keys=False)['cabins'].shift(1)
+
         # 对于先升后降(先降再降)的分析
         seg_start_mask = df_target['price_duration_hours'].eq(1)   # 开始变价节点
         drop_mask = seg_start_mask & ((prev_pct > 0) | (prev_pct < 0)) & (df_target['price_change_percent'] < 0)
 
-        df_drop_nodes = df_target.loc[drop_mask, ['gid', 'baggage_weight', 'hours_until_departure', 'days_to_departure', 'update_hour', 'update_week']].copy()
+        df_drop_nodes = df_target.loc[drop_mask, ['gid', 'baggage_weight', 'hours_until_departure', 'days_to_departure', 'update_hour', 'update_week', 'cabins']].copy()
         df_drop_nodes.rename(columns={'hours_until_departure': 'drop_hours_until_departure'}, inplace=True)
         df_drop_nodes.rename(columns={'days_to_departure': 'drop_days_to_departure'}, inplace=True)
         df_drop_nodes.rename(columns={'update_hour': 'drop_update_hour'}, inplace=True)
         df_drop_nodes.rename(columns={'update_week': 'drop_update_week'}, inplace=True)
+        df_drop_nodes.rename(columns={'cabins': 'drop_cabins'}, inplace=True)
         df_drop_nodes['drop_price_change_percent'] = df_target.loc[drop_mask, 'price_change_percent'].astype(float).round(4).to_numpy()
         df_drop_nodes['drop_price_change_amount'] = df_target.loc[drop_mask, 'price_change_amount'].astype(float).round(2).to_numpy()
         df_drop_nodes['high_price_duration_hours'] = prev_dur.loc[drop_mask].astype(float).to_numpy()
         df_drop_nodes['high_price_change_percent'] = prev_pct.loc[drop_mask].astype(float).round(4).to_numpy()
         df_drop_nodes['high_price_change_amount'] = prev_amo.loc[drop_mask].astype(float).round(2).to_numpy()
         df_drop_nodes['high_price_amount'] = prev_price.loc[drop_mask].astype(float).round(2).to_numpy()
+        df_drop_nodes['high_price_cabins'] = prev_cabin.loc[drop_mask].astype(str)
         df_drop_nodes = df_drop_nodes.reset_index(drop=True)
 
         flight_info_cols = [
@@ -121,9 +124,9 @@ def preprocess_data_simple(df_input, is_train=False, hourly_time=None):
         df_drop_nodes = df_drop_nodes.merge(df_gid_info, on=['gid', 'baggage_weight'], how='left')
 
         drop_info_cols = [
-            'drop_update_hour', 'drop_update_week',
+            'drop_update_hour', 'drop_update_week', 'drop_cabins', 
             'drop_days_to_departure', 'drop_hours_until_departure', 'drop_price_change_percent', 'drop_price_change_amount',
-            'high_price_duration_hours', 'high_price_change_percent', 'high_price_change_amount', 'high_price_amount', 
+            'high_price_duration_hours', 'high_price_change_percent', 'high_price_change_amount', 'high_price_amount', 'high_price_cabins',
         ]
         # 按顺序排列 去掉gid
         df_drop_nodes = df_drop_nodes[flight_info_cols + ['baggage_weight'] + drop_info_cols]
@@ -132,25 +135,27 @@ def preprocess_data_simple(df_input, is_train=False, hourly_time=None):
         # seg_start_mask = df_target['price_duration_hours'].eq(1)
         rise_mask = seg_start_mask & ((prev_pct > 0) | (prev_pct < 0)) & (df_target['price_change_percent'] > 0)
 
-        df_rise_nodes = df_target.loc[rise_mask, ['gid', 'baggage_weight', 'hours_until_departure', 'days_to_departure', 'update_hour', 'update_week']].copy()
+        df_rise_nodes = df_target.loc[rise_mask, ['gid', 'baggage_weight', 'hours_until_departure', 'days_to_departure', 'update_hour', 'update_week', 'cabins']].copy()
         df_rise_nodes.rename(columns={'hours_until_departure': 'rise_hours_until_departure'}, inplace=True)
         df_rise_nodes.rename(columns={'days_to_departure': 'rise_days_to_departure'}, inplace=True)
         df_rise_nodes.rename(columns={'update_hour': 'rise_update_hour'}, inplace=True)
         df_rise_nodes.rename(columns={'update_week': 'rise_update_week'}, inplace=True)
+        df_rise_nodes.rename(columns={'cabins': 'rise_cabins'}, inplace=True)
         df_rise_nodes['rise_price_change_percent'] = df_target.loc[rise_mask, 'price_change_percent'].astype(float).round(4).to_numpy()
         df_rise_nodes['rise_price_change_amount'] = df_target.loc[rise_mask, 'price_change_amount'].astype(float).round(2).to_numpy()
         df_rise_nodes['prev_rise_duration_hours'] = prev_dur.loc[rise_mask].astype(float).to_numpy()
         df_rise_nodes['prev_rise_change_percent'] = prev_pct.loc[rise_mask].astype(float).round(4).to_numpy()
         df_rise_nodes['prev_rise_change_amount'] = prev_amo.loc[rise_mask].astype(float).round(2).to_numpy()
         df_rise_nodes['prev_rise_amount'] = prev_price.loc[rise_mask].astype(float).round(2).to_numpy()
+        df_rise_nodes['prev_rise_cabins'] = prev_cabin.loc[rise_mask].astype(str)
         df_rise_nodes = df_rise_nodes.reset_index(drop=True)
 
         df_rise_nodes = df_rise_nodes.merge(df_gid_info, on=['gid', 'baggage_weight'], how='left')
         
         rise_info_cols = [
-            'rise_update_hour', 'rise_update_week',
+            'rise_update_hour', 'rise_update_week', 'rise_cabins', 
             'rise_days_to_departure', 'rise_hours_until_departure', 'rise_price_change_percent', 'rise_price_change_amount',
-            'prev_rise_duration_hours', 'prev_rise_change_percent', 'prev_rise_change_amount', 'prev_rise_amount',
+            'prev_rise_duration_hours', 'prev_rise_change_percent', 'prev_rise_change_amount', 'prev_rise_amount', 'prev_rise_cabins'
         ]
         df_rise_nodes = df_rise_nodes[flight_info_cols + ['baggage_weight'] + rise_info_cols]