Browse Source

提交部分预测, 修改训练过程

node04 2 ngày trước cách đây
mục cha
commit
cbce610e75
5 tập tin đã thay đổi với 157 bổ sung15 xóa
  1. 1 0
      .gitignore
  2. 58 12
      data_process.py
  3. 95 0
      main_pe.py
  4. 1 1
      main_tr.py
  5. 2 2
      uo_atlas_import.py

+ 1 - 0
.gitignore

@@ -33,6 +33,7 @@ desktop.ini
 # 生成的图表图片
 photo/
 data_shards/
+predictions/
 
 # 字体文件(体积大,不适合版本控制)
 *.ttf

+ 58 - 12
data_process.py

@@ -4,7 +4,7 @@ import gc
 import os
 
 
-def preprocess_data_simple(df_input, is_train=False):
+def preprocess_data_simple(df_input, is_train=False, hourly_time=None):
 
     print(">>> 开始数据预处理")
     # 城市码映射成数字(不用)
@@ -31,12 +31,13 @@ def preprocess_data_simple(df_input, is_train=False):
     df_input = df_input[df_input['hours_until_departure'] <= 480]
     df_input = df_input[df_input['baggage_weight'] == 20]   # 先保留20公斤行李的
 
-    # 在hours_until_departure 的末尾 保留真实的而不是补齐的数据
+    # 在hours_until_departure 的末尾 保留到当前时刻的数据
     if not is_train:
-        _tail_filled = df_input.groupby(['gid', 'baggage_weight'])['is_filled'].transform(
-            lambda s: s.iloc[::-1].cummin().iloc[::-1]
-        )
-        df_input = df_input[~((df_input['is_filled'] == 1) & (_tail_filled == 1))]
+        df_input = df_input[df_input['update_hour'] <= hourly_time].copy()
+    else:
+        df_input = df_input.copy()  # 训练集也 copy 一下保持一致性
+    
+    df_input = df_input.reset_index(drop=True)
 
     # 价格变化最小量阈值
     price_change_amount_threshold = 5
@@ -89,12 +90,15 @@ def preprocess_data_simple(df_input, is_train=False):
             ascending=[True, True, False]
         ).reset_index(drop=True)
 
-        # 对于先升后降的分析
+        # 每条对应的前一条记录
         prev_pct = df_target.groupby(['gid', 'baggage_weight'], group_keys=False)['price_change_percent'].shift(1)
         prev_amo = df_target.groupby(['gid', 'baggage_weight'], group_keys=False)['price_change_amount'].shift(1)
         prev_dur = df_target.groupby(['gid', 'baggage_weight'], group_keys=False)['price_duration_hours'].shift(1)
         prev_price = df_target.groupby(['gid', 'baggage_weight'], group_keys=False)['price_total'].shift(1)
-        drop_mask = (prev_pct > 0) & (df_target['price_change_percent'] < 0)
+        
+        # 对于先升后降(先降再降)的分析
+        seg_start_mask = df_target['price_duration_hours'].eq(1)   # 开始变价节点
+        drop_mask = seg_start_mask & ((prev_pct > 0) | (prev_pct < 0)) & (df_target['price_change_percent'] < 0)
 
         df_drop_nodes = df_target.loc[drop_mask, ['gid', 'baggage_weight', 'hours_until_departure', 'days_to_departure', 'update_hour', 'update_week']].copy()
         df_drop_nodes.rename(columns={'hours_until_departure': 'drop_hours_until_departure'}, inplace=True)
@@ -124,9 +128,9 @@ def preprocess_data_simple(df_input, is_train=False):
         # 按顺序排列 去掉gid
         df_drop_nodes = df_drop_nodes[flight_info_cols + ['baggage_weight'] + drop_info_cols]
         
