import os import numpy as np import pandas as pd import torch import torch.distributed as dist from torch.utils.data import DataLoader from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, mean_absolute_error from matplotlib import font_manager import matplotlib.pyplot as plt import seaborn as sns from utils import FlightDataset # 分布式模型评估 def evaluate_model_distribute(model, device, sequences, targets, group_ids, batch_size=16, test_loader=None, batch_flight_routes=None, target_scaler=None, flag_distributed=False, rank=0, local_rank=0, world_size=1, output_dir='.', batch_idx=-1, csv_file='evaluate_results.csv', evalute_flag='evaluate', save_mode='a'): if test_loader is None: if not sequences: print("没有足够的数据进行评估。") return test_dataset = FlightDataset(sequences, targets, group_ids) test_loader = DataLoader(test_dataset, batch_size=batch_size) # ?? batch_fn_str = ' '.join([route.replace('|', ' ') for route in batch_flight_routes]) if batch_flight_routes else '' model.eval() # 初始化存储容器(张量形式以便跨进程通信) y_preds_list = [] y_trues_list = [] group_info_list = [] with torch.no_grad(): for X_batch, y_batch, group_ids_batch in test_loader: X_batch = X_batch.to(device) y_batch = y_batch.to(device) # 分布式模式下需确保不同进程处理不同数据分片 outputs = model(X_batch) # 收集当前批次的结果(保留在GPU上) y_preds_list.append(outputs.cpu().numpy()) # 移动到CPU以节省GPU内存 y_trues_list.append(y_batch.cpu().numpy()) # 处理 group_info(需转换为可序列化格式) for i in range(len(group_ids_batch[0])): group_id = tuple(g[i].item() if isinstance(g, torch.Tensor) else g[i] for g in group_ids_batch) group_info_list.append(group_id) pass # 合并当前进程的结果 y_preds = np.concatenate(y_preds_list, axis=0) y_trues = np.concatenate(y_trues_list, axis=0) group_info = group_info_list # --- 分布式结果聚合 --- if flag_distributed: # 收集所有进程的预测结果 y_preds_tensor = torch.tensor(y_preds, device=device) y_trues_tensor = torch.tensor(y_trues, device=device) # 收集所有进程的 y_preds 和 y_trues gather_y_preds = [torch.zeros_like(y_preds_tensor) for _ in range(world_size)] gather_y_trues = [torch.zeros_like(y_trues_tensor) for _ in range(world_size)] dist.all_gather(gather_y_preds, y_preds_tensor) dist.all_gather(gather_y_trues, y_trues_tensor) # 合并结果到 rank 0 if rank == 0: y_preds = np.concatenate([tensor.cpu().numpy() for tensor in gather_y_preds], axis=0) y_trues = np.concatenate([tensor.cpu().numpy() for tensor in gather_y_trues], axis=0) # 将 group_info 转换为字符串列表以便传输 group_info_str = ['|'.join(map(str, info)) for info in group_info] gather_group_info = [None for _ in range(world_size)] dist.all_gather_object(gather_group_info, group_info_str) if rank == 0: group_info = [] for info_list in gather_group_info: for info_str in info_list: group_info.append(tuple(info_str.split('|'))) # --- 仅在 rank 0 计算指标并保存结果 --- if rank == 0: # 分类任务结果 y_preds_class = y_preds[:, 0] y_trues_class = y_trues[:, 0] y_preds_class_labels = (y_preds_class >= 0.5).astype(int) y_trues_class_labels = y_trues_class.astype(int) # 打印指标 printScore_cc(y_trues_class_labels, y_preds_class_labels, batch_fn_str=batch_fn_str, batch_idx=batch_idx, evalute_flag=evalute_flag) # 构造 DataFrame results_df = pd.DataFrame({ 'city_pair': [info[0] for info in group_info], 'flight_day': [info[1] for info in group_info], 'flight_number_1': [info[2] for info in group_info], 'flight_number_2': [info[3] for info in group_info], 'from_date': [info[4] for info in group_info], 'baggage': [info[5] for info in group_info], 'price': [info[6] for info in group_info], 'Hours_until_Departure': [info[7] for info in group_info], 'update_hour': [info[8] for info in group_info], 'probability': y_preds_class, 'Actual_Will_Price_Drop': y_trues_class_labels, 'Predicted_Will_Price_Drop': y_preds_class_labels, }) # 数值处理 threshold = 1e-3 numeric_columns = ['probability', # 'Actual_Amount_Of_Drop', 'Predicted_Amount_Of_Drop', 'Actual_Time_To_Drop', 'Predicted_Time_To_Drop' ] for col in numeric_columns: results_df[col] = results_df[col].where(results_df[col].abs() >= threshold, 0) # 保存结果 results_df_path = os.path.join(output_dir, csv_file) if save_mode == 'a': # 追加模式 results_df.to_csv(results_df_path, mode='a', index=False, header=not os.path.exists(results_df_path)) else: # 重写模式 results_df.to_csv(results_df_path, mode='w', index=False, header=True) print(f"预测结果已保存到 '{results_df_path}'") return results_df else: return None def printScore_cc(y_trues_class_labels, y_preds_class_labels, batch_fn_str='', batch_idx=-1, evalute_flag='evaluate'): accuracy = accuracy_score(y_trues_class_labels, y_preds_class_labels) precision = precision_score(y_trues_class_labels, y_preds_class_labels, zero_division=0) recall = recall_score(y_trues_class_labels, y_preds_class_labels, zero_division=0) f1 = f1_score(y_trues_class_labels, y_preds_class_labels, zero_division=0) print(f"分类准确率: {accuracy:.4f}") print(f"分类精确率: {precision:.4f}") print(f"分类召回率: {recall:.4f}") print(f"分类F1值: {f1:.4f}") # 获取二者都为1的正例索引 indices = np.where((y_trues_class_labels == 1) & (y_preds_class_labels == 1))[0] if len(indices) > 0: pass else: print("没有正例") font_path = "./simhei.ttf" font_prop = font_manager.FontProperties(fname=font_path) # font_prop = font.font_prop # 混淆矩阵 cm = confusion_matrix(y_trues_class_labels, y_preds_class_labels) plt.figure(figsize=(6, 5)) sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['预测:不会降价', '预测:会降价'], yticklabels=['实际:不会降价', '实际:会降价']) plt.xticks(fontproperties=font_prop) plt.yticks(fontproperties=font_prop) plt.xlabel('预测情况', fontproperties=font_prop) plt.ylabel('实际结果', fontproperties=font_prop) plt.title('分类结果的混淆矩阵', fontproperties=font_prop) plt.savefig(f"./photo/{evalute_flag}_confusion_matrix_{batch_idx}_{batch_fn_str}.png")