Jelajahi Sumber

提交包络线+降价潜力方案, 新的验证keep_info的函数

node04 2 minggu lalu
induk
melakukan
413a08c177
7 mengubah file dengan 517 tambahan dan 93 penghapusan
  1. 2 0
      .gitignore
  2. 93 0
      data_loader.py
  3. 183 26
      data_preprocess.py
  4. 68 65
      follow_up.py
  5. 1 1
      main_pe_0.py
  6. 13 1
      main_tr_0.py
  7. 157 0
      result_keep_verify.py

+ 2 - 0
.gitignore

@@ -12,4 +12,6 @@ predictions_2/
 predictions_4/
 validate/
 keep_0/
+logs/
+photo_envelope/
 __pycache__/

+ 93 - 0
data_loader.py

@@ -998,6 +998,99 @@ def validate_one_line(db, city_pair, flight_day, flight_number_1, flight_number_
             time.sleep(sleep_time)
 
 
+def validate_keep_one_line(db, table_name, city_pair, flight_day, flight_number_1, flight_number_2, baggage, update_hour_str,
+                           limit=0, max_retries=3, base_sleep=1.0):
+    """验证keep_info的一行"""
+    city_pair_split = city_pair.split('-')
+    from_city_code = city_pair_split[0]
+    to_city_code = city_pair_split[1]
+    flight_day_str = datetime.strptime(flight_day, "%Y-%m-%d").strftime("%Y%m%d") 
+    baggage_str = f"1-{baggage}"
+    if baggage == 0:
+        baggage_str = "-;-;-;-"
+
+    for attempt in range(1, max_retries + 1):
+        try:
+            print(f"🔁 第 {attempt}/{max_retries} 次尝试查询") 
+            # 构建查询条件
+            query_condition = {
+                "from_city_code": from_city_code,
+                "to_city_code": to_city_code,
+                "search_dep_time": flight_day_str,
+                "segments.baggage": baggage_str,
+                "crawl_date": {"$gte": update_hour_str},
+                "segments.0.flight_number": flight_number_1,
+            }
+            # 如果有第二段
+            if flight_number_2 != "VJ":
+                query_condition["segments.1.flight_number"] = flight_number_2
+            print(f"   查询条件: {query_condition}")
+            # 定义要查询的字段
+            projection = {
+                # "_id": 1,
+                "from_city_code": 1,
+                "search_dep_time": 1,
+                "to_city_code": 1,
+                "currency": 1,
+                "adult_price": 1,
+                "adult_tax": 1,
+                "adult_total_price": 1,
+                "seats_remaining": 1,
+                "segments": 1,
+                "source_website": 1,
+                "crawl_date": 1
+            }
+            # 执行查询
+            cursor = db.get_collection(table_name).find(
+                query_condition,
+                projection=projection  # 添加投影参数
+            ).sort(
+                [
+                    ("crawl_date", 1)
+                ]
+            )
+            if limit > 0:
+                cursor = cursor.limit(limit)
+
+            # 将结果转换为列表
+            results = list(cursor)
+            print(f"✅ 查询成功,找到 {len(results)} 条记录")
+
+            if results:
+                df = pd.DataFrame(results)
+                # 处理特殊的 ObjectId 类型
+                if '_id' in df.columns:
+                    df = df.drop(columns=['_id'])
+                print(f"📊 已转换为 DataFrame,形状: {df.shape}")
+
+                # 1️⃣ 展开 segments
+                print(f"📊 开始扩展segments 稍等...")
+                t1 = time.time()
+                df = expand_segments_columns_optimized(df)
+                t2 = time.time()
+                rt = round(t2 - t1, 3)
+                print(f"用时: {rt} 秒")
+                print(f"📊 已将segments扩展成字段,形状: {df.shape}")
+
+                # 不用排序,因为mongo语句已经排好
+                return df
+
+            else:
+                print("⚠️  查询结果为空")
+                return pd.DataFrame()
+
+        except (ServerSelectionTimeoutError, PyMongoError) as e:
+            print(f"⚠️ Mongo 查询失败: {e}")
+            if attempt == max_retries:
+                print("❌ 达到最大重试次数,放弃")
+                return pd.DataFrame()
+            
+            # 指数退避 + 随机抖动
+            sleep_time = base_sleep * (2 ** (attempt - 1)) + random.random()
+            print(f"⏳ {sleep_time:.2f}s 后重试...")
+            time.sleep(sleep_time)
+
+
 if __name__ == "__main__":
 
     # test_mongo_connection(db)

+ 183 - 26
data_preprocess.py

@@ -996,6 +996,16 @@ def preprocess_data_simple(df_input, is_train=False):
         ]
         df_rise_nodes = df_rise_nodes[flight_info_cols + rise_info_cols]
 
+        # 制作历史包络线
+        envelope_group = ['city_pair', 'flight_number_1', 'flight_number_2', 'flight_day']
+        idx_peak = df_input.groupby(envelope_group)['adult_total_price'].idxmax()
+        df_envelope = df_input.loc[idx_peak, envelope_group + [
+            'adult_total_price', 'hours_until_departure'
+        ]].rename(columns={
+            'adult_total_price': 'peak_price',
+            'hours_until_departure': 'peak_hours',
+        }).reset_index(drop=True)
+
         # 对于没有先升后降的gid进行分析
         # gids_with_drop = df_target.loc[drop_mask, 'gid'].unique()
         # df_no_drop = df_target[~df_target['gid'].isin(gids_with_drop)].copy()
