evaluate_validate_pnl.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. import argparse
  2. import os
  3. import pandas as pd
  4. def _safe_div(numer, denom):
  5. if denom == 0:
  6. return pd.NA
  7. return round(numer / denom, 4)
  8. def evaluate_validate_pnl(csv_path, output_path=None):
  9. df = pd.read_csv(csv_path)
  10. if df.empty:
  11. print("输入文件为空")
  12. return
  13. if "will_price_drop" not in df.columns:
  14. print("缺少 will_price_drop 字段")
  15. return
  16. if "drop_flag_window" not in df.columns:
  17. if "drop_flag" in df.columns:
  18. print("缺少 drop_flag_window,使用 drop_flag 作为替代口径")
  19. df["drop_flag_window"] = df["drop_flag"]
  20. else:
  21. print("缺少 drop_flag_window 字段,请先用更新后的验证脚本生成")
  22. return
  23. df_signal = df[df["will_price_drop"] == 1].copy()
  24. if df_signal.empty:
  25. print("信号样本为空 (will_price_drop==1)")
  26. return
  27. df_signal["drop_flag_window"] = pd.to_numeric(
  28. df_signal["drop_flag_window"], errors="coerce"
  29. ).fillna(0).astype(int)
  30. df_signal["pnl"] = pd.to_numeric(df_signal.get("pnl"), errors="coerce")
  31. df_signal["pnl_pct"] = pd.to_numeric(df_signal.get("pnl_pct"), errors="coerce")
  32. valid_pnl_mask = df_signal["pnl"].notna()
  33. y_true = df_signal["drop_flag_window"].astype(int)
  34. y_pred = df_signal["will_price_drop"].astype(int)
  35. tp = int(((y_true == 1) & (y_pred == 1)).sum()) # 正阳
  36. tn = int(((y_true == 0) & (y_pred == 0)).sum()) # 正阴
  37. fp = int(((y_true == 0) & (y_pred == 1)).sum()) # 假阳
  38. fn = int(((y_true == 1) & (y_pred == 0)).sum()) # 假阴
  39. accuracy = _safe_div(tp + tn, tp + tn + fp + fn)
  40. precision = _safe_div(tp, tp + fp)
  41. recall = _safe_div(tp, tp + fn)
  42. f1 = (
  43. pd.NA
  44. if precision is pd.NA or recall is pd.NA or (precision + recall) == 0
  45. else round(2 * precision * recall / (precision + recall), 4)
  46. )
  47. pnl_series = df_signal.loc[valid_pnl_mask, "pnl"]
  48. pnl_pct_series = df_signal.loc[valid_pnl_mask, "pnl_pct"]
  49. win_series = pnl_series[pnl_series > 0] # 盈利单
  50. loss_series = pnl_series[pnl_series < 0] # 亏损单
  51. flat_series = pnl_series[pnl_series == 0] # 平价单
  52. win_rate = _safe_div(len(win_series), len(pnl_series)) # 盈利单数占比
  53. avg_win = round(win_series.mean(), 4) if not win_series.empty else pd.NA # 盈利单平均每单盈利
  54. avg_loss = round(abs(loss_series.mean()), 4) if not loss_series.empty else pd.NA # 亏损单平均每单亏损
  55. pl_ratio_avg = (
  56. pd.NA if avg_loss is pd.NA or avg_loss == 0 else round(avg_win / avg_loss, 4)
  57. ) # 平均每单盈亏比
  58. sum_win = round(win_series.sum(), 4) if not win_series.empty else 0.0 # 盈利单总盈利
  59. sum_loss = round(abs(loss_series.sum()), 4) if not loss_series.empty else 0.0 # 亏损单总亏损
  60. pl_ratio_sum = pd.NA if sum_loss == 0 else round(sum_win / sum_loss, 4) # 盈利单总盈利与亏损单总亏损的盈亏比
  61. summary = {
  62. "rows_total": int(len(df)),
  63. "rows_signal": int(len(df_signal)),
  64. "rows_with_pnl": int(valid_pnl_mask.sum()),
  65. "rows_pnl_missing": int((~valid_pnl_mask).sum()),
  66. "tp": tp,
  67. "fp": fp,
  68. "tn": tn,
  69. "fn": fn,
  70. "accuracy": accuracy,
  71. "precision": precision,
  72. "recall": recall,
  73. "f1": f1,
  74. "win_rate": win_rate,
  75. "avg_win": avg_win,
  76. "avg_loss": avg_loss,
  77. "profit_loss_ratio_avg": pl_ratio_avg,
  78. "profit_loss_ratio_sum": pl_ratio_sum,
  79. "pnl_sum": round(pnl_series.sum(), 4) if not pnl_series.empty else pd.NA, # 汇总盈亏
  80. "pnl_pct_mean": round(pnl_pct_series.mean(), 4) if not pnl_pct_series.empty else pd.NA, # 汇总盈亏百分比
  81. "wins": int(len(win_series)), # 盈利单数
  82. "losses": int(len(loss_series)), # 亏损单数
  83. "flats": int(len(flat_series)), # 平价单数
  84. }
  85. summary_df = pd.DataFrame([summary])
  86. print(summary_df.to_string(index=False))
  87. if output_path is None:
  88. base, _ = os.path.splitext(csv_path) # 与验证文件在同一目录
  89. output_path = f"{base}_summary.csv"
  90. summary_df.to_csv(output_path, index=False, encoding="utf-8-sig")
  91. print(f"盈亏汇总已保存: {output_path}")
  92. if __name__ == "__main__":
  93. # 临时添加参数用于调试
  94. # import sys
  95. # if len(sys.argv) == 1:
  96. # sys.argv = [
  97. # sys.argv[0],
  98. # # "/home/node04/yuzhou/jiangcang_vj/validate/node0205_zong/result_validate_node0205_zong_20260211100622.csv", # 替换为实际路径
  99. # "/home/node04/yuzhou/jiangcang_vj/validate/node0210_zong/result_validate_node0210_zong_20260211101300.csv", # 替换为实际路径
  100. # # "--output", "debug_output.csv"
  101. # ]
  102. parser = argparse.ArgumentParser(description="验证结果的准确率与盈亏比统计")
  103. parser.add_argument("csv_path", help="result_validate_*.csv 路径")
  104. parser.add_argument("--output", default=None, help="汇总输出 CSV 路径")
  105. args = parser.parse_args()
  106. evaluate_validate_pnl(args.csv_path, args.output)