Răsfoiți Sursa

调整训练、预测、与验证,放入价格增长额度的考虑

node04 5 zile în urmă
părinte
comite
caa7c82092
3 a modificat fișierele cu 76 adăugiri și 43 ștergeri
  1. 49 29
      data_preprocess.py
  2. 1 1
      main_tr_0.py
  3. 26 13
      result_validate_0.py

+ 49 - 29
data_preprocess.py

@@ -883,16 +883,16 @@ def preprocess_data_simple(df_input, is_train=False):
     df_input = df_input.drop(columns=['price_change_segment'])
 
     # 仅在价格变化点记录余票变化量;其它非价格变化点置空(NaN) 
-    _price_diff = df_input.groupby(['gid', 'baggage'], group_keys=False)['adult_total_price'].diff()
-    _price_changed = _price_diff.notna() & _price_diff.ne(0)
-    _seats_diff = df_input.groupby(['gid', 'baggage'], group_keys=False)['seats_remaining'].diff()
-    df_input['seats_remaining_change_amount'] = _seats_diff.where(_price_changed).round(0)
-    # 前向填充 并 填充缺失值为0
-    df_input['seats_remaining_change_amount'] = (
-        df_input.groupby(['gid', 'baggage'], group_keys=False)['seats_remaining_change_amount']
-        .ffill()
-        .fillna(0)
-    )
+    # _price_diff = df_input.groupby(['gid', 'baggage'], group_keys=False)['adult_total_price'].diff()
+    # _price_changed = _price_diff.notna() & _price_diff.ne(0)
+    # _seats_diff = df_input.groupby(['gid', 'baggage'], group_keys=False)['seats_remaining'].diff()
+    # df_input['seats_remaining_change_amount'] = _seats_diff.where(_price_changed).round(0)
+    # # 前向填充 并 填充缺失值为0
+    # df_input['seats_remaining_change_amount'] = (
+    #     df_input.groupby(['gid', 'baggage'], group_keys=False)['seats_remaining_change_amount']
+    #     .ffill()
+    #     .fillna(0)
+    # )
     
     adult_price = df_input.pop('Adult_Total_Price')
     hours_until = df_input.pop('Hours_Until_Departure')
@@ -912,7 +912,9 @@ def preprocess_data_simple(df_input, is_train=False):
         prev_pct = df_target.groupby('gid', group_keys=False)['price_change_percent'].shift(1)
         prev_amo = df_target.groupby('gid', group_keys=False)['price_change_amount'].shift(1)
         prev_dur = df_target.groupby('gid', group_keys=False)['price_duration_hours'].shift(1)
-        prev_seats_amo = df_target.groupby('gid', group_keys=False)['seats_remaining_change_amount'].shift(1)
+        # prev_seats_amo = df_target.groupby('gid', group_keys=False)['seats_remaining_change_amount'].shift(1)
+        prev_price = df_target.groupby('gid', group_keys=False)['adult_total_price'].shift(1)
+        prev_seats = df_target.groupby('gid', group_keys=False)['seats_remaining'].shift(1)
         drop_mask = (prev_pct > 0) & (df_target['price_change_percent'] < 0)
         
         df_drop_nodes = df_target.loc[drop_mask, ['gid', 'hours_until_departure']].copy()
@@ -922,7 +924,9 @@ def preprocess_data_simple(df_input, is_train=False):
         df_drop_nodes['high_price_duration_hours'] = prev_dur.loc[drop_mask].astype(float).to_numpy()
         df_drop_nodes['high_price_change_percent'] = prev_pct.loc[drop_mask].astype(float).round(4).to_numpy()
         df_drop_nodes['high_price_change_amount'] = prev_amo.loc[drop_mask].astype(float).round(2).to_numpy()
