Example #1
0
    def forward(self, x, y):
        shp_x = x.shape

        if self.batch_dice:
            axes = [0] + list(range(2, len(shp_x)))
        else:
            axes = list(range(2, len(shp_x)))

        if self.apply_nonlin is not None:
            net_output = self.apply_nonlin(x)  # (b,c,x,y,z)

        gt_onehot = gt2onehot(net_output, y, axes)  # (b,c,x,y,z)

        intersection = sum_tensor(net_output * gt_onehot, axes, keepdim=False)
        ground_o = sum_tensor(gt_onehot**2, axes, keepdim=False)
        pred_o = sum_tensor(net_output**2, axes, keepdim=False)
        dc = 2.0 * (intersection + self.smooth) / (ground_o + pred_o +
                                                   self.smooth)

        if not self.do_bg:
            if self.batch_dice:
                dc = dc[1:]
            else:
                dc = dc[:, 1:]
        dc = dc.mean()

        return -dc
Example #2
0
    def run_online_evaluation(self, output, target):
        with torch.no_grad():
            num_classes = output.shape[1]
            output_softmax = softmax_helper(output)
            output_seg = output_softmax.argmax(1)
            target = target[:, 0]
            axes = tuple(range(1, len(target.shape)))
            tp_hard = torch.zeros(
                (target.shape[0], num_classes - 1)).to(output_seg.device.index)
            fp_hard = torch.zeros(
                (target.shape[0], num_classes - 1)).to(output_seg.device.index)
            fn_hard = torch.zeros(
                (target.shape[0], num_classes - 1)).to(output_seg.device.index)
            for c in range(1, num_classes):
                tp_hard[:, c - 1] = sum_tensor(
                    (output_seg == c).float() * (target == c).float(),
                    axes=axes)
                fp_hard[:, c - 1] = sum_tensor(
                    (output_seg == c).float() * (target != c).float(),
                    axes=axes)
                fn_hard[:, c - 1] = sum_tensor(
                    (output_seg != c).float() * (target == c).float(),
                    axes=axes)

            tp_hard = tp_hard.sum(0, keepdim=False).detach().cpu().numpy()
            fp_hard = fp_hard.sum(0, keepdim=False).detach().cpu().numpy()
            fn_hard = fn_hard.sum(0, keepdim=False).detach().cpu().numpy()

            self.online_eval_foreground_dc.append(
                list((2 * tp_hard) / (2 * tp_hard + fp_hard + fn_hard + 1e-8)))
            self.online_eval_tp.append(list(tp_hard))
            self.online_eval_fp.append(list(fp_hard))
            self.online_eval_fn.append(list(fn_hard))
Example #3
0
    def run_online_evaluation(self, output, target):
        with torch.no_grad():
            num_classes = output[0].shape[1]
            output_seg = output[0].argmax(1)
            target = target[0][:, 0]
            axes = tuple(range(1, len(target.shape)))
            tp_hard = torch.zeros((target.shape[0], num_classes - 1)).to(output_seg.device.index)
            fp_hard = torch.zeros((target.shape[0], num_classes - 1)).to(output_seg.device.index)
            fn_hard = torch.zeros((target.shape[0], num_classes - 1)).to(output_seg.device.index)
            for c in range(1, num_classes):
                tp_hard[:, c - 1] = sum_tensor((output_seg == c).float() * (target == c).float(), axes=axes)
                fp_hard[:, c - 1] = sum_tensor((output_seg == c).float() * (target != c).float(), axes=axes)
                fn_hard[:, c - 1] = sum_tensor((output_seg != c).float() * (target == c).float(), axes=axes)

            # tp_hard, fp_hard, fn_hard = get_tp_fp_fn((output_softmax > (1 / num_classes)).float(), target,
            #                                         axes, None)
            # print_if_rank0("before allgather", tp_hard.shape)
            tp_hard = tp_hard.sum(0, keepdim=False)[None]
            fp_hard = fp_hard.sum(0, keepdim=False)[None]
            fn_hard = fn_hard.sum(0, keepdim=False)[None]

            tp_hard = awesome_allgather_function.apply(tp_hard)
            fp_hard = awesome_allgather_function.apply(fp_hard)
            fn_hard = awesome_allgather_function.apply(fn_hard)

        tp_hard = tp_hard.detach().cpu().numpy().sum(0)
        fp_hard = fp_hard.detach().cpu().numpy().sum(0)
        fn_hard = fn_hard.detach().cpu().numpy().sum(0)
        self.online_eval_foreground_dc.append(list((2 * tp_hard) / (2 * tp_hard + fp_hard + fn_hard + 1e-8)))
        self.online_eval_tp.append(list(tp_hard))
        self.online_eval_fp.append(list(fp_hard))
        self.online_eval_fn.append(list(fn_hard))
