示例#1
0
def test__reshape_dict_by_image(tmp_path, mocker):
    mocker.patch('caliban_toolbox.dataset_builder.compute_cell_size', mocked_compute_cell_size)
    _create_minimal_dataset(tmp_path)
    db = DatasetBuilder(tmp_path)

    # create dict
    tissues = ['tissue1', 'tissue2', 'tissue3']
    platforms = ['platform1', 'platform2', 'platform3']
    data_dict = _create_test_dict(tissues=tissues, platforms=platforms)

    # same size as input data
    output_shape = (40, 40)

    reshaped_dict = db._reshape_dict(data_dict=data_dict, resize='by_image',
                                     output_shape=output_shape)
    X_reshaped, tissue_list_reshaped = reshaped_dict['X'], reshaped_dict['tissue_list']
    assert X_reshaped.shape[1:3] == output_shape

    # make sure that for each tissue, the arrays with correct value have correct tissue label
    for constant_val, tissue in enumerate(tissues):
        # each image was tagged with a different, compute that here
        image_val = np.max(X_reshaped, axis=(1, 2, 3))

        tissue_idx = image_val == constant_val
        tissue_labels = np.array(tissue_list_reshaped)[tissue_idx]
        assert np.all(tissue_labels == tissue)

        # There were originally 5 images of each tissue type. Tissue types with even values
        # are resized to be 2x larger on each dimension, and should have 4x more images
        if constant_val % 2 == 0:
            assert len(tissue_labels) == 5 * 4
        # tissue types with odd values are resized to be smaller, which leads to same number
        # of unique images due to padding
        else:
            assert len(tissue_labels) == 5
示例#2
0
def test__validate_dataset(tmp_path):
    _create_minimal_dataset(tmp_path)
    db = DatasetBuilder(dataset_path=tmp_path)

    # bad path
    with pytest.raises(ValueError):
        db._validate_dataset('bad_path')

    dataset_path = os.path.join(tmp_path, 'example_dataset')
    os.makedirs(dataset_path)

    # no folders in supplied dataset
    with pytest.raises(ValueError):
        db._validate_dataset(dataset_path)

    os.makedirs(os.path.join(dataset_path, 'experiment_1'))
    Path(os.path.join(dataset_path, 'experiment_1', 'example_file.npz')).touch()

    # supplied experiment has an NPZ and no metadata file
    with pytest.raises(ValueError):
        db._validate_dataset(tmp_path)

    # directory has a metadata file and no NPZ
    os.remove(os.path.join(dataset_path, 'experiment_1', 'example_file.npz'))
    Path(os.path.join(dataset_path, 'experiment_1', 'metadata.json')).touch()

    with pytest.raises(ValueError):
        db._validate_dataset(os.path.join(tmp_path))
示例#3
0
def test__init__(tmp_path):
    _create_minimal_dataset(tmp_path)
    db = DatasetBuilder(tmp_path)

    assert db.dataset_path == tmp_path

    # bad path
    with pytest.raises(ValueError):
        _ = DatasetBuilder(dataset_path='bad_path')
示例#4
0
def test__get_metadata(tmp_path):
    tissues = ['tissue1', 'tissue2']
    platforms = ['platform1', 'platform2']
    experiments = ['exp1', 'exp2']
    npzs = [1, 1]

    _create_test_dataset(path=tmp_path, experiments=experiments, platforms=platforms,
                         tissues=tissues, npz_num=npzs)

    db = DatasetBuilder(tmp_path)
    for i in range(len(experiments)):
        metadata = db._get_metadata(os.path.join(tmp_path, experiments[i]))
        assert metadata['tissue'] == tissues[i]
        assert metadata['platform'] == platforms[i]
