class DefSegLoss(TorchLoss): def __init__(self, weights: List[float], penalty="l2", loss_mult=None): super().__init__() self.grad_loss = Grad(penalty=penalty, loss_mult=loss_mult) self.weights = weights self.deform_mse_loss = MSELoss() self.label_dice_loss = DiceLoss() def cumulate( self, predicted: Union[torch.Tensor, Iterable[torch.Tensor]], outputs: Union[torch.Tensor, Iterable[torch.Tensor]], ): """predicted = (warped template, warped maps, pred maps, flow)""" label, template = outputs grad_loss = self.grad_loss.cumulate(predicted[3], None) deform_loss = self.deform_mse_loss.cumulate(predicted[3], torch.zeros(predicted[3].shape).cuda()) label_dice_loss = self.label_dice_loss.cumulate(predicted[0], label) loss = grad_loss * self.weights[0] + deform_loss * self.weights[1] + \ label_dice_loss * self.weights[2] self._cum_loss += loss.item() self._count += 1 return loss def new(self): new_loss = self.__class__( penalty=self.grad_loss.penalty, loss_mult=self.grad_loss.loss_mult, weights=self.weights ) new_loss.reset() return new_loss def description(self): return "{}, deform {}, label {}".format( self.grad_loss.description(), self.deform_mse_loss.description(), self.label_dice_loss.description(), ) def reset(self): super().reset() self.grad_loss.reset() self.deform_mse_loss.reset() self.label_dice_loss.reset()
class DefLoss(TorchLoss): def __init__(self, weights: List[float], penalty="l2", loss_mult=None): super().__init__() self.weights = weights self.pred_maps_mse_loss = MSELoss() self.grad_loss = Grad(penalty=penalty, loss_mult=loss_mult) self.deform_mse_loss = MSELoss() self.label_mse_loss = MSELoss() self.label_mse_loss_affine = MSELoss() self.atlas_mse_loss = MSELoss() self.atlas_mse_loss_affine = MSELoss() def cumulate( self, predicted: Union[torch.Tensor, Iterable[torch.Tensor]], outputs: Union[torch.Tensor, Iterable[torch.Tensor]], ): """predicted = affine_warped_template 0, warped_template 1, pred_maps 2, preint_flow 3, warped_image 4, warped_label 5, affine_warped_label 6, batch_atlas 7""" label = outputs weights = self.weights grad_loss = self.grad_loss.cumulate(predicted[3], None) deform_loss = self.deform_mse_loss.cumulate( predicted[3], torch.zeros(predicted[3].shape).cuda()) pred_map_mse_loss = self.pred_maps_mse_loss.cumulate( predicted[2], label) pred_map_loss = weights[0] * pred_map_mse_loss affine_label_mse_loss = self.label_mse_loss_affine.cumulate( predicted[0], label) affine_label_loss = weights[1] * affine_label_mse_loss label_mse_loss = self.label_mse_loss.cumulate(predicted[1], label) label_loss = weights[1] * label_mse_loss + affine_label_loss atlas_mse_loss = self.atlas_mse_loss.cumulate(predicted[5], predicted[7]) atlas_mse_loss_affine = self.atlas_mse_loss_affine.cumulate( predicted[6], predicted[7]) atlas_loss = atlas_mse_loss * weights[ 4] + atlas_mse_loss_affine * weights[4] loss = label_loss + grad_loss * weights[2] + deform_loss * weights[ 3] + pred_map_loss + atlas_loss # loss = label_loss self._cum_loss += loss.item() self._count += 1 return loss def new(self): new_loss = self.__class__(penalty=self.grad_loss.penalty, loss_mult=self.grad_loss.loss_mult, weights=self.weights) new_loss.reset() return new_loss def description(self): return "total {:.4f}, pred map {}, label {}, grad {}, deform {}, ".format( self.log(), self.pred_maps_mse_loss.description(), self.label_mse_loss.description(), self.grad_loss.description(), self.deform_mse_loss.description(), ) def reset(self): super().reset() self.pred_maps_mse_loss.reset() self.grad_loss.reset() self.deform_mse_loss.reset() self.label_mse_loss.reset() self.label_mse_loss_affine.reset() self.atlas_mse_loss.reset() self.atlas_mse_loss_affine.reset()
class DefSegLoss(TorchLoss): def __init__(self, weights: List[float], penalty="l2", loss_mult=None): super().__init__() self.weights = weights self.pred_maps_bce_loss = BCELoss(logit=False) self.pred_maps_dice_loss = DiceLoss() self.pred_maps_mse_loss = MSELoss() self.grad_loss = Grad(penalty=penalty, loss_mult=loss_mult) self.deform_mse_loss = MSELoss() self.label_dice_loss = DiceLoss() self.label_mse_loss = MSELoss() self.template_dice_loss = DiceLoss() self.template_mse_loss = MSELoss() # self.label_bce_loss = BCELoss(logit=False) # self.template_bce_loss = BCELoss(logit=False) self.epoch = 0 def cumulate( self, predicted: Union[torch.Tensor, Iterable[torch.Tensor]], outputs: Union[torch.Tensor, Iterable[torch.Tensor]], ): """predicted = (warped template, warped maps, pred maps, flow)""" label, template = outputs if self.epoch <= 10: weights = [1, 0, 0, 0, 0, 0, 0, 0, 0] else: weights = self.weights pred_map_bce_loss = self.pred_maps_bce_loss.cumulate( predicted[2], label) pred_map_dice_loss = self.pred_maps_dice_loss.cumulate( predicted[2], label) pred_map_mse_loss = self.pred_maps_mse_loss.cumulate( predicted[2], label) pred_map_loss = weights[0] * pred_map_bce_loss + weights[ 1] * pred_map_dice_loss + weights[2] * pred_map_mse_loss grad_loss = self.grad_loss.cumulate(predicted[3], None) deform_loss = self.deform_mse_loss.cumulate( predicted[3], torch.zeros(predicted[3].shape).cuda()) label_dice_loss = self.label_dice_loss.cumulate(predicted[0], label) label_mse_loss = self.label_mse_loss.cumulate(predicted[0], label) label_loss = weights[3] * label_dice_loss + weights[4] * label_mse_loss template_dice_loss = self.template_dice_loss.cumulate( predicted[1], template) template_mse_loss = self.template_mse_loss.cumulate( predicted[1], template) template_loss = weights[5] * template_dice_loss + weights[ 6] * template_mse_loss loss = pred_map_loss + label_loss + template_loss + grad_loss * self.weights[ 7] + deform_loss * self.weights[8] self._cum_loss += loss.item() self._count += 1 return loss def new(self): new_loss = self.__class__(penalty=self.grad_loss.penalty, loss_mult=self.grad_loss.loss_mult, weights=self.weights) new_loss.reset() return new_loss def description(self): return "total {:.4f}, pred map {}, pred map {}, pred map {}, label {}, label {}, grad {}, deform {}, ".format( self.log(), self.pred_maps_bce_loss.description(), self.pred_maps_dice_loss.description(), self.pred_maps_mse_loss.description(), self.label_dice_loss.description(), self.label_mse_loss.description(), # self.template_dice_loss.description(), self.template_mse_loss.description(), self.grad_loss.description(), self.deform_mse_loss.description(), ) def reset(self): super().reset() self.epoch += 1 self.pred_maps_bce_loss.reset() self.pred_maps_mse_loss.reset() self.pred_maps_dice_loss.reset() self.grad_loss.reset() self.deform_mse_loss.reset() self.label_dice_loss.reset() self.label_mse_loss.reset() self.template_mse_loss.reset() self.template_dice_loss.reset()
class DefSegLoss(TorchLoss): def __init__(self, template, penalty="l2", loss_mult=None, weight=0.01): super().__init__() self.mse_loss = MSELoss() self.bce_loss = BCELoss(logit=True) self.grad_loss = Grad(penalty=penalty, loss_mult=loss_mult) self.weight = weight if isinstance(template, np.ndarray): self.template = torch.from_numpy( template).float().cuda().unsqueeze(0) else: self.template = template # def cumulate( # self, # predicted: Union[torch.Tensor, Iterable[torch.Tensor]], # outputs: Union[torch.Tensor, Iterable[torch.Tensor]], # ): # # predicted = (warped_maps, pred maps, flow) # # mse_loss = self.mse_loss.cumulate(predicted[0], self.template) # bce_loss = self.bce_loss.cumulate(predicted[1], outputs) # grad_loss = self.grad_loss.cumulate(predicted[2], None) # loss = mse_loss * 0.5 + bce_loss * 0.5 + grad_loss * self.weight # self._cum_loss += loss.item() # self._count += 1 # return loss def cumulate( self, predicted: Union[torch.Tensor, Iterable[torch.Tensor]], outputs: Union[torch.Tensor, Iterable[torch.Tensor]], ): # predicted = (warped template, pred maps, flow) mse_loss = self.mse_loss.cumulate(predicted[0], outputs) # bce_loss = self.bce_loss.cumulate(predicted[1], outputs) grad_loss = self.grad_loss.cumulate(predicted[2], None) loss = mse_loss + grad_loss * self.weight self._cum_loss += loss.item() self._count += 1 return loss def new(self): new_loss = self.__class__(template=self.template, penalty=self.grad_loss.penalty, loss_mult=self.grad_loss.loss_mult, weight=self.weight) new_loss.reset() return new_loss # def description(self): # return "{}, {}, {}".format(self.mse_loss.description(), self.bce_loss.description(), self.grad_loss.description()) def description(self): return "{}, {}".format(self.mse_loss.description(), self.grad_loss.description()) def reset(self): super().reset() self.mse_loss.reset() self.bce_loss.reset() self.grad_loss.reset()