| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178 |
- import os
- import numpy as np
- import pandas as pd
- import torch
- import torch.distributed as dist
- from torch.utils.data import DataLoader
- from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, mean_absolute_error
- from matplotlib import font_manager
- import matplotlib.pyplot as plt
- import seaborn as sns
- from utils import FlightDataset
- # 分布式模型评估
- def evaluate_model_distribute(model, device, sequences, targets, group_ids, batch_size=16, test_loader=None,
- batch_flight_routes=None, target_scaler=None,
- flag_distributed=False, rank=0, local_rank=0, world_size=1, output_dir='.', batch_idx=-1,
- csv_file='evaluate_results.csv', evalute_flag='evaluate', save_mode='a'):
-
- if test_loader is None:
- if not sequences:
- print("没有足够的数据进行评估。")
- return
- test_dataset = FlightDataset(sequences, targets, group_ids)
- test_loader = DataLoader(test_dataset, batch_size=batch_size) # ??
-
- batch_fn_str = ' '.join([route.replace('|', ' ') for route in batch_flight_routes]) if batch_flight_routes else ''
- model.eval()
- # 初始化存储容器(张量形式以便跨进程通信)
- y_preds_list = []
- y_trues_list = []
- group_info_list = []
- with torch.no_grad():
- for X_batch, y_batch, group_ids_batch in test_loader:
- X_batch = X_batch.to(device)
- y_batch = y_batch.to(device)
- # 分布式模式下需确保不同进程处理不同数据分片
- outputs = model(X_batch)
- # 收集当前批次的结果(保留在GPU上)
- y_preds_list.append(outputs.cpu().numpy()) # 移动到CPU以节省GPU内存
- y_trues_list.append(y_batch.cpu().numpy())
- # 处理 group_info(需转换为可序列化格式)
- for i in range(len(group_ids_batch[0])):
- group_id = tuple(g[i].item() if isinstance(g, torch.Tensor) else g[i] for g in group_ids_batch)
- group_info_list.append(group_id)
- pass
-
- # 合并当前进程的结果
- y_preds = np.concatenate(y_preds_list, axis=0)
- y_trues = np.concatenate(y_trues_list, axis=0)
- group_info = group_info_list
-
- # --- 分布式结果聚合 ---
- if flag_distributed:
- # 收集所有进程的预测结果
- y_preds_tensor = torch.tensor(y_preds, device=device)
- y_trues_tensor = torch.tensor(y_trues, device=device)
- # 收集所有进程的 y_preds 和 y_trues
- gather_y_preds = [torch.zeros_like(y_preds_tensor) for _ in range(world_size)]
- gather_y_trues = [torch.zeros_like(y_trues_tensor) for _ in range(world_size)]
- dist.all_gather(gather_y_preds, y_preds_tensor)
- dist.all_gather(gather_y_trues, y_trues_tensor)
- # 合并结果到 rank 0
- if rank == 0:
- y_preds = np.concatenate([tensor.cpu().numpy() for tensor in gather_y_preds], axis=0)
- y_trues = np.concatenate([tensor.cpu().numpy() for tensor in gather_y_trues], axis=0)
- # 将 group_info 转换为字符串列表以便传输
- group_info_str = ['|'.join(map(str, info)) for info in group_info]
- gather_group_info = [None for _ in range(world_size)]
- dist.all_gather_object(gather_group_info, group_info_str)
- if rank == 0:
- group_info = []
- for info_list in gather_group_info:
- for info_str in info_list:
- group_info.append(tuple(info_str.split('|')))
- # --- 仅在 rank 0 计算指标并保存结果 ---
- if rank == 0:
-
- # 分类任务结果
- y_preds_class = y_preds[:, 0]
- y_trues_class = y_trues[:, 0]
- y_preds_class_labels = (y_preds_class >= 0.5).astype(int)
- y_trues_class_labels = y_trues_class.astype(int)
- # 打印指标
- printScore_cc(y_trues_class_labels, y_preds_class_labels, batch_fn_str=batch_fn_str, batch_idx=batch_idx, evalute_flag=evalute_flag)
- # 构造 DataFrame
- results_df = pd.DataFrame({
- 'city_pair': [info[0] for info in group_info],
- 'flight_day': [info[1] for info in group_info],
- 'flight_number_1': [info[2] for info in group_info],
- 'flight_number_2': [info[3] for info in group_info],
- 'from_date': [info[4] for info in group_info],
- 'baggage': [info[5] for info in group_info],
- 'price': [info[6] for info in group_info],
- 'Hours_until_Departure': [info[7] for info in group_info],
- 'update_hour': [info[8] for info in group_info],
- 'target_amount_of_drop': [info[9] for info in group_info], # 训练时的验证才有这两个target列
- 'target_time_to_drop': [info[10] for info in group_info],
- 'probability': y_preds_class,
- 'Actual_Will_Price_Drop': y_trues_class_labels,
- 'Predicted_Will_Price_Drop': y_preds_class_labels,
- })
- # 数值处理
- threshold = 1e-3
- numeric_columns = ['probability', 'target_amount_of_drop', 'target_time_to_drop']
- for col in numeric_columns:
- results_df[col] = results_df[col].where(results_df[col].abs() >= threshold, 0)
- if col in ['target_time_to_drop']:
- results_df[col] = results_df[col].round(0).astype(int)
- if col in ['target_amount_of_drop']:
- results_df[col] = results_df[col].round(2)
-
- # 保存结果
- results_df_path = os.path.join(output_dir, csv_file)
- if save_mode == 'a':
- # 追加模式
- results_df.to_csv(results_df_path, mode='a', index=False, header=not os.path.exists(results_df_path))
- else:
- # 重写模式
- results_df.to_csv(results_df_path, mode='w', index=False, header=True)
- print(f"预测结果已保存到 '{results_df_path}'")
-
- return results_df
-
- else:
- return None
- def printScore_cc(y_trues_class_labels, y_preds_class_labels, batch_fn_str='', batch_idx=-1, evalute_flag='evaluate'):
-
- accuracy = accuracy_score(y_trues_class_labels, y_preds_class_labels)
- precision = precision_score(y_trues_class_labels, y_preds_class_labels, zero_division=0)
- recall = recall_score(y_trues_class_labels, y_preds_class_labels, zero_division=0)
- f1 = f1_score(y_trues_class_labels, y_preds_class_labels, zero_division=0)
- print(f"分类准确率: {accuracy:.4f}")
- print(f"分类精确率: {precision:.4f}")
- print(f"分类召回率: {recall:.4f}")
- print(f"分类F1值: {f1:.4f}")
- # 获取二者都为1的正例索引
- indices = np.where((y_trues_class_labels == 1) & (y_preds_class_labels == 1))[0]
- if len(indices) > 0:
- pass
- else:
- print("没有正例")
- font_path = "./simhei.ttf"
- font_prop = font_manager.FontProperties(fname=font_path)
- # font_prop = font.font_prop
-
- # 混淆矩阵
- cm = confusion_matrix(y_trues_class_labels, y_preds_class_labels)
- plt.figure(figsize=(6, 5))
- sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
- xticklabels=['预测:不会降价', '预测:会降价'],
- yticklabels=['实际:不会降价', '实际:会降价'])
- plt.xticks(fontproperties=font_prop)
- plt.yticks(fontproperties=font_prop)
- plt.xlabel('预测情况', fontproperties=font_prop)
- plt.ylabel('实际结果', fontproperties=font_prop)
- plt.title('分类结果的混淆矩阵', fontproperties=font_prop)
- plt.savefig(f"./photo/{evalute_flag}_confusion_matrix_{batch_idx}_{batch_fn_str}.png")
|