Sfoglia il codice sorgente

完善预测代码

node04 1 giorno fa
parent
commit
de0fdbb4c2
4 ha cambiato i file con 244 aggiunte e 6 eliminazioni
  1. 233 1
      data_process.py
  2. 8 2
      main_pe.py
  3. 1 1
      main_tr.py
  4. 2 2
      uo_atlas_import.py

+ 233 - 1
data_process.py

@@ -267,4 +267,236 @@ def predict_data_simple(df_input, city_pair, output_dir, predict_dir=".", pred_t
             m = all_prices['source'] == 'rise'
             df_rise_nodes.loc[all_prices.loc[m, 'row_id'], 'relative_position'] = all_prices.loc[m, 'relative_position'].values
     
-    pass
+    pass
+    # =====================================================================
+
+    df_min_hours['simple_will_price_drop'] = 0
+    df_min_hours['simple_drop_in_hours'] = 0
+    df_min_hours['simple_drop_in_hours_prob'] = 0.0
+    df_min_hours['simple_drop_in_hours_dist'] = ''   # 空串 表示未知
+    df_min_hours['flag_dist'] = ''
+    df_min_hours['drop_price_change_upper'] = 0.0
+    df_min_hours['drop_price_change_lower'] = 0.0
+    df_min_hours['drop_price_sample_size'] = 0
+    df_min_hours['rise_price_change_upper'] = 0.0
+    df_min_hours['rise_price_change_lower'] = 0.0
+    df_min_hours['rise_price_sample_size'] = 0
+
+    # 这个阈值取多少?
+    # pct_threshold = 0.01
+    # pct_threshold_1 = 0.01
+
+    for idx, row in df_min_hours.iterrows(): 
+        city_pair = row['citypair']
+        flight_numbers = row['flight_numbers']
+        baggage_weight = row['baggage_weight']
+        days_to_departure = row['days_to_departure']
+        hours_until_departure = row['hours_until_departure']
+        price_change_percent = row['price_change_percent']
+        price_change_amount = row['price_change_amount']
+        price_duration_hours = row['price_duration_hours']
+        price_amount = row['price_total']
+
+        length_drop = 0
+        length_rise = 0
+
+        # 针对历史上发生的 >降价
+        if not df_drop_nodes.empty:
+            # 对准航线 航班号 行李配额
+            df_drop_nodes_part = df_drop_nodes[
+                (df_drop_nodes['citypair'] == city_pair) &
+                (df_drop_nodes['flight_numbers'] == flight_numbers) &
+                (df_drop_nodes['baggage_weight'] == baggage_weight)
+            ]
+            # 降价前 增量阈值、当前阈值 的匹配
+            if not df_drop_nodes_part.empty and pd.notna(price_change_amount):   
+                
+                pca_base = float(price_change_amount)
+                pca_vals = pd.to_numeric(df_drop_nodes_part['high_price_change_amount'], errors='coerce')
+                df_drop_gap = df_drop_nodes_part.loc[
+                    pca_vals.notna(),
+                    ['drop_days_to_departure', 'drop_hours_until_departure', 'drop_price_change_percent', 'drop_price_change_amount', 
+                     'high_price_duration_hours', 'high_price_change_percent', 'high_price_change_amount', 'high_price_amount', 'relative_position'
+                     ]
+                ].copy()
+                df_drop_gap['pca_gap'] = (pca_vals.loc[pca_vals.notna()] - pca_base)
+                df_drop_gap['pca_abs_gap'] = df_drop_gap['pca_gap'].abs()
+
+                price_base = pd.to_numeric(price_amount, errors='coerce')
+                high_price_vals = pd.to_numeric(df_drop_gap['high_price_amount'], errors='coerce')
+                df_drop_gap['price_gap'] = high_price_vals - price_base
+                df_drop_gap['price_abs_gap'] = df_drop_gap['price_gap'].abs()
+
+                df_drop_gap = df_drop_gap.sort_values(['price_abs_gap', 'pca_abs_gap'], ascending=[True, True])
+                df_match = df_drop_gap[(df_drop_gap['price_abs_gap'] <= 5.0) & (df_drop_gap['pca_abs_gap'] <= 10.0)].copy()
+
+                # 历史上出现的极近似的增长(下降)幅度后的降价场景
+                if not df_match.empty:
+                    dur_base = pd.to_numeric(price_duration_hours, errors='coerce')
+                    # hud_base = pd.to_numeric(hours_until_departure, errors='coerce')
+                    dtd_base = pd.to_numeric(days_to_departure, errors='coerce')
+
+                    if pd.notna(dur_base) and pd.notna(dtd_base): 
+                        df_match_chk = df_match.copy()
+
+                        drop_dtd_vals = pd.to_numeric(df_match_chk['drop_days_to_departure'], errors='coerce')
+                        df_match_chk = df_match_chk.loc[drop_dtd_vals.notna()].copy()
+                        df_match_chk = df_match_chk.loc[(drop_dtd_vals.loc[drop_dtd_vals.notna()] - float(dtd_base)).abs() <= 3].copy()
+
+                        # 距离起飞天数也对的上
+                        if not df_match_chk.empty:
+                            length_drop = df_match_chk.shape[0]
+                            df_min_hours.loc[idx, 'drop_price_sample_size'] = length_drop
+
+                            drop_price_change_upper = df_match_chk['drop_price_change_amount'].max()   # 降价上限
+                            drop_price_change_lower = df_match_chk['drop_price_change_amount'].min()   # 降价下限
+                            df_min_hours.loc[idx, 'drop_price_change_upper'] = round(drop_price_change_upper, 2)
+                            df_min_hours.loc[idx, 'drop_price_change_lower'] = round(drop_price_change_lower, 2)
+
+                            remaining_hours = (
+                                pd.to_numeric(df_match_chk['high_price_duration_hours'], errors='coerce') - float(dur_base)
+                            ).clip(lower=0)
+                            remaining_hours = remaining_hours.round().astype(int)
+
+                            counts = remaining_hours.value_counts().sort_index()
+                            probs = (counts / counts.sum()).round(4)
+
+                            top_hours = int(probs.idxmax())
+                            top_prob = float(probs.max())
+
+                            dist_items = list(zip(probs.index.tolist(), probs.tolist()))
+                            dist_items = dist_items[:10]
+                            dist_str = ' '.join([f"{int(h)}h->{float(p)}" for h, p in dist_items])
+
+                            df_min_hours.loc[idx, 'simple_will_price_drop'] = 1
+                            df_min_hours.loc[idx, 'simple_drop_in_hours'] = top_hours
+                            df_min_hours.loc[idx, 'simple_drop_in_hours_prob'] = 1
+                            df_min_hours.loc[idx, 'simple_drop_in_hours_dist'] = dist_str
+                            df_min_hours.loc[idx, 'flag_dist'] = 'd0'
+                    pass
+                pass
+        
+        # 针对历史上发生的 <升价
+        if not df_rise_nodes.empty:
+            # 对准航线 航班号 行李配额
+            df_rise_nodes_part = df_rise_nodes[
+                (df_rise_nodes['citypair'] == city_pair) &
+                (df_rise_nodes['flight_numbers'] == flight_numbers) &
+                (df_rise_nodes['baggage_weight'] == baggage_weight)
+            ]
+            # 升价前 增量阈值、当前阈值 的匹配
+            if not df_rise_nodes_part.empty and pd.notna(price_change_amount):
+                pca_base_1 = float(price_change_amount)
+                pca_vals_1 = pd.to_numeric(df_rise_nodes_part['prev_rise_change_amount'], errors='coerce')
+                df_rise_gap_1 = df_rise_nodes_part.loc[
+                    pca_vals_1.notna(),
+                    ['rise_days_to_departure', 'rise_hours_until_departure', 'rise_price_change_percent', 'rise_price_change_amount',
+                     'prev_rise_duration_hours', 'prev_rise_change_percent', 'prev_rise_change_amount', 'prev_rise_amount', 'relative_position']
+                ].copy()
+                df_rise_gap_1['pca_gap'] = (pca_vals_1.loc[pca_vals_1.notna()] - pca_base_1)
+                df_rise_gap_1['pca_abs_gap'] = df_rise_gap_1['pca_gap'].abs()
+
+                price_base_1 = pd.to_numeric(price_amount, errors='coerce')
+                rise_price_vals_1 = pd.to_numeric(df_rise_gap_1['prev_rise_amount'], errors='coerce')
+                df_rise_gap_1['price_gap'] = rise_price_vals_1 - price_base_1
+                df_rise_gap_1['price_abs_gap'] = df_rise_gap_1['price_gap'].abs()
+
+                df_rise_gap_1 = df_rise_gap_1.sort_values(['price_abs_gap', 'pca_abs_gap'], ascending=[True, True])
+                df_match_1 = df_rise_gap_1.loc[(df_rise_gap_1['price_abs_gap'] <= 5.0) & (df_rise_gap_1['pca_abs_gap'] <= 10.0)].copy()
+
+                # 历史上出现的极近似的增长(下降)幅度后的升价场景
+                if not df_match_1.empty:
+                    dur_base_1 = pd.to_numeric(price_duration_hours, errors='coerce')
+                    # hud_base_1 = pd.to_numeric(hours_until_departure, errors='coerce')
+                    dtd_base_1 = pd.to_numeric(days_to_departure, errors='coerce')
+
+                    if pd.notna(dur_base_1) and pd.notna(dtd_base_1): 
+                        df_match_chk_1 = df_match_1.copy()
+                        
+                        drop_dtd_vals_1 = pd.to_numeric(df_match_chk_1['rise_days_to_departure'], errors='coerce')
+                        df_match_chk_1 = df_match_chk_1.loc[drop_dtd_vals_1.notna()].copy()
+                        df_match_chk_1 = df_match_chk_1.loc[(drop_dtd_vals_1.loc[drop_dtd_vals_1.notna()] - float(dtd_base_1)).abs() <= 3].copy()
+
+                        # 距离起飞天数也对的上
+                        if not df_match_chk_1.empty:
+                            length_rise = df_match_chk_1.shape[0]
+                            df_min_hours.loc[idx, 'rise_price_sample_size'] = length_rise
+
+                            rise_price_change_upper = df_match_chk_1['rise_price_change_amount'].max()   # 涨价上限
+                            rise_price_change_lower = df_match_chk_1['rise_price_change_amount'].min()   # 涨价下限
+                            df_min_hours.loc[idx, 'rise_price_change_upper'] = round(rise_price_change_upper, 2)
+                            df_min_hours.loc[idx, 'rise_price_change_lower'] = round(rise_price_change_lower, 2)
+
+                            # 可以明确的判定不降价
+                            if length_drop == 0:
+                                df_min_hours.loc[idx, 'simple_will_price_drop'] = 0
+                                df_min_hours.loc[idx, 'simple_drop_in_hours'] = 0
+                                df_min_hours.loc[idx, 'simple_drop_in_hours_prob'] = 0.0
+                                # df_min_hours.loc[idx, 'simple_drop_in_hours_dist'] = 'r0'
+                                df_min_hours.loc[idx, 'flag_dist'] = 'r0'
+                            # 分歧判定
+                            else:
+                                drop_prob = round(length_drop / (length_rise + length_drop), 2)
+                                # 依旧保持之前的降价判定,概率修改
+                                if drop_prob >= 0.4:
+                                    df_min_hours.loc[idx, 'simple_will_price_drop'] = 1
+                                    # df_min_hours.loc[idx, 'simple_drop_in_hours_dist'] = 'd1'
+                                    df_min_hours.loc[idx, 'flag_dist'] = 'd1'
+                                # 改判不降价,概率修改
+                                else:
+                                    df_min_hours.loc[idx, 'simple_will_price_drop'] = 0
+                                    # df_min_hours.loc[idx, 'simple_drop_in_hours_dist'] = 'r1'
+                                    df_min_hours.loc[idx, 'flag_dist'] = 'r1'
+
+                                df_min_hours.loc[idx, 'simple_drop_in_hours_prob'] = drop_prob
+
+    print("判定循环结束")
+
+    _dep_hour = pd.to_datetime(df_min_hours["from_time"], errors="coerce").dt.floor("h")
+    df_min_hours["valid_begin_hour"] = (_dep_hour - pd.to_timedelta(360, unit="h")).dt.strftime("%Y-%m-%d %H:%M:%S")
+    df_min_hours["valid_end_hour"] = (_dep_hour - pd.to_timedelta(24, unit="h")).dt.strftime("%Y-%m-%d %H:%M:%S")
+
+    # 要展示在预测表里的字段
+    order_cols = [
+        "citypair", "flight_numbers", "baggage_weight", "from_date", "from_time",
+        "cabins", "ticket_amount", "currency", 
+        "price_total", 'relative_position', 'days_to_departure', 'hours_until_departure', 
+        'price_change_amount', 'price_change_percent', 'price_duration_hours', 
+        "update_hour", "update_week", 
+        'valid_begin_hour', 'valid_end_hour',
+        'simple_will_price_drop', 'simple_drop_in_hours', 'simple_drop_in_hours_prob', 'simple_drop_in_hours_dist',
+        'flag_dist',
+        'drop_price_change_upper', 'drop_price_change_lower', 'drop_price_sample_size',
+        'rise_price_change_upper', 'rise_price_change_lower', 'rise_price_sample_size',
+    ]
+    df_predict = df_min_hours[order_cols]
+    df_predict = df_predict.rename(columns={
+            'simple_will_price_drop': 'will_price_drop',
+            'simple_drop_in_hours': 'drop_in_hours',
+            'simple_drop_in_hours_prob': 'drop_in_hours_prob',
+            'simple_drop_in_hours_dist': 'drop_in_hours_dist',
+        }
+    )
+
+    # 排序
+    df_predict = df_predict.sort_values(
+        by=['citypair', 'flight_numbers', 'baggage_weight', 'from_date'],
+        kind='mergesort',
+        na_position='last',
+    ).reset_index(drop=True)
+
+    total_cnt = len(df_predict)
+    if "will_price_drop" in df_predict.columns:
+        _wpd = pd.to_numeric(df_predict["will_price_drop"], errors="coerce")
+        drop_1_cnt = int((_wpd == 1).sum())
+        drop_0_cnt = int((_wpd == 0).sum())
+    else:
+        drop_1_cnt = 0
+        drop_0_cnt = 0
+    print(f"will_price_drop 分类数量统计: 1(会降)={drop_1_cnt}, 0(不降)={drop_0_cnt}, 总数={total_cnt}")
+
+    csv_path1 = os.path.join(predict_dir, f'future_predictions_{pred_time_str}.csv')
+    df_predict.to_csv(csv_path1, mode='a', index=False, header=not os.path.exists(csv_path1), encoding='utf-8-sig')
+
+    print("预测结果已追加")
+    return df_predict

+ 8 - 2
main_pe.py

@@ -87,8 +87,14 @@ def start_predict():
 
         df_predict = predict_data_simple(df_test_inputs, uo_city_pair, output_dir, predict_dir, hourly_time_str)
         
-        pass
-
+        del df_test_inputs
+        del df_predict
+        print(f"第 {idx} 组 预测完成")
+        print()
+        time.sleep(1)
+
+    print("所有批次的预测结束")
+    print()
 
 
 if __name__ == "__main__":

+ 1 - 1
main_tr.py

@@ -21,7 +21,7 @@ def start_train():
     max_workers = min(8, cpu_cores)  # 最大不超过8个进程
 
     from_date_end = (datetime.today() - timedelta(days=0)).strftime("%Y-%m-%d")  # 截止日改为今天
-    from_date_begin = "2026-03-17"
+    from_date_begin = "2026-03-27"
 
     print(f"训练时间范围: {from_date_begin} 到 {from_date_end}")
 

+ 2 - 2
uo_atlas_import.py

@@ -230,8 +230,8 @@ def main_import_process(create_at_begin, create_at_end):
     print()
 
 if __name__ == "__main__":
-    create_at_begin = "2026-03-26 00:00:00"
-    create_at_end = "2026-03-26 15:59:59"
+    create_at_begin = "2026-03-27 10:00:00"
+    create_at_end = "2026-03-27 15:59:59"
     main_import_process(create_at_begin, create_at_end)
     
     # try: