فهرست منبع

修改验证细节与加注释

node04 1 هفته پیش
والد
کامیت
71eadb6bb2
3فایلهای تغییر یافته به همراه59 افزوده شده و 24 حذف شده
  1. 10 3
      data_preprocess.py
  2. 4 4
      main_tr_0.py
  3. 45 17
      result_validate_0.py

+ 10 - 3
data_preprocess.py

@@ -1306,13 +1306,20 @@ def predict_data_simple(df_input, group_route_str, output_dir, predict_dir=".",
                                 df_min_hours.loc[idx, 'simple_drop_in_hours'] = 0
                                 df_min_hours.loc[idx, 'simple_drop_in_hours_prob'] = 0.0
                                 df_min_hours.loc[idx, 'simple_drop_in_hours_dist'] = 'k0'
-                            # 依旧保持之前的降价判定,只是概率修改
+                            # 分歧判定
                             else:
                                 drop_prob = round(length_drop / (length_keep + length_drop), 2)
+                                # 依旧保持之前的降价判定,概率修改
+                                if drop_prob >= 0.45:
+                                    df_min_hours.loc[idx, 'simple_will_price_drop'] = 1
+                                    df_min_hours.loc[idx, 'simple_drop_in_hours_dist'] = 'd1'
+                                # 改判不降价,概率修改
+                                else:
+                                    df_min_hours.loc[idx, 'simple_will_price_drop'] = 0
+                                    df_min_hours.loc[idx, 'simple_drop_in_hours_dist'] = 'k1'
                                 df_min_hours.loc[idx, 'simple_drop_in_hours'] = 0
                                 df_min_hours.loc[idx, 'simple_drop_in_hours_prob'] = drop_prob
-                                df_min_hours.loc[idx, 'simple_drop_in_hours_dist'] = 'k1'
-
+                                
                             # elif length_keep == length_drop:   # 不降价与降价相同, 取0.5概率
 
                             #     df_min_hours.loc[idx, 'simple_will_price_drop'] = 1

+ 4 - 4
main_tr_0.py

@@ -17,8 +17,8 @@ def merge_and_overwrite_csv(df_new, csv_path, dedup_cols):
         df_old = pd.read_csv(csv_path, encoding='utf-8-sig')
         if key_cols and all(c in df_old.columns for c in key_cols):
             df_old_keys = df_old[key_cols].drop_duplicates()
-            df_add = df_new.merge(df_old_keys, on=key_cols, how='left', indicator=True)
-            df_add = df_add[df_add['_merge'] == 'left_only'].drop(columns=['_merge'])
+            df_add = df_new.merge(df_old_keys, on=key_cols, how='left', indicator=True)  # indicator=True 会在结果df中添加一个_merge列
+            df_add = df_add[df_add['_merge'] == 'left_only'].drop(columns=['_merge'])    # left_only 只在左表(df_new)中存在
         else:
             df_add = df_new.copy()
         df_merged = pd.concat([df_old, df_add], ignore_index=True)
@@ -28,7 +28,7 @@ def merge_and_overwrite_csv(df_new, csv_path, dedup_cols):
 
     sort_cols = [c for c in dedup_cols if c in df_merged.columns]
     if sort_cols:
-        df_merged = df_merged.sort_values(by=sort_cols).reset_index(drop=True)
+        df_merged = df_merged.sort_values(by=sort_cols).reset_index(drop=True)  # 重新分组排序
 
     df_merged.to_csv(csv_path, index=False, encoding='utf-8-sig')
 
@@ -49,7 +49,7 @@ def start_train():
     # date_end = datetime.today().strftime("%Y-%m-%d")
     date_end = (datetime.today() - timedelta(days=1)).strftime("%Y-%m-%d")
     # date_begin = (datetime.today() - timedelta(days=32)).strftime("%Y-%m-%d")
-    date_begin = "2025-12-01"
+    date_begin = "2026-01-27"   # 2025-12-01  2026-01-27 
 
     print(f"训练时间范围: {date_begin} 到 {date_end}")
 

+ 45 - 17
result_validate_0.py

@@ -226,7 +226,7 @@ def validate_process_auto(node, interval_hours):
     print()
 
 
-def validate_process_zong(node):
+def validate_process_zong(node, enable_min_batch_flag=False, min_batch_time_str=None):
     object_dir = "./predictions_0"
 
     output_dir = f"./validate/{node}_zong"
@@ -248,9 +248,27 @@ def validate_process_zong(node):
         return
 
     csv_files.sort()
-    list_df_will_drop = []    
+    list_df_will_drop = []
+
+    min_batch_dt = None
+    if enable_min_batch_flag:
+        if not min_batch_time_str:
+            print("enable_min_batch_flag=True 但未提供 min_batch_time_str,退出")
+            return        
+        min_batch_dt = datetime.datetime.strptime(min_batch_time_str, "%Y%m%d%H%M")
+        min_batch_dt = min_batch_dt.replace(minute=0, second=0, microsecond=0)
     
+    # 从所有预测的文件中
     for csv_file in csv_files:
+        batch_time_str = (
+            csv_file.replace("future_predictions_", "").replace(".csv", "")
+        )
+        batch_dt = datetime.datetime.strptime(batch_time_str, "%Y%m%d%H%M")
+        batch_hour_dt = batch_dt.replace(minute=0, second=0, microsecond=0)
+        # 跳过早于 min_batch_dt 的批次
+        if min_batch_dt is not None and batch_hour_dt < min_batch_dt:
+            continue
+        
         csv_path = os.path.join(object_dir, csv_file)
         try:
             df_predict = pd.read_csv(csv_path)
@@ -266,10 +284,6 @@ def validate_process_zong(node):
             print(f"缺少 will_price_drop 字段,跳过: {csv_file}")
             continue
 
-        batch_time_str = (
-            csv_file.replace("future_predictions_", "").replace(".csv", "")
-        )
-
         df_predict_will_drop = df_predict[df_predict["will_price_drop"] == 1].copy()
 
         if df_predict_will_drop.empty:
@@ -277,7 +291,7 @@ def validate_process_zong(node):
         
         # df_predict_will_drop["batch_file"] = csv_file
         df_predict_will_drop["batch_time"] = batch_time_str
-        list_df_will_drop.append(df_predict_will_drop)
+        list_df_will_drop.append(df_predict_will_drop)   # 保存每个批次的 will_drop 数据
 
         del df_predict
     
@@ -285,34 +299,40 @@ def validate_process_zong(node):
         print("所有批次的 will_drop 都为空")
         return
 
+    # === 1. 合并所有 will_drop 结果 ===
     df_predict_will_drop_all = pd.concat(list_df_will_drop, ignore_index=True)
 
+    # 释放临时列表内存(大列表时很有必要)
     del list_df_will_drop
     
     before_rows = len(df_predict_will_drop_all)
+    # 定义“航班唯一标识”的分组键
     group_keys = ["city_pair", "flight_number_1", "flight_number_2", "flight_day"]
+    # === 2. batch_time 转为 datetime,用于时间间隔判断 ===
     df_predict_will_drop_all["batch_dt"] = pd.to_datetime(
         df_predict_will_drop_all["batch_time"],
         format="%Y%m%d%H%M",
-        errors="coerce",
+        errors="coerce",   # 非法时间直接置为 NaT
     )
+    # === 3. 自动推断 batch_time 的“正常时间步长”(分钟) ===
     diff_minutes = (
         df_predict_will_drop_all["batch_dt"].dropna().sort_values().drop_duplicates().diff()
         .dt.total_seconds()
         .div(60)
         .dropna()
     )
+    # - 取出现频率最高的时间差作为“期望步长”  默认 60 分钟
     expected_step_minutes = (
         int(diff_minutes.value_counts().idxmax()) if not diff_minutes.empty else 60
     )
-
+    # === 4. 按航班 + 批次时间排序,为后续连续性判断做准备 ===
     df_predict_will_drop_all.sort_values(
         by=group_keys + ["batch_dt"],
         inplace=True,
         ignore_index=True,
         na_position="last",
     )
-
+    # === 5. 计算组内相邻 batch_dt 的时间间隔 ===
     df_predict_will_drop_all["prev_batch_dt"] = df_predict_will_drop_all.groupby(group_keys)[
         "batch_dt"
     ].shift(1)
@@ -321,20 +341,26 @@ def validate_process_zong(node):
         .dt.total_seconds()
         .div(60)
     )
