Quellcode durchsuchen

提交用于对接王康规则的相关代码, 调整训练与预测时间参数

node04 vor 1 Woche
Ursprung
Commit
6caa7c16b7
6 geänderte Dateien mit 134 neuen und 50 gelöschten Zeilen
  1. 9 9
      data_preprocess.py
  2. 114 33
      descending_cabin_task.py
  3. 5 3
      follow_up.py
  4. 3 3
      main_pe_0.py
  5. 1 1
      main_tr_0.py
  6. 2 1
      result_keep_verify.py

+ 9 - 9
data_preprocess.py

@@ -926,7 +926,7 @@ def preprocess_data_simple(df_input, is_train=False):
 
     # 训练过程
     if is_train:
-        df_target = df_input[(df_input['hours_until_departure'] >= 12) & (df_input['hours_until_departure'] <= 360)].copy()   # 扩展至360小时(15天) 
+        df_target = df_input[(df_input['hours_until_departure'] >= 8) & (df_input['hours_until_departure'] <= 240)].copy()   # 扩展至240小时(10天) 
         df_target = df_target.sort_values(
             by=['gid', 'hours_until_departure'],
             ascending=[True, False]
@@ -1073,7 +1073,7 @@ def predict_data_simple(df_input, group_route_str, output_dir, predict_dir=".",
     ).reset_index(drop=True)
 
     df_sorted = df_sorted[
-        df_sorted['hours_until_departure'].between(12, 360)
+        df_sorted['hours_until_departure'].between(8, 240)
     ].reset_index(drop=True)
 
     # 每个 gid 取 hours_until_departure 最小的一条
@@ -1082,9 +1082,9 @@ def predict_data_simple(df_input, group_route_str, output_dir, predict_dir=".",
         .reset_index(drop=True)
     )
 
-    # 确保 hours_until_departure 在 [12, 360] 的 范围内
+    # 确保 hours_until_departure 在 [8, 240] 的 范围内
     # df_min_hours = df_min_hours[
-    #     df_min_hours['hours_until_departure'].between(12, 360)
+    #     df_min_hours['hours_until_departure'].between(8, 240)
     # ].reset_index(drop=True)
 
     drop_info_csv_path = os.path.join(output_dir, f'{group_route_str}_drop_info.csv')
@@ -1200,7 +1200,7 @@ def predict_data_simple(df_input, group_route_str, output_dir, predict_dir=".",
     ).round(4)
 
     # 综合评分阈值:大于阈值的都认为值得投放
-    target_score_threshold = 0.75
+    target_score_threshold = 0.8
     # df_min_hours['target_score_threshold'] = target_score_threshold
     df_min_hours['is_good_target'] = (df_min_hours['target_score'] >= target_score_threshold).astype(int)
 
@@ -1488,8 +1488,8 @@ def predict_data_simple(df_input, group_route_str, output_dir, predict_dir=".",
     _pred_dt = pd.to_datetime(str(pred_time_str), format="%Y%m%d%H%M", errors="coerce")
     df_min_hours["update_hour"] = _pred_dt.strftime("%Y-%m-%d %H:%M:%S")
     _dep_hour = pd.to_datetime(df_min_hours["from_time"], errors="coerce").dt.floor("h")
-    df_min_hours["valid_begin_hour"] = (_dep_hour - pd.to_timedelta(360, unit="h")).dt.strftime("%Y-%m-%d %H:%M:%S")
-    df_min_hours["valid_end_hour"] = (_dep_hour - pd.to_timedelta(12, unit="h")).dt.strftime("%Y-%m-%d %H:%M:%S")
+    df_min_hours["valid_begin_hour"] = (_dep_hour - pd.to_timedelta(240, unit="h")).dt.strftime("%Y-%m-%d %H:%M:%S")
+    df_min_hours["valid_end_hour"] = (_dep_hour - pd.to_timedelta(8, unit="h")).dt.strftime("%Y-%m-%d %H:%M:%S")
 
     # 要展示在预测表里的字段
     order_cols = ['city_pair', 'flight_day', 'flight_number_1', 'flight_number_2', 'from_time', 
@@ -1523,12 +1523,12 @@ def predict_data_simple(df_input, group_route_str, output_dir, predict_dir=".",
         na_position='last',
     ).reset_index(drop=True)
 
-    # 时间段过滤 过滤掉异常时间(update_hour 早于 crawl_date)
+    # 时间段过滤 过滤掉异常时间(update_hour 早于 crawl_date, 以及超过8小时不更新的数据
     update_dt = pd.to_datetime(df_predict["update_hour"], errors="coerce")
     crawl_dt = pd.to_datetime(df_predict["crawl_date"], errors="coerce")
     dt_diff = update_dt - crawl_dt
     df_predict = df_predict.loc[
-        (dt_diff >= pd.Timedelta(0)) & (dt_diff <= pd.Timedelta(hours=12))
+        (dt_diff >= pd.Timedelta(0)) & (dt_diff <= pd.Timedelta(hours=8))
         # (dt_diff >= pd.Timedelta(0))
     ].reset_index(drop=True)
     print("更新时间过滤完成")

+ 114 - 33
descending_cabin_task.py

@@ -113,8 +113,11 @@ class FlightPriceClient:
                     time.sleep(RETRY_INTERVAL)
                 resp = requests.post(url, headers=self.headers, json=payload, timeout=30)
                 resp.raise_for_status()
-                print(resp.json())
-                return resp.json()
+                body = resp.json()
+                print(json.dumps(body, ensure_ascii=False)[:200])
+                return body
+                # print(resp.json())
+                # return resp.json()
             except requests.Timeout as e:
                 last_err = FlightPriceRequestError(f"请求超时: {url}", cause=e)
             except requests.ConnectionError as e:
@@ -183,9 +186,10 @@ class ResultMatcher:
         第 i 段需满足:seg[i].cabin == 第 i 个舱位、seg[i].baggage == 第 i 个行李、seg[i].flight_number == 第 i 个航班号。
         返回匹配到的那条 result 项(含 data),未匹配到返回 None。
         """
-        cabin_list = [s.strip() for s in (cabins or "").split(";")]
-        baggage_list = [s.strip() for s in (baggages or "").split(";")]
-        flight_list = [s.strip() for s in (flight_numbers or "").split(";")] if flight_numbers else []
+        separator = '|'  # 更换分隔符由;为|
+        cabin_list = [s.strip() for s in (cabins or "").split(separator)]
+        baggage_list = [s.strip() for s in (baggages or "").split(separator)]
+        flight_list = [s.strip() for s in (flight_numbers or "").split(separator)] if flight_numbers else []
 
         n = len(cabin_list)
         if n == 0 or len(baggage_list) != n:
@@ -257,17 +261,19 @@ class FlightPriceTaskRunner:
         self.client = client or FlightPriceClient()
         self.matcher = ResultMatcher()
         self.handler = VerifyResultHandler()
-        # self.rate = fetch_rate("USD", "CNY")
+        self.rate = fetch_rate("USD", "CNY")
         
 
     def run(
         self,
         task: dict,
+        do_verify: bool = True,
     ) -> dict:
         """
         执行单条任务。task 需含: from_city_code, to_city_code, from_day, cabin, baggage, adult_total_price。
         flight_number 可选,用于匹配。
-        返回: {"status": "ok"|"placeholder"|"no_match", "price_info": {...}, "raw_verify": ...}
+        do_verify=False 时仅执行询价+匹配并返回 matched(不做验价)
+        返回: {"status": "ok"|"placeholder"|"no_match", "price_info": {...}, "raw_verify": ..., "raw_search": ..., "matched": ...}
         """
         from_city_code = task["from_city_code"]
         to_city_code = task["to_city_code"]
@@ -296,6 +302,38 @@ class FlightPriceTaskRunner:
         data = matched.get("data")
         if not data:
             return {"status": "no_data", "msg": "匹配项无 data", "raw_search": search_resp}
+
+        # 只询价不验价走的流程
+        if not do_verify:
+            # return {"status": "ok", "raw_search": search_resp, "matched": matched, "data": data}
+            expected_in_currency, rate_err = self._expected_price_in_verify_currency(task, matched)
+            if rate_err:
+                return {
+                    "status": "rate_error",
+                    "msg": rate_err,
+                    "raw_search": search_resp,
+                    "matched": matched,
+                    "data": data,
+                }
+            actual = matched.get("adult_total_price")   # 询价和验价接口出来的币种已经是人民币, 不用再转换
+            if self._price_within_threshold(expected_in_currency, actual):  # 对比
+                return {
+                    "status": "ok",
+                    "price_info": self._extract_price_info(matched),
+                    "raw_search": search_resp,
+                    "matched": matched,
+                    "data": data,
+                }
+            return {
+                "status": "price_not_within_threshold",
+                "msg": "询价结果价格不在阈值内",
+                "expected": expected_in_currency,
+                "actual": actual,
+                "raw_search": search_resp,
+                "matched": matched,
+                "data": data,
+            }
+            
         # 2. 验价(先 not_verify=False)
         try:
             verify_resp = self.client.verify_price(
@@ -385,14 +423,16 @@ class FlightPriceTaskRunner:
             expected_val = float(task.get("adult_total_price"))
         except (TypeError, ValueError):
             return None, "任务 adult_total_price 无效"
-        task_currency = (task.get("currency") or "USD").strip().upper()
-        verify_currency = (valid.get("currency") or "CNY").strip().upper()
-        if task_currency == verify_currency:
-            return expected_val, None
-        rate = fetch_rate(task_currency, verify_currency)
-        if rate is None:
-            return None, "汇率获取失败"
-        return expected_val * rate, None
+        if self.rate is None:
+            task_currency = (task.get("currency") or "USD").strip().upper()
+            verify_currency = (valid.get("currency") or "CNY").strip().upper()
+            if task_currency == verify_currency:
+                return expected_val, None
+            rate = fetch_rate(task_currency, verify_currency)
+            if rate is None:
+                return None, "汇率获取失败"
+            self.rate = rate
+        return expected_val * self.rate, None
 
     @staticmethod
     def _price_within_threshold(
@@ -428,7 +468,7 @@ class FlightPriceTaskRunner:
 def _process_one_task(row, runner):
     """处理单条任务:构建 end_task、执行 run、解析结果。成功返回 flight_data 字典,失败返回 None。"""
     task = row
-
+    separator = '|'  # 分隔符由;更换为|
 
     thread_name = threading.current_thread().name
     # print(f"[thread_name: {thread_name}] 正在处理任务: {task}")
@@ -438,9 +478,15 @@ def _process_one_task(row, runner):
 
     flight_numbers = task["flight_number_1"].strip()
     if task["flight_number_2"].strip() != "VJ":
-        flight_numbers += ";" + task["flight_number_2"].strip()
-    cabins = ";".join(["Y"] * len(flight_numbers.split(";")))
-    baggages = ";".join([f"1-{task['baggage']}"] * len(flight_numbers.split(";")))
+        flight_numbers += separator + task["flight_number_2"].strip()
+    cabins = separator.join(["Y"] * len(flight_numbers.split(separator)))
+
+    if str(task['baggage']) == '0':
+        baggage_str = "-;-;-;-"
+    else:
+        baggage_str = f"1-{task['baggage']}"
+
+    baggages = separator.join([baggage_str] * len(flight_numbers.split(separator)))
 
     end_task = {
         "from_city_code": from_city_code,
@@ -456,7 +502,8 @@ def _process_one_task(row, runner):
     # print(end_task)
     # print("--------------------------------")
 
-    out = runner.run(end_task)
+    time.sleep(1)
+    out = runner.run(end_task, do_verify=False)  # 不验价,仅询价
     # print(json.dumps(out, ensure_ascii=False, indent=2))
     if out.get("status") != "ok":
         # print(f"[thread_name={thread_name}] 错误: {out.get('msg')}")
@@ -464,15 +511,35 @@ def _process_one_task(row, runner):
 
     # print(f"价格: {out.get('price_info').get('adult_total_price')}")
     raw_verify = out.get("raw_verify")
-    results = raw_verify.get("result")
+    if raw_verify:
+        results = raw_verify.get("result") or []
+    else:
+        matched = out.get("matched") or {}
+        results = [matched] if matched else []
     if not results:
         return None
 
+    print("raw_verify pass")
+
+    # task 存放了 keep_info 的全部字段
+    drop_price_change_upper = float(task.get("drop_price_change_upper"))   # 降价的最小幅度
+    drop_price_change_lower = float(task.get("drop_price_change_lower"))
+
+    max_threshold = round(drop_price_change_upper * runner.rate * 0.5)   # 降价阈值要按汇率转人民币(四舍五入到整数)
+
     result = results[0]
-    segments = result.get("segments")
+    # adult_price = result.get("adult_price")
+    # adult_tax = result.get("adult_tax")
+    # adult_total_price = result.get("adult_total_price")
+    segments = result.get("segments") or []
+    if not segments:
+        return None
     end_segments = []
     baggage = segments[0].get("baggage")
-    pc, kg = [int(i) for i in baggage.split("-")]
+    if baggage == "-;-;-;-":
+        pc, kg = 0, 0   # 无行李的设置?
+    else:
+        pc, kg = [int(i) for i in baggage.split("-")]
     for seg in segments:
         flight_number = seg.get("flight_number")
         operating_flight_number = seg.get("operating_flight_number")
@@ -486,24 +553,27 @@ def _process_one_task(row, runner):
 
         end_segment = {
             "carrier": seg.get("carrier"),
-            "dep_air_port": seg.get("dep_air_port"),
-            "arr_air_port": seg.get("arr_air_port"),
+            "flight_number": flight_number,
+            # "dep_air_port": seg.get("dep_air_port"),
+            # "arr_air_port": seg.get("arr_air_port"),
             "dep_city_code": seg.get("dep_city_code"),
             "arr_city_code": seg.get("arr_city_code"),
-            "flight_number": flight_number,
-            "operating_flight_number": operating_flight_number,
+            # "operating_flight_number": operating_flight_number,
             "cabin": seg.get("cabin"),
             "dep_time": dep_time,
-            "arr_time": arr_time,
+            # "arr_time": arr_time,
         }
         end_segments.append(end_segment)
 
     return {
         "trip_type": 1,
-        "segments": end_segments,
-        "price_add": 0,
+        # "cover_price": adult_price,
+        # "cover_tax": adult_tax,
         "bag_amount": pc,
         "bag_weight": kg,
+        "max_threshold": max_threshold,
+        "segments": end_segments,
+        "ret_segments": [],
         "task": task
     }
 
@@ -553,9 +623,12 @@ def main():
 
     policy_list = []
     keep_info_end = []
-    max_workers = 3  # 并发线程数,可按需要调整
+    max_workers = 5  # 并发线程数,可按需要调整
     with ThreadPoolExecutor(max_workers=max_workers) as executor:
         futures = {executor.submit(_process_one_task, task, runner): task for task in task_list}
+        total = len(futures)
+        done = 0
+        failed = 0
         for future in as_completed(futures):
             try:
                 flight_data = future.result()
@@ -564,12 +637,20 @@ def main():
                     keep_info_end.append(task)
                     policy_list.append(flight_data)
             except Exception as e:
+                failed += 1
                 task = futures[future]
-                print(f"任务异常 {task}: {e}")
+                # print(f"任务异常 {task}: {e}")
+                logger.error(f"任务异常 {task}: {e}")
+            finally:
+                done += 1
+                logger.info(
+                    f"进度: {done}/{total}, policy: {len(policy_list)}, keep: {len(keep_info_end)}, failed: {failed}"
+                )
 
     # 3 批量一次性上传政策
     logger.info(f"数据过滤后, 上传政策: {len(policy_list)}")
-    logger.info(f"policy_list: {policy_list}")
+    # logger.info(f"policy_list: {policy_list}")
+    logger.info(f"policy_list: {json.dumps(policy_list, ensure_ascii=False, default=str)[:1000]}")
     if len(policy_list) > 0:
         # 这里批量一次性上传政策 
         payload = {"items": policy_list}

+ 5 - 3
follow_up.py

@@ -121,6 +121,7 @@ def follow_up_handle():
     if df_keep_info.empty:
         df_keep_info = df_last_predict_will_drop.copy()
         df_keep_info["into_update_hour"] = df_keep_info['update_hour']
+        # df_keep_info["into_price"] = df_keep_info['adult_total_price']
         df_keep_info["keep_flag"] = 1
         # df_keep_info["last_predict_time"] = target_time
 
@@ -181,6 +182,7 @@ def follow_up_handle():
         # keep_flag 设为 1
         if not df_to_add.empty:
             df_to_add['into_update_hour'] = df_to_add['update_hour']
+            # df_to_add['into_price'] = df_to_add['adult_total_price']
             df_to_add["keep_flag"] = 1
         
         df_keep_with_merge = df_keep_info.reset_index().merge(
@@ -348,6 +350,6 @@ def follow_up_handle():
 if __name__ == "__main__":
     time.sleep(2)
     follow_up_handle()
-    # time.sleep(10)
-    # from descending_cabin_task import main as descending_cabin_task_main
-    # descending_cabin_task_main()
+    time.sleep(5)
+    from descending_cabin_task import main as descending_cabin_task_main
+    descending_cabin_task_main()

+ 3 - 3
main_pe_0.py

@@ -38,9 +38,9 @@ def start_predict():
         except Exception as e:
             print(f"remove {csv_path} info: {str(e)}")
 
-    # 预测时间范围,满足起飞时间 在12小时后到360小时后
-    pred_hour_begin = hourly_time + timedelta(hours=12)
-    pred_hour_end = hourly_time + timedelta(hours=360)
+    # 预测时间范围,满足起飞时间 在8小时后到240小时后
+    pred_hour_begin = hourly_time + timedelta(hours=8)
+    pred_hour_end = hourly_time + timedelta(hours=240)
 
     pred_date_end = pred_hour_end.strftime("%Y-%m-%d")
     pred_date_begin = pred_hour_begin.strftime("%Y-%m-%d")

+ 1 - 1
main_tr_0.py

@@ -50,7 +50,7 @@ def start_train():
     # date_end = datetime.today().strftime("%Y-%m-%d")
     date_end = (datetime.today() - timedelta(days=1)).strftime("%Y-%m-%d")
     # date_begin = (datetime.today() - timedelta(days=32)).strftime("%Y-%m-%d")
-    date_begin = "2026-03-11"   # 2026-01-01 2026-03-11 2026-03-16
+    date_begin = "2026-01-01"   # 2026-01-01 2026-03-11 2026-03-16 2026-03-18
 
     print(f"训练时间范围: {date_begin} 到 {date_end}")
 

+ 2 - 1
result_keep_verify.py

@@ -57,6 +57,7 @@ def _validate_keep_info_df(df_keep_info_part):
                     .reset_index(drop=True)
                 )
                 mask_drop = df_query["adult_total_price"] < entry_price
+                # mask_drop = (df_query["adult_total_price"] < entry_price) & (df_query["crawl_dt"] > update_dt)
                 if mask_drop.any():
                     first_row = df_query.loc[mask_drop].iloc[0]
                     price_diff = entry_price - first_row["adult_total_price"]
@@ -173,5 +174,5 @@ def verify_process(min_batch_time_str, max_batch_time_str):
         
 
 if __name__ == "__main__":
-    verify_process("202603121800", "202603160800")
+    verify_process("202603161800", "202603180800")
     pass