예제 #1
0
    def test_get_dataloaders_for_unknown_framework(
        self,
        mocked_load_dataloader_config: MagicMock,
    ) -> None:
        """Test that get_dataloaders works when unknown framework requested."""
        mocked_load_dataloader_config.return_value = [
            {
                "name": "framework_foo",
                "help": None,
                "params": {},
            },
            {
                "name": "framework_bar",
                "help": None,
                "params": {},
            },
        ]

        feeder = Feeder(
            data={
                "config": {
                    "framework": "framework_unknown",
                },
            },
        )
        actual = feeder.get_dataloaders()

        mocked_load_dataloader_config.assert_called_once()
        self.assertEqual([], actual)
예제 #2
0
    def test_get_dataloaders(
        self,
        mocked_load_dataloader_config: MagicMock,
    ) -> None:
        """Test that get_dataloaders works."""
        params = [
            {
                "name": "param1"
            },
            {
                "name": "param2"
            },
            {
                "name": "param3"
            },
        ]
        mocked_load_dataloader_config.return_value = [
            {
                "name": "framework_foo",
                "help": None,
                "params": params,
            },
            {
                "name": "framework_bar",
                "help": None,
                "params": {},
            },
        ]

        expected = params

        feeder = Feeder(data={
            "config": {
                "framework": "framework_foo",
            },
        }, )
        actual = feeder.get_dataloaders()

        mocked_load_dataloader_config.assert_called_once()
        self.assertEqual(expected, actual)
예제 #3
0
 def test_get_dataloaders_fails_without_framework(self) -> None:
     """Test that get_dataloaders fails when no config given."""
     feeder = Feeder(data={})
     with self.assertRaisesRegex(ClientErrorException, "Framework not set."):
         feeder.get_dataloaders()