Пример #1
0
class DefSegLoss(TorchLoss):
    def __init__(self, weights: List[float], penalty="l2", loss_mult=None):
        super().__init__()
        self.grad_loss = Grad(penalty=penalty, loss_mult=loss_mult)
        self.weights = weights
        self.deform_mse_loss = MSELoss()
        self.label_dice_loss = DiceLoss()

    def cumulate(
        self,
        predicted: Union[torch.Tensor, Iterable[torch.Tensor]],
        outputs: Union[torch.Tensor, Iterable[torch.Tensor]],
    ):
        """predicted = (warped template, warped maps, pred maps, flow)"""
        label, template = outputs

        grad_loss = self.grad_loss.cumulate(predicted[3], None)
        deform_loss = self.deform_mse_loss.cumulate(predicted[3], torch.zeros(predicted[3].shape).cuda())

        label_dice_loss = self.label_dice_loss.cumulate(predicted[0], label)

        loss = grad_loss * self.weights[0] + deform_loss * self.weights[1] + \
            label_dice_loss * self.weights[2]

        self._cum_loss += loss.item()
        self._count += 1
        return loss

    def new(self):
        new_loss = self.__class__(
            penalty=self.grad_loss.penalty,
            loss_mult=self.grad_loss.loss_mult, weights=self.weights
        )
        new_loss.reset()
        return new_loss

    def description(self):
        return "{}, deform {}, label {}".format(
            self.grad_loss.description(), self.deform_mse_loss.description(),
            self.label_dice_loss.description(),
        )

    def reset(self):
        super().reset()
        self.grad_loss.reset()
        self.deform_mse_loss.reset()
        self.label_dice_loss.reset()
Пример #2
0
class DefLoss(TorchLoss):
    def __init__(self, weights: List[float], penalty="l2", loss_mult=None):
        super().__init__()
        self.weights = weights

        self.pred_maps_mse_loss = MSELoss()

        self.grad_loss = Grad(penalty=penalty, loss_mult=loss_mult)
        self.deform_mse_loss = MSELoss()

        self.label_mse_loss = MSELoss()
        self.label_mse_loss_affine = MSELoss()

        self.atlas_mse_loss = MSELoss()
        self.atlas_mse_loss_affine = MSELoss()

    def cumulate(
        self,
        predicted: Union[torch.Tensor, Iterable[torch.Tensor]],
        outputs: Union[torch.Tensor, Iterable[torch.Tensor]],
    ):
        """predicted = affine_warped_template 0, warped_template 1, pred_maps 2, preint_flow 3, warped_image 4,
        warped_label 5, affine_warped_label 6, batch_atlas 7"""
        label = outputs
        weights = self.weights

        grad_loss = self.grad_loss.cumulate(predicted[3], None)
        deform_loss = self.deform_mse_loss.cumulate(
            predicted[3],
            torch.zeros(predicted[3].shape).cuda())

        pred_map_mse_loss = self.pred_maps_mse_loss.cumulate(
            predicted[2], label)
        pred_map_loss = weights[0] * pred_map_mse_loss

        affine_label_mse_loss = self.label_mse_loss_affine.cumulate(
            predicted[0], label)
        affine_label_loss = weights[1] * affine_label_mse_loss

        label_mse_loss = self.label_mse_loss.cumulate(predicted[1], label)
        label_loss = weights[1] * label_mse_loss + affine_label_loss

        atlas_mse_loss = self.atlas_mse_loss.cumulate(predicted[5],
                                                      predicted[7])
        atlas_mse_loss_affine = self.atlas_mse_loss_affine.cumulate(
            predicted[6], predicted[7])
        atlas_loss = atlas_mse_loss * weights[
            4] + atlas_mse_loss_affine * weights[4]
        loss = label_loss + grad_loss * weights[2] + deform_loss * weights[
            3] + pred_map_loss + atlas_loss

        # loss = label_loss
        self._cum_loss += loss.item()
        self._count += 1
        return loss

    def new(self):
        new_loss = self.__class__(penalty=self.grad_loss.penalty,
                                  loss_mult=self.grad_loss.loss_mult,
                                  weights=self.weights)
        new_loss.reset()
        return new_loss

    def description(self):
        return "total {:.4f}, pred map {}, label {}, grad {}, deform {}, ".format(
            self.log(),
            self.pred_maps_mse_loss.description(),
            self.label_mse_loss.description(),
            self.grad_loss.description(),
            self.deform_mse_loss.description(),
        )

    def reset(self):
        super().reset()
        self.pred_maps_mse_loss.reset()

        self.grad_loss.reset()
        self.deform_mse_loss.reset()
        self.label_mse_loss.reset()
        self.label_mse_loss_affine.reset()
        self.atlas_mse_loss.reset()
        self.atlas_mse_loss_affine.reset()
