Exemplo n.º 1
0
    def test_get_metrics_for_unknown_framework(
        self,
        mocked_check_module: MagicMock,
        mocked_load_help_lpot_params: MagicMock,
    ) -> None:
        """Test that get_domains fails when no config given."""
        mocked_load_help_lpot_params.return_value = {
            "__help__topk": "help for topk",
            "topk": {
                "__help__k": "help for k in topk",
                "__help__missing_param": "help for missing_param in topk",
            },
            "__help__metric1": "help for metric1",
            "__help__metric3": "help for metric3",
        }

        expected = [
            {
                "name": "custom",
                "help": "",
                "value": None,
            },
        ]

        feeder = Feeder(data={
            "config": {
                "framework": "unknown_framework",
            },
        }, )

        actual = feeder.get_metrics()

        mocked_check_module.assert_called_once_with("unknown_framework")
        mocked_load_help_lpot_params.assert_called_once_with("metrics")
        self.assertEqual(expected, actual)
Exemplo n.º 2
0
    def test_get_metrics_for_onnxrt(
        self,
        mocked_check_module: MagicMock,
        mocked_load_help_lpot_params: MagicMock,
    ) -> None:
        """Test that get_domains fails when no config given."""
        mocked_load_help_lpot_params.return_value = {
            "__help__topk": "help for topk",
            "topk": {
                "__help__k": "help for k in topk",
                "__help__missing_param": "help for missing_param in topk",
            },
            "__help__metric1": "help for metric1",
            "__help__metric3": "help for metric3",
        }

        expected = [
            {
                "name": "topk",
                "help": "help for topk",
                "params": [
                    {
                        "name": "k",
                        "help": "help for k in topk",
                        "value": [1, 5],
                    },
                ],
            },
            {
                "name": "COCOmAP",
                "help": "",
                "params": [
                    {
                        "name": "anno_path",
                        "help": "",
                        "value": "",
                    },
                ],
            },
            {
                "name": "MSE",
                "help": "",
                "params": [
                    {
                        "name": "compare_label",
                        "help": "",
                        "value": True,
                    },
                ],
            },
            {
                "name": "RMSE",
                "help": "",
                "params": [
                    {
                        "name": "compare_label",
                        "help": "",
                        "value": True,
                    },
                ],
            },
            {
                "name": "MAE",
                "help": "",
                "params": [
                    {
                        "name": "compare_label",
                        "help": "",
                        "value": True,
                    },
                ],
            },
            {
                "name": "metric1",
                "help": "help for metric1",
                "value": None,
            },
            {
                "name": "custom",
                "help": "",
                "value": None,
            },
        ]

        feeder = Feeder(
            data={
                "config": {
                    "framework": "onnxrt",
                },
            },
        )

        actual = feeder.get_metrics()

        mocked_check_module.assert_called_once_with("onnxrt")
        mocked_load_help_lpot_params.assert_called_once_with("metrics")
        self.assertEqual(expected, actual)
Exemplo n.º 3
0
 def test_get_metrics_fails_without_framework(self) -> None:
     """Test that get_domains fails when no config given."""
     feeder = Feeder(data={})
     with self.assertRaisesRegex(ClientErrorException, "Framework not set."):
         feeder.get_metrics()