def test__reshape_dict_by_value(tmp_path): _create_minimal_dataset(tmp_path) db = DatasetBuilder(tmp_path) # create dict tissues = ['tissue1', 'tissue2', 'tissue3'] platforms = ['platform1', 'platform2', 'platform3'] data_dict = _create_test_dict(tissues=tissues, platforms=platforms) # same size as input data output_shape = (40, 40) reshaped_dict = db._reshape_dict(data_dict=data_dict, resize=3, output_shape=output_shape) X_reshaped, tissue_list_reshaped = reshaped_dict['X'], reshaped_dict['tissue_list'] assert X_reshaped.shape[1:3] == output_shape # make sure that for each tissue, the arrays with correct value have correct tissue label for constant_val, tissue in enumerate(tissues): # each image was tagged with a different, compute that here image_val = np.max(X_reshaped, axis=(1, 2, 3)) tissue_idx = image_val == constant_val tissue_labels = np.array(tissue_list_reshaped)[tissue_idx] assert np.all(tissue_labels == tissue) # There were originally 5 images of each tissue type. Each dimension was resized 3x, # so there should be 9x more images assert len(tissue_labels) == 5 * 9 # now with a resize to make images smaller reshaped_dict = db._reshape_dict(data_dict=data_dict, resize=0.5, output_shape=output_shape) X_reshaped, tissue_list_reshaped = reshaped_dict['X'], reshaped_dict['tissue_list'] assert X_reshaped.shape[1:3] == output_shape # make sure that for each tissue, the arrays with correct value have correct tissue label for constant_val, tissue in enumerate(tissues): # each image was tagged with a different, compute that here image_val = np.max(X_reshaped, axis=(1, 2, 3)) tissue_idx = image_val == constant_val tissue_labels = np.array(tissue_list_reshaped)[tissue_idx] assert np.all(tissue_labels == tissue) # There were originally 5 images of each tissue type. Each dimension was resized 0.5, # and because the images are padded there should be the same total number of images assert len(tissue_labels) == 5
def test__reshape_dict_by_image(tmp_path, mocker): mocker.patch('caliban_toolbox.dataset_builder.compute_cell_size', mocked_compute_cell_size) _create_minimal_dataset(tmp_path) db = DatasetBuilder(tmp_path) # create dict tissues = ['tissue1', 'tissue2', 'tissue3'] platforms = ['platform1', 'platform2', 'platform3'] data_dict = _create_test_dict(tissues=tissues, platforms=platforms) # same size as input data output_shape = (40, 40) reshaped_dict = db._reshape_dict(data_dict=data_dict, resize='by_image', output_shape=output_shape) X_reshaped, tissue_list_reshaped = reshaped_dict['X'], reshaped_dict['tissue_list'] assert X_reshaped.shape[1:3] == output_shape # make sure that for each tissue, the arrays with correct value have correct tissue label for constant_val, tissue in enumerate(tissues): # each image was tagged with a different, compute that here image_val = np.max(X_reshaped, axis=(1, 2, 3)) tissue_idx = image_val == constant_val tissue_labels = np.array(tissue_list_reshaped)[tissue_idx] assert np.all(tissue_labels == tissue) # There were originally 5 images of each tissue type. Tissue types with even values # are resized to be 2x larger on each dimension, and should have 4x more images if constant_val % 2 == 0: assert len(tissue_labels) == 5 * 4 # tissue types with odd values are resized to be smaller, which leads to same number # of unique images due to padding else: assert len(tissue_labels) == 5
def test__reshape_dict_no_resize(tmp_path): _create_minimal_dataset(tmp_path) db = DatasetBuilder(tmp_path) # create dict tissues = ['tissue1', 'tissue2', 'tissue3'] platforms = ['platform1', 'platform2', 'platform3'] data_dict = _create_test_dict(tissues=tissues, platforms=platforms) # this is 1/2 the size on each dimension as original, so we expect 4x more crops output_shape = (20, 20) reshaped_dict = db._reshape_dict(data_dict=data_dict, resize=False, output_shape=output_shape) X_reshaped, tissue_list_reshaped = reshaped_dict['X'], reshaped_dict['tissue_list'] assert X_reshaped.shape[1:3] == output_shape assert X_reshaped.shape[0] == 4 * data_dict['X'].shape[0] # make sure that for each tissue, the arrays with correct value have correct tissue label for constant_val, tissue in enumerate(tissues): tissue_idx = X_reshaped[:, 0, 0, 0] == constant_val tissue_labels = np.array(tissue_list_reshaped)[tissue_idx] assert np.all(tissue_labels == tissue)