Pārlūkot izejas kodu

提交uo预测结果验证第一版

node04 1 mēnesi atpakaļ
vecāks
revīzija
c1cb459532
5 mainītis faili ar 232 papildinājumiem un 4 dzēšanām
  1. 1 0
      .gitignore
  2. 70 0
      data_loader.py
  3. 157 0
      result_keep_verify.py
  4. 2 2
      run_uo.sh
  5. 2 2
      uo_atlas_import.py

+ 1 - 0
.gitignore

@@ -35,6 +35,7 @@ photo/
 data_shards/
 predictions/
 keep/
+validate/
 
 # 字体文件(体积大,不适合版本控制)
 *.ttf

+ 70 - 0
data_loader.py

@@ -577,6 +577,76 @@ def load_data(db_config, city_pair, from_date_begin, from_date_end, is_train=Tru
     return df_all
 
 
+def validate_keep_one_line(db, table_name, city_pair, flight_numbers, baggage_weight, from_date, entry_price, update_hour_str, del_batch_std_str,
+                           limit=0, max_retries=3, base_sleep=1.0):
+    """验证keep_info的一行"""
+
+    for attempt in range(1, max_retries + 1):
+        try:
+            print(f"🔁 第 {attempt}/{max_retries} 次尝试查询") 
+
+            collection = db[table_name]
+            projection = {
+                # "_id": 0   # 一般会关掉
+                "citypair": 1,
+                "flight_numbers": 1,
+                "from_date": 1,
+                "from_time": 1,
+                "create_time": 1,
+                "baggage_weight": 1,
+                "cabins": 1,
+                "ticket_amount": 1,
+                "currency": 1,
+                "price_base": 1,
+                "price_tax": 1,
+                "price_total": 1
+            }            
+            query = {
+                "citypair": city_pair,
+                "flight_numbers": flight_numbers,
+                "baggage_weight": baggage_weight,
+                "from_date": from_date,
+                "create_time": {"$lte": del_batch_std_str}
+            }
+            raw_rows = list(collection.find(query, projection))
+            if not raw_rows:
+                return pd.DataFrame()
+            
+            df = pd.DataFrame(raw_rows)
+            if df.empty or 'flight_numbers' not in df.columns or 'from_date' not in df.columns or 'create_time' not in df.columns:
+                return pd.DataFrame()
+            
+            df['_create_time_dt'] = pd.to_datetime(df['create_time'], errors='coerce')
+            df = df[df['_create_time_dt'].notna()].sort_values('_create_time_dt').reset_index(drop=True)
+            if df.empty:
+                return pd.DataFrame()
+
+            entry_price_num = pd.to_numeric(entry_price, errors='coerce')
+            if pd.isna(entry_price_num):
+                return pd.DataFrame()
+            
+            df['_price_total_num'] = pd.to_numeric(df['price_total'], errors='coerce')
+            matched_idx = df.index[df['_price_total_num'] == entry_price_num]
+            if len(matched_idx) == 0:
+                return pd.DataFrame()
+            
+            start_idx = matched_idx[-1]
+            df = df.iloc[start_idx:].reset_index(drop=True)
+            df = df.drop(columns=['_create_time_dt', '_price_total_num'])
+            return df
+
+        except (ServerSelectionTimeoutError, PyMongoError) as e:
+            print(f"⚠️ Mongo 查询失败: {e}")
+            if attempt == max_retries:
+                print("❌ 达到最大重试次数,放弃")
+                return pd.DataFrame()
+            
+            # 指数退避 + 随机抖动
+            sleep_time = base_sleep * (2 ** (attempt - 1)) + random.random()
+            print(f"⏳ {sleep_time:.2f}s 后重试...")
+            time.sleep(sleep_time)
+
+
 if __name__ == "__main__":
     
     cpu_cores = os.cpu_count()  # 你的系统是72

+ 157 - 0
result_keep_verify.py

@@ -0,0 +1,157 @@
+import os
+import datetime
+import pandas as pd
+from data_loader import mongo_con_parse, validate_keep_one_line
+from config import mongo_config, mongo_table_uo
+
+
+def _validate_keep_info_df(df_keep_info_part):
+    client, db = mongo_con_parse(mongo_config)
+    count = 0
+
+    if "price_diff" not in df_keep_info_part.columns:
+        df_keep_info_part["price_diff"] = 0
+    if "time_diff_hours" not in df_keep_info_part.columns:
+        df_keep_info_part["time_diff_hours"] = 0
+    
+
+    for idx, row in df_keep_info_part.iterrows():
+        df_keep_info_part.at[idx, "price_diff"] = 0
+        df_keep_info_part.at[idx, "time_diff_hours"] = 0
+
+        city_pair = row['citypair']
+        flight_numbers = row['flight_numbers']
+        baggage_weight = row['baggage_weight']
+        from_date = row['from_date']
+
+        into_update_hour = row['into_update_hour']
+        into_update_dt = pd.to_datetime(into_update_hour, format='%Y-%m-%d %H:%M:%S')
+        del_batch_time_str = row['del_batch_time_str']
+        del_batch_dt = pd.to_datetime(del_batch_time_str, format='%Y%m%d%H%M')
+        del_batch_std_str = del_batch_dt.strftime('%Y-%m-%d %H:%M:%S')
+
+        entry_price = pd.to_numeric(row.get('price_total'), errors='coerce')
+
+        df_query = validate_keep_one_line(db, mongo_table_uo, city_pair, flight_numbers, baggage_weight, from_date, entry_price, into_update_hour, del_batch_std_str)
+
+        if (not df_query.empty) and pd.notna(entry_price):
+            if ("price_total" in df_query.columns) and ("create_time" in df_query.columns):
+                df_query["price_total"] = pd.to_numeric(df_query["price_total"], errors="coerce")
+                df_query["create_dt"] = pd.to_datetime(df_query["create_time"], errors="coerce")
+                df_query = (
+                    df_query.dropna(subset=["price_total", "create_dt"])
+                    .sort_values("create_dt")
+                    .reset_index(drop=True)
+                )
+                mask_drop = df_query["price_total"] < entry_price
+                if mask_drop.any():
+                    first_row = df_query.loc[mask_drop].iloc[0]
+                    price_diff = entry_price - first_row["price_total"]
+                    time_diff_hours = (first_row["create_dt"] - into_update_dt) / pd.Timedelta(hours=1)
+                    df_keep_info_part.at[idx, "price_diff"] = round(float(price_diff), 2)
+                    df_keep_info_part.at[idx, "time_diff_hours"] = round(float(time_diff_hours), 2)
+
+        del df_query
+
+        count += 1
+        if count % 5 == 0:
+            print(f"cal count: {count}")
+    
+    print(f"计算结束")
+    client.close()
+
+    return df_keep_info_part
+
+
+def verify_process(min_batch_time_str, max_batch_time_str):
+    object_dir = "./keep"
+
+    output_dir = f"./validate/keep"
+    os.makedirs(output_dir, exist_ok=True)
+
+    timestamp_str = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
+    save_scv = f"result_keep_verify_{timestamp_str}.csv"
+    output_path = os.path.join(output_dir, save_scv)
+
+    # 检查目录是否存在
+    if not os.path.exists(object_dir):
+        print(f"目录不存在: {object_dir}")
+        return
+
+    # 获取所有以 keep_info_ 开头的 CSV 文件
+    csv_files = []
+    for file in os.listdir(object_dir):
+        if file.startswith("keep_info_") and file.endswith(".csv"):
+            csv_files.append(file)
+    
+    if not csv_files:
+        print(f"在 {object_dir} 中没有找到 keep_info_ 开头的 CSV 文件")
+        return
+    
+    csv_files.sort()
+    
+    min_batch_dt = datetime.datetime.strptime(min_batch_time_str, "%Y%m%d%H%M")
+    min_batch_dt = min_batch_dt.replace(minute=0, second=0, microsecond=0)
+    max_batch_dt = datetime.datetime.strptime(max_batch_time_str, "%Y%m%d%H%M")
+    max_batch_dt = max_batch_dt.replace(minute=0, second=0, microsecond=0)
+
+    if min_batch_dt is not None and max_batch_dt is not None and min_batch_dt > max_batch_dt:
+        print(f"时间范围非法: min_batch_time_str({min_batch_time_str}) > max_batch_time_str({max_batch_time_str}),退出")
+        return
+    
+    # 从所有的 keep_info 文件中
+    for csv_file in csv_files:
+        batch_time_str = (
+            csv_file.replace("keep_info_", "").replace(".csv", "")
+        )
+        batch_dt = datetime.datetime.strptime(batch_time_str, "%Y%m%d%H%M")
+        batch_hour_dt = batch_dt.replace(minute=0, second=0, microsecond=0)
+        
+        if min_batch_dt is not None and batch_hour_dt < min_batch_dt:
+            continue
+        if max_batch_dt is not None and batch_hour_dt > max_batch_dt:
+            continue
+
+        # 读取 CSV 文件
+        csv_path = os.path.join(object_dir, csv_file)
+        try:
+            df_keep_info = pd.read_csv(csv_path)
+        except Exception as e:
+            print(f"read {csv_path} error: {str(e)}")
+            df_keep_info = pd.DataFrame()
+        
+        if df_keep_info.empty:
+            print(f"keep_info数据为空: {csv_file}")
+            continue
+
+        df_keep_info_del = df_keep_info[df_keep_info['keep_flag'] == -1].reset_index(drop=True)
+        df_keep_info_del['del_batch_time_str'] = batch_time_str
+        df_keep_info_del = _validate_keep_info_df(df_keep_info_del)
+ 
+        # 根据价格变化情况, 移出时间与验证终点时间的对比, 计算 status_flag 状态
+        price_diff_num = pd.to_numeric(df_keep_info_del.get("price_diff"), errors="coerce").fillna(0)
+        del_batch_dt = pd.to_datetime(
+            df_keep_info_del.get("del_batch_time_str"), format="%Y%m%d%H%M", errors="coerce"
+        )
+        valid_end_dt = pd.to_datetime(
+            df_keep_info_del.get("valid_end_hour"), format="%Y-%m-%d %H:%M:%S", errors="coerce"
+        )
+        status_flag = pd.Series(0, index=df_keep_info_del.index, dtype="int64")  # 其它场景
+        status_flag.loc[price_diff_num > 0] = 1   # 降价场景 
+        mask_zero = price_diff_num == 0
+        mask_time_ok = mask_zero & del_batch_dt.notna() & valid_end_dt.notna() & (del_batch_dt >= valid_end_dt)
+        status_flag.loc[mask_time_ok] = 2   # 超时场景
+        df_keep_info_del["status_flag"] = status_flag
+
+        write_header = not os.path.exists(output_path)
+        df_keep_info_del.to_csv(output_path, mode="a", header=write_header, index=False, encoding="utf-8-sig")
+        del df_keep_info_del
+        print(f"批次:{batch_time_str} 检验结束")
+
+    print("检验结束")
+    print()
+
+
+if __name__ == "__main__":
+    verify_process("202604071700", "202604081400")
+    pass

+ 2 - 2
run_uo.sh

@@ -14,7 +14,7 @@ log "=== 脚本开始执行 ==="
 START_TIME=$(date +%s)
 
 # 启动第一个任务(后台执行)
-/home/node04/anaconda3/bin/python main_pe.py >> $LOG_DIR/prediction.log 2>&1 &
+/home/node04/anaconda3/bin/python -u main_pe.py >> $LOG_DIR/prediction.log 2>&1 &
 PID=$!
 
 log "main_pe.py 已启动,PID=$PID"
@@ -41,6 +41,6 @@ fi
 
 log "开始执行 follow_up.py"
 
-/home/node04/anaconda3/bin/python follow_up.py >> $LOG_DIR/keep.log 2>&1
+/home/node04/anaconda3/bin/python -u follow_up.py >> $LOG_DIR/keep.log 2>&1
 
 log "=== 脚本执行结束 ==="

+ 2 - 2
uo_atlas_import.py

@@ -241,8 +241,8 @@ if __name__ == "__main__":
     create_at_end = current_time.strftime("%Y-%m-%d %H:%M:%S")
     create_at_begin = (current_time - timedelta(days=1)).strftime("%Y-%m-%d %H:%M:%S")
 
-    # create_at_begin = "2026-03-30 00:00:00"
-    # create_at_end = "2026-03-31 08:59:59"
+    # create_at_begin = "2026-04-07 00:00:00"
+    # create_at_end = "2026-04-07 10:59:59"
 
     main_import_process(create_at_begin, create_at_end)