예제 #1
0
    def test_get_transforms(
        self,
        mocked_load_transforms_config: MagicMock,
        mocked_load_transforms_filter_config: MagicMock,
    ) -> None:
        """Test that get_transforms works."""
        mocked_load_transforms_config.return_value = [
            {
                "name": "framework_foo",
                "help": None,
                "params": [
                    {"name": "transform1"},
                    {"name": "transform2"},
                    {"name": "transform3"},
                ],
            },
            {
                "name": "framework_bar",
                "help": None,
                "params": [
                    {"name": "transform4"},
                    {"name": "transform5"},
                    {"name": "transform6"},
                ],
            },
        ]
        mocked_load_transforms_filter_config.return_value = {
            "framework_foo": {
                "domain_foo": [
                    "transform1",
                    "transform2",
                    "transform3",
                    "transform4",
                ],
                "domain_bar": [
                    "transform2",
                    "transform3",
                    "transform4",
                ],
            },
        }

        expected = [
            {"name": "transform2"},
            {"name": "transform3"},
        ]

        feeder = Feeder(
            data={
                "config": {
                    "framework": "framework_foo",
                    "domain": "domain_bar",
                },
            },
        )
        actual = feeder.get_transforms()

        mocked_load_transforms_config.assert_called_once()
        mocked_load_transforms_filter_config.assert_called_once()
        self.assertEqual(expected, actual)
예제 #2
0
    def test_get_precisions_for_missing_framework(
        self,
        mocked_load_precisions_config: MagicMock,
    ) -> None:
        """Test get_precisions function."""
        mocked_load_precisions_config.return_value = {
            "framework_foo": [
                {
                    "foo": {}
                },
                {
                    "bar": {}
                },
                {
                    "baz": {}
                },
            ],
            "framework_bar": [
                {
                    "foo": {}
                },
            ],
        }

        feeder = Feeder(data={})

        with self.assertRaisesRegex(ClientErrorException,
                                    "Framework not set."):
            feeder.get_precisions()
예제 #3
0
    def test_get_precisions_for_unknown_framework(
        self,
        mocked_load_precisions_config: MagicMock,
    ) -> None:
        """Test get_precisions function."""
        mocked_load_precisions_config.return_value = {
            "framework_foo": [
                {"foo": {}},
                {"bar": {}},
                {"baz": {}},
            ],
            "framework_bar": [
                {"foo": {}},
            ],
        }

        feeder = Feeder(
            data={
                "config": {
                    "framework": "framework_baz",
                },
            },
        )
        actual = feeder.get_precisions()

        self.assertEqual([], actual)
예제 #4
0
    def test_get_dataloaders_for_unknown_framework(
        self,
        mocked_load_dataloader_config: MagicMock,
    ) -> None:
        """Test that get_dataloaders works when unknown framework requested."""
        mocked_load_dataloader_config.return_value = [
            {
                "name": "framework_foo",
                "help": None,
                "params": {},
            },
            {
                "name": "framework_bar",
                "help": None,
                "params": {},
            },
        ]

        feeder = Feeder(
            data={
                "config": {
                    "framework": "framework_unknown",
                },
            },
        )
        actual = feeder.get_dataloaders()

        mocked_load_dataloader_config.assert_called_once()
        self.assertEqual([], actual)
예제 #5
0
    def test_get_metrics_for_unknown_framework(
        self,
        mocked_check_module: MagicMock,
        mocked_load_help_lpot_params: MagicMock,
    ) -> None:
        """Test that get_domains fails when no config given."""
        mocked_load_help_lpot_params.return_value = {
            "__help__topk": "help for topk",
            "topk": {
                "__help__k": "help for k in topk",
                "__help__missing_param": "help for missing_param in topk",
            },
            "__help__metric1": "help for metric1",
            "__help__metric3": "help for metric3",
        }

        expected = [
            {
                "name": "custom",
                "help": "",
                "value": None,
            },
        ]

        feeder = Feeder(data={
            "config": {
                "framework": "unknown_framework",
            },
        }, )

        actual = feeder.get_metrics()

        mocked_check_module.assert_called_once_with("unknown_framework")
        mocked_load_help_lpot_params.assert_called_once_with("metrics")
        self.assertEqual(expected, actual)
