def load(self, split, decoders, shuffle) -> tf.data.Dataset: """Return a `tf.data.Dataset` object representing the requested split.""" try: return tfds.load( name=self.name, split=split, data_dir=self.data_dir, download=self.download, decoders=decoders, as_dataset_kwargs={"shuffle_files": shuffle}, ) except AssertionError as e: if not self.download: utils.warn( f"Field 'download' of component {self.__class__.__name__} is False. " "If the TFDS dataset is not downloaded, set 'download' to True to " "call `download_and_prepare()` automatically.") raise e from None
def load(self, splits, decoders, shuffle) -> tf.data.Dataset: result = None for name, split in splits.items(): try: dataset = tfds.load( name=name, split=split, data_dir=self.data_dir, download=self.download, decoders=decoders, as_dataset_kwargs={"shuffle_files": shuffle}, ) except AssertionError as e: if not self.download: utils.warn( f"Field 'download' of component {self.__class__.__name__} is " "False. If the TFDS dataset is not downloaded, set 'download' " "to True to call `download_and_prepare()` automatically." ) raise e from None result = result.concatenate( dataset) if result is not None else dataset return result
def configure( instance, conf: Dict[str, Any], name: Optional[str] = None, interactive: bool = False, ): """ Configure the component instance with parameters from the `conf` dict. Configuration passed through `conf` takes precedence over and will overwrite any values already set on the instance - either class defaults or those set in `__init__`. """ # Only component instances can be configured. if not utils.is_component_instance(instance): raise TypeError( "Only @component, @factory, and @task instances can be configured. " f"Received: {instance}.") # Configuration can only happen once. if instance.__component_configured__: raise ValueError( f"Component '{instance.__component_name__}' has already been configured." ) if name is not None: instance.__component_name__ = name # Set the correct value for each field. for field in instance.__component_fields__.values(): full_name = f"{instance.__component_name__}.{field.name}" field_type_name = (field.type.__name__ if inspect.isclass(field.type) else str(field.type)) if isinstance(field, ComponentField): # Create a list of all component subclasses of the field type, and # add to the list all factory classes which can build the type (or # any subclass of the type). component_subclasses = list( utils.generate_component_subclasses(field.type)) for type_subclass in utils.generate_subclasses(field.type): component_subclasses.extend( FACTORY_REGISTRY.get(type_subclass, [])) if field.name in conf: conf_field_value = conf[field.name] if isinstance(field, ComponentField): # The configuration value could be a string specifying a component # or factory class to instantiate. if len(component_subclasses) > 0 and isinstance( conf_field_value, str): for subclass in component_subclasses: if (conf_field_value == subclass.__name__ or conf_field_value == subclass.__qualname__ or utils.convert_to_snake_case(conf_field_value) == utils.convert_to_snake_case( subclass.__name__)): conf_field_value = subclass() break # Set the value on the instance. instance.__component_configured_field_values__[ field.name] = conf_field_value # This value has now been 'consumed', so delete it from `conf`. del conf[field.name] # If there's a value in scope, we don't need to do anything. elif field.name in instance.__component_fields_with_values_in_scope__: pass # If the field explicitly allows values to be missing, there's no need # to do anything. elif field.allow_missing: pass # If there is only one concrete component subclass of the annotated # type, we assume the user must intend to use that subclass, and so # instantiate and use an instance automatically. elif isinstance(field, ComponentField) and len(component_subclasses) == 1: component_cls = list(component_subclasses)[0] utils.warn( f"'{utils.type_name_str(component_cls)}' is the only concrete component " f"class that satisfies the type of the annotated field '{full_name}'. " "Using an instance of this class by default.", ) # This is safe because we don't allow custom `__init__` methods. conf_field_value = component_cls() # Set the value on the instance. instance.__component_configured_field_values__[ field.name] = conf_field_value # If we are running interactively, prompt for a value. elif interactive: if isinstance(field, ComponentField): if len(component_subclasses) > 0: component_cls = utils.prompt_for_component_subclass( full_name, component_subclasses) # This is safe because we don't allow custom `__init__` methods. conf_field_value = component_cls() else: raise ValueError( "No component or factory class is defined which satisfies the " f"type {field_type_name} of field {full_name}. If such a class " "has been defined, it must be imported before calling " "`configure`.") else: conf_field_value = utils.prompt_for_value( full_name, field.type) # Set the value on the instance. instance.__component_configured_field_values__[ field.name] = conf_field_value # Otherwise, raise an appropriate error. else: if isinstance(field, ComponentField): if len(component_subclasses) > 0: raise ValueError( f"Component field '{full_name}' of type '{field_type_name}' " f"has no default or configured class. Please configure " f"'{full_name}' with one of the following @component or " "@factory classes:" + "\n ".join([""] + list( utils.type_name_str(c) for c in component_subclasses))) else: raise ValueError( f"Component field '{full_name}' of type '{field_type_name}' " f"has no default or configured class. No defined @component " "or @factory class satisfies this type. Please define an " f"@component class subclassing '{field_type_name}', or an " "@factory class with a `build()` method returning a " f"'{field_type_name}' instance. This class must be imported " "before invoking `configure()`.") raise ValueError( "No configuration value found for annotated field " f"'{full_name}' of type '{field_type_name}'.") # At this point we are certain that this field has has a value, so keep # track of that fact. instance.__component_fields_with_values_in_scope__.add(field.name) # Check that all `conf` values are being used, and throw if we've been # passed an un-used option. for key in conf: error = ValueError( f"Key '{key}' does not correspond to any field of component " f"'{instance.__component_name__}'." "\n\n" "If you have nested components as follows:\n\n" "```\n" "@component\n" "class ChildComponent:\n" " a: int = Field(0)\n" "\n" "@task\n" "class SomeTask:\n" " child: ChildComponent = ComponentField(ChildComponent)\n" " def run(self):\n" " print(self.child.a)\n" "```\n\n" "then trying to configure `a=<SOME_VALUE>` will fail. You instead need to " "fully qualify the key name, and configure the value with " "`child.a=<SOME_VALUE>`.") if "." in key: scoped_component_name = key.split(".")[0] if not (scoped_component_name in instance.__component_fields__ and isinstance( instance.__component_fields__[scoped_component_name], ComponentField)): raise error elif key not in instance.__component_fields__: raise error # Recursively configure any sub-components. for field in instance.__component_fields__.values(): if not isinstance(field, ComponentField): continue try: sub_component_instance = base_getattr(instance, field.name) except AttributeError as e: if field.allow_missing: continue raise e from None if not utils.is_component_instance(sub_component_instance): continue full_name = f"{instance.__component_name__}.{field.name}" if not sub_component_instance.__component_configured__: # Set the component parent so that inherited fields function # correctly. sub_component_instance.__component_parent__ = instance # Extend the field names in scope. All fields with values defined in # the scope of the parent are also accessible in the child. sub_component_instance.__component_fields_with_values_in_scope__ |= ( instance.__component_fields_with_values_in_scope__) # Configure the nested sub-component. The configuration we use # consists of all any keys scoped to `field.name`. field_name_scoped_conf = { a[len(f"{field.name}."):]: b for a, b in conf.items() if a.startswith(f"{field.name}.") } configure( sub_component_instance, field_name_scoped_conf, name=full_name, interactive=interactive, ) instance.__component_configured__ = True if hasattr(instance.__class__, "__post_configure__"): instance.__post_configure__()
def component(cls: Type): """A decorator which turns a class into a Zookeeper component.""" if not inspect.isclass(cls): raise TypeError("Only classes can be decorated with @component.") if inspect.isabstract(cls): raise TypeError( "Abstract classes cannot be decorated with @component.") if utils.is_component_class(cls): raise TypeError( f"The class {cls.__name__} is already a component; the @component decorator " "cannot be applied again.") if cls.__init__ not in (object.__init__, __component_init__): # A component class could have `__component_init__` as its init method # if it inherits from a component. raise TypeError( "Component classes must not define a custom `__init__` method.") cls.__init__ = __component_init__ if hasattr(cls, "__post_configure__"): if not callable(cls.__post_configure__): raise TypeError( "The `__post_configure__` attribute of a @component class must be a " "method.") call_args = inspect.signature(cls.__post_configure__).parameters if len(call_args) > 1 or len( call_args) == 1 and "self" not in call_args: raise TypeError( "The `__post_configure__` method of a @component class must take no " f"arguments except `self`, but `{cls.__name__}.__post_configure__` " f"accepts arguments {tuple(name for name in call_args)}.") # Populate `__component_fields__` with all fields defined on this class and # all superclasses. We have to go through the MRO chain and collect them in # reverse order so that they are correctly overriden. fields = {} for base_class in reversed(inspect.getmro(cls)): for name, value in base_class.__dict__.items(): if isinstance(value, Field): fields[name] = value if len(fields) == 0: utils.warn(f"Component {cls.__name__} has no defined fields.") # Throw an error if there is a field defined on a superclass that has been # overriden with a non-Field value. for name in dir(cls): if name in fields and not isinstance(getattr(cls, name), Field): super_class = fields[name].host_component_class raise ValueError( f"Field '{name}' is defined on super-class {super_class.__name__}. " f"In subclass {cls.__name__}, '{name}' has been overriden with value: " f"{getattr(cls, name)}.\n\n" f"If you wish to change the default value of field '{name}' in a " f"subclass of {super_class.__name__}, please wrap the new default " "value in a new `Field` instance.") cls.__component_fields__ = fields # Override class methods to correctly interact with component fields. _wrap_getattribute(cls) _wrap_setattr(cls) _wrap_delattr(cls) _wrap_dir(cls) # Components should have nice `__str__` and `__repr__` methods. cls.__str__ = __component_str__ cls.__repr__ = __component_repr__ # These will be overriden during configuration. cls.__component_name__ = cls.__name__ cls.__component_parent__ = None cls.__component_configured__ = False return cls
def configure_component_instance( instance, conf: Dict[str, Any], name: str, fields_in_scope: AbstractSet[str], interactive: bool, ): """Configure the component instance with parameters from the `conf` dict. This method is recursively called for each component instance in the component tree by the exported `configure` function. """ if name is not None: instance.__component_name__ = name if hasattr(instance.__class__, "__pre_configure__"): conf = instance.__pre_configure__({**conf}) if not isinstance(conf, dict): raise ValueError( "Expected the `__pre_configure__` method of component " f"'{instance.__component_name__}' to return a dict of configuration, " f"but received: {conf}" ) # Extend the field names in scope. instance.__component_fields_with_values_in_scope__ |= fields_in_scope # Set the correct value for each field. for field in instance.__component_fields__.values(): full_name = f"{instance.__component_name__}.{field.name}" field_type_name = ( field.type.__name__ if inspect.isclass(field.type) else str(field.type) ) if isinstance(field, ComponentField): # Create a list of all component subclasses of the field type, and # add to the list all factory classes which can build the type (or # any subclass of the type). component_subclasses = list(utils.generate_component_subclasses(field.type)) for type_subclass in utils.generate_subclasses(field.type): component_subclasses.extend(FACTORY_REGISTRY.get(type_subclass, [])) if field.name in conf: conf_field_value = conf[field.name] if isinstance(field, ComponentField): # The configuration value could be a string specifying a component # or factory class to instantiate. if len(component_subclasses) > 0 and isinstance(conf_field_value, str): for subclass in component_subclasses: if ( conf_field_value == subclass.__name__ or conf_field_value == subclass.__qualname__ or utils.convert_to_snake_case(conf_field_value) == utils.convert_to_snake_case(subclass.__name__) ): conf_field_value = subclass() break # If this isn't the case, then it's a user type error, but we don't # throw here and instead let the run-time type-checking take care of # it (which will provide a better error message). if utils.is_component_instance(conf_field_value): # Set the component parent so that field value inheritence will # work correctly. conf_field_value.__component_parent__ = instance # Set the value on the instance. instance.__component_configured_field_values__[ field.name ] = conf_field_value # If there's a value in scope, we don't need to do anything. elif field.name in instance.__component_fields_with_values_in_scope__: pass # If the field explicitly allows values to be missing, there's no need # to do anything. elif field.allow_missing: continue # If there is only one concrete component subclass of the annotated # type, we assume the user must intend to use that subclass, and so # instantiate and use an instance automatically. elif isinstance(field, ComponentField) and len(component_subclasses) == 1: component_cls = list(component_subclasses)[0] utils.warn( f"'{utils.type_name_str(component_cls)}' is the only concrete component " f"class that satisfies the type of the annotated field '{full_name}'. " "Using an instance of this class by default.", ) # This is safe because we don't allow custom `__init__` methods. conf_field_value = component_cls() # Set the component parent so that field value inheritence will work # correctly. conf_field_value.__component_parent__ = instance # Set the value on the instance. instance.__component_configured_field_values__[ field.name ] = conf_field_value # If we are running interactively, prompt for a value. elif interactive: if isinstance(field, ComponentField): if len(component_subclasses) > 0: component_cls = utils.prompt_for_component_subclass( full_name, component_subclasses ) # This is safe because we don't allow custom `__init__` methods. conf_field_value = component_cls() else: raise ValueError( "No component or factory class is defined which satisfies the " f"type {field_type_name} of field {full_name}. If such a class " "has been defined, it must be imported before calling " "`configure`." ) else: conf_field_value = utils.prompt_for_value(full_name, field.type) # Set the value on the instance. instance.__component_configured_field_values__[ field.name ] = conf_field_value # Otherwise, raise an appropriate error. else: if isinstance(field, ComponentField): if len(component_subclasses) > 0: raise ValueError( f"Component field '{full_name}' of type '{field_type_name}' " f"has no default or configured class. Please configure " f"'{full_name}' with one of the following @component or " "@factory classes:" + "\n ".join( [""] + list(utils.type_name_str(c) for c in component_subclasses) ) ) else: raise ValueError( f"Component field '{full_name}' of type '{field_type_name}' " f"has no default or configured class. No defined @component " "or @factory class satisfies this type. Please define an " f"@component class subclassing '{field_type_name}', or an " "@factory class with a `build()` method returning a " f"'{field_type_name}' instance. This class must be imported " "before invoking `configure()`." ) raise ValueError( "No configuration value found for annotated field " f"'{full_name}' of type '{field_type_name}'." ) # At this point we are certain that this field has has a value, so keep # track of that fact. instance.__component_fields_with_values_in_scope__.add(field.name) # Check that all `conf` values are being used, and throw if we've been # passed an un-used option. for key in conf: error = ValueError( f"Key '{key}' does not correspond to any field of component " f"'{instance.__component_name__}'." "\n\n" "If you have nested components as follows:\n\n" "```\n" "@component\n" "class ChildComponent:\n" " a: int = Field(0)\n" "\n" "@task\n" "class SomeTask:\n" " child: ChildComponent = ComponentField(ChildComponent)\n" " def run(self):\n" " print(self.child.a)\n" "```\n\n" "then trying to configure `a=<SOME_VALUE>` will fail. You instead need to " "fully qualify the key name, and configure the value with " "`child.a=<SOME_VALUE>`." ) if "." in key: scoped_component_name = key.split(".")[0] if not ( scoped_component_name in instance.__component_fields__ and isinstance( instance.__component_fields__[scoped_component_name], ComponentField ) ): raise error elif key not in instance.__component_fields__: raise error instance.__component_configured__ = True if hasattr(instance.__class__, "__post_configure__"): instance.__post_configure__()