def register_task_cls(task_config_cls): """Decorates a factory of Tasks for lookup by a subclass of TaskConfig. This decorator supports registration of tasks as follows: ``` @dataclasses.dataclass class MyTaskConfig(TaskConfig): # Add fields here. pass @register_task_cls(MyTaskConfig) class MyTask(Task): # Inherits def __init__(self, task_config). pass my_task_config = MyTaskConfig() my_task = get_task(my_task_config) # Returns MyTask(my_task_config). ``` Besisdes a class itself, other callables that create a Task from a TaskConfig can be decorated by the result of this function, as long as there is at most one registration for each config class. Args: task_config_cls: a subclass of TaskConfig (*not* an instance of TaskConfig). Each task_config_cls can only be used for a single registration. Returns: A callable for use as class decorator that registers the decorated class for creation from an instance of task_config_cls. """ return registry.register(_REGISTERED_TASK_CLS, task_config_cls)
def register_data_loader_cls(data_config_cls): """Decorates a factory of DataLoader for lookup by a subclass of DataConfig. This decorator supports registration of data loaders as follows: ``` @dataclasses.dataclass class MyDataConfig(DataConfig): # Add fields here. pass @register_data_loader_cls(MyDataConfig) class MyDataLoader: # Inherits def __init__(self, data_config). pass my_data_config = MyDataConfig() # Returns MyDataLoader(my_data_config). my_loader = get_data_loader(my_data_config) ``` Args: data_config_cls: a subclass of DataConfig (*not* an instance of DataConfig). Returns: A callable for use as class decorator that registers the decorated class for creation from an instance of data_config_cls. """ return registry.register(_REGISTERED_DATA_LOADER_CLS, data_config_cls)
def register_config_factory(name): """Register ExperimentConfig factory method.""" return registry.register(_REGISTERED_CONFIGS, name)
def register_task_cls(task_config: cfg.TaskConfig) -> Task: """Register ExperimentConfig factory method.""" return registry.register(_REGISTERED_TASK_CLS, task_config)