| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283 |
- import torch
- # 航线列表分组切片并带上索引
- def chunk_list_with_index(lst, group_size):
- return [(i, lst[i:i + group_size]) for i in range(0, len(lst), group_size)]
- # pandas 在指定列之前插入新列
- def insert_df_col(df, insert_col_name, base_col_name, inplace=True):
- if not inplace:
- df = df.copy()
- if base_col_name not in df.columns:
- raise ValueError(f"base_col_name '{base_col_name}' 不存在")
- if insert_col_name not in df.columns:
- raise ValueError(f"insert_col_name '{insert_col_name}' 不存在")
- if base_col_name == insert_col_name:
- return df
- insert_idx = df.columns.get_loc(base_col_name)
- col_data = df.pop(insert_col_name)
- df.insert(insert_idx, insert_col_name, col_data)
- return df
- # 真正创建序列过程
- def create_fixed_length_sequences(df, features, target_vars, input_length=452, is_train=True):
- sequences = []
- targets = []
- group_ids = []
- threshold = 28 # 截止起飞小时数
- # gid 基于 city_pair, flight_day, flight_number_1, flight_number_2 分组 不包括 baggage
- grouped = df.groupby(['gid'])
- for _, df_group in grouped:
- city_pair = df_group['city_pair'].iloc[0]
- flight_day = df_group['flight_day'].iloc[0]
- flight_number_1 = df_group['flight_number_1'].iloc[0]
- flight_number_2 = df_group['flight_number_2'].iloc[0]
- dep_time_str = df_group['dep_time_1'].iloc[0].strftime('%Y-%m-%d %H:%M:%S')
-
- # 按行李配额分开
- df_group_bag_30 = df_group[df_group['baggage']==30]
- df_group_bag_20 = df_group[df_group['baggage']==20]
- # 过滤训练时间段 (28 ~ 480)
- df_group_bag_30_filtered = df_group_bag_30[(df_group_bag_30['Hours_Until_Departure'] >= threshold) & (df_group_bag_30['Hours_Until_Departure'] < threshold + input_length)]
- df_group_bag_20_filtered = df_group_bag_20[(df_group_bag_20['Hours_Until_Departure'] >= threshold) & (df_group_bag_20['Hours_Until_Departure'] < threshold + input_length)]
- # 条件: 长度要一致
- condition_list = [
- len(df_group_bag_30_filtered) == input_length,
- len(df_group_bag_20_filtered) == input_length,
- ]
- if all(condition_list):
- seq_features_1 = df_group_bag_30_filtered[features].to_numpy()
- seq_features_2 = df_group_bag_20_filtered[features].to_numpy()
-
- # 将几个特征序列沿着第 0 维拼接,得到形状为 (2, 452, 25)
- combined_features = torch.stack([torch.tensor(seq_features_1, dtype=torch.float32),
- torch.tensor(seq_features_2, dtype=torch.float32)])
- # 将拼接后的结果添加到 sequences 列表中
- sequences.append(combined_features)
- if is_train and target_vars:
- seq_targets = df_group_bag_30_filtered[target_vars].iloc[0].to_numpy()
- targets.append(torch.tensor(seq_targets, dtype=torch.float32))
-
- name_c = [city_pair, flight_day, flight_number_1, flight_number_2, dep_time_str]
- # 直接获取最后一行的相关信息
- last_row = df_group_bag_30_filtered.iloc[-1]
- name_c.extend([str(last_row['baggage']),
- str(last_row['Adult_Total_Price']),
- str(last_row['Hours_Until_Departure']),
- str(last_row['update_hour'])])
- group_ids.append(tuple(name_c))
- pass
- pass
- pass
|