예제 #6
0
    def test_get_models(
        self,
        mocked_load_model_config: MagicMock,
    ) -> None:
        """Test get_models method."""
        mocked_load_model_config.return_value = {
            "__help__framework_foo": "help for framework_foo",
            "framework_foo": {
                "__help__domain1": "help text for framework_foo/domain1",
                "domain1": {},
                "__help__domain2": "help text for framework_foo/domain2",
                "domain2": {
                    "__help__model1": "help for model 1",
                    "model1": {},
                    "__help__model2": "help for model 2",
                    "model2": {},
                    "__help__model3": "help for model 3",
                    "model3": {},
                },
                "__help__domain3": "help text for framework_foo/domain3",
                "domain3": {},
            },
            "__help__framework_bar": "help for framework_bar",
            "framework_bar": {
                "__help__domain1": "help text for framework_bar/domain1",
                "domain1": {},
            },
            "__help__framework_baz": "help for framework_baz",
            "framework_baz": {
                "__help__domain1": "help text for framework_baz/domain1",
                "domain1": {},
            },
        }

        expected = [
            {
                "name": "model1",
                "help": "help for model 1",
            },
            {
                "name": "model2",
                "help": "help for model 2",
            },
            {
                "name": "model3",
                "help": "help for model 3",
            },
        ]

        feeder = Feeder(data={
            "config": {
                "framework": "framework_foo",
                "domain": "domain2",
            },
        }, )
        actual = feeder.get_models()

        mocked_load_model_config.assert_called_once()
        self.assertEqual(expected, actual)
예제 #7
0
 def test_feed_with_unknown_param_fails(self) -> None:
     """Test that calling feed with not supported param fails."""
     feeder = Feeder(data={"param": "foo"})
     with self.assertRaisesRegex(
         ClientErrorException,
         "Could not found method for foo parameter.",
     ):
         feeder.feed()
예제 #8
0
 def test_get_models_fails_without_domain(self) -> None:
     """Test that get_models fails when no config given."""
     feeder = Feeder(data={
         "config": {
             "framework": "framework_foo",
         },
     }, )
     with self.assertRaisesRegex(ClientErrorException, "Domain not set."):
         feeder.get_models()
예제 #9
0
    def test_get_quantization_approaches_for_fake_framework(self) -> None:
        """Test get_quantization_approaches."""
        feeder = Feeder(data={
            "config": {
                "framework": "framework_foo",
            },
        }, )
        output = feeder.get_quantization_approaches()
        quantization_names = [approach.get("name") for approach in output]

        self.assertEqual(["post_training_static_quant"], quantization_names)
예제 #10
0
    def test_feed_for_framework_works(self, mocked_get_frameworks: MagicMock) -> None:
        """Test that calling feed with not supported param fails."""
        frameworks = [
            {"name": "framework1"},
            {"name": "framework2"},
            {"name": "framework3"},
        ]
        mocked_get_frameworks.return_value = frameworks
        expected = {"framework": frameworks}

        feeder = Feeder(data={"param": "framework"})
        actual = feeder.feed()

        self.assertEqual(expected, actual)
예제 #11
0
    def test_get_strategies(
        self,
        mocked_load_help_lpot_params: MagicMock,
    ) -> None:
        """Test get_strategies function."""
        mocked_load_help_lpot_params.return_value = {
            "__help__strategy1": "help1",
            "__help__strategy_unknown": "this should be skipped",
            "__help__strategy2": "help2",
        }
        expected = [
            {
                "name": "strategy1",
                "help": "help1",
            },
            {
                "name": "strategy2",
                "help": "help2",
            },
            {
                "name": "strategy3",
                "help": "",
            },
        ]

        actual = Feeder.get_strategies()

        mocked_load_help_lpot_params.assert_called_once_with("strategies")
        self.assertEqual(expected, actual)
예제 #12
0
    def test_get_transforms_without_domain(
        self,
        mocked_load_transforms_config: MagicMock,
    ) -> None:
        """Test that get_transforms works when no domain requested."""
        mocked_load_transforms_config.return_value = [
            {
                "name": "framework_foo",
                "help": None,
                "params": [
                    {"name": "transform1"},
                    {"name": "transform2"},
                    {"name": "transform3"},
                ],
            },
            {
                "name": "framework_bar",
                "help": None,
                "params": [
                    {"name": "transform4"},
                    {"name": "transform5"},
                    {"name": "transform6"},
                ],
            },
        ]
        expected = [
            {"name": "transform1"},
            {"name": "transform2"},
            {"name": "transform3"},
        ]

        feeder = Feeder(
            data={
                "config": {
                    "framework": "framework_foo",
                },
            },
        )
        actual = feeder.get_transforms()

        mocked_load_transforms_config.assert_called_once()
        self.assertEqual(expected, actual)
