utils.py 3.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import torch
  2. # 航线列表分组切片并带上索引
  3. def chunk_list_with_index(lst, group_size):
  4. return [(i, lst[i:i + group_size]) for i in range(0, len(lst), group_size)]
  5. # pandas 在指定列之前插入新列
  6. def insert_df_col(df, insert_col_name, base_col_name, inplace=True):
  7. if not inplace:
  8. df = df.copy()
  9. if base_col_name not in df.columns:
  10. raise ValueError(f"base_col_name '{base_col_name}' 不存在")
  11. if insert_col_name not in df.columns:
  12. raise ValueError(f"insert_col_name '{insert_col_name}' 不存在")
  13. if base_col_name == insert_col_name:
  14. return df
  15. insert_idx = df.columns.get_loc(base_col_name)
  16. col_data = df.pop(insert_col_name)
  17. df.insert(insert_idx, insert_col_name, col_data)
  18. return df
  19. # 真正创建序列过程
  20. def create_fixed_length_sequences(df, features, target_vars, input_length=452, is_train=True):
  21. sequences = []
  22. targets = []
  23. group_ids = []
  24. threshold = 28 # 截止起飞小时数
  25. # gid 基于 city_pair, flight_day, flight_number_1, flight_number_2 分组 不包括 baggage
  26. grouped = df.groupby(['gid'])
  27. for _, df_group in grouped:
  28. city_pair = df_group['city_pair'].iloc[0]
  29. flight_day = df_group['flight_day'].iloc[0]
  30. flight_number_1 = df_group['flight_number_1'].iloc[0]
  31. flight_number_2 = df_group['flight_number_2'].iloc[0]
  32. dep_time_str = df_group['dep_time_1'].iloc[0].strftime('%Y-%m-%d %H:%M:%S')
  33. # 按行李配额分开
  34. df_group_bag_30 = df_group[df_group['baggage']==30]
  35. df_group_bag_20 = df_group[df_group['baggage']==20]
  36. # 过滤训练时间段 (28 ~ 480)
  37. df_group_bag_30_filtered = df_group_bag_30[(df_group_bag_30['Hours_Until_Departure'] >= threshold) & (df_group_bag_30['Hours_Until_Departure'] < threshold + input_length)]
  38. df_group_bag_20_filtered = df_group_bag_20[(df_group_bag_20['Hours_Until_Departure'] >= threshold) & (df_group_bag_20['Hours_Until_Departure'] < threshold + input_length)]
  39. # 条件: 长度要一致
  40. condition_list = [
  41. len(df_group_bag_30_filtered) == input_length,
  42. len(df_group_bag_20_filtered) == input_length,
  43. ]
  44. if all(condition_list):
  45. seq_features_1 = df_group_bag_30_filtered[features].to_numpy()
  46. seq_features_2 = df_group_bag_20_filtered[features].to_numpy()
  47. # 将几个特征序列沿着第 0 维拼接,得到形状为 (2, 452, 25)
  48. combined_features = torch.stack([torch.tensor(seq_features_1, dtype=torch.float32),
  49. torch.tensor(seq_features_2, dtype=torch.float32)])
  50. # 将拼接后的结果添加到 sequences 列表中
  51. sequences.append(combined_features)
  52. if is_train and target_vars:
  53. seq_targets = df_group_bag_30_filtered[target_vars].iloc[0].to_numpy()
  54. targets.append(torch.tensor(seq_targets, dtype=torch.float32))
  55. name_c = [city_pair, flight_day, flight_number_1, flight_number_2, dep_time_str]
  56. # 直接获取最后一行的相关信息
  57. last_row = df_group_bag_30_filtered.iloc[-1]
  58. name_c.extend([str(last_row['baggage']),
  59. str(last_row['Adult_Total_Price']),
  60. str(last_row['Hours_Until_Departure']),
  61. str(last_row['update_hour'])])
  62. group_ids.append(tuple(name_c))
  63. pass
  64. pass
  65. pass