predict.py 3.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. import datetime
  2. import torch
  3. import pandas as pd
  4. import numpy as np
  5. import os
  6. from torch.utils.data import DataLoader
  7. from utils import FlightDataset
  8. def predict_future_distribute(model, sequences, group_ids, batch_size=16, target_scaler=None, interval_hours=8, predict_dir=".", pred_time_str=""):
  9. if not sequences:
  10. print("没有足够的数据进行预测。")
  11. return
  12. test_dataset = FlightDataset(sequences, None, group_ids)
  13. test_loader = DataLoader(test_dataset, batch_size=batch_size)
  14. model.eval()
  15. y_preds = []
  16. group_info = []
  17. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  18. with torch.no_grad():
  19. for X_batch, group_ids in test_loader:
  20. X_batch = X_batch.to(device)
  21. outputs = model(X_batch)
  22. y_preds.extend(outputs.cpu().numpy())
  23. for i in range(len(group_ids[0])):
  24. group_id = tuple(group_ids_elem[i] for group_ids_elem in group_ids)
  25. group_info.append(group_id)
  26. y_preds = np.array(y_preds)
  27. y_preds_class = y_preds[:, 0]
  28. y_preds_class_labels = (y_preds_class >= 0.5).astype(int)
  29. results_df = pd.DataFrame({
  30. 'city_pair': [info[0] for info in group_info],
  31. 'flight_day': [info[1] for info in group_info],
  32. 'flight_number_1': [info[2] for info in group_info],
  33. 'flight_number_2': [info[3] for info in group_info],
  34. 'from_date': [info[4] for info in group_info],
  35. 'baggage': [info[5] for info in group_info],
  36. 'price': [info[6] for info in group_info],
  37. 'Hours_until_Departure': [info[7] for info in group_info],
  38. 'update_hour': [info[8] for info in group_info],
  39. 'crawl_date': [info[9] for info in group_info],
  40. 'probability': y_preds_class,
  41. 'Predicted_Will_Price_Drop': y_preds_class_labels,
  42. })
  43. # 先转成 datetime
  44. update_hour_dt = pd.to_datetime(results_df['update_hour']) # 起飞前36小时对应时间(整点)
  45. valid_begin_dt = update_hour_dt + pd.Timedelta(hours=interval_hours) # 起飞前28小时(36-8=28)(32-4=28)(30-2=28)对应时间(整点)
  46. valid_end_dt = valid_begin_dt + pd.Timedelta(hours=24) # 起飞前4小时(28-24=4)
  47. # 统一格式化
  48. valid_begin_str = valid_begin_dt.dt.strftime('%Y-%m-%d %H:%M:%S')
  49. valid_end_str = valid_end_dt.dt.strftime('%Y-%m-%d %H:%M:%S')
  50. # probability 列的位置
  51. prob_col_idx = results_df.columns.get_loc('probability')
  52. # interval_hours(统一数值)
  53. results_df.insert(
  54. loc=prob_col_idx,
  55. column='interval_hours',
  56. value=interval_hours
  57. )
  58. # valid_begin_hour
  59. results_df.insert(
  60. loc=prob_col_idx + 1, # 原 probability 列的位置 加1
  61. column='valid_begin_hour',
  62. value=valid_begin_str
  63. )
  64. # valid_end_hour
  65. results_df.insert(
  66. loc=prob_col_idx + 2, # 原 probability 列的位置 加2
  67. column='valid_end_hour',
  68. value=valid_end_str
  69. )
  70. # 数值处理
  71. threshold = 1e-3
  72. numeric_columns = ['probability']
  73. for col in numeric_columns:
  74. results_df[col] = results_df[col].where(results_df[col].abs() >= threshold, 0)
  75. csv_path1 = os.path.join(predict_dir, f'future_predictions_{pred_time_str}.csv')
  76. results_df.to_csv(csv_path1, mode='a', index=False, header=not os.path.exists(csv_path1), encoding='utf-8-sig')
  77. print("预测结果已追加")
  78. return results_df