예제 #1
0
    def __call__(self, input, target):
        assert isinstance(input, torch.Tensor) and isinstance(target, torch.Tensor)
        assert input.dim() == 5
        assert target.dim() == 5

        input, target = convert_to_numpy(input, target)
        if self.use_last_target:
            target = target[:, -1, ...]  # 4D
        else:
            # use 1st target channel
            target = target[:, 0, ...]  # 4D

        batch_aps = []
        # iterate over the batch
        for inp, tar in zip(input, target):
            segs = self.input_to_seg(inp)  # 4D
            # convert target to seg
            tar = self.target_to_seg(tar)
            # filter small instances if necessary
            tar = self._filter_instances(tar)

            # compute average precision per channel
            segs_aps = [self.metric(self._filter_instances(seg), tar) for seg in segs]

            logger.info(f'Max Average Precision for channel: {np.argmax(segs_aps)}')
            # save max AP
            batch_aps.append(np.max(segs_aps))

        return torch.tensor(batch_aps).mean()
예제 #2
0
    def __call__(self, input, target):
        """
		:param input: 5D probability maps torch float tensor (NxCxDxHxW)
		:param target: 4D or 5D ground truth torch tensor. 4D (NxDxHxW) tensor will be expanded to 5D as one-hot
		:return: intersection over union averaged over all channels
		"""
        assert input.dim() == 5

        predictions, target = convert_to_numpy(input, target)
        predictions = predictions[0]

        # global otsu threshold on the predictions
        global_thresh = threshold_otsu(predictions)

        low_intensity_region = np.where(predictions < global_thresh)

        predictions = np.array(predictions)
        predictions[low_intensity_region] = 0
        predictions = np.expand_dims(predictions, axis=0)

        predictions = torch.tensor(predictions)

        target = torch.tensor(target)

        n_classes = input.size()[1]

        if target.dim() == 4:
            target = expand_as_one_hot(target,
                                       C=n_classes,
                                       ignore_index=self.ignore_index)

        assert predictions.size() == target.size()

        per_batch_iou = []
        for _input, _target in zip(predictions, target):
            binary_prediction = self._binarize_predictions(_input, n_classes)

            if self.ignore_index is not None:
                # zero out ignore_index
                mask = _target == self.ignore_index
                binary_prediction[mask] = 0
                _target[mask] = 0

            # convert to uint8 just in case
            binary_prediction = binary_prediction.byte()
            _target = _target.byte()

            per_channel_iou = []
            for c in range(n_classes):
                if c in self.skip_channels:
                    continue

                per_channel_iou.append(
                    self._jaccard_index(binary_prediction[c], _target[c]))

            assert per_channel_iou, "All channels were ignored from the computation"
            mean_iou = torch.mean(torch.tensor(per_channel_iou))
            per_batch_iou.append(mean_iou)

        return torch.mean(torch.tensor(per_batch_iou))
예제 #3
0
    def __call__(self, input, target):
        """
        Compute ARand Error for each input, target pair in the batch and return the mean value.

        Args:
            input (torch.tensor): 5D (NCDHW) output from the network
            target (torch.tensor): 4D (NDHW) ground truth segmentation

        Returns:
            average ARand Error across the batch
        """
        def _arand_err(gt, seg):
            n_seg = len(np.unique(seg))
            if n_seg == 1:
                return 0.
            return adapted_rand_error(gt, seg)[0]

        # converts input and target to numpy arrays
        input, target = convert_to_numpy(input, target)
        if self.use_last_target:
            target = target[:, -1, ...]  # 4D
        else:
            # use 1st target channel
            target = target[:, 0, ...]  # 4D

        # ensure target is of integer type
        target = target.astype(np.int)

        per_batch_arand = []
        for _input, _target in zip(input, target):
            n_clusters = len(np.unique(_target))
            # skip ARand eval if there is only one label in the patch due to the zero-division error in Arand impl
            # xxx/skimage/metrics/_adapted_rand_error.py:70: RuntimeWarning: invalid value encountered in double_scalars
            # precision = sum_p_ij2 / sum_a2
            logger.info(f'Number of ground truth clusters: {n_clusters}')
            if n_clusters == 1:
                logger.info(
                    'Skipping ARandError computation: only 1 label present in the ground truth'
                )
                per_batch_arand.append(0.)
                continue

            # convert _input to segmentation CDHW
            segm = self.input_to_segm(_input)
            assert segm.ndim == 4

            # compute per channel arand and return the minimum value
            per_channel_arand = [
                _arand_err(_target, channel_segm) for channel_segm in segm
            ]
            logger.info(
                f'Min ARand for channel: {np.argmin(per_channel_arand)}')
            per_batch_arand.append(np.min(per_channel_arand))

        # return mean arand error
        mean_arand = torch.mean(torch.tensor(per_batch_arand))
        logger.info(f'ARand: {mean_arand.item()}')
        return mean_arand
