|
|
@@ -6,8 +6,7 @@ from pymongo.errors import PyMongoError, ServerSelectionTimeoutError
|
|
|
import pandas as pd
|
|
|
import os
|
|
|
import random
|
|
|
-import threading
|
|
|
-from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
|
+from concurrent.futures import ProcessPoolExecutor, as_completed
|
|
|
import numpy as np
|
|
|
import matplotlib.pyplot as plt
|
|
|
from matplotlib import font_manager
|
|
|
@@ -641,19 +640,19 @@ def plot_c12_trend(df, output_dir="."):
|
|
|
|
|
|
|
|
|
def process_flight_group(args):
|
|
|
- """处理单个航班号的线程函数(独立数据库连接)"""
|
|
|
- thread_id, db_config, each_group, from_city, to_city, date_begin_s, date_end_s, is_hot, plot_flag, output_dir = args
|
|
|
+ """处理单个航班号的进程函数(独立数据库连接)"""
|
|
|
+ process_id, db_config, each_group, from_city, to_city, date_begin_s, date_end_s, is_hot, plot_flag, output_dir = args
|
|
|
flight_nums = each_group.get("flight_numbers")
|
|
|
details = each_group.get("details")
|
|
|
|
|
|
- print(f"[线程{thread_id}] 开始处理航班号: {flight_nums}")
|
|
|
+ print(f"[进程{process_id}] 开始处理航班号: {flight_nums}")
|
|
|
|
|
|
- # 为每个线程创建独立的数据库连接
|
|
|
+ # 为每个进程创建独立的数据库连接
|
|
|
try:
|
|
|
client, db = mongo_con_parse(db_config)
|
|
|
- print(f"[线程{thread_id}] ✅ 数据库连接创建成功")
|
|
|
+ print(f"[进程{process_id}] ✅ 数据库连接创建成功")
|
|
|
except Exception as e:
|
|
|
- print(f"[线程{thread_id}] ❌ 数据库连接创建失败: {e}")
|
|
|
+ print(f"[进程{process_id}] ❌ 数据库连接创建失败: {e}")
|
|
|
return pd.DataFrame()
|
|
|
|
|
|
try:
|
|
|
@@ -667,7 +666,7 @@ def process_flight_group(args):
|
|
|
|
|
|
# 保证远期表里有数据
|
|
|
if df1.empty:
|
|
|
- print(f"[线程{thread_id}] 航班号:{flight_nums} 远期表无数据, 跳过")
|
|
|
+ print(f"[进程{process_id}] 航班号:{flight_nums} 远期表无数据, 跳过")
|
|
|
return pd.DataFrame()
|
|
|
|
|
|
# 查询近期表
|
|
|
@@ -680,7 +679,7 @@ def process_flight_group(args):
|
|
|
|
|
|
# 保证近期表里有数据
|
|
|
if df2.empty:
|
|
|
- print(f"[线程{thread_id}] 航班号:{flight_nums} 近期表无数据, 跳过")
|
|
|
+ print(f"[进程{process_id}] 航班号:{flight_nums} 近期表无数据, 跳过")
|
|
|
return pd.DataFrame()
|
|
|
|
|
|
# 起飞天数、行李配额以近期表的为主
|
|
|
@@ -722,7 +721,7 @@ def process_flight_group(args):
|
|
|
|
|
|
# 合并前检查是否都有数据
|
|
|
if df_b1.empty and df_b2.empty:
|
|
|
- print(f"[线程{thread_id}] ⚠️ dep_date:{dep_date}, baggage:{baggage} 远期表和近期表都为空,跳过")
|
|
|
+ print(f"[进程{process_id}] ⚠️ dep_date:{dep_date}, baggage:{baggage} 远期表和近期表都为空,跳过")
|
|
|
continue
|
|
|
|
|
|
cols = ["seg1_flight_number", "seg1_dep_air_port", "seg1_arr_air_port",
|
|
|
@@ -743,13 +742,13 @@ def process_flight_group(args):
|
|
|
if list_12:
|
|
|
df_c12 = pd.concat(list_12, ignore_index=True)
|
|
|
if plot_flag:
|
|
|
- print(f"[线程{thread_id}] ✅ dep_date:{dep_date}, 所有 baggage 数据合并完成,总形状: {df_c12.shape}")
|
|
|
+ print(f"[进程{process_id}] ✅ dep_date:{dep_date}, 所有 baggage 数据合并完成,总形状: {df_c12.shape}")
|
|
|
plot_c12_trend(df_c12, output_dir)
|
|
|
- print(f"[线程{thread_id}] ✅ dep_date:{dep_date}, 所有 baggage 数据绘图完成")
|
|
|
+ print(f"[进程{process_id}] ✅ dep_date:{dep_date}, 所有 baggage 数据绘图完成")
|
|
|
else:
|
|
|
df_c12 = pd.DataFrame()
|
|
|
if plot_flag:
|
|
|
- print(f"[线程{thread_id}] ⚠️ dep_date:{dep_date}, 所有 baggage 数据合并为空")
|
|
|
+ print(f"[进程{process_id}] ⚠️ dep_date:{dep_date}, 所有 baggage 数据合并为空")
|
|
|
|
|
|
del list_12
|
|
|
list_mid.append(df_c12)
|
|
|
@@ -761,33 +760,33 @@ def process_flight_group(args):
|
|
|
|
|
|
if list_mid:
|
|
|
df_mid = pd.concat(list_mid, ignore_index=True)
|
|
|
- print(f"[线程{thread_id}] ✅ 航班号:{flight_nums} 所有 起飞日期 数据合并完成,总形状: {df_mid.shape}")
|
|
|
+ print(f"[进程{process_id}] ✅ 航班号:{flight_nums} 所有 起飞日期 数据合并完成,总形状: {df_mid.shape}")
|
|
|
else:
|
|
|
df_mid = pd.DataFrame()
|
|
|
- print(f"[线程{thread_id}] ⚠️ 航班号:{flight_nums} 所有 起飞日期 数据合并为空")
|
|
|
+ print(f"[进程{process_id}] ⚠️ 航班号:{flight_nums} 所有 起飞日期 数据合并为空")
|
|
|
|
|
|
del list_mid
|
|
|
del df1
|
|
|
del df2
|
|
|
gc.collect()
|
|
|
- print(f"[线程{thread_id}] 结束处理航班号: {flight_nums}")
|
|
|
+ print(f"[进程{process_id}] 结束处理航班号: {flight_nums}")
|
|
|
return df_mid
|
|
|
|
|
|
except Exception as e:
|
|
|
- print(f"[线程{thread_id}] ❌ 处理航班号:{flight_nums} 时发生异常: {e}")
|
|
|
+ print(f"[进程{process_id}] ❌ 处理航班号:{flight_nums} 时发生异常: {e}")
|
|
|
return pd.DataFrame()
|
|
|
finally:
|
|
|
# 确保关闭数据库连接
|
|
|
try:
|
|
|
client.close()
|
|
|
- print(f"[线程{thread_id}] ✅ 数据库连接已关闭")
|
|
|
+ print(f"[进程{process_id}] ✅ 数据库连接已关闭")
|
|
|
except:
|
|
|
pass
|
|
|
|
|
|
|
|
|
def load_train_data(db_config, flight_route_list, table_name, date_begin, date_end, output_dir='.', is_hot=1, plot_flag=False,
|
|
|
- use_multithread=False, max_workers=None):
|
|
|
- """加载训练数据(支持多线程)"""
|
|
|
+ use_multiprocess=False, max_workers=None):
|
|
|
+ """加载训练数据(支持多进程)"""
|
|
|
timestamp_str = datetime.now().strftime("%Y%m%d%H%M%S")
|
|
|
date_begin_s = datetime.strptime(date_begin, "%Y-%m-%d").strftime("%Y%m%d") # 查询时的格式
|
|
|
date_end_s = datetime.strptime(date_end, "%Y-%m-%d").strftime("%Y%m%d")
|
|
|
@@ -800,7 +799,7 @@ def load_train_data(db_config, flight_route_list, table_name, date_begin, date_e
|
|
|
route = f"{from_city}-{to_city}"
|
|
|
print(f"开始处理航线: {route}")
|
|
|
|
|
|
- # 在主线程中查询航班号分组(避免多线程重复查询)
|
|
|
+ # 在主进程中查询航班号分组(避免多进程重复查询)
|
|
|
main_client, main_db = mongo_con_parse(db_config)
|
|
|
all_groups = query_groups_of_city_code(main_db, from_city, to_city, table_name)
|
|
|
main_client.close()
|
|
|
@@ -808,18 +807,18 @@ def load_train_data(db_config, flight_route_list, table_name, date_begin, date_e
|
|
|
all_groups_len = len(all_groups)
|
|
|
print(f"该航线共有{all_groups_len}个航班号")
|
|
|
|
|
|
- if use_multithread and all_groups_len > 1:
|
|
|
+ if use_multiprocess and all_groups_len > 1:
|
|
|
print(f"启用多线程处理,最大线程数: {max_workers}")
|
|
|
- # 多线程处理
|
|
|
- thread_args = []
|
|
|
- thread_id = 0
|
|
|
+ # 多进程处理
|
|
|
+ process_args = []
|
|
|
+ process_id = 0
|
|
|
for each_group in all_groups:
|
|
|
- thread_id += 1
|
|
|
- args = (thread_id, db_config, each_group, from_city, to_city, date_begin_s, date_end_s, is_hot, plot_flag, output_dir)
|
|
|
- thread_args.append(args)
|
|
|
+ process_id += 1
|
|
|
+ args = (process_id, db_config, each_group, from_city, to_city, date_begin_s, date_end_s, is_hot, plot_flag, output_dir)
|
|
|
+ process_args.append(args)
|
|
|
|
|
|
- with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
|
- future_to_group = {executor.submit(process_flight_group, args): each_group for args, each_group in zip(thread_args, all_groups)}
|
|
|
+ with ProcessPoolExecutor(max_workers=max_workers) as executor:
|
|
|
+ future_to_group = {executor.submit(process_flight_group, args): each_group for args, each_group in zip(process_args, all_groups)}
|
|
|
|
|
|
for future in as_completed(future_to_group):
|
|
|
each_group = future_to_group[future]
|
|
|
@@ -835,11 +834,11 @@ def load_train_data(db_config, flight_route_list, table_name, date_begin, date_e
|
|
|
print(f"❌ 航班号:{flight_nums} 处理异常: {e}")
|
|
|
|
|
|
else:
|
|
|
- # 单线程处理(线程编号为0)
|
|
|
- print("使用单线程处理")
|
|
|
- thread_id = 0
|
|
|
+ # 单进程处理(进程编号为0)
|
|
|
+ print("使用单进程处理")
|
|
|
+ process_id = 0
|
|
|
for each_group in all_groups:
|
|
|
- args = (thread_id, db_config, each_group, from_city, to_city, date_begin_s, date_end_s, is_hot, plot_flag, output_dir)
|
|
|
+ args = (process_id, db_config, each_group, from_city, to_city, date_begin_s, date_end_s, is_hot, plot_flag, output_dir)
|
|
|
flight_nums = each_group.get("flight_numbers", "未知")
|
|
|
try:
|
|
|
df_mid = process_flight_group(args)
|
|
|
@@ -1000,7 +999,7 @@ if __name__ == "__main__":
|
|
|
from utils import chunk_list_with_index
|
|
|
|
|
|
cpu_cores = os.cpu_count() # 你的系统是72
|
|
|
- max_workers = min(16, cpu_cores) # 最大不超过16个线程
|
|
|
+ max_workers = min(8, cpu_cores) # 最大不超过8个进程
|
|
|
|
|
|
output_dir = f"./output"
|
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
@@ -1009,7 +1008,7 @@ if __name__ == "__main__":
|
|
|
date_begin = "2025-12-07"
|
|
|
date_end = datetime.today().strftime("%Y-%m-%d")
|
|
|
|
|
|
- flight_route_list = vj_flight_route_list_hot[0:] # 热门 vj_flight_route_list_hot 冷门 vj_flight_route_list_nothot
|
|
|
+ flight_route_list = vj_flight_route_list_hot[4:] # 热门 vj_flight_route_list_hot 冷门 vj_flight_route_list_nothot
|
|
|
table_name = CLEAN_VJ_HOT_NEAR_INFO_TAB # 热门 CLEAN_VJ_HOT_NEAR_INFO_TAB 冷门 CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
|
|
|
is_hot = 1 # 1 热门 0 冷门
|
|
|
group_size = 1
|
|
|
@@ -1021,7 +1020,7 @@ if __name__ == "__main__":
|
|
|
print(f"第 {idx} 组 :", group_route_list)
|
|
|
start_time = time.time()
|
|
|
load_train_data(mongodb_config, group_route_list, table_name, date_begin, date_end, output_dir, is_hot, plot_flag=False,
|
|
|
- use_multithread=False, max_workers=max_workers)
|
|
|
+ use_multiprocess=True, max_workers=max_workers)
|
|
|
end_time = time.time()
|
|
|
run_time = round(end_time - start_time, 3)
|
|
|
print(f"用时: {run_time} 秒")
|