Procházet zdrojové kódy

新写一份VJ预测准确度验证

node04 před 1 měsícem
rodič
revize
e0fa586524
2 změnil soubory, kde provedl 205 přidání a 2 odebrání
  1. 1 0
      descending_cabin_task.py
  2. 204 2
      result_keep_verify.py

+ 1 - 0
descending_cabin_task.py

@@ -585,6 +585,7 @@ def sync_policy(payload):
     }
     # print(json.dumps(payload, ensure_ascii=False, indent=2))
     response = requests.post(POLICY_URL, headers=headers, json=payload, timeout=30)
+    # print(response.text[:1000])
     resp_json = response.json()
     """
     {

+ 204 - 2
result_keep_verify.py

@@ -171,8 +171,210 @@ def verify_process(min_batch_time_str, max_batch_time_str):
 
     print("检验结束")
     print()
+
+
+def verify_process_2(min_batch_time_str, max_batch_time_str):
+
+    object_dir = "/home/node04/descending_cabin_files"
+
+    output_dir = f"./validate/keep"
+    os.makedirs(output_dir, exist_ok=True)
+
+    timestamp_str = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
+    save_scv = f"result_keep_verify_{timestamp_str}.csv"
+    output_path = os.path.join(output_dir, save_scv)
+
+    # 检查目录是否存在
+    if not os.path.exists(object_dir):
+        print(f"目录不存在: {object_dir}")
+        return
+                    
+    # 获取所有以 keep_info_end_ 开头的 CSV 文件
+    csv_files = []
+    for file in os.listdir(object_dir):
+        if file.startswith("keep_info_end_") and file.endswith(".csv"):
+            csv_files.append(file)
+    
+    if not csv_files:
+        print(f"在 {object_dir} 中没有找到 keep_info_end_ 开头的 CSV 文件")
+        return
+    
+    csv_files.sort()
+
+    min_batch_dt = datetime.datetime.strptime(min_batch_time_str, "%Y%m%d%H%M")
+    min_batch_dt = min_batch_dt.replace(minute=0, second=0, microsecond=0)
+    max_batch_dt = datetime.datetime.strptime(max_batch_time_str, "%Y%m%d%H%M")
+    max_batch_dt = max_batch_dt.replace(minute=0, second=0, microsecond=0)
+
+    if min_batch_dt is not None and max_batch_dt is not None and min_batch_dt > max_batch_dt:
+        print(f"时间范围非法: min_batch_time_str({min_batch_time_str}) > max_batch_time_str({max_batch_time_str}),退出")
+        return
+    
+    list_df = []
+
+    # 从所有的 keep_info_end_ 文件中
+    for csv_file in csv_files:
+        batch_time_str = csv_file.replace("keep_info_end_", "").replace(".csv", "")
+        batch_dt = datetime.datetime.strptime(batch_time_str, "%Y%m%d%H%M%S")
+        batch_hour_dt = batch_dt.replace(minute=0, second=0, microsecond=0)
+
+        if min_batch_dt is not None and batch_hour_dt < min_batch_dt:
+            continue
+        if max_batch_dt is not None and batch_hour_dt > max_batch_dt:
+            continue
+
+        # 读取 CSV 文件
+        csv_path = os.path.join(object_dir, csv_file)
+        try:
+            df_keep_info = pd.read_csv(csv_path)
+        except Exception as e:
+            print(f"read {csv_path} error: {str(e)}")
+            continue
+
+        if df_keep_info.empty:
+            print(f"keep_info数据为空: {csv_file}")
+            continue
         
+        df_keep_info["batch_time_str"] = batch_hour_dt.strftime("%Y%m%d%H%M")
+        # df_keep_info["src_file"] = csv_file
+        list_df.append(df_keep_info)
+        del df_keep_info
+
+    if not list_df:
+        print("时间范围内没有可用 keep_info_end_ 数据")
+        return
+    
+    df_keep_all = pd.concat(list_df, ignore_index=True)
+    del list_df
+    
+    sort_cols = ["city_pair", "flight_day", "flight_number_1", "flight_number_2", "into_update_hour"]
+    df_keep_all = df_keep_all.sort_values(sort_cols, kind="mergesort").reset_index(drop=True)
+    df_keep_all["gid"] = df_keep_all.groupby(sort_cols, sort=False).ngroup().astype("int64") + 1
+
+    client, db = mongo_con_parse()
+    list_base_row = []
+
+    for gid, df_gid in df_keep_all.groupby("gid", sort=False):
+        city_pair = df_gid["city_pair"].iloc[0]
+        flight_day = df_gid["flight_day"].iloc[0]
+        flight_number_1 = df_gid["flight_number_1"].iloc[0]
+        flight_number_2 = df_gid["flight_number_2"].iloc[0]
+        into_update_hour = df_gid["into_update_hour"].iloc[0]
+        valid_end_hour = df_gid["valid_end_hour"].iloc[0]
+
+        into_update_dt = pd.to_datetime(
+            df_gid.get("into_update_hour"), format="%Y-%m-%d %H:%M:%S", errors="coerce"
+        ).min()
+        batch_dt = pd.to_datetime(
+            df_gid.get("batch_time_str"), format="%Y%m%d%H%M", errors="coerce"
+        ).max()
+
+        valid_end_dt = pd.to_datetime(valid_end_hour, format="%Y-%m-%d %H:%M:%S", errors="coerce")
+        
+        flag = 0   # 等待(弹出)标记
+        if batch_dt >= valid_end_dt:
+            flag = 2     # 超时标记      
+
+        if pd.isna(into_update_dt) or pd.isna(batch_dt):
+            print(f"gid={gid} 时间字段解析失败,跳过")
+            continue
+
+        crawl_date_begin = (batch_dt + pd.Timedelta(hours=0)).strftime("%Y-%m-%d %H:%M:%S")
+        crawl_date_end = (batch_dt + pd.Timedelta(hours=8)).strftime("%Y-%m-%d %H:%M:%S")
+
+        if city_pair in vj_flight_route_list_hot:
+            table_name_far = CLEAN_VJ_HOT_FAR_INFO_TAB
+            table_name_near = CLEAN_VJ_HOT_NEAR_INFO_TAB
+        elif city_pair in vj_flight_route_list_nothot:
+            table_name_far = CLEAN_VJ_NOTHOT_FAR_INFO_TAB
+            table_name_near = CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
+        else:
+            print(f"gid={gid} 城市对{city_pair}不在热门/冷门列表,跳过")
+            continue
+
+        baggage = 0
+        df_query_far = validate_keep_one_line(
+            db,
+            table_name_far,
+            city_pair,
+            flight_day,
+            flight_number_1,
+            flight_number_2,
+            baggage,
+            crawl_date_begin,
+            crawl_date_end,
+        )
+        df_query_near = validate_keep_one_line(
+            db,
+            table_name_near,
+            city_pair,
+            flight_day,
+            flight_number_1,
+            flight_number_2,
+            baggage,
+            crawl_date_begin,
+            crawl_date_end,
+        )
+        df_query = pd.concat([df_query_far, df_query_near], ignore_index=True)
+
+        df_g1 = df_gid.copy()
+        df_g2 = df_query.copy()
+
+        df_g1["_batch_dt"] = pd.to_datetime(
+            df_g1.get("batch_time_str"), format="%Y%m%d%H%M", errors="coerce"
+        )
+
+        last_price = float(df_g1.iloc[-1]["adult_total_price"])
+        df_last_price = df_g1[df_g1["adult_total_price"] == last_price]
+        base_row = df_last_price.iloc[0]
+        # base_pos = int(df_last_price.index[0])
+        base_dt = base_row["_batch_dt"]
+        base_price = float(base_row["adult_total_price"])
+       
+        # drop_pos = pd.NA
+        drop_crawl_date = pd.NA
+        drop_price = pd.NA
+        price_diff = 0.0
+        time_diff_hours = 0.0
+
+        if not df_g2.empty:
+            df_g2["crawl_dt"] = pd.to_datetime(df_g2.get("crawl_date"), errors="coerce")
+            mask_drop = df_g2["adult_total_price"] < base_price
+            if mask_drop.any():
+                drop_row = df_g2.loc[mask_drop].iloc[0]
+                # drop_pos = int(drop_row.name)
+                drop_crawl_date = drop_row.get("crawl_date")
+                drop_price = float(drop_row["adult_total_price"])
+                price_diff = round(base_price - drop_price, 2)
+                time_diff_hours = round(
+                    float((drop_row["crawl_dt"] - base_dt) / pd.Timedelta(hours=1)),
+                    2,
+                )
+                flag = 1  # 发生降价标记
+        
+        base_row_cp = base_row.copy()
+        base_row_cp["end_batch_dt"] = batch_dt
+        base_row_cp["drop_crawl_date"] = drop_crawl_date
+        base_row_cp["drop_price"] = drop_price
+        base_row_cp["price_diff"] = price_diff
+        base_row_cp["time_diff_hours"] = time_diff_hours
+        base_row_cp["flag"] = flag
+        list_base_row.append(base_row_cp)
+
+        del df_g1
+        del df_g2
+        del df_last_price
+        del df_query_far
+        del df_query_near
+        del df_query
+
+    client.close()
+
+    df_base = pd.DataFrame(list_base_row)
+    df_base.to_csv(output_path, header=True, index=False, encoding="utf-8-sig")
+    print(f"输出: {output_path}")
+    return
 
 if __name__ == "__main__":
-    verify_process("202603161800", "202603180800")
-    pass
+    # verify_process("202604021500", "202604030900")
+    verify_process_2("202604021700", "202604031600")