Browse Source

调整预测相关策略

node04 2 weeks ago
parent
commit
4f4e83b4f2
5 changed files with 239 additions and 138 deletions
  1. 12 6
      data_loader.py
  2. 202 112
      data_preprocess.py
  3. 4 3
      follow_up.py
  4. 11 7
      main_pe_0.py
  5. 10 10
      main_tr_0.py

+ 12 - 6
data_loader.py

@@ -83,7 +83,8 @@ def query_flight_range_status(db, table_name, from_city, to_city, dep_date_begin
                     "$gte": dep_date_begin,
                     "$lte": dep_date_end,
                 },
-                "segments.baggage": {"$in": ["1-20", "1-30"]}  # 只查20公斤和30公斤行李的
+                # "segments.baggage": {"$in": ["-;-;-;-", "1-30"]}  # 无行李,30公斤行李
+                "segments.baggage": "-;-;-;-"
             }
             # 动态添加航班号条件
             for i, flight_num in enumerate(flight_nums):
@@ -641,7 +642,7 @@ def plot_c12_trend(df, output_dir="."):
 
 def process_flight_group(args):
     """处理单个航班号的进程函数(独立数据库连接)"""
-    process_id, db_config, each_group, from_city, to_city, date_begin_s, date_end_s, is_hot, plot_flag, output_dir = args
+    process_id, db_config, each_group, from_city, to_city, date_begin_s, date_end_s, is_hot, is_train, plot_flag, output_dir = args
     flight_nums = each_group.get("flight_numbers")
     details = each_group.get("details")
 
@@ -690,6 +691,10 @@ def process_flight_group(args):
             common_dep_dates = df2['search_dep_time'].unique()
             common_baggages = df2['baggage'].unique()
 
+        # 如果是预测,起飞天数以远期表为主
+        if not is_train:
+            common_dep_dates = df1['search_dep_time'].unique()
+
         list_mid = []
         for dep_date in common_dep_dates:
             # 起飞日期筛选
@@ -784,7 +789,7 @@ def process_flight_group(args):
             pass
 
 
-def load_train_data(db_config, flight_route_list, table_name, date_begin, date_end, output_dir='.', is_hot=1, plot_flag=False,
+def load_train_data(db_config, flight_route_list, table_name, date_begin, date_end, output_dir='.', is_hot=1, is_train=True, plot_flag=False,
                     use_multiprocess=False, max_workers=None):
     """加载训练数据(支持多进程)"""
     timestamp_str = datetime.now().strftime("%Y%m%d%H%M%S")
@@ -814,7 +819,7 @@ def load_train_data(db_config, flight_route_list, table_name, date_begin, date_e
             process_id = 0
             for each_group in all_groups:
                 process_id += 1
-                args = (process_id, db_config, each_group, from_city, to_city, date_begin_s, date_end_s, is_hot, plot_flag, output_dir)
+                args = (process_id, db_config, each_group, from_city, to_city, date_begin_s, date_end_s, is_hot, is_train, plot_flag, output_dir)
                 process_args.append(args)
             
             with ProcessPoolExecutor(max_workers=max_workers) as executor:
@@ -838,7 +843,7 @@ def load_train_data(db_config, flight_route_list, table_name, date_begin, date_e
             print("使用单进程处理")
             process_id = 0
             for each_group in all_groups:
-                args = (process_id, db_config, each_group, from_city, to_city, date_begin_s, date_end_s, is_hot, plot_flag, output_dir)
+                args = (process_id, db_config, each_group, from_city, to_city, date_begin_s, date_end_s, is_hot, is_train, plot_flag, output_dir)
                 flight_nums = each_group.get("flight_numbers", "未知")
                 try:
                     df_mid = process_flight_group(args)
@@ -1005,10 +1010,11 @@ if __name__ == "__main__":
     os.makedirs(output_dir, exist_ok=True)
 
     # 加载热门航线数据
-    date_begin = "2026-01-20"
+    date_begin = "2026-01-01"
     date_end = datetime.today().strftime("%Y-%m-%d")
 
     flight_route_list = vj_flight_route_list_hot[:]  # 热门 vj_flight_route_list_hot  冷门 vj_flight_route_list_nothot
+    # flight_route_list = ["SGN-NGO"]  # 测试段
     table_name = CLEAN_VJ_HOT_NEAR_INFO_TAB  # 热门 CLEAN_VJ_HOT_NEAR_INFO_TAB  冷门 CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
     is_hot = 1   # 1 热门 0 冷门
     group_size = 1

+ 202 - 112
data_preprocess.py

@@ -845,7 +845,7 @@ def preprocess_data_simple(df_input, is_train=False):
     ).reset_index(drop=True)
 
     df_input = df_input[df_input['hours_until_departure'] <= 480]
-    df_input = df_input[df_input['baggage'] == 30]
+    df_input = df_input[df_input['baggage'] == 0]   # 只保留无行李的
 
     # 在hours_until_departure 的末尾 保留真实的而不是补齐的数据
     if not is_train:
@@ -854,16 +854,39 @@ def preprocess_data_simple(df_input, is_train=False):
         )
         df_input = df_input[~((df_input['is_filled'] == 1) & (_tail_filled == 1))]
     
+    # 价格变化最小量阈值
+    price_change_amount_threshold = 5
+    df_input['_raw_price_diff'] = df_input.groupby(['gid', 'baggage'], group_keys=False)['adult_total_price'].diff()
+
     # 计算价格变化量
+    # df_input['price_change_amount'] = (
+    #     df_input.groupby(['gid', 'baggage'], group_keys=False)['adult_total_price']
+    #     .apply(lambda s: s.diff().replace(0, np.nan).ffill().fillna(0)).round(2)
+    # )
     df_input['price_change_amount'] = (
-        df_input.groupby(['gid', 'baggage'], group_keys=False)['adult_total_price']
-        .apply(lambda s: s.diff().replace(0, np.nan).ffill().fillna(0)).round(2)
+        df_input['_raw_price_diff']
+        .mask(df_input['_raw_price_diff'].abs() < price_change_amount_threshold, 0)
+        .replace(0, np.nan)
+        .groupby([df_input['gid'], df_input['baggage']], group_keys=False)
+        .ffill()
+        .fillna(0)
+        .round(2)
     )
 
     # 计算价格变化百分比(相对于上一时间点的变化率)
+    # df_input['price_change_percent'] = (
+    #     df_input.groupby(['gid', 'baggage'], group_keys=False)['adult_total_price']
+    #     .apply(lambda s: s.pct_change().replace(0, np.nan).ffill().fillna(0)).round(4)
+    # )
     df_input['price_change_percent'] = (
         df_input.groupby(['gid', 'baggage'], group_keys=False)['adult_total_price']
-        .apply(lambda s: s.pct_change().replace(0, np.nan).ffill().fillna(0)).round(4)
+        .pct_change()
+        .mask(df_input['_raw_price_diff'].abs() < price_change_amount_threshold, 0)
+        .replace(0, np.nan)
+        .groupby([df_input['gid'], df_input['baggage']], group_keys=False)
+        .ffill()
+        .fillna(0)
+        .round(4)
     )
 
     # 第一步:标记价格变化段
@@ -880,7 +903,8 @@ def preprocess_data_simple(df_input, is_train=False):
     )
     
     # 可选:删除临时列