Example #4
0
def soft_dice_per_batch_2(net_output, gt, smooth=1., smooth_in_nom=1., background_weight=1, rebalance_weights=None,
                        square_nominator=False, square_denom=False):
    if rebalance_weights is not None and len(rebalance_weights) != gt.shape[1]:
        rebalance_weights = rebalance_weights[1:] # this is the case when use_bg=False
    axes = tuple([0] + list(range(2, len(net_output.size()))))
    tp = sum_tensor(net_output * gt, axes, keepdim=False)
    fn = sum_tensor((1 - net_output) * gt, axes, keepdim=False)
    fp = sum_tensor(net_output * (1 - gt), axes, keepdim=False)
    weights = torch.ones(tp.shape)
    weights[0] = background_weight
    if net_output.device.type == "cuda":
        weights = weights.cuda(net_output.device.index)
    if rebalance_weights is not None:
        rebalance_weights = torch.from_numpy(rebalance_weights).float()
        if net_output.device.type == "cuda":
            rebalance_weights = rebalance_weights.cuda(net_output.device.index)
        tp = tp * rebalance_weights
        fn = fn * rebalance_weights

    nominator = tp

    if square_nominator:
        nominator = nominator ** 2

    if square_denom:
        denom = 2 * tp ** 2 + fp ** 2 + fn ** 2
    else:
        denom = 2 * tp + fp + fn

    result = (- ((2 * nominator + smooth_in_nom) / (denom + smooth)) * weights).mean()
    return result
Example #5
0
def get_tp_fp_fn(net_output, gt, axes=None, mask=None, square=False):
    """
    net_output must be (b, c, x, y(, z)))
    gt must be a label map (shape (b, 1, x, y(, z)) OR shape (b, x, y(, z))) or one hot encoding (b, c, x, y(, z))
    if mask is provided it must have shape (b, 1, x, y(, z)))
    :param net_output:
    :param gt:
    :param axes:
    :param mask: mask must be 1 for valid pixels and 0 for invalid pixels
    :param square: if True then fp, tp and fn will be squared before summation
    :return:
    """
    if axes is None:
        axes = tuple(range(2, len(net_output.size())))

    shp_x = net_output.shape
    shp_y = gt.shape

    with torch.no_grad():
        if len(shp_x) != len(shp_y):
            gt = gt.view((shp_y[0], 1, *shp_y[1:]))

        if all([i == j for i, j in zip(net_output.shape, gt.shape)]):
            # if this is the case then gt is probably already a one hot encoding
            y_onehot = gt
        else:
            gt = gt.long()
            y_onehot = torch.zeros(shp_x)
            if net_output.device.type == "cuda":
                y_onehot = y_onehot.cuda(net_output.device.index)
            y_onehot.scatter_(1, gt, 1)

    tp = net_output * y_onehot
    fp = net_output * (1 - y_onehot)
    fn = (1 - net_output) * y_onehot

    if mask is not None:
        tp = torch.stack(tuple(x_i * mask[:, 0]
                               for x_i in torch.unbind(tp, dim=1)),
                         dim=1)
        fp = torch.stack(tuple(x_i * mask[:, 0]
                               for x_i in torch.unbind(fp, dim=1)),
                         dim=1)
        fn = torch.stack(tuple(x_i * mask[:, 0]
                               for x_i in torch.unbind(fn, dim=1)),
                         dim=1)

    if square:
        tp = tp**2
        fp = fp**2
        fn = fn**2

    tp = sum_tensor(tp, axes, keepdim=False)
    fp = sum_tensor(fp, axes, keepdim=False)
    fn = sum_tensor(fn, axes, keepdim=False)

    return tp, fp, fn
