Example #1
0
    def load(self, split, decoders, shuffle) -> tf.data.Dataset:
        """Return a `tf.data.Dataset` object representing the requested split."""

        try:
            return tfds.load(
                name=self.name,
                split=split,
                data_dir=self.data_dir,
                download=self.download,
                decoders=decoders,
                as_dataset_kwargs={"shuffle_files": shuffle},
            )
        except AssertionError as e:
            if not self.download:
                utils.warn(
                    f"Field 'download' of component {self.__class__.__name__} is False. "
                    "If the TFDS dataset is not downloaded, set 'download' to True to "
                    "call `download_and_prepare()` automatically.")
            raise e from None
Example #2
0
 def load(self, splits, decoders, shuffle) -> tf.data.Dataset:
     result = None
     for name, split in splits.items():
         try:
             dataset = tfds.load(
                 name=name,
                 split=split,
                 data_dir=self.data_dir,
                 download=self.download,
                 decoders=decoders,
                 as_dataset_kwargs={"shuffle_files": shuffle},
             )
         except AssertionError as e:
             if not self.download:
                 utils.warn(
                     f"Field 'download' of component {self.__class__.__name__} is "
                     "False. If the TFDS dataset is not downloaded, set 'download' "
                     "to True to call `download_and_prepare()` automatically."
                 )
             raise e from None
         result = result.concatenate(
             dataset) if result is not None else dataset
     return result