@@ -1048,9 +1058,9 @@ def preprocess_data_simple(df_input, is_train=False):
         del df_target
         # del df_no_drop
 
-        return df_input, df_drop_nodes, df_rise_nodes
+        return df_input, df_drop_nodes, df_rise_nodes, df_envelope
 
-    return df_input, None, None
+    return df_input, None, None, None
 
 
 def predict_data_simple(df_input, group_route_str, output_dir, predict_dir=".", pred_time_str=""):
@@ -1089,25 +1099,137 @@ def predict_data_simple(df_input, group_route_str, output_dir, predict_dir=".",
     else:
         df_rise_nodes = pd.DataFrame()
 
+    # ==================== 跨航班日包络线 + 降价潜力 ====================
+    print(">>> 构建跨航班日价格包络线")
+    flight_key = ['city_pair', 'flight_number_1', 'flight_number_2']
+    day_key = flight_key + ['flight_day']
+
+    # 1. 历史侧:加载训练阶段的峰值数据
+    envelope_csv_path = os.path.join(output_dir, f'{group_route_str}_envelope_info.csv')
+    if os.path.exists(envelope_csv_path):
+        df_hist = pd.read_csv(envelope_csv_path)
+        df_hist = df_hist[day_key + ['peak_price', 'peak_hours']]
+        df_hist['source'] = 'hist'
+    else:
+        df_hist = pd.DataFrame()
+
+    # 2. 未来侧:当前在售价格
+    df_future = df_min_hours[day_key + ['adult_total_price', 'hours_until_departure']].copy().rename(
+        columns={'adult_total_price': 'peak_price', 'hours_until_departure': 'peak_hours'}
+    )
+    df_future['source'] = 'future'
+
+    # 3. 合并包络线数据点
+    df_envelope_all = pd.concat(
+        [x for x in [df_hist, df_future] if not x.empty], ignore_index=True
+    ).drop_duplicates(subset=day_key, keep='last')
+
+    # 4. 包络线统计 + 找高点起飞日
+    df_envelope_agg = df_envelope_all.groupby(flight_key).agg(
+        envelope_max=('peak_price', 'max'),               # 峰值最大 
+        envelope_min=('peak_price', 'min'),               # 峰值最小
+        envelope_mean=('peak_price', 'mean'),             # 峰值平均
+        envelope_count=('peak_price', 'count'),           # 峰值统计总数
+        envelope_avg_peak_hours=('peak_hours', 'mean'),   # 峰值发生的距离起飞小时数, 做一下平均
+    ).reset_index()
+
+    # 对数值列保留两位小数
+    df_envelope_agg[['envelope_mean', 'envelope_avg_peak_hours']] = df_envelope_agg[['envelope_mean', 'envelope_avg_peak_hours']].round(2)
+
+    idx_top = df_envelope_all.groupby(flight_key)['peak_price'].idxmax()
+    df_top = df_envelope_all.loc[idx_top, flight_key + ['flight_day', 'peak_price', 'peak_hours']].rename(
+        columns={'flight_day': 'target_flight_day', 'peak_price': 'target_price', 'peak_hours': 'target_peak_hours'}
+    )
+    df_envelope_agg = df_envelope_agg.merge(df_top, on=flight_key, how='left')
+
+    # 5. 合并到 df_min_hours
+    df_min_hours = df_min_hours.merge(df_envelope_agg, on=flight_key, how='left')
+    price_range = (df_min_hours['envelope_max'] - df_min_hours['envelope_min']).replace(0, 1)    # 计算当前价格在包络区间的百分位
+    df_min_hours['envelope_position'] = (
+        (df_min_hours['adult_total_price'] - df_min_hours['envelope_min']) / price_range
+    ).clip(0, 1).round(4)
+    df_min_hours['is_envelope_peak'] = (df_min_hours['envelope_position'] >= 0.75).astype(int)   # 0.95 -> 0.75
+    df_min_hours['is_target_day'] = (df_min_hours['flight_day'] == df_min_hours['target_flight_day']).astype(int)
+
+    # ==================== 目标二:降价潜力评分 ====================
+    # 用“上涨后回落倾向”替代简单计数:drop / (drop + rise)
+    # drop_count 来自 _drop_info.csv(上涨段后转跌),rise_count 来自 _rise_info.csv(上涨段后继续涨)
+    df_min_hours['drop_potential'] = 0.0
+
+    # 先保证相关列一定存在,避免后续选列 KeyError
+    # df_min_hours['drop_freq_count'] = 0.0
+    # df_min_hours['rise_freq_count'] = 0.0
+
+    df_drop_freq = pd.DataFrame(columns=flight_key + ['drop_freq_count'])
+    df_rise_freq = pd.DataFrame(columns=flight_key + ['rise_freq_count'])
+
+    if not df_drop_nodes.empty:
+        df_drop_freq = (
+            df_drop_nodes.groupby(flight_key)
+            .size()
+            .reset_index(name='drop_freq_count')
+        )
+
+    if not df_rise_nodes.empty:
+        df_rise_freq = (
+            df_rise_nodes.groupby(flight_key)
+            .size()
+            .reset_index(name='rise_freq_count')
+        )
+
+    if (not df_drop_freq.empty) or (not df_rise_freq.empty):
+        df_min_hours = df_min_hours.merge(df_drop_freq, on=flight_key, how='left')
+        df_min_hours = df_min_hours.merge(df_rise_freq, on=flight_key, how='left')
+
+        df_min_hours['drop_freq_count'] = df_min_hours['drop_freq_count'].fillna(0).astype(float)
+        df_min_hours['rise_freq_count'] = df_min_hours['rise_freq_count'].fillna(0).astype(float)
+        
+        # 轻微平滑,避免样本很少时出现 0/0 或过度极端
+        alpha = 1.0
+        denom = df_min_hours['drop_freq_count'] + df_min_hours['rise_freq_count'] + 2.0 * alpha
+        df_min_hours['drop_potential'] = (
+            (df_min_hours['drop_freq_count'] + alpha) / denom.replace(0, np.nan)
+        ).fillna(0.0).clip(0, 1).round(4)
+        
+    # ==================== 综合评分:包络高位 × 降价潜力 ====================
+    # target_score = 包络位置(越高越好)× 降价潜力(越高越好)
+    thres_ep = 0.7
+    thres_dp = 0.3
+    df_min_hours['target_score'] = (
+        df_min_hours['envelope_position'] * thres_ep + df_min_hours['drop_potential'] * thres_dp
+    ).round(4)
+
+    # 综合评分阈值:大于阈值的都认为值得投放
+    target_score_threshold = 0.75
+    # df_min_hours['target_score_threshold'] = target_score_threshold
+    df_min_hours['is_good_target'] = (df_min_hours['target_score'] >= target_score_threshold).astype(int)
+
+    print(f">>> 包络线+降价潜力评分完成")
+    del df_hist, df_future, df_envelope_all, df_envelope_agg, df_top, df_drop_freq, df_rise_freq
+    
+    df_min_hours = df_min_hours[df_min_hours['is_good_target'] == 1].reset_index(drop=True)   # 保留值得投放的 
+
+    # =====================================================================
+
     df_min_hours['simple_will_price_drop'] = 0   
     df_min_hours['simple_drop_in_hours'] = 0
     df_min_hours['simple_drop_in_hours_prob'] = 0.0
     df_min_hours['simple_drop_in_hours_dist'] = ''   # 空串 表示未知
     df_min_hours['flag_dist'] = ''
     df_min_hours['drop_price_change_upper'] = 0.0
-    df_min_hours['drop_price_change_mode'] = 0.0
+    # df_min_hours['drop_price_change_mode'] = 0.0
     df_min_hours['drop_price_change_lower'] = 0.0
     df_min_hours['drop_price_sample_size'] = 0
     df_min_hours['rise_price_change_upper'] = 0.0
-    df_min_hours['rise_price_change_mode'] = 0.0
+    # df_min_hours['rise_price_change_mode'] = 0.0
     df_min_hours['rise_price_change_lower'] = 0.0
     df_min_hours['rise_price_sample_size'] = 0
 
     # 这个阈值取多少?
-    pct_threshold = 0.001
+    pct_threshold = 0.01
     # pct_threshold = 2
-    pct_threshold_1 = 0.001
-    pct_threshold_c = 0.001
+    pct_threshold_1 = 0.01
+    # pct_threshold_c = 0.001
 
     for idx, row in df_min_hours.iterrows(): 
         city_pair = row['city_pair']
@@ -1180,9 +1302,9 @@ def predict_data_simple(df_input, group_route_str, output_dir, predict_dir=".",
                         # df_match_chk = df_match_chk.loc[dur_vals.notna()].copy()
                         # df_match_chk = df_match_chk.loc[(dur_vals.loc[dur_vals.notna()] - float(dur_base)).abs() <= 36].copy()
 
-                        drop_hud_vals = pd.to_numeric(df_match_chk['drop_hours_until_departure'], errors='coerce')
-                        df_match_chk = df_match_chk.loc[drop_hud_vals.notna()].copy()
-                        df_match_chk = df_match_chk.loc[(drop_hud_vals.loc[drop_hud_vals.notna()] - float(hud_base)).abs() <= 24].copy()
+                        # drop_hud_vals = pd.to_numeric(df_match_chk['drop_hours_until_departure'], errors='coerce')
+                        # df_match_chk = df_match_chk.loc[drop_hud_vals.notna()].copy()
+                        # df_match_chk = df_match_chk.loc[(drop_hud_vals.loc[drop_hud_vals.notna()] - float(hud_base)).abs() <= 24].copy()
 
                         # seats_vals = pd.to_numeric(df_match_chk['high_price_seats_remaining_change_amount'], errors='coerce')
                         # df_match_chk = df_match_chk.loc[seats_vals.notna()].copy()
@@ -1198,9 +1320,9 @@ def predict_data_simple(df_input, group_route_str, output_dir, predict_dir=".",
                             df_min_hours.loc[idx, 'drop_price_change_upper'] = round(drop_price_change_upper, 2)
                             df_min_hours.loc[idx, 'drop_price_change_lower'] = round(drop_price_change_lower, 2)
 
-                            drop_mode_values = df_match_chk['drop_price_change_amount'].mode()  # 降价众数
-                            if len(drop_mode_values) > 0:
-                                df_min_hours.loc[idx, 'drop_price_change_mode'] = round(float(drop_mode_values[0]), 2)
+                            # drop_mode_values = df_match_chk['drop_price_change_amount'].mode()  # 降价众数
+                            # if len(drop_mode_values) > 0:
+                            #     df_min_hours.loc[idx, 'drop_price_change_mode'] = round(float(drop_mode_values[0]), 2)
 
                             remaining_hours = (
                                 pd.to_numeric(df_match_chk['high_price_duration_hours'], errors='coerce') - float(dur_base)
@@ -1379,9 +1501,9 @@ def predict_data_simple(df_input, group_route_str, output_dir, predict_dir=".",
                         # df_match_chk_1 = df_match_chk_1.loc[dur_vals_1.notna()].copy()
                         # df_match_chk_1 = df_match_chk_1.loc[(dur_vals_1.loc[dur_vals_1.notna()] - float(dur_base_1)).abs() <= 24].copy()
 
-                        rise_hud_vals_1 = pd.to_numeric(df_match_chk_1['rise_hours_until_departure'], errors='coerce')
-                        df_match_chk_1 = df_match_chk_1.loc[rise_hud_vals_1.notna()].copy()
-                        df_match_chk_1 = df_match_chk_1.loc[(rise_hud_vals_1.loc[rise_hud_vals_1.notna()] - float(hud_base_1)).abs() <= 24].copy()
+                        # rise_hud_vals_1 = pd.to_numeric(df_match_chk_1['rise_hours_until_departure'], errors='coerce')
+                        # df_match_chk_1 = df_match_chk_1.loc[rise_hud_vals_1.notna()].copy()
+                        # df_match_chk_1 = df_match_chk_1.loc[(rise_hud_vals_1.loc[rise_hud_vals_1.notna()] - float(hud_base_1)).abs() <= 24].copy()
 
                         # seats_vals_1 = pd.to_numeric(df_match_chk_1['rise_seats_remaining_change_amount'], errors='coerce')
                         # df_match_chk_1 = df_match_chk_1.loc[seats_vals_1.notna()].copy()
@@ -1397,9 +1519,9 @@ def predict_data_simple(df_input, group_route_str, output_dir, predict_dir=".",
                             df_min_hours.loc[idx, 'rise_price_change_upper'] = round(rise_price_change_upper, 2)
                             df_min_hours.loc[idx, 'rise_price_change_lower'] = round(rise_price_change_lower, 2)
 
-                            rise_mode_values = df_match_chk_1['rise_price_change_amount'].mode()  # 涨价众数
-                            if len(rise_mode_values) > 0:
-                                df_min_hours.loc[idx, 'rise_price_change_mode'] = round(float(rise_mode_values[0]), 2)
+                            # rise_mode_values = df_match_chk_1['rise_price_change_amount'].mode()  # 涨价众数
+                            # if len(rise_mode_values) > 0:
+                            #     df_min_hours.loc[idx, 'rise_price_change_mode'] = round(float(rise_mode_values[0]), 2)
 
                             # 可以明确的判定不降价
                             if length_drop == 0:
@@ -1478,7 +1600,27 @@ def predict_data_simple(df_input, group_route_str, output_dir, predict_dir=".",
                     #             df_min_hours.loc[idx, 'simple_drop_in_hours_prob'] = 0.0
                     #             df_min_hours.loc[idx, 'simple_drop_in_hours_dist'] = 'n1'
                 pass
-    print("判定过程结束")
+    print("判定循环结束")
+    # 按航班号统一其降价/涨价的上限与下限, 上限统一取最大, 下限统一取最小
+    # _grp_cols = ['city_pair', 'flight_number_1', 'flight_number_2']
+    # _g = df_min_hours.groupby(_grp_cols, dropna=False)
+    # df_min_hours['drop_price_change_upper'] = pd.to_numeric(
+    #     _g['drop_price_change_upper'].transform('max'),
+    #     errors='coerce'
+    # ).fillna(0.0).round(2)
+    # df_min_hours['drop_price_change_lower'] = pd.to_numeric(
+    #     _g['drop_price_change_lower'].transform('min'),
+    #     errors='coerce'
+    # ).fillna(0.0).round(2)
+    # df_min_hours['rise_price_change_upper'] = pd.to_numeric(
+    #     _g['rise_price_change_upper'].transform('max'),
+    #     errors='coerce'
+    # ).fillna(0.0).round(2)
+    # df_min_hours['rise_price_change_lower'] = pd.to_numeric(
+    #     _g['rise_price_change_lower'].transform('min'),
+    #     errors='coerce'
+    # ).fillna(0.0).round(2)
+    
     df_min_hours = df_min_hours.rename(columns={'seg1_dep_time': 'from_time'})
     _pred_dt = pd.to_datetime(str(pred_time_str), format="%Y%m%d%H%M", errors="coerce")
     df_min_hours["update_hour"] = _pred_dt.strftime("%Y-%m-%d %H:%M:%S")
@@ -1494,8 +1636,13 @@ def predict_data_simple(df_input, group_route_str, output_dir, predict_dir=".",
                   'valid_begin_hour', 'valid_end_hour',
                   'simple_will_price_drop', 'simple_drop_in_hours', 'simple_drop_in_hours_prob', 'simple_drop_in_hours_dist',
                   'flag_dist',
-                  'drop_price_change_upper', 'drop_price_change_mode', 'drop_price_change_lower', 'drop_price_sample_size',
-                  'rise_price_change_upper', 'rise_price_change_mode', 'rise_price_change_lower', 'rise_price_sample_size',
+                  'drop_price_change_upper', 'drop_price_change_lower', 'drop_price_sample_size',
+                  'rise_price_change_upper', 'rise_price_change_lower', 'rise_price_sample_size',
+                  'envelope_max', 'envelope_min', 'envelope_mean', 'envelope_count',
+                  'envelope_avg_peak_hours', 'envelope_position', 'is_envelope_peak',         # 包络线特征
+                  'target_flight_day', 'target_price', 'target_peak_hours', 'is_target_day',  # 高点起飞日(纯包络线高点)
+                  'drop_freq_count', 'drop_potential',                                        # 降价潜力 
+                  'target_score', 'is_good_target',                                           # 综合目标评分(高点 × 降价潜力 = 最终投放目标) 
                  ]
     df_predict = df_min_hours[order_cols]
     df_predict = df_predict.rename(columns={
@@ -1513,15 +1660,25 @@ def predict_data_simple(df_input, group_route_str, output_dir, predict_dir=".",
         na_position='last',
     ).reset_index(drop=True)
 
-    # 时间段过滤 过滤掉异常时间(update_hour 早于 crawl_date)因为现在有实时验价, 不做8小时之内的过滤
+    # 时间段过滤 过滤掉异常时间(update_hour 早于 crawl_date)
     update_dt = pd.to_datetime(df_predict["update_hour"], errors="coerce")
     crawl_dt = pd.to_datetime(df_predict["crawl_date"], errors="coerce")
     dt_diff = update_dt - crawl_dt
     df_predict = df_predict.loc[
-        # (dt_diff >= pd.Timedelta(0)) & (dt_diff <= pd.Timedelta(hours=8))
-        (dt_diff >= pd.Timedelta(0))
+        (dt_diff >= pd.Timedelta(0)) & (dt_diff <= pd.Timedelta(hours=12))
+        # (dt_diff >= pd.Timedelta(0))
     ].reset_index(drop=True)
-    print("更新时间过滤")
+    print("更新时间过滤完成")
+
+    total_cnt = len(df_predict)
+    if "will_price_drop" in df_predict.columns:
+        _wpd = pd.to_numeric(df_predict["will_price_drop"], errors="coerce")
+        drop_1_cnt = int((_wpd == 1).sum())
+        drop_0_cnt = int((_wpd == 0).sum())
+    else:
+        drop_1_cnt = 0
+        drop_0_cnt = 0
+    print(f"will_price_drop 分类数量统计: 1(会降)={drop_1_cnt}, 0(不降)={drop_0_cnt}, 总数={total_cnt}")
 
     csv_path1 = os.path.join(predict_dir, f'future_predictions_{pred_time_str}.csv')
     df_predict.to_csv(csv_path1, mode='a', index=False, header=not os.path.exists(csv_path1), encoding='utf-8-sig')

+ 68 - 65
follow_up.py

@@ -31,7 +31,7 @@ def follow_up_handle():
     csv_files.sort()
     
     # 调试分支
-    # target_time = "202603011600"
+    # target_time = "202603131400"
     # matching_files = [f for f in csv_files if target_time in f]
     # if matching_files:
     #     last_csv_file = matching_files[0]
@@ -64,9 +64,9 @@ def follow_up_handle():
     df_last_predict_will_drop = df_last_predict_will_drop.drop_duplicates(
         subset=key_cols, keep="last"
     ).reset_index(drop=True)
-    df_last_predict_not_drop = df_last_predict_not_drop.drop_duplicates(
-        subset=key_cols, keep="last"
-    ).reset_index(drop=True)
+    # df_last_predict_not_drop = df_last_predict_not_drop.drop_duplicates(
+    #     subset=key_cols, keep="last"
+    # ).reset_index(drop=True)
 
     # 读取维护表
     if os.path.exists(keep_info_path):
@@ -120,6 +120,7 @@ def follow_up_handle():
     # 初始化维护表
     if df_keep_info.empty:
         df_keep_info = df_last_predict_will_drop.copy()
+        df_keep_info["into_update_hour"] = df_keep_info['update_hour']
         df_keep_info["keep_flag"] = 1
         # df_keep_info["last_predict_time"] = target_time
 
@@ -159,7 +160,7 @@ def follow_up_handle():
         
         for c in key_cols:
             df_last_predict_will_drop[c] = df_last_predict_will_drop[c].astype(str)
-            df_last_predict_not_drop[c] = df_last_predict_not_drop[c].astype(str)
+            # df_last_predict_not_drop[c] = df_last_predict_not_drop[c].astype(str)
             df_keep_info[c] = df_keep_info[c].astype(str)
 
         df_keep_info = df_keep_info.drop_duplicates(subset=key_cols, keep="last").reset_index(drop=True)
@@ -179,6 +180,7 @@ def follow_up_handle():
         )
         # keep_flag 设为 1
         if not df_to_add.empty:
+            df_to_add['into_update_hour'] = df_to_add['update_hour']
             df_to_add["keep_flag"] = 1
         
         df_keep_with_merge = df_keep_info.reset_index().merge(
@@ -231,70 +233,71 @@ def follow_up_handle():
                     # new_hud = hud - 1
                     new_hud = hud - hud_decrement
                     df_keep_info.loc[mask_need_observe, "hours_until_departure"] = new_hud
-
-                    df_keep_only_keys = df_keep_info.loc[mask_keep_only, key_cols].copy()
-                    df_keep_only_keys["_row_idx"] = df_keep_only_keys.index
-                    # 检查 df_keep_only_keys 是否在 df_last_predict_not_drop 中
-                    df_keep_only_keys = df_keep_only_keys.merge(
-                        df_last_predict_not_drop[key_cols].drop_duplicates(),
-                        on=key_cols,
-                        how="left",
-                        indicator=True,
-                    )
-                    idx_in_not_drop = df_keep_only_keys.loc[
-                        df_keep_only_keys["_merge"] == "both", "_row_idx"
-                    ].tolist()
-                    mask_in_not_drop = df_keep_info.index.isin(idx_in_not_drop)     # 在 df_last_predict_not_drop 中出现 只是will_price_drop为0 未达边界
-                    mask_not_drop_observe = mask_need_observe & mask_in_not_drop    # 判断为不降价的布尔索引数组
-                    mask_boundary_observe = mask_need_observe & ~mask_in_not_drop   # 判断为到达边界的布尔索引数组
-
-                    df_keep_info.loc[mask_not_drop_observe, "keep_flag"] = -1       # 删除标志
-
-                    if mask_boundary_observe.any():
-                        new_hud_full = pd.to_numeric(
-                            df_keep_info["hours_until_departure"], errors="coerce"
-                        )
-                        df_keep_info.loc[mask_boundary_observe, "keep_flag"] = -1    # 默认删除标志
-                        df_keep_info.loc[
-                            mask_boundary_observe & new_hud_full.gt(4), "keep_flag"  # 如果达到边界且hours_until_departure大于4 则给保留标志
-                        ] = 2
+                    df_keep_info.loc[mask_need_observe, "keep_flag"] = -1       # 删除标志
+
+                    # df_keep_only_keys = df_keep_info.loc[mask_keep_only, key_cols].copy()
+                    # df_keep_only_keys["_row_idx"] = df_keep_only_keys.index
+                    # # 检查 df_keep_only_keys 是否在 df_last_predict_not_drop 中
+                    # df_keep_only_keys = df_keep_only_keys.merge(
+                    #     df_last_predict_not_drop[key_cols].drop_duplicates(),
+                    #     on=key_cols,
+                    #     how="left",
+                    #     indicator=True,
+                    # )
+                    # idx_in_not_drop = df_keep_only_keys.loc[
+                    #     df_keep_only_keys["_merge"] == "both", "_row_idx"
+                    # ].tolist()
+                    # mask_in_not_drop = df_keep_info.index.isin(idx_in_not_drop)     # 在 df_last_predict_not_drop 中出现 只是will_price_drop为0 未达边界
+                    # mask_not_drop_observe = mask_need_observe & mask_in_not_drop    # 判断为不降价的布尔索引数组
+                    # mask_boundary_observe = mask_need_observe & ~mask_in_not_drop   # 判断为到达边界的布尔索引数组
+
+                    # df_keep_info.loc[mask_not_drop_observe, "keep_flag"] = -1       # 删除标志
+
+                    # if mask_boundary_observe.any():
+                    #     new_hud_full = pd.to_numeric(
+                    #         df_keep_info["hours_until_departure"], errors="coerce"
+                    #     )
+                    #     df_keep_info.loc[mask_boundary_observe, "keep_flag"] = -1    # 默认删除标志
+                    #     df_keep_info.loc[
+                    #         mask_boundary_observe & new_hud_full.gt(4), "keep_flag"  # 如果达到边界且hours_until_departure大于4 则给保留标志
+                    #     ] = 2
                     
                     pass
         
         # 对于这些边界保持状态(keep_flag为2) 检查其是否在最新一次验价后的文件里存在, 如果不存在 则标记为-1
-        df_temp_2 = df_keep_info.loc[df_keep_info["keep_flag"] == 2, key_cols].copy()
-        if not df_temp_2.empty:
-            end_dir = "/home/node04/descending_cabin_files"
-            end_candidates = []
-            if os.path.isdir(end_dir):
-                for f in os.listdir(end_dir):
-                    if f.startswith("keep_info_end_") and f.endswith(".csv"):
-                        ts = f.replace("keep_info_end_", "").replace(".csv", "")
-                        if ts.isdigit():
-                            end_candidates.append((ts, f))   #(时间戳,文件名)
-            if end_candidates:
-                end_candidates.sort(key=lambda x: x[0])
-                end_last_path = os.path.join(end_dir, end_candidates[-1][1])  # 最新一次验价后的文件
-                try:
-                    df_end_last = pd.read_csv(end_last_path)
-                except Exception:
-                    df_end_last = pd.DataFrame()
-
-                if not df_end_last.empty and all(c in df_end_last.columns for c in key_cols):  # key_cols作为比对条件
-                    df_temp_2["_row_idx"] = df_temp_2.index
-                    df_end_keys = df_end_last[key_cols].drop_duplicates().copy()
-                    for c in key_cols:
-                        df_temp_2[c] = df_temp_2[c].astype(str)
-                        df_end_keys[c] = df_end_keys[c].astype(str)
-                    df_temp_2_with_merge = df_temp_2.merge(
-                        df_end_keys, on=key_cols, how="left", indicator=True
-                    )
-                    # 对于只在 df_temp_2 出现,而不在 df_end_keys 出现的索引,在 df_keep_info 中标记为-1
-                    idx_to_rm_2 = df_temp_2_with_merge.loc[
-                        df_temp_2_with_merge["_merge"] == "left_only", "_row_idx"
-                    ].tolist()
-                    if idx_to_rm_2:
-                        df_keep_info.loc[idx_to_rm_2, "keep_flag"] = -1
+        # df_temp_2 = df_keep_info.loc[df_keep_info["keep_flag"] == 2, key_cols].copy()
+        # if not df_temp_2.empty:
+        #     end_dir = "/home/node04/descending_cabin_files"
+        #     end_candidates = []
+        #     if os.path.isdir(end_dir):
+        #         for f in os.listdir(end_dir):
+        #             if f.startswith("keep_info_end_") and f.endswith(".csv"):
+        #                 ts = f.replace("keep_info_end_", "").replace(".csv", "")
+        #                 if ts.isdigit():
+        #                     end_candidates.append((ts, f))   #(时间戳,文件名)
+        #     if end_candidates:
+        #         end_candidates.sort(key=lambda x: x[0])
+        #         end_last_path = os.path.join(end_dir, end_candidates[-1][1])  # 最新一次验价后的文件
+        #         try:
+        #             df_end_last = pd.read_csv(end_last_path)
+        #         except Exception:
+        #             df_end_last = pd.DataFrame()
+
+        #         if not df_end_last.empty and all(c in df_end_last.columns for c in key_cols):  # key_cols作为比对条件
+        #             df_temp_2["_row_idx"] = df_temp_2.index
+        #             df_end_keys = df_end_last[key_cols].drop_duplicates().copy()
+        #             for c in key_cols:
+        #                 df_temp_2[c] = df_temp_2[c].astype(str)
+        #                 df_end_keys[c] = df_end_keys[c].astype(str)
+        #             df_temp_2_with_merge = df_temp_2.merge(
+        #                 df_end_keys, on=key_cols, how="left", indicator=True
+        #             )
+        #             # 对于只在 df_temp_2 出现,而不在 df_end_keys 出现的索引,在 df_keep_info 中标记为-1
+        #             idx_to_rm_2 = df_temp_2_with_merge.loc[
+        #                 df_temp_2_with_merge["_merge"] == "left_only", "_row_idx"
+        #             ].tolist()
+        #             if idx_to_rm_2:
+        #                 df_keep_info.loc[idx_to_rm_2, "keep_flag"] = -1
 
         # 将 df_to_add 添加到 df_keep_info 之后
         add_rows = len(df_to_add) if "df_to_add" in locals() else 0

+ 1 - 1
main_pe_0.py

@@ -115,7 +115,7 @@ def start_predict():
             print(f"起飞时间在 {pred_hour_begin} 到 {pred_hour_end} 之间没有航班,跳过此批次。")
             continue
         
-        df_test_inputs, _, _, = preprocess_data_simple(df_test)
+        df_test_inputs, _, _, _, = preprocess_data_simple(df_test)
 
         df_predict = predict_data_simple(df_test_inputs, group_route_str, output_dir, predict_dir, hourly_time_str)
         

+ 13 - 1
main_tr_0.py

@@ -1,5 +1,6 @@
 import os
 import time
+import gc
 import pandas as pd
 from datetime import datetime, timedelta
 from config import mongodb_config, vj_flight_route_list_hot, vj_flight_route_list_nothot, CLEAN_VJ_HOT_NEAR_INFO_TAB, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
@@ -105,7 +106,7 @@ def start_train():
             print(f"训练数据为空,跳过此批次。")
             continue
 
-        _, df_drop_nodes, df_rise_nodes = preprocess_data_simple(df_train, is_train=True)
+        _, df_drop_nodes, df_rise_nodes, df_envelope = preprocess_data_simple(df_train, is_train=True)
 
         dedup_cols = ['city_pair', 'flight_number_1', 'flight_number_2', 'flight_day']
 
@@ -123,6 +124,17 @@ def start_train():
             merge_and_overwrite_csv(df_rise_nodes, rise_info_csv_path, dedup_cols)
             print(f"本批次训练已保存csv文件: {rise_info_csv_path}")
 
+        envelope_csv_path = os.path.join(output_dir, f'{group_route_str}_envelope_info.csv')
+        if not df_envelope.empty:
+            merge_and_overwrite_csv(df_envelope, envelope_csv_path, dedup_cols)
+            print(f"本批次训练已保存csv文件: {envelope_csv_path}")
+
+        del df_drop_nodes
+        del df_rise_nodes
+        del df_envelope
+
+        gc.collect()
+
         time.sleep(1)
 
     print(f"所有批次训练已完成")

+ 157 - 0
result_keep_verify.py

@@ -0,0 +1,157 @@
+import os
+import datetime
+import pandas as pd
+from data_loader import mongo_con_parse, validate_keep_one_line, fill_hourly_crawl_date
+from config import vj_flight_route_list_hot, vj_flight_route_list_nothot, \
+    CLEAN_VJ_HOT_NEAR_INFO_TAB, CLEAN_VJ_HOT_FAR_INFO_TAB, CLEAN_VJ_NOTHOT_NEAR_INFO_TAB, CLEAN_VJ_NOTHOT_FAR_INFO_TAB
+
+
+def _validate_keep_info_df(df_keep_info_part):
+    client, db = mongo_con_parse()
+    count = 0
+
+    if "price_diff" not in df_keep_info_part.columns:
+        df_keep_info_part["price_diff"] = 0
+    if "time_diff_hours" not in df_keep_info_part.columns:
+        df_keep_info_part["time_diff_hours"] = 0
+
+    for idx, row in df_keep_info_part.iterrows():
+        df_keep_info_part.at[idx, "price_diff"] = 0
+        df_keep_info_part.at[idx, "time_diff_hours"] = 0
+
+        city_pair = row['city_pair']
+        flight_day = row['flight_day']
+        flight_number_1 = row['flight_number_1']
+        flight_number_2 = row['flight_number_2']
+        baggage = row['baggage']
+        update_hour = row['update_hour']
+        update_dt = pd.to_datetime(update_hour, format='%Y-%m-%d %H:%M:%S')
+        into_update_hour = row['into_update_hour']
+        into_update_dt = pd.to_datetime(into_update_hour, format='%Y-%m-%d %H:%M:%S')
+
+        entry_price = pd.to_numeric(row.get('adult_total_price'), errors='coerce')
+        if city_pair in vj_flight_route_list_hot:
+            table_name_far = CLEAN_VJ_HOT_FAR_INFO_TAB
+            table_name_near = CLEAN_VJ_HOT_NEAR_INFO_TAB
+        elif city_pair in vj_flight_route_list_nothot: 
+            table_name_far = CLEAN_VJ_NOTHOT_FAR_INFO_TAB
+            table_name_near = CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
+
+        # 分别从远期表和近期表里查询
+        df_query_far = validate_keep_one_line(db, table_name_far, city_pair, flight_day, flight_number_1, flight_number_2, baggage, into_update_hour)
+        df_query_near = validate_keep_one_line(db, table_name_near, city_pair, flight_day, flight_number_1, flight_number_2, baggage, into_update_hour)
+        # 合并
+        df_query = pd.concat([df_query_far, df_query_near]).reset_index(drop=True)
+        if (not df_query.empty) and pd.notna(entry_price):
+            if ("adult_total_price" in df_query.columns) and ("crawl_date" in df_query.columns):
+                df_query["adult_total_price"] = pd.to_numeric(df_query["adult_total_price"], errors="coerce")
+                df_query["crawl_dt"] = pd.to_datetime(df_query["crawl_date"], errors="coerce")
+                df_query = (
+                    df_query.dropna(subset=["adult_total_price", "crawl_dt"])
+                    .sort_values("crawl_dt")
+                    .reset_index(drop=True)
+                )
+                mask_drop = df_query["adult_total_price"] < entry_price
+                if mask_drop.any():
+                    first_row = df_query.loc[mask_drop].iloc[0]
+                    price_diff = entry_price - first_row["adult_total_price"]
+                    time_diff_hours = (first_row["crawl_dt"] - into_update_dt) / pd.Timedelta(hours=1)
+                    df_keep_info_part.at[idx, "price_diff"] = round(float(price_diff), 2)
+                    df_keep_info_part.at[idx, "time_diff_hours"] = round(float(time_diff_hours), 2)
+                    pass
+        
+        del df_query
+        del df_query_far
+        del df_query_near
+
+        count += 1
+        if count % 5 == 0:
+            print(f"cal count: {count}")
+    
+    print(f"计算结束")
+    client.close()
+
+    return df_keep_info_part
+
+
+def verify_process(min_batch_time_str, max_batch_time_str):
+    object_dir = "./keep_0"
+
+    output_dir = f"./validate/keep"
+    os.makedirs(output_dir, exist_ok=True)
+
+    timestamp_str = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
+    save_scv = f"result_keep_verify_{timestamp_str}.csv"
+    output_path = os.path.join(output_dir, save_scv)
+    
+    # 获取今天的日期
+    # today_str = pd.Timestamp.now().strftime('%Y-%m-%d')
+
+    # 检查目录是否存在
+    if not os.path.exists(object_dir):
+        print(f"目录不存在: {object_dir}")
+        return
+                    
+    # 获取所有以 keep_info_ 开头的 CSV 文件
+    csv_files = []
+    for file in os.listdir(object_dir):
+        if file.startswith("keep_info_") and file.endswith(".csv"):
+            csv_files.append(file)
+
+    if not csv_files:
+        print(f"在 {object_dir} 中没有找到 keep_info_ 开头的 CSV 文件")
+        return
+    
+    csv_files.sort()
+    # print(csv_files)
+
+    min_batch_dt = datetime.datetime.strptime(min_batch_time_str, "%Y%m%d%H%M")
+    min_batch_dt = min_batch_dt.replace(minute=0, second=0, microsecond=0)
+    max_batch_dt = datetime.datetime.strptime(max_batch_time_str, "%Y%m%d%H%M")
+    max_batch_dt = max_batch_dt.replace(minute=0, second=0, microsecond=0)
+
+    if min_batch_dt is not None and max_batch_dt is not None and min_batch_dt > max_batch_dt:
+        print(f"时间范围非法: min_batch_time_str({min_batch_time_str}) > max_batch_time_str({max_batch_time_str}),退出")
+        return
+    
+    # 从所有的 keep_info 文件中
+    for csv_file in csv_files:
+        batch_time_str = (
+            csv_file.replace("keep_info_", "").replace(".csv", "")
+        )
+        batch_dt = datetime.datetime.strptime(batch_time_str, "%Y%m%d%H%M")
+        batch_hour_dt = batch_dt.replace(minute=0, second=0, microsecond=0)
+        
+        if min_batch_dt is not None and batch_hour_dt < min_batch_dt:
+            continue
+        if max_batch_dt is not None and batch_hour_dt > max_batch_dt:
+            continue
+
+        # 读取 CSV 文件
+        csv_path = os.path.join(object_dir, csv_file)
+        try:
+            df_keep_info = pd.read_csv(csv_path)
+        except Exception as e:
+            print(f"read {csv_path} error: {str(e)}")
+            df_keep_info = pd.DataFrame()
+        
+        if df_keep_info.empty:
+            print(f"keep_info数据为空: {csv_file}")
+            continue
+        
+        df_keep_info_del = df_keep_info[df_keep_info['keep_flag'] == -1].reset_index(drop=True)
+        df_keep_info_del = _validate_keep_info_df(df_keep_info_del)
+        df_keep_info_del['del_batch_time_str'] = batch_time_str
+
+        write_header = not os.path.exists(output_path)
+        df_keep_info_del.to_csv(output_path, mode="a", header=write_header, index=False, encoding="utf-8-sig")
+        del df_keep_info_del
+        print(f"批次:{batch_time_str} 检验结束")
+
+    print("检验结束")
+    print()
+        
+
+if __name__ == "__main__":
+    verify_process("202603121800", "202603131600")
+    pass