Example #6
0
def soft_dice(net_output, gt, smooth=1., smooth_in_nom=1., square_nominator=False, square_denom=False):
    axes = tuple(range(2, len(net_output.size())))
    if square_nominator:
        intersect = sum_tensor(net_output * gt, axes, keepdim=False)
    else:
        intersect = sum_tensor((net_output * gt) ** 2, axes, keepdim=False)
    if square_denom:
        denom = sum_tensor(net_output ** 2 + gt ** 2, axes, keepdim=False)
    else:
        denom = sum_tensor(net_output + gt, axes, keepdim=False)
    result = (- ((2 * intersect + smooth_in_nom) / (denom + smooth))).mean()
    return result
Example #7
0
    def forward(self, x, y, loss_mask=None):
        shp_x = x.shape
        shp_y = y.shape

        if self.batch_dice:
            axes = [0] + list(range(2, len(shp_x)))
        else:
            axes = list(range(2, len(shp_x)))

        if len(shp_x) != len(shp_y):
            y = y.view((shp_y[0], 1, *shp_y[1:]))

        if all([i == j for i, j in zip(x.shape, y.shape)]):
            # if this is the case then gt is probably already a one hot encoding
            y_onehot = y
        else:
            gt = y.long()
            y_onehot = torch.zeros(shp_x)
            if x.device.type == "cuda":
                y_onehot = y_onehot.cuda(x.device.index)
            y_onehot.scatter_(1, gt, 1)

        if self.apply_nonlin is not None:
            x = self.apply_nonlin(x)

        if not self.do_bg:
            x = x[:, 1:]
            y_onehot = y_onehot[:, 1:]

        tp, fp, fn, _ = get_tp_fp_fn_tn(x, y_onehot, axes, loss_mask,
                                        self.square)

        # GDL weight computation, we use 1/V
        volumes = sum_tensor(
            y_onehot, axes) + 1e-6  # add some eps to prevent div by zero

        if self.square_volumes:
            volumes = volumes**2

        # apply weights
        tp = tp / volumes
        fp = fp / volumes
        fn = fn / volumes

        # sum over classes
        if self.batch_dice:
            axis = 0
        else:
            axis = 1

        tp = tp.sum(axis, keepdim=False)
        fp = fp.sum(axis, keepdim=False)
        fn = fn.sum(axis, keepdim=False)

        # compute dice
        dc = (2 * tp + self.smooth) / (2 * tp + fp + fn + self.smooth)

        dc = dc.mean()

        return -dc
Example #8
0
    def forward(self, x, y, loss_mask=None):
        shp_x = x.shape
        shp_y = y.shape

        if self.batch_dice:
            axes = [0] + list(range(2, len(shp_x)))
        else:
            axes = list(range(2, len(shp_x)))

        if self.apply_nonlin is not None:
            x = self.apply_nonlin(x)

        with torch.no_grad():
            if len(shp_x) != len(shp_y):
                y = y.view((shp_y[0], 1, *shp_y[1:]))

            if all([i == j for i, j in zip(x.shape, y.shape)]):
                # if this is the case then gt is probably already a one hot encoding
                y_onehot = y
            else:
                y = y.long()
                y_onehot = torch.zeros(shp_x)
                if x.device.type == "cuda":
                    y_onehot = y_onehot.cuda(x.device.index)
                y_onehot.scatter_(1, y, 1).float()

        intersect = x * y_onehot
        # values in the denominator get smoothed
        denominator = x**2 + y_onehot**2

        # aggregation was previously done in get_tp_fp_fn, but needs to be done here now (needs to be done after
        # squaring)
        intersect = sum_tensor(intersect, axes, False) + self.smooth
        denominator = sum_tensor(denominator, axes, False) + self.smooth

        dc = 2 * intersect / denominator

        if not self.do_bg:
            if self.batch_dice:
                dc = dc[1:]
            else:
                dc = dc[:, 1:]
        dc = dc.mean()

        return -dc
