Explorar el Código

提交预测结果验证代码 修改redis地址

node04 hace 1 mes
padre
commit
7581ed0a09
Se han modificado 4 ficheros con 218 adiciones y 2 borrados
  1. 101 0
      data_loader.py
  2. 1 1
      main_tr.py
  3. 115 0
      result_validate.py
  4. 1 1
      train.py

+ 101 - 0
data_loader.py

@@ -752,6 +752,107 @@ def query_all_flight_number(db, table_name):
     
     return list_flight_number
 
+
+def validate_one_line(db, city_pair, flight_day, flight_number_1, flight_number_2, baggage, valid_begin_hour, 
+                      limit=0, max_retries=3, base_sleep=1.0):
+    """验证预测结果的一行"""
+    
+    if city_pair in vj_flight_route_list_hot:
+        table_name = CLEAN_VJ_HOT_NEAR_INFO_TAB
+    elif city_pair in vj_flight_route_list_nothot: 
+        table_name = CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
+    else:
+        print(f"城市对{city_pair}不在热门航线与冷门航线, 返回")
+        return pd.DataFrame()
+        
+    city_pair_split = city_pair.split('-')
+    from_city_code = city_pair_split[0]
+    to_city_code = city_pair_split[1]
+    flight_day_str = datetime.strptime(flight_day, "%Y-%m-%d").strftime("%Y%m%d") 
+    baggage_str = f"1-{baggage}"
+
+    for attempt in range(1, max_retries + 1):
+        try:
+            print(f"🔁 第 {attempt}/{max_retries} 次尝试查询") 
+            # 构建查询条件
+            query_condition = {
+                "from_city_code": from_city_code,
+                "to_city_code": to_city_code,
+                "search_dep_time": flight_day_str,
+                "segments.baggage": baggage_str,
+                "crawl_date": {"$gte": valid_begin_hour},
+                "segments.0.flight_number": flight_number_1,
+            }
+            # 如果有第二段
+            if flight_number_2 != "VJ":
+                query_condition["segments.1.flight_number"] = flight_number_2
+            print(f"   查询条件: {query_condition}")
+            # 定义要查询的字段
+            projection = {
+                # "_id": 1,
+                "from_city_code": 1,
+                "search_dep_time": 1,
+                "to_city_code": 1,
+                "currency": 1,
+                "adult_price": 1,
+                "adult_tax": 1,
+                "adult_total_price": 1,
+                "seats_remaining": 1,
+                "segments": 1,
+                "source_website": 1,
+                "crawl_date": 1
+            }
+            # 执行查询
+            cursor = db.get_collection(table_name).find(
+                query_condition,
+                projection=projection  # 添加投影参数
+            ).sort(
+                [
+                    ("crawl_date", 1)
+                ]
+            )
+            if limit > 0:
+                cursor = cursor.limit(limit)
+
+            # 将结果转换为列表
+            results = list(cursor)
+            print(f"✅ 查询成功,找到 {len(results)} 条记录")
+
+            if results:
+                df = pd.DataFrame(results)
+                # 处理特殊的 ObjectId 类型
+                if '_id' in df.columns:
+                    df = df.drop(columns=['_id'])
+                print(f"📊 已转换为 DataFrame,形状: {df.shape}")
+
+                # 1️⃣ 展开 segments
+                print(f"📊 开始扩展segments 稍等...")
+                t1 = time.time()
+                df = expand_segments_columns(df)
+                t2 = time.time()
+                rt = round(t2 - t1, 3)
+                print(f"用时: {rt} 秒")
+                print(f"📊 已将segments扩展成字段,形状: {df.shape}")
+
+                # 不用排序,因为mongo语句已经排好
+                return df
+
+            else:
+                print("⚠️  查询结果为空")
+                return pd.DataFrame()
+
+        except (ServerSelectionTimeoutError, PyMongoError) as e:
+            print(f"⚠️ Mongo 查询失败: {e}")
+            if attempt == max_retries:
+                print("❌ 达到最大重试次数,放弃")
+                return pd.DataFrame()
+            
+            # 指数退避 + 随机抖动
+            sleep_time = base_sleep * (2 ** (attempt - 1)) + random.random()
+            print(f"⏳ {sleep_time:.2f}s 后重试...")
+            time.sleep(sleep_time)
+
+
 if __name__ == "__main__":
 
     # test_mongo_connection(db)

+ 1 - 1
main_tr.py

@@ -148,7 +148,7 @@ def start_train():
     target_scaler = None      # 初始化目标缩放器
 
     # 初始化 Redis 客户端(请根据实际情况修改 host、port、db)
-    redis_client = redis.Redis(host='192.168.20.222', port=6379, db=0)
+    redis_client = redis.Redis(host='192.168.20.237', port=6379, db=0)
     lock_key = "data_loading_lock_11"
     barrier_key = 'distributed_barrier_11'
 

+ 115 - 0
result_validate.py

