Ejemplo n.º 1
0
Archivo: gsn.py Proyecto: apple/ml-gsn
    def validation_epoch_end(self, outputs):
        # each process stores features separately, so gather them together to calculate FID over the full distribution
        real_features = dim_zero_cat(self.fid.real_features)
        real_features_list = [
            torch.empty_like(real_features)
            for _ in range(dist.get_world_size())
        ]
        dist.all_gather(real_features_list, real_features)
        real_features = dim_zero_cat(real_features_list)

        fake_features = dim_zero_cat(self.fid.fake_features)
        fake_features_list = [
            torch.empty_like(fake_features)
            for _ in range(dist.get_world_size())
        ]
        dist.all_gather(fake_features_list, fake_features)
        fake_features = dim_zero_cat(fake_features_list)

        rank = dist.get_rank()
        if rank == 0:
            fid = calculate_fid(real_features,
                                fake_features)  # returned as numpy array
            fid = torch.tensor(
                [float(fid)],
                device=self.device)  # but we need a torch tensor for DDP
            print('')
            print('FID with {} samples: {}'.format(len(real_features), fid))
            print('')
        else:
            fid = torch.tensor([0.0], device=self.device)

        # share the result with all GPUs so that the checkpointing function doesn't crash
        dist.broadcast(tensor=fid, src=0)
        self.log('metrics/fid', fid, rank_zero_only=True)
Ejemplo n.º 2
0
    def compute(
        self
    ) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor],
                                                    List[Tensor]]]:
        """
        Compute the precision-recall curve

        Returns:
            3-element tuple containing

            precision:
                tensor where element i is the precision of predictions with
                score >= thresholds[i] and the last element is 1.
                If multiclass, this is a list of such tensors, one for each class.
            recall:
                tensor where element i is the recall of predictions with
                score >= thresholds[i] and the last element is 0.
                If multiclass, this is a list of such tensors, one for each class.
            thresholds:
                Thresholds used for computing precision/recall scores
        """
        preds = dim_zero_cat(self.preds)
        target = dim_zero_cat(self.target)
        return _precision_recall_curve_compute(preds, target, self.num_classes,
                                               self.pos_label)
Ejemplo n.º 3
0
 def compute(self):
     """
     Computes pearson correlation coefficient over state.
     """
     preds = dim_zero_cat(self.preds)
     target = dim_zero_cat(self.target)
     return _pearson_corrcoef_compute(preds, target)
Ejemplo n.º 4
0
 def compute(self):
     """
     Computes spearmans correlation coefficient
     """
     preds = dim_zero_cat(self.preds)
     target = dim_zero_cat(self.target)
     return _spearman_corrcoef_compute(preds, target)
Ejemplo n.º 5
0
 def compute(self) -> Tensor:
     """Computes explained variance over state."""
     preds = dim_zero_cat(self.preds)
     target = dim_zero_cat(self.target)
     return _ssim_compute(
         preds, target, self.kernel_size, self.sigma, self.reduction, self.data_range, self.k1, self.k2
     )
Ejemplo n.º 6
0
    def compute(
        self
    ) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor],
                                                    List[Tensor]]]:
        """Compute the precision-recall curve.

        Returns:
            3-element tuple containing

            precision:
                tensor where element ``i`` is the precision of predictions with
                ``score >= thresholds[i]`` and the last element is 1.
                If multiclass, this is a list of such tensors, one for each class.
            recall:
                tensor where element ``i`` is the recall of predictions with
                ``score >= thresholds[i]`` and the last element is 0.
                If multiclass, this is a list of such tensors, one for each class.
            thresholds:
                Thresholds used for computing precision/recall scores
        """
        preds = dim_zero_cat(self.preds)
        target = dim_zero_cat(self.target)
        if not self.num_classes:
            raise ValueError(
                f"`num_classes` bas to be positive number, but got {self.num_classes}"
            )
        return _precision_recall_curve_compute(preds, target, self.num_classes,
                                               self.pos_label)
Ejemplo n.º 7
0
 def compute(self) -> Tensor:
     """
     Computes AUC based on inputs passed in to ``update`` previously.
     """
     x = dim_zero_cat(self.x)
     y = dim_zero_cat(self.y)
     return _auc_compute(x, y, reorder=self.reorder)
Ejemplo n.º 8
0
    def compute(self) -> Tuple[Tensor, Tensor]:
        """Calculate KID score based on accumulated extracted features from the two distributions. Returns a tuple
        of mean and standard deviation of KID scores calculated on subsets of extracted features.

        Implementation inspired by `Fid Score`_
        """
        real_features = dim_zero_cat(self.real_features)
        fake_features = dim_zero_cat(self.fake_features)

        n_samples_real = real_features.shape[0]
        if n_samples_real < self.subset_size:
            raise ValueError("Argument `subset_size` should be smaller than the number of samples")
        n_samples_fake = fake_features.shape[0]
        if n_samples_fake < self.subset_size:
            raise ValueError("Argument `subset_size` should be smaller than the number of samples")

        kid_scores_ = []
        for _ in range(self.subsets):
            perm = torch.randperm(n_samples_real)
            f_real = real_features[perm[: self.subset_size]]
            perm = torch.randperm(n_samples_fake)
            f_fake = fake_features[perm[: self.subset_size]]

            o = poly_mmd(f_real, f_fake, self.degree, self.gamma, self.coef)
            kid_scores_.append(o)
        kid_scores = torch.stack(kid_scores_)
        return kid_scores.mean(), kid_scores.std(unbiased=False)