Example #9
0
    def forward(self, net_output, gt, loss_mask=None):
        shp_x = net_output.shape
        shp_y = gt.shape
        # class_num = shp_x[1]

        with torch.no_grad():
            if len(shp_x) != len(shp_y):
                gt = gt.view((shp_y[0], 1, *shp_y[1:]))

            if all([i == j for i, j in zip(net_output.shape, gt.shape)]):
                # if this is the case then gt is probably already a one hot encoding
                y_onehot = gt
            else:
                gt = gt.long()
                y_onehot = torch.zeros(shp_x)
                if net_output.device.type == "cuda":
                    y_onehot = y_onehot.cuda(net_output.device.index)
                y_onehot.scatter_(1, gt, 1)

        if self.batch_dice:
            axes = [0] + list(range(2, len(shp_x)))
        else:
            axes = list(range(2, len(shp_x)))

        if self.apply_nonlin is not None:
            softmax_output = self.apply_nonlin(net_output)

        # no object value
        bg_onehot = 1 - y_onehot
        squared_error = (y_onehot - softmax_output)**2
        specificity_part = sum_tensor(squared_error * y_onehot, axes) / (
            sum_tensor(y_onehot, axes) + self.smooth)
        sensitivity_part = sum_tensor(squared_error * bg_onehot, axes) / (
            sum_tensor(bg_onehot, axes) + self.smooth)

        ss = self.r * specificity_part + (1 - self.r) * sensitivity_part

        if not self.do_bg:
            if self.batch_dice:
                ss = ss[1:]
            else:
                ss = ss[:, 1:]
        ss = ss.mean()

        return ss
Example #10
0
def mget_tp_fp_fn(net_output,
                  gt,
                  focal_conduct,
                  axes=None,
                  mask=None,
                  square=False):
    """
    net_output must be (b, c, x, y(, z)))
    gt must be a label map (shape (b, 1, x, y(, z)) OR shape (b, x, y(, z))) or one hot encoding (b, c, x, y(, z))
    if mask is provided it must have shape (b, 1, x, y(, z)))
    :param net_output:
    :param gt:
    :param axes:
    :param mask: mask must be 1 for valid pixels and 0 for invalid pixels
    :param square: if True then fp, tp and fn will be squared before summation
    :return:
    """
    if axes is None:
        axes = tuple(range(2, len(net_output.size())))

    shp_x = net_output.shape
    shp_y = gt.shape

    with torch.no_grad():
        if len(shp_x) != len(shp_y):
            gt = gt.view((shp_y[0], 1, *shp_y[1:]))

        if all([i == j for i, j in zip(net_output.shape, gt.shape)]):
            # if this is the case then gt is probably already a one hot encoding
            y_onehot = gt
        else:
            gt = gt.long()
            y_onehot = torch.zeros(shp_x)
            if net_output.device.type == "cuda":
                y_onehot = y_onehot.cuda(net_output.device.index)
            y_onehot.scatter_(1, gt, 1)

    # here is a little dummy...for simplicity, focal_conduct is [N,C]
    # y_onehot is [B, C, H, W, D]
    # first reshape [N,C] to [B, H, W, D, C]
    focal_conduct = torch.reshape(
        focal_conduct,
        (net_output.shape[0], net_output.shape[2], net_output.shape[3],
         net_output.shape[4], net_output.shape[1]))
    # then use transpose [B, H, W, D, C] to [B, C, H, W, D]
    focal_conduct = focal_conduct.transpose(4, 3)
    focal_conduct = focal_conduct.transpose(3, 2)
    focal_conduct = focal_conduct.transpose(2, 1)

    # this is like focal Tversky loss, but it would suppress too much
    # focal_fp = net_output * (1 - y_onehot) * focal_conduct2

    focal_fp = net_output * (1 - y_onehot)
    focal_fn = (1 - net_output) * y_onehot * focal_conduct
    tp = net_output * y_onehot
    fp = net_output * (1 - y_onehot)
    fn = (1 - net_output) * y_onehot

    if mask is not None:
        tp = torch.stack(tuple(x_i * mask[:, 0]
                               for x_i in torch.unbind(tp, dim=1)),
                         dim=1)
        focal_fp = torch.stack(tuple(x_i * mask[:, 0]
                                     for x_i in torch.unbind(focal_fp, dim=1)),
                               dim=1)
        focal_fn = torch.stack(tuple(x_i * mask[:, 0]
                                     for x_i in torch.unbind(focal_fn, dim=1)),
                               dim=1)
        fp = torch.stack(tuple(x_i * mask[:, 0]
                               for x_i in torch.unbind(fp, dim=1)),
                         dim=1)
        fn = torch.stack(tuple(x_i * mask[:, 0]
                               for x_i in torch.unbind(fn, dim=1)),
                         dim=1)

    if square:
        tp = tp**2
        focal_fp = focal_fp**2
        focal_fn = focal_fn**2
        fp = fp**2
        fn = fn**2

    # print(tp.size())
    tp = sum_tensor(tp, axes, keepdim=False)
    focal_fp = sum_tensor(focal_fp, axes, keepdim=False)
    focal_fn = sum_tensor(focal_fn, axes, keepdim=False)
    # print(tp.size())
    fp = sum_tensor(fp, axes, keepdim=False)
    fn = sum_tensor(fn, axes, keepdim=False)

    return focal_fp, focal_fn, tp, fp, fn
