Esempio n. 1
0
    def _setup(self, config: Dict):
        """Subclasses should override this for custom initialization."""
        # Set this flag to False, if we find an error, or reduce
        self.name = config['name']
        self.run_flag = True
        custom_modules = config['custom_modules']
        setup_default_modules()
        import_modules(custom_modules)
        # Get the current computation block
        target_block_id = config['to_run']
        self.block_id = target_block_id
        # Update the schemas with the configuration
        schemas: Dict[str, Schema] = Schema.deserialize(config['schemas'])
        schemas_copy = deepcopy(schemas)
        global_vars = config['global_vars']
        self.verbose = config['verbose']
        self.hyper_params = config['hyper_params']
        self.debug = config['debug']

        with TrialLogging(log_dir=self.logdir,
                          verbose=self.verbose,
                          console_prefix=self.block_id,
                          hyper_params=self.hyper_params,
                          capture_warnings=True):

            # Compile, activate links, and load checkpoints
            filled_schemas: Dict = OrderedDict()
            for block_id, schema_block in schemas.items():

                block_params = config['params'][block_id]

                utils.update_schema_with_params(schemas_copy[block_id],
                                                block_params)

                # First activate links from previous blocks in the
                # pipeline
                utils.update_link_refs(schemas_copy, block_id, global_vars)
                block: Component = schemas_copy[block_id]()
                filled_schemas[block_id] = schemas_copy[block_id]

                if block_id in config['checkpoints']:
                    # Get the block hash
                    needed_set = utils.extract_needed_blocks(
                        schemas, block_id, global_vars)
                    needed_blocks = ((k, v) for k, v in filled_schemas.items()
                                     if k in needed_set)
                    block_hash = repr(OrderedDict(needed_blocks))

                    # Check the mask, if it's False then we end
                    # immediately
                    mask_value = config['checkpoints'][block_id]['mask'][
                        block_hash]
                    if mask_value is False:
                        self.run_flag = False
                        return

                    # There should be a checkpoint
                    checkpoint = config['checkpoints'][block_id]['paths'][
                        block_hash]
                    state = load_state_from_file(checkpoint)
                    block.load_state(state)

                # Holding compiled objects alongside schemas is okay
                # but not fully expressed in our type annotations.
                # TODO: fix this in our utils type annotations
                schemas_copy[block_id] = block  # type: ignore

        # If everything went well, just compile
        self.block = schemas_copy[target_block_id]

        # Add tb prefix to computables in case multiple plots are
        # requested
        if not config['merge_plot']:
            self.block.tb_log_prefix = self.name
Esempio n. 2
0
def load(path: str,
         map_location=None,
         auto_install=False,
         pickle_module=dill,
         **pickle_load_args):
    """Load object with state from the given path

    Loads a flambe object by using the saved config files, and then
    loads the saved state into said object. See `load_state_from_file`
    for details regarding how the state is loaded from the save file or
    directory.

    Parameters
    ----------
    path : str
        Path to the save file or directory
    map_location : type
        Location (device) where items will be moved. ONLY used when the
        directory save format is used. See torch.load documentation for
        more details (the default is None).
    auto_install : bool
        If True, automatically installs extensions as needed.
    pickle_module : type
        Pickle module that has load and dump methods; dump should
        accept a pickle_protocol parameter (the default is dill).
    **pickle_load_args : type
        Additional args that `pickle_module` should use to load; see
        torch.load documentation for more details

    Returns
    -------
    Component
        object with both the architecture (config) and state that was
        saved to path

    Raises
    ------
    LoadError
        If a Component object is not loadable from the given path
        because extensions are not installed, or the config is empty,
        nonexistent, or otherwise invalid.

    """
    state = load_state_from_file(path, map_location, pickle_module,
                                 **pickle_load_args)
    yaml_config = state._metadata[''][FLAMBE_CONFIG_KEY]
    stash = state._metadata[''][FLAMBE_STASH_KEY] \
        if FLAMBE_STASH_KEY in state._metadata[''] else None
    setup_default_modules()
    yamls = list(yaml.load_all(yaml_config))

    if yamls is None:
        raise LoadError(
            "Cannot load schema from empty config. This object may not have been saved"
            " for any of the following reasons:\n - The object was not created from a"
            "config or with compile method\n - The object originally linked to other"
            "objects that cannot be represented in YAML")
    if len(yamls) > 2:
        raise LoadError(
            f"{os.path.join(path, CONFIG_FILE_NAME)} should contain an (optional) "
            "extensions section and the main object.")
    if len(yamls) == 2:
        if yamls[0] is not None:
            extensions = dict(yamls[0])
            custom_modules = extensions.keys()
            for x in custom_modules:
                if not is_installed_module(x):
                    if auto_install:
                        logger.warn(
                            f"auto_install==True, installing missing Module "
                            f"{x}: {extensions[x]}")
                        install_extensions({x: extensions[x]})
                        logger.debug(
                            f"Installed module {x} from {extensions[x]}")
                    else:
                        raise ImportError(
                            f"Module {x} is required and not installed. Please 'pip install'"
                            "the package containing the module or set auto_install flag"
                            " to True.")
                import_modules([x])
                logger.debug(f"Automatically imported {x}")

            # Reload with extensions' module imported (and registered)
            schema = list(yaml.load_all(yaml_config))[1]

            # Set the extensions to the schema so that they are
            # passed when compiling the component.
            schema.add_extensions_metadata(extensions)
        else:
            schema = yamls[1]
    elif len(yamls) == 1:
        schema = yamls[0]
    else:
        raise LoadError(
            "No config found at location; cannot load. Try just loading state with "
            "the function 'load_state_from_file'")

    if schema is None:
        raise LoadError(
            "Cannot load schema from empty config. This object may not have been saved"
            " for any of the following reasons:\n - The object was not created from a"
            "config or with compile method\n - The object originally linked to other"
            "objects that cannot be represented in YAML")

    _update_link_refs(schema)
    # TODO: maybe replace with instance check if solution to circular
    # dependency with component is found

    try:
        instance = schema(stash)
    except TypeError:
        raise LoadError(
            f"Loaded object is not callable - likely because an extension is not "
            f"installed. Check if {os.path.join(path, CONFIG_FILE_NAME)} has an "
            f"extensions section at the top and install as necessary. Alternatively "
            f"set auto_install=True")
    instance.load_state(state)
    return instance