def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: """ Args: y_pred : the shape should be BNH[WD], where N is the number of classes. It only supports binary segmentation. The input should be the original logits since it will be transformed by a sigmoid in the forward function. y_true : the shape should be BNH[WD], where N is the number of classes. It only supports binary segmentation. Raises: ValueError: When input and target are different shape ValueError: When len(y_pred.shape) != 4 and len(y_pred.shape) != 5 ValueError: When num_classes ValueError: When the number of classes entered does not match the expected number """ if y_pred.shape != y_true.shape: raise ValueError( f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})" ) if len(y_pred.shape) != 4 and len(y_pred.shape) != 5: raise ValueError( f"input shape must be 4 or 5, but got {y_pred.shape}") if y_pred.shape[1] == 1: y_pred = one_hot(y_pred, num_classes=self.num_classes) y_true = one_hot(y_true, num_classes=self.num_classes) if torch.max(y_true) != self.num_classes - 1: raise ValueError( f"Pelase make sure the number of classes is {self.num_classes-1}" ) n_pred_ch = y_pred.shape[1] if self.to_onehot_y: if n_pred_ch == 1: warnings.warn( "single channel prediction, `to_onehot_y=True` ignored.") else: y_true = one_hot(y_true, num_classes=n_pred_ch) asy_focal_loss = self.asy_focal_loss(y_pred, y_true) asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true) loss: torch.Tensor = self.weight * asy_focal_loss + ( 1 - self.weight) * asy_focal_tversky_loss if self.reduction == LossReduction.SUM.value: return torch.sum(loss) # sum over the batch and channel dims if self.reduction == LossReduction.NONE.value: return loss # returns [N, num_classes] losses if self.reduction == LossReduction.MEAN.value: return torch.mean(loss) raise ValueError( f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].' )
def __call__( self, img: Union[Sequence[NdarrayOrTensor], NdarrayOrTensor]) -> NdarrayOrTensor: img_ = self.get_stacked_torch(img) if self.num_classes is not None: has_ch_dim = True if img_.ndimension() > 1 and img_.shape[1] > 1: warnings.warn( "no need to specify num_classes for One-Hot format data.") else: if img_.ndimension() == 1: # if no channel dim, need to remove channel dim after voting has_ch_dim = False img_ = one_hot(img_, self.num_classes, dim=1) img_ = torch.mean(img_.float(), dim=0) if self.num_classes is not None: # if not One-Hot, use "argmax" to vote the most common class out_pt = torch.argmax(img_, dim=0, keepdim=has_ch_dim) else: # for One-Hot data, round the float number to 0 or 1 out_pt = torch.round(img_) return self.post_convert(out_pt, img)
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: n_pred_ch = y_pred.shape[1] if self.to_onehot_y: if n_pred_ch == 1: warnings.warn( "single channel prediction, `to_onehot_y=True` ignored.") else: y_true = one_hot(y_true, num_classes=n_pred_ch) if y_true.shape != y_pred.shape: raise ValueError( f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})" ) y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon) cross_entropy = -y_true * torch.log(y_pred) back_ce = torch.pow(1 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0] back_ce = (1 - * back_ce fore_ce = cross_entropy[:, 1] fore_ce = * fore_ce loss = torch.mean( torch.sum(torch.stack([back_ce, fore_ce], dim=1), dim=1)) return loss
def __call__( self, img: Union[Sequence[torch.Tensor], torch.Tensor]) -> torch.Tensor: img_ = torch.stack(img) if isinstance(img, (tuple, list)) else torch.as_tensor(img) if self.num_classes is not None: has_ch_dim = True if img_.ndimension() > 2 and img_.shape[2] > 1: warnings.warn( "no need to specify num_classes for One-Hot format data.") else: if img_.ndimension() == 2: # if no channel dim, need to remove channel dim after voting has_ch_dim = False img_ = one_hot(img_, self.num_classes, dim=2) img_ = torch.mean(img_.float(), dim=0) if self.num_classes is not None: # if not One-Hot, use "argmax" to vote the most common class return torch.argmax(img_, dim=1, keepdim=has_ch_dim) else: # for One-Hot data, round the float number to 0 or 1 return torch.round(img_)
def __call__( self, img, argmax: Optional[bool] = None, to_onehot: Optional[bool] = None, n_classes: Optional[int] = None, threshold_values: Optional[bool] = None, logit_thresh: Optional[float] = None, ): """ Args: argmax: whether to execute argmax function on input data before transform. Defaults to ``self.argmax``. to_onehot: whether to convert input data into the one-hot format. Defaults to ``self.to_onehot``. n_classes: the number of classes to convert to One-Hot format. Defaults to ``self.n_classes``. threshold_values: whether threshold the float value to int number 0 or 1. Defaults to ``self.threshold_values``. logit_thresh: the threshold value for thresholding operation.. Defaults to ``self.logit_thresh``. """ if argmax or self.argmax: img = torch.argmax(img, dim=1, keepdim=True) if to_onehot or self.to_onehot: _nclasses = self.n_classes if n_classes is None else n_classes assert isinstance(_nclasses, int), "One of self.n_classes or n_classes must be an integer" img = one_hot(img, _nclasses) if threshold_values or self.threshold_values: img = img >= (self.logit_thresh if logit_thresh is None else logit_thresh) return img.float()
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Args: input: the shape should be BNH[WD]. The input should be the original logits due to the restriction of ``monai.losses.FocalLoss``. target: the shape should be BNH[WD] or B1H[WD]. Raises: ValueError: When number of dimensions for input and target are different. ValueError: When number of channels for target is neither 1 nor the same as input. """ if len(input.shape) != len(target.shape): raise ValueError("the number of dimensions for input and target should be the same.") n_pred_ch = input.shape[1] if self.to_onehot_y: if n_pred_ch == 1: warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") else: target = one_hot(target, num_classes=n_pred_ch) if not self.include_background: if n_pred_ch == 1: warnings.warn("single channel prediction, `include_background=False` ignored.") else: # if skipping background, removing first channel target = target[:, 1:] input = input[:, 1:] dice_loss = self.dice(input, target) focal_loss = self.focal(input, target) total_loss: torch.Tensor = self.lambda_dice * dice_loss + self.lambda_focal * focal_loss return total_loss
def test_shape(self, input_data, expected_shape, expected_result=None): result = one_hot(**input_data) self.assertEqual(result.shape, expected_shape) if expected_result is not None: self.assertTrue(np.allclose(expected_result, result.numpy())) if "dtype" in input_data: self.assertEqual(result.dtype, input_data["dtype"]) else: # by default, expecting float type self.assertEqual(result.dtype, torch.float)
def __call__( self, img: torch.Tensor, argmax: Optional[bool] = None, to_onehot: Optional[bool] = None, num_classes: Optional[int] = None, threshold_values: Optional[bool] = None, logit_thresh: Optional[float] = None, rounding: Optional[str] = None, n_classes: Optional[int] = None, ) -> torch.Tensor: """ Args: img: the input tensor data to convert, if no channel dimension when converting to `One-Hot`, will automatically add it. argmax: whether to execute argmax function on input data before transform. Defaults to ``self.argmax``. to_onehot: whether to convert input data into the one-hot format. Defaults to ``self.to_onehot``. num_classes: the number of classes to convert to One-Hot format. Defaults to ``self.num_classes``. threshold_values: whether threshold the float value to int number 0 or 1. Defaults to ``self.threshold_values``. logit_thresh: the threshold value for thresholding operation.. Defaults to ``self.logit_thresh``. rounding: if not None, round the data according to the specified option, available options: ["torchrounding"]. .. deprecated:: 0.6.0 ``n_classes`` is deprecated, use ``num_classes`` instead. """ # in case the new num_classes is default but you still call deprecated n_classes if n_classes is not None and num_classes is None: num_classes = n_classes if argmax or self.argmax: img = torch.argmax(img, dim=0, keepdim=True) if to_onehot or self.to_onehot: _nclasses = self.num_classes if num_classes is None else num_classes if not isinstance(_nclasses, int): raise AssertionError("One of self.num_classes or num_classes must be an integer") img = one_hot(img, num_classes=_nclasses, dim=0) if threshold_values or self.threshold_values: img = img >= (self.logit_thresh if logit_thresh is None else logit_thresh) rounding = self.rounding if rounding is None else rounding if rounding is not None: look_up_option(rounding, ["torchrounding"]) img = torch.round(img) return img.float()
def __call__( self, img: NdarrayOrTensor, argmax: Optional[bool] = None, to_onehot: Optional[int] = None, threshold: Optional[float] = None, rounding: Optional[str] = None ) -> NdarrayOrTensor: """ Args: img: the input tensor data to convert, if no channel dimension when converting to `One-Hot`, will automatically add it. argmax: whether to execute argmax function on input data before transform. Defaults to ``self.argmax``. to_onehot: if not None, convert input data into the one-hot format with specified number of classes. Defaults to ``self.to_onehot``. threshold: if not None, threshold the float values to int number 0 or 1 with specified threshold value. Defaults to ``self.threshold``. rounding: if not None, round the data according to the specified option, available options: ["torchrounding"]. """ img_t: torch.Tensor img_t, *_ = convert_data_type(img, torch.Tensor) # type: ignore if argmax or self.argmax: img_t = torch.argmax(img_t, dim=self.kwargs.get("dim", 0), keepdim=self.kwargs.get("keepdim", True)) to_onehot = self.to_onehot if to_onehot is None else to_onehot if to_onehot is not None: if not isinstance(to_onehot, int): raise ValueError("the number of classes for One-Hot must be an integer.") img_t = one_hot( img_t, num_classes=to_onehot, dim=self.kwargs.get("dim", 0), dtype=self.kwargs.get("dtype", torch.float) ) threshold = self.threshold if threshold is None else threshold if threshold is not None: img_t = img_t >= threshold rounding = self.rounding if rounding is None else rounding if rounding is not None: look_up_option(rounding, ["torchrounding"]) img_t = torch.round(img_t) img, *_ = convert_to_dst_type(img_t, img, dtype=self.kwargs.get("dtype", torch.float)) return img
def __call__(self, img, to_onehot: Optional[bool] = None, num_classes: Optional[int] = None): """ Args: to_onehot: whether to convert the data to One-Hot format first. Defaults to ``self.to_onehot``. num_classes: the class number used to convert to One-Hot format if `to_onehot` is True. Defaults to ``self.num_classes``. """ if to_onehot or self.to_onehot: if num_classes is None: num_classes = self.num_classes assert isinstance(num_classes, int), "must specify class number for One-Hot." img = one_hot(img, num_classes) n_classes = img.shape[1] outputs = list() for i in range(n_classes): outputs.append(img[:, i : i + 1]) return outputs
def do_binarization( input_data: torch.Tensor, bin_mode: str = "threshold", bin_threshold: Union[float, Sequence[float]] = 0.5, ) -> torch.Tensor: """ Args: input_data: the input that to be binarized, in the shape [B] or [BN] or [BNHW] or [BNHWD]. bin_mode: can be ``"threshold"`` or ``"mutually_exclusive"``, or a callable function. - ``"threshold"``, a single threshold or a sequence of thresholds should be set. - ``"mutually_exclusive"``, `input_data` will be converted by a combination of argmax and to_onehot. bin_threshold: the threshold to binarize the input data, can be a single value or a sequence of values that each one of the value represents a threshold for a class. Raises: AssertionError: when `bin_threshold` is a sequence and the input has the shape [B]. AssertionError: when `bin_threshold` is a sequence but the length != the number of classes. AssertionError: when `bin_mode` is ``"mutually_exclusive"`` the input has the shape [B]. AssertionError: when `bin_mode` is ``"mutually_exclusive"`` the input has the shape [B, 1]. """ input_ndim = input_data.ndimension() if bin_mode == "threshold": if isinstance(bin_threshold, Sequence): assert input_ndim > 1, "a sequence of thresholds are used for multi-class tasks." error_hint = "the length of the sequence should be the same as the number of classes." assert input_data.shape[1] == len(bin_threshold), "{}".format( error_hint) for cls_num in range(input_data.shape[1]): input_data[:, cls_num] = (input_data[:, cls_num] > bin_threshold[cls_num]).float() else: input_data = (input_data > bin_threshold).float() elif bin_mode == "mutually_exclusive": assert input_ndim > 1, "mutually_exclusive is used for multi-class tasks." n_classes = input_data.shape[1] assert n_classes > 1, "mutually_exclusive is used for multi-class tasks." input_data = torch.argmax(input_data, dim=1, keepdim=True) input_data = one_hot(input_data, num_classes=n_classes) return input_data
def test_consistency_with_cross_entropy_2d_onehot_label(self): # For gamma=0 the focal loss reduces to the cross entropy loss focal_loss = FocalLoss(to_onehot_y=False, gamma=0.0, reduction="mean") ce = nn.CrossEntropyLoss(reduction="mean") max_error = 0 class_num = 10 batch_size = 128 for _ in range(100): # Create a random tensor of shape (batch_size, class_num, 8, 4) x = torch.rand(batch_size, class_num, 8, 4, requires_grad=True) # Create a random batch of classes l = torch.randint(low=0, high=class_num, size=(batch_size, 1, 8, 4)) if torch.cuda.is_available(): x = x.cuda() l = l.cuda() output0 = focal_loss(x, one_hot(l, num_classes=class_num)) output1 = ce(x, l[:, 0]) / class_num a = float(output0.cpu().detach()) b = float(output1.cpu().detach()) if abs(a - b) > max_error: max_error = abs(a - b) self.assertAlmostEqual(max_error, 0.0, places=3)
def test_consistency_with_cross_entropy_classification_01(self): # for gamma=0.1 the focal loss differs from the cross entropy loss focal_loss = FocalLoss(to_onehot_y=True, gamma=0.1, reduction="mean") ce = nn.BCEWithLogitsLoss(reduction="mean") max_error = 0 class_num = 10 batch_size = 128 for _ in range(100): # Create a random scores tensor of shape (batch_size, class_num) x = torch.rand(batch_size, class_num, requires_grad=True) # Create a random batch of classes l = torch.randint(low=0, high=class_num, size=(batch_size, 1)) l = l.long() if torch.cuda.is_available(): x = x.cuda() l = l.cuda() output0 = focal_loss(x, l) output1 = ce(x, one_hot(l, num_classes=class_num)) a = float(output0.cpu().detach()) b = float(output1.cpu().detach()) if abs(a - b) > max_error: max_error = abs(a - b) self.assertNotAlmostEqual(max_error, 0.0, places=3)
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: n_pred_ch = y_pred.shape[1] if self.to_onehot_y: if n_pred_ch == 1: warnings.warn( "single channel prediction, `to_onehot_y=True` ignored.") else: y_true = one_hot(y_true, num_classes=n_pred_ch) if y_true.shape != y_pred.shape: raise ValueError( f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})" ) # clip the prediction to avoid NaN y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon) axis = list(range(2, len(y_pred.shape))) # Calculate true positives (tp), false negatives (fn) and false positives (fp) tp = torch.sum(y_true * y_pred, dim=axis) fn = torch.sum(y_true * (1 - y_pred), dim=axis) fp = torch.sum((1 - y_true) * y_pred, dim=axis) dice_class = (tp + self.epsilon) / (tp + * fn + (1 - * fp + self.epsilon) # Calculate losses separately for each class, enhancing both classes back_dice = 1 - dice_class[:, 0] fore_dice = (1 - dice_class[:, 1]) * torch.pow(1 - dice_class[:, 1], -self.gamma) # Average class scores loss = torch.mean(torch.stack([back_dice, fore_dice], dim=-1)) return loss
def __call__( self, img: NdarrayOrTensor, argmax: Optional[bool] = None, to_onehot: Optional[int] = None, threshold: Optional[float] = None, rounding: Optional[str] = None, n_classes: Optional[int] = None, # deprecated num_classes: Optional[int] = None, # deprecated logit_thresh: Optional[float] = None, # deprecated threshold_values: Optional[bool] = None, # deprecated ) -> NdarrayOrTensor: """ Args: img: the input tensor data to convert, if no channel dimension when converting to `One-Hot`, will automatically add it. argmax: whether to execute argmax function on input data before transform. Defaults to ``self.argmax``. to_onehot: if not None, convert input data into the one-hot format with specified number of classes. Defaults to ``self.to_onehot``. threshold: if not None, threshold the float values to int number 0 or 1 with specified threshold value. Defaults to ``self.threshold``. rounding: if not None, round the data according to the specified option, available options: ["torchrounding"]. .. deprecated:: 0.6.0 ``n_classes`` is deprecated, use ``to_onehot`` instead. .. deprecated:: 0.7.0 ``num_classes`` is deprecated, use ``to_onehot`` instead. ``logit_thresh`` is deprecated, use ``threshold`` instead. ``threshold_values`` is deprecated, use ``threshold`` instead. """ if isinstance(to_onehot, bool): warnings.warn( "`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead." ) to_onehot = num_classes if to_onehot else None if isinstance(threshold, bool): warnings.warn( "`threshold_values=True/False` is deprecated, please use `threshold=value` instead." ) threshold = logit_thresh if threshold else None img = convert_to_tensor(img, track_meta=get_track_meta()) img_t, *_ = convert_data_type(img, torch.Tensor) if argmax or self.argmax: img_t = torch.argmax(img_t, dim=0, keepdim=True) to_onehot = self.to_onehot if to_onehot is None else to_onehot if to_onehot is not None: if not isinstance(to_onehot, int): raise AssertionError( "the number of classes for One-Hot must be an integer.") img_t = one_hot(img_t, num_classes=to_onehot, dim=0) threshold = self.threshold if threshold is None else threshold if threshold is not None: img_t = img_t >= threshold rounding = self.rounding if rounding is None else rounding if rounding is not None: look_up_option(rounding, ["torchrounding"]) img_t = torch.round(img_t) img, *_ = convert_to_dst_type(img_t, img, dtype=torch.float) return img
def test_convergence(self, loss_type, loss_args, forward_args): """ The goal of this test is to assess if the gradient of the loss function is correct by testing if we can train a one layer neural network to segment one image. We verify that the loss is decreasing in almost all SGD steps. """ learning_rate = 0.001 max_iter = 40 # define a simple 3d example target_seg = torch.tensor( [[ # raw 0 [[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]], # raw 1 [[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]], # raw 2 [[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]], ]], device=self.device, ) target_seg = torch.unsqueeze(target_seg, dim=0) image = 12 * target_seg + 27 image = image.float().to(self.device) num_classes = 2 num_voxels = 3 * 4 * 4 target_onehot = one_hot(target_seg, num_classes=num_classes) # define a one layer model class OnelayerNet(nn.Module): def __init__(self): super(OnelayerNet, self).__init__() self.layer_1 = nn.Linear(num_voxels, 200) self.acti = nn.ReLU() self.layer_2 = nn.Linear(200, num_voxels * num_classes) def forward(self, x): x = x.view(-1, num_voxels) x = self.layer_1(x) x = self.acti(x) x = self.layer_2(x) x = x.view(-1, num_classes, 3, 4, 4) return x # initialise the network net = OnelayerNet().to(self.device) # initialize the loss loss = loss_type(**loss_args) # initialize a SGD optimizer optimizer = optim.Adam(net.parameters(), lr=learning_rate) loss_history = [] init_output = None # train the network for iter_i in range(max_iter): # set the gradient to zero optimizer.zero_grad() # forward pass output = net(image) if init_output is None: init_output = torch.argmax(output, 1).detach().cpu().numpy() if loss_args["to_onehot_y"] is False: loss_val = loss(output, target_onehot, **forward_args) else: loss_val = loss(output, target_seg, **forward_args) if iter_i % 10 == 0: pred = torch.argmax(output, 1).detach().cpu().numpy() gt = target_seg.detach().cpu().numpy()[:, 0] print( f"{loss_type.__name__} iter: {iter_i}, acc: {np.sum(pred == gt) /}" ) # backward pass loss_val.backward() optimizer.step() # stats loss_history.append(loss_val.item()) pred = torch.argmax(output, 1).detach().cpu().numpy() target = target_seg.detach().cpu().numpy()[:, 0] # initial predictions are bad self.assertTrue(not np.allclose(init_output, target)) # final predictions are good np.testing.assert_allclose(pred, target)
def forward(self, input: torch.Tensor, target: torch.Tensor, smooth: float = 1e-5) -> torch.Tensor: """ Args: input: the shape should be BNH[WD]. target: the shape should be BNH[WD]. smooth: a small constant to avoid nan. Raises: ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. """ if self.sigmoid: input = torch.sigmoid(input) n_pred_ch = input.shape[1] if self.softmax: if n_pred_ch == 1: warnings.warn( "single channel prediction, `softmax=True` ignored.") else: input = torch.softmax(input, 1) if self.other_act is not None: input = self.other_act(input) if self.to_onehot_y: if n_pred_ch == 1: warnings.warn( "single channel prediction, `to_onehot_y=True` ignored.") else: target = one_hot(target, num_classes=n_pred_ch) if not self.include_background: if n_pred_ch == 1: warnings.warn( "single channel prediction, `include_background=False` ignored." ) else: # if skipping background, removing first channel target = target[:, 1:] input = input[:, 1:] assert ( target.shape == input.shape ), f"ground truth has differing shape ({target.shape}) from input ({input.shape})" # reducing only spatial dimensions (not batch nor channels) reduce_axis = list(range(2, len(input.shape))) intersection = torch.sum(target * input, dim=reduce_axis) if self.squared_pred: target = torch.pow(target, 2) input = torch.pow(input, 2) ground_o = torch.sum(target, dim=reduce_axis) pred_o = torch.sum(input, dim=reduce_axis) denominator = ground_o + pred_o if self.jaccard: denominator = 2.0 * (denominator - intersection) f: torch.Tensor = 1.0 - (2.0 * intersection + smooth) / (denominator + smooth) if self.reduction == LossReduction.MEAN.value: f = torch.mean(f) # the batch and channel average elif self.reduction == LossReduction.SUM.value: f = torch.sum(f) # sum over the batch and channel dims elif self.reduction == LossReduction.NONE.value: pass # returns [N, n_classes] losses else: raise ValueError( f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].' ) return f
def compute_roc_auc( y_pred: torch.Tensor, y: torch.Tensor, to_onehot_y: bool = False, softmax: bool = False, other_act: Optional[Callable] = None, average: Union[Average, str] = Average.MACRO, ): """Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC). Referring to: `sklearn.metrics.roc_auc_score < sklearn.metrics.roc_auc_score.html#sklearn.metrics.roc_auc_score>`_. Args: y_pred: input data to compute, typical classification model output. it must be One-Hot format and first dim is batch, example shape: [16] or [16, 2]. y: ground truth to compute ROC AUC metric, the first dim is batch. example shape: [16, 1] will be converted into [16, 2] (where `2` is inferred from `y_pred`). to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. softmax: whether to add softmax function to `y_pred` before computation. Defaults to False. other_act: callable function to replace `softmax` as activation layer if needed, Defaults to ``None``. for example: `other_act = lambda x: torch.log_softmax(x)`. average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``} Type of averaging performed if not binary classification. Defaults to ``"macro"``. - ``"macro"``: calculate metrics for each label, and find their unweighted mean. This does not take label imbalance into account. - ``"weighted"``: calculate metrics for each label, and find their average, weighted by support (the number of true instances for each label). - ``"micro"``: calculate metrics globally by considering each element of the label indicator matrix as a label. - ``"none"``: the scores for each class are returned. Raises: ValueError: When ``y_pred`` dimension is not one of [1, 2]. ValueError: When ``y`` dimension is not one of [1, 2]. ValueError: When ``softmax=True`` and ``other_act is not None``. Incompatible values. TypeError: When ``other_act`` is not an ``Optional[Callable]``. ValueError: When ``average`` is not one of ["macro", "weighted", "micro", "none"]. Note: ROCAUC expects y to be comprised of 0's and 1's. `y_pred` must be either prob. estimates or confidence values. """ y_pred_ndim = y_pred.ndimension() y_ndim = y.ndimension() if y_pred_ndim not in (1, 2): raise ValueError( "Predictions should be of shape (batch_size, n_classes) or (batch_size, )." ) if y_ndim not in (1, 2): raise ValueError( "Targets should be of shape (batch_size, n_classes) or (batch_size, )." ) if y_pred_ndim == 2 and y_pred.shape[1] == 1: y_pred = y_pred.squeeze(dim=-1) y_pred_ndim = 1 if y_ndim == 2 and y.shape[1] == 1: y = y.squeeze(dim=-1) if y_pred_ndim == 1: if to_onehot_y: warnings.warn( "y_pred has only one channel, to_onehot_y=True ignored.") if softmax: warnings.warn("y_pred has only one channel, softmax=True ignored.") return _calculate(y, y_pred) else: n_classes = y_pred.shape[1] if to_onehot_y: y = one_hot(y, n_classes) if softmax and other_act is not None: raise ValueError( "Incompatible values: softmax=True and other_act is not None.") if softmax: y_pred = y_pred.float().softmax(dim=1) if other_act is not None: if not callable(other_act): raise TypeError( f"other_act must be None or callable but is {type(other_act).__name__}." ) y_pred = other_act(y_pred) assert y.shape == y_pred.shape, "data shapes of y_pred and y do not match." average = Average(average) if average == Average.MICRO: return _calculate(y.flatten(), y_pred.flatten()) else: y, y_pred = y.transpose(0, 1), y_pred.transpose(0, 1) auc_values = [ _calculate(y_, y_pred_) for y_, y_pred_ in zip(y, y_pred) ] if average == Average.NONE: return auc_values if average == Average.MACRO: return np.mean(auc_values) if average == Average.WEIGHTED: weights = [sum(y_) for y_ in y] return np.average(auc_values, weights=weights) raise ValueError( f'Unsupported average: {average}, available options are ["macro", "weighted", "micro", "none"].' )
def compute_meandice( y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True, to_onehot_y: bool = False, mutually_exclusive: bool = False, sigmoid: bool = False, other_act: Optional[Callable] = None, logit_thresh: float = 0.5, ) -> torch.Tensor: """Computes Dice score metric from full size Tensor and collects average. Args: y_pred: input data to compute, typical segmentation model output. it must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. y: ground truth to compute mean dice metric, the first dim is batch. example shape: [16, 1, 32, 32] will be converted into [16, 3, 32, 32]. alternative shape: [16, 3, 32, 32] and set `to_onehot_y=False` to use 3-class labels directly. include_background: whether to skip Dice computation on the first channel of the predicted output. Defaults to True. to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. mutually_exclusive: if True, `y_pred` will be converted into a binary matrix using a combination of argmax and to_onehot. Defaults to False. sigmoid: whether to add sigmoid function to y_pred before computation. Defaults to False. other_act: callable function to replace `sigmoid` as activation layer if needed, Defaults to ``None``. for example: `other_act = torch.tanh`. logit_thresh: the threshold value used to convert (for example, after sigmoid if `sigmoid=True`) `y_pred` into a binary matrix. Defaults to 0.5. Raises: ValueError: When ``sigmoid=True`` and ``other_act is not None``. Incompatible values. TypeError: When ``other_act`` is not an ``Optional[Callable]``. ValueError: When ``sigmoid=True`` and ``mutually_exclusive=True``. Incompatible values. Returns: Dice scores per batch and per class, (shape [batch_size, n_classes]). Note: This method provides two options to convert `y_pred` into a binary matrix (1) when `mutually_exclusive` is True, it uses a combination of ``argmax`` and ``to_onehot``, (2) when `mutually_exclusive` is False, it uses a threshold ``logit_thresh`` (optionally with a ``sigmoid`` function before thresholding). """ n_classes = y_pred.shape[1] n_len = len(y_pred.shape) if sigmoid and other_act is not None: raise ValueError( "Incompatible values: sigmoid=True and other_act is not None.") if sigmoid: y_pred = y_pred.float().sigmoid() if other_act is not None: if not callable(other_act): raise TypeError( f"other_act must be None or callable but is {type(other_act).__name__}." ) y_pred = other_act(y_pred) if n_classes == 1: if mutually_exclusive: warnings.warn( "y_pred has only one class, mutually_exclusive=True ignored.") if to_onehot_y: warnings.warn( "y_pred has only one channel, to_onehot_y=True ignored.") if not include_background: warnings.warn( "y_pred has only one channel, include_background=False ignored." ) # make both y and y_pred binary y_pred = (y_pred >= logit_thresh).float() y = (y > 0).float() else: # multi-channel y_pred # make both y and y_pred binary if mutually_exclusive: if sigmoid: raise ValueError( "Incompatible values: sigmoid=True and mutually_exclusive=True." ) y_pred = torch.argmax(y_pred, dim=1, keepdim=True) y_pred = one_hot(y_pred, num_classes=n_classes) else: y_pred = (y_pred >= logit_thresh).float() if to_onehot_y: y = one_hot(y, num_classes=n_classes) if not include_background: y = y[:, 1:] if y.shape[1] > 1 else y y_pred = y_pred[:, 1:] if y_pred.shape[1] > 1 else y_pred assert y.shape == y_pred.shape, "Ground truth one-hot has differing shape (%r) from source (%r)" % ( y.shape, y_pred.shape, ) y = y.float() y_pred = y_pred.float() # reducing only spatial dimensions (not batch nor channels) reduce_axis = list(range(2, n_len)) intersection = torch.sum(y * y_pred, dim=reduce_axis) y_o = torch.sum(y, reduce_axis) y_pred_o = torch.sum(y_pred, dim=reduce_axis) denominator = y_o + y_pred_o f = torch.where(y_o > 0, (2.0 * intersection) / denominator, torch.tensor(float("nan"), device=y_o.device)) return f # returns array of Dice shape: [batch, n_classes]
def forward(self, input: torch.Tensor, target: torch.Tensor, smooth: float = 1e-5): """ Args: input (tensor): the shape should be BNH[WD]. target (tensor): the shape should be BNH[WD]. smooth: a small constant to avoid nan. Raises: ValueError: reduction={self.reduction} is invalid. """ if self.sigmoid: input = torch.sigmoid(input) n_pred_ch = input.shape[1] if n_pred_ch == 1: if self.softmax: warnings.warn("single channel prediction, `softmax=True` ignored.") if self.to_onehot_y: warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") if not self.include_background: warnings.warn("single channel prediction, `include_background=False` ignored.") else: if self.softmax: input = torch.softmax(input, 1) if self.to_onehot_y: target = one_hot(target, n_pred_ch) if not self.include_background: # if skipping background, removing first channel target = target[:, 1:] input = input[:, 1:] assert ( target.shape == input.shape ), f"ground truth has differing shape ({target.shape}) from input ({input.shape})" # reducing only spatial dimensions (not batch nor channels) reduce_axis = list(range(2, len(input.shape))) intersection = torch.sum(target * input, reduce_axis) ground_o = torch.sum(target, reduce_axis) pred_o = torch.sum(input, reduce_axis) denominator = ground_o + pred_o w = self.w_func(ground_o.float()) for b in w: infs = torch.isinf(b) b[infs] = 0.0 b[infs] = torch.max(b) f = 1.0 - (2.0 * (intersection * w).sum(1) + smooth) / ((denominator * w).sum(1) + smooth) if self.reduction == LossReduction.MEAN: f = torch.mean(f) # the batch and channel average elif self.reduction == LossReduction.SUM: f = torch.sum(f) # sum over the batch and channel dims elif self.reduction == LossReduction.NONE: pass # returns [N, n_classes] losses else: raise ValueError(f"reduction={self.reduction} is invalid.") return f
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Args: input: the shape should be BNH[WD]. target: the shape should be BNH[WD]. Raises: ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. """ if self.sigmoid: input = torch.sigmoid(input) n_pred_ch = input.shape[1] if self.softmax: if n_pred_ch == 1: warnings.warn( "single channel prediction, `softmax=True` ignored.") else: input = torch.softmax(input, 1) if self.other_act is not None: input = self.other_act(input) if self.to_onehot_y: if n_pred_ch == 1: warnings.warn( "single channel prediction, `to_onehot_y=True` ignored.") else: target = one_hot(target, num_classes=n_pred_ch) if not self.include_background: if n_pred_ch == 1: warnings.warn( "single channel prediction, `include_background=False` ignored." ) else: # if skipping background, removing first channel target = target[:, 1:] input = input[:, 1:] if target.shape != input.shape: raise AssertionError( f"ground truth has differing shape ({target.shape}) from input ({input.shape})" ) p0 = input p1 = 1 - p0 g0 = target g1 = 1 - g0 # reducing only spatial dimensions (not batch nor channels) reduce_axis: List[int] = torch.arange(2, len(input.shape)).tolist() if self.batch: # reducing spatial dimensions and batch reduce_axis = [0] + reduce_axis tp = torch.sum(p0 * g0, reduce_axis) fp = self.alpha * torch.sum(p0 * g1, reduce_axis) fn = self.beta * torch.sum(p1 * g0, reduce_axis) numerator = tp + self.smooth_nr denominator = tp + fp + fn + self.smooth_dr score: torch.Tensor = 1.0 - numerator / denominator if self.reduction == LossReduction.SUM.value: return torch.sum(score) # sum over the batch and channel dims if self.reduction == LossReduction.NONE.value: return score # returns [N, n_classes] losses if self.reduction == LossReduction.MEAN.value: return torch.mean(score) raise ValueError( f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].' )
def forward(self, input, target: torch.Tensor, smooth: float = 1e-5) -> torch.Tensor: """ Args: input: the shape should be BNH[WD]. target: the shape should be BNH[WD]. smooth: a small constant to avoid nan. Raises: ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. """ x, att_maps = input loss_function_single_channel = Dice(to_onehot_y=False, softmax=False) total_att_loss = 0 if self.supervised_attention: L = len(att_maps) att_losses = [] G_l = target for level in range(L): # A[level] are the attention maps as they arrive here # G[level] are downsampled ground truth maps, they are converted to one-hot inside the loss-function att_loss = loss_function_single_channel( att_maps[L - level - 1], G_l) att_losses.append(att_loss) total_att_loss = total_att_loss + 1 / L * att_loss if level < L - 1: shape_curr_att_map = att_maps[L - level - 1].shape shape_next_att_map = att_maps[L - level - 2].shape # assert that shape of current attention map is multiple of next att map in all dimensions assert all([ x % y == 0 for x, y in zip(shape_curr_att_map, shape_next_att_map) ]) shape_ratio = [ x // y for x, y in zip(shape_curr_att_map, shape_next_att_map) ] kernel_size_and_stride = shape_ratio[2:5] G_l = torch.nn.MaxPool3d( kernel_size=kernel_size_and_stride, stride=kernel_size_and_stride)(G_l) hardness_weight = None if self.hardness_weighting: hardness_lambda = 0.6 hardness_weight = hardness_lambda * abs( torch.softmax(x, dim=1) - one_hot(target, num_classes=x.shape[1])) + (1.0 - hardness_lambda) # img = hardness_weight.cpu().detach().numpy() # x_ = torch.softmax(x, dim=1).cpu().detach().numpy() # target_ = one_hot(target, num_classes=x.shape[1]).cpu().detach().numpy() # import matplotlib.pyplot as plt # fig, axs = plt.subplots(1, 3) # axs[0].imshow(x_[1, 1, :, :, 20], cmap='gray') # axs[1].imshow(target_[1, 1, :, :, 20]) # axs[2].imshow(img[1, 1, :, :, 20]) # # pass loss_function_multi_channel = Dice(to_onehot_y=True, softmax=True, hardness_weight=hardness_weight) pred_loss = loss_function_multi_channel(x, target) return total_att_loss + pred_loss
def forward(self, input: torch.Tensor, target: torch.Tensor, smooth: float = 1e-5): """ Args: input: the shape should be BNH[WD]. target: the shape should be BNH[WD]. smooth: a small constant to avoid nan. Raises: ValueError: reduction={self.reduction} is invalid. """ if self.sigmoid: input = torch.sigmoid(input) n_pred_ch = input.shape[1] if n_pred_ch == 1: if self.softmax: warnings.warn( "single channel prediction, `softmax=True` ignored.") if self.to_onehot_y: warnings.warn( "single channel prediction, `to_onehot_y=True` ignored.") if not self.include_background: warnings.warn( "single channel prediction, `include_background=False` ignored." ) else: if self.softmax: input = torch.softmax(input, 1) if self.to_onehot_y: target = one_hot(target, n_pred_ch) if not self.include_background: # if skipping background, removing first channel target = target[:, 1:] input = input[:, 1:] assert ( target.shape == input.shape ), f"ground truth has differing shape ({target.shape}) from input ({input.shape})" p0 = input p1 = 1 - p0 g0 = target g1 = 1 - g0 # reducing only spatial dimensions (not batch nor channels) reduce_axis = list(range(2, len(input.shape))) tp = torch.sum(p0 * g0, reduce_axis) fp = self.alpha * torch.sum(p0 * g1, reduce_axis) fn = self.beta * torch.sum(p1 * g0, reduce_axis) numerator = tp + smooth denominator = tp + fp + fn + smooth score = 1.0 - numerator / denominator if self.reduction == LossReduction.SUM: return score.sum() # sum over the batch and channel dims if self.reduction == LossReduction.NONE: return score # returns [N, n_classes] losses if self.reduction == LossReduction.MEAN: return score.mean() raise ValueError(f"reduction={self.reduction} is invalid.")
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Args: input: the shape should be BNH[WD]. target: the shape should be BNH[WD]. Raises: ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. """ if self.sigmoid: input = torch.sigmoid(input) n_pred_ch = input.shape[1] if self.softmax: if n_pred_ch == 1: warnings.warn("single channel prediction, `softmax=True` ignored.") else: input = torch.softmax(input, 1) if self.other_act is not None: input = self.other_act(input) if self.to_onehot_y: if n_pred_ch == 1: warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") else: target = one_hot(target, num_classes=n_pred_ch) if not self.include_background: if n_pred_ch == 1: warnings.warn("single channel prediction, `include_background=False` ignored.") else: # if skipping background, removing first channel target = target[:, 1:] input = input[:, 1:] if target.shape != input.shape: raise AssertionError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})") # reducing only spatial dimensions (not batch nor channels) reduce_axis = list(range(2, len(input.shape))) if self.batch: # reducing spatial dimensions and batch reduce_axis = [0] + reduce_axis intersection = torch.sum(target * input, dim=reduce_axis) ### uncoment lines below to enable label weights # if self.label_weights is not None: # add wights to labels # bs=intersection.shape[0] # w = torch.tensor(self.label_weights, dtype=torch.float32,device=torch.device('cuda:0')) # w= w.repeat(bs, 1) ## change size to [BS, Num of classes ] # intersection = w* intersection if self.squared_pred: target = torch.pow(target, 2) input = torch.pow(input, 2) ground_o = torch.sum(target, dim=reduce_axis) pred_o = torch.sum(input, dim=reduce_axis) denominator = ground_o + pred_o if self.jaccard: denominator = 2.0 * (denominator - intersection) f: torch.Tensor = 1.0 - (2.0 * intersection + self.smooth_nr) / (denominator + self.smooth_dr) if self.reduction == LossReduction.MEAN.value: f = torch.mean(f) # the batch and channel average elif self.reduction == LossReduction.SUM.value: f = torch.sum(f) # sum over the batch and channel dims elif self.reduction == LossReduction.NONE.value: pass # returns [N, n_classes] losses else: raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') return f
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Args: input: the shape should be BNH[WD]. target: the shape should be BNH[WD]. Raises: ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. """ if self.sigmoid: input = torch.sigmoid(input) n_pred_ch = input.shape[1] if self.softmax: if n_pred_ch == 1: warnings.warn("single channel prediction, `softmax=True` ignored.") else: input = torch.softmax(input, 1) if self.other_act is not None: input = self.other_act(input) if self.to_onehot_y: if n_pred_ch == 1: warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") else: target = one_hot(target, num_classes=n_pred_ch) if not self.include_background: if n_pred_ch == 1: warnings.warn("single channel prediction, `include_background=False` ignored.") else: # if skipping background, removing first channel target = target[:, 1:] input = input[:, 1:] if target.shape != input.shape: raise AssertionError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})") # reducing only spatial dimensions (not batch nor channels) reduce_axis: List[int] = torch.arange(2, len(input.shape)).tolist() if self.batch: reduce_axis = [0] + reduce_axis intersection = torch.sum(target * input, reduce_axis) ground_o = torch.sum(target, reduce_axis) pred_o = torch.sum(input, reduce_axis) denominator = ground_o + pred_o w = self.w_func(ground_o.float()) for b in w: infs = torch.isinf(b) b[infs] = 0.0 b[infs] = torch.max(b) f: torch.Tensor = 1.0 - (2.0 * (intersection * w).sum(0 if self.batch else 1) + self.smooth_nr) / ( (denominator * w).sum(0 if self.batch else 1) + self.smooth_dr ) if self.reduction == LossReduction.MEAN.value: f = torch.mean(f) # the batch and channel average elif self.reduction == LossReduction.SUM.value: f = torch.sum(f) # sum over the batch and channel dims elif self.reduction == LossReduction.NONE.value: pass # returns [N, n_classes] losses else: raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') return f
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Args: input: the shape should be BNH[WD], where N is the number of classes. target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes. Raises: AssertionError: When input and target (after one hot transform if set) have different shapes. ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. Example: >>> from monai.losses.dice import * # NOQA >>> import torch >>> from monai.losses.dice import DiceLoss >>> B, C, H, W = 7, 5, 3, 2 >>> input = torch.rand(B, C, H, W) >>> target_idx = torch.randint(low=0, high=C - 1, size=(B, H, W)).long() >>> target = one_hot(target_idx[:, None, ...], num_classes=C) >>> self = DiceLoss(reduction='none') >>> loss = self(input, target) >>> assert np.broadcast_shapes(loss.shape, input.shape) == input.shape """ if self.sigmoid: input = torch.sigmoid(input) n_pred_ch = input.shape[1] if self.softmax: if n_pred_ch == 1: warnings.warn( "single channel prediction, `softmax=True` ignored.") else: input = torch.softmax(input, 1) if self.other_act is not None: input = self.other_act(input) if self.to_onehot_y: if n_pred_ch == 1: warnings.warn( "single channel prediction, `to_onehot_y=True` ignored.") else: target = one_hot(target, num_classes=n_pred_ch) if not self.include_background: if n_pred_ch == 1: warnings.warn( "single channel prediction, `include_background=False` ignored." ) else: # if skipping background, removing first channel target = target[:, 1:] input = input[:, 1:] if target.shape != input.shape: raise AssertionError( f"ground truth has different shape ({target.shape}) from input ({input.shape})" ) # reducing only spatial dimensions (not batch nor channels) reduce_axis: List[int] = torch.arange(2, len(input.shape)).tolist() if self.batch: # reducing spatial dimensions and batch reduce_axis = [0] + reduce_axis intersection = torch.sum(target * input, dim=reduce_axis) if self.squared_pred: target = torch.pow(target, 2) input = torch.pow(input, 2) ground_o = torch.sum(target, dim=reduce_axis) pred_o = torch.sum(input, dim=reduce_axis) denominator = ground_o + pred_o if self.jaccard: denominator = 2.0 * (denominator - intersection) f: torch.Tensor = 1.0 - (2.0 * intersection + self.smooth_nr) / ( denominator + self.smooth_dr) if self.reduction == LossReduction.MEAN.value: f = torch.mean(f) # the batch and channel average elif self.reduction == LossReduction.SUM.value: f = torch.sum(f) # sum over the batch and channel dims elif self.reduction == LossReduction.NONE.value: # If we are not computing voxelwise loss components at least # make sure a none reduction maintains a broadcastable shape broadcast_shape = list(f.shape[0:2]) + [1] * (len(input.shape) - 2) f = f.view(broadcast_shape) else: raise ValueError( f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].' ) return f
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Args: input: the shape should be BNH[WD]. target: the shape should be BNH[WD]. Raises: ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. """ if self.sigmoid: input = torch.sigmoid(input) n_pred_ch = input.shape[1] if self.softmax: if n_pred_ch == 1: warnings.warn("single channel prediction, `softmax=True` ignored.") else: input = torch.softmax(input, 1) if self.other_act is not None: input = self.other_act(input) if self.to_onehot_y: if n_pred_ch == 1: warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") else: target = one_hot(target, num_classes=n_pred_ch) if not self.include_background: if n_pred_ch == 1: warnings.warn("single channel prediction, `include_background=False` ignored.") else: # if skipping background, removing first channel target = target[:, 1:] input = input[:, 1:] if target.shape != input.shape: raise AssertionError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})") # reducing only spatial dimensions (not batch nor channels) reduce_axis: List[int] = torch.arange(2, len(input.shape)).tolist() if self.batch: reduce_axis = [0] + reduce_axis intersection = torch.sum(target * input, reduce_axis) ground_o = torch.sum(target, reduce_axis) pred_o = torch.sum(input, reduce_axis) denominator = ground_o + pred_o w = self.w_func(ground_o.float()) infs = torch.isinf(w) if self.batch: w[infs] = 0.0 w = w + infs * torch.max(w) else: w[infs] = 0.0 max_values = torch.max(w, dim=1)[0].unsqueeze(dim=1) w = w + infs * max_values final_reduce_dim = 0 if self.batch else 1 numer = 2.0 * (intersection * w).sum(final_reduce_dim, keepdim=True) + self.smooth_nr denom = (denominator * w).sum(final_reduce_dim, keepdim=True) + self.smooth_dr f: torch.Tensor = 1.0 - (numer / denom) if self.reduction == LossReduction.MEAN.value: f = torch.mean(f) # the batch and channel average elif self.reduction == LossReduction.SUM.value: f = torch.sum(f) # sum over the batch and channel dims elif self.reduction == LossReduction.NONE.value: # If we are not computing voxelwise loss components at least # make sure a none reduction maintains a broadcastable shape broadcast_shape = list(f.shape[0:2]) + [1] * (len(input.shape) - 2) f = f.view(broadcast_shape) else: raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') return f
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Args: input: the shape should be BNH[WD], where N is the number of classes. The input should be the original logits since it will be transferred by `F.log_softmax` in the forward function. target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes. Raises: AssertionError: When input and target (after one hot transform if setted) have different shapes. ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. ValueError: When ``self.weight`` is a sequence and the length is not equal to the number of classes. ValueError: When ``self.weight`` is/contains a value that is less than 0. """ n_pred_ch = input.shape[1] if self.to_onehot_y: if n_pred_ch == 1: warnings.warn( "single channel prediction, `to_onehot_y=True` ignored.") else: target = one_hot(target, num_classes=n_pred_ch) if not self.include_background: if n_pred_ch == 1: warnings.warn( "single channel prediction, `include_background=False` ignored." ) else: # if skipping background, removing first channel target = target[:, 1:] input = input[:, 1:] if target.shape != input.shape: raise AssertionError( f"ground truth has different shape ({target.shape}) from input ({input.shape})" ) i = input t = target # Change the shape of input and target to B x N x num_voxels. i = i.view(i.size(0), i.size(1), -1) t = t.view(t.size(0), t.size(1), -1) # Compute the log proba. logpt = F.log_softmax(i, dim=1) # Get the proba pt = torch.exp(logpt) # B,H*W or B,N,H*W if self.weight is not None: class_weight: Optional[torch.Tensor] = None if isinstance(self.weight, (float, int)): class_weight = torch.as_tensor([self.weight] * i.size(1)) else: class_weight = torch.as_tensor(self.weight) if class_weight.size(0) != i.size(1): raise ValueError( "the length of the weight sequence should be the same as the number of classes. " + "If `include_background=False`, the number should not include class 0." ) if class_weight.min() < 0: raise ValueError( "the value/values of weights should be no less than 0.") class_weight = # Convert the weight to a map in which each voxel # has the weight associated with the ground-truth label # associated with this voxel in target. at = class_weight[None, :, None] # N => 1,N,1 at = at.expand((t.size(0), -1, t.size(2))) # 1,N,1 => B,N,H*W # Multiply the log proba by their weights. logpt = logpt * at # Compute the loss mini-batch. weight = torch.pow(-pt + 1.0, self.gamma) loss = torch.mean(-weight * t * logpt, dim=-1) if self.reduction == LossReduction.SUM.value: return loss.sum() if self.reduction == LossReduction.NONE.value: return loss if self.reduction == LossReduction.MEAN.value: return loss.mean() raise ValueError( f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].' )
def compute_confusion_metric( y_pred: torch.Tensor, y: torch.Tensor, to_onehot_y: bool = False, activation: Optional[Union[str, Callable]] = None, bin_mode: Optional[str] = "threshold", bin_threshold: Union[float, Sequence[float]] = 0.5, metric_name: str = "hit_rate", average: Union[Average, str] = Average.MACRO, zero_division: int = 0, ) -> Union[np.ndarray, List[float], float]: """ Compute confusion matrix related metrics. This function supports to calculate all metrics mentioned in: `Confusion matrix <>`_. Before calculating, an activation function and/or a binarization manipulation can be employed to pre-process the original inputs. Zero division is handled by replacing the result into a single value. Referring to: `sklearn.metrics <>`_. Args: y_pred: predictions. As for classification tasks, `y_pred` should has the shape [B] or [BN]. As for segmentation tasks, the shape should be [BNHW] or [BNHWD]. y: ground truth, the first dim is batch. to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. activation: [``"sigmoid"``, ``"softmax"``] Activation method, if specified, an activation function will be employed for `y_pred`. Defaults to None. The parameter can also be a callable function, for example: ``activation = lambda x: torch.log_softmax(x)``. bin_mode: [``"threshold"``, ``"mutually_exclusive"``] Binarization method, if specified, a binarization manipulation will be employed for `y_pred`. - ``"threshold"``, a single threshold or a sequence of thresholds should be set. - ``"mutually_exclusive"``, `y_pred` will be converted by a combination of `argmax` and `to_onehot`. bin_threshold: the threshold for binarization, can be a single value or a sequence of values that each one of the value represents a threshold for a class. metric_name: [``"sensitivity"``, ``"specificity"``, ``"precision"``, ``"negative predictive value"``, ``"miss rate"``, ``"fall out"``, ``"false discovery rate"``, ``"false omission rate"``, ``"prevalence threshold"``, ``"threat score"``, ``"accuracy"``, ``"balanced accuracy"``, ``"f1 score"``, ``"matthews correlation coefficient"``, ``"fowlkes mallows index"``, ``"informedness"``, ``"markedness"``] Some of the metrics have multiple aliases (as shown in the wikipedia page aforementioned), and you can also input those names instead. average: [``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``] Type of averaging performed if not binary classification. Defaults to ``"macro"``. - ``"macro"``: calculate metrics for each label, and find their unweighted mean. This does not take label imbalance into account. - ``"weighted"``: calculate metrics for each label, and find their average, weighted by support (the number of true instances for each label). - ``"micro"``: calculate metrics globally by considering each element of the label indicator matrix as a label. - ``"none"``: the scores for each class are returned. zero_division: the value to return when there is a zero division, for example, when all predictions and labels are negative. Defaults to 0. Raises: AssertionError: when data shapes of `y_pred` and `y` do not match. AssertionError: when specify activation function and ``mutually_exclusive`` mode at the same time. """ y_pred_ndim, y_ndim = y_pred.ndimension(), y.ndimension() # one-hot for ground truth if to_onehot_y: if y_pred_ndim == 1: warnings.warn("y_pred has only one channel, to_onehot_y=True ignored.") else: n_classes = y_pred.shape[1] y = one_hot(y, num_classes=n_classes) # check shape assert y.shape == y_pred.shape, "data shapes of y_pred and y do not match." # activation for predictions if activation is not None: assert bin_mode != "mutually_exclusive", "activation is unnecessary for mutually exclusive classes." y_pred = do_activation(y_pred, activation=activation) # binarization for predictions if bin_mode is not None: y_pred = do_binarization(y_pred, bin_mode=bin_mode, bin_threshold=bin_threshold) # get confusion matrix elements con_list = cal_confusion_matrix_elements(y_pred, y) # get simplified metric name metric_name = check_metric_name_and_unify(metric_name) result = do_calculate_metric(con_list, metric_name, average=average, zero_division=zero_division) return result
def dice_loss(input: tensor, target: tensor, include_background: bool = True, softmax: bool = False, to_onehot: bool = True, squared_pred: bool = False, reduction: Union[LossReduction, str] = LossReduction.MEAN, smooth: float = 1e-5): """ loss function, from Milletari, F. et. al. (2016) V-Net: Fully Convolutional Neural Networks forVolumetric Medical Image Segmentation, 3DV, 2016. Args: input: predict tensor,the shape should be BNH[WD]. target: target tensor, the shape should be BNH[WD]. include_background: softmax: if True, apply a softmax function to the prediction. to_onehot: whether to convert `target` into the one-hot format. Defaults to False. squared_pred: use squared versions of targets and predictions in the denominator or not. reduction: {``"none"``, ``"mean"``, ``"sum"``} Specifies the reduction to apply to the output. Defaults to ``"mean"``. - ``"none"``: no reduction will be applied. - ``"mean"``: the sum of the output will be divided by the number of elements in the output. - ``"sum"``: the output will be summed. smooth: a small constant to avoid nan. """ n_pred_ch = input.shape[1] if softmax: input = torch.softmax(input, 1) if to_onehot: if n_pred_ch == 1: warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") else: # the F.one_hot can not use here, because it would return BNH[WD]C (C is the class target = one_hot(, num_classes=n_pred_ch) if not include_background: if n_pred_ch == 1: warnings.warn("single channel prediction, `include_background=False` ignored.") else: # if skipping background, removing first channel target = target[:, 1:] input = input[:, 1:] assert ( target.shape == input.shape ), f"ground truth has differing shape ({target.shape}) from input ({input.shape})" # reducing only spatial dimensions (not batch nor channels) reduce_axis = list(range(2, len(input.shape))) intersection = torch.sum(target * input, dim=reduce_axis) if squared_pred: target = torch.pow(target, 2) input = torch.pow(input, 2) ground_o = torch.sum(target, dim=reduce_axis) pred_o = torch.sum(input, dim=reduce_axis) denominator = ground_o + pred_o f = 1.0 - (2.0 * intersection + smooth) / (denominator + smooth) reduction = LossReduction(reduction).value if reduction == LossReduction.MEAN.value: f = torch.mean(f) # the batch and channel average elif reduction == LossReduction.SUM.value: f = torch.sum(f) # sum over the batch and channel dims elif reduction == LossReduction.NONE.value: pass # returns [N, n_classes] losses else: raise ValueError(f'Unsupported reduction: {reduction}, available options are ["mean", "sum", "none"].') return f