| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128 |
- import argparse
- import os
- import pandas as pd
- def _safe_div(numer, denom):
- if denom == 0:
- return pd.NA
- return round(numer / denom, 4)
- def evaluate_validate_pnl(csv_path, output_path=None):
- df = pd.read_csv(csv_path)
- if df.empty:
- print("输入文件为空")
- return
- if "will_price_drop" not in df.columns:
- print("缺少 will_price_drop 字段")
- return
- if "drop_flag_window" not in df.columns:
- if "drop_flag" in df.columns:
- print("缺少 drop_flag_window,使用 drop_flag 作为替代口径")
- df["drop_flag_window"] = df["drop_flag"]
- else:
- print("缺少 drop_flag_window 字段,请先用更新后的验证脚本生成")
- return
- df_signal = df[df["will_price_drop"] == 1].copy()
- if df_signal.empty:
- print("信号样本为空 (will_price_drop==1)")
- return
- df_signal["drop_flag_window"] = pd.to_numeric(
- df_signal["drop_flag_window"], errors="coerce"
- ).fillna(0).astype(int)
- df_signal["pnl"] = pd.to_numeric(df_signal.get("pnl"), errors="coerce")
- df_signal["pnl_pct"] = pd.to_numeric(df_signal.get("pnl_pct"), errors="coerce")
- valid_pnl_mask = df_signal["pnl"].notna()
- y_true = df_signal["drop_flag_window"].astype(int)
- y_pred = df_signal["will_price_drop"].astype(int)
- tp = int(((y_true == 1) & (y_pred == 1)).sum()) # 正阳
- tn = int(((y_true == 0) & (y_pred == 0)).sum()) # 正阴
- fp = int(((y_true == 0) & (y_pred == 1)).sum()) # 假阳
- fn = int(((y_true == 1) & (y_pred == 0)).sum()) # 假阴
- accuracy = _safe_div(tp + tn, tp + tn + fp + fn)
- precision = _safe_div(tp, tp + fp)
- recall = _safe_div(tp, tp + fn)
- f1 = (
- pd.NA
- if precision is pd.NA or recall is pd.NA or (precision + recall) == 0
- else round(2 * precision * recall / (precision + recall), 4)
- )
- pnl_series = df_signal.loc[valid_pnl_mask, "pnl"]
- pnl_pct_series = df_signal.loc[valid_pnl_mask, "pnl_pct"]
- win_series = pnl_series[pnl_series > 0] # 盈利单
- loss_series = pnl_series[pnl_series < 0] # 亏损单
- flat_series = pnl_series[pnl_series == 0] # 平价单
- win_rate = _safe_div(len(win_series), len(pnl_series)) # 盈利单数占比
- avg_win = round(win_series.mean(), 4) if not win_series.empty else pd.NA # 盈利单平均每单盈利
- avg_loss = round(abs(loss_series.mean()), 4) if not loss_series.empty else pd.NA # 亏损单平均每单亏损
- pl_ratio_avg = (
- pd.NA if avg_loss is pd.NA or avg_loss == 0 else round(avg_win / avg_loss, 4)
- ) # 平均每单盈亏比
- sum_win = round(win_series.sum(), 4) if not win_series.empty else 0.0 # 盈利单总盈利
- sum_loss = round(abs(loss_series.sum()), 4) if not loss_series.empty else 0.0 # 亏损单总亏损
- pl_ratio_sum = pd.NA if sum_loss == 0 else round(sum_win / sum_loss, 4) # 盈利单总盈利与亏损单总亏损的盈亏比
- summary = {
- "rows_total": int(len(df)),
- "rows_signal": int(len(df_signal)),
- "rows_with_pnl": int(valid_pnl_mask.sum()),
- "rows_pnl_missing": int((~valid_pnl_mask).sum()),
- "tp": tp,
- "fp": fp,
- "tn": tn,
- "fn": fn,
- "accuracy": accuracy,
- "precision": precision,
- "recall": recall,
- "f1": f1,
- "win_rate": win_rate,
- "avg_win": avg_win,
- "avg_loss": avg_loss,
- "profit_loss_ratio_avg": pl_ratio_avg,
- "profit_loss_ratio_sum": pl_ratio_sum,
- "pnl_sum": round(pnl_series.sum(), 4) if not pnl_series.empty else pd.NA, # 汇总盈亏
- "pnl_pct_mean": round(pnl_pct_series.mean(), 4) if not pnl_pct_series.empty else pd.NA, # 汇总盈亏百分比
- "wins": int(len(win_series)), # 盈利单数
- "losses": int(len(loss_series)), # 亏损单数
- "flats": int(len(flat_series)), # 平价单数
- }
- summary_df = pd.DataFrame([summary])
- print(summary_df.to_string(index=False))
- if output_path is None:
- base, _ = os.path.splitext(csv_path) # 与验证文件在同一目录
- output_path = f"{base}_summary.csv"
- summary_df.to_csv(output_path, index=False, encoding="utf-8-sig")
- print(f"盈亏汇总已保存: {output_path}")
- if __name__ == "__main__":
- # 临时添加参数用于调试
- # import sys
- # if len(sys.argv) == 1:
- # sys.argv = [
- # sys.argv[0],
- # "/home/node04/yuzhou/jiangcang_vj/validate/node0205_zong/result_validate_node0205_zong_20260209155011.csv", # 替换为实际路径
- # # "--output", "debug_output.csv"
- # ]
- parser = argparse.ArgumentParser(description="验证结果的准确率与盈亏比统计")
- parser.add_argument("csv_path", help="result_validate_*.csv 路径")
- parser.add_argument("--output", default=None, help="汇总输出 CSV 路径")
- args = parser.parse_args()
- evaluate_validate_pnl(args.csv_path, args.output)
|