-        # 对于“上涨后再次上涨”的分析(连续两个正向变价段)
-        seg_start_mask = df_target['price_duration_hours'].eq(1)
-        rise_mask = seg_start_mask & (prev_pct > 0) & (df_target['price_change_percent'] > 0)
+        # 对于先升再升(先降再升)的分析
+        # seg_start_mask = df_target['price_duration_hours'].eq(1)
+        rise_mask = seg_start_mask & ((prev_pct > 0) | (prev_pct < 0)) & (df_target['price_change_percent'] > 0)
 
         df_rise_nodes = df_target.loc[rise_mask, ['gid', 'baggage_weight', 'hours_until_departure', 'days_to_departure', 'update_hour', 'update_week']].copy()
         df_rise_nodes.rename(columns={'hours_until_departure': 'rise_hours_until_departure'}, inplace=True)
@@ -169,4 +173,46 @@ def preprocess_data_simple(df_input, is_train=False):
         return df_input, df_drop_nodes, df_rise_nodes, df_envelope
     
     return df_input, None, None, None
- 
+ 
+
+def predict_data_simple(df_input, city_pair, output_dir, predict_dir=".", pred_time_str=""):
+    if df_input is None or df_input.empty:
+        return pd.DataFrame()
+
+    df_sorted = df_input.sort_values(
+        by=['gid', 'baggage_weight', 'hours_until_departure'],
+        ascending=[True, True, False],
+    ).reset_index(drop=True)
+
+    df_sorted = df_sorted[
+        df_sorted['hours_until_departure'].between(24, 360)
+    ].reset_index(drop=True)
+
+    # 每个 gid  baggage_weight 取 hours_until_departure 最小的一条
+    df_min_hours = (
+        df_sorted.drop_duplicates(subset=['gid', 'baggage_weight'], keep='last')
+        .reset_index(drop=True)
+    )
+
+    # 读历史升价-降价
+    drop_info_csv_path = os.path.join(output_dir, f'{city_pair}_drop_info.csv')
+    if os.path.exists(drop_info_csv_path):
+        df_drop_nodes = pd.read_csv(drop_info_csv_path)
+    else:
+        df_drop_nodes = pd.DataFrame()
+
+    # 读历史升价-升价
+    rise_info_csv_path = os.path.join(output_dir, f'{city_pair}_rise_info.csv')
+    if os.path.exists(rise_info_csv_path):
+        df_rise_nodes = pd.read_csv(rise_info_csv_path)
+    else:
+        df_rise_nodes = pd.DataFrame()
+    
+    # ==================== 跨航班日包络线 + 降价潜力 ====================
+    print(">>> 构建跨航班日价格包络线")
+    flight_key = ['citypair', 'flight_numbers', 'baggage_weight']
+    day_key = flight_key + ['from_date']
+    
+    
+
+    pass

+ 95 - 0
main_pe.py

