ソースを参照

提交验证脚本

node04 1 週間 前
コミット
918c9bce3f
2 ファイル変更240 行追加7 行削除
  1. 7 7
      result_validate.py
  2. 233 0
      result_validate_0.py

+ 7 - 7
result_validate.py

@@ -113,7 +113,7 @@ def validate_process(node, interval_hours, pred_time_str):
     client.close()
 
     timestamp_str = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
-    save_scv = f"result_validate_{node}_{interval_hours}_{pred_time_str}_{timestamp_str}.csv"
+    save_scv = f"result_validate_{node}_{pred_time_str}_{timestamp_str}.csv"
     
     output_path = os.path.join(output_dir, save_scv)
     df_predict.to_csv(output_path, index=False, encoding="utf-8-sig")
@@ -279,12 +279,12 @@ def validate_process_auto(node, interval_hours):
     client.close()
 
     timestamp_str = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
-    save_scv = f"result_validate_{node}_{interval_hours}_{last_valid_time_str}_{timestamp_str}.csv"
+    save_scv = f"result_validate_{node}_{last_valid_time_str}_{timestamp_str}.csv"
     
     output_path = os.path.join(output_dir, save_scv)
     df_predict.to_csv(output_path, index=False, encoding="utf-8-sig")
     print(f"保存完成: {output_path}")
-    print(f"验证完成: {node} {interval_hours} {last_valid_time_str}")    
+    print(f"验证完成: {node} {last_valid_time_str}")    
     print()
 
 if __name__ == "__main__":
@@ -296,14 +296,14 @@ if __name__ == "__main__":
 
     # 0 手动验证
     if interval_hours == 0:
-        node, interval_hours, pred_time_str = "node0112", 8, "202601151600"
+        node, interval_hours, pred_time_str = "node0112_8", 8, "202601151600"
         validate_process(node, interval_hours, pred_time_str)
     # 自动验证
     else:
         # 这个node可以手动去改
-        node = "node0112"
+        node = "node0117_8"
         if interval_hours == 4:
-            node = "node0114"
+            node = "node0117_4"
         if interval_hours == 2:
-            node = "node0115"    
+            node = "node0117_2"    
         validate_process_auto(node, interval_hours)

+ 233 - 0
result_validate_0.py

