Ejemplo n.º 1
0
    def run_online_evaluation(self, output, target):
        with torch.no_grad():
            num_classes = output.shape[1]
            output_softmax = softmax_helper(output)
            output_seg = output_softmax.argmax(1)
            target = target[:, 0]
            axes = tuple(range(1, len(target.shape)))
            tp_hard = torch.zeros(
                (target.shape[0], num_classes - 1)).to(output_seg.device.index)
            fp_hard = torch.zeros(
                (target.shape[0], num_classes - 1)).to(output_seg.device.index)
            fn_hard = torch.zeros(
                (target.shape[0], num_classes - 1)).to(output_seg.device.index)
            for c in range(1, num_classes):
                tp_hard[:, c - 1] = sum_tensor(
                    (output_seg == c).float() * (target == c).float(),
                    axes=axes)
                fp_hard[:, c - 1] = sum_tensor(
                    (output_seg == c).float() * (target != c).float(),
                    axes=axes)
                fn_hard[:, c - 1] = sum_tensor(
                    (output_seg != c).float() * (target == c).float(),
                    axes=axes)

            tp_hard = tp_hard.sum(0, keepdim=False).detach().cpu().numpy()
            fp_hard = fp_hard.sum(0, keepdim=False).detach().cpu().numpy()
            fn_hard = fn_hard.sum(0, keepdim=False).detach().cpu().numpy()

            self.online_eval_foreground_dc.append(
                list((2 * tp_hard) / (2 * tp_hard + fp_hard + fn_hard + 1e-8)))
            self.online_eval_tp.append(list(tp_hard))
            self.online_eval_fp.append(list(fp_hard))
            self.online_eval_fn.append(list(fn_hard))
Ejemplo n.º 2
0
    def forward(self, x, y, loss_mask=None):
        shp_x = x.shape
        shp_y = y.shape

        if self.batch_dice:
            axes = [0] + list(range(2, len(shp_x)))
        else:
            axes = list(range(2, len(shp_x)))

        if len(shp_x) != len(shp_y):
            y = y.view((shp_y[0], 1, *shp_y[1:]))

        if all([i == j for i, j in zip(x.shape, y.shape)]):
            # if this is the case then gt is probably already a one hot encoding
            y_onehot = y
        else:
            gt = y.long()
            y_onehot = torch.zeros(shp_x)
            if x.device.type == "cuda":
                y_onehot = y_onehot.cuda(x.device.index)
            y_onehot.scatter_(1, gt, 1)

        if self.apply_nonlin is not None:
            x = self.apply_nonlin(x)

        if not self.do_bg:
            x = x[:, 1:]
            y_onehot = y_onehot[:, 1:]

        tp, fp, fn, _ = get_tp_fp_fn_tn(x, y_onehot, axes, loss_mask,
                                        self.square)

        # GDL weight computation, we use 1/V
        volumes = sum_tensor(
            y_onehot, axes) + 1e-6  # add some eps to prevent div by zero

        if self.square_volumes:
            volumes = volumes**2

        # apply weights
        tp = tp / volumes
        fp = fp / volumes
        fn = fn / volumes

        # sum over classes
        if self.batch_dice:
            axis = 0
        else:
            axis = 1

        tp = tp.sum(axis, keepdim=False)
        fp = fp.sum(axis, keepdim=False)
        fn = fn.sum(axis, keepdim=False)

        # compute dice
        dc = (2 * tp + self.smooth) / (2 * tp + fp + fn + self.smooth)

        dc = dc.mean()

        return -dc
Ejemplo n.º 3
0
    def forward(self, x, y, loss_mask=None):
        shp_x = x.shape
        shp_y = y.shape

        if self.batch_dice:
            axes = [0] + list(range(2, len(shp_x)))
        else:
            axes = list(range(2, len(shp_x)))

        if self.apply_nonlin is not None:
            x = self.apply_nonlin(x)

        with torch.no_grad():
            if len(shp_x) != len(shp_y):
                y = y.view((shp_y[0], 1, *shp_y[1:]))

            if all([i == j for i, j in zip(x.shape, y.shape)]):
                # if this is the case then gt is probably already a one hot encoding
                y_onehot = y
            else:
                y = y.long()
                y_onehot = torch.zeros(shp_x)
                if x.device.type == "cuda":
                    y_onehot = y_onehot.cuda(x.device.index)
                y_onehot.scatter_(1, y, 1).float()

        intersect = x * y_onehot
        # values in the denominator get smoothed
        denominator = x**2 + y_onehot**2

        # aggregation was previously done in get_tp_fp_fn, but needs to be done here now (needs to be done after
        # squaring)
        intersect = sum_tensor(intersect, axes, False) + self.smooth
        denominator = sum_tensor(denominator, axes, False) + self.smooth

        dc = 2 * intersect / denominator

        if not self.do_bg:
            if self.batch_dice:
                dc = dc[1:]
            else:
                dc = dc[:, 1:]
        dc = dc.mean()

        return -dc
