예제 #1
0
def load(path: str,
         map_location=None,
         auto_install=False,
         pickle_module=dill,
         **pickle_load_args):
    """Load object with state from the given path

    Loads a flambe object by using the saved config files, and then
    loads the saved state into said object. See `load_state_from_file`
    for details regarding how the state is loaded from the save file or
    directory.

    Parameters
    ----------
    path : str
        Path to the save file or directory
    map_location : type
        Location (device) where items will be moved. ONLY used when the
        directory save format is used. See torch.load documentation for
        more details (the default is None).
    auto_install : bool
        If True, automatically installs extensions as needed.
    pickle_module : type
        Pickle module that has load and dump methods; dump should
        accept a pickle_protocol parameter (the default is dill).
    **pickle_load_args : type
        Additional args that `pickle_module` should use to load; see
        torch.load documentation for more details

    Returns
    -------
    Component
        object with both the architecture (config) and state that was
        saved to path

    Raises
    ------
    LoadError
        If a Component object is not loadable from the given path
        because extensions are not installed, or the config is empty,
        nonexistent, or otherwise invalid.

    """
    state = load_state_from_file(path, map_location, pickle_module,
                                 **pickle_load_args)
    yaml_config = state._metadata[''][FLAMBE_CONFIG_KEY]
    stash = state._metadata[''][FLAMBE_STASH_KEY] \
        if FLAMBE_STASH_KEY in state._metadata[''] else None
    setup_default_modules()
    yamls = list(yaml.load_all(yaml_config))

    if yamls is None:
        raise LoadError(
            "Cannot load schema from empty config. This object may not have been saved"
            " for any of the following reasons:\n - The object was not created from a"
            "config or with compile method\n - The object originally linked to other"
            "objects that cannot be represented in YAML")
    if len(yamls) > 2:
        raise LoadError(
            f"{os.path.join(path, CONFIG_FILE_NAME)} should contain an (optional) "
            "extensions section and the main object.")
    if len(yamls) == 2:
        if yamls[0] is not None:
            extensions = dict(yamls[0])
            custom_modules = extensions.keys()
            for x in custom_modules:
                if not is_installed_module(x):
                    if auto_install:
                        logger.warn(
                            f"auto_install==True, installing missing Module "
                            f"{x}: {extensions[x]}")
                        install_extensions({x: extensions[x]})
                        logger.debug(
                            f"Installed module {x} from {extensions[x]}")
                    else:
                        raise ImportError(
                            f"Module {x} is required and not installed. Please 'pip install'"
                            "the package containing the module or set auto_install flag"
                            " to True.")
                import_modules([x])
                logger.debug(f"Automatically imported {x}")

            # Reload with extensions' module imported (and registered)
            schema = list(yaml.load_all(yaml_config))[1]

            # Set the extensions to the schema so that they are
            # passed when compiling the component.
            schema.add_extensions_metadata(extensions)
        else:
            schema = yamls[1]
    elif len(yamls) == 1:
        schema = yamls[0]
    else:
        raise LoadError(
            "No config found at location; cannot load. Try just loading state with "
            "the function 'load_state_from_file'")

    if schema is None:
        raise LoadError(
            "Cannot load schema from empty config. This object may not have been saved"
            " for any of the following reasons:\n - The object was not created from a"
            "config or with compile method\n - The object originally linked to other"
            "objects that cannot be represented in YAML")

    _update_link_refs(schema)
    # TODO: maybe replace with instance check if solution to circular
    # dependency with component is found

    try:
        instance = schema(stash)
    except TypeError:
        raise LoadError(
            f"Loaded object is not callable - likely because an extension is not "
            f"installed. Check if {os.path.join(path, CONFIG_FILE_NAME)} has an "
            f"extensions section at the top and install as necessary. Alternatively "
            f"set auto_install=True")
    instance.load_state(state)
    return instance
예제 #2
0
def test_is_installed_module():
    assert exts.is_installed_module("pytest") is True
    assert exts.is_installed_module(
        "some_inexistent_package_0987654321") is False