Example #11
0
    def run_iteration(self,
                      data_generator,
                      do_backprop=True,
                      run_online_evaluation=False):
        data_dict = next(data_generator)
        data = data_dict['data']
        target = data_dict['target']

        data = maybe_to_torch(data)
        target = maybe_to_torch(target)

        data = to_cuda(data, gpu_id=None)
        target = to_cuda(target, gpu_id=None)

        self.optimizer.zero_grad()

        output = self.network(data)
        del data

        total_loss = None
        for i in range(len(output)):
            # Starting here it gets spicy!
            axes = tuple(range(2, len(output[i].size())))

            # network does not do softmax. We need to do softmax for dice
            output_softmax = softmax_helper(output[i])

            # get the tp, fp and fn terms we need
            tp, fp, fn, _ = get_tp_fp_fn_tn(output_softmax,
                                            target[i],
                                            axes,
                                            mask=None)
            # for dice, compute nominator and denominator so that we have to accumulate only 2 instead of 3 variables
            # do_bg=False in nnUNetTrainer -> [:, 1:]
            nominator = 2 * tp[:, 1:]
            denominator = 2 * tp[:, 1:] + fp[:, 1:] + fn[:, 1:]

            if self.batch_dice:
                # for DDP we need to gather all nominator and denominator terms from all GPUS to do proper batch dice
                nominator = awesome_allgather_function.apply(nominator)
                denominator = awesome_allgather_function.apply(denominator)
                nominator = nominator.sum(0)
                denominator = denominator.sum(0)
            else:
                pass

            ce_loss = self.ce_loss(output[i], target[i])

            # we smooth by 1e-5 to penalize false positives if tp is 0
            dice_loss = (-(nominator + 1e-5) / (denominator + 1e-5)).mean()
            if total_loss is None:
                total_loss = self.ds_loss_weights[i] * (ce_loss + dice_loss)
            else:
                total_loss += self.ds_loss_weights[i] * (ce_loss + dice_loss)

        if run_online_evaluation:
            with torch.no_grad():
                num_classes = output[0].shape[1]
                output_seg = output[0].argmax(1)
                target = target[0][:, 0]
                axes = tuple(range(1, len(target.shape)))
                tp_hard = torch.zeros(
                    (target.shape[0],
                     num_classes - 1)).to(output_seg.device.index)
                fp_hard = torch.zeros(
                    (target.shape[0],
                     num_classes - 1)).to(output_seg.device.index)
                fn_hard = torch.zeros(
                    (target.shape[0],
                     num_classes - 1)).to(output_seg.device.index)
                for c in range(1, num_classes):
                    tp_hard[:, c - 1] = sum_tensor(
                        (output_seg == c).float() * (target == c).float(),
                        axes=axes)
                    fp_hard[:, c - 1] = sum_tensor(
                        (output_seg == c).float() * (target != c).float(),
                        axes=axes)
                    fn_hard[:, c - 1] = sum_tensor(
                        (output_seg != c).float() * (target == c).float(),
                        axes=axes)

                # tp_hard, fp_hard, fn_hard = get_tp_fp_fn((output_softmax > (1 / num_classes)).float(), target,
                #                                         axes, None)
                # print_if_rank0("before allgather", tp_hard.shape)
                tp_hard = tp_hard.sum(0, keepdim=False)[None]
                fp_hard = fp_hard.sum(0, keepdim=False)[None]
                fn_hard = fn_hard.sum(0, keepdim=False)[None]

                tp_hard = awesome_allgather_function.apply(tp_hard)
                fp_hard = awesome_allgather_function.apply(fp_hard)
                fn_hard = awesome_allgather_function.apply(fn_hard)
                # print_if_rank0("after allgather", tp_hard.shape)

                # print_if_rank0("after sum", tp_hard.shape)

                self.run_online_evaluation(
                    tp_hard.detach().cpu().numpy().sum(0),
                    fp_hard.detach().cpu().numpy().sum(0),
                    fn_hard.detach().cpu().numpy().sum(0))
        del target

        if do_backprop:
            if not self.fp16 or amp is None:
                total_loss.backward()
            else:
                with amp.scale_loss(total_loss, self.optimizer) as scaled_loss:
                    scaled_loss.backward()
            _ = clip_grad_norm_(self.network.parameters(), 12)
            self.optimizer.step()

        return total_loss.detach().cpu().numpy()