-        df_drop_nodes['high_price_seats_remaining_change_amount'] = prev_seats_amo.loc[drop_mask].astype(float).round(1).to_numpy()
+        # df_drop_nodes['high_price_seats_remaining_change_amount'] = prev_seats_amo.loc[drop_mask].astype(float).round(1).to_numpy()
+        df_drop_nodes['high_price_amount'] = prev_price.loc[drop_mask].astype(float).round(2).to_numpy()
+        df_drop_nodes['high_price_seats_remaining'] = prev_seats.loc[drop_mask].astype(int).to_numpy()
         df_drop_nodes = df_drop_nodes.reset_index(drop=True)
 
         flight_info_cols = [
@@ -939,7 +943,7 @@ def preprocess_data_simple(df_input, is_train=False):
 
         drop_info_cols = ['drop_hours_until_departure', 'drop_price_change_percent', 'drop_price_change_amount',
                           'high_price_duration_hours', 'high_price_change_percent', 'high_price_change_amount',
-                          'high_price_seats_remaining_change_amount',
+                          'high_price_amount', 'high_price_seats_remaining',
         ]
         # 按顺序排列 去掉gid
         df_drop_nodes = df_drop_nodes[flight_info_cols + drop_info_cols]
@@ -951,7 +955,7 @@ def preprocess_data_simple(df_input, is_train=False):
 
         keep_info_cols = [
             'keep_hours_until_departure', 'keep_price_change_percent', 'keep_price_change_amount', 
-            'keep_price_duration_hours', 'keep_seats_remaining_change_amount',
+            'keep_price_duration_hours', 'keep_price_amount', 'keep_price_seats_remaining',
         ]
         
         if df_no_drop.empty:
@@ -974,7 +978,7 @@ def preprocess_data_simple(df_input, is_train=False):
 
             df_keep_nodes = df_keep_row[
                 ['gid', 'hours_until_departure', 'price_change_percent', 'price_change_amount', 
-                 'price_duration_hours', 'seats_remaining_change_amount']
+                 'price_duration_hours', 'adult_total_price', 'seats_remaining']
             ].copy()
             df_keep_nodes.rename(
                 columns={
@@ -982,7 +986,8 @@ def preprocess_data_simple(df_input, is_train=False):
                     'price_change_percent': 'keep_price_change_percent',
                     'price_change_amount': 'keep_price_change_amount',
                     'price_duration_hours': 'keep_price_duration_hours',
-                    'seats_remaining_change_amount': 'keep_seats_remaining_change_amount',
+                    'adult_total_price': 'keep_price_amount',
+                    'seats_remaining': 'keep_price_seats_remaining',
                 },
                 inplace=True,
             )