Ejemplo n.º 9
0
    def compute(self) -> Tensor:
        """Computes calibration error across all confidences and accuracies.

        Returns:
            Tensor: Calibration error across previously collected examples.
        """
        confidences = dim_zero_cat(self.confidences)
        accuracies = dim_zero_cat(self.accuracies)
        return _ce_compute(confidences, accuracies, self.bin_boundaries, norm=self.norm)
Ejemplo n.º 10
0
    def compute(self) -> Union[Tensor, List[Tensor]]:
        """Compute the average precision score.

        Returns:
            tensor with average precision. If multiclass will return list
            of such tensors, one for each class
        """
        preds = dim_zero_cat(self.preds)
        target = dim_zero_cat(self.target)
        if not self.num_classes:
            raise ValueError(f"`num_classes` bas to be positive number, but got {self.num_classes}")
        return _average_precision_compute(preds, target, self.num_classes, self.pos_label, self.average)
Ejemplo n.º 11
0
    def compute(self) -> Union[Tensor, List[Tensor]]:
        """
        Compute the average precision score

        Returns:
            tensor with average precision. If multiclass will return list
            of such tensors, one for each class

        """
        preds = dim_zero_cat(self.preds)
        target = dim_zero_cat(self.target)
        return _average_precision_compute(preds, target, self.num_classes,
                                          self.pos_label)
Ejemplo n.º 12
0
 def compute(self) -> Tensor:
     """Computes AUROC based on inputs passed in to ``update`` previously."""
     if not self.mode:
         raise RuntimeError("You have to have determined mode.")
     preds = dim_zero_cat(self.preds)
     target = dim_zero_cat(self.target)
     return _auroc_compute(
         preds,
         target,
         self.mode,
         self.num_classes,
         self.pos_label,
         self.average,
         self.max_fpr,
     )
Ejemplo n.º 13
0
 def compute(self) -> Tensor:
     """
     Computes AUROC based on inputs passed in to ``update`` previously.
     """
     preds = dim_zero_cat(self.preds)
     target = dim_zero_cat(self.target)
     return _auroc_compute(
         preds,
         target,
         self.mode,
         self.num_classes,
         self.pos_label,
         self.average,
         self.max_fpr,
     )
Ejemplo n.º 14
0
    def _sync_dist(self,
                   dist_sync_fn: Callable = gather_all_tensors,
                   process_group: Optional[Any] = None) -> None:
        input_dict = {attr: getattr(self, attr) for attr in self._reductions}

        for attr, reduction_fn in self._reductions.items():
            # pre-concatenate metric states that are lists to reduce number of all_gather operations
            if reduction_fn == dim_zero_cat and isinstance(
                    input_dict[attr], list) and len(input_dict[attr]) > 1:
                input_dict[attr] = [dim_zero_cat(input_dict[attr])]

        output_dict = apply_to_collection(
            input_dict,
            Tensor,
            dist_sync_fn,
            group=process_group or self.process_group,
        )

        for attr, reduction_fn in self._reductions.items():
            # pre-processing ops (stack or flatten for inputs)
            if isinstance(output_dict[attr][0], Tensor):
                output_dict[attr] = torch.stack(output_dict[attr])
            elif isinstance(output_dict[attr][0], list):
                output_dict[attr] = _flatten(output_dict[attr])

            if not (callable(reduction_fn) or reduction_fn is None):
                raise TypeError('reduction_fn must be callable or None')
            reduced = reduction_fn(
                output_dict[attr]
            ) if reduction_fn is not None else output_dict[attr]
            setattr(self, attr, reduced)
Ejemplo n.º 15
0
    def compute(self) -> Tuple[Tensor, Tensor]:
        features = dim_zero_cat(self.features)
        # random permute the features
        idx = torch.randperm(features.shape[0])
        features = features[idx]

        # calculate probs and logits
        prob = features.softmax(dim=1)
        log_prob = features.log_softmax(dim=1)

        # split into groups
        prob = prob.chunk(self.splits, dim=0)
        log_prob = log_prob.chunk(self.splits, dim=0)

        # calculate score per split
        mean_prob = [p.mean(dim=0, keepdim=True) for p in prob]
        kl_ = [
            p * (log_p - m_p.log())
            for p, log_p, m_p in zip(prob, log_prob, mean_prob)
        ]
        kl_ = [k.sum(dim=1).mean().exp() for k in kl_]
        kl = torch.stack(kl_)

        # return mean and std
        return kl.mean(), kl.std()
