def _compute_metric_score( metric: torchmetrics.Metric, custom_metric_func: Callable, label: torch.Tensor, logits: Optional[torch.Tensor] = None, embeddings1: Optional[torch.Tensor] = None, embeddings2: Optional[torch.Tensor] = None, reverse_prob: Optional[bool] = False, ): if logits is not None: if isinstance(metric, torchmetrics.AUROC): prob = compute_probability(logits=logits) metric.update(preds=prob, target=label) # only for binary classification elif isinstance(metric, BaseAggregator): metric.update(custom_metric_func(logits, label)) else: metric.update(logits.squeeze(dim=1), label) else: if isinstance(metric, BaseAggregator): metric.update(custom_metric_func(embeddings1, embeddings2, label)) else: metric.update( compute_probability( embeddings1=embeddings1, embeddings2=embeddings2, reverse_prob=reverse_prob, ), label, )
def _assert_half_support(metric_module: Metric, metric_functional: Callable, preds: Tensor, target: Tensor, device: str = "cpu", **kwargs_update): """ Test if an metric can be used with half precision tensors Args: metric_module: the metric module to test metric_functional: the metric functional to test preds: torch tensor with predictions target: torch tensor with targets device: determine device, either "cpu" or "cuda" kwargs_update: Additional keyword arguments that will be passed with preds and target when running update on the metric. """ y_hat = preds[0].half().to( device) if preds[0].is_floating_point() else preds[0].to(device) y = target[0].half().to( device) if target[0].is_floating_point() else target[0].to(device) kwargs_update = { k: (v[0].half() if v.is_floating_point() else v[0]).to(device) if isinstance(v, Tensor) else v for k, v in kwargs_update.items() } metric_module = metric_module.to(device) _assert_tensor(metric_module(y_hat, y, **kwargs_update)) _assert_tensor(metric_functional(y_hat, y, **kwargs_update))
def _compute_metric_score( self, metric: torchmetrics.Metric, custom_metric_func: Callable, logits: torch.Tensor, label: torch.Tensor, ): if isinstance(metric, (torchmetrics.AUROC, torchmetrics.AveragePrecision)): prob = F.softmax(logits.float(), dim=1) metric.update(preds=prob[:, 1], target=label) # only for binary classification elif isinstance(metric, BaseAggregator): metric.update(custom_metric_func(logits, label)) else: metric.update(logits.squeeze(dim=1), label)
def _assert_half_support( metric_module: Metric, metric_functional: Callable, preds: torch.Tensor, target: torch.Tensor, device: str = "cpu", ): """ Test if an metric can be used with half precision tensors Args: metric_module: the metric module to test metric_functional: the metric functional to test preds: torch tensor with predictions target: torch tensor with targets device: determine device, either "cpu" or "cuda" """ y_hat = preds[0].half().to( device) if preds[0].is_floating_point() else preds[0].to(device) y = target[0].half().to( device) if target[0].is_floating_point() else target[0].to(device) metric_module = metric_module.to(device) _assert_tensor(metric_module(y_hat, y)) _assert_tensor(metric_functional(y_hat, y))