예제 #4
0
    def __call__(self, input, target):
        """
		:input: The predictions from the network (BS*C*D*H*W)
		:target: The ground truth passed to the network (BS*C*D*H*W)
		:return: Number of matched peaks
		"""

        # look for local max within fixed range between input and target

        # Get input and predictions in required format
        input, target = convert_to_numpy(input, target)

        # pass the original gt coordinates here
        #original_gt = #

        # Take the peaks from the predictions
        local_max = skimage.feature.peak_local_max(input[0][0], min_distance=4)

        # Take the foreground pixels from ground truth
        foreground = np.where(target[0][0] == 1.0)
        foreground_coords = []

        for idx, val in enumerate(foreground[0]):
            foreground_coords.append(
                [foreground[0][idx], foreground[1][idx], foreground[2][idx]])

        eval_dict = {}

        # Matching procedure
        for coord in local_max:
            z, y, x = coord[0], coord[1], coord[2]
            eval_dict[(z, y, x)] = {}

            for coord in foreground_coords:
                gt_z, gt_y, gt_x = coord[0], coord[1], coord[2]

                # Taking anisotropy into account
                std_euc = distance.seuclidean([z, y, x], [gt_z, gt_y, gt_x],
                                              [3.0, 1.3, 1.3])

                # Keeping a threshold for matching
                if std_euc <= 8.0:
                    eval_dict[(z, y, x)][gt_z, gt_y, gt_x] = std_euc

        n_count = 0
        for k, v in eval_dict.items():
            if v != {}:
                #vsorted = {k2: v2 for k2, v2 in sorted(v.items(), key=lambda item: item[1])}

                n_count += 1
                #print(k, dict(itertools.islice(vsorted.items(), 1)))

        print(n_count)
        return torch.tensor(n_count)
예제 #5
0
    def __call__(self, input, target):
        if target.dim() == 5:
            if self.use_last_target:
                target = target[:, -1, ...]  # 4D
            else:
                # use 1st target channel
                target = target[:, 0, ...]  # 4D

        input1 = input2 = input
        multi_head = isinstance(input, tuple)
        if multi_head:
            input1, input2 = input

        input1, input2, target = convert_to_numpy(input1, input2, target)

        batch_aps = []
        i_batch = 0
        # iterate over the batch
        for inp1, inp2, tar in zip(input1, input2, target):
            if multi_head:
                inp = (inp1, inp2)
            else:
                inp = inp1

            segs = self.input_to_seg(inp, tar)  # expects 4D
            assert segs.ndim == 4
            # convert target to seg
            tar = self.target_to_seg(tar)

            # filter small instances if necessary
            tar = self._filter_instances(tar)

            # compute average precision per channel
            segs_aps = [
                self.metric(self._filter_instances(seg), tar) for seg in segs
            ]

            logger.info(
                f'Batch: {i_batch}. Max Average Precision for channel: {np.argmax(segs_aps)}'
            )
            # save max AP
            batch_aps.append(np.max(segs_aps))
            i_batch += 1

        return torch.tensor(batch_aps).mean()