-
+    # === 6. 标记“是否是一个新的连续段” ===
+    # 新段的条件:
+    #   1) prev_batch_dt 缺失(当前是组内第一条)
+    #   2) batch_dt 缺失 (不常见)
+    #   3) 与上一条的时间间隔 != 期望步长
     df_predict_will_drop_all["is_new_segment"] = (
         df_predict_will_drop_all["prev_batch_dt"].isna()
         | df_predict_will_drop_all["batch_dt"].isna()
         | (df_predict_will_drop_all["gap_minutes"] != expected_step_minutes)
     )
+    # === 7. 生成段号(segment_id)===
+    # 同一航班内,每遇到一个新段就 +1
     df_predict_will_drop_all["segment_id"] = df_predict_will_drop_all.groupby(group_keys)[
         "is_new_segment"
     ].cumsum()
-
+    # === 8. 计算每个连续段的“段尾 hours_until_departure” ===
     df_segment_last = df_predict_will_drop_all.groupby(
         group_keys + ["segment_id"], as_index=False
     ).agg(last_hours_until_departure=("hours_until_departure", "last"))
-
+    # === 9. 每个连续段只保留“第一条记录”,并补上段尾信息 ===
     df_predict_will_drop_filter = df_predict_will_drop_all.drop_duplicates(
         subset=group_keys + ["segment_id"], keep="first"
     ).merge(
@@ -342,7 +368,7 @@ def validate_process_zong(node):
         on=group_keys + ["segment_id"],
         how="left",
     )
-
+    # === 10. 清理中间附加字段 ===
     df_predict_will_drop_filter = (
         df_predict_will_drop_filter.drop(
             columns=[
@@ -355,7 +381,7 @@ def validate_process_zong(node):
         )
         .reset_index(drop=True)
     )
-
+    # === 11. 调整字段顺序(last_hours_until_departure 紧跟 price_change_percent)===
     if "price_change_percent" in df_predict_will_drop_filter.columns:
         cols = df_predict_will_drop_filter.columns.tolist()
         if "last_hours_until_departure" in cols:
@@ -415,8 +441,10 @@ if __name__ == "__main__":
     if interval_hours == 0:
         # node, pred_time_str = "node0127", "202601301500"
         # validate_process(node, interval_hours, pred_time_str)
-        node = "node0127"
-        validate_process_zong(node)
+        # node = "node0127"
+        # validate_process_zong(node)  # 无条件汇总
+        node = "node0203"
+        validate_process_zong(node, True, "202602031100")  # 有条件汇总
     # 1 自动验证
     else:
         node = "node0127"