Quellcode durchsuchen

适配多进程加载数据与跑DDP

node04 vor 1 Monat
Ursprung
Commit
d6c0a6b2fd
7 geänderte Dateien mit 75 neuen und 49 gelöschten Zeilen
  1. 37 38
      data_loader.py
  2. 12 0
      evaluate.py
  3. 8 1
      main_pe.py
  4. 10 6
      main_tr.py
  5. 4 0
      predict.py
  6. 2 2
      result_validate.py
  7. 2 2
      utils.py

+ 37 - 38
data_loader.py

@@ -6,8 +6,7 @@ from pymongo.errors import PyMongoError, ServerSelectionTimeoutError
 import pandas as pd
 import os
 import random
-import threading
-from concurrent.futures import ThreadPoolExecutor, as_completed
+from concurrent.futures import ProcessPoolExecutor, as_completed
 import numpy as np
 import matplotlib.pyplot as plt
 from matplotlib import font_manager
@@ -641,19 +640,19 @@ def plot_c12_trend(df, output_dir="."):
 
 
 def process_flight_group(args):
-    """处理单个航班号的线程函数(独立数据库连接)"""
-    thread_id, db_config, each_group, from_city, to_city, date_begin_s, date_end_s, is_hot, plot_flag, output_dir = args
+    """处理单个航班号的程函数(独立数据库连接)"""
+    process_id, db_config, each_group, from_city, to_city, date_begin_s, date_end_s, is_hot, plot_flag, output_dir = args
     flight_nums = each_group.get("flight_numbers")
     details = each_group.get("details")
 
-    print(f"[线程{thread_id}] 开始处理航班号: {flight_nums}")
+    print(f"[进程{process_id}] 开始处理航班号: {flight_nums}")
 
-    # 为每个线程创建独立的数据库连接
+    # 为每个程创建独立的数据库连接
     try:
         client, db = mongo_con_parse(db_config)
-        print(f"[线程{thread_id}] ✅ 数据库连接创建成功")
+        print(f"[进程{process_id}] ✅ 数据库连接创建成功")
     except Exception as e:
-        print(f"[线程{thread_id}] ❌ 数据库连接创建失败: {e}")
+        print(f"[进程{process_id}] ❌ 数据库连接创建失败: {e}")
         return pd.DataFrame()
 
     try:
@@ -667,7 +666,7 @@ def process_flight_group(args):
         
         # 保证远期表里有数据
         if df1.empty:
-            print(f"[线程{thread_id}] 航班号:{flight_nums} 远期表无数据, 跳过")
+            print(f"[进程{process_id}] 航班号:{flight_nums} 远期表无数据, 跳过")
             return pd.DataFrame()
         
         # 查询近期表
@@ -680,7 +679,7 @@ def process_flight_group(args):
             
         # 保证近期表里有数据
         if df2.empty:
-            print(f"[线程{thread_id}] 航班号:{flight_nums} 近期表无数据, 跳过")
+            print(f"[进程{process_id}] 航班号:{flight_nums} 近期表无数据, 跳过")
             return pd.DataFrame()
         
         # 起飞天数、行李配额以近期表的为主
@@ -722,7 +721,7 @@ def process_flight_group(args):
 
                 # 合并前检查是否都有数据
                 if df_b1.empty and df_b2.empty:
-                    print(f"[线程{thread_id}] ⚠️ dep_date:{dep_date}, baggage:{baggage} 远期表和近期表都为空,跳过")
+                    print(f"[进程{process_id}] ⚠️ dep_date:{dep_date}, baggage:{baggage} 远期表和近期表都为空,跳过")
                     continue
 
                 cols = ["seg1_flight_number", "seg1_dep_air_port", "seg1_arr_air_port",
@@ -743,13 +742,13 @@ def process_flight_group(args):
             if list_12:
                 df_c12 = pd.concat(list_12, ignore_index=True)
                 if plot_flag:
-                    print(f"[线程{thread_id}] ✅ dep_date:{dep_date}, 所有 baggage 数据合并完成,总形状: {df_c12.shape}")
+                    print(f"[进程{process_id}] ✅ dep_date:{dep_date}, 所有 baggage 数据合并完成,总形状: {df_c12.shape}")
                     plot_c12_trend(df_c12, output_dir)
-                    print(f"[线程{thread_id}] ✅ dep_date:{dep_date}, 所有 baggage 数据绘图完成")
+                    print(f"[进程{process_id}] ✅ dep_date:{dep_date}, 所有 baggage 数据绘图完成")
             else:
                 df_c12 = pd.DataFrame()
                 if plot_flag:
-                    print(f"[线程{thread_id}] ⚠️ dep_date:{dep_date}, 所有 baggage 数据合并为空")
+                    print(f"[进程{process_id}] ⚠️ dep_date:{dep_date}, 所有 baggage 数据合并为空")
 
             del list_12
             list_mid.append(df_c12)
@@ -761,33 +760,33 @@ def process_flight_group(args):
 
         if list_mid:
             df_mid = pd.concat(list_mid, ignore_index=True)
-            print(f"[线程{thread_id}] ✅ 航班号:{flight_nums} 所有 起飞日期 数据合并完成,总形状: {df_mid.shape}")
+            print(f"[进程{process_id}] ✅ 航班号:{flight_nums} 所有 起飞日期 数据合并完成,总形状: {df_mid.shape}")
         else:
             df_mid = pd.DataFrame()
-            print(f"[线程{thread_id}] ⚠️ 航班号:{flight_nums} 所有 起飞日期 数据合并为空")
+            print(f"[进程{process_id}] ⚠️ 航班号:{flight_nums} 所有 起飞日期 数据合并为空")
         
         del list_mid
         del df1
         del df2
         gc.collect()
-        print(f"[线程{thread_id}] 结束处理航班号: {flight_nums}")
+        print(f"[进程{process_id}] 结束处理航班号: {flight_nums}")
         return df_mid
     
     except Exception as e:
-        print(f"[线程{thread_id}] ❌ 处理航班号:{flight_nums} 时发生异常: {e}")
+        print(f"[进程{process_id}] ❌ 处理航班号:{flight_nums} 时发生异常: {e}")
         return pd.DataFrame()
     finally:
         # 确保关闭数据库连接
         try:
             client.close()
-            print(f"[线程{thread_id}] ✅ 数据库连接已关闭")
+            print(f"[进程{process_id}] ✅ 数据库连接已关闭")
         except:
             pass
 
 
 def load_train_data(db_config, flight_route_list, table_name, date_begin, date_end, output_dir='.', is_hot=1, plot_flag=False,
-                    use_multithread=False, max_workers=None):
-    """加载训练数据(支持多线程)"""
+                    use_multiprocess=False, max_workers=None):
+    """加载训练数据(支持多程)"""
     timestamp_str = datetime.now().strftime("%Y%m%d%H%M%S")
     date_begin_s = datetime.strptime(date_begin, "%Y-%m-%d").strftime("%Y%m%d")  # 查询时的格式
     date_end_s = datetime.strptime(date_end, "%Y-%m-%d").strftime("%Y%m%d")