@@ -0,0 +1,233 @@
+import argparse
+import datetime
+import os
+import pandas as pd
+from data_loader import mongo_con_parse, validate_one_line, fill_hourly_crawl_date
+
+
+def _validate_predict_df(df_predict):
+    client, db = mongo_con_parse()
+
+    count = 0
+
+    for idx, row in df_predict.iterrows():
+        city_pair = row['city_pair']
+        flight_day = row['flight_day']
+        flight_number_1 = row['flight_number_1']
+        flight_number_2 = row['flight_number_2']
+        baggage = row['baggage']
+        valid_begin_hour = row['valid_begin_hour']
+        valid_begin_dt = pd.to_datetime(valid_begin_hour, format='%Y-%m-%d %H:%M:%S')
+        # valid_end_hour = row['valid_end_hour']
+        # valid_end_dt = pd.to_datetime(valid_end_hour, format='%Y-%m-%d %H:%M:%S')
+        update_hour = row['update_hour']
+        update_dt = pd.to_datetime(update_hour, format='%Y-%m-%d %H:%M:%S')
+        valid_begin_hour_modify = max(
+            valid_begin_dt,
+            update_dt
+        ).strftime('%Y-%m-%d %H:%M:%S')
+        df_val= validate_one_line(db, city_pair, flight_day, flight_number_1, flight_number_2, baggage, valid_begin_hour_modify)
+        # 有可能在当前验证时刻,数据库里没有在valid_begin_hour之后的数据
+        if not df_val.empty:
+            df_val_f = fill_hourly_crawl_date(df_val, rear_fill=2)
+            df_val_f = df_val_f[df_val_f['is_filled']==0]     # 只要原始数据,不要补齐的
+            # df_val_f = df_val_f[df_val_f['update_hour'] <= valid_end_dt]
+            if df_val_f.empty:
+                drop_flag = 0
+                first_drop_amount = pd.NA
+                first_drop_hours = pd.NA
+                last_hours_util = pd.NA
+                last_update_hour = pd.NA
+                list_change_price = []
+                list_change_hours = []
+            else:
+                # 有效数据的最后一行
+                last_row = df_val_f.iloc[-1]
+                last_hours_util = last_row['hours_until_departure']
+                last_update_hour = last_row['update_hour']
+
+                # 价格变化过滤
+                df_price_changes = df_val_f.loc[
+                    df_val_f["adult_total_price"].shift() != df_val_f["adult_total_price"]
+                ].copy()
+
+                # 价格变化幅度
+                df_price_changes['change_amount'] = df_price_changes['adult_total_price'].diff().fillna(0)
+
+                # 找到第一个 change_amount 小于 -5 的行
+                first_negative_change = df_price_changes[df_price_changes['change_amount'] < -5].head(1)
+
+                # 提取所需的值
+                if not first_negative_change.empty:
+                    drop_flag = 1
+                    first_drop_amount = first_negative_change['change_amount'].iloc[0].round(2)
+                    first_drop_hours = first_negative_change['hours_until_departure'].iloc[0]
+                else:
+                    drop_flag = 0
+                    first_drop_amount = pd.NA
+                    first_drop_hours = pd.NA
+
+                list_change_price = df_price_changes['adult_total_price'].tolist()
+                list_change_hours = df_price_changes['hours_until_departure'].tolist()
+
+        else:
+            drop_flag = 0
+            first_drop_amount = pd.NA
+            first_drop_hours = pd.NA
+            last_hours_util = pd.NA
+            last_update_hour = pd.NA
+            list_change_price = []
+            list_change_hours = []
+        
+        safe_sep = "; "
+
+        df_predict.at[idx, 'change_prices'] = safe_sep.join(map(str, list_change_price))
+        df_predict.at[idx, 'change_hours'] = safe_sep.join(map(str, list_change_hours))
+        df_predict.at[idx, 'last_hours_util'] = last_hours_util
+        df_predict.at[idx, 'last_update_hour'] = last_update_hour
+        df_predict.at[idx, 'first_drop_amount'] = first_drop_amount * -1  # 负数转正数
+        df_predict.at[idx, 'first_drop_hours'] = first_drop_hours
+        df_predict.at[idx, 'drop_flag'] = drop_flag
+
+        count += 1
+        if count % 5 == 0:
+            print(f"cal count: {count}")
+
+    print(f"计算结束")
+    client.close()
+
+    return df_predict
+
+
+def validate_process(node, interval_hours, pred_time_str):
+    '''手动验证脚本'''
+    date = pred_time_str[4:8]
+    output_dir = f"./validate/{node}_{date}"
+    os.makedirs(output_dir, exist_ok=True)
+
+    object_dir = "./predictions_0"
+    
+    csv_file = f'future_predictions_{pred_time_str}.csv'  
+    csv_path = os.path.join(object_dir, csv_file)
+
+    try:
+        df_predict = pd.read_csv(csv_path)
+    except Exception as e:
+        print(f"read {csv_path} error: {str(e)}")
+        df_predict = pd.DataFrame()
+    
+    if df_predict.empty:
+        print(f"预测数据为空")
+        return
+    
+    df_predict = _validate_predict_df(df_predict)
+
+    timestamp_str = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
+    save_scv = f"result_validate_{node}_{pred_time_str}_{timestamp_str}.csv"
+    
+    output_path = os.path.join(output_dir, save_scv)
+    df_predict.to_csv(output_path, index=False, encoding="utf-8-sig")
+    print(f"保存完成: {output_path}")
+
+
+def validate_process_auto(node, interval_hours):
+    '''自动验证脚本'''
+    # 当前时间,取整时
+    current_time = datetime.datetime.now() 
+    current_time_str = current_time.strftime("%Y%m%d%H%M")
+    hourly_time = current_time.replace(minute=0, second=0, microsecond=0)
+    hourly_time_str = hourly_time.strftime("%Y%m%d%H%M")
+    print(f"验证时间:{current_time_str}, (取整): {hourly_time_str}")
+
+    output_dir = f"./validate/{node}"
+    os.makedirs(output_dir, exist_ok=True)
+
+    object_dir = "./predictions_0"
+
+    # 检查目录是否存在
+    if not os.path.exists(object_dir):
+        print(f"目录不存在: {object_dir}")
+        return
+                    
+    # 获取所有以 future_predictions_ 开头的 CSV 文件
+    csv_files = []
+    for file in os.listdir(object_dir):
+        if file.startswith("future_predictions_") and file.endswith(".csv"):
+            csv_files.append(file)
+    
+    if not csv_files:
+        print(f"在 {object_dir} 中没有找到 future_predictions_ 开头的 CSV 文件")
+        return
+
+    # 提取时间戳并转换为 datetime 对象
+    file_times = []
+    for file in csv_files:
+        # 提取时间戳部分:future_predictions_202601151600.csv -> 202601151600
+        timestamp_str = file.replace("future_predictions_", "").replace(".csv", "")
+        try:
+            # 将时间戳转换为 datetime 对象
+            file_time = datetime.datetime.strptime(timestamp_str, "%Y%m%d%H%M")
+            file_times.append((file, file_time))
+        except ValueError as e:
+            print(f"文件 {file} 的时间戳格式错误: {e}")
+            continue
+    
+    if not file_times:
+        print("没有找到有效的时间戳文件")
+        return
+
+    # 目标验证文件(当前整点减50小时)
+    target_time = hourly_time - datetime.timedelta(hours=50)                          
+    target_time_str = target_time.strftime("%Y%m%d%H%M")
+    print(f"目标验证时间: {target_time_str}")
+
+    valid_files = [(f, t) for f, t in file_times if t == target_time]
+
+    if not valid_files:
+        print(f"没有找到目标对应时间 {target_time.strftime('%Y%m%d%H%M')} 的文件")
+        return
+
+    valid_file, valid_time = valid_files[0]
+    valid_time_str = valid_time.strftime("%Y%m%d%H%M")
+    print(f"找到符合条件的文件: {valid_file} (时间: {valid_time_str})")
+    
+    csv_path = os.path.join(object_dir, valid_file)
+
+    # 开始验证
+    try:
+        df_predict = pd.read_csv(csv_path)
+    except Exception as e:
+        print(f"read {csv_path} error: {str(e)}")
+        df_predict = pd.DataFrame()
+    
+    if df_predict.empty:
+        print(f"预测数据为空")
+        return
+    
+    df_predict = _validate_predict_df(df_predict)
+
+    timestamp_str = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
+    save_scv = f"result_validate_{node}_{valid_time_str}_{timestamp_str}.csv"
+
+    output_path = os.path.join(output_dir, save_scv)
+    df_predict.to_csv(output_path, index=False, encoding="utf-8-sig")
+    print(f"保存完成: {output_path}")
+    print(f"验证完成: {node} {valid_time_str}")    
+    print()
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description='验证脚本')
+    parser.add_argument('--interval', type=int, choices=[1], 
+                        default=0, help='间隔小时数(1,)')
+    args = parser.parse_args() 
+    interval_hours = args.interval
+
+    # 0 手动验证
+    if interval_hours == 0:
+        node, pred_time_str = "node0127", "202601281700"
+        validate_process(node, interval_hours, pred_time_str)
+    # 1 自动验证
+    else:
+        node = "node0122"
+        validate_process_auto(node, interval_hours)