@@ -0,0 +1,115 @@
+import os
+import datetime
+import pandas as pd
+from data_loader import mongo_con_parse, validate_one_line, fill_hourly_crawl_date
+
+
+def validate_process(node, date):
+
+    output_dir = f"./validate/{node}_{date}"
+    os.makedirs(output_dir, exist_ok=True)
+
+    object_dir = "./data_shards"
+    csv_file = 'future_predictions.csv'  
+    csv_path = os.path.join(object_dir, csv_file)
+
+    try:
+        df_predict = pd.read_csv(csv_path)
+    except Exception as e:
+        print(f"read {csv_path} error: {str(e)}")
+        df_predict = pd.DataFrame()
+    
+    if df_predict.empty:
+        print(f"预测数据为空")
+        return
+
+    # fly_day = df_predict['flight_day'].unique()[0]
+
+    client, db = mongo_con_parse()
+
+    count = 0
+    for idx, row in df_predict.iterrows(): 
+        city_pair = row['city_pair']
+        flight_day = row['flight_day']
+        flight_number_1 = row['flight_number_1']
+        flight_number_2 = row['flight_number_2']
+        baggage = row['baggage']
+        valid_begin_hour = row['valid_begin_hour'] 
+        df_val= validate_one_line(db, city_pair, flight_day, flight_number_1, flight_number_2, baggage, valid_begin_hour)
+        if not df_val.empty:
+            df_val_f = fill_hourly_crawl_date(df_val, rear_fill=2)
+            df_val_f = df_val_f[df_val_f['is_filled']==0]    # 只要原始数据,不要补齐的
+            if df_val_f.empty:
+                drop_flag = 0
+                first_drop_amount = pd.NA
+                first_drop_hours = pd.NA
+                last_hours_util = pd.NA
+                last_update_hour = pd.NA
+                list_change_price = []
+                list_change_hours = []
+            else:
+                # 有效数据的最后一行
+                last_row = df_val_f.iloc[-1]
+                last_hours_util = last_row['hours_until_departure']
+                last_update_hour = last_row['update_hour']
+                
+                # 价格变化过滤
+                df_price_changes = df_val_f.loc[
+                    df_val_f["adult_total_price"].shift() != df_val_f["adult_total_price"]
+                ].copy()
+            
+                # 价格变化幅度
+                df_price_changes['change_amount'] = df_price_changes['adult_total_price'].diff().fillna(0)
+
+                # 找到第一个 change_amount 小于 -10 的行
+                first_negative_change = df_price_changes[df_price_changes['change_amount'] < -10].head(1)
+
+                # 提取所需的值
+                if not first_negative_change.empty:
+                    drop_flag = 1
+                    first_drop_amount = first_negative_change['change_amount'].iloc[0].round(2)
+                    first_drop_hours = first_negative_change['hours_until_departure'].iloc[0]
+                else:
+                    drop_flag = 0
+                    first_drop_amount = pd.NA
+                    first_drop_hours = pd.NA
+                    
+                list_change_price = df_price_changes['adult_total_price'].tolist()
+                list_change_hours = df_price_changes['hours_until_departure'].tolist()
+        else:
+            drop_flag = 0
+            first_drop_amount = pd.NA
+            first_drop_hours = pd.NA
+            last_hours_util = pd.NA
+            last_update_hour = pd.NA
+            list_change_price = []
+            list_change_hours = []
+
+        safe_sep = "; "
+        
+        df_predict.at[idx, 'change_prices'] = safe_sep.join(map(str, list_change_price))
+        df_predict.at[idx, 'change_hours'] = safe_sep.join(map(str, list_change_hours))
+        df_predict.at[idx, 'last_hours_util'] = last_hours_util
+        df_predict.at[idx, 'last_update_hour'] = last_update_hour
+        df_predict.at[idx, 'first_drop_amount'] = first_drop_amount * -1  # 负数转正数
+        df_predict.at[idx, 'first_drop_hours'] = first_drop_hours
+        df_predict.at[idx, 'drop_flag'] = drop_flag
+
+        count += 1
+        if count % 5 == 0:
+            print(f"cal count: {count}")
+    
+    print(f"计算结束")
+    client.close()
+
+    timestamp_str = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
+    save_scv = f"result_validate_{node}_{date}_{timestamp_str}.csv"
+
+    output_path = os.path.join(output_dir, save_scv)
+    df_predict.to_csv(output_path, index=False, encoding="utf-8-sig")
+    print(f"保存完成: {output_path}")
+
+
+if __name__ == "__main__":
+    node, date = "node0105", "0107"
+    validate_process(node, date)

+ 1 - 1
train.py

@@ -153,7 +153,7 @@ def prepare_data_distribute(sequences, targets, group_ids, flag_distributed=Fals
 
     # 等待其他进程生成数据,并同步
     if flag_distributed:
-        redis_client = redis.Redis(host='192.168.20.222', port=6379, db=0)
+        redis_client = redis.Redis(host='192.168.20.237', port=6379, db=0)
         barrier_key = 'distributed_barrier_11'
 
         # 等待所有进程都到达 barrier