Example #12
0
    def forward(self, x, y=None, return_hard_tp_fp_fn=False):
        res = super(Generic_UNet_DP,
                    self).forward(x)  # regular Generic_UNet forward pass

        if y is None:
            return res
        else:
            # compute ce loss
            if self._deep_supervision and self.do_ds:
                ce_losses = [self.ce_loss(res[0], y[0]).unsqueeze(0)]
                tps = []
                fps = []
                fns = []

                res_softmax = softmax_helper(res[0])
                tp, fp, fn, _ = get_tp_fp_fn_tn(res_softmax, y[0])
                tps.append(tp)
                fps.append(fp)
                fns.append(fn)
                for i in range(1, len(y)):
                    ce_losses.append(self.ce_loss(res[i], y[i]).unsqueeze(0))
                    res_softmax = softmax_helper(res[i])
                    tp, fp, fn, _ = get_tp_fp_fn_tn(res_softmax, y[i])
                    tps.append(tp)
                    fps.append(fp)
                    fns.append(fn)
                ret = ce_losses, tps, fps, fns
            else:
                ce_loss = self.ce_loss(res, y).unsqueeze(0)

                # tp fp and fn need the output to be softmax
                res_softmax = softmax_helper(res)

                tp, fp, fn, _ = get_tp_fp_fn_tn(res_softmax, y)

                ret = ce_loss, tp, fp, fn

            if return_hard_tp_fp_fn:
                if self._deep_supervision and self.do_ds:
                    output = res[0]
                    target = y[0]
                else:
                    target = y
                    output = res

                with torch.no_grad():
                    num_classes = output.shape[1]
                    output_softmax = softmax_helper(output)
                    output_seg = output_softmax.argmax(1)
                    target = target[:, 0]
                    axes = tuple(range(1, len(target.shape)))
                    tp_hard = torch.zeros(
                        (target.shape[0],
                         num_classes - 1)).to(output_seg.device.index)
                    fp_hard = torch.zeros(
                        (target.shape[0],
                         num_classes - 1)).to(output_seg.device.index)
                    fn_hard = torch.zeros(
                        (target.shape[0],
                         num_classes - 1)).to(output_seg.device.index)
                    for c in range(1, num_classes):
                        tp_hard[:, c - 1] = sum_tensor(
                            (output_seg == c).float() * (target == c).float(),
                            axes=axes)
                        fp_hard[:, c - 1] = sum_tensor(
                            (output_seg == c).float() * (target != c).float(),
                            axes=axes)
                        fn_hard[:, c - 1] = sum_tensor(
                            (output_seg != c).float() * (target == c).float(),
                            axes=axes)

                    tp_hard = tp_hard.sum(0, keepdim=False)[None]
                    fp_hard = fp_hard.sum(0, keepdim=False)[None]
                    fn_hard = fn_hard.sum(0, keepdim=False)[None]

                    ret = *ret, tp_hard, fp_hard, fn_hard
            return ret