Ejemplo n.º 16
0
    def _sync_dist(self, dist_sync_fn=gather_all_tensors):
        input_dict = {
            attr: getattr(self, attr)
            for attr in self._reductions.keys()
        }
        for attr, reduction_fn in self._reductions.items():
            # pre-concatenate metric states that are lists to reduce number of all_gather operations
            if reduction_fn == dim_zero_cat and isinstance(
                    input_dict[attr], list) and len(input_dict[attr]) > 1:
                input_dict[attr] = [dim_zero_cat(input_dict[attr])]
        output_dict = apply_to_collection(
            input_dict,
            Tensor,
            dist_sync_fn,
            group=self.process_group,
        )

        for attr, reduction_fn in self._reductions.items():
            # pre-processing ops (stack or flatten for inputs)
            if isinstance(output_dict[attr][0], Tensor):
                output_dict[attr] = torch.stack(output_dict[attr])
            elif isinstance(output_dict[attr][0], list):
                output_dict[attr] = _flatten(output_dict[attr])

            assert isinstance(reduction_fn, Callable) or reduction_fn is None
            reduced = reduction_fn(
                output_dict[attr]
            ) if reduction_fn is not None else output_dict[attr]
            setattr(self, attr, reduced)
Ejemplo n.º 17
0
 def compute(self) -> Tensor:
     """Computes explained variance over state."""
     preds = dim_zero_cat(self.preds)
     target = dim_zero_cat(self.target)
     return _ssim_compute(
         preds,
         target,
         self.gaussian_kernel,
         self.sigma,
         self.kernel_size,
         self.reduction,
         self.data_range,
         self.k1,
         self.k2,
         self.return_full_image,
         self.return_contrast_sensitivity,
     )
Ejemplo n.º 18
0
 def compute(self) -> Tensor:
     """Computes explained variance over state."""
     preds = dim_zero_cat(self.preds)
     target = dim_zero_cat(self.target)
     return _multiscale_ssim_compute(
         preds,
         target,
         self.gaussian_kernel,
         self.sigma,
         self.kernel_size,
         self.reduction,
         self.data_range,
         self.k1,
         self.k2,
         self.betas,
         self.normalize,
     )
Ejemplo n.º 19
0
Archivo: fid.py Proyecto: Borda/metrics
    def compute(self) -> Tensor:
        """Calculate FID score based on accumulated extracted features from the two distributions."""
        real_features = dim_zero_cat(self.real_features)
        fake_features = dim_zero_cat(self.fake_features)
        # computation is extremely sensitive so it needs to happen in double precision
        orig_dtype = real_features.dtype
        real_features = real_features.double()
        fake_features = fake_features.double()

        # calculate mean and covariance
        n = real_features.shape[0]
        mean1 = real_features.mean(dim=0)
        mean2 = fake_features.mean(dim=0)
        diff1 = real_features - mean1
        diff2 = fake_features - mean2
        cov1 = 1.0 / (n - 1) * diff1.t().mm(diff1)
        cov2 = 1.0 / (n - 1) * diff2.t().mm(diff2)

        # compute fid
        return _compute_fid(mean1, cov1, mean2, cov2).to(orig_dtype)
Ejemplo n.º 20
0
 def compute(self) -> Tensor:
     """Computes and returns spectral distortion index."""
     preds = dim_zero_cat(self.preds)
     target = dim_zero_cat(self.target)
     return _spectral_distortion_index_compute(preds, target, self.p,
                                               self.reduction)
Ejemplo n.º 21
0
 def compute(self) -> Tensor:
     """Computes Spearman's correlation coefficient."""
     preds = dim_zero_cat(self.preds)
     target = dim_zero_cat(self.target)
     return _spearman_corrcoef_compute(preds, target)
Ejemplo n.º 22
0
 def compute(self) -> Tensor:
     measures = dim_zero_cat(self.measures) if self.reduction is None or self.reduction == 'none' else self.measures
     return _kld_compute(measures, self.total, self.reduction)
Ejemplo n.º 23
0
 def compute(self) -> Tensor:
     """Computes explained variance over state."""
     preds = dim_zero_cat(self.preds)
     target = dim_zero_cat(self.target)
     return _ergas_compute(preds, target, self.ratio, self.reduction)
Ejemplo n.º 24
0
 def compute(self) -> Tensor:
     """Compute the aggregated value."""
     if isinstance(self.value, list) and self.value:
         return dim_zero_cat(self.value)
     return self.value
Ejemplo n.º 25
0
 def compute(self) -> Tensor:
     preds = dim_zero_cat(self.preds)
     target = dim_zero_cat(self.target)
     return _cosine_similarity_compute(preds, target, self.reduction)
Ejemplo n.º 26
0
 def compute(self) -> Tensor:
     """Computes spectra over state."""
     preds = dim_zero_cat(self.preds)
     target = dim_zero_cat(self.target)
     return _sam_compute(preds, target, self.reduction)