Example #1
0
def dice(outputs: torch.Tensor,
         targets: torch.Tensor,
         eps: float = 1e-7,
         threshold: float = None,
         activation: str = "Sigmoid"):
    """
    Computes the dice metric

    Args:
        outputs (list):  A list of predicted elements
        targets (list): A list of elements that are to be predicted
        eps (float): epsilon
        threshold (float): threshold for outputs binarization
        activation (str): An torch.nn activation applied to the outputs.
            Must be one of ["none", "Sigmoid", "Softmax2d"]

    Returns:
        double:  Dice score
    """
    activation_fn = get_activation_fn(activation)
    outputs = activation_fn(outputs)

    if threshold is not None:
        outputs = (outputs > threshold).float()

    intersection = (targets * outputs).mean(dim=2).mean(dim=2)
    union = targets.mean(dim=2).mean(dim=2) + outputs.mean(dim=2).mean(dim=2)
    dice = ((2 * intersection + eps) / (union + eps)).mean()

    return dice
Example #2
0
def iou(outputs: torch.Tensor,
        targets: torch.Tensor,
        eps: float = 1e-7,
        threshold: float = None,
        activation: str = "Sigmoid"):
    """
    Args:
        outputs (torch.Tensor): A list of predicted elements
        targets (torch.Tensor):  A list of elements that are to be predicted
        eps (float): epsilon to avoid zero division
        threshold (float): threshold for outputs binarization
        activation (str): An torch.nn activation applied to the outputs.
            Must be one of ["none", "Sigmoid", "Softmax2d"]

    Returns:
        float: IoU (Jaccard) score
    """
    activation_fn = get_activation_fn(activation)
    outputs = activation_fn(outputs)

    if threshold is not None:
        outputs = (outputs > threshold).float()

    intersection = torch.sum(targets * outputs)
    union = torch.sum(targets) + torch.sum(outputs)
    iou = intersection / (union - intersection + eps)

    return iou
Example #3
0
    def dice(self,
             outputs,
             targets,
             eps=1e-7,
             threshold=0.5,
             activation='Sigmoid'):
        activation_fn = get_activation_fn(activation)
        outputs = activation_fn(outputs)

        if threshold is not None:
            outputs = (outputs > threshold).float()

        batch_size = len(targets)
        outputs = outputs.view(batch_size, -1)
        targets = targets.view(batch_size, -1)

        intersection = torch.sum(targets * outputs, dim=1)
        union = torch.sum(targets, dim=1) + torch.sum(outputs, dim=1)
        dice = (2 * intersection / (union + eps)).cpu().numpy()

        result = []
        for i, d in enumerate(dice):
            if d >= eps:
                result.append(d)
                continue

            s = torch.sum(targets[i]).cpu().numpy()
            result.append(1 if s < eps else d)

        return np.mean(result)
def dice(outputs: torch.Tensor,
         targets: torch.Tensor,
         eps: float = 1e-7,
         threshold: float = 0.5,
         activation: str = "Sigmoid"):
    """
    Computes the dice metric
    Args:
        outputs (list):  A list of predicted elements
        targets (list): A list of elements that are to be predicted
        eps (float): epsilon
        threshold (float): threshold for outputs binarization
        activation (str): An torch.nn activation applied to the outputs.
            Must be one of ["none", "Sigmoid", "Softmax2d"]
    Returns:
        double:  Dice score
    """
    outputs = outputs[:, 0, ...]
    targets = targets[:, 0, ...]

    activation_fn = get_activation_fn(activation)
    outputs = activation_fn(outputs)

    if threshold is not None:
        outputs = (outputs > threshold).float()

    intersection = torch.sum(targets * outputs)
    union = torch.sum(targets) + torch.sum(outputs)
    # this looks a bit awkward but `eps * (union == 0)` term
    # makes sure that if I and U are both 0, than Dice == 1
    # and if U != 0 and I == 0 the eps term in numerator is zeroed out
    # i.e. (0 + eps) / (U - 0 + eps) doesn't happen
    dice = 2 * (intersection + eps * (union == 0)) / (union + eps)

    return dice
