Ejemplo n.º 1
0
    def compute(self) -> List[float]:
        """
        Compute cmc@k metrics with all the accumulated data for all k.

        Returns:
            list of metrics values
        """
        query_mask = (self.storage[self.is_query_key] == 1).to(torch.bool)

        embeddings = self.storage[self.embeddings_key].float()
        labels = self.storage[self.labels_key]

        query_embeddings = embeddings[query_mask]
        query_labels = labels[query_mask]

        gallery_embeddings = embeddings[~query_mask]
        gallery_labels = labels[~query_mask]

        conformity_matrix = (gallery_labels == query_labels.reshape(-1, 1)).to(torch.bool)

        metrics = []
        for k in self.topk_args:
            value = cmc_score(
                query_embeddings=query_embeddings,
                gallery_embeddings=gallery_embeddings,
                conformity_matrix=conformity_matrix,
                topk=k,
            )
            metrics.append(value)

        return metrics
Ejemplo n.º 2
0
def test_no_mask_cmc_score(
    query_embeddings,
    gallery_embeddings,
    conformity_matrix,
    available_samples,
    topk,
) -> None:
    """
    In this test we just check that masked_cmc_score is equal to cmc_score
    when all the samples are available for for scoring.
    """
    masked_score = masked_cmc_score(
        query_embeddings=query_embeddings,
        gallery_embeddings=gallery_embeddings,
        conformity_matrix=conformity_matrix,
        available_samples=available_samples,
        topk=topk,
    )
    score = cmc_score(
        query_embeddings=query_embeddings,
        gallery_embeddings=gallery_embeddings,
        conformity_matrix=conformity_matrix,
        topk=topk,
    )
    assert masked_score == score
Ejemplo n.º 3
0
def test_cmc_score_with_samples(generate_samples_for_cmc_score):
    """
    Count cmc score callback for sets of well-separated data clusters labeled
    with error_rate probability mistake.
    """
    for (
        error_rate,
        query_embs,
        query_labels,
        gallery_embs,
        gallery_labels,
    ) in generate_samples_for_cmc_score:
        true_cmc_01 = 1 - error_rate
        conformity_matrix = (query_labels.reshape((-1, 1)) == gallery_labels).to(torch.bool)
        cmc = cmc_score(
            query_embeddings=query_embs,
            gallery_embeddings=gallery_embs,
            conformity_matrix=conformity_matrix,
            topk=1,
        )
        assert abs(cmc - true_cmc_01) <= 0.05