Example #3
0
def configure(
    instance,
    conf: Dict[str, Any],
    name: Optional[str] = None,
    interactive: bool = False,
):
    """
    Configure the component instance with parameters from the `conf` dict.

    Configuration passed through `conf` takes precedence over and will
    overwrite any values already set on the instance - either class defaults
    or those set in `__init__`.
    """
    # Only component instances can be configured.
    if not utils.is_component_instance(instance):
        raise TypeError(
            "Only @component, @factory, and @task instances can be configured. "
            f"Received: {instance}.")

    # Configuration can only happen once.
    if instance.__component_configured__:
        raise ValueError(
            f"Component '{instance.__component_name__}' has already been configured."
        )

    if name is not None:
        instance.__component_name__ = name

    # Set the correct value for each field.
    for field in instance.__component_fields__.values():
        full_name = f"{instance.__component_name__}.{field.name}"
        field_type_name = (field.type.__name__
                           if inspect.isclass(field.type) else str(field.type))

        if isinstance(field, ComponentField):
            # Create a list of all component subclasses of the field type, and
            # add to the list all factory classes which can build the type (or
            # any subclass of the type).
            component_subclasses = list(
                utils.generate_component_subclasses(field.type))
            for type_subclass in utils.generate_subclasses(field.type):
                component_subclasses.extend(
                    FACTORY_REGISTRY.get(type_subclass, []))

        if field.name in conf:
            conf_field_value = conf[field.name]

            if isinstance(field, ComponentField):
                # The configuration value could be a string specifying a component
                # or factory class to instantiate.
                if len(component_subclasses) > 0 and isinstance(
                        conf_field_value, str):
                    for subclass in component_subclasses:
                        if (conf_field_value == subclass.__name__
                                or conf_field_value == subclass.__qualname__ or
                                utils.convert_to_snake_case(conf_field_value)
                                == utils.convert_to_snake_case(
                                    subclass.__name__)):
                            conf_field_value = subclass()
                            break

            # Set the value on the instance.
            instance.__component_configured_field_values__[
                field.name] = conf_field_value

            # This value has now been 'consumed', so delete it from `conf`.
            del conf[field.name]

        # If there's a value in scope, we don't need to do anything.
        elif field.name in instance.__component_fields_with_values_in_scope__:
            pass

        # If the field explicitly allows values to be missing, there's no need
        # to do anything.
        elif field.allow_missing:
            pass

        # If there is only one concrete component subclass of the annotated
        # type, we assume the user must intend to use that subclass, and so
        # instantiate and use an instance automatically.
        elif isinstance(field,
                        ComponentField) and len(component_subclasses) == 1:
            component_cls = list(component_subclasses)[0]
            utils.warn(
                f"'{utils.type_name_str(component_cls)}' is the only concrete component "
                f"class that satisfies the type of the annotated field '{full_name}'. "
                "Using an instance of this class by default.", )
            # This is safe because we don't allow custom `__init__` methods.
            conf_field_value = component_cls()

            # Set the value on the instance.
            instance.__component_configured_field_values__[
                field.name] = conf_field_value

        # If we are running interactively, prompt for a value.
        elif interactive:
            if isinstance(field, ComponentField):
                if len(component_subclasses) > 0:
                    component_cls = utils.prompt_for_component_subclass(
                        full_name, component_subclasses)
                    # This is safe because we don't allow custom `__init__` methods.
                    conf_field_value = component_cls()
                else:
                    raise ValueError(
                        "No component or factory class is defined which satisfies the "
                        f"type {field_type_name} of field {full_name}. If such a class "
                        "has been defined, it must be imported before calling "
                        "`configure`.")
            else:
                conf_field_value = utils.prompt_for_value(
                    full_name, field.type)

            # Set the value on the instance.
            instance.__component_configured_field_values__[
                field.name] = conf_field_value

        # Otherwise, raise an appropriate error.
        else:
            if isinstance(field, ComponentField):
                if len(component_subclasses) > 0:
                    raise ValueError(
                        f"Component field '{full_name}' of type '{field_type_name}' "
                        f"has no default or configured class. Please configure "
                        f"'{full_name}' with one of the following @component or "
                        "@factory classes:" + "\n    ".join([""] + list(
                            utils.type_name_str(c)
                            for c in component_subclasses)))
                else:
                    raise ValueError(
                        f"Component field '{full_name}' of type '{field_type_name}' "
                        f"has no default or configured class. No defined @component "
                        "or @factory class satisfies this type. Please define an "
                        f"@component class subclassing '{field_type_name}', or an "
                        "@factory class with a `build()` method returning a "
                        f"'{field_type_name}' instance. This class must be imported "
                        "before invoking `configure()`.")
            raise ValueError(
                "No configuration value found for annotated field "
                f"'{full_name}' of type '{field_type_name}'.")

        # At this point we are certain that this field has has a value, so keep
        # track of that fact.
        instance.__component_fields_with_values_in_scope__.add(field.name)

    # Check that all `conf` values are being used, and throw if we've been
    # passed an un-used option.
    for key in conf:
        error = ValueError(
            f"Key '{key}' does not correspond to any field of component "
            f"'{instance.__component_name__}'."
            "\n\n"
            "If you have nested components as follows:\n\n"
            "```\n"
            "@component\n"
            "class ChildComponent:\n"
            "    a: int = Field(0)\n"
            "\n"
            "@task\n"
            "class SomeTask:\n"
            "    child: ChildComponent = ComponentField(ChildComponent)\n"
            "    def run(self):\n"
            "        print(self.child.a)\n"
            "```\n\n"
            "then trying to configure `a=<SOME_VALUE>` will fail. You instead need to "
            "fully qualify the key name, and configure the value with "
            "`child.a=<SOME_VALUE>`.")

        if "." in key:
            scoped_component_name = key.split(".")[0]
            if not (scoped_component_name in instance.__component_fields__
                    and isinstance(
                        instance.__component_fields__[scoped_component_name],
                        ComponentField)):
                raise error
        elif key not in instance.__component_fields__:
            raise error

    # Recursively configure any sub-components.
    for field in instance.__component_fields__.values():
        if not isinstance(field, ComponentField):
            continue

        try:
            sub_component_instance = base_getattr(instance, field.name)
        except AttributeError as e:
            if field.allow_missing:
                continue
            raise e from None

        if not utils.is_component_instance(sub_component_instance):
            continue

        full_name = f"{instance.__component_name__}.{field.name}"

        if not sub_component_instance.__component_configured__:
            # Set the component parent so that inherited fields function
            # correctly.
            sub_component_instance.__component_parent__ = instance

            # Extend the field names in scope. All fields with values defined in
            # the scope of the parent are also accessible in the child.
            sub_component_instance.__component_fields_with_values_in_scope__ |= (
                instance.__component_fields_with_values_in_scope__)

            # Configure the nested sub-component. The configuration we use
            # consists of all any keys scoped to `field.name`.
            field_name_scoped_conf = {
                a[len(f"{field.name}."):]: b
                for a, b in conf.items() if a.startswith(f"{field.name}.")
            }
            configure(
                sub_component_instance,
                field_name_scoped_conf,
                name=full_name,
                interactive=interactive,
            )

    instance.__component_configured__ = True

    if hasattr(instance.__class__, "__post_configure__"):
        instance.__post_configure__()
