Ejemplo n.º 1
0
    def test_augmentation(self, transformer_config):
        raw = np.random.rand(32, 96, 96)
        # assign raw to label's channels for ease of comparison
        label = np.stack(raw for _ in range(3))
        # create temporary h5 file
        tmp_file = NamedTemporaryFile()
        tmp_path = tmp_file.name
        with h5py.File(tmp_path, 'w') as f:
            f.create_dataset('raw', data=raw)
            f.create_dataset('label', data=label)

        # set phase='train' in order to execute the train transformers
        phase = 'train'
        dataset = StandardHDF5Dataset(
            tmp_path,
            phase=phase,
            slice_builder_config=_slice_builder_conf((16, 64, 64),
                                                     (8, 32, 32)),
            transformer_config=transformer_config[phase]['transformer'])

        # test augmentations using DataLoader with 4 worker threads
        data_loader = DataLoader(dataset,
                                 batch_size=1,
                                 num_workers=4,
                                 shuffle=True)
        for (img, label) in data_loader:
            for i in range(label.shape[0]):
                assert np.allclose(img, label[i])
Ejemplo n.º 2
0
    def test_hdf5_dataset(self):
        path = create_random_dataset((128, 128, 128))

        patch_shapes = [(127, 127, 127), (69, 70, 70), (32, 64, 64)]
        stride_shapes = [(1, 1, 1), (17, 23, 23), (32, 64, 64)]

        phase = 'test'

        for patch_shape, stride_shape in zip(patch_shapes, stride_shapes):
            with h5py.File(path, 'r') as f:
                raw = f['raw'][...]
                label = f['label'][...]

                dataset = StandardHDF5Dataset(
                    path,
                    phase=phase,
                    slice_builder_config=create_slice_builder(
                        patch_shape, stride_shape),
                    transformer_config=transformer_config[phase],
                    raw_internal_path='raw',
                    label_internal_path='label')

                # create zero-arrays of the same shape as the original dataset in order to verify if every element
                # was visited during the iteration
                visit_raw = np.zeros_like(raw)
                visit_label = np.zeros_like(label)

                for (_, idx) in dataset:
                    visit_raw[idx] = 1
                    visit_label[idx] = 1

                # verify that every element was visited at least once
                assert np.all(visit_raw)
                assert np.all(visit_label)
Ejemplo n.º 3
0
    def test_embeddings_predictor(self, tmpdir):
        config = {'model': {'output_heads': 1}, 'device': torch.device('cpu')}

        slice_builder_config = {
            'name': 'SliceBuilder',
            'patch_shape': (64, 200, 200),
            'stride_shape': (40, 150, 150)
        }

        transformer_config = {
            'raw': [{
                'name': 'ToTensor',
                'expand_dims': False,
                'dtype': 'long'
            }]
        }

        gt_file = 'resources/sample_ovule.h5'
        output_file = os.path.join(tmpdir, 'output_segmentation.h5')

        dataset = StandardHDF5Dataset(
            gt_file,
            phase='test',
            slice_builder_config=slice_builder_config,
            transformer_config=transformer_config,
            mirror_padding=None,
            raw_internal_path='label')

        loader = DataLoader(dataset,
                            batch_size=1,
                            num_workers=1,
                            shuffle=False,
                            collate_fn=prediction_collate)

        predictor = FakePredictor(FakeModel(),
                                  loader,
                                  output_file,
                                  config,
                                  clustering='meanshift',
                                  bandwidth=0.5)

        predictor.predict()

        with h5py.File(gt_file, 'r') as f:
            with h5py.File(output_file, 'r') as g:
                gt = f['label'][...]
                segm = g['segmentation/meanshift'][...]
                arand_error = adapted_rand(segm, gt)

                assert arand_error < 0.1
Ejemplo n.º 4
0
    def test_hdf5_with_multiple_label_datasets(self, transformer_config):
        path = create_random_dataset((128, 128, 128),
                                     label_datasets=['label1', 'label2'])
        patch_shape = (32, 64, 64)
        stride_shape = (32, 64, 64)
        phase = 'train'
        dataset = StandardHDF5Dataset(
            path,
            phase=phase,
            slice_builder_config=_slice_builder_conf(patch_shape,
                                                     stride_shape),
            transformer_config=transformer_config[phase]['transformer'],
            raw_internal_path='raw',
            label_internal_path=['label1', 'label2'])

        for raw, labels in dataset:
            assert len(labels) == 2