@@ -800,7 +799,7 @@ def load_train_data(db_config, flight_route_list, table_name, date_begin, date_e
         route = f"{from_city}-{to_city}"
         print(f"开始处理航线: {route}")
 
-        # 在主线程中查询航班号分组(避免多线程重复查询)
+        # 在主进程中查询航班号分组(避免多进程重复查询)
         main_client, main_db = mongo_con_parse(db_config)
         all_groups = query_groups_of_city_code(main_db, from_city, to_city, table_name)
         main_client.close()
@@ -808,18 +807,18 @@ def load_train_data(db_config, flight_route_list, table_name, date_begin, date_e
         all_groups_len = len(all_groups)
         print(f"该航线共有{all_groups_len}个航班号")
         
-        if use_multithread and all_groups_len > 1:
+        if use_multiprocess and all_groups_len > 1:
             print(f"启用多线程处理,最大线程数: {max_workers}")
-            # 多线程处理
-            thread_args = []
-            thread_id = 0
+            # 多程处理
+            process_args = []
+            process_id = 0
             for each_group in all_groups:
-                thread_id += 1
-                args = (thread_id, db_config, each_group, from_city, to_city, date_begin_s, date_end_s, is_hot, plot_flag, output_dir)
-                thread_args.append(args)
+                process_id += 1
+                args = (process_id, db_config, each_group, from_city, to_city, date_begin_s, date_end_s, is_hot, plot_flag, output_dir)
+                process_args.append(args)
             
-            with ThreadPoolExecutor(max_workers=max_workers) as executor:
-                future_to_group = {executor.submit(process_flight_group, args): each_group for args, each_group in zip(thread_args, all_groups)}
+            with ProcessPoolExecutor(max_workers=max_workers) as executor:
+                future_to_group = {executor.submit(process_flight_group, args): each_group for args, each_group in zip(process_args, all_groups)}
                 
                 for future in as_completed(future_to_group):
                     each_group = future_to_group[future]
@@ -835,11 +834,11 @@ def load_train_data(db_config, flight_route_list, table_name, date_begin, date_e
                         print(f"❌ 航班号:{flight_nums} 处理异常: {e}")
 
         else:
-            # 单线程处理(线程编号为0)
-            print("使用单线程处理")
-            thread_id = 0
+            # 单进程处理(进程编号为0)
+            print("使用单程处理")
+            process_id = 0
             for each_group in all_groups:
-                args = (thread_id, db_config, each_group, from_city, to_city, date_begin_s, date_end_s, is_hot, plot_flag, output_dir)
+                args = (process_id, db_config, each_group, from_city, to_city, date_begin_s, date_end_s, is_hot, plot_flag, output_dir)
                 flight_nums = each_group.get("flight_numbers", "未知")
                 try:
                     df_mid = process_flight_group(args)
@@ -1000,7 +999,7 @@ if __name__ == "__main__":
     from utils import chunk_list_with_index
 
     cpu_cores = os.cpu_count()  # 你的系统是72
-    max_workers = min(16, cpu_cores)  # 最大不超过16个线
+    max_workers = min(8, cpu_cores)  # 最大不超过8个进
 
     output_dir = f"./output"
     os.makedirs(output_dir, exist_ok=True)
@@ -1009,7 +1008,7 @@ if __name__ == "__main__":
     date_begin = "2025-12-07"
     date_end = datetime.today().strftime("%Y-%m-%d")
 
-    flight_route_list = vj_flight_route_list_hot[0:]  # 热门 vj_flight_route_list_hot  冷门 vj_flight_route_list_nothot
+    flight_route_list = vj_flight_route_list_hot[4:]  # 热门 vj_flight_route_list_hot  冷门 vj_flight_route_list_nothot
     table_name = CLEAN_VJ_HOT_NEAR_INFO_TAB  # 热门 CLEAN_VJ_HOT_NEAR_INFO_TAB  冷门 CLEAN_VJ_NOTHOT_NEAR_INFO_TAB
     is_hot = 1   # 1 热门 0 冷门
     group_size = 1
@@ -1021,7 +1020,7 @@ if __name__ == "__main__":
         print(f"第 {idx} 组 :", group_route_list)
         start_time = time.time()
         load_train_data(mongodb_config, group_route_list, table_name, date_begin, date_end, output_dir, is_hot, plot_flag=False,
-                        use_multithread=False, max_workers=max_workers)
+                        use_multiprocess=True, max_workers=max_workers)
         end_time = time.time()
         run_time = round(end_time - start_time, 3)
         print(f"用时: {run_time} 秒")

+ 12 - 0
evaluate.py

@@ -116,6 +116,18 @@ def evaluate_model_distribute(model, device, sequences, targets, group_ids, batc
             'Predicted_Will_Price_Drop': y_preds_class_labels,
         })
 
+        # target_time_to_drop 转可空整数
+        results_df['target_time_to_drop'] = pd.to_numeric(
+            results_df['target_time_to_drop'], errors='coerce'
+        ).astype('Int64')
+
+        # target_amount_of_drop 转浮点数
+        results_df['target_amount_of_drop'] = pd.to_numeric(
+            results_df['target_amount_of_drop'], errors='coerce'
+        ).astype(float)
+
+        # print(results_df.dtypes)
+
         # 数值处理
         threshold = 1e-3
         numeric_columns = ['probability', 'target_amount_of_drop', 'target_time_to_drop']

+ 8 - 1
main_pe.py

@@ -50,12 +50,16 @@ def start_predict():
     #     except Exception as e:
     #         print(f"remove {csv_path} error: {str(e)}")
 
+    cpu_cores = os.cpu_count()  # 你的系统是72
+    max_workers = min(4, cpu_cores)  # 最大不超过4个进程
+
     model, _ = initialize_model()
 
     # 当前时间,取整时
     current_time = datetime.now() 
     hourly_time = current_time.replace(minute=0, second=0, microsecond=0)
     pred_time_str = hourly_time.strftime("%Y%m%d%H%M")
+    print(f"预测时间(取整): {pred_time_str}")
 
     # 预测时间范围,满足起飞时间 在28小时后到40小时后
     pred_hour_begin = hourly_time + timedelta(hours=28)
@@ -114,7 +118,8 @@ def start_predict():
         
         # 加载测试数据 (仅仅是时间段取到后天)
         start_time = time.time()
-        df_test = load_train_data(mongodb_config, group_route_list, table_name, pred_date_begin, pred_date_end, output_dir, is_hot)
+        df_test = load_train_data(mongodb_config, group_route_list, table_name, pred_date_begin, pred_date_end, output_dir, is_hot,
+                                  use_multiprocess=True, max_workers=max_workers)
         end_time = time.time()
         run_time = round(end_time - start_time, 3)
         print(f"用时: {run_time} 秒")
@@ -190,6 +195,8 @@ def start_predict():
         predict_future_distribute(model, sequences, group_ids, target_scaler=target_scaler, output_dir=output_dir, pred_time_str=pred_time_str)
 
     print("所有批次的预测结束")
+    print()
+
     # 所有批次的预测结束后, 统一过滤处理
     # csv_file = 'future_predictions.csv'
     # csv_path = os.path.join(output_dir, csv_file)

+ 10 - 6
main_tr.py

@@ -118,11 +118,11 @@ def start_train():
     # 仅在 rank == 0 时要做的
     if rank == 0:
         # 如果处理中断, 注释掉以下代码
-        batch_dir = os.path.join(output_dir, "batches")
-        try:
-            shutil.rmtree(batch_dir)
-        except FileNotFoundError:
-            print(f"rank:{rank}, {batch_dir} not found")
+        # batch_dir = os.path.join(output_dir, "batches")
+        # try:
+        #     shutil.rmtree(batch_dir)
+        # except FileNotFoundError:
+        #     print(f"rank:{rank}, {batch_dir} not found")
 
         # 如果处理中断, 注释掉以下代码
         csv_file_list = ['evaluate_results.csv']
@@ -133,6 +133,9 @@ def start_train():
             except Exception as e:
                 print(f"remove {csv_path}: {str(e)}")
 
+        cpu_cores = os.cpu_count()  # 你的系统是72
+        max_workers = min(8, cpu_cores)  # 最大不超过8个进程
+
         # 确保目录存在
         os.makedirs(output_dir, exist_ok=True) 
         os.makedirs(photo_dir, exist_ok=True)
@@ -259,7 +262,8 @@ def start_train():
             
             # 加载训练数据
             start_time = time.time()
-            df_train = load_train_data(mongodb_config, group_route_list, table_name, date_begin, date_end, output_dir, is_hot)
+            df_train = load_train_data(mongodb_config, group_route_list, table_name, date_begin, date_end, output_dir, is_hot,
+                                       use_multiprocess=True, max_workers=max_workers)
             end_time = time.time()
             run_time = round(end_time - start_time, 3)
             print(f"用时: {run_time} 秒")

+ 4 - 0
predict.py

@@ -65,6 +65,10 @@ def predict_future_distribute(model, sequences, group_ids, batch_size=16, target
     for col in numeric_columns:
         results_df[col] = results_df[col].where(results_df[col].abs() >= threshold, 0)
     
+    # 修改预测保存路径
+    output_dir = './predictions'
+    os.makedirs(output_dir, exist_ok=True)
+
     csv_path1 = os.path.join(output_dir, f'future_predictions_{pred_time_str}.csv')
     results_df.to_csv(csv_path1, mode='a', index=False, header=not os.path.exists(csv_path1), encoding='utf-8-sig')
 

+ 2 - 2
result_validate.py

@@ -9,7 +9,7 @@ def validate_process(node, date, pred_time_str):
     output_dir = f"./validate/{node}_{date}"
     os.makedirs(output_dir, exist_ok=True)
 
-    object_dir = "./data_shards"
+    object_dir = "./predictions"
     csv_file = f'future_predictions_{pred_time_str}.csv'  
     csv_path = os.path.join(object_dir, csv_file)
 
@@ -113,5 +113,5 @@ def validate_process(node, date, pred_time_str):
 
 
 if __name__ == "__main__":
-    node, date, pred_time_str = "node0108", "0109", "202601091100"
+    node, date, pred_time_str = "node0108", "0110", "202601100800"
     validate_process(node, date, pred_time_str)

+ 2 - 2
utils.py

@@ -82,8 +82,8 @@ def create_fixed_length_sequences(df, features, target_vars, threshold=36, input
                             str(last_row['crawl_date']), 
                             ]
             if is_train:
-                next_name_li.append(last_row['target_amount_of_drop'])
-                next_name_li.append(last_row['target_time_to_drop'])
+                next_name_li.append(str(last_row['target_amount_of_drop']))
+                next_name_li.append(str(last_row['target_time_to_drop']))
 
             name_c.extend(next_name_li)
             group_ids.append(tuple(name_c))