| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495 |
- import datetime
- import torch
- import pandas as pd
- import numpy as np
- import os
- from torch.utils.data import DataLoader
- from utils import FlightDataset
- def predict_future_distribute(model, sequences, group_ids, batch_size=16, target_scaler=None, interval_hours=8, predict_dir=".", pred_time_str=""):
- if not sequences:
- print("没有足够的数据进行预测。")
- return
-
- test_dataset = FlightDataset(sequences, None, group_ids)
- test_loader = DataLoader(test_dataset, batch_size=batch_size)
- model.eval()
- y_preds = []
- group_info = []
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- with torch.no_grad():
- for X_batch, group_ids in test_loader:
- X_batch = X_batch.to(device)
- outputs = model(X_batch)
- y_preds.extend(outputs.cpu().numpy())
- for i in range(len(group_ids[0])):
- group_id = tuple(group_ids_elem[i] for group_ids_elem in group_ids)
- group_info.append(group_id)
-
- y_preds = np.array(y_preds)
- y_preds_class = y_preds[:, 0]
- y_preds_class_labels = (y_preds_class >= 0.5).astype(int)
-
- results_df = pd.DataFrame({
- 'city_pair': [info[0] for info in group_info],
- 'flight_day': [info[1] for info in group_info],
- 'flight_number_1': [info[2] for info in group_info],
- 'flight_number_2': [info[3] for info in group_info],
- 'from_date': [info[4] for info in group_info],
- 'baggage': [info[5] for info in group_info],
- 'price': [info[6] for info in group_info],
- 'Hours_until_Departure': [info[7] for info in group_info],
- 'update_hour': [info[8] for info in group_info],
- 'crawl_date': [info[9] for info in group_info],
- 'probability': y_preds_class,
- 'Predicted_Will_Price_Drop': y_preds_class_labels,
- })
- # 先转成 datetime
- update_hour_dt = pd.to_datetime(results_df['update_hour']) # 起飞前36小时对应时间(整点)
- valid_begin_dt = update_hour_dt + pd.Timedelta(hours=interval_hours) # 起飞前28小时(36-8=28)(32-4=28)(30-2=28)对应时间(整点)
- valid_end_dt = valid_begin_dt + pd.Timedelta(hours=24) # 起飞前4小时(28-24=4)
- # 统一格式化
- valid_begin_str = valid_begin_dt.dt.strftime('%Y-%m-%d %H:%M:%S')
- valid_end_str = valid_end_dt.dt.strftime('%Y-%m-%d %H:%M:%S')
- # probability 列的位置
- prob_col_idx = results_df.columns.get_loc('probability')
-
- # interval_hours(统一数值)
- results_df.insert(
- loc=prob_col_idx,
- column='interval_hours',
- value=interval_hours
- )
- # valid_begin_hour
- results_df.insert(
- loc=prob_col_idx + 1, # 原 probability 列的位置 加1
- column='valid_begin_hour',
- value=valid_begin_str
- )
- # valid_end_hour
- results_df.insert(
- loc=prob_col_idx + 2, # 原 probability 列的位置 加2
- column='valid_end_hour',
- value=valid_end_str
- )
-
- # 数值处理
- threshold = 1e-3
- numeric_columns = ['probability']
- for col in numeric_columns:
- results_df[col] = results_df[col].where(results_df[col].abs() >= threshold, 0)
-
- csv_path1 = os.path.join(predict_dir, f'future_predictions_{pred_time_str}.csv')
- results_df.to_csv(csv_path1, mode='a', index=False, header=not os.path.exists(csv_path1), encoding='utf-8-sig')
- print("预测结果已追加")
- return results_df
|