def _setup(self, config: Dict): """Subclasses should override this for custom initialization.""" # Set this flag to False, if we find an error, or reduce self.name = config['name'] self.run_flag = True custom_modules = config['custom_modules'] setup_default_modules() import_modules(custom_modules) # Get the current computation block target_block_id = config['to_run'] self.block_id = target_block_id # Update the schemas with the configuration schemas: Dict[str, Schema] = Schema.deserialize(config['schemas']) schemas_copy = deepcopy(schemas) global_vars = config['global_vars'] self.verbose = config['verbose'] self.hyper_params = config['hyper_params'] self.debug = config['debug'] with TrialLogging(log_dir=self.logdir, verbose=self.verbose, console_prefix=self.block_id, hyper_params=self.hyper_params, capture_warnings=True): # Compile, activate links, and load checkpoints filled_schemas: Dict = OrderedDict() for block_id, schema_block in schemas.items(): block_params = config['params'][block_id] utils.update_schema_with_params(schemas_copy[block_id], block_params) # First activate links from previous blocks in the # pipeline utils.update_link_refs(schemas_copy, block_id, global_vars) block: Component = schemas_copy[block_id]() filled_schemas[block_id] = schemas_copy[block_id] if block_id in config['checkpoints']: # Get the block hash needed_set = utils.extract_needed_blocks( schemas, block_id, global_vars) needed_blocks = ((k, v) for k, v in filled_schemas.items() if k in needed_set) block_hash = repr(OrderedDict(needed_blocks)) # Check the mask, if it's False then we end # immediately mask_value = config['checkpoints'][block_id]['mask'][ block_hash] if mask_value is False: self.run_flag = False return # There should be a checkpoint checkpoint = config['checkpoints'][block_id]['paths'][ block_hash] state = load_state_from_file(checkpoint) block.load_state(state) # Holding compiled objects alongside schemas is okay # but not fully expressed in our type annotations. # TODO: fix this in our utils type annotations schemas_copy[block_id] = block # type: ignore # If everything went well, just compile self.block = schemas_copy[target_block_id] # Add tb prefix to computables in case multiple plots are # requested if not config['merge_plot']: self.block.tb_log_prefix = self.name
def load(path: str, map_location=None, auto_install=False, pickle_module=dill, **pickle_load_args): """Load object with state from the given path Loads a flambe object by using the saved config files, and then loads the saved state into said object. See `load_state_from_file` for details regarding how the state is loaded from the save file or directory. Parameters ---------- path : str Path to the save file or directory map_location : type Location (device) where items will be moved. ONLY used when the directory save format is used. See torch.load documentation for more details (the default is None). auto_install : bool If True, automatically installs extensions as needed. pickle_module : type Pickle module that has load and dump methods; dump should accept a pickle_protocol parameter (the default is dill). **pickle_load_args : type Additional args that `pickle_module` should use to load; see torch.load documentation for more details Returns ------- Component object with both the architecture (config) and state that was saved to path Raises ------ LoadError If a Component object is not loadable from the given path because extensions are not installed, or the config is empty, nonexistent, or otherwise invalid. """ state = load_state_from_file(path, map_location, pickle_module, **pickle_load_args) yaml_config = state._metadata[''][FLAMBE_CONFIG_KEY] stash = state._metadata[''][FLAMBE_STASH_KEY] \ if FLAMBE_STASH_KEY in state._metadata[''] else None setup_default_modules() yamls = list(yaml.load_all(yaml_config)) if yamls is None: raise LoadError( "Cannot load schema from empty config. This object may not have been saved" " for any of the following reasons:\n - The object was not created from a" "config or with compile method\n - The object originally linked to other" "objects that cannot be represented in YAML") if len(yamls) > 2: raise LoadError( f"{os.path.join(path, CONFIG_FILE_NAME)} should contain an (optional) " "extensions section and the main object.") if len(yamls) == 2: if yamls[0] is not None: extensions = dict(yamls[0]) custom_modules = extensions.keys() for x in custom_modules: if not is_installed_module(x): if auto_install: logger.warn( f"auto_install==True, installing missing Module " f"{x}: {extensions[x]}") install_extensions({x: extensions[x]}) logger.debug( f"Installed module {x} from {extensions[x]}") else: raise ImportError( f"Module {x} is required and not installed. Please 'pip install'" "the package containing the module or set auto_install flag" " to True.") import_modules([x]) logger.debug(f"Automatically imported {x}") # Reload with extensions' module imported (and registered) schema = list(yaml.load_all(yaml_config))[1] # Set the extensions to the schema so that they are # passed when compiling the component. schema.add_extensions_metadata(extensions) else: schema = yamls[1] elif len(yamls) == 1: schema = yamls[0] else: raise LoadError( "No config found at location; cannot load. Try just loading state with " "the function 'load_state_from_file'") if schema is None: raise LoadError( "Cannot load schema from empty config. This object may not have been saved" " for any of the following reasons:\n - The object was not created from a" "config or with compile method\n - The object originally linked to other" "objects that cannot be represented in YAML") _update_link_refs(schema) # TODO: maybe replace with instance check if solution to circular # dependency with component is found try: instance = schema(stash) except TypeError: raise LoadError( f"Loaded object is not callable - likely because an extension is not " f"installed. Check if {os.path.join(path, CONFIG_FILE_NAME)} has an " f"extensions section at the top and install as necessary. Alternatively " f"set auto_install=True") instance.load_state(state) return instance