Пример #1
0
class DefSegLoss(TorchLoss):
    def __init__(self, weights: List[float], penalty="l2", loss_mult=None):
        super().__init__()
        self.grad_loss = Grad(penalty=penalty, loss_mult=loss_mult)
        self.weights = weights
        self.deform_mse_loss = MSELoss()
        self.label_dice_loss = DiceLoss()

    def cumulate(
        self,
        predicted: Union[torch.Tensor, Iterable[torch.Tensor]],
        outputs: Union[torch.Tensor, Iterable[torch.Tensor]],
    ):
        """predicted = (warped template, warped maps, pred maps, flow)"""
        label, template = outputs

        grad_loss = self.grad_loss.cumulate(predicted[3], None)
        deform_loss = self.deform_mse_loss.cumulate(predicted[3], torch.zeros(predicted[3].shape).cuda())

        label_dice_loss = self.label_dice_loss.cumulate(predicted[0], label)

        loss = grad_loss * self.weights[0] + deform_loss * self.weights[1] + \
            label_dice_loss * self.weights[2]

        self._cum_loss += loss.item()
        self._count += 1
        return loss

    def new(self):
        new_loss = self.__class__(
            penalty=self.grad_loss.penalty,
            loss_mult=self.grad_loss.loss_mult, weights=self.weights
        )
        new_loss.reset()
        return new_loss

    def description(self):
        return "{}, deform {}, label {}".format(
            self.grad_loss.description(), self.deform_mse_loss.description(),
            self.label_dice_loss.description(),
        )

    def reset(self):
        super().reset()
        self.grad_loss.reset()
        self.deform_mse_loss.reset()
        self.label_dice_loss.reset()
Пример #2
0
class DefSegLoss(TorchLoss):
    def __init__(self, weights: List[float], penalty="l2", loss_mult=None):
        super().__init__()
        self.weights = weights

        self.pred_maps_bce_loss = BCELoss(logit=False)
        self.pred_maps_dice_loss = DiceLoss()
        self.pred_maps_mse_loss = MSELoss()

        self.grad_loss = Grad(penalty=penalty, loss_mult=loss_mult)
        self.deform_mse_loss = MSELoss()

        self.label_dice_loss = DiceLoss()
        self.label_mse_loss = MSELoss()

        self.template_dice_loss = DiceLoss()
        self.template_mse_loss = MSELoss()

        # self.label_bce_loss = BCELoss(logit=False)
        # self.template_bce_loss = BCELoss(logit=False)
        self.epoch = 0

    def cumulate(
        self,
        predicted: Union[torch.Tensor, Iterable[torch.Tensor]],
        outputs: Union[torch.Tensor, Iterable[torch.Tensor]],
    ):
        """predicted = (warped template, warped maps, pred maps, flow)"""
        label, template = outputs
        if self.epoch <= 10:
            weights = [1, 0, 0, 0, 0, 0, 0, 0, 0]
        else:
            weights = self.weights

        pred_map_bce_loss = self.pred_maps_bce_loss.cumulate(
            predicted[2], label)
        pred_map_dice_loss = self.pred_maps_dice_loss.cumulate(
            predicted[2], label)
        pred_map_mse_loss = self.pred_maps_mse_loss.cumulate(
            predicted[2], label)
        pred_map_loss = weights[0] * pred_map_bce_loss + weights[
            1] * pred_map_dice_loss + weights[2] * pred_map_mse_loss

        grad_loss = self.grad_loss.cumulate(predicted[3], None)
        deform_loss = self.deform_mse_loss.cumulate(
            predicted[3],
            torch.zeros(predicted[3].shape).cuda())

        label_dice_loss = self.label_dice_loss.cumulate(predicted[0], label)
        label_mse_loss = self.label_mse_loss.cumulate(predicted[0], label)
        label_loss = weights[3] * label_dice_loss + weights[4] * label_mse_loss

        template_dice_loss = self.template_dice_loss.cumulate(
            predicted[1], template)
        template_mse_loss = self.template_mse_loss.cumulate(
            predicted[1], template)
        template_loss = weights[5] * template_dice_loss + weights[
            6] * template_mse_loss

        loss = pred_map_loss + label_loss + template_loss + grad_loss * self.weights[
            7] + deform_loss * self.weights[8]
        self._cum_loss += loss.item()
        self._count += 1
        return loss

    def new(self):
        new_loss = self.__class__(penalty=self.grad_loss.penalty,
                                  loss_mult=self.grad_loss.loss_mult,
                                  weights=self.weights)
        new_loss.reset()
        return new_loss

    def description(self):
        return "total {:.4f}, pred map {}, pred map {}, pred map {}, label {}, label {}, grad {}, deform {}, ".format(
            self.log(),
            self.pred_maps_bce_loss.description(),
            self.pred_maps_dice_loss.description(),
            self.pred_maps_mse_loss.description(),
            self.label_dice_loss.description(),
            self.label_mse_loss.description(),
            # self.template_dice_loss.description(), self.template_mse_loss.description(),
            self.grad_loss.description(),
            self.deform_mse_loss.description(),
        )

    def reset(self):
        super().reset()
        self.epoch += 1
        self.pred_maps_bce_loss.reset()
        self.pred_maps_mse_loss.reset()
        self.pred_maps_dice_loss.reset()

        self.grad_loss.reset()
        self.deform_mse_loss.reset()
        self.label_dice_loss.reset()
        self.label_mse_loss.reset()
        self.template_mse_loss.reset()
        self.template_dice_loss.reset()