Example #4
0
def component(cls: Type):
    """A decorator which turns a class into a Zookeeper component."""

    if not inspect.isclass(cls):
        raise TypeError("Only classes can be decorated with @component.")

    if inspect.isabstract(cls):
        raise TypeError(
            "Abstract classes cannot be decorated with @component.")

    if utils.is_component_class(cls):
        raise TypeError(
            f"The class {cls.__name__} is already a component; the @component decorator "
            "cannot be applied again.")

    if cls.__init__ not in (object.__init__, __component_init__):
        # A component class could have `__component_init__` as its init method
        # if it inherits from a component.
        raise TypeError(
            "Component classes must not define a custom `__init__` method.")
    cls.__init__ = __component_init__

    if hasattr(cls, "__post_configure__"):
        if not callable(cls.__post_configure__):
            raise TypeError(
                "The `__post_configure__` attribute of a @component class must be a "
                "method.")
        call_args = inspect.signature(cls.__post_configure__).parameters
        if len(call_args) > 1 or len(
                call_args) == 1 and "self" not in call_args:
            raise TypeError(
                "The `__post_configure__` method of a @component class must take no "
                f"arguments except `self`, but `{cls.__name__}.__post_configure__` "
                f"accepts arguments {tuple(name for name in call_args)}.")

    # Populate `__component_fields__` with all fields defined on this class and
    # all superclasses. We have to go through the MRO chain and collect them in
    # reverse order so that they are correctly overriden.
    fields = {}
    for base_class in reversed(inspect.getmro(cls)):
        for name, value in base_class.__dict__.items():
            if isinstance(value, Field):
                fields[name] = value

    if len(fields) == 0:
        utils.warn(f"Component {cls.__name__} has no defined fields.")

    # Throw an error if there is a field defined on a superclass that has been
    # overriden with a non-Field value.
    for name in dir(cls):
        if name in fields and not isinstance(getattr(cls, name), Field):
            super_class = fields[name].host_component_class
            raise ValueError(
                f"Field '{name}' is defined on super-class {super_class.__name__}. "
                f"In subclass {cls.__name__}, '{name}' has been overriden with value: "
                f"{getattr(cls, name)}.\n\n"
                f"If you wish to change the default value of field '{name}' in a "
                f"subclass of {super_class.__name__}, please wrap the new default "
                "value in a new `Field` instance.")

    cls.__component_fields__ = fields

    # Override class methods to correctly interact with component fields.
    _wrap_getattribute(cls)
    _wrap_setattr(cls)
    _wrap_delattr(cls)
    _wrap_dir(cls)

    # Components should have nice `__str__` and `__repr__` methods.
    cls.__str__ = __component_str__
    cls.__repr__ = __component_repr__

    # These will be overriden during configuration.
    cls.__component_name__ = cls.__name__
    cls.__component_parent__ = None
    cls.__component_configured__ = False

    return cls