예제 #6
0
    def __call__(self, input, target):
        """
		:input: The predictions from the network (BS*C*D*H*W)
		:target: The ground truth passed to the network (BS*C*D*H*W)
		:return: Number of matched peaks
		"""

        matching_count = 0
        neighbor_threshold = 8.0

        # look for local max within fixed range between input and target

        # Get input and predictions in required format
        input, target = convert_to_numpy(input, target)

        # Take the peaks from the predictions
        # min_distance is hyperparameter here
        local_max = skimage.feature.peak_local_max(input[0][0], min_distance=4)

        # Take the foreground pixels from ground truth
        foreground = np.where(target[0][0] == 1.0)
        foreground_coords = []

        for idx, val in enumerate(foreground[0]):
            foreground_coords.append(
                [foreground[0][idx], foreground[1][idx], foreground[2][idx]])

        eval_dict = {}

        # Matching procedure
        eval_dict = match_routine(local_max, foreground_coords,
                                  neighbor_threshold)

        for eval_key, eval_val in eval_dict.items():
            if eval_val != {}:
                #vsorted = {k2: v2 for k2, v2 in sorted(eval_val.items(), key=lambda item: item[1])}

                matching_count += 1
                #print(k, dict(itertools.islice(vsorted.items(), 1)))

        print(matching_count)
        return torch.tensor(matching_count)
예제 #7
0
    def __call__(self, input, target):
        predictions, target = convert_to_numpy(input, target)

        predictions = predictions[0]
        target = target[0]

        global_thresh = threshold_otsu(predictions[0])

        foreground = predictions > global_thresh

        local_max = skimage.feature.peak_local_max(predictions[0], min_distance=2)

        temp_max = np.zeros((1,48,128,128))

        for i, each in enumerate(local_max):
            temp_max[0][each[0], each[1], each[2]] = i+1

        inv_temp_max = np.logical_not(temp_max)
        dist_tr = ndimage.distance_transform_edt(inv_temp_max)

        # Thresh val.
        thresh_tr = dist_tr > 3

        thresh_temp = np.logical_not(thresh_tr).astype(np.float64)

        extra = np.where(thresh_temp != foreground)

        thresh_temp[extra] = 0

        watershed_output = watershed(thresh_temp, temp_max, mask=thresh_temp).astype(np.uint16)

        wshed_peaks = []
        wshed = np.where(watershed_output[0]!=0)


        for wid, wval in enumerate(wshed[0]):
            wshed_peaks.append([wshed[0][wid], wshed[1][wid], wshed[2][wid]])


        intersection_peaks = []
        for each in local_max:
            z,y,x = each[0], each[1], each[2]
            for ws in wshed_peaks:
                wsz, wsy, wsx = ws[0], ws[1], ws[2]
                if z == wsz and y == wsy and x == wsx:
                    intersection_peaks.append(ws)

        gt_foreground = np.where(target[0]==1.0)
        gt_coords = []

        for idx, val in enumerate(gt_foreground[0]):
            gt_coords.append([gt_foreground[0][idx], gt_foreground[1][idx], gt_foreground[2][idx]])

        eval_dict = {}

        for gtc in gt_coords:
            gt_z, gt_y, gt_x = gtc[0], gtc[1], gtc[2]
            eval_dict[(gt_z, gt_y, gt_x)] = {}

            for a, peak in enumerate(intersection_peaks):
                z, y, x = peak[0], peak[1], peak[2]

                std_euc = distance.seuclidean([gt_z, gt_y, gt_x], [z,y,x], [3.0,1.3,1.3])

                if std_euc <= 6.0:
                    #instance_label = watershed_output[0][wshed_peaks[a][0]][wshed_peaks[a][1]][wshed_peaks[a][2]]
                    #eval_dict[gt_z, gt_y, gt_x][(z,y,x)] = {instance_label, std_euc}
                    eval_dict[gt_z, gt_y, gt_x][(z,y,x)] = std_euc

        tp_count = 0
        fn_count = 0

        for k1, v1 in eval_dict.items():
            if v1 != {}:
                tp_count += 1
                #v1sorted = {k:v for k,v in sorted(v1.items(), key=lambda item: item[1])}
                #print(k1, v1sorted)
            else:
                fn_count += 1

        print('Total GT peaks: 79')
        print('Prediction peaks after thresholding: ' + str(len(intersection_peaks)))
        print('Matched GT and peaks (TP): ' +  str(tp_count))

        # fn count -> 

        print('No prediction peak for GT (FN): ' + str(fn_count))
        # for k, v in eval_dict.items():
        #   print(k, v)

        eval_dict = {}

        fp_count = 0
        tp_count = 0

        for a, peak in enumerate(intersection_peaks):
            z, y, x = peak[0], peak[1], peak[2]
            eval_dict[(z, y, x)] = {}

            for gtc in gt_coords:
                gt_z, gt_y, gt_x = gtc[0], gtc[1], gtc[2]

                std_euc = distance.seuclidean([z,y,x], [gt_z, gt_y, gt_x], [3.0,1.3,1.3])

                if std_euc <= 6.0:
                    #instance_label = watershed_output[0][wshed_peaks[a][0]][wshed_peaks[a][1]][wshed_peaks[a][2]]
                    #eval_dict[gt_z, gt_y, gt_x][(z,y,x)] = {instance_label, std_euc}
                    eval_dict[z, y, x][(gt_z,gt_y,gt_x)] = std_euc

        for k1, v1 in eval_dict.items():
            if v1 != {}:
                tp_count += 1
                #v1sorted = {k:v for k,v in sorted(v1.items(), key=lambda item: item[1])}
                #print(k1, v1sorted)
            else:
                fp_count += 1
                #print(k1)

        #print('' + tp_count)
        print('No GT for a peak (FP): ' + str(fp_count))

        precision_val = self.precision(tp_count, fp_count)
        recall_val = self.recall(tp_count, fn_count)

        F1_score = 2 * precision_val * recall_val / (precision_val + recall_val)

        return torch.tensor(F1_score)
