Example #1
0
    def __call__(self, input, target):
        if isinstance(input, torch.Tensor):
            assert input.dim() == 5
            # convert to numpy array
            input = input[0].detach().cpu().numpy()  # 4D

        if isinstance(target, torch.Tensor):
            if not self.use_last_target:
                assert target.dim() == 4
                # convert to numpy array
                target = target[0].detach().cpu().numpy()  # 3D
            else:
                # if use_last_target == True the target must be 5D (NxCxDxHxW)
                assert target.dim() == 5
                target = target[0, -1].detach().cpu().numpy()  # 3D

        if isinstance(input, np.ndarray):
            assert input.ndim == 4

        if isinstance(target, np.ndarray):
            assert target.ndim == 3

        if self.use_first_input:
            # compute only on the first input channel
            n_channels = 1
        else:
            n_channels = input.shape[0]

        per_channel_arand = []
        for c in range(n_channels):
            predictions = input[c]
            # threshold probability maps
            predictions = predictions > self.threshold

            if self.invert_pmaps:
                # for connected component analysis we need to treat boundary signal as background
                # assign 0-label to boundary mask
                predictions = np.logical_not(predictions)

            predictions = predictions.astype(np.uint8)
            # run connected components on the predicted mask; consider only 1-connectivity
            predicted = measure.label(predictions,
                                      background=0,
                                      connectivity=1)
            # make sure that target is 'int' type as well
            target = target.astype(np.int64)
            # compute AdaptedRand error
            arand = adapted_rand(predicted, target)
            per_channel_arand.append(arand)

        # get minimum AdaptedRand error across channels
        min_arand, c_index = np.min(per_channel_arand), np.argmin(
            per_channel_arand)
        LOGGER.info(f'Min AdaptedRand error: {min_arand}, channel: {c_index}')
        return min_arand
Example #2
0
    def __call__(self, input, target):
        """
        Compute ARand Error for each input, target pair in the batch and return the mean value.

        Args:
            input (torch.tensor): 5D (NCDHW) output from the network
            target (torch.tensor): 4D (NDHW) ground truth segmentation

        Returns:
            average ARand Error across the batch
        """
        # converts input and target to numpy arrays
        input, target = self._convert_to_numpy(input, target)
        # ensure target is of integer type
        target = target.astype(np.int)

        per_batch_arand = []
        _batch_inst = 0
        for _input, _target in zip(input, target):
            LOGGER.info(
                f'Number of ground truth clusters: {len(np.unique(_target))}')

            # convert _input to segmentation
            segm = self.input_to_segm(_input)

            # run connected components if necessary
            if self.run_target_cc:
                _target = measure.label(_target, connectivity=1)

            if self.save_plots:
                # save predicted and ground truth segmentation
                plot_segm(segm, _target, self.plots_dir)

            assert segm.ndim == 4

            # compute per channel arand and return the minimum value
            per_channel_arand = []
            for channel_segm in segm:
                per_channel_arand.append(adapted_rand(channel_segm, _target))

            # get the min arand across channels
            min_arand, c_index = np.min(per_channel_arand), np.argmin(
                per_channel_arand)
            LOGGER.info(
                f'Batch: {_batch_inst}. Min AdaptedRand error: {min_arand}, channel: {c_index}'
            )
            per_batch_arand.append(min_arand)
            _batch_inst += 1

        # return mean arand error
        return torch.mean(torch.tensor(per_batch_arand))
Example #3
0
    def test_embeddings_predictor(self, tmpdir):
        config = {'model': {'output_heads': 1}, 'device': torch.device('cpu')}

        slice_builder_config = {
            'name': 'SliceBuilder',
            'patch_shape': (100, 200, 200),
            'stride_shape': (60, 150, 150)
        }

        transformer_config = {
            'raw': [{
                'name': 'ToTensor',
                'expand_dims': False,
                'dtype': 'long'
            }]
        }

        gt_file = 'resources/sample_cells.h5'
        output_file = os.path.join(tmpdir, 'output_segmentation.h5')

        dataset = HDF5Dataset(gt_file,
                              phase='test',
                              slice_builder_config=slice_builder_config,
                              transformer_config=transformer_config,
                              raw_internal_path='label')

        loader = DataLoader(dataset,
                            batch_size=1,
                            num_workers=1,
                            shuffle=False,
                            collate_fn=prediction_collate)

        predictor = FakePredictor(FakeModel(),
                                  loader,
                                  output_file,
                                  config,
                                  clustering='meanshift',
                                  bandwidth=0.5)

        predictor.predict()

        with h5py.File(gt_file, 'r') as f:
            with h5py.File(output_file, 'r') as g:
                gt = f['label'][...]
                segm = g['segmentation/meanshift'][...]
                arand_error = adapted_rand(segm, gt)

                assert arand_error < 0.1
Example #4
0
 def __call__(self, input, target):
     return adapted_rand(input, target, all_stats=self.all_stats)