示例#5
0
def test__reshape_dict_by_value(tmp_path):
    _create_minimal_dataset(tmp_path)
    db = DatasetBuilder(tmp_path)

    # create dict
    tissues = ['tissue1', 'tissue2', 'tissue3']
    platforms = ['platform1', 'platform2', 'platform3']
    data_dict = _create_test_dict(tissues=tissues, platforms=platforms)

    # same size as input data
    output_shape = (40, 40)

    reshaped_dict = db._reshape_dict(data_dict=data_dict, resize=3,
                                     output_shape=output_shape)
    X_reshaped, tissue_list_reshaped = reshaped_dict['X'], reshaped_dict['tissue_list']
    assert X_reshaped.shape[1:3] == output_shape

    # make sure that for each tissue, the arrays with correct value have correct tissue label
    for constant_val, tissue in enumerate(tissues):
        # each image was tagged with a different, compute that here
        image_val = np.max(X_reshaped, axis=(1, 2, 3))

        tissue_idx = image_val == constant_val
        tissue_labels = np.array(tissue_list_reshaped)[tissue_idx]
        assert np.all(tissue_labels == tissue)

        # There were originally 5 images of each tissue type. Each dimension was resized 3x,
        # so there should be 9x more images
        assert len(tissue_labels) == 5 * 9

    # now with a resize to make images smaller
    reshaped_dict = db._reshape_dict(data_dict=data_dict, resize=0.5,
                                     output_shape=output_shape)
    X_reshaped, tissue_list_reshaped = reshaped_dict['X'], reshaped_dict['tissue_list']
    assert X_reshaped.shape[1:3] == output_shape

    # make sure that for each tissue, the arrays with correct value have correct tissue label
    for constant_val, tissue in enumerate(tissues):
        # each image was tagged with a different, compute that here
        image_val = np.max(X_reshaped, axis=(1, 2, 3))

        tissue_idx = image_val == constant_val
        tissue_labels = np.array(tissue_list_reshaped)[tissue_idx]
        assert np.all(tissue_labels == tissue)

        # There were originally 5 images of each tissue type. Each dimension was resized 0.5,
        # and because the images are padded there should be the same total number of images
        assert len(tissue_labels) == 5
示例#6
0
def test__identify_tissue_and_platform_types(tmp_path):
    # create dataset
    experiments = ['exp{}'.format(i) for i in range(5)]
    tissues = ['tissue1', 'tissue2', 'tissue3', 'tissue2', 'tissue1']
    platforms = ['platform1', 'platform1', 'platform2', 'platform2', 'platform3']
    npz_num = [1] * 5
    _create_test_dataset(tmp_path, experiments=experiments, tissues=tissues,
                         platforms=platforms, npz_num=npz_num)

    db = DatasetBuilder(dataset_path=tmp_path)

    db._identify_tissue_and_platform_types()

    # check that all tissues and platforms added
    assert set(db.all_tissues) == set(tissues)
    assert set(db.all_platforms) == set(platforms)
示例#7
0
def test__load_experiment_multiple_npz(tmp_path):
    experiments, tissues, platforms, npz_num = ['exp1'], ['tissue1'], ['platform1'], [5]
    _create_test_dataset(tmp_path, experiments=experiments, tissues=tissues,
                         platforms=platforms, npz_num=npz_num)

    # initialize db
    db = DatasetBuilder(tmp_path)

    # load dataset
    X, y, tissue, platform = db._load_experiment(os.path.join(tmp_path, experiments[0]))

    # 5 NPZs with 10 images each
    assert X.shape[0] == 50
    assert y.shape[0] == 50

    assert tissue == tissues[0]
    assert platform == platforms[0]
示例#8
0
def test__validate_output_shape(tmp_path):
    _create_minimal_dataset(tmp_path)
    db = DatasetBuilder(tmp_path)

    # make sure list or tuple is converted
    output_shapes = [[222, 333], (222, 333)]
    for output_shape in output_shapes:
        validated_shape = db._validate_output_shape(output_shape)
        assert validated_shape == [output_shape, output_shape, output_shape]

    # not all splits specified
    output_shape = [(123, 456), (789, 1011)]
    with pytest.raises(ValueError):
        _ = db._validate_output_shape(output_shape=output_shape)

    # not all splits have 2 entries
    output_shape = [(12, 34), (56, 78), (910, 1112, 1314)]
    with pytest.raises(ValueError):
        _ = db._validate_output_shape(output_shape=output_shape)

    # too many splits
    output_shape = [(12, 34), (56, 78), (910, 1112), (1314, )]
    with pytest.raises(ValueError):
        _ = db._validate_output_shape(output_shape=output_shape)

    # not a list/tuple
    output_shape = 56
    with pytest.raises(ValueError):
        _ = db._validate_output_shape(output_shape=output_shape)
