utils.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. import gc
  2. import time
  3. import torch
  4. from torch.utils.data import Dataset
  5. # 航线列表分组切片并带上索引
  6. def chunk_list_with_index(lst, group_size):
  7. return [(i, lst[i:i + group_size]) for i in range(0, len(lst), group_size)]
  8. # pandas 在指定列之前插入新列
  9. def insert_df_col(df, insert_col_name, base_col_name, inplace=True):
  10. if not inplace:
  11. df = df.copy()
  12. if base_col_name not in df.columns:
  13. raise ValueError(f"base_col_name '{base_col_name}' 不存在")
  14. if insert_col_name not in df.columns:
  15. raise ValueError(f"insert_col_name '{insert_col_name}' 不存在")
  16. if base_col_name == insert_col_name:
  17. return df
  18. insert_idx = df.columns.get_loc(base_col_name)
  19. col_data = df.pop(insert_col_name)
  20. df.insert(insert_idx, insert_col_name, col_data)
  21. return df
  22. # 真正创建序列过程
  23. def create_fixed_length_sequences(df, features, target_vars, threshold=48, input_length=432, is_train=True):
  24. print(">>开始创建序列")
  25. start_time = time.time()
  26. sequences = []
  27. targets = []
  28. group_ids = []
  29. # gid 基于 city_pair, flight_day, flight_number_1, flight_number_2 分组 不包括 baggage
  30. grouped = df.groupby(['gid'])
  31. for _, df_group in grouped:
  32. city_pair = df_group['city_pair'].iloc[0]
  33. flight_day = df_group['flight_day'].iloc[0]
  34. flight_number_1 = df_group['flight_number_1'].iloc[0]
  35. flight_number_2 = df_group['flight_number_2'].iloc[0]
  36. dep_time_str = df_group['dep_time_1'].iloc[0].strftime('%Y-%m-%d %H:%M:%S')
  37. # 按行李配额分开
  38. df_group_bag_30 = df_group[df_group['baggage']==30]
  39. df_group_bag_20 = df_group[df_group['baggage']==20]
  40. # 过滤训练时间段 (48 ~ 480)
  41. df_group_bag_30_filtered = df_group_bag_30[(df_group_bag_30['Hours_Until_Departure'] >= threshold) & (df_group_bag_30['Hours_Until_Departure'] < threshold + input_length)]
  42. df_group_bag_20_filtered = df_group_bag_20[(df_group_bag_20['Hours_Until_Departure'] >= threshold) & (df_group_bag_20['Hours_Until_Departure'] < threshold + input_length)]
  43. # 条件: 长度要一致
  44. condition_list = [
  45. len(df_group_bag_30_filtered) == input_length,
  46. len(df_group_bag_20_filtered) == input_length,
  47. ]
  48. if all(condition_list):
  49. seq_features_1 = df_group_bag_30_filtered[features].to_numpy()
  50. seq_features_2 = df_group_bag_20_filtered[features].to_numpy()
  51. # 将几个特征序列沿着第 0 维拼接,得到形状为 (2, 432, 25)
  52. combined_features = torch.stack([torch.tensor(seq_features_1, dtype=torch.float32),
  53. torch.tensor(seq_features_2, dtype=torch.float32)])
  54. # 将拼接后的结果添加到 sequences 列表中
  55. sequences.append(combined_features)
  56. if is_train and target_vars:
  57. seq_targets = df_group_bag_30_filtered[target_vars].iloc[0].to_numpy()
  58. targets.append(torch.tensor(seq_targets, dtype=torch.float32))
  59. name_c = [city_pair, flight_day, flight_number_1, flight_number_2, dep_time_str]
  60. # 直接获取最后一行的相关信息
  61. last_row = df_group_bag_30_filtered.iloc[-1]
  62. next_name_li = [str(last_row['baggage']),
  63. str(last_row['Adult_Total_Price']),
  64. str(last_row['Hours_Until_Departure']),
  65. str(last_row['update_hour']),
  66. ]
  67. if is_train:
  68. next_name_li.append(last_row['target_amount_of_drop'])
  69. next_name_li.append(last_row['target_time_to_drop'])
  70. name_c.extend(next_name_li)
  71. group_ids.append(tuple(name_c))
  72. del df_group_bag_30_filtered, df_group_bag_20_filtered
  73. del df_group_bag_30, df_group_bag_20
  74. del df_group
  75. gc.collect()
  76. print(">>结束创建序列")
  77. end_time = time.time()
  78. run_time = round(end_time - start_time, 3)
  79. print(f"用时: {run_time} 秒")
  80. print(f"生成的序列数量:{len(sequences)}")
  81. return sequences, targets, group_ids
  82. class FlightDataset(Dataset):
  83. def __init__(self, X_sequences, y_sequences=None, group_ids=None):
  84. self.X_sequences = X_sequences
  85. self.y_sequences = y_sequences
  86. self.group_ids = group_ids
  87. self.return_group_ids = group_ids is not None
  88. def __len__(self):
  89. return len(self.X_sequences)
  90. def __getitem__(self, idx):
  91. if self.return_group_ids:
  92. if self.y_sequences:
  93. return self.X_sequences[idx], self.y_sequences[idx], self.group_ids[idx]
  94. else:
  95. return self.X_sequences[idx], self.group_ids[idx]
  96. else:
  97. if self.y_sequences:
  98. return self.X_sequences[idx], self.y_sequences[idx]
  99. else:
  100. return self.X_sequences[idx]
  101. class EarlyStoppingDist:
  102. """早停机制(分布式)"""
  103. def __init__(self, patience=10, verbose=False, delta=0, path='best_model.pth', rank=0, local_rank=0):
  104. """
  105. Args:
  106. patience (int): 在训练集(或验证集)损失不再改善时,等待多少个epoch后停止训练
  107. verbose (bool): 是否打印相关信息
  108. delta (float): 训练集损失需要减少的最小变化量
  109. path (str): 保存最佳模型的路径
  110. """
  111. self.patience = patience
  112. self.verbose = verbose
  113. self.delta = delta
  114. self.path = path
  115. self.counter = 0
  116. self.best_loss = None
  117. self.early_stop = False
  118. self.rank = rank
  119. self.local_rank = local_rank
  120. def __call__(self, loss, model):
  121. if self.best_loss is None:
  122. self.best_loss = loss
  123. self.save_checkpoint(loss, model)
  124. elif loss > self.best_loss - self.delta:
  125. self.counter += 1
  126. if self.verbose and self.rank == 0:
  127. print(f'Rank:{self.rank}, Local Rank:{self.local_rank}, EarlyStopping counter: {self.counter} out of {self.patience}')
  128. if self.counter >= self.patience:
  129. self.early_stop = True
  130. else:
  131. self.save_checkpoint(loss, model)
  132. self.best_loss = loss
  133. self.counter = 0
  134. if self.is_nan(loss):
  135. self.counter += self.patience # 立即触发早停
  136. self.early_stop = True
  137. def is_nan(self, loss):
  138. """检查损失值是否为NaN(通用方法)"""
  139. try:
  140. # 所有NaN类型都不等于自身
  141. return loss != loss
  142. except Exception:
  143. # 处理不支持比较的类型
  144. return False
  145. def save_checkpoint(self, loss, model):
  146. """保存模型"""
  147. if self.verbose and self.rank == 0:
  148. print(f'Rank:{self.rank}, Local Rank:{self.local_rank}, Loss decreased ({self.best_loss:.6f} --> {loss:.6f}). Saving model ...')
  149. torch.save(model.state_dict(), self.path)