Exemplo n.º 1
0
    def __init__(
        self,
        input_key: str = "targets",
        output_key: str = "logits",
        prefix: str = "map",
        map_args: List[int] = None,
        num_classes: int = None,
    ):
        """
        Args:
            input_key (str): input key to use for
                calculation mean average accuracy at k;
                specifies our `y_true`
            output_key (str): output key to use for
                calculation mean average accuracy at k;
                specifies our `y_pred`
            prefix (str): key for the metric's name
            map_args (List[int]): specifies which map@K to log.
                [1] - map@1
                [1, 3] - map@1 and map@3
                [1, 3, 5] - map@1, map@3 and map@5
            num_classes (int): number of classes to calculate ``map_args``
                if ``map_args`` is None
        """
        list_args = map_args or get_default_topk_args(num_classes)

        super().__init__(
            prefix=prefix,
            metric_fn=metrics.mean_average_accuracy,
            list_args=list_args,
            input_key=input_key,
            output_key=output_key,
            topk=list_args,
        )
Exemplo n.º 2
0
    def __init__(
        self,
        embeddings_key: str = "logits",
        labels_key: str = "targets",
        is_query_key: str = "is_query",
        prefix: str = "cmc",
        topk_args: List[int] = None,
        num_classes: int = None,
    ):
        """
        This callback was designed to count
        cumulative matching characteristics.
        If current object is from query your dataset
        should output `True` in `is_query_key`
        and false if current object is from gallery.
        You can see `QueryGalleryDataset` in
        `catalyst.contrib.data.ml` for more information.
        On batch end callback accumulate all embeddings
        Args:
            embeddings_key (str): embeddings key in output dict
            labels_key (str): labels key in output dict
            is_query_key (str): bool key True if current
                object is from query
            prefix (str): key for the metric's name
            topk_args (List[int]): specifies which cmc@K to log.
                [1] - cmc@1
                [1, 3] - cmc@1 and cmc@3
                [1, 3, 5] - cmc@1, cmc@3 and cmc@5
            num_classes (int): number of classes to calculate ``accuracy_args``
                if ``topk_args`` is None

        """
        super().__init__(order=CallbackOrder.Metric)
        self.list_args = topk_args or get_default_topk_args(num_classes)
        self._metric_fn = cmc_score
        self._prefix = prefix
        self.embeddings_key = embeddings_key
        self.labels_key = labels_key
        self.is_query_key = is_query_key
        self._gallery_embeddings: torch.Tensor = None
        self._query_embeddings: torch.Tensor = None
        self._gallery_labels: torch.Tensor = None
        self._query_labels: torch.Tensor = None
        self._gallery_idx = None
        self._query_idx = None
        self._query_size = None
        self._gallery_size = None
Exemplo n.º 3
0
    def __init__(
        self,
        input_key: str = "targets",
        output_key: str = "logits",
        prefix: str = "accuracy",
        accuracy_args: List[int] = None,
        num_classes: int = None,
        threshold: float = None,
        activation: str = None,
    ):
        """
        Args:
            input_key (str): input key to use for accuracy calculation;
                specifies our `y_true`
            output_key (str): output key to use for accuracy calculation;
                specifies our `y_pred`
            prefix (str): key for the metric's name
            accuracy_args (List[int]): specifies which accuracy@K to log:
                [1] - accuracy
                [1, 3] - accuracy at 1 and 3
                [1, 3, 5] - accuracy at 1, 3 and 5
            num_classes (int): number of classes to calculate ``accuracy_args``
                if ``accuracy_args`` is None
            threshold (float): threshold for outputs binarization.
            activation (str): An torch.nn activation applied to the outputs.
                Must be one of ``"none"``, ``"Sigmoid"``, or ``"Softmax"``
        """
        list_args = accuracy_args or get_default_topk_args(num_classes)

        super().__init__(
            prefix=prefix,
            metric_fn=metrics.accuracy,
            list_args=list_args,
            input_key=input_key,
            output_key=output_key,
            topk=list_args,
            threshold=threshold,
            activation=activation,
        )