示例#1
0
class TestNDCG(RetrievalMetricTester):
    @pytest.mark.parametrize("ddp", [True, False])
    @pytest.mark.parametrize("dist_sync_on_step", [True, False])
    @pytest.mark.parametrize("empty_target_action", ['skip', 'neg', 'pos'])
    @pytest.mark.parametrize("k", [None, 1, 4, 10])
    @pytest.mark.parametrize(
        **_default_metric_class_input_arguments_with_non_binary_target)
    def test_class_metric(
        self,
        ddp: bool,
        indexes: Tensor,
        preds: Tensor,
        target: Tensor,
        dist_sync_on_step: bool,
        empty_target_action: str,
        k: int,
    ):
        metric_args = {'empty_target_action': empty_target_action, 'k': k}

        self.run_class_metric_test(
            ddp=ddp,
            indexes=indexes,
            preds=preds,
            target=target,
            metric_class=RetrievalNormalizedDCG,
            sk_metric=_ndcg_at_k,
            dist_sync_on_step=dist_sync_on_step,
            metric_args=metric_args,
        )

    @pytest.mark.parametrize(
        **_default_metric_functional_input_arguments_with_non_binary_target)
    @pytest.mark.parametrize("k", [None, 1, 4, 10])
    def test_functional_metric(self, preds: Tensor, target: Tensor, k: int):
        self.run_functional_metric_test(
            preds=preds,
            target=target,
            metric_functional=retrieval_normalized_dcg,
            sk_metric=_ndcg_at_k,
            metric_args={},
            k=k,
        )

    @pytest.mark.parametrize(
        **_default_metric_class_input_arguments_with_non_binary_target)
    def test_precision_cpu(self, indexes: Tensor, preds: Tensor,
                           target: Tensor):
        self.run_precision_test_cpu(
            indexes=indexes,
            preds=preds,
            target=target,
            metric_module=RetrievalNormalizedDCG,
            metric_functional=retrieval_normalized_dcg,
        )

    @pytest.mark.parametrize(
        **_default_metric_class_input_arguments_with_non_binary_target)
    def test_precision_gpu(self, indexes: Tensor, preds: Tensor,
                           target: Tensor):
        self.run_precision_test_gpu(
            indexes=indexes,
            preds=preds,
            target=target,
            metric_module=RetrievalNormalizedDCG,
            metric_functional=retrieval_normalized_dcg,
        )

    @pytest.mark.parametrize(**_concat_tests(
        _errors_test_class_metric_parameters_default,
        _errors_test_class_metric_parameters_no_pos_target,
        _errors_test_class_metric_parameters_k,
    ))
    def test_arguments_class_metric(self, indexes: Tensor, preds: Tensor,
                                    target: Tensor, message: str,
                                    metric_args: dict):
        self.run_metric_class_arguments_test(
            indexes=indexes,
            preds=preds,
            target=target,
            metric_class=RetrievalNormalizedDCG,
            message=message,
            metric_args=metric_args,
            exception_type=ValueError,
            kwargs_update={},
        )

    @pytest.mark.parametrize(**_concat_tests(
        _errors_test_functional_metric_parameters_default,
        _errors_test_functional_metric_parameters_k,
    ))
    def test_arguments_functional_metric(self, preds: Tensor, target: Tensor,
                                         message: str, metric_args: dict):
        self.run_functional_metric_arguments_test(
            preds=preds,
            target=target,
            metric_functional=retrieval_normalized_dcg,
            message=message,
            exception_type=ValueError,
            kwargs_update=metric_args,
        )
示例#2
0
class TestMRR(RetrievalMetricTester):
    @pytest.mark.parametrize("ddp", [True, False])
    @pytest.mark.parametrize("dist_sync_on_step", [True, False])
    @pytest.mark.parametrize("empty_target_action", ['skip', 'neg', 'pos'])
    @pytest.mark.parametrize(**_default_metric_class_input_arguments)
    def test_class_metric(
        self,
        ddp: bool,
        indexes: Tensor,
        preds: Tensor,
        target: Tensor,
        dist_sync_on_step: bool,
        empty_target_action: str,
    ):
        metric_args = {'empty_target_action': empty_target_action}

        self.run_class_metric_test(
            ddp=ddp,
            indexes=indexes,
            preds=preds,
            target=target,
            metric_class=RetrievalMRR,
            sk_metric=_reciprocal_rank,
            dist_sync_on_step=dist_sync_on_step,
            metric_args=metric_args,
        )

    @pytest.mark.parametrize(**_default_metric_functional_input_arguments)
    def test_functional_metric(self, preds: Tensor, target: Tensor):
        self.run_functional_metric_test(
            preds=preds,
            target=target,
            metric_functional=retrieval_reciprocal_rank,
            sk_metric=_reciprocal_rank,
            metric_args={},
        )

    @pytest.mark.parametrize(**_default_metric_class_input_arguments)
    def test_precision_cpu(self, indexes: Tensor, preds: Tensor,
                           target: Tensor):
        self.run_precision_test_cpu(
            indexes=indexes,
            preds=preds,
            target=target,
            metric_module=RetrievalMRR,
            metric_functional=retrieval_reciprocal_rank,
        )

    @pytest.mark.parametrize(**_default_metric_class_input_arguments)
    def test_precision_gpu(self, indexes: Tensor, preds: Tensor,
                           target: Tensor):
        self.run_precision_test_gpu(
            indexes=indexes,
            preds=preds,
            target=target,
            metric_module=RetrievalMRR,
            metric_functional=retrieval_reciprocal_rank,
        )

    @pytest.mark.parametrize(**_concat_tests(
        _errors_test_class_metric_parameters_default,
        _errors_test_class_metric_parameters_no_pos_target,
    ))
    def test_arguments_class_metric(self, indexes: Tensor, preds: Tensor,
                                    target: Tensor, message: str,
                                    metric_args: dict):
        self.run_metric_class_arguments_test(
            indexes=indexes,
            preds=preds,
            target=target,
            metric_class=RetrievalMRR,
            message=message,
            metric_args=metric_args,
            exception_type=ValueError,
            kwargs_update={},
        )

    @pytest.mark.parametrize(**
                             _errors_test_functional_metric_parameters_default)
    def test_arguments_functional_metric(self, preds: Tensor, target: Tensor,
                                         message: str, metric_args: dict):
        self.run_functional_metric_arguments_test(
            preds=preds,
            target=target,
            metric_functional=retrieval_reciprocal_rank,
            message=message,
            exception_type=ValueError,
            kwargs_update=metric_args,
        )