-    df_input = df_input.drop(columns=['price_change_segment'])
+    # df_input = df_input.drop(columns=['price_change_segment'])
+    df_input = df_input.drop(columns=['price_change_segment', '_raw_price_diff'])
 
     # 仅在价格变化点记录余票变化量;其它非价格变化点置空(NaN) 
     # _price_diff = df_input.groupby(['gid', 'baggage'], group_keys=False)['adult_total_price'].diff()
@@ -902,7 +926,7 @@ def preprocess_data_simple(df_input, is_train=False):
 
     # 训练过程
     if is_train:
-        df_target = df_input[(df_input['hours_until_departure'] >= 12) & (df_input['hours_until_departure'] <= 60)].copy()
+        df_target = df_input[(df_input['hours_until_departure'] >= 12) & (df_input['hours_until_departure'] <= 360)].copy()   # 扩展至360小时(15天) 
         df_target = df_target.sort_values(
             by=['gid', 'hours_until_departure'],
             ascending=[True, False]
@@ -947,61 +971,84 @@ def preprocess_data_simple(df_input, is_train=False):
         ]
         # 按顺序排列 去掉gid
         df_drop_nodes = df_drop_nodes[flight_info_cols + drop_info_cols]
-        df_drop_nodes = df_drop_nodes[df_drop_nodes['drop_price_change_percent'] <= -0.01]   # 太低的降幅不计
+        # df_drop_nodes = df_drop_nodes[df_drop_nodes['drop_price_change_percent'] <= -0.01]   # 太低的降幅不计
+
+        # 对于“上涨后再次上涨”的分析(连续两个正向变价段)
+        seg_start_mask = df_target['price_duration_hours'].eq(1)
+        rise_mask = seg_start_mask & (prev_pct > 0) & (df_target['price_change_percent'] > 0)
+
+        df_rise_nodes = df_target.loc[rise_mask, ['gid', 'hours_until_departure']].copy()
+        df_rise_nodes.rename(columns={'hours_until_departure': 'rise_hours_until_departure'}, inplace=True)
+        df_rise_nodes['rise_price_change_percent'] = df_target.loc[rise_mask, 'price_change_percent'].astype(float).round(4).to_numpy()
+        df_rise_nodes['rise_price_change_amount'] = df_target.loc[rise_mask, 'price_change_amount'].astype(float).round(2).to_numpy()
+        df_rise_nodes['prev_rise_duration_hours'] = prev_dur.loc[rise_mask].astype(float).to_numpy()
+        df_rise_nodes['prev_rise_change_percent'] = prev_pct.loc[rise_mask].astype(float).round(4).to_numpy()
+        df_rise_nodes['prev_rise_change_amount'] = prev_amo.loc[rise_mask].astype(float).round(2).to_numpy()
+        df_rise_nodes['prev_rise_amount'] = prev_price.loc[rise_mask].astype(float).round(2).to_numpy()
+        df_rise_nodes['prev_rise_seats_remaining'] = prev_seats.loc[rise_mask].astype(int).to_numpy()
+        df_rise_nodes = df_rise_nodes.reset_index(drop=True)
+
+        df_rise_nodes = df_rise_nodes.merge(df_gid_info, on='gid', how='left')
+        rise_info_cols = [
+            'rise_hours_until_departure', 'rise_price_change_percent', 'rise_price_change_amount',
+            'prev_rise_duration_hours', 'prev_rise_change_percent', 'prev_rise_change_amount',
+            'prev_rise_amount', 'prev_rise_seats_remaining',
+        ]
+        df_rise_nodes = df_rise_nodes[flight_info_cols + rise_info_cols]
 
         # 对于没有先升后降的gid进行分析