示例#9
0
def test__validate_categories(tmp_path):
    _create_minimal_dataset(tmp_path)
    db = DatasetBuilder(tmp_path)

    category_list = ['cat1', 'cat2', 'cat3']

    # convert single category to list
    supplied_categories = 'cat1'
    validated = db._validate_categories(category_list=category_list,
                                        supplied_categories=supplied_categories)
    assert validated == [supplied_categories]

    # convert 'all' to list of all categories
    supplied_categories = 'all'
    validated = db._validate_categories(category_list=category_list,
                                        supplied_categories=supplied_categories)
    assert np.all(validated == category_list)

    # convert 'all' to list of all categories
    supplied_categories = ['cat1', 'cat3']
    validated = db._validate_categories(category_list=category_list,
                                        supplied_categories=supplied_categories)
    assert np.all(validated == supplied_categories)

    # invalid string
    supplied_categories = 'cat4'
    with pytest.raises(ValueError):
        _ = db._validate_categories(category_list=category_list,
                                    supplied_categories=supplied_categories)

    # invalid list
    supplied_categories = ['cat4', 'cat1']
    with pytest.raises(ValueError):
        _ = db._validate_categories(category_list=category_list,
                                    supplied_categories=supplied_categories)
示例#10
0
def test_build_dataset(tmp_path):
    # create dataset
    experiments = ['exp{}'.format(i) for i in range(5)]
    tissues = ['tissue1', 'tissue2', 'tissue3', 'tissue4', 'tissue5']
    platforms = ['platform5', 'platform4', 'platform3', 'platform2', 'platform1']
    npz_num = [2, 2, 4, 6, 8]
    _create_test_dataset(tmp_path, experiments=experiments, tissues=tissues,
                         platforms=platforms, npz_num=npz_num)

    db = DatasetBuilder(tmp_path)

    # dataset with all data included
    output_dicts = db.build_dataset(tissues=tissues, platforms=platforms, output_shape=(20, 20))

    for dict in output_dicts:
        # make sure correct tissues and platforms loaded
        current_tissues = dict['tissue_list']
        current_platforms = dict['platform_list']
        assert set(current_tissues) == set(tissues)
        assert set(current_platforms) == set(platforms)

    # dataset with only a subset included
    tissues, platforms = tissues[:3], platforms[:3]
    output_dicts = db.build_dataset(tissues=tissues, platforms=platforms, output_shape=(20, 20))

    for dict in output_dicts:
        # make sure correct tissues and platforms loaded
        current_tissues = dict['tissue_list']
        current_platforms = dict['platform_list']
        assert set(current_tissues) == set(tissues)
        assert set(current_platforms) == set(platforms)

    # cropping to 1/2 the size, there should be 4x more crops
    output_dicts_crop = db.build_dataset(tissues=tissues, platforms=platforms,
                                         output_shape=(10, 10), relabel=True)

    for base_dict, crop_dict in zip(output_dicts, output_dicts_crop):
        X_base, X_crop = base_dict['X'], crop_dict['X']
        assert X_base.shape[0] * 4 == X_crop.shape[0]

    # check that NPZs have been relabeled
    for current_dict in output_dicts_crop:
        assert len(np.unique(current_dict['y'])) == 2

    # different sizes for different splits
    output_dicts_diff_sizes = db.build_dataset(tissues=tissues, platforms=platforms,
                                               output_shape=[(10, 10), (15, 15), (20, 20)])

    assert output_dicts_diff_sizes[0]['X'].shape[1:3] == (10, 10)
    assert output_dicts_diff_sizes[1]['X'].shape[1:3] == (15, 15)
    assert output_dicts_diff_sizes[2]['X'].shape[1:3] == (20, 20)

    # full runthrough with default options changed
    _ = db.build_dataset(tissues='all', platforms=platforms, output_shape=(10, 10),
                         relabel=True, resize='by_image', small_object_threshold=5,
                         balance=True)