示例#3
0
class TestRPrecision(RetrievalMetricTester):
    @pytest.mark.parametrize("ddp", [True, False])
    @pytest.mark.parametrize("dist_sync_on_step", [True, False])
    @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"])
    @pytest.mark.parametrize(
        "ignore_index", [None, 1]
    )  # avoid setting 0, otherwise test with all 0 targets will fail
    @pytest.mark.parametrize(**_default_metric_class_input_arguments)
    def test_class_metric(
        self,
        ddp: bool,
        indexes: Tensor,
        preds: Tensor,
        target: Tensor,
        dist_sync_on_step: bool,
        empty_target_action: str,
        ignore_index: int,
    ):
        metric_args = dict(empty_target_action=empty_target_action,
                           ignore_index=ignore_index)

        self.run_class_metric_test(
            ddp=ddp,
            indexes=indexes,
            preds=preds,
            target=target,
            metric_class=RetrievalRPrecision,
            sk_metric=_r_precision,
            dist_sync_on_step=dist_sync_on_step,
            metric_args=metric_args,
        )

    @pytest.mark.parametrize("ddp", [True, False])
    @pytest.mark.parametrize("dist_sync_on_step", [True, False])
    @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"])
    @pytest.mark.parametrize(
        **_default_metric_class_input_arguments_ignore_index)
    def test_class_metric_ignore_index(
        self,
        ddp: bool,
        indexes: Tensor,
        preds: Tensor,
        target: Tensor,
        dist_sync_on_step: bool,
        empty_target_action: str,
    ):
        metric_args = dict(empty_target_action=empty_target_action,
                           ignore_index=-100)

        self.run_class_metric_test(
            ddp=ddp,
            indexes=indexes,
            preds=preds,
            target=target,
            metric_class=RetrievalRPrecision,
            sk_metric=_r_precision,
            dist_sync_on_step=dist_sync_on_step,
            metric_args=metric_args,
        )

    @pytest.mark.parametrize(**_default_metric_functional_input_arguments)
    def test_functional_metric(self, preds: Tensor, target: Tensor):
        self.run_functional_metric_test(
            preds=preds,
            target=target,
            metric_functional=retrieval_r_precision,
            sk_metric=_r_precision,
            metric_args={},
        )

    @pytest.mark.parametrize(**_default_metric_class_input_arguments)
    def test_precision_cpu(self, indexes: Tensor, preds: Tensor,
                           target: Tensor):
        self.run_precision_test_cpu(
            indexes=indexes,
            preds=preds,
            target=target,
            metric_module=RetrievalRPrecision,
            metric_functional=retrieval_r_precision,
        )

    @pytest.mark.parametrize(**_default_metric_class_input_arguments)
    def test_precision_gpu(self, indexes: Tensor, preds: Tensor,
                           target: Tensor):
        self.run_precision_test_gpu(
            indexes=indexes,
            preds=preds,
            target=target,
            metric_module=RetrievalRPrecision,
            metric_functional=retrieval_r_precision,
        )

    @pytest.mark.parametrize(**_concat_tests(
        _errors_test_class_metric_parameters_default,
        _errors_test_class_metric_parameters_no_pos_target,
    ))
    def test_arguments_class_metric(self, indexes: Tensor, preds: Tensor,
                                    target: Tensor, message: str,
                                    metric_args: dict):
        self.run_metric_class_arguments_test(
            indexes=indexes,
            preds=preds,
            target=target,
            metric_class=RetrievalRPrecision,
            message=message,
            metric_args=metric_args,
            exception_type=ValueError,
            kwargs_update={},
        )

    @pytest.mark.parametrize(**
                             _errors_test_functional_metric_parameters_default)
    def test_arguments_functional_metric(self, preds: Tensor, target: Tensor,
                                         message: str, metric_args: dict):
        self.run_functional_metric_arguments_test(
            preds=preds,
            target=target,
            metric_functional=retrieval_r_precision,
            message=message,
            exception_type=ValueError,
            kwargs_update=metric_args,
        )
