evaluate.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. import os
  2. import numpy as np
  3. import pandas as pd
  4. import torch
  5. import torch.distributed as dist
  6. from torch.utils.data import DataLoader
  7. from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, mean_absolute_error
  8. from matplotlib import font_manager
  9. import matplotlib.pyplot as plt
  10. import seaborn as sns
  11. from utils import FlightDataset
  12. # 分布式模型评估
  13. def evaluate_model_distribute(model, device, sequences, targets, group_ids, batch_size=16, test_loader=None,
  14. batch_flight_routes=None, target_scaler=None,
  15. flag_distributed=False, rank=0, local_rank=0, world_size=1, output_dir='.', batch_idx=-1,
  16. csv_file='evaluate_results.csv', evalute_flag='evaluate', save_mode='a'):
  17. if test_loader is None:
  18. if not sequences:
  19. print("没有足够的数据进行评估。")
  20. return
  21. test_dataset = FlightDataset(sequences, targets, group_ids)
  22. test_loader = DataLoader(test_dataset, batch_size=batch_size) # ??
  23. batch_fn_str = ' '.join([route.replace('|', ' ') for route in batch_flight_routes]) if batch_flight_routes else ''
  24. model.eval()
  25. # 初始化存储容器(张量形式以便跨进程通信)
  26. y_preds_list = []
  27. y_trues_list = []
  28. group_info_list = []
  29. with torch.no_grad():
  30. for X_batch, y_batch, group_ids_batch in test_loader:
  31. X_batch = X_batch.to(device)
  32. y_batch = y_batch.to(device)
  33. # 分布式模式下需确保不同进程处理不同数据分片
  34. outputs = model(X_batch)
  35. # 收集当前批次的结果(保留在GPU上)
  36. y_preds_list.append(outputs.cpu().numpy()) # 移动到CPU以节省GPU内存
  37. y_trues_list.append(y_batch.cpu().numpy())
  38. # 处理 group_info(需转换为可序列化格式)
  39. for i in range(len(group_ids_batch[0])):
  40. group_id = tuple(g[i].item() if isinstance(g, torch.Tensor) else g[i] for g in group_ids_batch)
  41. group_info_list.append(group_id)
  42. pass
  43. # 合并当前进程的结果
  44. y_preds = np.concatenate(y_preds_list, axis=0)
  45. y_trues = np.concatenate(y_trues_list, axis=0)
  46. group_info = group_info_list
  47. # --- 分布式结果聚合 ---
  48. if flag_distributed:
  49. # 收集所有进程的预测结果
  50. y_preds_tensor = torch.tensor(y_preds, device=device)
  51. y_trues_tensor = torch.tensor(y_trues, device=device)
  52. # 收集所有进程的 y_preds 和 y_trues
  53. gather_y_preds = [torch.zeros_like(y_preds_tensor) for _ in range(world_size)]
  54. gather_y_trues = [torch.zeros_like(y_trues_tensor) for _ in range(world_size)]
  55. dist.all_gather(gather_y_preds, y_preds_tensor)
  56. dist.all_gather(gather_y_trues, y_trues_tensor)
  57. # 合并结果到 rank 0
  58. if rank == 0:
  59. y_preds = np.concatenate([tensor.cpu().numpy() for tensor in gather_y_preds], axis=0)
  60. y_trues = np.concatenate([tensor.cpu().numpy() for tensor in gather_y_trues], axis=0)
  61. # 将 group_info 转换为字符串列表以便传输
  62. group_info_str = ['|'.join(map(str, info)) for info in group_info]
  63. gather_group_info = [None for _ in range(world_size)]
  64. dist.all_gather_object(gather_group_info, group_info_str)
  65. if rank == 0:
  66. group_info = []
  67. for info_list in gather_group_info:
  68. for info_str in info_list:
  69. group_info.append(tuple(info_str.split('|')))
  70. # --- 仅在 rank 0 计算指标并保存结果 ---
  71. if rank == 0:
  72. # 分类任务结果
  73. y_preds_class = y_preds[:, 0]
  74. y_trues_class = y_trues[:, 0]
  75. y_preds_class_labels = (y_preds_class >= 0.5).astype(int)
  76. y_trues_class_labels = y_trues_class.astype(int)
  77. # 打印指标
  78. printScore_cc(y_trues_class_labels, y_preds_class_labels, batch_fn_str=batch_fn_str, batch_idx=batch_idx, evalute_flag=evalute_flag)
  79. # 构造 DataFrame
  80. results_df = pd.DataFrame({
  81. 'city_pair': [info[0] for info in group_info],
  82. 'flight_day': [info[1] for info in group_info],
  83. 'flight_number_1': [info[2] for info in group_info],
  84. 'flight_number_2': [info[3] for info in group_info],
  85. 'from_date': [info[4] for info in group_info],
  86. 'baggage': [info[5] for info in group_info],
  87. 'price': [info[6] for info in group_info],
  88. 'Hours_until_Departure': [info[7] for info in group_info],
  89. 'update_hour': [info[8] for info in group_info],
  90. 'target_amount_of_drop': [info[9] for info in group_info], # 训练时的验证才有这两个target列
  91. 'target_time_to_drop': [info[10] for info in group_info],
  92. 'probability': y_preds_class,
  93. 'Actual_Will_Price_Drop': y_trues_class_labels,
  94. 'Predicted_Will_Price_Drop': y_preds_class_labels,
  95. })
  96. # 数值处理
  97. threshold = 1e-3
  98. numeric_columns = ['probability', 'target_amount_of_drop', 'target_time_to_drop']
  99. for col in numeric_columns:
  100. results_df[col] = results_df[col].where(results_df[col].abs() >= threshold, 0)
  101. if col in ['target_time_to_drop']:
  102. results_df[col] = results_df[col].round(0).astype(int)
  103. if col in ['target_amount_of_drop']:
  104. results_df[col] = results_df[col].round(2)
  105. # 保存结果
  106. results_df_path = os.path.join(output_dir, csv_file)
  107. if save_mode == 'a':
  108. # 追加模式
  109. results_df.to_csv(results_df_path, mode='a', index=False, header=not os.path.exists(results_df_path))
  110. else:
  111. # 重写模式
  112. results_df.to_csv(results_df_path, mode='w', index=False, header=True)
  113. print(f"预测结果已保存到 '{results_df_path}'")
  114. return results_df
  115. else:
  116. return None
  117. def printScore_cc(y_trues_class_labels, y_preds_class_labels, batch_fn_str='', batch_idx=-1, evalute_flag='evaluate'):
  118. accuracy = accuracy_score(y_trues_class_labels, y_preds_class_labels)
  119. precision = precision_score(y_trues_class_labels, y_preds_class_labels, zero_division=0)
  120. recall = recall_score(y_trues_class_labels, y_preds_class_labels, zero_division=0)
  121. f1 = f1_score(y_trues_class_labels, y_preds_class_labels, zero_division=0)
  122. print(f"分类准确率: {accuracy:.4f}")
  123. print(f"分类精确率: {precision:.4f}")
  124. print(f"分类召回率: {recall:.4f}")
  125. print(f"分类F1值: {f1:.4f}")
  126. # 获取二者都为1的正例索引
  127. indices = np.where((y_trues_class_labels == 1) & (y_preds_class_labels == 1))[0]
  128. if len(indices) > 0:
  129. pass
  130. else:
  131. print("没有正例")
  132. font_path = "./simhei.ttf"
  133. font_prop = font_manager.FontProperties(fname=font_path)
  134. # font_prop = font.font_prop
  135. # 混淆矩阵
  136. cm = confusion_matrix(y_trues_class_labels, y_preds_class_labels)
  137. plt.figure(figsize=(6, 5))
  138. sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
  139. xticklabels=['预测:不会降价', '预测:会降价'],
  140. yticklabels=['实际:不会降价', '实际:会降价'])
  141. plt.xticks(fontproperties=font_prop)
  142. plt.yticks(fontproperties=font_prop)
  143. plt.xlabel('预测情况', fontproperties=font_prop)
  144. plt.ylabel('实际结果', fontproperties=font_prop)
  145. plt.title('分类结果的混淆矩阵', fontproperties=font_prop)
  146. plt.savefig(f"./photo/{evalute_flag}_confusion_matrix_{batch_idx}_{batch_fn_str}.png")