예제 #1
0
 def _compute_metric_score(
     metric: torchmetrics.Metric,
     custom_metric_func: Callable,
     label: torch.Tensor,
     logits: Optional[torch.Tensor] = None,
     embeddings1: Optional[torch.Tensor] = None,
     embeddings2: Optional[torch.Tensor] = None,
     reverse_prob: Optional[bool] = False,
 ):
     if logits is not None:
         if isinstance(metric, torchmetrics.AUROC):
             prob = compute_probability(logits=logits)
             metric.update(preds=prob, target=label)  # only for binary classification
         elif isinstance(metric, BaseAggregator):
             metric.update(custom_metric_func(logits, label))
         else:
             metric.update(logits.squeeze(dim=1), label)
     else:
         if isinstance(metric, BaseAggregator):
             metric.update(custom_metric_func(embeddings1, embeddings2, label))
         else:
             metric.update(
                 compute_probability(
                     embeddings1=embeddings1,
                     embeddings2=embeddings2,
                     reverse_prob=reverse_prob,
                 ),
                 label,
             )
예제 #2
0
def _assert_half_support(metric_module: Metric,
                         metric_functional: Callable,
                         preds: Tensor,
                         target: Tensor,
                         device: str = "cpu",
                         **kwargs_update):
    """
    Test if an metric can be used with half precision tensors

    Args:
        metric_module: the metric module to test
        metric_functional: the metric functional to test
        preds: torch tensor with predictions
        target: torch tensor with targets
        device: determine device, either "cpu" or "cuda"
        kwargs_update: Additional keyword arguments that will be passed with preds and
                target when running update on the metric.
    """
    y_hat = preds[0].half().to(
        device) if preds[0].is_floating_point() else preds[0].to(device)
    y = target[0].half().to(
        device) if target[0].is_floating_point() else target[0].to(device)
    kwargs_update = {
        k: (v[0].half() if v.is_floating_point() else v[0]).to(device)
        if isinstance(v, Tensor) else v
        for k, v in kwargs_update.items()
    }
    metric_module = metric_module.to(device)
    _assert_tensor(metric_module(y_hat, y, **kwargs_update))
    _assert_tensor(metric_functional(y_hat, y, **kwargs_update))
예제 #3
0
 def _compute_metric_score(
     self,
     metric: torchmetrics.Metric,
     custom_metric_func: Callable,
     logits: torch.Tensor,
     label: torch.Tensor,
 ):
     if isinstance(metric, (torchmetrics.AUROC, torchmetrics.AveragePrecision)):
         prob = F.softmax(logits.float(), dim=1)
         metric.update(preds=prob[:, 1], target=label)  # only for binary classification
     elif isinstance(metric, BaseAggregator):
         metric.update(custom_metric_func(logits, label))
     else:
         metric.update(logits.squeeze(dim=1), label)
예제 #4
0
def _assert_half_support(
    metric_module: Metric,
    metric_functional: Callable,
    preds: torch.Tensor,
    target: torch.Tensor,
    device: str = "cpu",
):
    """
    Test if an metric can be used with half precision tensors

    Args:
        metric_module: the metric module to test
        metric_functional: the metric functional to test
        preds: torch tensor with predictions
        target: torch tensor with targets
        device: determine device, either "cpu" or "cuda"
    """
    y_hat = preds[0].half().to(
        device) if preds[0].is_floating_point() else preds[0].to(device)
    y = target[0].half().to(
        device) if target[0].is_floating_point() else target[0].to(device)
    metric_module = metric_module.to(device)
    _assert_tensor(metric_module(y_hat, y))
    _assert_tensor(metric_functional(y_hat, y))