示例#4
0
class TestNDCG(RetrievalMetricTester):
    @pytest.mark.parametrize("ddp", [True, False])
    @pytest.mark.parametrize("dist_sync_on_step", [True, False])
    @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"])
    @pytest.mark.parametrize(
        "ignore_index", [None, 3]
    )  # avoid setting 0, otherwise test with all 0 targets will fail
    @pytest.mark.parametrize("k", [None, 1, 4, 10])
    @pytest.mark.parametrize(
        **_default_metric_class_input_arguments_with_non_binary_target)
    def test_class_metric(
        self,
        ddp: bool,
        indexes: Tensor,
        preds: Tensor,
        target: Tensor,
        dist_sync_on_step: bool,
        empty_target_action: str,
        ignore_index: int,
        k: int,
    ):
        metric_args = dict(empty_target_action=empty_target_action,
                           k=k,
                           ignore_index=ignore_index)

        self.run_class_metric_test(
            ddp=ddp,
            indexes=indexes,
            preds=preds,
            target=target,
            metric_class=RetrievalNormalizedDCG,
            sk_metric=_ndcg_at_k,
            dist_sync_on_step=dist_sync_on_step,
            metric_args=metric_args,
        )

    @pytest.mark.parametrize("ddp", [True, False])
    @pytest.mark.parametrize("dist_sync_on_step", [True, False])
    @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"])
    @pytest.mark.parametrize("k", [None, 1, 4, 10])
    @pytest.mark.parametrize(
        **_default_metric_class_input_arguments_ignore_index)
    def test_class_metric_ignore_index(
        self,
        ddp: bool,
        indexes: Tensor,
        preds: Tensor,
        target: Tensor,
        dist_sync_on_step: bool,
        empty_target_action: str,
        k: int,
    ):
        metric_args = dict(empty_target_action=empty_target_action,
                           k=k,
                           ignore_index=-100)

        self.run_class_metric_test(
            ddp=ddp,
            indexes=indexes,
            preds=preds,
            target=target,
            metric_class=RetrievalNormalizedDCG,
            sk_metric=_ndcg_at_k,
            dist_sync_on_step=dist_sync_on_step,
            metric_args=metric_args,
        )

    @pytest.mark.parametrize(
        **_default_metric_functional_input_arguments_with_non_binary_target)
    @pytest.mark.parametrize("k", [None, 1, 4, 10])
    def test_functional_metric(self, preds: Tensor, target: Tensor, k: int):
        self.run_functional_metric_test(
            preds=preds,
            target=target,
            metric_functional=retrieval_normalized_dcg,
            sk_metric=_ndcg_at_k,
            metric_args={},
            k=k,
        )

    @pytest.mark.parametrize(
        **_default_metric_class_input_arguments_with_non_binary_target)
    def test_precision_cpu(self, indexes: Tensor, preds: Tensor,
                           target: Tensor):
        self.run_precision_test_cpu(
            indexes=indexes,
            preds=preds,
            target=target,
            metric_module=RetrievalNormalizedDCG,
            metric_functional=retrieval_normalized_dcg,
        )

    @pytest.mark.parametrize(
        **_default_metric_class_input_arguments_with_non_binary_target)
    def test_precision_gpu(self, indexes: Tensor, preds: Tensor,
                           target: Tensor):
        self.run_precision_test_gpu(
            indexes=indexes,
            preds=preds,
            target=target,
            metric_module=RetrievalNormalizedDCG,
            metric_functional=retrieval_normalized_dcg,
        )

    @pytest.mark.parametrize(**_concat_tests(
        _errors_test_class_metric_parameters_with_nonbinary,
        _errors_test_class_metric_parameters_k,
    ))
    def test_arguments_class_metric(self, indexes: Tensor, preds: Tensor,
                                    target: Tensor, message: str,
                                    metric_args: dict):
        if target.is_floating_point():
            pytest.skip("NDCG metric works with float target input")

        self.run_metric_class_arguments_test(
            indexes=indexes,
            preds=preds,
            target=target,
            metric_class=RetrievalNormalizedDCG,
            message=message,
            metric_args=metric_args,
            exception_type=ValueError,
            kwargs_update={},
        )

    @pytest.mark.parametrize(**_concat_tests(
        _errors_test_functional_metric_parameters_with_nonbinary,
        _errors_test_functional_metric_parameters_k,
    ))
    def test_arguments_functional_metric(self, preds: Tensor, target: Tensor,
                                         message: str, metric_args: dict):
        if target.is_floating_point():
            pytest.skip("NDCG metric works with float target input")

        self.run_functional_metric_arguments_test(
            preds=preds,
            target=target,
            metric_functional=retrieval_normalized_dcg,
            message=message,
            exception_type=ValueError,
            kwargs_update=metric_args,
        )