Пример #3
0
class DefLoss(TorchLoss):
    def __init__(self, weights: List[float], penalty="l2", loss_mult=None):
        super().__init__()
        self.weights = weights

        self.pred_maps_bce_loss = BCELoss(logit=False)
        self.pred_maps_dice_loss = DiceLoss()
        self.pred_maps_mse_loss = MSELoss()

        self.grad_loss = Grad(penalty=penalty, loss_mult=loss_mult)
        self.deform_mse_loss = MSELoss()

        self.label_dice_loss = DiceLoss()
        self.label_mse_loss = MSELoss()

        self.label_dice_loss_affine = DiceLoss()
        self.label_mse_loss_affine = MSELoss()

        self.template_dice_loss = DiceLoss()
        self.template_mse_loss = MSELoss()

        self.atlas_mse_loss = MSELoss()
        self.atlas_mse_loss_affine = MSELoss()

        # self.label_bce_loss = BCELoss(logit=False)
        # self.template_bce_loss = BCELoss(logit=False)
        self.epoch = 0

    def cumulate(
        self,
        predicted: Union[torch.Tensor, Iterable[torch.Tensor]],
        outputs: Union[torch.Tensor, Iterable[torch.Tensor]],
    ):
        """predicted = affine_warped_template, warped_template, pred_maps, preint_flow, warped_image, warped_label, affine_warped_label, batch_atlas"""
        label, template = outputs
        weights = self.weights
        grad_loss = self.grad_loss.cumulate(predicted[3], None)
        deform_loss = self.deform_mse_loss.cumulate(predicted[3], torch.zeros(predicted[3].shape).cuda())

        pred_map_mse_loss = self.pred_maps_mse_loss.cumulate(predicted[2], label)
        pred_map_loss = weights[0] * pred_map_mse_loss

        affine_label_mse_loss = self.label_mse_loss_affine.cumulate(predicted[0], label)
        affine_label_loss = weights[1] * affine_label_mse_loss

        label_mse_loss = self.label_mse_loss.cumulate(predicted[1], label)
        label_loss = weights[1] * label_mse_loss + affine_label_loss

        atlas_mse_loss = self.atlas_mse_loss.cumulate(predicted[5], predicted[7])
        atlas_mse_loss_affine = self.atlas_mse_loss_affine.cumulate(predicted[6], predicted[7])
        atlas_loss = atlas_mse_loss * weights[4] + atlas_mse_loss_affine * weights[4]
        loss = label_loss + grad_loss * weights[2] + deform_loss * weights[3] + pred_map_loss + atlas_loss

        # loss = label_loss
        self._cum_loss += loss.item()
        self._count += 1
        return loss

    def new(self):
        new_loss = self.__class__(
            penalty=self.grad_loss.penalty,
            loss_mult=self.grad_loss.loss_mult, weights=self.weights
        )
        new_loss.reset()
        return new_loss

    def description(self):
        return "total {:.4f}, pred map {}, label {}, grad {}, deform {}, ".format(
            self.log(),
            self.pred_maps_mse_loss.description(),
            self.label_mse_loss.description(),
            self.grad_loss.description(), self.deform_mse_loss.description(),
        )

    def reset(self):
        super().reset()
        self.epoch += 1
        self.pred_maps_bce_loss.reset()
        self.pred_maps_mse_loss.reset()
        self.pred_maps_dice_loss.reset()

        self.grad_loss.reset()
        self.deform_mse_loss.reset()
        self.label_dice_loss.reset()
        self.label_mse_loss.reset()
        self.label_dice_loss_affine.reset()
        self.label_mse_loss_affine.reset()