예제 #1
0
    def preprocess(self,
                   secrets: Optional[str] = None,
                   download_ext: bool = True,
                   install_ext: bool = False,
                   import_ext: bool = True,
                   check_tags: bool = True,
                   **kwargs) -> Tuple[Runnable, Dict[str, str]]:
        """Preprocess the runnable file.

        Looks for syntax errors, import errors, etc. Also injects
        the secrets into the runnables.

        If this method runs and ends without exceptions, then the
        experiment is ok to be run. If this method raises an Error and
        the SafeExecutionContext is used as context manager,
        then the __exit__ method will be executed.

        Parameters
        ----------
        secrets: Optional[str]
            Optional path to the secrets file
        install_ext: bool
            Whether to install the extensions or not.
            This process also downloads the remote extensions.
            Defaults to False
        install_ext: bool
            Whether to import the extensions or not.
            Defaults to True.
        check_tags: bool
            Whether to check that all tags are valid. Defaults to True.

        Returns
        -------
        Tuple[Runnable, Dict[str, str]]
            A tuple containing the compiled Runnable and a dict
            containing the extensions the Runnable uses.

        Raises
        ------
        Exception
            Depending on the error.

        """
        content, extensions = self.first_parse()

        config = configparser.ConfigParser()
        if secrets:
            config.read(secrets)

        if install_ext:
            t = os.path.join(FLAMBE_GLOBAL_FOLDER, "extensions")
            extensions = download_extensions(extensions, t)
            install_extensions(extensions, user_flag=False)

        if import_ext:
            import_modules(extensions.keys())

        # Check that all tags are valid
        if check_tags:
            self.check_tags(content)

        # Compile the runnable now that the extensions were imported.
        runnable = self.compile_runnable(content)

        if secrets:
            runnable.inject_secrets(secrets)

        if extensions:
            runnable.inject_extensions(extensions)

        runnable.parse()

        return runnable, extensions
예제 #2
0
def load(path: str,
         map_location=None,
         auto_install=False,
         pickle_module=dill,
         **pickle_load_args):
    """Load object with state from the given path

    Loads a flambe object by using the saved config files, and then
    loads the saved state into said object. See `load_state_from_file`
    for details regarding how the state is loaded from the save file or
    directory.

    Parameters
    ----------
    path : str
        Path to the save file or directory
    map_location : type
        Location (device) where items will be moved. ONLY used when the
        directory save format is used. See torch.load documentation for
        more details (the default is None).
    auto_install : bool
        If True, automatically installs extensions as needed.
    pickle_module : type
        Pickle module that has load and dump methods; dump should
        accept a pickle_protocol parameter (the default is dill).
    **pickle_load_args : type
        Additional args that `pickle_module` should use to load; see
        torch.load documentation for more details

    Returns
    -------
    Component
        object with both the architecture (config) and state that was
        saved to path

    Raises
    ------
    LoadError
        If a Component object is not loadable from the given path
        because extensions are not installed, or the config is empty,
        nonexistent, or otherwise invalid.

    """
    state = load_state_from_file(path, map_location, pickle_module,
                                 **pickle_load_args)
    yaml_config = state._metadata[''][FLAMBE_CONFIG_KEY]
    stash = state._metadata[''][FLAMBE_STASH_KEY] \
        if FLAMBE_STASH_KEY in state._metadata[''] else None
    setup_default_modules()
    yamls = list(yaml.load_all(yaml_config))

    if yamls is None:
        raise LoadError(
            "Cannot load schema from empty config. This object may not have been saved"
            " for any of the following reasons:\n - The object was not created from a"
            "config or with compile method\n - The object originally linked to other"
            "objects that cannot be represented in YAML")
    if len(yamls) > 2:
        raise LoadError(
            f"{os.path.join(path, CONFIG_FILE_NAME)} should contain an (optional) "
            "extensions section and the main object.")
    if len(yamls) == 2:
        if yamls[0] is not None:
            extensions = dict(yamls[0])
            custom_modules = extensions.keys()
            for x in custom_modules:
                if not is_installed_module(x):
                    if auto_install:
                        logger.warn(
                            f"auto_install==True, installing missing Module "
                            f"{x}: {extensions[x]}")
                        install_extensions({x: extensions[x]})
                        logger.debug(
                            f"Installed module {x} from {extensions[x]}")
                    else:
                        raise ImportError(
                            f"Module {x} is required and not installed. Please 'pip install'"
                            "the package containing the module or set auto_install flag"
                            " to True.")
                import_modules([x])
                logger.debug(f"Automatically imported {x}")

            # Reload with extensions' module imported (and registered)
            schema = list(yaml.load_all(yaml_config))[1]

            # Set the extensions to the schema so that they are
            # passed when compiling the component.
            schema.add_extensions_metadata(extensions)
        else:
            schema = yamls[1]
    elif len(yamls) == 1:
        schema = yamls[0]
    else:
        raise LoadError(
            "No config found at location; cannot load. Try just loading state with "
            "the function 'load_state_from_file'")

    if schema is None:
        raise LoadError(
            "Cannot load schema from empty config. This object may not have been saved"
            " for any of the following reasons:\n - The object was not created from a"
            "config or with compile method\n - The object originally linked to other"
            "objects that cannot be represented in YAML")

    _update_link_refs(schema)
    # TODO: maybe replace with instance check if solution to circular
    # dependency with component is found

    try:
        instance = schema(stash)
    except TypeError:
        raise LoadError(
            f"Loaded object is not callable - likely because an extension is not "
            f"installed. Check if {os.path.join(path, CONFIG_FILE_NAME)} has an "
            f"extensions section at the top and install as necessary. Alternatively "
            f"set auto_install=True")
    instance.load_state(state)
    return instance