predict.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. import datetime
  2. import torch
  3. import pandas as pd
  4. import numpy as np
  5. import os
  6. from torch.utils.data import DataLoader
  7. from utils import FlightDataset
  8. def predict_future_distribute(model, sequences, group_ids, batch_size=16, target_scaler=None, output_dir="."):
  9. if not sequences:
  10. print("没有足够的数据进行预测。")
  11. return
  12. test_dataset = FlightDataset(sequences, None, group_ids)
  13. test_loader = DataLoader(test_dataset, batch_size=batch_size)
  14. model.eval()
  15. y_preds = []
  16. group_info = []
  17. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  18. with torch.no_grad():
  19. for X_batch, group_ids in test_loader:
  20. X_batch = X_batch.to(device)
  21. outputs = model(X_batch)
  22. y_preds.extend(outputs.cpu().numpy())
  23. for i in range(len(group_ids[0])):
  24. group_id = tuple(group_ids_elem[i] for group_ids_elem in group_ids)
  25. group_info.append(group_id)
  26. y_preds = np.array(y_preds)
  27. y_preds_class = y_preds[:, 0]
  28. y_preds_class_labels = (y_preds_class >= 0.5).astype(int)
  29. results_df = pd.DataFrame({
  30. 'city_pair': [info[0] for info in group_info],
  31. 'flight_day': [info[1] for info in group_info],
  32. 'flight_number_1': [info[2] for info in group_info],
  33. 'flight_number_2': [info[3] for info in group_info],
  34. 'from_date': [info[4] for info in group_info],
  35. 'baggage': [info[5] for info in group_info],
  36. 'price': [info[6] for info in group_info],
  37. 'Hours_until_Departure': [info[7] for info in group_info],
  38. 'update_hour': [info[8] for info in group_info],
  39. 'crawl_date': [info[9] for info in group_info],
  40. 'probability': y_preds_class,
  41. 'Predicted_Will_Price_Drop': y_preds_class_labels,
  42. })
  43. # 先转成 datetime
  44. update_hour_dt = pd.to_datetime(results_df['update_hour']) # 起飞前48小时对应时间(整点)
  45. valid_begin_dt = update_hour_dt + pd.Timedelta(hours=20) # 起飞前28小时(48-20=28)对应时间(整点)
  46. # 在 probability 前新增一列
  47. results_df.insert(
  48. loc=results_df.columns.get_loc('probability'),
  49. column='valid_begin_hour',
  50. value=valid_begin_dt.dt.strftime('%Y-%m-%d %H:%M:%S')
  51. )
  52. # 数值处理
  53. threshold = 1e-3
  54. numeric_columns = ['probability']
  55. for col in numeric_columns:
  56. results_df[col] = results_df[col].where(results_df[col].abs() >= threshold, 0)
  57. csv_path1 = os.path.join(output_dir, 'future_predictions.csv')
  58. results_df.to_csv(csv_path1, mode='a', index=False, header=not os.path.exists(csv_path1), encoding='utf-8-sig')
  59. print("预测结果已追加")
  60. return results_df