def test_get_dataloaders_for_unknown_framework( self, mocked_load_dataloader_config: MagicMock, ) -> None: """Test that get_dataloaders works when unknown framework requested.""" mocked_load_dataloader_config.return_value = [ { "name": "framework_foo", "help": None, "params": {}, }, { "name": "framework_bar", "help": None, "params": {}, }, ] feeder = Feeder( data={ "config": { "framework": "framework_unknown", }, }, ) actual = feeder.get_dataloaders() mocked_load_dataloader_config.assert_called_once() self.assertEqual([], actual)
def test_get_dataloaders( self, mocked_load_dataloader_config: MagicMock, ) -> None: """Test that get_dataloaders works.""" params = [ { "name": "param1" }, { "name": "param2" }, { "name": "param3" }, ] mocked_load_dataloader_config.return_value = [ { "name": "framework_foo", "help": None, "params": params, }, { "name": "framework_bar", "help": None, "params": {}, }, ] expected = params feeder = Feeder(data={ "config": { "framework": "framework_foo", }, }, ) actual = feeder.get_dataloaders() mocked_load_dataloader_config.assert_called_once() self.assertEqual(expected, actual)
def test_get_dataloaders_fails_without_framework(self) -> None: """Test that get_dataloaders fails when no config given.""" feeder = Feeder(data={}) with self.assertRaisesRegex(ClientErrorException, "Framework not set."): feeder.get_dataloaders()