示例#11
0
def test__balance_dict(tmp_path):
    _create_minimal_dataset(tmp_path)
    db = DatasetBuilder(tmp_path)

    X_data = np.random.rand(9, 10, 10, 3)
    y_data = np.random.rand(9, 10, 10, 1)
    tissue_list = np.array(['tissue1'] * 3 + ['tissue2'] * 3 + ['tissue3'] * 3)
    platform_list = np.array(['platform1'] * 3 + ['platform2'] * 3 + ['platform3'] * 3)

    balanced_dict = {'X': X_data, 'y': y_data, 'tissue_list': tissue_list,
                     'platform_list': platform_list}
    output_dict = db._balance_dict(data_dict=balanced_dict, seed=0, category='tissue_list')

    # data is already balanced, all items should be identical
    for key in output_dict:
        assert np.all(output_dict[key] == balanced_dict[key])

    # tissue 3 has most, others need to be upsampled
    tissue_list = np.array(['tissue1'] * 1 + ['tissue2'] * 2 + ['tissue3'] * 6)
    unbalanced_dict = {'X': X_data, 'y': y_data, 'tissue_list': tissue_list,
                       'platform_list': platform_list}
    output_dict = db._balance_dict(data_dict=unbalanced_dict, seed=0, category='tissue_list')

    # tissue 3 is unchanged
    for key in output_dict:
        assert np.all(output_dict[key][-6:] == unbalanced_dict[key][-6:])

    # tissue 1 only has a single example, all copies should be equal
    tissue1_idx = np.where(output_dict['tissue_list'] == 'tissue1')[0]
    for key in output_dict:
        vals = output_dict[key]
        for idx in tissue1_idx:
            new_val = vals[idx]
            old_val = unbalanced_dict[key][0]
            assert np.all(new_val == old_val)

    # tissue 2 has 2 examples, all copies should be equal to one of those values
    tissue2_idx = np.where(output_dict['tissue_list'] == 'tissue2')[0]
    for key in output_dict:
        vals = output_dict[key]
        for idx in tissue2_idx:
            new_val = vals[idx]
            old_val1 = unbalanced_dict[key][1]
            old_val2 = unbalanced_dict[key][2]
            assert np.all(new_val == old_val1) or np.all(new_val == old_val2)

    # check with same seed
    output_dict_same_seed = db._balance_dict(data_dict=unbalanced_dict, seed=0,
                                             category='tissue_list')

    for key in output_dict_same_seed:
        assert np.all(output_dict_same_seed[key] == output_dict[key])

    # check with different seed
    output_dict_diff_seed = db._balance_dict(data_dict=unbalanced_dict, seed=1,
                                             category='tissue_list')

    for key in ['X', 'y']:
        assert not np.all(output_dict_diff_seed[key] == output_dict[key])
示例#12
0
def test__reshape_dict_no_resize(tmp_path):
    _create_minimal_dataset(tmp_path)
    db = DatasetBuilder(tmp_path)

    # create dict
    tissues = ['tissue1', 'tissue2', 'tissue3']
    platforms = ['platform1', 'platform2', 'platform3']
    data_dict = _create_test_dict(tissues=tissues, platforms=platforms)

    # this is 1/2 the size on each dimension as original, so we expect 4x more crops
    output_shape = (20, 20)

    reshaped_dict = db._reshape_dict(data_dict=data_dict, resize=False, output_shape=output_shape)
    X_reshaped, tissue_list_reshaped = reshaped_dict['X'], reshaped_dict['tissue_list']
    assert X_reshaped.shape[1:3] == output_shape

    assert X_reshaped.shape[0] == 4 * data_dict['X'].shape[0]

    # make sure that for each tissue, the arrays with correct value have correct tissue label
    for constant_val, tissue in enumerate(tissues):
        tissue_idx = X_reshaped[:, 0, 0, 0] == constant_val
        tissue_labels = np.array(tissue_list_reshaped)[tissue_idx]
        assert np.all(tissue_labels == tissue)