예제 #8
0
    def __call__(self, input, target):
        """
        :param input: 5D probability maps torch float tensor (NxCxDxHxW)
        :param target: 4D or 5D ground truth torch tensor. 4D (NxDxHxW) tensor will be expanded to 5D as one-hot
        :return: intersection over union averaged over all channels
        """
        assert input.dim() == 5

        predictions, target = convert_to_numpy(input, target)
        predictions = predictions[0]

        # global otsu threshold on the predictions
        global_thresh = threshold_otsu(predictions)

        # define the foreground based on threshold
        foreground = predictions > global_thresh

        # get the local peaks from the predictions
        local_max = skimage.feature.peak_local_max(predictions[0], min_distance=5)

        # prepare the vol with peaks
        local_peaks_vol = np.zeros((1, 48, 128, 128))

        for coordinate in local_max:
            local_peaks_vol[0][coordinate[0], coordinate[1], coordinate[2]] = 1.0

        # dilate the peaks
        inv_local_peaks_vol = np.logical_not(local_peaks_vol)

        # get distance transform
        local_peaks_edt = ndimage.distance_transform_edt(inv_local_peaks_vol)

        # threshold the edt and invert back: fg as 1, bg as 0
        spherical_peaks = local_peaks_edt > 3
        spherical_peaks = np.logical_not(spherical_peaks).astype(np.float64)

        # get the outliers based on threshold and set zero
        outliers = np.where(spherical_peaks != foreground)
        spherical_peaks[outliers] = 0

        spherical_peaks = np.expand_dims(spherical_peaks, axis=0)
        # print(spherical_peaks.shape)
        # print(np.min(spherical_peaks))
        # print(np.max(spherical_peaks))
        # print(len(np.where(spherical_peaks==1.0)[0]))

        # spherical_peaks = torch.tensor(spherical_peaks)

        # spherical_peaks.to('cuda')

        spherical_peaks = torch.tensor(spherical_peaks)
        target = torch.tensor(target)

        n_classes = input.size()[1]

        if target.dim() == 4:
            target = expand_as_one_hot(target, C=n_classes, ignore_index=self.ignore_index)

        assert spherical_peaks.size() == target.size()

        per_batch_iou = []
        for _input, _target in zip(spherical_peaks, target):
            binary_prediction = self._binarize_predictions(_input, n_classes)

            if self.ignore_index is not None:
                # zero out ignore_index
                mask = _target == self.ignore_index
                binary_prediction[mask] = 0
                _target[mask] = 0

            # convert to uint8 just in case
            binary_prediction = binary_prediction.byte()
            _target = _target.byte()

            per_channel_iou = []
            for c in range(n_classes):
                if c in self.skip_channels:
                    continue

                per_channel_iou.append(self._jaccard_index(binary_prediction[c], _target[c]))

            assert per_channel_iou, "All channels were ignored from the computation"
            mean_iou = torch.mean(torch.tensor(per_channel_iou))
            per_batch_iou.append(mean_iou)

        return torch.mean(torch.tensor(per_batch_iou))