@@ -1058,8 +1063,10 @@ def predict_data_simple(df_input, group_route_str, output_dir, predict_dir=".",
         price_change_amount = row['price_change_amount']
         price_duration_hours = row['price_duration_hours']
         hours_until_departure = row['hours_until_departure']
-        seats_remaining_change_amount = row['seats_remaining_change_amount']
-        
+        # seats_remaining_change_amount = row['seats_remaining_change_amount']
+        price_amount = row['adult_total_price']
+        seats_remaining = row['seats_remaining']
+
         length_drop = 0
         length_keep = 0
 
@@ -1090,12 +1097,18 @@ def predict_data_simple(df_input, group_route_str, output_dir, predict_dir=".",
                 df_drop_gap = df_drop_nodes_part.loc[
                     pct_vals.notna(),
                     ['drop_hours_until_departure', 'high_price_duration_hours', 'high_price_change_percent', 
-                     'high_price_change_amount', 'high_price_seats_remaining_change_amount']
+                     'high_price_change_amount', 'high_price_amount', 'high_price_seats_remaining']
                 ].copy()
                 df_drop_gap['pct_gap'] = (pct_vals.loc[pct_vals.notna()] - pct_base)
                 df_drop_gap['pct_abs_gap'] = df_drop_gap['pct_gap'].abs()
-                df_drop_gap = df_drop_gap.sort_values(['pct_abs_gap'], ascending=True)
-                df_match = df_drop_gap[df_drop_gap['pct_abs_gap'] <= pct_threshold]
+
+                price_base = pd.to_numeric(price_amount, errors='coerce')
+                high_price_vals = pd.to_numeric(df_drop_gap['high_price_amount'], errors='coerce')
+                df_drop_gap['price_gap'] = high_price_vals - price_base
+                df_drop_gap['price_abs_gap'] = df_drop_gap['price_gap'].abs()
+
+                df_drop_gap = df_drop_gap.sort_values(['pct_abs_gap', 'price_abs_gap'], ascending=[True, True])
+                df_match = df_drop_gap[(df_drop_gap['pct_abs_gap'] <= pct_threshold) & (df_drop_gap['price_abs_gap'] <= 5.0)].copy()
 
                 # 历史上出现的极近似的增长幅度后的降价场景
                 if not df_match.empty:
@@ -1263,15 +1276,21 @@ def predict_data_simple(df_input, group_route_str, output_dir, predict_dir=".",
                 # 一般判定场景
                 pct_base_1 = float(price_change_percent)
                 pct_vals_1 = pd.to_numeric(df_keep_nodes_part['keep_price_change_percent'], errors='coerce')
-                df_drop_gap_1 = df_keep_nodes_part.loc[
+                df_keep_gap_1 = df_keep_nodes_part.loc[
                     pct_vals_1.notna(),
                     ['keep_hours_until_departure', 'keep_price_duration_hours', 'keep_price_change_percent', 
-                     'keep_price_change_amount', 'keep_seats_remaining_change_amount']
+                     'keep_price_change_amount', 'keep_price_amount', 'keep_price_seats_remaining']
                 ].copy()
-                df_drop_gap_1['pct_gap'] = (pct_vals_1.loc[pct_vals_1.notna()] - pct_base_1)
-                df_drop_gap_1['pct_abs_gap'] = df_drop_gap_1['pct_gap'].abs()
-                df_drop_gap_1 = df_drop_gap_1.sort_values(['pct_abs_gap'], ascending=True)
-                df_match_1 = df_drop_gap_1.loc[df_drop_gap_1['pct_abs_gap'] <= pct_threshold_1].copy()
+                df_keep_gap_1['pct_gap'] = (pct_vals_1.loc[pct_vals_1.notna()] - pct_base_1)
+                df_keep_gap_1['pct_abs_gap'] = df_keep_gap_1['pct_gap'].abs()
+                
+                price_base_1 = pd.to_numeric(price_amount, errors='coerce')
+                keep_price_vals_1 = pd.to_numeric(df_keep_gap_1['keep_price_amount'], errors='coerce')
+                df_keep_gap_1['price_gap'] = keep_price_vals_1 - price_base_1
+                df_keep_gap_1['price_abs_gap'] = df_keep_gap_1['price_gap'].abs()
+
+                df_keep_gap_1 = df_keep_gap_1.sort_values(['pct_abs_gap', 'price_abs_gap'], ascending=[True, True])
+                df_match_1 = df_keep_gap_1.loc[(df_keep_gap_1['pct_abs_gap'] <= pct_threshold_1) & (df_keep_gap_1['price_abs_gap'] <= 5.0)].copy()
 
                 # 历史上出现过近似变化幅度后保持低价场景
                 if not df_match_1.empty:
@@ -1381,8 +1400,9 @@ def predict_data_simple(df_input, group_route_str, output_dir, predict_dir=".",
     df_min_hours["valid_begin_hour"] = (_dep_hour - pd.to_timedelta(60, unit="h")).dt.strftime("%Y-%m-%d %H:%M:%S")
     df_min_hours["valid_end_hour"] = (_dep_hour - pd.to_timedelta(12, unit="h")).dt.strftime("%Y-%m-%d %H:%M:%S")
 
-    order_cols = ['city_pair', 'flight_day', 'flight_number_1', 'flight_number_2', 'from_time', 'baggage', 'currency', 
-                  'adult_total_price', 'hours_until_departure', 'price_change_percent', 'price_duration_hours',
+    order_cols = ['city_pair', 'flight_day', 'flight_number_1', 'flight_number_2', 'from_time', 
+                  'baggage', 'seats_remaining', 'currency',
+                  'adult_total_price', 'hours_until_departure', 'price_change_percent', 'price_duration_hours', 
                   'update_hour', 'crawl_date',
                   'valid_begin_hour', 'valid_end_hour',
                   'simple_will_price_drop', 'simple_drop_in_hours', 'simple_drop_in_hours_prob', 'simple_drop_in_hours_dist'

+ 1 - 1
main_tr_0.py

@@ -49,7 +49,7 @@ def start_train():
     # date_end = datetime.today().strftime("%Y-%m-%d")
     date_end = (datetime.today() - timedelta(days=1)).strftime("%Y-%m-%d")
     # date_begin = (datetime.today() - timedelta(days=32)).strftime("%Y-%m-%d")
-    date_begin = "2026-02-03"   # 2025-12-01  2026-01-27  2026-02-03 
+    date_begin = "2025-12-01"   # 2025-12-01  2026-01-27  2026-02-06
 
     print(f"训练时间范围: {date_begin} 到 {date_end}")
 

+ 26 - 13
result_validate_0.py

@@ -226,9 +226,9 @@ def validate_process_auto(node, interval_hours):
     print()
 
 
-def validate_process_zong(node, enable_min_batch_flag=False, min_batch_time_str=None):
+def validate_process_zong(node, enable_min_max_batch_flag=False, min_batch_time_str=None, max_batch_time_str=None):
     object_dir = "./predictions_0"
-
+    
     output_dir = f"./validate/{node}_zong"
     os.makedirs(output_dir, exist_ok=True)
 
@@ -251,13 +251,22 @@ def validate_process_zong(node, enable_min_batch_flag=False, min_batch_time_str=
     list_df_will_drop = []
 
     min_batch_dt = None
-    if enable_min_batch_flag:
-        if not min_batch_time_str:
-            print("enable_min_batch_flag=True 但未提供 min_batch_time_str,退出")
-            return        
-        min_batch_dt = datetime.datetime.strptime(min_batch_time_str, "%Y%m%d%H%M")
-        min_batch_dt = min_batch_dt.replace(minute=0, second=0, microsecond=0)
-    
+    max_batch_dt = None
+    if enable_min_max_batch_flag:
+        if not min_batch_time_str and not max_batch_time_str:
+            print("enable_min_max_batch_flag=True 但未提供 min_batch_time_str/max_batch_time_str,退出")
+            return
+        if min_batch_time_str:
+            min_batch_dt = datetime.datetime.strptime(min_batch_time_str, "%Y%m%d%H%M")
+            min_batch_dt = min_batch_dt.replace(minute=0, second=0, microsecond=0)
+        if max_batch_time_str:
+            max_batch_dt = datetime.datetime.strptime(max_batch_time_str, "%Y%m%d%H%M")
+            max_batch_dt = max_batch_dt.replace(minute=0, second=0, microsecond=0)
+
+        if min_batch_dt is not None and max_batch_dt is not None and min_batch_dt > max_batch_dt:
+            print(f"时间范围非法: min_batch_time_str({min_batch_time_str}) > max_batch_time_str({max_batch_time_str}),退出")
+            return
+
     # 从所有预测的文件中
     for csv_file in csv_files:
         batch_time_str = (
@@ -265,10 +274,10 @@ def validate_process_zong(node, enable_min_batch_flag=False, min_batch_time_str=
         )
         batch_dt = datetime.datetime.strptime(batch_time_str, "%Y%m%d%H%M")
         batch_hour_dt = batch_dt.replace(minute=0, second=0, microsecond=0)
-        # 跳过早于 min_batch_dt 的批次
         if min_batch_dt is not None and batch_hour_dt < min_batch_dt:
             continue
-        
+        if max_batch_dt is not None and batch_hour_dt > max_batch_dt:
+            continue
         csv_path = os.path.join(object_dir, csv_file)
         try:
             df_predict = pd.read_csv(csv_path)
@@ -443,8 +452,12 @@ if __name__ == "__main__":
         # validate_process(node, interval_hours, pred_time_str)
         # node = "node0127"
         # validate_process_zong(node)  # 无条件汇总
-        node = "node0203"
-        validate_process_zong(node, True, "202602041100")  # 有条件汇总
+        node = "node0127"
+        validate_process_zong(node, True, None, "202602051400")   # 有条件汇总
+        # node = "node0203"
+        # validate_process_zong(node, True, "202602041100", "202602051400")  # 有条件汇总
+        # node = "node0205"
+        # validate_process_zong(node, True, "202602061000")  # 有条件汇总
     # 1 自动验证
     else:
         node = "node0127"