コード例 #1
0
ファイル: test_params_feeder.py プロジェクト: intel/lpot
    def test_get_transforms(
        self,
        mocked_load_transforms_config: MagicMock,
        mocked_load_transforms_filter_config: MagicMock,
    ) -> None:
        """Test that get_transforms works."""
        mocked_load_transforms_config.return_value = [
            {
                "name": "framework_foo",
                "help": None,
                "params": [
                    {"name": "transform1"},
                    {"name": "transform2"},
                    {"name": "transform3"},
                ],
            },
            {
                "name": "framework_bar",
                "help": None,
                "params": [
                    {"name": "transform4"},
                    {"name": "transform5"},
                    {"name": "transform6"},
                ],
            },
        ]
        mocked_load_transforms_filter_config.return_value = {
            "framework_foo": {
                "domain_foo": [
                    "transform1",
                    "transform2",
                    "transform3",
                    "transform4",
                ],
                "domain_bar": [
                    "transform2",
                    "transform3",
                    "transform4",
                ],
            },
        }

        expected = [
            {"name": "transform2"},
            {"name": "transform3"},
        ]

        feeder = Feeder(
            data={
                "config": {
                    "framework": "framework_foo",
                    "domain": "domain_bar",
                },
            },
        )
        actual = feeder.get_transforms()

        mocked_load_transforms_config.assert_called_once()
        mocked_load_transforms_filter_config.assert_called_once()
        self.assertEqual(expected, actual)
コード例 #2
0
ファイル: test_params_feeder.py プロジェクト: vuiseng9/lpot
    def test_get_transforms_for_unknown_framework(
        self,
        mocked_load_transforms_config: MagicMock,
    ) -> None:
        """Test that get_transforms works when unknown framework requested."""
        mocked_load_transforms_config.return_value = [
            {
                "name":
                "framework_foo",
                "help":
                None,
                "params": [
                    {
                        "name": "transform1"
                    },
                    {
                        "name": "transform2"
                    },
                    {
                        "name": "transform3"
                    },
                ],
            },
            {
                "name":
                "framework_bar",
                "help":
                None,
                "params": [
                    {
                        "name": "transform4"
                    },
                    {
                        "name": "transform5"
                    },
                    {
                        "name": "transform6"
                    },
                ],
            },
        ]

        feeder = Feeder(data={
            "config": {
                "framework": "framework_unknown",
            },
        }, )
        actual = feeder.get_transforms()

        mocked_load_transforms_config.assert_called_once()
        self.assertEqual([], actual)
コード例 #3
0
ファイル: test_params_feeder.py プロジェクト: intel/lpot
 def test_get_transforms_fails_without_framework(self) -> None:
     """Test that get_transforms fails when no config given."""
     feeder = Feeder(data={})
     with self.assertRaisesRegex(ClientErrorException, "Framework not set."):
         feeder.get_transforms()