-        gids_with_drop = df_target.loc[drop_mask, 'gid'].unique()
-        df_no_drop = df_target[~df_target['gid'].isin(gids_with_drop)].copy()
+        # gids_with_drop = df_target.loc[drop_mask, 'gid'].unique()
+        # df_no_drop = df_target[~df_target['gid'].isin(gids_with_drop)].copy()
 
-        keep_info_cols = [
-            'keep_hours_until_departure', 'keep_price_change_percent', 'keep_price_change_amount', 
-            'keep_price_duration_hours', 'keep_price_amount', 'keep_price_seats_remaining',
-        ]
+        # keep_info_cols = [
+        #     'keep_hours_until_departure', 'keep_price_change_percent', 'keep_price_change_amount', 
+        #     'keep_price_duration_hours', 'keep_price_amount', 'keep_price_seats_remaining',
+        # ]
         
-        if df_no_drop.empty:
-            df_keep_nodes = pd.DataFrame(columns=flight_info_cols + keep_info_cols)
-        else:
-            df_no_drop = df_no_drop.sort_values(
-                by=['gid', 'hours_until_departure'],
-                ascending=[True, False]
-            ).reset_index(drop=True)
-
-            df_no_drop['keep_segment'] = df_no_drop.groupby('gid')['price_change_percent'].transform(
-                lambda s: (s != s.shift()).cumsum()
-            )
-
-            df_keep_row = (
-                df_no_drop.groupby(['gid', 'keep_segment'], as_index=False)
-                .tail(1)
-                .reset_index(drop=True)
-            )
-
-            df_keep_nodes = df_keep_row[
-                ['gid', 'hours_until_departure', 'price_change_percent', 'price_change_amount', 
-                 'price_duration_hours', 'adult_total_price', 'seats_remaining']
-            ].copy()
-            df_keep_nodes.rename(
-                columns={
-                    'hours_until_departure': 'keep_hours_until_departure',
-                    'price_change_percent': 'keep_price_change_percent',
-                    'price_change_amount': 'keep_price_change_amount',
-                    'price_duration_hours': 'keep_price_duration_hours',
-                    'adult_total_price': 'keep_price_amount',
-                    'seats_remaining': 'keep_price_seats_remaining',
-                },
-                inplace=True,
-            )
-
-            df_keep_nodes = df_keep_nodes.merge(df_gid_info, on='gid', how='left')
-            df_keep_nodes = df_keep_nodes[flight_info_cols + keep_info_cols]
-
-            del df_keep_row
+        # if df_no_drop.empty:
+        #     df_keep_nodes = pd.DataFrame(columns=flight_info_cols + keep_info_cols)
+        # else:
+        #     df_no_drop = df_no_drop.sort_values(
+        #         by=['gid', 'hours_until_departure'],
+        #         ascending=[True, False]
+        #     ).reset_index(drop=True)
+
+        #     df_no_drop['keep_segment'] = df_no_drop.groupby('gid')['price_change_percent'].transform(
+        #         lambda s: (s != s.shift()).cumsum()
+        #     )
+
+        #     df_keep_row = (
+        #         df_no_drop.groupby(['gid', 'keep_segment'], as_index=False)
+        #         .tail(1)
+        #         .reset_index(drop=True)
+        #     )
+
+        #     df_keep_nodes = df_keep_row[
+        #         ['gid', 'hours_until_departure', 'price_change_percent', 'price_change_amount', 
+        #          'price_duration_hours', 'adult_total_price', 'seats_remaining']
+        #     ].copy()
+        #     df_keep_nodes.rename(
+        #         columns={
+        #             'hours_until_departure': 'keep_hours_until_departure',
+        #             'price_change_percent': 'keep_price_change_percent',
+        #             'price_change_amount': 'keep_price_change_amount',
+        #             'price_duration_hours': 'keep_price_duration_hours',
+        #             'adult_total_price': 'keep_price_amount',
+        #             'seats_remaining': 'keep_price_seats_remaining',
+        #         },
+        #         inplace=True,
+        #     )
+
+        #     df_keep_nodes = df_keep_nodes.merge(df_gid_info, on='gid', how='left')
+        #     df_keep_nodes = df_keep_nodes[flight_info_cols + keep_info_cols]
+
+        #     del df_keep_row
         
         del df_gid_info
         del df_target
-        del df_no_drop
+        # del df_no_drop
 
-        return df_input, df_drop_nodes, df_keep_nodes
+        return df_input, df_drop_nodes, df_rise_nodes
 
     return df_input, None, None
 
@@ -1016,7 +1063,7 @@ def predict_data_simple(df_input, group_route_str, output_dir, predict_dir=".",
     ).reset_index(drop=True)
 
     df_sorted = df_sorted[
-        df_sorted['hours_until_departure'].between(12, 60)
+        df_sorted['hours_until_departure'].between(12, 360)
     ].reset_index(drop=True)
 
     # 每个 gid 取 hours_until_departure 最小的一条
@@ -1025,9 +1072,9 @@ def predict_data_simple(df_input, group_route_str, output_dir, predict_dir=".",
         .reset_index(drop=True)
     )
 
-    # 确保 hours_until_departure 在 [12, 60] 的 范围内
+    # 确保 hours_until_departure 在 [12, 360] 的 范围内
     # df_min_hours = df_min_hours[
-    #     df_min_hours['hours_until_departure'].between(12, 60)
+    #     df_min_hours['hours_until_departure'].between(12, 360)
     # ].reset_index(drop=True)
 
     drop_info_csv_path = os.path.join(output_dir, f'{group_route_str}_drop_info.csv')