Пример #3
0
class DefSegLoss(TorchLoss):
    def __init__(self, weights: List[float], penalty="l2", loss_mult=None):
        super().__init__()
        self.weights = weights

        self.pred_maps_bce_loss = BCELoss(logit=False)
        self.pred_maps_dice_loss = DiceLoss()
        self.pred_maps_mse_loss = MSELoss()

        self.grad_loss = Grad(penalty=penalty, loss_mult=loss_mult)
        self.deform_mse_loss = MSELoss()

        self.label_dice_loss = DiceLoss()
        self.label_mse_loss = MSELoss()

        self.template_dice_loss = DiceLoss()
        self.template_mse_loss = MSELoss()

        # self.label_bce_loss = BCELoss(logit=False)
        # self.template_bce_loss = BCELoss(logit=False)
        self.epoch = 0

    def cumulate(
        self,
        predicted: Union[torch.Tensor, Iterable[torch.Tensor]],
        outputs: Union[torch.Tensor, Iterable[torch.Tensor]],
    ):
        """predicted = (warped template, warped maps, pred maps, flow)"""
        label, template = outputs
        if self.epoch <= 10:
            weights = [1, 0, 0, 0, 0, 0, 0, 0, 0]
        else:
            weights = self.weights

        pred_map_bce_loss = self.pred_maps_bce_loss.cumulate(
            predicted[2], label)
        pred_map_dice_loss = self.pred_maps_dice_loss.cumulate(
            predicted[2], label)
        pred_map_mse_loss = self.pred_maps_mse_loss.cumulate(
            predicted[2], label)
        pred_map_loss = weights[0] * pred_map_bce_loss + weights[
            1] * pred_map_dice_loss + weights[2] * pred_map_mse_loss

        grad_loss = self.grad_loss.cumulate(predicted[3], None)
        deform_loss = self.deform_mse_loss.cumulate(
            predicted[3],
            torch.zeros(predicted[3].shape).cuda())

        label_dice_loss = self.label_dice_loss.cumulate(predicted[0], label)
        label_mse_loss = self.label_mse_loss.cumulate(predicted[0], label)
        label_loss = weights[3] * label_dice_loss + weights[4] * label_mse_loss

        template_dice_loss = self.template_dice_loss.cumulate(
            predicted[1], template)
        template_mse_loss = self.template_mse_loss.cumulate(
            predicted[1], template)
        template_loss = weights[5] * template_dice_loss + weights[
            6] * template_mse_loss

        loss = pred_map_loss + label_loss + template_loss + grad_loss * self.weights[
            7] + deform_loss * self.weights[8]
        self._cum_loss += loss.item()
        self._count += 1
        return loss

    def new(self):
        new_loss = self.__class__(penalty=self.grad_loss.penalty,
                                  loss_mult=self.grad_loss.loss_mult,
                                  weights=self.weights)
        new_loss.reset()
        return new_loss

    def description(self):
        return "total {:.4f}, pred map {}, pred map {}, pred map {}, label {}, label {}, grad {}, deform {}, ".format(
            self.log(),
            self.pred_maps_bce_loss.description(),
            self.pred_maps_dice_loss.description(),
            self.pred_maps_mse_loss.description(),
            self.label_dice_loss.description(),
            self.label_mse_loss.description(),
            # self.template_dice_loss.description(), self.template_mse_loss.description(),
            self.grad_loss.description(),
            self.deform_mse_loss.description(),
        )

    def reset(self):
        super().reset()
        self.epoch += 1
        self.pred_maps_bce_loss.reset()
        self.pred_maps_mse_loss.reset()
        self.pred_maps_dice_loss.reset()

        self.grad_loss.reset()
        self.deform_mse_loss.reset()
        self.label_dice_loss.reset()
        self.label_mse_loss.reset()
        self.template_mse_loss.reset()
        self.template_dice_loss.reset()
Пример #4
0
class DefSegLoss(TorchLoss):
    def __init__(self, template, penalty="l2", loss_mult=None, weight=0.01):
        super().__init__()
        self.mse_loss = MSELoss()
        self.bce_loss = BCELoss(logit=True)
        self.grad_loss = Grad(penalty=penalty, loss_mult=loss_mult)
        self.weight = weight
        if isinstance(template, np.ndarray):
            self.template = torch.from_numpy(
                template).float().cuda().unsqueeze(0)
        else:
            self.template = template

    # def cumulate(
    #     self,
    #     predicted: Union[torch.Tensor, Iterable[torch.Tensor]],
    #     outputs: Union[torch.Tensor, Iterable[torch.Tensor]],
    # ):
    #     # predicted = (warped_maps, pred maps, flow)
    #
    #     mse_loss = self.mse_loss.cumulate(predicted[0], self.template)
    #     bce_loss = self.bce_loss.cumulate(predicted[1], outputs)
    #     grad_loss = self.grad_loss.cumulate(predicted[2], None)
    #     loss = mse_loss * 0.5 + bce_loss * 0.5 + grad_loss * self.weight
    #     self._cum_loss += loss.item()
    #     self._count += 1
    #     return loss

    def cumulate(
        self,
        predicted: Union[torch.Tensor, Iterable[torch.Tensor]],
        outputs: Union[torch.Tensor, Iterable[torch.Tensor]],
    ):
        # predicted = (warped template, pred maps, flow)

        mse_loss = self.mse_loss.cumulate(predicted[0], outputs)
        # bce_loss = self.bce_loss.cumulate(predicted[1], outputs)
        grad_loss = self.grad_loss.cumulate(predicted[2], None)
        loss = mse_loss + grad_loss * self.weight
        self._cum_loss += loss.item()
        self._count += 1
        return loss

    def new(self):
        new_loss = self.__class__(template=self.template,
                                  penalty=self.grad_loss.penalty,
                                  loss_mult=self.grad_loss.loss_mult,
                                  weight=self.weight)
        new_loss.reset()
        return new_loss

    # def description(self):
    #     return "{}, {}, {}".format(self.mse_loss.description(), self.bce_loss.description(), self.grad_loss.description())

    def description(self):
        return "{}, {}".format(self.mse_loss.description(),
                               self.grad_loss.description())

    def reset(self):
        super().reset()
        self.mse_loss.reset()
        self.bce_loss.reset()
        self.grad_loss.reset()