Эх сурвалжийг харах

在计算价格分位时带上分组

node04 4 өдөр өмнө
parent
commit
611b03c087
2 өөрчлөгдсөн 49 нэмэгдсэн , 16 устгасан
  1. 46 13
      data_preprocess.py
  2. 3 3
      main_tr_0.py

+ 46 - 13
data_preprocess.py

@@ -1124,7 +1124,17 @@ def predict_data_simple(df_input, group_route_str, output_dir, predict_dir=".",
     else:
         df_rise_nodes = pd.DataFrame()
     
-    # 联合价格分布 ==========================================================
+    # 联合价格分布(按航班分组计算)==========================================================
+    flight_key = ['city_pair', 'flight_number_1', 'flight_number_2']
+    group_cols = [
+        c for c in flight_key
+        if (
+            c in df_min_hours.columns
+            or (not df_drop_nodes.empty and c in df_drop_nodes.columns)
+            or (not df_rise_nodes.empty and c in df_rise_nodes.columns)
+        )
+    ]
+
     # 统一初始化
     df_min_hours['relative_position'] = np.nan
     if not df_drop_nodes.empty:
@@ -1136,39 +1146,62 @@ def predict_data_simple(df_input, group_route_str, output_dir, predict_dir=".",
 
     # 当前待预测
     if not df_min_hours.empty and 'adult_total_price' in df_min_hours.columns:
-        cur = df_min_hours[['adult_total_price']].copy()
+        min_group_cols = [c for c in group_cols if c in df_min_hours.columns]
+        cur = df_min_hours[min_group_cols + ['flight_day', 'adult_total_price']].copy()
+        for c in group_cols:
+            if c not in cur.columns:
+                cur[c] = np.nan
         cur['price'] = pd.to_numeric(cur['adult_total_price'], errors='coerce')
         cur['source'] = 'min'
         cur['row_id'] = cur.index
-        parts.append(cur[['price', 'source', 'row_id']])
+        parts.append(cur[group_cols + ['flight_day', 'price', 'source', 'row_id']])
 
     # 历史降价
     if not df_drop_nodes.empty and 'high_price_amount' in df_drop_nodes.columns:
-        drop = df_drop_nodes[['high_price_amount']].copy()
+        drop_group_cols = [c for c in group_cols if c in df_drop_nodes.columns]
+        drop = df_drop_nodes[drop_group_cols + ['flight_day', 'high_price_amount']].copy()
+        for c in group_cols:
+            if c not in drop.columns:
+                drop[c] = np.nan
         drop['price'] = pd.to_numeric(drop['high_price_amount'], errors='coerce')
         drop['source'] = 'drop'
         drop['row_id'] = drop.index
-        parts.append(drop[['price', 'source', 'row_id']])
+        parts.append(drop[group_cols + ['flight_day', 'price', 'source', 'row_id']])
 
     # 历史升价
     if not df_rise_nodes.empty and 'prev_rise_amount' in df_rise_nodes.columns:
-        rise = df_rise_nodes[['prev_rise_amount']].copy()
+        rise_group_cols = [c for c in group_cols if c in df_rise_nodes.columns]
+        rise = df_rise_nodes[rise_group_cols + ['flight_day', 'prev_rise_amount']].copy()
+        for c in group_cols:
+            if c not in rise.columns:
+                rise[c] = np.nan
         rise['price'] = pd.to_numeric(rise['prev_rise_amount'], errors='coerce')
         rise['source'] = 'rise'
         rise['row_id'] = rise.index
-        parts.append(rise[['price', 'source', 'row_id']])
+        parts.append(rise[group_cols + ['flight_day', 'price', 'source', 'row_id']])
     
     if parts:
         all_prices = pd.concat(parts, ignore_index=True)
         all_prices = all_prices.dropna(subset=['price']).reset_index(drop=True)
 
-        # 计算价格百分位
-        dense_rank = all_prices['price'].rank(method='dense')
-        max_rank = dense_rank.max()
-        if pd.notna(max_rank) and max_rank > 1:
-            all_prices['relative_position'] = (dense_rank - 1) / (max_rank - 1)
+        # 计算价格百分位(优先按分组计算,无法分组时回退全局)
+        if group_cols:
+            all_prices['dense_rank'] = all_prices.groupby(group_cols, dropna=False)['price'].rank(method='dense')
+            all_prices['max_rank'] = all_prices.groupby(group_cols, dropna=False)['dense_rank'].transform('max')
+            all_prices['relative_position'] = np.where(
+                all_prices['max_rank'] > 1,
+                (all_prices['dense_rank'] - 1) / (all_prices['max_rank'] - 1),
+                1.0,
+            )
+            all_prices = all_prices.drop(columns=['dense_rank', 'max_rank'])
         else:
-            all_prices['relative_position'] = 1.0
+            dense_rank = all_prices['price'].rank(method='dense')
+            max_rank = dense_rank.max()
+            if pd.notna(max_rank) and max_rank > 1:
+                all_prices['relative_position'] = (dense_rank - 1) / (max_rank - 1)
+            else:
+                all_prices['relative_position'] = 1.0
+        
         all_prices['relative_position'] = all_prices['relative_position'].round(4)
 
         # 回填到三个表

+ 3 - 3
main_tr_0.py

@@ -48,9 +48,9 @@ def start_train():
     max_workers = min(8, cpu_cores)  # 最大不超过8个进程
 
     # date_end = datetime.today().strftime("%Y-%m-%d")
-    date_end = (datetime.today() - timedelta(days=1)).strftime("%Y-%m-%d")
-    # date_begin = (datetime.today() - timedelta(days=32)).strftime("%Y-%m-%d")
-    date_begin = "2026-04-01"   # 2026-03-01  2026-04-28
+    date_end = (datetime.today() + timedelta(days=1)).strftime("%Y-%m-%d")   # 截止到明天
+    date_begin = date_end
+    # date_begin = "2026-05-07"   # 2026-03-01  2026-04-28 2026-05-07
 
     print(f"训练时间范围: {date_begin} 到 {date_end}")