Ejemplo n.º 4
0
    def run_online_evaluation(self, output, target):
        with torch.no_grad():
            num_classes = output[0].shape[1]
            output_seg = output[0].argmax(1)
            target = target[0][:, 0]
            axes = tuple(range(1, len(target.shape)))
            tp_hard = torch.zeros(
                (target.shape[0], num_classes - 1)).to(output_seg.device.index)
            fp_hard = torch.zeros(
                (target.shape[0], num_classes - 1)).to(output_seg.device.index)
            fn_hard = torch.zeros(
                (target.shape[0], num_classes - 1)).to(output_seg.device.index)
            for c in range(1, num_classes):
                tp_hard[:, c - 1] = sum_tensor(
                    (output_seg == c).float() * (target == c).float(),
                    axes=axes)
                fp_hard[:, c - 1] = sum_tensor(
                    (output_seg == c).float() * (target != c).float(),
                    axes=axes)
                fn_hard[:, c - 1] = sum_tensor(
                    (output_seg != c).float() * (target == c).float(),
                    axes=axes)

            # tp_hard, fp_hard, fn_hard = get_tp_fp_fn((output_softmax > (1 / num_classes)).float(), target,
            #                                         axes, None)
            # print_if_rank0("before allgather", tp_hard.shape)
            tp_hard = tp_hard.sum(0, keepdim=False)[None]
            fp_hard = fp_hard.sum(0, keepdim=False)[None]
            fn_hard = fn_hard.sum(0, keepdim=False)[None]

            tp_hard = awesome_allgather_function.apply(tp_hard)
            fp_hard = awesome_allgather_function.apply(fp_hard)
            fn_hard = awesome_allgather_function.apply(fn_hard)

        tp_hard = tp_hard.detach().cpu().numpy().sum(0)
        fp_hard = fp_hard.detach().cpu().numpy().sum(0)
        fn_hard = fn_hard.detach().cpu().numpy().sum(0)
        self.online_eval_foreground_dc.append(
            list((2 * tp_hard) / (2 * tp_hard + fp_hard + fn_hard + 1e-8)))
        self.online_eval_tp.append(list(tp_hard))
        self.online_eval_fp.append(list(fp_hard))
        self.online_eval_fn.append(list(fn_hard))