예제 #13
0
    def test_get_dataloaders(
        self,
        mocked_load_dataloader_config: MagicMock,
    ) -> None:
        """Test that get_dataloaders works."""
        params = [
            {
                "name": "param1"
            },
            {
                "name": "param2"
            },
            {
                "name": "param3"
            },
        ]
        mocked_load_dataloader_config.return_value = [
            {
                "name": "framework_foo",
                "help": None,
                "params": params,
            },
            {
                "name": "framework_bar",
                "help": None,
                "params": {},
            },
        ]

        expected = params

        feeder = Feeder(data={
            "config": {
                "framework": "framework_foo",
            },
        }, )
        actual = feeder.get_dataloaders()

        mocked_load_dataloader_config.assert_called_once()
        self.assertEqual(expected, actual)
예제 #14
0
    def test_get_frameworks(
        self,
        mocked_load_model_config: MagicMock,
        mocked_get_frameworks: MagicMock,
    ) -> None:
        """Test get_frameworks function."""
        mocked_load_model_config.return_value = {
            "__help__framework_foo": "framework_foo is in known frameworks, so should be inculded",
            "framework_foo": {
                "__help__domain1": "help text for framework_foo/domain1",
                "domain1": {},
            },
            "__help__framework_bar": "framework_bar is not known, so should be ignored",
            "framework_bar": {
                "__help__domain1": "help text for framework_bar/domain1",
                "domain1": {},
            },
            "__help__framework_baz": "framework_baz is in known frameworks, so should be inculded",
            "framework_baz": {
                "__help__domain1": "help text for framework_baz/domain1",
                "domain1": {},
            },
        }
        mocked_get_frameworks.return_value = [
            "framework_baz",
            "framework_foo",
        ]
        expected = [
            {
                "name": "framework_foo",
                "help": "framework_foo is in known frameworks, so should be inculded",
            },
            {
                "name": "framework_baz",
                "help": "framework_baz is in known frameworks, so should be inculded",
            },
        ]

        actual = Feeder.get_frameworks()

        self.assertEqual(expected, actual)
예제 #15
0
    def test_get_metrics_for_onnxrt(
        self,
        mocked_check_module: MagicMock,
        mocked_load_help_lpot_params: MagicMock,
    ) -> None:
        """Test that get_domains fails when no config given."""
        mocked_load_help_lpot_params.return_value = {
            "__help__topk": "help for topk",
            "topk": {
                "__help__k": "help for k in topk",
                "__help__missing_param": "help for missing_param in topk",
            },
            "__help__metric1": "help for metric1",
            "__help__metric3": "help for metric3",
        }

        expected = [
            {
                "name": "topk",
                "help": "help for topk",
                "params": [
                    {
                        "name": "k",
                        "help": "help for k in topk",
                        "value": [1, 5],
                    },
                ],
            },
            {
                "name": "COCOmAP",
                "help": "",
                "params": [
                    {
                        "name": "anno_path",
                        "help": "",
                        "value": "",
                    },
                ],
            },
            {
                "name": "MSE",
                "help": "",
                "params": [
                    {
                        "name": "compare_label",
                        "help": "",
                        "value": True,
                    },
                ],
            },
            {
                "name": "RMSE",
                "help": "",
                "params": [
                    {
                        "name": "compare_label",
                        "help": "",
                        "value": True,
                    },
                ],
            },
            {
                "name": "MAE",
                "help": "",
                "params": [
                    {
                        "name": "compare_label",
                        "help": "",
                        "value": True,
                    },
                ],
            },
            {
                "name": "metric1",
                "help": "help for metric1",
                "value": None,
            },
            {
                "name": "custom",
                "help": "",
                "value": None,
            },
        ]

        feeder = Feeder(
            data={
                "config": {
                    "framework": "onnxrt",
                },
            },
        )

        actual = feeder.get_metrics()

        mocked_check_module.assert_called_once_with("onnxrt")
        mocked_load_help_lpot_params.assert_called_once_with("metrics")
        self.assertEqual(expected, actual)
예제 #16
0
 def test_get_metrics_fails_without_framework(self) -> None:
     """Test that get_domains fails when no config given."""
     feeder = Feeder(data={})
     with self.assertRaisesRegex(ClientErrorException, "Framework not set."):
         feeder.get_metrics()
예제 #17
0
 def test_feed_without_param_failes(self) -> None:
     """Test that calling feed when param is not set fails."""
     feeder = Feeder(data={})
     with self.assertRaisesRegex(ClientErrorException, "Parameter not defined."):
         feeder.feed()