Example #1
0
    def __call__(self, input, target):
        """
        Compute ARand Error for each input, target pair in the batch and return the mean value.

        Args:
            input (torch.tensor): 5D (NCDHW) output from the network
            target (torch.tensor): 4D (NDHW) ground truth segmentation

        Returns:
            average ARand Error across the batch
        """
        def _arand_err(gt, seg):
            n_seg = len(np.unique(seg))
            if n_seg == 1:
                return 0.
            return adapted_rand_error(gt, seg)[0]

        # converts input and target to numpy arrays
        input, target = convert_to_numpy(input, target)
        if self.use_last_target:
            target = target[:, -1, ...]  # 4D
        else:
            # use 1st target channel
            target = target[:, 0, ...]  # 4D

        # ensure target is of integer type
        target = target.astype(np.int)

        per_batch_arand = []
        for _input, _target in zip(input, target):
            n_clusters = len(np.unique(_target))
            # skip ARand eval if there is only one label in the patch due to the zero-division error in Arand impl
            # xxx/skimage/metrics/_adapted_rand_error.py:70: RuntimeWarning: invalid value encountered in double_scalars
            # precision = sum_p_ij2 / sum_a2
            logger.info(f'Number of ground truth clusters: {n_clusters}')
            if n_clusters == 1:
                logger.info('Skipping ARandError computation: only 1 label present in the ground truth')
                per_batch_arand.append(0.)
                continue

            # convert _input to segmentation CDHW
            segm = self.input_to_segm(_input)
            assert segm.ndim == 4

            if self.save_plots:
                # save predicted and ground truth segmentation
                plot_segm(segm, _target, self.plots_dir)

            # compute per channel arand and return the minimum value
            per_channel_arand = [_arand_err(_target, channel_segm) for channel_segm in segm]
            logger.info(f'Min ARand for channel: {np.argmin(per_channel_arand)}')
            per_batch_arand.append(np.min(per_channel_arand))

        # return mean arand error
        mean_arand = torch.mean(torch.tensor(per_batch_arand))
        logger.info(f'ARand: {mean_arand.item()}')
        return mean_arand
Example #2
0
    def __call__(self, input, target):
        """
        Compute ARand Error for each input, target pair in the batch and return the mean value.

        Args:
            input (torch.tensor): 5D (NCDHW) output from the network
            target (torch.tensor): 4D (NDHW) ground truth segmentation

        Returns:
            average ARand Error across the batch
        """
        # converts input and target to numpy arrays
        input, target = self._convert_to_numpy(input, target)
        # ensure target is of integer type
        target = target.astype(np.int)

        per_batch_arand = []
        _batch_inst = 0
        for _input, _target in zip(input, target):
            LOGGER.info(
                f'Number of ground truth clusters: {len(np.unique(_target))}')

            # convert _input to segmentation
            segm = self.input_to_segm(_input)

            # run connected components if necessary
            if self.run_target_cc:
                _target = measure.label(_target, connectivity=1)

            if self.save_plots:
                # save predicted and ground truth segmentation
                plot_segm(segm, _target, self.plots_dir)

            assert segm.ndim == 4

            # compute per channel arand and return the minimum value
            per_channel_arand = []
            for channel_segm in segm:
                per_channel_arand.append(adapted_rand(channel_segm, _target))

            # get the min arand across channels
            min_arand, c_index = np.min(per_channel_arand), np.argmin(
                per_channel_arand)
            LOGGER.info(
                f'Batch: {_batch_inst}. Min AdaptedRand error: {min_arand}, channel: {c_index}'
            )
            per_batch_arand.append(min_arand)
            _batch_inst += 1

        # return mean arand error
        return torch.mean(torch.tensor(per_batch_arand))