示例#13
0
def test__load_all_experiments(tmp_path):
    # create dataset
    experiments = ['exp{}'.format(i) for i in range(5)]
    tissues = ['tissue1', 'tissue2', 'tissue3', 'tissue4', 'tissue5']
    platforms = ['platform5', 'platform4', 'platform3', 'platform2', 'platform1']
    npz_num = [2, 2, 4, 6, 8]
    _create_test_dataset(tmp_path, experiments=experiments, tissues=tissues,
                         platforms=platforms, npz_num=npz_num)

    total_img_num = np.sum(npz_num) * 10

    # initialize db
    db = DatasetBuilder(tmp_path)
    db._identify_tissue_and_platform_types()

    train_ratio, val_ratio, test_ratio = 0.7, 0.2, 0.1

    db._load_all_experiments(data_split=[train_ratio, val_ratio, test_ratio], seed=None)

    # get outputs
    train_dict, val_dict, test_dict = db.train_dict, db.val_dict, db.test_dict

    # check that splits were performed correctly
    for ratio, dict in zip((train_ratio, val_ratio, test_ratio),
                           (train_dict, val_dict, test_dict)):

        X_data, y_data = dict['X'], dict['y']
        assert X_data.shape[0] == ratio * total_img_num
        assert y_data.shape[0] == ratio * total_img_num

        tissue_list, platform_list = dict['tissue_list'], dict['platform_list']
        assert len(tissue_list) == len(platform_list) == X_data.shape[0]

    # check that the metadata maps to the correct images
    for dict in (train_dict, val_dict, test_dict):
        X_data, tissue_list, platform_list = dict['X'], dict['tissue_list'], dict['platform_list']

        # loop over each tissue type, and check that the NPZ is filled with correct constant value
        for constant_val, tissue in enumerate(tissues):

            # index of images with matching tissue type
            tissue_idx = tissue_list == tissue

            images = X_data[tissue_idx]
            assert np.all(images == constant_val)

        # loop over each platform type, and check that the NPZ contains correct constant value
        for constant_val, platform in enumerate(platforms):

            # index of images with matching platform type
            platform_idx = platform_list == platform

            images = X_data[platform_idx]
            assert np.all(images == constant_val)
示例#14
0
def test__subset_data_dict(tmp_path):
    _create_minimal_dataset(tmp_path)

    X = np.arange(100)
    y = np.arange(100)
    tissue_list = np.array(['tissue1'] * 10 + ['tissue2'] * 50 + ['tissue3'] * 40)
    platform_list = np.array(['platform1'] * 20 + ['platform2'] * 40 + ['platform3'] * 40)
    data_dict = {'X': X, 'y': y, 'tissue_list': tissue_list, 'platform_list': platform_list}

    db = DatasetBuilder(tmp_path)

    # all tissues, one platform
    tissues = ['tissue1', 'tissue2', 'tissue3']
    platforms = ['platform1']
    subset_dict = db._subset_data_dict(data_dict=data_dict, tissues=tissues, platforms=platforms)
    X_subset = subset_dict['X']
    keep_idx = np.isin(platform_list, platforms)

    assert np.all(X_subset == X[keep_idx])

    # all platforms, one tissue
    tissues = np.array(['tissue2'])
    platforms = np.array(['platform1', 'platform2', 'platform3'])
    subset_dict = db._subset_data_dict(data_dict=data_dict, tissues=tissues, platforms=platforms)
    X_subset = subset_dict['X']
    keep_idx = np.isin(tissue_list, tissues)

    assert np.all(X_subset == X[keep_idx])

    # drop tissue 1 and platform 3
    tissues = np.array(['tissue2', 'tissue3'])
    platforms = np.array(['platform1', 'platform2'])
    subset_dict = db._subset_data_dict(data_dict=data_dict, tissues=tissues, platforms=platforms)
    X_subset = subset_dict['X']
    platform_keep_idx = np.isin(platform_list, platforms)
    tissue_keep_idx = np.isin(tissue_list, tissues)
    keep_idx = np.logical_and(platform_keep_idx, tissue_keep_idx)

    assert np.all(X_subset == X[keep_idx])

    # tissue/platform combination that doesn't exist
    tissues = np.array(['tissue1'])
    platforms = np.array(['platform3'])
    with pytest.raises(ValueError):
        _ = db._subset_data_dict(data_dict=data_dict, tissues=tissues, platforms=platforms)