@@ -0,0 +1,95 @@
+import os
+import time
+from datetime import datetime, timedelta
+from config import mongo_config, uo_city_pairs_new
+from data_loader import load_data
+from data_process import preprocess_data_simple, predict_data_simple
+
+
+def start_predict():
+    print(f"开始预测")
+
+    output_dir = "./data_shards"
+    predict_dir = "./predictions"
+
+    os.makedirs(predict_dir, exist_ok=True)
+
+    cpu_cores = os.cpu_count()  # 你的系统是72
+    max_workers = min(4, cpu_cores)  # 最大不超过4个进程
+
+    # 当前时间,取整时
+    current_time = datetime.now() 
+    current_time_str = current_time.strftime("%Y%m%d%H%M")
+    hourly_time = current_time.replace(minute=0, second=0, microsecond=0)
+    hourly_time_str = hourly_time.strftime("%Y%m%d%H%M")
+    print(f"预测时间:{current_time_str}, (取整): {hourly_time_str}")
+
+    # 清空上一次(同小时内)预测结果
+    csv_file_list = [f'future_predictions_{hourly_time_str}.csv']
+    for csv_file in csv_file_list:
+        try:
+            csv_path = os.path.join(predict_dir, csv_file)
+            os.remove(csv_path)
+        except Exception as e:
+            print(f"remove {csv_path} info: {str(e)}")
+
+    # 预测时间范围,满足起飞时间 在24小时后到360小时后
+    pred_hour_begin = hourly_time + timedelta(hours=24)
+    pred_hour_end = hourly_time + timedelta(hours=360)
+
+    pred_date_end = pred_hour_end.strftime("%Y-%m-%d")
+    pred_date_begin = pred_hour_begin.strftime("%Y-%m-%d") 
+
+    print(f"预测起飞时间范围: {pred_date_begin} 到 {pred_date_end}")
+
+    uo_city_pairs = uo_city_pairs_new.copy()
+    uo_city_pair_list = [f"{pair[:3]}-{pair[3:]}" for pair in uo_city_pairs]   
+
+    # 如果临时处理中断,从日志里找到 中断的索引 修改它
+    resume_idx = 0
+    uo_city_pair_list = uo_city_pair_list[resume_idx:]
+
+    # 打印预测阶段起始索引顺序
+    max_len = len(uo_city_pair_list) + resume_idx
+    print(f"预测阶段起始索引顺序:{resume_idx} ~ {max_len - 1}") 
+
+    for idx, uo_city_pair in enumerate(uo_city_pair_list, start=resume_idx):
+        print(f"第 {idx} 组 :", uo_city_pair)
+
+        # 加载预测数据 (仅仅是天数取到以后)
+        start_time = time.time()
+        df_test = load_data(mongo_config, uo_city_pair, pred_date_begin, pred_date_end, is_train=False,
+                            use_multiprocess=True, max_workers=max_workers)
+        end_time = time.time()
+        run_time = round(end_time - start_time, 3)
+        print(f"用时: {run_time} 秒")
+
+        if df_test.empty:
+            print(f"预测数据为空,跳过此批次。")
+            continue
+
+        # 按起飞时间过滤
+        df_test['from_hour'] = df_test['from_time'].dt.floor('h')
+        # 使用整点时间进行比较过滤
+        mask = (df_test['from_hour'] >= pred_hour_begin) & (df_test['from_hour'] < pred_hour_end)
+        original_count = len(df_test)
+        df_test = df_test[mask].reset_index(drop=True)
+        filtered_count = len(df_test)
+        # 删除临时字段
+        df_test = df_test.drop(columns=['from_hour'])
+        print(f"按起飞时间过滤:过滤前 {original_count} 条,过滤后 {filtered_count} 条")
+
+        if filtered_count == 0:
+            print(f"起飞时间在 {pred_hour_begin} 到 {pred_hour_end} 之间没有航班,跳过此批次。")
+            continue
+
+        df_test_inputs, _, _, _,  = preprocess_data_simple(df_test, is_train=False, hourly_time=hourly_time)
+
+        df_predict = predict_data_simple(df_test_inputs, uo_city_pair, output_dir, predict_dir, hourly_time_str)
+        
+        pass
+
+
+
+if __name__ == "__main__":
+    start_predict()

+ 1 - 1
main_tr.py

@@ -20,7 +20,7 @@ def start_train():
     cpu_cores = os.cpu_count()  # 你的系统是72
     max_workers = min(8, cpu_cores)  # 最大不超过8个进程
 
-    from_date_end = (datetime.today() - timedelta(days=1)).strftime("%Y-%m-%d")
+    from_date_end = (datetime.today() - timedelta(days=0)).strftime("%Y-%m-%d")  # 截止日改为今天
     from_date_begin = "2026-03-17"
 
     print(f"训练时间范围: {from_date_begin} 到 {from_date_end}")

+ 2 - 2
uo_atlas_import.py

@@ -230,8 +230,8 @@ def main_import_process(create_at_begin, create_at_end):
     print()
 
 if __name__ == "__main__":
-    create_at_begin = "2026-03-25 00:00:00"
-    create_at_end = "2026-03-25 23:59:59"
+    create_at_begin = "2026-03-26 00:00:00"
+    create_at_end = "2026-03-26 15:59:59"
     main_import_process(create_at_begin, create_at_end)
     
     # try: