def test_split_by_count(): X_vals = np.arange(100) y_vals = np.arange(100, 200) data_dict = {'X': X_vals, 'y': y_vals} split_counts = [4, 50, 100] ds = DatasetSplitter(seed=0) split_dict = ds.split(input_dict=data_dict, split_counts=split_counts) split_x_vals, split_y_vals = [], [] for split in split_counts: current_split = split_dict[str(split)] assert len(current_split['X']) == split if split_x_vals == []: # first split split_x_vals = current_split['X'] split_y_vals = current_split['y'] else: # make sure all all previous values are in current split current_x_vals = current_split['X'] current_y_vals = current_split['y'] assert np.all(np.isin(split_x_vals, current_x_vals)) assert np.all(np.isin(split_y_vals, current_y_vals)) # update counter with current values split_x_vals = current_x_vals split_y_vals = current_y_vals # same seed should produce same values ds = DatasetSplitter(seed=0) split_dict_same_seed = ds.split(input_dict=data_dict, split_counts=split_counts) for split in split_dict_same_seed: current_split = split_dict_same_seed[split] original_split = split_dict[split] for data in current_split: assert np.array_equal(current_split[data], original_split[data]) # different seed should produce different values ds = DatasetSplitter(seed=1) split_dict_same_seed = ds.split(input_dict=data_dict, split_counts=split_counts) for split in split_dict_same_seed: current_split = split_dict_same_seed[split] original_split = split_dict[split] for data in current_split: assert not np.array_equal(current_split[data], original_split[data]) # setting minimum size split_dict = ds.split(input_dict=data_dict, min_size=10, split_counts=split_counts) assert len(split_dict['4']['X']) == 10
def test__duplicate_indices(): test_indices = [np.arange(5), np.arange(1), np.arange(7)] min_size = 8 for test_idx in test_indices: ds = DatasetSplitter() duplicated_indices = ds._duplicate_indices(indices=test_idx, min_size=min_size) assert len(duplicated_indices) == min_size # all of the same indices are still present assert set(test_idx) == set(duplicated_indices)
def test__validate_dict(): valid_dict = {'X': 1, 'y': 2} ds = DatasetSplitter() ds._validate_dict(valid_dict) invalid_dict = {'X': 1, 'y1': 2} with pytest.raises(ValueError): ds._validate_dict(invalid_dict) invalid_dict = {'X1': 1, 'y': 2} with pytest.raises(ValueError): ds._validate_dict(invalid_dict)
def test__validate_split_counts(): # unsorted split_counts get sorted ds = DatasetSplitter() split_counts = [5, 1, 10] valid_counts = ds._validate_split_counts(split_counts=split_counts) assert valid_counts == sorted(valid_counts) with pytest.raises(ValueError): # first split_count is size 0 split_counts = [0, 1, 4] _ = ds._validate_split_counts(split_counts=split_counts) with pytest.raises(ValueError): # duplicate split_counts split_counts = [4, 8, 8] _ = ds._validate_split_counts(split_counts=split_counts) with pytest.raises(ValueError): # non-integer split counts split_counts = [4, 0.25, 7] _ = ds._validate_split_counts(split_counts=split_counts)
def test__validate_split_proportions(): # unsorted split_proportions get sorted ds = DatasetSplitter() split_proportions = [0.8, 0.3, 0.5] valid_proportions = ds._validate_split_proportions( split_proportions=split_proportions) assert valid_proportions == sorted(valid_proportions) with pytest.raises(ValueError): # first split_proportion is size 0 split_proportions = [0, 0.25, 0.5] _ = ds._validate_split_proportions(split_proportions=split_proportions) with pytest.raises(ValueError): # last split_proportion is greater than 1 split_proportions = [0.1, 0.25, 1.5] _ = ds._validate_split_proportions(split_proportions=split_proportions) with pytest.raises(ValueError): # duplicate split_proportions split_proportions = [0.1, 0.1, 1] _ = ds._validate_split_proportions(split_proportions=split_proportions)
def test__init__(): seed = 123 ds = DatasetSplitter(seed=seed) assert ds.seed == seed
def test_split_by_proportion(): X_vals = np.arange(100) y_vals = np.arange(100, 200) data_dict = {'X': X_vals, 'y': y_vals} split_proportions = [0.1, 0.5, 1] ds = DatasetSplitter(seed=0) split_dict = ds.split(input_dict=data_dict, split_proportions=split_proportions) split_x_vals, split_y_vals = [], [] for split in split_proportions: current_split = split_dict[str(split)] assert len(current_split['X']) == int(100 * split) if split_x_vals == []: # first split split_x_vals = current_split['X'] split_y_vals = current_split['y'] else: # make sure all all previous values are in current split current_x_vals = current_split['X'] current_y_vals = current_split['y'] assert np.all(np.isin(split_x_vals, current_x_vals)) assert np.all(np.isin(split_y_vals, current_y_vals)) # update counter with current values split_x_vals = current_x_vals split_y_vals = current_y_vals # same seed should produce same values ds = DatasetSplitter(seed=0) split_dict_same_seed = ds.split(input_dict=data_dict, split_proportions=split_proportions) for split in split_dict_same_seed: current_split = split_dict_same_seed[split] original_split = split_dict[split] for data in current_split: assert np.array_equal(current_split[data], original_split[data]) # different seed should produce different values ds = DatasetSplitter(seed=1) split_dict_same_seed = ds.split(input_dict=data_dict, split_proportions=split_proportions) for split in split_dict_same_seed: current_split = split_dict_same_seed[split] original_split = split_dict[split] for data in current_split: assert not np.array_equal(current_split[data], original_split[data]) # split corresponding to fewer than 1 image returns a single image split_proportions = [0.001, 0.3, 1] ds = DatasetSplitter(seed=0) split_dict = ds.split(input_dict=data_dict, split_proportions=split_proportions) assert len(split_dict['0.001']['X']) == 1 # setting minimum size split_dict = ds.split(input_dict=data_dict, min_size=10, split_proportions=split_proportions) assert len(split_dict['0.001']['X']) == 10