فهرست منبع

提交绘图相关

node04 4 روز پیش
والد
کامیت
74ffd0a1c3
3فایلهای تغییر یافته به همراه327 افزوده شده و 15 حذف شده
  1. 1 0
      .gitignore
  2. 326 15
      data_loader.py
  3. BIN
      simhei.ttf

+ 1 - 0
.gitignore

@@ -0,0 +1 @@
+photo/

+ 326 - 15
data_loader.py

@@ -1,12 +1,21 @@
+import os
 import time
 import random
 from datetime import datetime, timedelta
+import gc
+from concurrent.futures import ProcessPoolExecutor, as_completed
 import pymongo
 from pymongo.errors import PyMongoError, ServerSelectionTimeoutError
 import pandas as pd
+import matplotlib.pyplot as plt
+from matplotlib import font_manager
+import matplotlib.dates as mdates
 from uo_atlas_import import mongo_con_parse
 from config import mongo_config, mongo_table_uo, uo_city_pairs
 
+font_path = "./simhei.ttf"
+font_prop = font_manager.FontProperties(fname=font_path)
+
 
 def query_groups_of_city_pair(db, city_pair, table_name, min_days=10, max_retries=3, base_sleep=1.0):
     """根据city_pair查询航线, 筛选1个月内至少有10天起飞的航线"""