@@ -1036,17 +1083,26 @@ def predict_data_simple(df_input, group_route_str, output_dir, predict_dir=".",
     else:
         df_drop_nodes = pd.DataFrame()
 
-    keep_info_csv_path = os.path.join(output_dir, f'{group_route_str}_keep_info.csv')
-    if os.path.exists(keep_info_csv_path):
-        df_keep_nodes = pd.read_csv(keep_info_csv_path)
+    rise_info_csv_path = os.path.join(output_dir, f'{group_route_str}_rise_info.csv')
+    if os.path.exists(rise_info_csv_path):
+        df_rise_nodes = pd.read_csv(rise_info_csv_path)
     else:
-        df_keep_nodes = pd.DataFrame()
+        df_rise_nodes = pd.DataFrame()
 
     df_min_hours['simple_will_price_drop'] = 0   
     df_min_hours['simple_drop_in_hours'] = 0
     df_min_hours['simple_drop_in_hours_prob'] = 0.0
     df_min_hours['simple_drop_in_hours_dist'] = ''   # 空串 表示未知
-    
+    df_min_hours['flag_dist'] = ''
+    df_min_hours['drop_price_change_upper'] = 0.0
+    df_min_hours['drop_price_change_mode'] = 0.0
+    df_min_hours['drop_price_change_lower'] = 0.0
+    df_min_hours['drop_price_sample_size'] = 0
+    df_min_hours['rise_price_change_upper'] = 0.0
+    df_min_hours['rise_price_change_mode'] = 0.0
+    df_min_hours['rise_price_change_lower'] = 0.0
+    df_min_hours['rise_price_sample_size'] = 0
+
     # 这个阈值取多少?
     pct_threshold = 0.001
     # pct_threshold = 2
@@ -1068,7 +1124,8 @@ def predict_data_simple(df_input, group_route_str, output_dir, predict_dir=".",
         seats_remaining = row['seats_remaining']
 
         length_drop = 0
-        length_keep = 0
+        length_rise = 0
+        # length_keep = 0
 
         # 针对历史上发生的 高价->低价
         if not df_drop_nodes.empty:
@@ -1088,7 +1145,7 @@ def predict_data_simple(df_input, group_route_str, output_dir, predict_dir=".",
             # 降价前 增幅阈值的匹配 与 高价历史持续时间 得出降价时间的概率
             if not df_drop_nodes_part.empty and pd.notna(price_change_percent):   
                 # 增幅太小的去掉
-                df_drop_nodes_part = df_drop_nodes_part[df_drop_nodes_part['high_price_change_percent'] >= 0.01]
+                # df_drop_nodes_part = df_drop_nodes_part[df_drop_nodes_part['high_price_change_percent'] >= 0.01]
                 # pct_diff = (df_drop_nodes_part['high_price_change_percent'] - float(price_change_percent)).abs()
                 # df_match = df_drop_nodes_part.loc[pct_diff <= pct_threshold, ['high_price_duration_hours', 'high_price_change_percent']].copy()
                 
@@ -1096,7 +1153,8 @@ def predict_data_simple(df_input, group_route_str, output_dir, predict_dir=".",
                 pct_vals = pd.to_numeric(df_drop_nodes_part['high_price_change_percent'], errors='coerce')
                 df_drop_gap = df_drop_nodes_part.loc[
                     pct_vals.notna(),
-                    ['drop_hours_until_departure', 'high_price_duration_hours', 'high_price_change_percent', 
+                    ['drop_hours_until_departure', 'drop_price_change_percent', 'drop_price_change_amount',
+                     'high_price_duration_hours', 'high_price_change_percent', 
                      'high_price_change_amount', 'high_price_amount', 'high_price_seats_remaining']
                 ].copy()
                 df_drop_gap['pct_gap'] = (pct_vals.loc[pct_vals.notna()] - pct_base)
@@ -1118,13 +1176,13 @@ def predict_data_simple(df_input, group_route_str, output_dir, predict_dir=".",
 
                     if pd.notna(dur_base) and pd.notna(hud_base):  # and pd.notna(seats_base)
                         df_match_chk = df_match.copy()
-                        dur_vals = pd.to_numeric(df_match_chk['high_price_duration_hours'], errors='coerce')
-                        df_match_chk = df_match_chk.loc[dur_vals.notna()].copy()
-                        df_match_chk = df_match_chk.loc[(dur_vals.loc[dur_vals.notna()] - float(dur_base)).abs() <= 36].copy()
+                        # dur_vals = pd.to_numeric(df_match_chk['high_price_duration_hours'], errors='coerce')
+                        # df_match_chk = df_match_chk.loc[dur_vals.notna()].copy()
+                        # df_match_chk = df_match_chk.loc[(dur_vals.loc[dur_vals.notna()] - float(dur_base)).abs() <= 36].copy()
 
                         drop_hud_vals = pd.to_numeric(df_match_chk['drop_hours_until_departure'], errors='coerce')
                         df_match_chk = df_match_chk.loc[drop_hud_vals.notna()].copy()
-                        df_match_chk = df_match_chk.loc[(drop_hud_vals.loc[drop_hud_vals.notna()] - float(hud_base)).abs() <= 18].copy()
+                        df_match_chk = df_match_chk.loc[(drop_hud_vals.loc[drop_hud_vals.notna()] - float(hud_base)).abs() <= 24].copy()
 
                         # seats_vals = pd.to_numeric(df_match_chk['high_price_seats_remaining_change_amount'], errors='coerce')
                         # df_match_chk = df_match_chk.loc[seats_vals.notna()].copy()
@@ -1132,6 +1190,18 @@ def predict_data_simple(df_input, group_route_str, output_dir, predict_dir=".",
 
                         # 持续时间、距离起飞时间、座位变化都匹配上
                         if not df_match_chk.empty:
+                            length_drop = df_match_chk.shape[0]
+                            df_min_hours.loc[idx, 'drop_price_sample_size'] = length_drop
+
+                            drop_price_change_upper = df_match_chk['drop_price_change_amount'].max()   # 降价上限
+                            drop_price_change_lower = df_match_chk['drop_price_change_amount'].min()   # 降价下限
+                            df_min_hours.loc[idx, 'drop_price_change_upper'] = round(drop_price_change_upper, 2)
+                            df_min_hours.loc[idx, 'drop_price_change_lower'] = round(drop_price_change_lower, 2)
+
+                            drop_mode_values = df_match_chk['drop_price_change_amount'].mode()  # 降价众数
+                            if len(drop_mode_values) > 0:
+                                df_min_hours.loc[idx, 'drop_price_change_mode'] = round(float(drop_mode_values[0]), 2)
+
                             remaining_hours = (
                                 pd.to_numeric(df_match_chk['high_price_duration_hours'], errors='coerce') - float(dur_base)
                             ).clip(lower=0)
@@ -1151,8 +1221,8 @@ def predict_data_simple(df_input, group_route_str, output_dir, predict_dir=".",
                             df_min_hours.loc[idx, 'simple_drop_in_hours'] = top_hours
                             df_min_hours.loc[idx, 'simple_drop_in_hours_prob'] = 1
                             df_min_hours.loc[idx, 'simple_drop_in_hours_dist'] = dist_str
+                            df_min_hours.loc[idx, 'flag_dist'] = 'd0'
 
-                            length_drop = df_match_chk.shape[0]
                             # continue   # 已经判定降价 后面不再做
                 
                 # 历史上未出现的极近似的增长幅度后的降价场景
@@ -1187,22 +1257,22 @@ def predict_data_simple(df_input, group_route_str, output_dir, predict_dir=".",
                     #             df_min_hours.loc[idx, 'simple_drop_in_hours_dist'] = '0h->0.3'
                     #             continue  # 已经判定降价 后面不再做
                             
-        # 针对历史上发生 一直低价、一直高价、低价->高价、连续低价 等 
-        if not df_keep_nodes.empty:
+        # 针对历史上发生的 连续涨价 
+        if not df_rise_nodes.empty:
             # 对准航班号, 不同起飞日期
             if flight_number_2 and flight_number_2 != 'VJ':
-                df_keep_nodes_part = df_keep_nodes[
-                    (df_keep_nodes['city_pair'] == city_pair) &
-                    (df_keep_nodes['flight_number_1'] == flight_number_1) &
-                    (df_keep_nodes['flight_number_2'] == flight_number_2)
+                df_rise_nodes_part = df_rise_nodes[
+                    (df_rise_nodes['city_pair'] == city_pair) &
+                    (df_rise_nodes['flight_number_1'] == flight_number_1) &
+                    (df_rise_nodes['flight_number_2'] == flight_number_2)
                 ]
             else:
-                df_keep_nodes_part = df_keep_nodes[
-                    (df_keep_nodes['city_pair'] == city_pair) &
-                    (df_keep_nodes['flight_number_1'] == flight_number_1)
+                df_rise_nodes_part = df_rise_nodes[
+                    (df_rise_nodes['city_pair'] == city_pair) &
+                    (df_rise_nodes['flight_number_1'] == flight_number_1)
                 ]
 
-            if not df_keep_nodes_part.empty and pd.notna(price_change_percent):
+            if not df_rise_nodes_part.empty and pd.notna(price_change_percent):
                 # pct_vals_1 = df_keep_nodes_part['keep_price_change_percent'].replace([np.inf, -np.inf], np.nan).dropna()
                 # # 保留百分位 10% ~ 90% 之间的 数据
                 # if not pct_vals_1.empty:
@@ -1275,68 +1345,84 @@ def predict_data_simple(df_input, group_route_str, output_dir, predict_dir=".",
 
                 # 一般判定场景
                 pct_base_1 = float(price_change_percent)
-                pct_vals_1 = pd.to_numeric(df_keep_nodes_part['keep_price_change_percent'], errors='coerce')
-                df_keep_gap_1 = df_keep_nodes_part.loc[
+                pct_vals_1 = pd.to_numeric(df_rise_nodes_part['prev_rise_change_percent'], errors='coerce')
+                df_rise_gap_1 = df_rise_nodes_part.loc[
                     pct_vals_1.notna(),
-                    ['keep_hours_until_departure', 'keep_price_duration_hours', 'keep_price_change_percent', 
-                     'keep_price_change_amount', 'keep_price_amount', 'keep_price_seats_remaining']
+                    ['rise_hours_until_departure', 'rise_price_change_percent', 'rise_price_change_amount',
+                     'prev_rise_duration_hours', 'prev_rise_change_percent', 
+                     'prev_rise_change_amount', 'prev_rise_amount', 'prev_rise_seats_remaining']
                 ].copy()
-                df_keep_gap_1['pct_gap'] = (pct_vals_1.loc[pct_vals_1.notna()] - pct_base_1)
-                df_keep_gap_1['pct_abs_gap'] = df_keep_gap_1['pct_gap'].abs()
+                df_rise_gap_1['pct_gap'] = (pct_vals_1.loc[pct_vals_1.notna()] - pct_base_1)
+                df_rise_gap_1['pct_abs_gap'] = df_rise_gap_1['pct_gap'].abs()
                 
                 price_base_1 = pd.to_numeric(price_amount, errors='coerce')
-                keep_price_vals_1 = pd.to_numeric(df_keep_gap_1['keep_price_amount'], errors='coerce')
-                df_keep_gap_1['price_gap'] = keep_price_vals_1 - price_base_1
-                df_keep_gap_1['price_abs_gap'] = df_keep_gap_1['price_gap'].abs()
+                rise_price_vals_1 = pd.to_numeric(df_rise_gap_1['prev_rise_amount'], errors='coerce')
+                df_rise_gap_1['price_gap'] = rise_price_vals_1 - price_base_1
+                df_rise_gap_1['price_abs_gap'] = df_rise_gap_1['price_gap'].abs()
 
-                df_keep_gap_1 = df_keep_gap_1.sort_values(['pct_abs_gap', 'price_abs_gap'], ascending=[True, True])
-                df_match_1 = df_keep_gap_1.loc[(df_keep_gap_1['pct_abs_gap'] <= pct_threshold_1) & (df_keep_gap_1['price_abs_gap'] <= 10.0)].copy()
+                df_rise_gap_1 = df_rise_gap_1.sort_values(['pct_abs_gap', 'price_abs_gap'], ascending=[True, True])
+                df_match_1 = df_rise_gap_1.loc[(df_rise_gap_1['pct_abs_gap'] <= pct_threshold_1) & (df_rise_gap_1['price_abs_gap'] <= 10.0)].copy()
 
-                # 历史上出现过近似变化幅度后保持低价场景
+                # 历史上出现过近似变化幅度后继续涨价场景
                 if not df_match_1.empty:
-                    df_match_1['hours_delta'] = hours_until_departure - df_match_1['keep_hours_until_departure']
-                    df_match_1['modify_keep_price_duration_hours'] = df_match_1['keep_price_duration_hours'] - df_match_1['hours_delta']
-                    # df_match_1 = df_match_1[df_match_1['modify_keep_price_duration_hours'] > 0]
+                    # df_match_1['hours_delta'] = hours_until_departure - df_match_1['rise_hours_until_departure']
+                    # df_match_1['modify_rise_price_duration_hours'] = df_match_1['rise_price_duration_hours'] - df_match_1['hours_delta']
+                    # df_match_1 = df_match_1[df_match_1['modify_rise_price_duration_hours'] > 0]
 
-                    dur_base_1 = pd.to_numeric(price_duration_hours, errors='coerce')
-                    # hud_base_1 = pd.to_numeric(hours_until_departure, errors='coerce')
+                    # dur_base_1 = pd.to_numeric(price_duration_hours, errors='coerce')
+                    hud_base_1 = pd.to_numeric(hours_until_departure, errors='coerce')
                     # seats_base_1 = pd.to_numeric(seats_remaining_change_amount, errors='coerce')
 
-                    if pd.notna(dur_base_1):   #  and pd.notna(seats_base_1)
+                    if pd.notna(hud_base_1):   #  and pd.notna(seats_base_1)
                         df_match_chk_1 = df_match_1.copy()
-                        dur_vals_1 = pd.to_numeric(df_match_chk_1['modify_keep_price_duration_hours'], errors='coerce')
-                        df_match_chk_1 = df_match_chk_1.loc[dur_vals_1.notna()].copy()
-                        df_match_chk_1 = df_match_chk_1.loc[(dur_vals_1.loc[dur_vals_1.notna()] - float(dur_base_1)).abs() <= 24].copy()
+                        # dur_vals_1 = pd.to_numeric(df_match_chk_1['modify_rise_price_duration_hours'], errors='coerce')
+                        # df_match_chk_1 = df_match_chk_1.loc[dur_vals_1.notna()].copy()
+                        # df_match_chk_1 = df_match_chk_1.loc[(dur_vals_1.loc[dur_vals_1.notna()] - float(dur_base_1)).abs() <= 24].copy()
 
-                        # drop_hud_vals_1 = pd.to_numeric(df_match_chk_1['keep_hours_until_departure'], errors='coerce')
-                        # df_match_chk_1 = df_match_chk_1.loc[drop_hud_vals_1.notna()].copy()
-                        # df_match_chk_1 = df_match_chk_1.loc[(drop_hud_vals_1.loc[drop_hud_vals_1.notna()] - float(hud_base_1)).abs() <= 20].copy()
+                        rise_hud_vals_1 = pd.to_numeric(df_match_chk_1['rise_hours_until_departure'], errors='coerce')
+                        df_match_chk_1 = df_match_chk_1.loc[rise_hud_vals_1.notna()].copy()
+                        df_match_chk_1 = df_match_chk_1.loc[(rise_hud_vals_1.loc[rise_hud_vals_1.notna()] - float(hud_base_1)).abs() <= 24].copy()
 
-                        # seats_vals_1 = pd.to_numeric(df_match_chk_1['keep_seats_remaining_change_amount'], errors='coerce')
+                        # seats_vals_1 = pd.to_numeric(df_match_chk_1['rise_seats_remaining_change_amount'], errors='coerce')
                         # df_match_chk_1 = df_match_chk_1.loc[seats_vals_1.notna()].copy()
                         # df_match_chk_1 = df_match_chk_1.loc[seats_vals_1.loc[seats_vals_1.notna()] == float(seats_base_1)].copy()
 
                         # 持续时间、距离起飞时间、座位变化都匹配上
                         if not df_match_chk_1.empty:
-                            length_keep = df_match_chk_1.shape[0]
+                            length_rise = df_match_chk_1.shape[0]
+                            df_min_hours.loc[idx, 'rise_price_sample_size'] = length_rise
+                            
+                            rise_price_change_upper = df_match_chk_1['rise_price_change_amount'].max()   # 涨价上限
+                            rise_price_change_lower = df_match_chk_1['rise_price_change_amount'].min()   # 涨价下限
+                            df_min_hours.loc[idx, 'rise_price_change_upper'] = round(rise_price_change_upper, 2)
+                            df_min_hours.loc[idx, 'rise_price_change_lower'] = round(rise_price_change_lower, 2)
+
+                            rise_mode_values = df_match_chk_1['rise_price_change_amount'].mode()  # 涨价众数
+                            if len(rise_mode_values) > 0:
+                                df_min_hours.loc[idx, 'rise_price_change_mode'] = round(float(rise_mode_values[0]), 2)
+
                             # 可以明确的判定不降价
                             if length_drop == 0:
                                 df_min_hours.loc[idx, 'simple_will_price_drop'] = 0
                                 df_min_hours.loc[idx, 'simple_drop_in_hours'] = 0
                                 df_min_hours.loc[idx, 'simple_drop_in_hours_prob'] = 0.0
-                                df_min_hours.loc[idx, 'simple_drop_in_hours_dist'] = 'k0'
+                                # df_min_hours.loc[idx, 'simple_drop_in_hours_dist'] = 'r0'
+                                df_min_hours.loc[idx, 'flag_dist'] = 'r0'
                             # 分歧判定
                             else:
-                                drop_prob = round(length_drop / (length_keep + length_drop), 2)
+                                drop_prob = round(length_drop / (length_rise + length_drop), 2)
                                 # 依旧保持之前的降价判定,概率修改
                                 if drop_prob >= 0.4:
                                     df_min_hours.loc[idx, 'simple_will_price_drop'] = 1
-                                    df_min_hours.loc[idx, 'simple_drop_in_hours_dist'] = 'd1'
+                                    # df_min_hours.loc[idx, 'simple_drop_in_hours_dist'] = 'd1'
+                                    df_min_hours.loc[idx, 'flag_dist'] = 'd1'
                                 # 改判不降价,概率修改
                                 else:
                                     df_min_hours.loc[idx, 'simple_will_price_drop'] = 0
-                                    df_min_hours.loc[idx, 'simple_drop_in_hours_dist'] = 'k1'
-                                df_min_hours.loc[idx, 'simple_drop_in_hours'] = 0
+                                    # df_min_hours.loc[idx, 'simple_drop_in_hours_dist'] = 'r1'
+                                    df_min_hours.loc[idx, 'flag_dist'] = 'r1'
+
+                                # df_min_hours.loc[idx, 'simple_drop_in_hours'] = 0
                                 df_min_hours.loc[idx, 'simple_drop_in_hours_prob'] = drop_prob
                                 
                             # elif length_keep == length_drop:   # 不降价与降价相同, 取0.5概率
@@ -1397,15 +1483,19 @@ def predict_data_simple(df_input, group_route_str, output_dir, predict_dir=".",
     _pred_dt = pd.to_datetime(str(pred_time_str), format="%Y%m%d%H%M", errors="coerce")
     df_min_hours["update_hour"] = _pred_dt.strftime("%Y-%m-%d %H:%M:%S")
     _dep_hour = pd.to_datetime(df_min_hours["from_time"], errors="coerce").dt.floor("h")
-    df_min_hours["valid_begin_hour"] = (_dep_hour - pd.to_timedelta(60, unit="h")).dt.strftime("%Y-%m-%d %H:%M:%S")
+    df_min_hours["valid_begin_hour"] = (_dep_hour - pd.to_timedelta(360, unit="h")).dt.strftime("%Y-%m-%d %H:%M:%S")
     df_min_hours["valid_end_hour"] = (_dep_hour - pd.to_timedelta(12, unit="h")).dt.strftime("%Y-%m-%d %H:%M:%S")
 
+    # 要展示在预测表里的字段
     order_cols = ['city_pair', 'flight_day', 'flight_number_1', 'flight_number_2', 'from_time', 
                   'baggage', 'seats_remaining', 'currency',
                   'adult_total_price', 'hours_until_departure', 'price_change_percent', 'price_duration_hours', 
                   'update_hour', 'crawl_date',
                   'valid_begin_hour', 'valid_end_hour',
-                  'simple_will_price_drop', 'simple_drop_in_hours', 'simple_drop_in_hours_prob', 'simple_drop_in_hours_dist'
+                  'simple_will_price_drop', 'simple_drop_in_hours', 'simple_drop_in_hours_prob', 'simple_drop_in_hours_dist',
+                  'flag_dist',
+                  'drop_price_change_upper', 'drop_price_change_mode', 'drop_price_change_lower', 'drop_price_sample_size',
+                  'rise_price_change_upper', 'rise_price_change_mode', 'rise_price_change_lower', 'rise_price_sample_size',
                  ]
     df_predict = df_min_hours[order_cols]
     df_predict = df_predict.rename(columns={

+ 4 - 3
follow_up.py

@@ -328,7 +328,8 @@ def follow_up_handle():
     pass
 
 if __name__ == "__main__":
+    time.sleep(2)
     follow_up_handle()
-    time.sleep(10)
-    from descending_cabin_task import main as descending_cabin_task_main
-    descending_cabin_task_main()
+    # time.sleep(10)
+    # from descending_cabin_task import main as descending_cabin_task_main
+    # descending_cabin_task_main()

+ 11 - 7
main_pe_0.py

@@ -38,20 +38,24 @@ def start_predict():
         except Exception as e:
             print(f"remove {csv_path} info: {str(e)}")
 
-    # 预测时间范围,满足起飞时间 在12小时后到60小时后
+    # 预测时间范围,满足起飞时间 在12小时后到360小时后
     pred_hour_begin = hourly_time + timedelta(hours=12)
-    pred_hour_end = hourly_time + timedelta(hours=60)
+    pred_hour_end = hourly_time + timedelta(hours=360)
 
     pred_date_end = pred_hour_end.strftime("%Y-%m-%d")
     pred_date_begin = pred_hour_begin.strftime("%Y-%m-%d")
 
     print(f"预测起飞时间范围: {pred_date_begin} 到 {pred_date_end}")
 
-    # 主干代码 (排除冷门航线)
-    flight_route_list = vj_flight_route_list_hot + vj_flight_route_list_nothot[:0]
+    # 主干代码 (热门航线 + 冷门航线)
+    flight_route_list = vj_flight_route_list_hot + vj_flight_route_list_nothot
     flight_route_list_len = len(flight_route_list)
     route_len_hot = len(vj_flight_route_list_hot)
-    route_len_nothot = len(vj_flight_route_list_nothot[:0])
+    route_len_nothot = len(vj_flight_route_list_nothot)
+
+    print(f"flight_route_list_len:{flight_route_list_len}")
+    print(f"route_len_hot:{route_len_hot}")
+    print(f"route_len_nothot:{route_len_nothot}")
 
     group_size = 1              # 每几组作为一个批次
 
@@ -83,9 +87,9 @@ def start_predict():
             print(f"无法确定热门还是冷门, 跳过此批次。")
             continue
 
-        # 加载测试数据 (仅仅是时间段取到后天)
+        # 加载测试数据 (仅仅是天数取到以后)
         start_time = time.time()
-        df_test = load_train_data(mongodb_config, group_route_list, table_name, pred_date_begin, pred_date_end, output_dir, is_hot,
+        df_test = load_train_data(mongodb_config, group_route_list, table_name, pred_date_begin, pred_date_end, output_dir, is_hot, is_train=False,
                                   use_multiprocess=True, max_workers=max_workers)
         end_time = time.time()
         run_time = round(end_time - start_time, 3)

+ 10 - 10
main_tr_0.py

@@ -49,15 +49,15 @@ def start_train():
     # date_end = datetime.today().strftime("%Y-%m-%d")
     date_end = (datetime.today() - timedelta(days=1)).strftime("%Y-%m-%d")
     # date_begin = (datetime.today() - timedelta(days=32)).strftime("%Y-%m-%d")
-    date_begin = "2026-02-24"   # 2025-12-01 2026-02-11 2026-02-24 2026-03-02
+    date_begin = "2026-01-01"   # 2025-12-01 2026-02-11 2026-02-24 2026-03-02
 
     print(f"训练时间范围: {date_begin} 到 {date_end}")
 
-    # 主干代码 (排除冷门航线)
-    flight_route_list = vj_flight_route_list_hot + vj_flight_route_list_nothot[:0]
+    # 主干代码 (热门航线 + 冷门航线)
+    flight_route_list = vj_flight_route_list_hot + vj_flight_route_list_nothot
     flight_route_list_len = len(flight_route_list)
     route_len_hot = len(vj_flight_route_list_hot)
-    route_len_nothot = len(vj_flight_route_list_nothot[:0])
+    route_len_nothot = len(vj_flight_route_list_nothot)
 
     print(f"flight_route_list_len:{flight_route_list_len}")
     print(f"route_len_hot:{route_len_hot}")
@@ -105,7 +105,7 @@ def start_train():
             print(f"训练数据为空,跳过此批次。")
             continue
 
-        _, df_drop_nodes, df_keep_nodes = preprocess_data_simple(df_train, is_train=True)
+        _, df_drop_nodes, df_rise_nodes = preprocess_data_simple(df_train, is_train=True)
 
         dedup_cols = ['city_pair', 'flight_number_1', 'flight_number_2', 'flight_day']
 
@@ -116,12 +116,12 @@ def start_train():
             merge_and_overwrite_csv(df_drop_nodes, drop_info_csv_path, dedup_cols)
             print(f"本批次训练已保存csv文件: {drop_info_csv_path}")
 
-        keep_info_csv_path = os.path.join(output_dir, f'{group_route_str}_keep_info.csv')
-        if df_keep_nodes.empty:
-            print(f"df_keep_nodes 为空,跳过保存: {keep_info_csv_path}")
+        rise_info_csv_path = os.path.join(output_dir, f'{group_route_str}_rise_info.csv')
+        if df_rise_nodes.empty:
+            print(f"df_rise_nodes 为空,跳过保存: {rise_info_csv_path}")
         else:
-            merge_and_overwrite_csv(df_keep_nodes, keep_info_csv_path, dedup_cols)
-            print(f"本批次训练已保存csv文件: {keep_info_csv_path}")
+            merge_and_overwrite_csv(df_rise_nodes, rise_info_csv_path, dedup_cols)
+            print(f"本批次训练已保存csv文件: {rise_info_csv_path}")
 
         time.sleep(1)