Example #5
0
def configure_component_instance(
    instance,
    conf: Dict[str, Any],
    name: str,
    fields_in_scope: AbstractSet[str],
    interactive: bool,
):
    """Configure the component instance with parameters from the `conf` dict.

    This method is recursively called for each component instance in the component tree
    by the exported `configure` function.
    """
    if name is not None:
        instance.__component_name__ = name

    if hasattr(instance.__class__, "__pre_configure__"):
        conf = instance.__pre_configure__({**conf})
        if not isinstance(conf, dict):
            raise ValueError(
                "Expected the `__pre_configure__` method of component "
                f"'{instance.__component_name__}' to return a dict of configuration, "
                f"but received: {conf}"
            )

    # Extend the field names in scope.
    instance.__component_fields_with_values_in_scope__ |= fields_in_scope

    # Set the correct value for each field.
    for field in instance.__component_fields__.values():
        full_name = f"{instance.__component_name__}.{field.name}"
        field_type_name = (
            field.type.__name__ if inspect.isclass(field.type) else str(field.type)
        )

        if isinstance(field, ComponentField):
            # Create a list of all component subclasses of the field type, and
            # add to the list all factory classes which can build the type (or
            # any subclass of the type).
            component_subclasses = list(utils.generate_component_subclasses(field.type))
            for type_subclass in utils.generate_subclasses(field.type):
                component_subclasses.extend(FACTORY_REGISTRY.get(type_subclass, []))

        if field.name in conf:
            conf_field_value = conf[field.name]

            if isinstance(field, ComponentField):
                # The configuration value could be a string specifying a component
                # or factory class to instantiate.
                if len(component_subclasses) > 0 and isinstance(conf_field_value, str):
                    for subclass in component_subclasses:
                        if (
                            conf_field_value == subclass.__name__
                            or conf_field_value == subclass.__qualname__
                            or utils.convert_to_snake_case(conf_field_value)
                            == utils.convert_to_snake_case(subclass.__name__)
                        ):
                            conf_field_value = subclass()
                            break

            # If this isn't the case, then it's a user type error, but we don't
            # throw here and instead let the run-time type-checking take care of
            # it (which will provide a better error message).
            if utils.is_component_instance(conf_field_value):
                # Set the component parent so that field value inheritence will
                # work correctly.
                conf_field_value.__component_parent__ = instance

            # Set the value on the instance.
            instance.__component_configured_field_values__[
                field.name
            ] = conf_field_value

        # If there's a value in scope, we don't need to do anything.
        elif field.name in instance.__component_fields_with_values_in_scope__:
            pass

        # If the field explicitly allows values to be missing, there's no need
        # to do anything.
        elif field.allow_missing:
            continue

        # If there is only one concrete component subclass of the annotated
        # type, we assume the user must intend to use that subclass, and so
        # instantiate and use an instance automatically.
        elif isinstance(field, ComponentField) and len(component_subclasses) == 1:
            component_cls = list(component_subclasses)[0]
            utils.warn(
                f"'{utils.type_name_str(component_cls)}' is the only concrete component "
                f"class that satisfies the type of the annotated field '{full_name}'. "
                "Using an instance of this class by default.",
            )
            # This is safe because we don't allow custom `__init__` methods.
            conf_field_value = component_cls()
            # Set the component parent so that field value inheritence will work
            # correctly.
            conf_field_value.__component_parent__ = instance

            # Set the value on the instance.
            instance.__component_configured_field_values__[
                field.name
            ] = conf_field_value

        # If we are running interactively, prompt for a value.
        elif interactive:
            if isinstance(field, ComponentField):
                if len(component_subclasses) > 0:
                    component_cls = utils.prompt_for_component_subclass(
                        full_name, component_subclasses
                    )
                    # This is safe because we don't allow custom `__init__` methods.
                    conf_field_value = component_cls()
                else:
                    raise ValueError(
                        "No component or factory class is defined which satisfies the "
                        f"type {field_type_name} of field {full_name}. If such a class "
                        "has been defined, it must be imported before calling "
                        "`configure`."
                    )
            else:
                conf_field_value = utils.prompt_for_value(full_name, field.type)

            # Set the value on the instance.
            instance.__component_configured_field_values__[
                field.name
            ] = conf_field_value

        # Otherwise, raise an appropriate error.
        else:
            if isinstance(field, ComponentField):
                if len(component_subclasses) > 0:
                    raise ValueError(
                        f"Component field '{full_name}' of type '{field_type_name}' "
                        f"has no default or configured class. Please configure "
                        f"'{full_name}' with one of the following @component or "
                        "@factory classes:"
                        + "\n    ".join(
                            [""]
                            + list(utils.type_name_str(c) for c in component_subclasses)
                        )
                    )
                else:
                    raise ValueError(
                        f"Component field '{full_name}' of type '{field_type_name}' "
                        f"has no default or configured class. No defined @component "
                        "or @factory class satisfies this type. Please define an "
                        f"@component class subclassing '{field_type_name}', or an "
                        "@factory class with a `build()` method returning a "
                        f"'{field_type_name}' instance. This class must be imported "
                        "before invoking `configure()`."
                    )
            raise ValueError(
                "No configuration value found for annotated field "
                f"'{full_name}' of type '{field_type_name}'."
            )

        # At this point we are certain that this field has has a value, so keep
        # track of that fact.
        instance.__component_fields_with_values_in_scope__.add(field.name)

    # Check that all `conf` values are being used, and throw if we've been
    # passed an un-used option.
    for key in conf:
        error = ValueError(
            f"Key '{key}' does not correspond to any field of component "
            f"'{instance.__component_name__}'."
            "\n\n"
            "If you have nested components as follows:\n\n"
            "```\n"
            "@component\n"
            "class ChildComponent:\n"
            "    a: int = Field(0)\n"
            "\n"
            "@task\n"
            "class SomeTask:\n"
            "    child: ChildComponent = ComponentField(ChildComponent)\n"
            "    def run(self):\n"
            "        print(self.child.a)\n"
            "```\n\n"
            "then trying to configure `a=<SOME_VALUE>` will fail. You instead need to "
            "fully qualify the key name, and configure the value with "
            "`child.a=<SOME_VALUE>`."
        )

        if "." in key:
            scoped_component_name = key.split(".")[0]
            if not (
                scoped_component_name in instance.__component_fields__
                and isinstance(
                    instance.__component_fields__[scoped_component_name], ComponentField
                )
            ):
                raise error
        elif key not in instance.__component_fields__:
            raise error

    instance.__component_configured__ = True

    if hasattr(instance.__class__, "__post_configure__"):
        instance.__post_configure__()