Procházet zdrojové kódy

修改绘图方式

node04 před 1 dnem
rodič
revize
be52c02de5
1 změnil soubory, kde provedl 22 přidání a 6 odebrání
  1. 22 6
      data_loader.py

+ 22 - 6
data_loader.py

@@ -7,6 +7,7 @@ import gc
 from concurrent.futures import ProcessPoolExecutor, as_completed
 import pymongo
 from pymongo.errors import PyMongoError, ServerSelectionTimeoutError
+import numpy as np
 import pandas as pd
 import matplotlib.pyplot as plt
 from matplotlib import font_manager
@@ -243,6 +244,8 @@ def plot_c1_trend(df, output_dir="."):
     output_dir_temp = os.path.join(output_dir, city_pair)
     os.makedirs(output_dir_temp, exist_ok=True)
 
+    df = df[df['baggage_weight'] == 0]  # 只保留无行李的
+
     # 创建图表对象
     fig = plt.figure(figsize=(14, 8))
 
@@ -280,9 +283,22 @@ def plot_c1_trend(df, output_dir="."):
             zorder=3,
         )
 
-        # 添加注释 (小时数, 价格)
+        # 添加注释 (小时数, 价格, 舱位)
+        # 点密集时自动抽样,避免文字严重重叠
+        n_points = len(change_points)
+        max_labels = 30
+        step = max(1, int(np.ceil(n_points / max_labels)))
+        label_points = change_points.iloc[::step].copy()
+
+        # 确保最后一个点始终有注释
+        if n_points > 0 and label_points.index[-1] != change_points.index[-1]:
+            label_points = pd.concat([label_points, change_points.tail(1)])
+
+        rotation_angle = 45 if n_points > max_labels else 25
+        label_fontsize = 4 if n_points > max_labels else 5
+
         for _, row in change_points.iterrows():
-            text = f"({row['hours_until_departure']}, {row['price_total']})"
+            text = f"({row['hours_until_departure']}, {row['price_total']}, {row['cabins']})"
             plt.annotate(
                 text,
                 xy=(row['update_hour'], row['price_total']),
@@ -290,10 +306,10 @@ def plot_c1_trend(df, output_dir="."):
                 textcoords="offset points",
                 ha='left',
                 va='center',
-                fontsize=5,  # 字体稍小
+                fontsize=label_fontsize,  # 字体稍小
                 color='gray',
                 alpha=0.8,
-                rotation=25,
+                rotation=rotation_angle,
             )
         
         del change_points
@@ -655,8 +671,8 @@ if __name__ == "__main__":
     output_dir = f"./photo"
     os.makedirs(output_dir, exist_ok=True)
 
-    from_date_begin = "2026-04-01"
-    from_date_end = "2026-04-20"
+    from_date_begin = "2026-04-21"
+    from_date_end = "2026-05-10"
 
     uo_city_pairs = uo_city_pairs_new.copy()