@@ -103,6 +112,8 @@ def query_flight_range_status(db, table_name, city_pair, flight_numbers, from_da
                 "cabins": 1,
                 "ticket_amount": 1,
                 "currency": 1,
+                "price_base": 1,
+                "price_tax": 1,
                 "price_total": 1
             }            
             pipeline = [
@@ -152,6 +163,7 @@ def query_flight_range_status(db, table_name, city_pair, flight_numbers, from_da
                             df.loc[missing_mask, 'from_time'] = df.loc[missing_mask, 'from_date'].astype(str).str.strip() + ' ' + more_time
                     else:
                         # 无法得到起飞日期的抛弃
+                        print(f"⚠️ 无法提取有效起飞时间,抛弃该条记录")
                         return pd.DataFrame()
 
                 print(f"📊 已转换为 DataFrame,形状: {df.shape}")
@@ -173,9 +185,209 @@ def query_flight_range_status(db, table_name, city_pair, flight_numbers, from_da
             time.sleep(sleep_time)
 
 
-def fill_hourly_create_time(df):
+def plot_c1_trend(df, output_dir="."):
+    """
+    根据传入的 dataframe 绘制 price_total 随 update_hour 的趋势图,
+    并按照 baggage 分类进行分组绘制。
+    """
+    # 颜色与线型配置(按顺序循环使用)
+    colors = ['green', 'blue', 'red', 'brown']
+    linestyles = ['--', '--', '--', '--']
+
+    # 确保时间字段为 datetime 类型
+    if not hasattr(df['update_hour'], 'dt'):
+        df['update_hour'] = pd.to_datetime(df['update_hour'])
+    
+    city_pair = df['citypair'].mode().iloc[0]
+    flight_numbers = df['flight_numbers'].mode().iloc[0]
+    from_time = df['from_time'].mode().iloc[0]
+    
+    output_dir_temp = os.path.join(output_dir, city_pair)
+    os.makedirs(output_dir_temp, exist_ok=True)
+
+    # 创建图表对象
+    fig = plt.figure(figsize=(14, 8))
+
+    # 按 baggage_weight 分类绘制
+    for i, (baggage_value, group) in enumerate(df.groupby('baggage_weight')):
+        # 按时间排序
+        df_g = group.sort_values('update_hour').reset_index(drop=True)
+
+        # 找价格变化点:与前一行不同的价格即为变化点
+        # keep first row + change rows + last row
+        change_points = df_g.loc[
+            (df_g['price_total'] != df_g['price_total'].shift(1)) |
+            (df_g.index == 0) |
+            (df_g.index == len(df_g) - 1)  # 终点
+        ].drop_duplicates(subset=['update_hour'])
+
+        # 绘制阶梯线(平缓-突变)+ 变化点
+        plt.step(
+            change_points['update_hour'],
+            change_points['price_total'],
+            where='post',
+            color=colors[i % len(colors)],
+            linestyle=linestyles[i % len(linestyles)],
+            linewidth=2,
+            label=f"Baggage {baggage_value}"
+        )
+        
+        plt.scatter(
+            change_points['update_hour'],
+            change_points['price_total'],
+            s=30,
+            facecolors='white',
+            edgecolors=colors[i % len(colors)],
+            linewidths=2,
+            zorder=3,
+        )
+
+        # 添加注释 (小时数, 价格)
+        for _, row in change_points.iterrows():
+            text = f"({row['hours_until_departure']}, {row['price_total']})"
+            plt.annotate(
+                text,
+                xy=(row['update_hour'], row['price_total']),
+                xytext=(0, 0),  # 向右偏移
+                textcoords="offset points",
+                ha='left',
+                va='center',
+                fontsize=5,  # 字体稍小
+                color='gray',
+                alpha=0.8,
+                rotation=25,
+            )
+        
+        del change_points
+        del df_g
+
+    # 自动优化日期显示
+    plt.gcf().autofmt_xdate()
+    plt.xlabel('时刻', fontsize=12, fontproperties=font_prop)
+    plt.ylabel('价格', fontsize=12, fontproperties=font_prop)
+    plt.title(f'价格变化趋势 - 航线: {city_pair} 航班号: {flight_numbers}\n起飞时间: {from_time}',
+              fontsize=14, fontweight='bold', fontproperties=font_prop)
+
+    # 设置 x 轴刻度为每天
+    ax = plt.gca()
+    ax.xaxis.set_major_locator(mdates.DayLocator(interval=1))    # 每天一个主刻度
+    ax.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d'))  # 显示月-日
+    ax.xaxis.set_minor_locator(mdates.HourLocator(byhour=[12]))  # 指定在12:00显示副刻度
+    ax.xaxis.set_minor_formatter(mdates.DateFormatter(''))       # 输出空字符串
+
+    # 添加图例
+    plt.legend(bbox_to_anchor=(1.05, 1), loc='best', prop=font_prop)
+    plt.grid(True, alpha=0.3)
+    plt.tight_layout()
+
+    safe_flight = flight_numbers.replace(",", "_")
+    safe_dep_time = from_time.strftime("%Y-%m-%d %H%M%S")
+    save_file = f"{city_pair} {safe_flight} {safe_dep_time}.png"
+    output_path = os.path.join(output_dir_temp, save_file)
+    # 保存图片(在显示之前)
+    plt.savefig(output_path, dpi=300, bbox_inches='tight', facecolor='white')
+
+    # 关闭图形释放内存
+    plt.close(fig)
+    
+
+def fill_hourly_create_time(df, head_fill=0, rear_fill=0):
     """补齐成小时粒度数据"""
-    pass
+    df = df.copy()
+
+    # 1. 转 datetime
+    df['create_time'] = pd.to_datetime(df['create_time'])
+    df['from_time'] = pd.to_datetime(df['from_time'])
+
+    # 添加一个用于分组的小时字段
+    df['update_hour'] = df['create_time'].dt.floor('h')
+
+    # 2. 排序规则:同一小时内,按原始时间戳排序
+    # 假设你想保留最早的一条
+    df = df.sort_values(['update_hour', 'create_time'])
+
+    # 3. 按小时去重,保留该小时内最早(最晚)的一条
+    df = df.drop_duplicates(subset=['update_hour'], keep='last')   #  keep='first'  keep='last'
+
+    # 4. 标记原始数据
+    df['is_filled'] = 0
+
+    # 5. 排序 + 设索引
+    df = df.sort_values('update_hour').set_index('update_hour')
+
+    # 6. 构造完整小时轴
+    start_of_hour = df.index.min()  # 默认 第一天 最早 开始
+    if head_fill == 1:
+        start_of_hour = df.index.min().normalize()  # 强制 第一天 00:00 开始
+
+    end_of_hour = df.index.max()  # 默认 最后一天 最晚 结束
+    if rear_fill == 1:
+        end_of_hour = df.index.max().normalize() + pd.Timedelta(hours=23)  # 强制 最后一天 23:00 结束
+    elif rear_fill == 2:
+        if 'from_time' in df.columns:
+            last_dep_time = df['from_time'].iloc[-1]
+            if pd.notna(last_dep_time):
+                # 对齐到整点小时(向下取整)
+                end_of_hour = last_dep_time.floor('h')
+
+    full_index = pd.date_range(
+        start=start_of_hour,
+        end=end_of_hour,
+        freq='1h'
+    )
+    # 7. 按小时补齐
+    df = df.reindex(full_index)
+
+    # 先恢复 dtype(关键!)
+    df = df.infer_objects(copy=False)
+
+    # 8. 新增出来的行标记为 1
+    df['is_filled'] = df['is_filled'].fillna(1)
+
+    # 9. 前向填充
+    df = df.ffill()
+
+    # 10. 还原整型字段
+    int_cols = [
+        'ticket_amount',
+        'baggage_weight',
+        'is_filled',
+    ]
+    for col in int_cols:
+        if col in df.columns:
+            df[col] = df[col].astype('int64')
+
+    # 10.5 价格字段统一保留两位小数
+    price_cols = [
+        'price_base',
+        'price_tax',
+        'price_total'
+    ]
+    for col in price_cols:
+        if col in df.columns:
+            df[col] = df[col].astype('float64').round(2)
+    
+    # 10.6 新增:距离起飞还有多少小时
+    if 'from_time' in df.columns:
+        # 创建临时字段(整点)
+        df['from_hour'] = df['from_time'].dt.floor('h')
+        # 计算小时差 df.index 此时就是 update_hour
+        df['hours_until_departure'] = (
+                (df['from_hour'] - df.index) / pd.Timedelta(hours=1)
+        ).astype('int64')
+        # 新增:距离起飞还有多少天
+        df['days_to_departure'] = (df['hours_until_departure'] // 24).astype('int64')
+        # 删除临时字段
+        df = df.drop(columns=['from_hour'])
+    
+    # 11. 写回 update_hour
+    df['update_hour'] = df.index
+
+    # 12. 恢复普通索引
+    df = df.reset_index(drop=True)
+
+    return df
+
 
 def process_flight_numbers(args):
     process_id, db_config, city_pair, flight_numbers, from_date_begin, from_date_end, is_train, plot_flag, output_dir = args
@@ -191,10 +403,59 @@ def process_flight_numbers(args):
     
     try:
         # 查询
-        df_1 = query_flight_range_status(db, mongo_table_uo, city_pair, flight_numbers, from_date_begin, from_date_end)
+        df1 = query_flight_range_status(db, mongo_table_uo, city_pair, flight_numbers, from_date_begin, from_date_end)
+
+        if df1.empty:
+            return pd.DataFrame()
         
-        df_f1 = fill_hourly_create_time(df_1)
+        common_dep_dates = df1['from_date'].unique()
+        common_baggages = df1['baggage_weight'].unique()
+
+        list_mid = []
+        for dep_date in common_dep_dates:
+            # 起飞日期筛选
+            df_d1 = df1[df1["from_date"] == dep_date].copy()
+            list_f1 = []
+            for baggage in common_baggages:
+                # 行李配额筛选
+                df_b1 = df_d1[df_d1["baggage_weight"] == baggage].copy()
+                if df_b1.empty:
+                    print(f"[进程{process_id}] ⚠️ dep_date:{dep_date}, baggage:{baggage} 为空,跳过")
+                    continue
+                df_f1 = fill_hourly_create_time(df_b1, rear_fill=2)
+                list_f1.append(df_f1)
+                del df_f1
+                del df_b1
 
+            if list_f1:
+                df_c1 = pd.concat(list_f1, ignore_index=True)
+                if plot_flag:
+                    print(f"[进程{process_id}] ✅ dep_date:{dep_date}, 所有 baggage 数据合并完成,总形状: {df_c1.shape}")
+                    plot_c1_trend(df_c1, output_dir)
+                    print(f"[进程{process_id}] ✅ dep_date:{dep_date}, 所有 baggage 数据绘图完成")
+            else:
+                df_c1 = pd.DataFrame()
+                if plot_flag:
+                    print(f"[进程{process_id}] ⚠️ dep_date:{dep_date}, 所有 baggage 数据合并为空")
+            
+            del list_f1
+            list_mid.append(df_c1)
+
+            del df_c1
+            del df_d1
+        
+        if list_mid:
+            df_mid = pd.concat(list_mid, ignore_index=True)
+            print(f"[进程{process_id}] ✅ 航班号:{flight_numbers} 所有 起飞日期 数据合并完成,总形状: {df_mid.shape}")
+        else:
+            df_mid = pd.DataFrame()
+            print(f"[进程{process_id}] ⚠️ 航班号:{flight_numbers} 所有 起飞日期 数据合并为空")
+        
+        del list_mid
+        del df1
+        gc.collect()
+        print(f"[进程{process_id}] 结束处理航班号: {flight_numbers}")
+        return df_mid
         
     except Exception as e:
         print(f"[进程{process_id}] ❌ 处理航班号:{flight_numbers} 时发生异常: {e}")
@@ -210,7 +471,8 @@ def process_flight_numbers(args):
 
 def load_data(db_config, city_pair, from_date_begin, from_date_end, is_train=True, plot_flag=False, output_dir='.', 
               use_multiprocess=False, max_workers=None):
-
+    list_all = []
+    
     print(f"开始处理航线: {city_pair}")
     main_client, main_db = mongo_con_parse(db_config)
     all_groups = query_groups_of_city_pair(main_db, city_pair, mongo_table_uo)
@@ -219,23 +481,71 @@ def load_data(db_config, city_pair, from_date_begin, from_date_end, is_train=Tru
     all_groups_len = len(all_groups)
     print(f"该航线共有{all_groups_len}组航班号")
 
-    print("使用单进程处理")
-    process_id = 0
+    if use_multiprocess and all_groups_len > 1:
+        print(f"启用多进程处理,最大进程数: {max_workers}")
+        # 多进程处理
+        process_args = []
+        process_id = 0
+        for each_group in all_groups:
+            flight_numbers = each_group.get("flight_numbers", "未知")
+            process_id += 1
+            args = (process_id, db_config, city_pair, flight_numbers, from_date_begin, from_date_end, is_train, plot_flag, output_dir)
+            process_args.append(args)
 
-    for each_group in all_groups:
-        flight_numbers = each_group.get("flight_numbers", "未知")
-        args = (process_id, db_config, city_pair, flight_numbers, from_date_begin, from_date_end, is_train, plot_flag, output_dir)
-        try:
-            df_mid = process_flight_numbers(args)
+        with ProcessPoolExecutor(max_workers=max_workers) as executor:
+            future_to_group = {executor.submit(process_flight_numbers, args): each_group for args, each_group in zip(process_args, all_groups)}
+            for future in as_completed(future_to_group):
+                each_group = future_to_group[future]
+                flight_numbers = each_group.get("flight_numbers", "未知")
+                try:
+                    df_mid = future.result()
+                    if not df_mid.empty:
+                        list_all.append(df_mid)
+                        print(f"✅ 航班号:{flight_numbers} 处理完成")
+                    else:
+                        print(f"⚠️ 航班号:{flight_numbers} 处理结果为空")
+                except Exception as e:
+                    print(f"❌ 航班号:{flight_numbers} 处理异常: {e}")
             pass
 
-        except Exception as e:
-            print(f"❌ 航班号:{flight_numbers} 处理异常: {e}")
+    else:
+        print("使用单进程处理")
+        process_id = 0
+
+        for each_group in all_groups:
+            flight_numbers = each_group.get("flight_numbers", "未知")
+            args = (process_id, db_config, city_pair, flight_numbers, from_date_begin, from_date_end, is_train, plot_flag, output_dir)
+            try:
+                df_mid = process_flight_numbers(args)
+                if not df_mid.empty:
+                    list_all.append(df_mid)
+                    print(f"✅ 航班号:{flight_numbers} 处理完成")
+                else:
+                    print(f"⚠️ 航班号:{flight_numbers} 处理结果为空")
+            except Exception as e:
+                print(f"❌ 航班号:{flight_numbers} 处理异常: {e}")
+
+    print(f"结束处理航线: {city_pair}")
+    if list_all:
+        df_all = pd.concat(list_all, ignore_index=True)
+    else:
+        df_all = pd.DataFrame()
+    
+    print(f"本批次数据加载完毕, 总形状: {df_all.shape}")
+    del list_all
+    gc.collect()
 
+    return df_all
 
 
 if __name__ == "__main__":
     
+    cpu_cores = os.cpu_count()  # 你的系统是72
+    max_workers = min(8, cpu_cores)  # 最大不超过8个进程
+
+    output_dir = f"./photo"
+    os.makedirs(output_dir, exist_ok=True)
+
     from_date_begin = "2026-03-17"
     from_date_end = "2026-04-01"
 
@@ -247,7 +557,8 @@ if __name__ == "__main__":
         print(f"第 {idx} 组 :", uo_city_pair)
 
         start_time = time.time()
-        load_data(mongo_config, uo_city_pair, from_date_begin, from_date_end)
+        load_data(mongo_config, uo_city_pair, from_date_begin, from_date_end,
+                  plot_flag=True, output_dir=output_dir, use_multiprocess=True, max_workers=max_workers)
         end_time = time.time()
         run_time = round(end_time - start_time, 3)
         print(f"用时: {run_time} 秒")

BIN
simhei.ttf