Ejemplo n.º 5
0
def get_tp_fp_fn_tn(net_output, gt, axes=None, mask=None, square=False):
    """
    net_output must be (b, c, x, y(, z)))
    gt must be a label map (shape (b, 1, x, y(, z)) OR shape (b, x, y(, z))) or one hot encoding (b, c, x, y(, z))
    if mask is provided it must have shape (b, 1, x, y(, z)))
    :param net_output:
    :param gt:
    :param axes: can be (, ) = no summation
    :param mask: mask must be 1 for valid pixels and 0 for invalid pixels
    :param square: if True then fp, tp and fn will be squared before summation
    :return:
    """
    if axes is None:
        axes = tuple(range(2, len(net_output.size())))

    shp_x = net_output.shape
    shp_y = gt.shape

    with torch.no_grad():
        if len(shp_x) != len(shp_y):
            gt = gt.view((shp_y[0], 1, *shp_y[1:]))

        if all([i == j for i, j in zip(net_output.shape, gt.shape)]):
            # if this is the case then gt is probably already a one hot encoding
            y_onehot = gt
        else:
            gt = gt.long()
            y_onehot = torch.zeros(shp_x)
            if net_output.device.type == "cuda":
                y_onehot = y_onehot.cuda(net_output.device.index)
            y_onehot.scatter_(1, gt, 1)

    tp = net_output * y_onehot
    fp = net_output * (1 - y_onehot)
    fn = (1 - net_output) * y_onehot
    tn = (1 - net_output) * (1 - y_onehot)

    if mask is not None:
        tp = torch.stack(tuple(x_i * mask[:, 0]
                               for x_i in torch.unbind(tp, dim=1)),
                         dim=1)
        fp = torch.stack(tuple(x_i * mask[:, 0]
                               for x_i in torch.unbind(fp, dim=1)),
                         dim=1)
        fn = torch.stack(tuple(x_i * mask[:, 0]
                               for x_i in torch.unbind(fn, dim=1)),
                         dim=1)
        tn = torch.stack(tuple(x_i * mask[:, 0]
                               for x_i in torch.unbind(tn, dim=1)),
                         dim=1)

    if square:
        tp = tp**2
        fp = fp**2
        fn = fn**2
        tn = tn**2

    if len(axes) > 0:
        tp = sum_tensor(tp, axes, keepdim=False)
        fp = sum_tensor(fp, axes, keepdim=False)
        fn = sum_tensor(fn, axes, keepdim=False)
        tn = sum_tensor(tn, axes, keepdim=False)

    return tp, fp, fn, tn
Ejemplo n.º 6
0
    def forward(self, x, y=None, return_hard_tp_fp_fn=False):
        res = super(Generic_UNet_DP,
                    self).forward(x)  # regular Generic_UNet forward pass

        if y is None:
            return res
        else:
            # compute ce loss
            if self._deep_supervision and self.do_ds:
                ce_losses = [self.ce_loss(res[0], y[0]).unsqueeze(0)]
                tps = []
                fps = []
                fns = []

                res_softmax = softmax_helper(res[0])
                tp, fp, fn, _ = get_tp_fp_fn_tn(res_softmax, y[0])
                tps.append(tp)
                fps.append(fp)
                fns.append(fn)
                for i in range(1, len(y)):
                    ce_losses.append(self.ce_loss(res[i], y[i]).unsqueeze(0))
                    res_softmax = softmax_helper(res[i])
                    tp, fp, fn, _ = get_tp_fp_fn_tn(res_softmax, y[i])
                    tps.append(tp)
                    fps.append(fp)
                    fns.append(fn)
                ret = ce_losses, tps, fps, fns
            else:
                ce_loss = self.ce_loss(res, y).unsqueeze(0)

                # tp fp and fn need the output to be softmax
                res_softmax = softmax_helper(res)

                tp, fp, fn, _ = get_tp_fp_fn_tn(res_softmax, y)

                ret = ce_loss, tp, fp, fn

            if return_hard_tp_fp_fn:
                if self._deep_supervision and self.do_ds:
                    output = res[0]
                    target = y[0]
                else:
                    target = y
                    output = res

                with torch.no_grad():
                    num_classes = output.shape[1]
                    output_softmax = softmax_helper(output)
                    output_seg = output_softmax.argmax(1)
                    target = target[:, 0]
                    axes = tuple(range(1, len(target.shape)))
                    tp_hard = torch.zeros(
                        (target.shape[0],
                         num_classes - 1)).to(output_seg.device.index)
                    fp_hard = torch.zeros(
                        (target.shape[0],
                         num_classes - 1)).to(output_seg.device.index)
                    fn_hard = torch.zeros(
                        (target.shape[0],
                         num_classes - 1)).to(output_seg.device.index)
                    for c in range(1, num_classes):
                        tp_hard[:, c - 1] = sum_tensor(
                            (output_seg == c).float() * (target == c).float(),
                            axes=axes)
                        fp_hard[:, c - 1] = sum_tensor(
                            (output_seg == c).float() * (target != c).float(),
                            axes=axes)
                        fn_hard[:, c - 1] = sum_tensor(
                            (output_seg != c).float() * (target == c).float(),
                            axes=axes)

                    tp_hard = tp_hard.sum(0, keepdim=False)[None]
                    fp_hard = fp_hard.sum(0, keepdim=False)[None]
                    fn_hard = fn_hard.sum(0, keepdim=False)[None]

                    ret = *ret, tp_hard, fp_hard, fn_hard
            return ret