示例#15
0
def test_summarize_dataset(tmp_path):
    _create_minimal_dataset(tmp_path)
    db = DatasetBuilder(tmp_path)

    # create dict
    tissues = ['tissue1', 'tissue2', 'tissue3']
    platforms = ['platform1', 'platform2', 'platform3']
    train_dict = _create_test_dict(tissues=tissues, platforms=platforms)
    val_dict = _create_test_dict(tissues=tissues[1:], platforms=platforms[1:])
    test_dict = _create_test_dict(tissues=tissues[:-1], platforms=platforms[:-1])

    # make sure each dict has 2 cells in every image for counting purposes
    for current_dict in [train_dict, val_dict, test_dict]:
        current_labels = current_dict['y']
        current_labels[:, 0, 0, 0] = 5
        current_labels[:, 10, 0, 0] = 12

        current_dict['y'] = current_labels

    db.train_dict = train_dict
    db.val_dict = val_dict
    db.test_dict = test_dict

    tissue_dict, platform_dict = db.summarize_dataset()

    # check that all tissues and platforms are present
    for i in range(len(tissues)):
        assert tissues[i] in tissue_dict
        assert platforms[i] in platform_dict

    # Check that math is computed correctly
    for dict in [tissue_dict, platform_dict]:
        for key in list(dict.keys()):

            # each image has only two cells
            cell_num = dict[key]['cell_num']
            image_num = dict[key]['image_num']
            assert cell_num == image_num * 2

            # middle categories are present in all three dicts, and hence have 15
            if key in ['tissue2', 'platform2']:
                assert image_num == 15
            else:
                assert image_num == 10
示例#16
0
def test__clean_labels(tmp_path):
    _create_minimal_dataset(tmp_path)
    db = DatasetBuilder(tmp_path)

    test_label = np.zeros((50, 50), dtype='int')
    test_label[:10, :10] = 2
    test_label[12:17, 12:17] = 2
    test_label[20:22, 22:23] = 3

    test_labels = np.zeros((2, 50, 50, 1), dtype='int')
    test_labels[0, ..., 0] = test_label

    test_X = np.zeros_like(test_labels)
    test_tissue = np.array(['tissue1', 'tissue2'])
    test_platform = np.array(['platform2', 'platform3'])

    test_dict = {'X': test_X, 'y': test_labels, 'tissue_list': test_tissue,
                 'platform_list': test_platform}

    # relabel sequential
    cleaned_dict = db._clean_labels(data_dict=test_dict, relabel=False)
    assert len(np.unique(cleaned_dict['y'])) == 2 + 1  # 0 for background

    # true relabel
    cleaned_dict = db._clean_labels(data_dict=test_dict, relabel=True)
    assert len(np.unique(cleaned_dict['y'])) == 3 + 1

    # remove small objects
    cleaned_dict = db._clean_labels(data_dict=test_dict, relabel=True,
                                    small_object_threshold=15)
    assert len(np.unique(cleaned_dict['y'])) == 2 + 1

    # remove sparse images
    cleaned_dict = db._clean_labels(data_dict=test_dict, relabel=True, min_objects=1)
    assert cleaned_dict['y'].shape[0] == 1
    assert cleaned_dict['X'].shape[0] == 1
    assert len(cleaned_dict['tissue_list']) == 1
    assert cleaned_dict['tissue_list'][0] == 'tissue1'
    assert len(cleaned_dict['platform_list']) == 1
    assert cleaned_dict['platform_list'][0] == 'platform2'