utils.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. import gc
  2. import time
  3. import torch
  4. from torch.utils.data import Dataset
  5. # 航线列表分组切片并带上索引
  6. def chunk_list_with_index(lst, group_size):
  7. return [(i, lst[i:i + group_size]) for i in range(0, len(lst), group_size)]
  8. # pandas 在指定列之前插入新列
  9. def insert_df_col(df, insert_col_name, base_col_name, inplace=True):
  10. if not inplace:
  11. df = df.copy()
  12. if base_col_name not in df.columns:
  13. raise ValueError(f"base_col_name '{base_col_name}' 不存在")
  14. if insert_col_name not in df.columns:
  15. raise ValueError(f"insert_col_name '{insert_col_name}' 不存在")
  16. if base_col_name == insert_col_name:
  17. return df
  18. insert_idx = df.columns.get_loc(base_col_name)
  19. col_data = df.pop(insert_col_name)
  20. df.insert(insert_idx, insert_col_name, col_data)
  21. return df
  22. # 真正创建序列过程
  23. def create_fixed_length_sequences(df, features, target_vars, input_length=452, is_train=True):
  24. print(">>开始创建序列")
  25. start_time = time.time()
  26. sequences = []
  27. targets = []
  28. group_ids = []
  29. threshold = 28 # 截止起飞小时数
  30. # gid 基于 city_pair, flight_day, flight_number_1, flight_number_2 分组 不包括 baggage
  31. grouped = df.groupby(['gid'])
  32. for _, df_group in grouped:
  33. city_pair = df_group['city_pair'].iloc[0]
  34. flight_day = df_group['flight_day'].iloc[0]
  35. flight_number_1 = df_group['flight_number_1'].iloc[0]
  36. flight_number_2 = df_group['flight_number_2'].iloc[0]
  37. dep_time_str = df_group['dep_time_1'].iloc[0].strftime('%Y-%m-%d %H:%M:%S')
  38. # 按行李配额分开
  39. df_group_bag_30 = df_group[df_group['baggage']==30]
  40. df_group_bag_20 = df_group[df_group['baggage']==20]
  41. # 过滤训练时间段 (28 ~ 480)
  42. df_group_bag_30_filtered = df_group_bag_30[(df_group_bag_30['Hours_Until_Departure'] >= threshold) & (df_group_bag_30['Hours_Until_Departure'] < threshold + input_length)]
  43. df_group_bag_20_filtered = df_group_bag_20[(df_group_bag_20['Hours_Until_Departure'] >= threshold) & (df_group_bag_20['Hours_Until_Departure'] < threshold + input_length)]
  44. # 条件: 长度要一致
  45. condition_list = [
  46. len(df_group_bag_30_filtered) == input_length,
  47. len(df_group_bag_20_filtered) == input_length,
  48. ]
  49. if all(condition_list):
  50. seq_features_1 = df_group_bag_30_filtered[features].to_numpy()
  51. seq_features_2 = df_group_bag_20_filtered[features].to_numpy()
  52. # 将几个特征序列沿着第 0 维拼接,得到形状为 (2, 452, 25)
  53. combined_features = torch.stack([torch.tensor(seq_features_1, dtype=torch.float32),
  54. torch.tensor(seq_features_2, dtype=torch.float32)])
  55. # 将拼接后的结果添加到 sequences 列表中
  56. sequences.append(combined_features)
  57. if is_train and target_vars:
  58. seq_targets = df_group_bag_30_filtered[target_vars].iloc[0].to_numpy()
  59. targets.append(torch.tensor(seq_targets, dtype=torch.float32))
  60. name_c = [city_pair, flight_day, flight_number_1, flight_number_2, dep_time_str]
  61. # 直接获取最后一行的相关信息
  62. last_row = df_group_bag_30_filtered.iloc[-1]
  63. next_name_li = [str(last_row['baggage']),
  64. str(last_row['Adult_Total_Price']),
  65. str(last_row['Hours_Until_Departure']),
  66. str(last_row['update_hour']),
  67. ]
  68. if is_train:
  69. next_name_li.append(last_row['target_amount_of_drop'])
  70. next_name_li.append(last_row['target_time_to_drop'])
  71. name_c.extend(next_name_li)
  72. group_ids.append(tuple(name_c))
  73. del df_group_bag_30_filtered, df_group_bag_20_filtered
  74. del df_group_bag_30, df_group_bag_20
  75. del df_group
  76. gc.collect()
  77. print(">>结束创建序列")
  78. end_time = time.time()
  79. run_time = round(end_time - start_time, 3)
  80. print(f"用时: {run_time} 秒")
  81. print(f"生成的序列数量:{len(sequences)}")
  82. return sequences, targets, group_ids
  83. class FlightDataset(Dataset):
  84. def __init__(self, X_sequences, y_sequences=None, group_ids=None):
  85. self.X_sequences = X_sequences
  86. self.y_sequences = y_sequences
  87. self.group_ids = group_ids
  88. self.return_group_ids = group_ids is not None
  89. def __len__(self):
  90. return len(self.X_sequences)
  91. def __getitem__(self, idx):
  92. if self.return_group_ids:
  93. if self.y_sequences:
  94. return self.X_sequences[idx], self.y_sequences[idx], self.group_ids[idx]
  95. else:
  96. return self.X_sequences[idx], self.group_ids[idx]
  97. else:
  98. if self.y_sequences:
  99. return self.X_sequences[idx], self.y_sequences[idx]
  100. else:
  101. return self.X_sequences[idx]
  102. class EarlyStoppingDist:
  103. """早停机制(分布式)"""
  104. def __init__(self, patience=10, verbose=False, delta=0, path='best_model.pth', rank=0, local_rank=0):
  105. """
  106. Args:
  107. patience (int): 在训练集(或验证集)损失不再改善时,等待多少个epoch后停止训练
  108. verbose (bool): 是否打印相关信息
  109. delta (float): 训练集损失需要减少的最小变化量
  110. path (str): 保存最佳模型的路径
  111. """
  112. self.patience = patience
  113. self.verbose = verbose
  114. self.delta = delta
  115. self.path = path
  116. self.counter = 0
  117. self.best_loss = None
  118. self.early_stop = False
  119. self.rank = rank
  120. self.local_rank = local_rank
  121. def __call__(self, loss, model):
  122. if self.best_loss is None:
  123. self.best_loss = loss
  124. self.save_checkpoint(loss, model)
  125. elif loss > self.best_loss - self.delta:
  126. self.counter += 1
  127. if self.verbose and self.rank == 0:
  128. print(f'Rank:{self.rank}, Local Rank:{self.local_rank}, EarlyStopping counter: {self.counter} out of {self.patience}')
  129. if self.counter >= self.patience:
  130. self.early_stop = True
  131. else:
  132. self.save_checkpoint(loss, model)
  133. self.best_loss = loss
  134. self.counter = 0
  135. if self.is_nan(loss):
  136. self.counter += self.patience # 立即触发早停
  137. self.early_stop = True
  138. def is_nan(self, loss):
  139. """检查损失值是否为NaN(通用方法)"""
  140. try:
  141. # 所有NaN类型都不等于自身
  142. return loss != loss
  143. except Exception:
  144. # 处理不支持比较的类型
  145. return False
  146. def save_checkpoint(self, loss, model):
  147. """保存模型"""
  148. if self.verbose and self.rank == 0:
  149. print(f'Rank:{self.rank}, Local Rank:{self.local_rank}, Loss decreased ({self.best_loss:.6f} --> {loss:.6f}). Saving model ...')
  150. torch.save(model.state_dict(), self.path)