import gc import time import torch from torch.utils.data import Dataset # 航线列表分组切片并带上索引 def chunk_list_with_index(lst, group_size): return [(i, lst[i:i + group_size]) for i in range(0, len(lst), group_size)] # pandas 在指定列之前插入新列 def insert_df_col(df, insert_col_name, base_col_name, inplace=True): if not inplace: df = df.copy() if base_col_name not in df.columns: raise ValueError(f"base_col_name '{base_col_name}' 不存在") if insert_col_name not in df.columns: raise ValueError(f"insert_col_name '{insert_col_name}' 不存在") if base_col_name == insert_col_name: return df insert_idx = df.columns.get_loc(base_col_name) col_data = df.pop(insert_col_name) df.insert(insert_idx, insert_col_name, col_data) return df # 真正创建序列过程 def create_fixed_length_sequences(df, features, target_vars, input_length=452, is_train=True): print(">>开始创建序列") start_time = time.time() sequences = [] targets = [] group_ids = [] threshold = 28 # 截止起飞小时数 # gid 基于 city_pair, flight_day, flight_number_1, flight_number_2 分组 不包括 baggage grouped = df.groupby(['gid']) for _, df_group in grouped: city_pair = df_group['city_pair'].iloc[0] flight_day = df_group['flight_day'].iloc[0] flight_number_1 = df_group['flight_number_1'].iloc[0] flight_number_2 = df_group['flight_number_2'].iloc[0] dep_time_str = df_group['dep_time_1'].iloc[0].strftime('%Y-%m-%d %H:%M:%S') # 按行李配额分开 df_group_bag_30 = df_group[df_group['baggage']==30] df_group_bag_20 = df_group[df_group['baggage']==20] # 过滤训练时间段 (28 ~ 480) df_group_bag_30_filtered = df_group_bag_30[(df_group_bag_30['Hours_Until_Departure'] >= threshold) & (df_group_bag_30['Hours_Until_Departure'] < threshold + input_length)] df_group_bag_20_filtered = df_group_bag_20[(df_group_bag_20['Hours_Until_Departure'] >= threshold) & (df_group_bag_20['Hours_Until_Departure'] < threshold + input_length)] # 条件: 长度要一致 condition_list = [ len(df_group_bag_30_filtered) == input_length, len(df_group_bag_20_filtered) == input_length, ] if all(condition_list): seq_features_1 = df_group_bag_30_filtered[features].to_numpy() seq_features_2 = df_group_bag_20_filtered[features].to_numpy() # 将几个特征序列沿着第 0 维拼接,得到形状为 (2, 452, 25) combined_features = torch.stack([torch.tensor(seq_features_1, dtype=torch.float32), torch.tensor(seq_features_2, dtype=torch.float32)]) # 将拼接后的结果添加到 sequences 列表中 sequences.append(combined_features) if is_train and target_vars: seq_targets = df_group_bag_30_filtered[target_vars].iloc[0].to_numpy() targets.append(torch.tensor(seq_targets, dtype=torch.float32)) name_c = [city_pair, flight_day, flight_number_1, flight_number_2, dep_time_str] # 直接获取最后一行的相关信息 last_row = df_group_bag_30_filtered.iloc[-1] name_c.extend([str(last_row['baggage']), str(last_row['Adult_Total_Price']), str(last_row['Hours_Until_Departure']), str(last_row['update_hour'])]) group_ids.append(tuple(name_c)) del df_group_bag_30_filtered, df_group_bag_20_filtered del df_group_bag_30, df_group_bag_20 del df_group gc.collect() print(">>结束创建序列") end_time = time.time() run_time = round(end_time - start_time, 3) print(f"用时: {run_time} 秒") print(f"生成的序列数量:{len(sequences)}") return sequences, targets, group_ids class FlightDataset(Dataset): def __init__(self, X_sequences, y_sequences=None, group_ids=None): self.X_sequences = X_sequences self.y_sequences = y_sequences self.group_ids = group_ids self.return_group_ids = group_ids is not None def __len__(self): return len(self.X_sequences) def __getitem__(self, idx): if self.return_group_ids: if self.y_sequences: return self.X_sequences[idx], self.y_sequences[idx], self.group_ids[idx] else: return self.X_sequences[idx], self.group_ids[idx] else: if self.y_sequences: return self.X_sequences[idx], self.y_sequences[idx] else: return self.X_sequences[idx] class EarlyStoppingDist: """早停机制(分布式)""" def __init__(self, patience=10, verbose=False, delta=0, path='best_model.pth', rank=0, local_rank=0): """ Args: patience (int): 在训练集(或验证集)损失不再改善时,等待多少个epoch后停止训练 verbose (bool): 是否打印相关信息 delta (float): 训练集损失需要减少的最小变化量 path (str): 保存最佳模型的路径 """ self.patience = patience self.verbose = verbose self.delta = delta self.path = path self.counter = 0 self.best_loss = None self.early_stop = False self.rank = rank self.local_rank = local_rank def __call__(self, loss, model): if self.best_loss is None: self.best_loss = loss self.save_checkpoint(loss, model) elif loss > self.best_loss - self.delta: self.counter += 1 if self.verbose and self.rank == 0: print(f'Rank:{self.rank}, Local Rank:{self.local_rank}, EarlyStopping counter: {self.counter} out of {self.patience}') if self.counter >= self.patience: self.early_stop = True else: self.save_checkpoint(loss, model) self.best_loss = loss self.counter = 0 if self.is_nan(loss): self.counter += self.patience # 立即触发早停 self.early_stop = True def is_nan(self, loss): """检查损失值是否为NaN(通用方法)""" try: # 所有NaN类型都不等于自身 return loss != loss except Exception: # 处理不支持比较的类型 return False def save_checkpoint(self, loss, model): """保存模型""" if self.verbose and self.rank == 0: print(f'Rank:{self.rank}, Local Rank:{self.local_rank}, Loss decreased ({self.best_loss:.6f} --> {loss:.6f}). Saving model ...') torch.save(model.state_dict(), self.path)