예제 #9
0
	def __call__(self, input, target):
		input, target = convert_to_numpy(input, target)
		return normalized_root_mse(target, input)
예제 #10
0
	def __call__(self, input, target):
		input, target = convert_to_numpy(input, target)
		return mean_squared_error(target, input)
예제 #11
0
 def __call__(self, input, target):
     input, target = convert_to_numpy(input, target)
     return peak_signal_noise_ratio(target, input)
예제 #12
0
    def __call__(self, input, target):

        threshold_val = 2
        neighbor_threshold = 8.0
        tp_count = 0
        fp_count = 0
        fn_count = 0
        predictions, target = convert_to_numpy(input, target)

        predictions = predictions[0]
        target = target[0]

        global_thresh = threshold_otsu(predictions[0])

        foreground = predictions > global_thresh

        local_max = skimage.feature.peak_local_max(predictions[0],
                                                   min_distance=3)

        temp_max = np.zeros((1, 48, 128, 128))

        for i, each in enumerate(local_max):
            temp_max[0][each[0], each[1], each[2]] = i + 1

        thresh_temp = dilation_routine(temp_max, threshold_val=2)

        extra = np.where(thresh_temp != foreground)

        thresh_temp[extra] = 0

        watershed_output = watershed(thresh_temp, temp_max,
                                     mask=thresh_temp).astype(np.uint16)

        wshed_peaks = []
        wshed = np.where(watershed_output[0] != 0)

        for wid, wval in enumerate(wshed[0]):
            wshed_peaks.append([wshed[0][wid], wshed[1][wid], wshed[2][wid]])

        intersection_peaks = []
        for each in local_max:
            z, y, x = each[0], each[1], each[2]
            for ws in wshed_peaks:
                wsz, wsy, wsx = ws[0], ws[1], ws[2]
                if z == wsz and y == wsy and x == wsx:
                    intersection_peaks.append(ws)

        gt_foreground = np.where(target[0] == 1.0)
        gt_coords = []

        for idx, val in enumerate(gt_foreground[0]):
            gt_coords.append([
                gt_foreground[0][idx], gt_foreground[1][idx],
                gt_foreground[2][idx]
            ])

        eval_dict = match_routine(gt_coords, intersection_peaks,
                                  neighbor_threshold)

        for k1, v1 in eval_dict.items():
            if v1 != {}:
                tp_count += 1
                #v1sorted = {k:v for k,v in sorted(v1.items(), key=lambda item: item[1])}
                #print(k1, v1sorted)
            else:
                fn_count += 1

        print('Total GT peaks: 79')
        print('Prediction peaks after thresholding: ' +
              str(len(intersection_peaks)))
        print('Matched GT and peaks (TP): ' + str(tp_count))

        print('No prediction peak for GT (FN): ' + str(fn_count))

        fp_eval_dict = match_routine(intersection_peaks, gt_coords,
                                     neighbor_threshold)

        for k1, v1 in fp_eval_dict.items():
            if v1 != {}:
                pass
                #v1sorted = {k:v for k,v in sorted(v1.items(), key=lambda item: item[1])}
                #print(k1, v1sorted)
            else:
                fp_count += 1

        print('No GT for a peak (FP): ' + str(fp_count))

        precision_val = self.precision(tp_count, fp_count)
        recall_val = self.recall(tp_count, fn_count)

        F1_score = 2 * precision_val * recall_val / (precision_val +
                                                     recall_val)

        return torch.tensor(F1_score)