|
@@ -34,6 +34,18 @@ def safe_train_test_split(*arrays, test_size=0.2, random_state=None, stratify=No
|
|
|
stratify=stratify
|
|
stratify=stratify
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
+def adaptive_test_size(n_samples):
|
|
|
|
|
+ """
|
|
|
|
|
+ 根据总样本数,自适应 test_size
|
|
|
|
|
+ """
|
|
|
|
|
+ if n_samples < 50:
|
|
|
|
|
+ return 0.4 # 极小样本,强行保验证
|
|
|
|
|
+ elif n_samples < 100:
|
|
|
|
|
+ return 0.3
|
|
|
|
|
+ # elif n_samples < 300:
|
|
|
|
|
+ # return 0.25
|
|
|
|
|
+ else:
|
|
|
|
|
+ return 0.2 # 常规情况
|
|
|
|
|
|
|
|
# 分布式数据集准备
|
|
# 分布式数据集准备
|
|
|
def prepare_data_distribute(sequences, targets, group_ids, flag_distributed=False, rank=0, local_rank=0, world_size=1):
|
|
def prepare_data_distribute(sequences, targets, group_ids, flag_distributed=False, rank=0, local_rank=0, world_size=1):
|
|
@@ -62,11 +74,23 @@ def prepare_data_distribute(sequences, targets, group_ids, flag_distributed=Fals
|
|
|
group_ids_filtered = [g for i, g in enumerate(group_ids) if valid_mask[i]]
|
|
group_ids_filtered = [g for i, g in enumerate(group_ids) if valid_mask[i]]
|
|
|
targets_array_filtered = targets_array[valid_mask]
|
|
targets_array_filtered = targets_array[valid_mask]
|
|
|
|
|
|
|
|
|
|
+ n_samples = len(sequences_filtered)
|
|
|
|
|
+ test_size = adaptive_test_size(n_samples)
|
|
|
|
|
+
|
|
|
|
|
+ # 只有在样本数本身不太小的情况下,才做最小验证集兜底
|
|
|
|
|
+ min_val_samples = 10
|
|
|
|
|
+ if n_samples >= 2 * min_val_samples:
|
|
|
|
|
+ expected_val = int(n_samples * test_size)
|
|
|
|
|
+ if expected_val < min_val_samples:
|
|
|
|
|
+ test_size = min_val_samples / n_samples
|
|
|
|
|
+
|
|
|
|
|
+ print(f"test_size: {test_size}")
|
|
|
|
|
+
|
|
|
# 第一步:将28样本拆分为训练集(80%)和临时集(20%)
|
|
# 第一步:将28样本拆分为训练集(80%)和临时集(20%)
|
|
|
train_28, temp_28, train_28_targets, temp_28_targets, train_28_gids, temp_28_gids = safe_train_test_split(
|
|
train_28, temp_28, train_28_targets, temp_28_targets, train_28_gids, temp_28_gids = safe_train_test_split(
|
|
|
sequences_filtered, targets_filtered, group_ids_filtered,
|
|
sequences_filtered, targets_filtered, group_ids_filtered,
|
|
|
stratify=targets_array_filtered,
|
|
stratify=targets_array_filtered,
|
|
|
- test_size=0.2,
|
|
|
|
|
|
|
+ test_size=test_size,
|
|
|
random_state=42,
|
|
random_state=42,
|
|
|
rank=rank,
|
|
rank=rank,
|
|
|
local_rank=local_rank
|
|
local_rank=local_rank
|