Example #5
0
def accuracy(
        outputs,
        targets,
        topk=(1, ),
        threshold: float = None,
        activation: str = None
):
    """
    Computes the accuracy@k for the specified values of k.
    """
    max_k = max(topk)
    batch_size = targets.size(0)

    activation_fn = get_activation_fn(activation)
    outputs = activation_fn(outputs)

    if threshold:
        outputs = (outputs > threshold).long()

    if len(outputs.shape) == 1 or outputs.shape[1] == 1:
        pred = outputs.t()
    else:
        _, pred = outputs.topk(max_k, 1, True, True)
        pred = pred.t()
    correct = pred.eq(targets.long().view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res
Example #6
0
    def on_batch_end(self, state: RunnerState):
        logits: torch.Tensor = state.output[self.output_key].detach().float()
        targets: torch.Tensor = state.input[self.input_key].detach().float()
        activation_fn = get_activation_fn(self.activation)
        probabilities: torch.Tensor = activation_fn(logits)

        for i in range(self.num_classes):
            self.meters[i].add(probabilities[:, i], targets[:, i])
Example #7
0
def accuracy(
        outputs,
        targets,
        topk=(1, ),
        threshold: float = None,
        activation: str = None,
):
    """
    Computes the accuracy.

    It can be used either for:

    1. Multi-class task, in this case:

      - you can use topk.
      - threshold and activation are not required.
      - targets is a tensor: batch_size
      - outputs is a tensor: batch_size x num_classes
      - computes the accuracy@k for the specified values of k.

    2. Multi-label task, in this case:

      - you must specify threshold and activation
      - topk will not be used
        (because of there is no method to apply top-k in
        multi-label classification).
      - outputs, targets are tensors with shape: batch_size x num_classes
      - targets is a tensor with binary vectors
    """
    activation_fn = get_activation_fn(activation)
    outputs = activation_fn(outputs)

    if threshold:
        outputs = (outputs > threshold).long()

    # multi-label classification
    if len(targets.shape) > 1 and targets.size(1) > 1:
        res = (targets.long() == outputs.long()).sum().float() / np.prod(
            targets.shape)
        return [res]

    max_k = max(topk)
    batch_size = targets.size(0)

    if len(outputs.shape) == 1 or outputs.shape[1] == 1:
        pred = outputs.t()
    else:
        _, pred = outputs.topk(max_k, 1, True, True)
        pred = pred.t()
    correct = pred.eq(targets.long().view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(1.0 / batch_size))
    return res
Example #8
0
def dice_apex(outputs,
              targets,
              eps: float = 1e-7,
              activation: str = "Sigmoid"):
    """
    Computes the dice metric
        Args:
        outputs (list):  A list of predicted elements
        targets (list): A list of elements that are to be predicted
        eps (float): epsilon
        activation (str): An torch.nn activation applied to the outputs.
            Must be one of ['none', 'sigmoid', 'softmax2d']
    Returns:
        double:  Dice score
    """
    activation_fn = get_activation_fn(activation)
    targets = targets.float()

    outputs = activation_fn(outputs)

    batch_size = len(targets)

    with torch.no_grad():
        outputs = outputs.view(batch_size, -1)
        targets = targets.view(batch_size, -1)
        assert (outputs.shape == targets.shape)

        probability = outputs
        p = (probability > 0.5).float()
        t = (targets > 0.5).float()

        t_sum = t.sum(-1)
        p_sum = p.sum(-1)
        neg_index = torch.nonzero(t_sum == 0)
        pos_index = torch.nonzero(t_sum >= 1)
        #print(len(neg_index), len(pos_index))

        dice_neg = (p_sum == 0).float()
        dice_pos = 2 * (p * t).sum(-1) / ((p + t).sum(-1))

        dice_neg = dice_neg[neg_index]
        dice_pos = dice_pos[pos_index]
        dice = torch.cat([dice_pos, dice_neg])

        dice_neg = np.nan_to_num(dice_neg.mean().item(), 0)
        dice_pos = np.nan_to_num(dice_pos.mean().item(), 0)
        dice = dice.mean().item()

        num_neg = len(neg_index)
        num_pos = len(pos_index)

    return dice
Example #9
0
    def accuracy(outputs: torch.Tensor,
                 targets: torch.Tensor,
                 eps: float = 1e-7,
                 threshold: float = None,
                 activation: str = "Sigmoid"):
        activation_fn = get_activation_fn(activation)
        outputs = activation_fn(outputs)

        if threshold is not None:
            outputs = (outputs > threshold).float()

        eq = (outputs.long() == targets.long()).long().float().sum()
        return eq / (targets.view(-1).size(0) + eps)
Example #10
0
    def forward(self, logits, targets):
        activation_fnc = get_activation_fn(self.activation)
        logits_softmax = activation_fnc(logits)

        ce_loss = self.ce_loss(logits, targets)

        dice_loss = 0
        for cls in range(self.num_classes):
            targets_cls = (targets == cls).float()
            outputs_cls = logits_softmax[:, cls]
            score = 1 - criterion.dice(outputs_cls, targets_cls, eps=1e-7, activation='none', threshold=None)
            dice_loss += score / self.num_classes

        loss = (1 - self.dice_weight) * ce_loss + self.dice_weight * dice_loss
        return loss
Example #11
0
def f1_score(
    outputs: torch.Tensor,
    targets: torch.Tensor,
    beta: float = 1.0,
    eps: float = 1e-7,
    threshold: float = None,
    activation: str = "Sigmoid",
):
    """
    Args:
        outputs (torch.Tensor): A list of predicted elements
        targets (torch.Tensor):  A list of elements that are to be predicted
        eps (float): epsilon to avoid zero division
        beta (float): beta param for f_score
        threshold (float): threshold for outputs binarization
        activation (str): An torch.nn activation applied to the outputs.
            Must be one of ["none", "Sigmoid", "Softmax2d"]

    Returns:
        float: F_1 score

    Main origins of inspiration:
        https://github.com/qubvel/segmentation_models.pytorch
    """
    activation_fn = get_activation_fn(activation)

    outputs = activation_fn(outputs)

    if threshold is not None:
        outputs = (outputs > threshold).float()

    true_positive = torch.sum(targets * outputs)
    false_positive = torch.sum(outputs) - true_positive
    false_negative = torch.sum(targets) - true_positive

    precision_plus_recall = (
        (1 + beta ** 2) * true_positive
        + beta ** 2 * false_negative
        + false_positive
        + eps
    )

    score = ((1 + beta ** 2) * true_positive + eps) / precision_plus_recall

    return score
Example #12
0
def iou(
    outputs: torch.Tensor,
    targets: torch.Tensor,
    # values are discarded, only None check
    # used for compatibility with MultiMetricCallback
    classes: List[str] = None,
    eps: float = 1e-7,
    threshold: float = None,
    activation: str = "Sigmoid",
) -> Union[float, List[float]]:
    """
    Args:
        outputs (torch.Tensor): A list of predicted elements
        targets (torch.Tensor):  A list of elements that are to be predicted
        eps (float): epsilon to avoid zero division
        threshold (float): threshold for outputs binarization
        activation (str): An torch.nn activation applied to the outputs.
            Must be one of ["none", "Sigmoid", "Softmax2d"]

    Returns:
        Union[float, List[float]]: IoU (Jaccard) score(s)
    """
    activation_fn = get_activation_fn(activation)
    outputs = activation_fn(outputs)

    if threshold is not None:
        outputs = (outputs > threshold).float()

    # ! fix backward compatibility
    if classes is not None:
        # if classes are specified we reduce across all dims except channels
        _sum = partial(torch.sum, dim=[0, 2, 3])
    else:
        _sum = torch.sum

    intersection = _sum(targets * outputs)
    union = _sum(targets) + _sum(outputs)
    # this looks a bit awkward but `eps * (union == 0)` term
    # makes sure that if I and U are both 0, than IoU == 1
    # and if U != 0 and I == 0 the eps term in numerator is zeroed out
    # i.e. (0 + eps) / (U - 0 + eps) doesn't happen
    iou = (intersection + eps * (union == 0)) / (union - intersection + eps)

    return iou
Example #13
0
def soft_dice(
    outputs: torch.Tensor,
    targets: torch.Tensor,
    eps: float = 1e-7,
    threshold: float = None,
    activation: str = "Sigmoid",
    weight=[0.2, 0.8]
):
    """
    Computes the dice metric

    Args:
        outputs (list):  A list of predicted elements
        targets (list): A list of elements that are to be predicted
        eps (float): epsilon
        threshold (float): threshold for outputs binarization
        activation (str): An torch.nn activation applied to the outputs.
            Must be one of ["none", "Sigmoid", "Softmax2d"]

    Returns:
        double:  Dice score
    """
    activation_fn = get_activation_fn(activation)
    outputs = activation_fn(outputs)
    targets = targets.float()

    batch_size = len(outputs)
    outputs = outputs.view(batch_size, -1)
    targets = targets.view(batch_size, -1)

    p = outputs.view(batch_size, -1)
    t = targets.view(batch_size, -1)
    w = targets.detach()
    w = w*(weight[1]-weight[0])+weight[0]

    p = w*(p*2-1)
    t = w*(t*2-1)

    intersection = (p * t).sum(-1)
    union = (p * p).sum(-1) + (t * t).sum(-1)
    dice = 1 - 2*intersection/union

    loss = dice
    return loss.mean()
Example #14
0
    def on_batch_end(self, state: RunnerState):
        outputs = state.output[self.output_key]
        targets = state.input[self.input_key]

        activation_fnc = get_activation_fn(self.activation)
        outputs = activation_fnc(outputs)
        _, outputs = outputs.max(dim=1)

        dice = 0
        for cls in range(self.num_classes):
            targets_cls = (targets == cls).float()
            outputs_cls = (outputs == cls).float()
            score = criterion.dice(outputs_cls,
                                   targets_cls,
                                   eps=1e-7,
                                   activation='none',
                                   threshold=None)
            dice += score / self.num_classes
        state.metrics.add_batch_value(name=self.prefix, value=dice)
Example #15
0
 def __init__(
     self,
     metric_names: List[str],
     meter_list: List,
     input_key: str = "targets",
     output_key: str = "logits",
     class_names: List[str] = None,
     num_classes: int = 2,
     activation: str = "Sigmoid",
 ):
     super().__init__(CallbackOrder.Metric)
     self.metric_names = metric_names
     self.meters = meter_list
     self.input_key = input_key
     self.output_key = output_key
     self.class_names = class_names
     self.num_classes = num_classes
     self.activation = activation
     self.activation_fn = get_activation_fn(self.activation)
Example #16
0
    def iou(self,
            outputs: torch.Tensor,
            targets: torch.Tensor,
            eps: float = 1e-7,
            threshold: float = None,
            activation: str = "Sigmoid"):
        """
        Args:
            outputs (torch.Tensor): A list of predicted elements
            targets (torch.Tensor):  A list of elements that are to be predicted
            eps (float): epsilon to avoid zero division
            threshold (float): threshold for outputs binarization
            activation (str): An torch.nn activation applied to the outputs.
                Must be one of ["none", "Sigmoid", "Softmax2d"]

        Returns:
            float: IoU (Jaccard) score
        """
        activation_fn = get_activation_fn(activation)
        outputs = activation_fn(outputs)

        if threshold is not None:
            outputs = (outputs > threshold).float()

        batch_size = len(targets)
        outputs = outputs.view(batch_size, -1)
        targets = targets.view(batch_size, -1)

        intersection = torch.sum(targets * outputs, dim=1)
        union = torch.sum(targets, dim=1) + torch.sum(outputs, dim=1)
        iou = (intersection / (union - intersection + eps)).cpu().numpy()

        result = []
        for i, d in enumerate(iou):
            if d >= eps:
                result.append(d)
                continue

            s = torch.sum(targets[i]).cpu().numpy()
            result.append(1 if s < eps else d)

        return np.mean(result)
    def on_batch_end(self, state):
        outputs = state.batch_out[self.output_key]
        targets = state.batch_in[self.input_key]

        activation_fnc = get_activation_fn(self.activation)
        outputs = activation_fnc(outputs)
        _, outputs = outputs.max(dim=1)

        dice = 0
        start_idx = 0 if self.include_bg else 1
        for cls in range(start_idx, self.num_classes):
            targets_cls = (targets == cls).float()
            outputs_cls = (outputs == cls).float()
            score = _dice(outputs_cls, targets_cls, eps=1e-7, activation='none', threshold=None)
            if self.class_names is not None:
                state.batch_metrics[f"{self.prefix}_{self.class_names[cls]}"] = score
            if self.include_bg:
                dice += score / self.num_classes
            else:
                dice += score / (self.num_classes - 1)
        state.batch_metrics[self.prefix] = dice
Example #18
0
def dice_wo_back(outputs: torch.Tensor,
                 targets: torch.Tensor,
                 eps: float = 1e-7,
                 threshold: float = None,
                 activation: str = "Sigmoid"):
    """
    Computes the dice metric

    Args:
        outputs (list):  A list of predicted elements
        targets (list): A list of elements that are to be predicted
        eps (float): epsilon
        threshold (float): threshold for outputs binarization
        activation (str): An torch.nn activation applied to the outputs.
            Must be one of ["none", "Sigmoid", "Softmax2d"]

    Returns:
        double:  Dice score
    """
    activation_fn = get_activation_fn(activation)
    outputs = activation_fn(outputs)

    if threshold is not None:
        outputs = (outputs > threshold).float()

    # intersection = torch.sum(targets * outputs, axis=[0, 2, 3])
    # union = torch.sum(targets, axis=[0, 2, 3]) + torch.sum(outputs, axis=[0, 2, 3])
    # dice = (2 * intersection + eps) / (union + eps)
    # return dice[:-1].mean()

    # exclude background layer
    outputs = outputs[:, :-1, ...]
    targets = targets[:, :-1, ...]

    intersection = torch.sum(targets * outputs)
    union = torch.sum(targets) + torch.sum(outputs)
    dice = (2 * intersection) / (union + eps)
    return dice
Example #19
0
 def __init__(
     self,
     metric_names: List[str],
     meter_list: List,
     input_key: str = "targets",
     output_key: str = "logits",
     class_names: List[str] = None,
     num_classes: int = 2,
     activation: str = "Sigmoid",
 ):
     """
     Args:
         metric_names (List[str]): of metrics to print
             Make sure that they are in the same order that metrics
             are outputted by the meters in `meter_list`
         meter_list (list-like): List of meters.meter.Meter instances
             len(meter_list) == num_classes
         input_key (str): input key to use for metric calculation
             specifies our ``y_true``.
         output_key (str): output key to use for metric calculation;
             specifies our ``y_pred``
         class_names (List[str]): class names to display in the logs.
             If None, defaults to indices for each class, starting from 0.
         num_classes (int): Number of classes; must be > 1
         activation (str): An torch.nn activation applied to the logits.
             Must be one of ['none', 'Sigmoid', 'Softmax2d']
     """
     super().__init__(CallbackOrder.Metric)
     self.metric_names = metric_names
     self.meters = meter_list
     self.input_key = input_key
     self.output_key = output_key
     self.class_names = class_names
     self.num_classes = num_classes
     self.activation = activation
     self.activation_fn = get_activation_fn(self.activation)