| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273 |
- import datetime
- import torch
- import pandas as pd
- import numpy as np
- import os
- from torch.utils.data import DataLoader
- from utils import FlightDataset
- def predict_future_distribute(model, sequences, group_ids, batch_size=16, target_scaler=None, output_dir=".", pred_time_str=""):
- if not sequences:
- print("没有足够的数据进行预测。")
- return
-
- test_dataset = FlightDataset(sequences, None, group_ids)
- test_loader = DataLoader(test_dataset, batch_size=batch_size)
- model.eval()
- y_preds = []
- group_info = []
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- with torch.no_grad():
- for X_batch, group_ids in test_loader:
- X_batch = X_batch.to(device)
- outputs = model(X_batch)
- y_preds.extend(outputs.cpu().numpy())
- for i in range(len(group_ids[0])):
- group_id = tuple(group_ids_elem[i] for group_ids_elem in group_ids)
- group_info.append(group_id)
-
- y_preds = np.array(y_preds)
- y_preds_class = y_preds[:, 0]
- y_preds_class_labels = (y_preds_class >= 0.5).astype(int)
-
- results_df = pd.DataFrame({
- 'city_pair': [info[0] for info in group_info],
- 'flight_day': [info[1] for info in group_info],
- 'flight_number_1': [info[2] for info in group_info],
- 'flight_number_2': [info[3] for info in group_info],
- 'from_date': [info[4] for info in group_info],
- 'baggage': [info[5] for info in group_info],
- 'price': [info[6] for info in group_info],
- 'Hours_until_Departure': [info[7] for info in group_info],
- 'update_hour': [info[8] for info in group_info],
- 'crawl_date': [info[9] for info in group_info],
- 'probability': y_preds_class,
- 'Predicted_Will_Price_Drop': y_preds_class_labels,
- })
- # 先转成 datetime
- update_hour_dt = pd.to_datetime(results_df['update_hour']) # 起飞前36小时对应时间(整点)
- valid_begin_dt = update_hour_dt + pd.Timedelta(hours=8) # 起飞前28小时(36-8=28)对应时间(整点)
- # 在 probability 前新增一列
- results_df.insert(
- loc=results_df.columns.get_loc('probability'),
- column='valid_begin_hour',
- value=valid_begin_dt.dt.strftime('%Y-%m-%d %H:%M:%S')
- )
-
- # 数值处理
- threshold = 1e-3
- numeric_columns = ['probability']
- for col in numeric_columns:
- results_df[col] = results_df[col].where(results_df[col].abs() >= threshold, 0)
-
- csv_path1 = os.path.join(output_dir, f'future_predictions_{pred_time_str}.csv')
- results_df.to_csv(csv_path1, mode='a', index=False, header=not os.path.exists(csv_path1), encoding='utf-8-sig')
- print("预测结果已追加")
- return results_df
|