import datetime import torch import pandas as pd import numpy as np import os from torch.utils.data import DataLoader from utils import FlightDataset def predict_future_distribute(model, sequences, group_ids, batch_size=16, target_scaler=None, output_dir=".", pred_time_str=""): if not sequences: print("没有足够的数据进行预测。") return test_dataset = FlightDataset(sequences, None, group_ids) test_loader = DataLoader(test_dataset, batch_size=batch_size) model.eval() y_preds = [] group_info = [] device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') with torch.no_grad(): for X_batch, group_ids in test_loader: X_batch = X_batch.to(device) outputs = model(X_batch) y_preds.extend(outputs.cpu().numpy()) for i in range(len(group_ids[0])): group_id = tuple(group_ids_elem[i] for group_ids_elem in group_ids) group_info.append(group_id) y_preds = np.array(y_preds) y_preds_class = y_preds[:, 0] y_preds_class_labels = (y_preds_class >= 0.5).astype(int) results_df = pd.DataFrame({ 'city_pair': [info[0] for info in group_info], 'flight_day': [info[1] for info in group_info], 'flight_number_1': [info[2] for info in group_info], 'flight_number_2': [info[3] for info in group_info], 'from_date': [info[4] for info in group_info], 'baggage': [info[5] for info in group_info], 'price': [info[6] for info in group_info], 'Hours_until_Departure': [info[7] for info in group_info], 'update_hour': [info[8] for info in group_info], 'crawl_date': [info[9] for info in group_info], 'probability': y_preds_class, 'Predicted_Will_Price_Drop': y_preds_class_labels, }) # 先转成 datetime update_hour_dt = pd.to_datetime(results_df['update_hour']) # 起飞前36小时对应时间(整点) valid_begin_dt = update_hour_dt + pd.Timedelta(hours=8) # 起飞前28小时(36-8=28)对应时间(整点) # 在 probability 前新增一列 results_df.insert( loc=results_df.columns.get_loc('probability'), column='valid_begin_hour', value=valid_begin_dt.dt.strftime('%Y-%m-%d %H:%M:%S') ) # 数值处理 threshold = 1e-3 numeric_columns = ['probability'] for col in numeric_columns: results_df[col] = results_df[col].where(results_df[col].abs() >= threshold, 0) csv_path1 = os.path.join(output_dir, f'future_predictions_{pred_time_str}.csv') results_df.to_csv(csv_path1, mode='a', index=False, header=not os.path.exists(csv_path1), encoding='utf-8-sig') print("预测结果已追加") return results_df