| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181 |
- import gc
- import time
- import torch
- from torch.utils.data import Dataset
- # 航线列表分组切片并带上索引
- def chunk_list_with_index(lst, group_size):
- return [(i, lst[i:i + group_size]) for i in range(0, len(lst), group_size)]
- # pandas 在指定列之前插入新列
- def insert_df_col(df, insert_col_name, base_col_name, inplace=True):
- if not inplace:
- df = df.copy()
- if base_col_name not in df.columns:
- raise ValueError(f"base_col_name '{base_col_name}' 不存在")
- if insert_col_name not in df.columns:
- raise ValueError(f"insert_col_name '{insert_col_name}' 不存在")
- if base_col_name == insert_col_name:
- return df
- insert_idx = df.columns.get_loc(base_col_name)
- col_data = df.pop(insert_col_name)
- df.insert(insert_idx, insert_col_name, col_data)
- return df
- # 真正创建序列过程
- def create_fixed_length_sequences(df, features, target_vars, input_length=452, is_train=True):
- print(">>开始创建序列")
- start_time = time.time()
- sequences = []
- targets = []
- group_ids = []
- threshold = 28 # 截止起飞小时数
- # gid 基于 city_pair, flight_day, flight_number_1, flight_number_2 分组 不包括 baggage
- grouped = df.groupby(['gid'])
- for _, df_group in grouped:
- city_pair = df_group['city_pair'].iloc[0]
- flight_day = df_group['flight_day'].iloc[0]
- flight_number_1 = df_group['flight_number_1'].iloc[0]
- flight_number_2 = df_group['flight_number_2'].iloc[0]
- dep_time_str = df_group['dep_time_1'].iloc[0].strftime('%Y-%m-%d %H:%M:%S')
-
- # 按行李配额分开
- df_group_bag_30 = df_group[df_group['baggage']==30]
- df_group_bag_20 = df_group[df_group['baggage']==20]
- # 过滤训练时间段 (28 ~ 480)
- df_group_bag_30_filtered = df_group_bag_30[(df_group_bag_30['Hours_Until_Departure'] >= threshold) & (df_group_bag_30['Hours_Until_Departure'] < threshold + input_length)]
- df_group_bag_20_filtered = df_group_bag_20[(df_group_bag_20['Hours_Until_Departure'] >= threshold) & (df_group_bag_20['Hours_Until_Departure'] < threshold + input_length)]
- # 条件: 长度要一致
- condition_list = [
- len(df_group_bag_30_filtered) == input_length,
- len(df_group_bag_20_filtered) == input_length,
- ]
- if all(condition_list):
- seq_features_1 = df_group_bag_30_filtered[features].to_numpy()
- seq_features_2 = df_group_bag_20_filtered[features].to_numpy()
-
- # 将几个特征序列沿着第 0 维拼接,得到形状为 (2, 452, 25)
- combined_features = torch.stack([torch.tensor(seq_features_1, dtype=torch.float32),
- torch.tensor(seq_features_2, dtype=torch.float32)])
- # 将拼接后的结果添加到 sequences 列表中
- sequences.append(combined_features)
- if is_train and target_vars:
- seq_targets = df_group_bag_30_filtered[target_vars].iloc[0].to_numpy()
- targets.append(torch.tensor(seq_targets, dtype=torch.float32))
-
- name_c = [city_pair, flight_day, flight_number_1, flight_number_2, dep_time_str]
- # 直接获取最后一行的相关信息
- last_row = df_group_bag_30_filtered.iloc[-1]
- next_name_li = [str(last_row['baggage']),
- str(last_row['Adult_Total_Price']),
- str(last_row['Hours_Until_Departure']),
- str(last_row['update_hour']),
- ]
- if is_train:
- next_name_li.append(last_row['target_amount_of_drop'])
- next_name_li.append(last_row['target_time_to_drop'])
- name_c.extend(next_name_li)
- group_ids.append(tuple(name_c))
-
- del df_group_bag_30_filtered, df_group_bag_20_filtered
- del df_group_bag_30, df_group_bag_20
- del df_group
- gc.collect()
- print(">>结束创建序列")
- end_time = time.time()
- run_time = round(end_time - start_time, 3)
- print(f"用时: {run_time} 秒")
- print(f"生成的序列数量:{len(sequences)}")
-
- return sequences, targets, group_ids
- class FlightDataset(Dataset):
- def __init__(self, X_sequences, y_sequences=None, group_ids=None):
- self.X_sequences = X_sequences
- self.y_sequences = y_sequences
- self.group_ids = group_ids
- self.return_group_ids = group_ids is not None
- def __len__(self):
- return len(self.X_sequences)
- def __getitem__(self, idx):
- if self.return_group_ids:
- if self.y_sequences:
- return self.X_sequences[idx], self.y_sequences[idx], self.group_ids[idx]
- else:
- return self.X_sequences[idx], self.group_ids[idx]
- else:
- if self.y_sequences:
- return self.X_sequences[idx], self.y_sequences[idx]
- else:
- return self.X_sequences[idx]
- class EarlyStoppingDist:
- """早停机制(分布式)"""
- def __init__(self, patience=10, verbose=False, delta=0, path='best_model.pth', rank=0, local_rank=0):
- """
- Args:
- patience (int): 在训练集(或验证集)损失不再改善时,等待多少个epoch后停止训练
- verbose (bool): 是否打印相关信息
- delta (float): 训练集损失需要减少的最小变化量
- path (str): 保存最佳模型的路径
- """
- self.patience = patience
- self.verbose = verbose
- self.delta = delta
- self.path = path
- self.counter = 0
- self.best_loss = None
- self.early_stop = False
- self.rank = rank
- self.local_rank = local_rank
- def __call__(self, loss, model):
- if self.best_loss is None:
- self.best_loss = loss
- self.save_checkpoint(loss, model)
- elif loss > self.best_loss - self.delta:
- self.counter += 1
- if self.verbose and self.rank == 0:
- print(f'Rank:{self.rank}, Local Rank:{self.local_rank}, EarlyStopping counter: {self.counter} out of {self.patience}')
- if self.counter >= self.patience:
- self.early_stop = True
- else:
- self.save_checkpoint(loss, model)
- self.best_loss = loss
- self.counter = 0
- if self.is_nan(loss):
- self.counter += self.patience # 立即触发早停
- self.early_stop = True
- def is_nan(self, loss):
- """检查损失值是否为NaN(通用方法)"""
- try:
- # 所有NaN类型都不等于自身
- return loss != loss
- except Exception:
- # 处理不支持比较的类型
- return False
- def save_checkpoint(self, loss, model):
- """保存模型"""
- if self.verbose and self.rank == 0:
- print(f'Rank:{self.rank}, Local Rank:{self.local_rank}, Loss decreased ({self.best_loss:.6f} --> {loss:.6f}). Saving model ...')
- torch.save(model.state_dict(), self.path)
|