Example #13
0
def soft_dice_per_batch_2(net_output,
                          gt,
                          smooth=1.,
                          smooth_in_nom=1.,
                          background_weight=1,
                          rebalance_weights=None,
                          square_nominator=False,
                          square_denom=False):
    if rebalance_weights is not None and len(rebalance_weights) != gt.shape[1]:
        rebalance_weights = rebalance_weights[
            1:]  # this is the case when use_bg=False
    # print('\nrebalance_weights is:',rebalance_weights)
    #rebalance_weights is: None
    axes = tuple([0] + list(range(2, len(net_output.size()))))
    # print('\naxes is:',axes)
    #axes is: (0, 2, 3, 4)
    # print('\nnet_output shape is:',net_output.shape)
    # print('\ngt shape is:',gt.shape)
    #net_output shape is: torch.Size([8, 2, 192, 192, 48])
    #gt shape is: torch.Size([8, 2, 192, 192, 48])
    tp = sum_tensor(net_output * gt, axes, keepdim=False)
    # print('\ntp is:',tp)
    #tp shape is: torch.Size([2])
    #tp is: tensor([62684.4570, 82510.1562], device='cuda:4', grad_fn=<SumBackward2>)
    fn = sum_tensor((1 - net_output) * gt, axes, keepdim=False)
    # print('\nfn is:',fn)
    #fn shape is: torch.Size([2])
    #fn is: tensor([195664.5312, 103144.8438], device='cuda:4', grad_fn=<SumBackward2>)
    fp = sum_tensor(net_output * (1 - gt), axes, keepdim=False)
    # print('\nfp is:',fp)
    #fp shape is: torch.Size([2])
    #fp is: tensor([3610596., 6475380.], device='cuda:4', grad_fn=<SumBackward2>)
    weights = torch.ones(tp.shape)
    # print('\nweights shape is:',weights.shape)
    #weights shape is: torch.Size([2])
    weights[0] = background_weight
    # print('\nbackground_weight is:',background_weight)
    #background_weight is: 1
    if net_output.device.type == "cuda":
        weights = weights.cuda(net_output.device.index)
    if rebalance_weights is not None:
        rebalance_weights = torch.from_numpy(rebalance_weights).float()
        if net_output.device.type == "cuda":
            rebalance_weights = rebalance_weights.cuda(net_output.device.index)
        tp = tp * rebalance_weights
        fn = fn * rebalance_weights

    nominator = tp

    if square_nominator:
        nominator = nominator**2

    if square_denom:
        denom = 2 * tp**2 + fp**2 + fn**2
    else:
        denom = 2 * tp + fp + fn

    # result_1=(- ((2 * nominator + smooth_in_nom) / (denom + smooth)) * weights)
    # print('\nresult_1 is:',result_1)
    #result_1 is: tensor([-0.0616, -0.0038], device='cuda:4', grad_fn=<MulBackward0>)
    dice_1 = (((2 * nominator + smooth_in_nom) / (denom + smooth)) * weights)
    # print('\ndice_1 is:',dice_1)
    result_1 = torch.pow((-torch.log(dice_1[0])), 0.3) * 0.4 + torch.pow(
        (-torch.log(dice_1[1])), 0.3) * 0.6
    # print('\nresult_1 is:',result_1)

    # result = (- ((2 * nominator + smooth_in_nom) / (denom + smooth)) * weights).mean()
    # print('\nresult is:',result)
    # result is: tensor(-0.0327, device='cuda:4', grad_fn= < MeanBackward0 >)
    # Here we should notice that the soft dice is set as negative.
    return result_1