model.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import torch
  2. import torch.nn as nn
  3. # 分类模型 (Transformer)
  4. class PriceDropClassifiTransModel(nn.Module):
  5. def __init__(self, input_size, num_periods=2, hidden_size=128, num_layers=3, output_size=1, dropout=0.3, conv_out_channels=64, kernel_size=3, num_heads=8):
  6. super(PriceDropClassifiTransModel, self).__init__()
  7. self.hidden_size = hidden_size
  8. self.num_layers = num_layers
  9. self.num_periods = num_periods
  10. # 卷积层
  11. self.conv1 = nn.Conv1d(
  12. in_channels=input_size * num_periods,
  13. out_channels=conv_out_channels,
  14. kernel_size=kernel_size,
  15. padding=kernel_size // 2,
  16. bias=False,
  17. )
  18. self.relu = nn.ReLU()
  19. # Transformer Encoder
  20. self.transformer_layer = nn.TransformerEncoderLayer(
  21. d_model=conv_out_channels,
  22. # d_model=input_size * num_periods, # 这里的d_model应为输入的特征数量, d_model能被num_heads整除
  23. nhead=num_heads,
  24. dim_feedforward=hidden_size,
  25. dropout=dropout
  26. )
  27. self.transformer_encoder = nn.TransformerEncoder(
  28. self.transformer_layer,
  29. num_layers=num_layers
  30. )
  31. # 注意力机制
  32. self.attention_layer = nn.Sequential(
  33. nn.Linear(conv_out_channels, hidden_size),
  34. # nn.Linear(input_size * num_periods, hidden_size),
  35. # nn.Conv1d(conv_out_channels, hidden_size),
  36. # nn.Tanh(),
  37. nn.ReLU(),
  38. nn.Linear(hidden_size, 1)
  39. )
  40. # 分类和回归输出层
  41. self.fc_classification = nn.Linear(conv_out_channels, 1)
  42. def forward(self, x):
  43. """
  44. 输入x的形状应为 [batch_size, num_periods, seq_length, input_size]
  45. """
  46. batch_size, num_periods, seq_length, input_size = x.size()
  47. # x = x[:,0,:,:].view(batch_size, 1, input_size, seq_length)
  48. # 将输入转换为 [batch_size, num_periods * input_size, seq_length]
  49. x = x.permute(0, 1, 3, 2).contiguous() # [batch_size, num_periods, input_size, seq_length]
  50. x = x.view(batch_size, num_periods * input_size, seq_length) # [batch_size, num_periods * input_size, seq_length]
  51. # x = x.view(batch_size, 1 * input_size, seq_length)
  52. # 经过卷积层和激活函数
  53. x = self.conv1(x) # [batch_size, conv_out_channels, seq_length]
  54. x = self.relu(x)
  55. # 转置以适应Transformer输入要求
  56. x = x.permute(2, 0, 1) # [seq_length, batch_size, conv_out_channels(num_periods * input_size)]
  57. # 经过Transformer编码器
  58. x = self.transformer_encoder(x) # [seq_length, batch_size, conv_out_channels(num_periods * input_size)]
  59. # 计算注意力
  60. attention_scores = self.attention_layer(x) # [seq_length, batch_size, 1]
  61. attention_weights = torch.softmax(attention_scores, dim=0) # [seq_length, batch_size, 1]
  62. # 对所有时间步进行加权求和
  63. context_vector = torch.sum(attention_weights * x, dim=0) # [batch_size, conv_out_channels(num_periods * input_size)]
  64. # 取最后一个时间步的输出进行分类和回归
  65. # context_vector = x[-1, :, :] # [batch_size, conv_out_channels(num_periods * input_size)]
  66. # 分类和回归输出
  67. classification_output = torch.sigmoid(self.fc_classification(context_vector)) # [batch_size, 1]
  68. # 打印检查:输出范围
  69. # print(f"Before clamp: min: {classification_output.min().item()}, max: {classification_output.max().item()}")
  70. # 将输出值限制在 [0.0001, 0.9999] 范围内,以避免数值极端
  71. # classification_output = torch.clamp(classification_output, min=1e-4, max=1 - 1e-4)
  72. return classification_output