def test_register_error(self): collection = {} @registry.register(collection, 'functions/func_0') def func_test0(): # pylint: disable=unused-variable pass with self.assertRaises(KeyError): @registry.register(collection, 'functions/func_0/sub_func') def func_test1(): # pylint: disable=unused-variable pass with self.assertRaises(LookupError): registry.lookup(collection, 'non-exist')
def test_register(self): collection = {} @registry.register(collection, 'functions/func_0') def func_test(): pass self.assertEqual(registry.lookup(collection, 'functions/func_0'), func_test) @registry.register(collection, 'classes/cls_0') class ClassRegistryKey: pass self.assertEqual( registry.lookup(collection, 'classes/cls_0'), ClassRegistryKey) @registry.register(collection, ClassRegistryKey) class ClassRegistryValue: pass self.assertEqual( registry.lookup(collection, ClassRegistryKey), ClassRegistryValue)
def get_task_cls(task_config_cls): task_cls = registry.lookup(_REGISTERED_TASK_CLS, task_config_cls) return task_cls
def get_exp_config_creater(exp_name: str): """Looks up ExperimentConfig factory methods.""" exp_creater = registry.lookup(_REGISTERED_CONFIGS, exp_name) return exp_creater
def get_task_cls(task_config: cfg.TaskConfig) -> Task: task_cls = registry.lookup(_REGISTERED_TASK_CLS, task_config) return task_cls
def get_data_loader(data_config): """Creates a data_loader from data_config.""" return registry.lookup(_REGISTERED_DATA_LOADER_CLS, data_config.__class__)(data_config)