Quellcode durchsuchen

小幅修正训练

node04 vor 2 Tagen
Ursprung
Commit
ef406d5c0c
5 geänderte Dateien mit 121 neuen und 102 gelöschten Zeilen
  1. 2 2
      data_loader.py
  2. 101 92
      data_preprocess.py
  3. 7 3
      evaluate.py
  4. 1 1
      main_tr.py
  5. 10 4
      utils.py

+ 2 - 2
data_loader.py

@@ -657,9 +657,9 @@ def load_train_data(db, flight_route_list, table_name, date_begin, date_end, out
                     df_b2[cols] = df_b2[cols].astype("string")
 
                     df_b12 = pd.concat([df_b1, df_b2]).reset_index(drop=True)
-                    print(f"📊 dep_date:{dep_date}, baggage:{baggage} 已将远期表和近期表合并,形状: {df_b12.shape}")
+                    # print(f"📊 dep_date:{dep_date}, baggage:{baggage} 已将远期表和近期表合并,形状: {df_b12.shape}")
                     df_b12 = fill_hourly_crawl_date(df_b12, rear_fill=2)
-                    print(f"📊 dep_date:{dep_date}, baggage:{baggage} 已合并且补齐为完整小时序列,形状: {df_b12.shape}")
+                    # print(f"📊 dep_date:{dep_date}, baggage:{baggage} 已合并且补齐为完整小时序列,形状: {df_b12.shape}")
                     # print(df_b12.dtypes)
                     list_12.append(df_b12)
                     del df_b12

+ 101 - 92
data_preprocess.py

@@ -324,104 +324,113 @@ def preprocess_data(df_input, features, categorical_features, is_training=True,
     # days_to_holiday 插在 update_hour 前面
     insert_df_col(df_input, 'days_to_holiday', 'update_hour')
 
-    # 制作targets
-    print(f"\n>>> 开始处理 对应区间: n_hours = {current_n_hours}")
-    target_lower_limit = 4
-    target_upper_limit = current_n_hours
-    mask_targets = (df_input['hours_until_departure'] >= target_lower_limit) & (df_input['hours_until_departure'] < target_upper_limit) & (df_input['baggage'] == 30)
-    df_targets = df_input.loc[mask_targets].copy()
-
-    targets_amout = df_targets.shape[0]
-    print(f"当前 目标区间数据量: {targets_amout}, 区间: [{target_lower_limit}, {target_upper_limit})")
-
-    if targets_amout == 0:
-        print(f">>> n_hours = {current_n_hours} 无有效数据,跳过")
-        return pd.DataFrame()
-
-    print(">>> 计算 price_at_n_hours")
-    df_input_object = df_input[(df_input['hours_until_departure'] >= current_n_hours) & (df_input['baggage'] == 30)].copy()
-    df_last = df_input_object.groupby('gid', observed=True).last().reset_index()   # 一般落在起飞前28小时
-    
-    # 提取并重命名 price 列
-    df_last_price_at_n_hours = df_last[['gid', 'adult_total_price']].rename(columns={'adult_total_price': 'price_at_n_hours'})
-    print(">>> price_at_n_hours计算完成,示例:")
-    print(df_last_price_at_n_hours.head(5))
-    
-    # 计算降价信息
-    print(">>> 计算降价信息")
-    df_targets = df_targets.merge(df_last_price_at_n_hours, on='gid', how='left')
-    df_targets['price_drop_amount'] = df_targets['price_at_n_hours'] - df_targets['adult_total_price']
-    df_targets['price_dropped'] = (
-        (df_targets['adult_total_price'] < df_targets['price_at_n_hours']) &
-        (df_targets['price_drop_amount'] >= 5)  # 降幅不能太小
-    )
-    df_price_drops = df_targets[df_targets['price_dropped']].copy()
-
-    price_drops_len = df_price_drops.shape[0]
-    if price_drops_len == 0:
-        print(f">>> n_hours = {current_n_hours} 无降价信息")
-        # 创建包含指定列的空 DataFrame
-        df_price_drop_info = pd.DataFrame({
-            'gid': pd.Series(dtype='int64'),
-            'first_drop_hours_until_departure': pd.Series(dtype='int64'),
-            'price_at_first_drop_hours': pd.Series(dtype='float64')
+    # 训练模式
+    if is_training:
+        print(">>> 训练模式:计算 target 相关列")
+        print(f"\n>>> 开始处理 对应区间: n_hours = {current_n_hours}")
+        target_lower_limit = 4
+        target_upper_limit = current_n_hours
+        mask_targets = (df_input['hours_until_departure'] >= target_lower_limit) & (df_input['hours_until_departure'] < target_upper_limit) & (df_input['baggage'] == 30)
+        df_targets = df_input.loc[mask_targets].copy()
+
+        targets_amout = df_targets.shape[0]
+        print(f"当前 目标区间数据量: {targets_amout}, 区间: [{target_lower_limit}, {target_upper_limit})")
+
+        if targets_amout == 0:
+            print(f">>> n_hours = {current_n_hours} 无有效数据,跳过")
+            return pd.DataFrame()
+
+        print(">>> 计算 price_at_n_hours")
+        df_input_object = df_input[(df_input['hours_until_departure'] >= current_n_hours) & (df_input['baggage'] == 30)].copy()
+        df_last = df_input_object.groupby('gid', observed=True).last().reset_index()   # 一般落在起飞前28小时
+        
+        # 提取并重命名 price 列
+        df_last_price_at_n_hours = df_last[['gid', 'adult_total_price']].rename(columns={'adult_total_price': 'price_at_n_hours'})
+        print(">>> price_at_n_hours计算完成,示例:")
+        print(df_last_price_at_n_hours.head(5))
+        
+        # 计算降价信息
+        print(">>> 计算降价信息")
+        df_targets = df_targets.merge(df_last_price_at_n_hours, on='gid', how='left')
+        df_targets['price_drop_amount'] = df_targets['price_at_n_hours'] - df_targets['adult_total_price']
+        df_targets['price_dropped'] = (
+            (df_targets['adult_total_price'] < df_targets['price_at_n_hours']) &
+            (df_targets['price_drop_amount'] >= 5)  # 降幅不能太小
+        )
+        df_price_drops = df_targets[df_targets['price_dropped']].copy()
+
+        price_drops_len = df_price_drops.shape[0]
+        if price_drops_len == 0:
+            print(f">>> n_hours = {current_n_hours} 无降价信息")
+            # 创建包含指定列的空 DataFrame
+            df_price_drop_info = pd.DataFrame({
+                'gid': pd.Series(dtype='int64'),
+                'first_drop_hours_until_departure': pd.Series(dtype='int64'),
+                'price_at_first_drop_hours': pd.Series(dtype='float64')
+            })
+        else:
+            df_price_drop_info = df_price_drops.groupby('gid', observed=True).first().reset_index()   # 第一次发生的降价
+            df_price_drop_info = df_price_drop_info[['gid', 'hours_until_departure', 'adult_total_price']].rename(columns={
+                    'hours_until_departure': 'first_drop_hours_until_departure',
+                    'adult_total_price': 'price_at_first_drop_hours'
+            })
+            print(">>> 降价信息计算完成,示例:")
+            print(df_price_drop_info.head(5))
+        
+        # 合并信息
+        df_gid_info = df_last_price_at_n_hours.merge(df_price_drop_info, on='gid', how='left')
+        df_gid_info['will_price_drop'] = df_gid_info['price_at_first_drop_hours'].notnull().astype(int)
+        df_gid_info['amount_of_price_drop'] = df_gid_info['price_at_n_hours'] - df_gid_info['price_at_first_drop_hours']
+        df_gid_info['amount_of_price_drop'] = df_gid_info['amount_of_price_drop'].fillna(0)  # 区别    
+        df_gid_info['time_to_price_drop'] = current_n_hours - df_gid_info['first_drop_hours_until_departure']
+        df_gid_info['time_to_price_drop'] = df_gid_info['time_to_price_drop'].fillna(0)  # 区别
+
+        del df_input_object
+        del df_last
+        del df_last_price_at_n_hours
+        del df_targets
+        del df_price_drops
+        del df_price_drop_info
+        gc.collect()
+        
+        # 将目标变量合并到输入数据中
+        print(">>> 将目标变量信息合并到 df_input")
+        df_input = df_input.merge(df_gid_info[['gid', 'will_price_drop', 'amount_of_price_drop', 'time_to_price_drop']], on='gid', how='left')
+        # 使用 0 填充 NaN 值
+        df_input[['will_price_drop', 'amount_of_price_drop', 'time_to_price_drop']] = df_input[
+            ['will_price_drop', 'amount_of_price_drop', 'time_to_price_drop']].fillna(0)
+        df_input = df_input.rename(columns={
+            'will_price_drop': 'target_will_price_drop',
+            'amount_of_price_drop': 'target_amount_of_drop',
+            'time_to_price_drop': 'target_time_to_drop'
         })
+        
+        # 计算每个 gid 分组在 df_targets 中的 adult_total_price 最小值
+        # print(">>> 计算每个 gid 分组的 adult_total_price 最小值...")
+        # df_min_price_by_gid = df_targets.groupby('gid')['adult_total_price'].min().reset_index()
+        # df_min_price_by_gid = df_min_price_by_gid.rename(columns={'adult_total_price': 'min_price'})
+        # gid_count = df_min_price_by_gid.shape[0]
+        # print(f">>> 计算完成,共 {gid_count} 个 gid 分组")
+
+        # # 将最小价格 merge 到 df_inputs 中
+        # print(">>> 将最小价格 merge 到输入数据中...")
+        # df_input = df_input.merge(df_min_price_by_gid, on='gid', how='left')
+
+        print(">>> 合并后 df_input 样例:")
+        print(df_input[['gid', 'hours_until_departure', 'adult_total_price', 'target_will_price_drop', 'target_amount_of_drop', 'target_time_to_drop']].head(5))
+
+    # 预测模式
     else:
-        df_price_drop_info = df_price_drops.groupby('gid', observed=True).first().reset_index()   # 第一次发生的降价
-        df_price_drop_info = df_price_drop_info[['gid', 'hours_until_departure', 'adult_total_price']].rename(columns={
-                'hours_until_departure': 'first_drop_hours_until_departure',
-                'adult_total_price': 'price_at_first_drop_hours'
-        })
-        print(">>> 降价信息计算完成,示例:")
-        print(df_price_drop_info.head(5))
-    
-    # 合并信息
-    df_gid_info = df_last_price_at_n_hours.merge(df_price_drop_info, on='gid', how='left')
-    df_gid_info['will_price_drop'] = df_gid_info['price_at_first_drop_hours'].notnull().astype(int)
-    df_gid_info['amount_of_price_drop'] = df_gid_info['price_at_n_hours'] - df_gid_info['price_at_first_drop_hours']
-    df_gid_info['amount_of_price_drop'] = df_gid_info['amount_of_price_drop'].fillna(0)  # 区别    
-    df_gid_info['time_to_price_drop'] = df_gid_info['first_drop_hours_until_departure']
-    df_gid_info['time_to_price_drop'] = df_gid_info['time_to_price_drop'].fillna(0)  # 区别
-
-    del df_input_object
-    del df_last
-    del df_last_price_at_n_hours
-    del df_targets
-    del df_price_drops
-    del df_price_drop_info
-    gc.collect()
-    
-    # 将目标变量合并到输入数据中
-    print(">>> 将目标变量信息合并到 df_input")
-    df_input = df_input.merge(df_gid_info[['gid', 'will_price_drop', 'amount_of_price_drop', 'time_to_price_drop']], on='gid', how='left')
-    # 使用 0 填充 NaN 值
-    df_input[['will_price_drop', 'amount_of_price_drop', 'time_to_price_drop']] = df_input[
-        ['will_price_drop', 'amount_of_price_drop', 'time_to_price_drop']].fillna(0)
-    df_input = df_input.rename(columns={
-        'will_price_drop': 'target_will_price_drop',
-        'amount_of_price_drop': 'target_amount_of_drop',
-        'time_to_price_drop': 'target_time_to_drop'
-    })
-    
-    # 计算每个 gid 分组在 df_targets 中的 adult_total_price 最小值
-    # print(">>> 计算每个 gid 分组的 adult_total_price 最小值...")
-    # df_min_price_by_gid = df_targets.groupby('gid')['adult_total_price'].min().reset_index()
-    # df_min_price_by_gid = df_min_price_by_gid.rename(columns={'adult_total_price': 'min_price'})
-    # gid_count = df_min_price_by_gid.shape[0]
-    # print(f">>> 计算完成,共 {gid_count} 个 gid 分组")
-
-    # # 将最小价格 merge 到 df_inputs 中
-    # print(">>> 将最小价格 merge 到输入数据中...")
-    # df_input = df_input.merge(df_min_price_by_gid, on='gid', how='left')
-
-    print(">>> 合并后 df_input 样例:")
-    print(df_input[['gid', 'hours_until_departure', 'adult_total_price', 'target_will_price_drop', 'target_amount_of_drop', 'target_time_to_drop']].head(5))
+        print(">>> 预测模式:补齐 target 相关列(全部置 0)")
+        df_input['target_will_price_drop'] = 0
+        df_input['target_amount_of_drop'] = 0.0
+        df_input['target_time_to_drop'] = 0
 
     # 按顺序排列
     order_columns = [
         "city_pair", "from_city_code", "from_city_num", "to_city_code", "to_city_num", "flight_day", 
         "seats_remaining", "baggage", "baggage_level", 
-        "price_change_times_total", "price_last_change_hours", "adult_total_price", "Adult_Total_Price", "target_will_price_drop", "target_time_to_drop",
+        "price_change_times_total", "price_last_change_hours", "adult_total_price", "Adult_Total_Price", "target_will_price_drop", "target_amount_of_drop", "target_time_to_drop",
         "days_to_departure", "days_to_holiday", "hours_until_departure", "Hours_Until_Departure", "update_hour", "gid",
         "flight_number_1", "flight_1_num", "airport_pair_1", "dep_time_1", "arr_time_1", "fly_duration_1", 
         "flight_by_hour", "flight_by_day", "flight_day_of_month", "flight_day_of_week", "flight_day_of_quarter", "flight_day_is_weekend", "is_transfer", 
@@ -430,7 +439,7 @@ def preprocess_data(df_input, features, categorical_features, is_training=True,
         "global_arr_time", "arr_country", "arr_country_is_holiday", "any_country_is_holiday",
     ]
     df_input = df_input[order_columns]
-
+    
     return df_input
 
 

+ 7 - 3
evaluate.py

@@ -108,6 +108,8 @@ def evaluate_model_distribute(model, device, sequences, targets, group_ids, batc
             'price': [info[6] for info in group_info],
             'Hours_until_Departure': [info[7] for info in group_info],
             'update_hour': [info[8] for info in group_info],
+            'target_amount_of_drop': [info[9] for info in group_info],  # 训练时的验证才有这两个target列
+            'target_time_to_drop': [info[10] for info in group_info],
             'probability': y_preds_class,
             'Actual_Will_Price_Drop': y_trues_class_labels,
             'Predicted_Will_Price_Drop': y_preds_class_labels,
@@ -115,11 +117,13 @@ def evaluate_model_distribute(model, device, sequences, targets, group_ids, batc
 
         # 数值处理
         threshold = 1e-3
-        numeric_columns = ['probability',
-                           # 'Actual_Amount_Of_Drop', 'Predicted_Amount_Of_Drop', 'Actual_Time_To_Drop', 'Predicted_Time_To_Drop'
-                           ]
+        numeric_columns = ['probability', 'target_amount_of_drop', 'target_time_to_drop']
         for col in numeric_columns:
             results_df[col] = results_df[col].where(results_df[col].abs() >= threshold, 0)
+            if col in ['target_time_to_drop']:
+                results_df[col] = results_df[col].round(0).astype(int)
+            if col in ['target_amount_of_drop']:
+                results_df[col] = results_df[col].round(2)
         
         # 保存结果
         results_df_path = os.path.join(output_dir, csv_file)

+ 1 - 1
main_tr.py

@@ -568,7 +568,7 @@ def _validate_group_structure(group_ids):
     
     sample = group_ids[0]
     assert isinstance(sample, tuple), "元素必须是元组"
-    assert len(sample) == 9, "元组长度必须为9"
+    assert len(sample) == 11, "元组长度必须为11"
 
 def debug_print_shard_info(sequences, targets, group_ids, rank, local_rank, world_size):
     """分布式环境下按Rank顺序打印分片前5条样本"""

+ 10 - 4
utils.py

@@ -77,10 +77,16 @@ def create_fixed_length_sequences(df, features, target_vars, input_length=452, i
             name_c = [city_pair, flight_day, flight_number_1, flight_number_2, dep_time_str]
             # 直接获取最后一行的相关信息
             last_row = df_group_bag_30_filtered.iloc[-1]
-            name_c.extend([str(last_row['baggage']),
-                           str(last_row['Adult_Total_Price']), 
-                           str(last_row['Hours_Until_Departure']), 
-                           str(last_row['update_hour'])])
+            next_name_li = [str(last_row['baggage']), 
+                            str(last_row['Adult_Total_Price']), 
+                            str(last_row['Hours_Until_Departure']),
+                            str(last_row['update_hour']), 
+                            ]
+            if is_train:
+                next_name_li.append(last_row['target_amount_of_drop'])
+                next_name_li.append(last_row['target_time_to_drop'])
+
+            name_c.extend(next_name_li)
             group_ids.append(tuple(name_c))
         
         del df_group_bag_30_filtered, df_group_bag_20_filtered