def preprocess(self, secrets: Optional[str] = None, download_ext: bool = True, install_ext: bool = False, import_ext: bool = True, check_tags: bool = True, **kwargs) -> Tuple[Runnable, Dict[str, str]]: """Preprocess the runnable file. Looks for syntax errors, import errors, etc. Also injects the secrets into the runnables. If this method runs and ends without exceptions, then the experiment is ok to be run. If this method raises an Error and the SafeExecutionContext is used as context manager, then the __exit__ method will be executed. Parameters ---------- secrets: Optional[str] Optional path to the secrets file install_ext: bool Whether to install the extensions or not. This process also downloads the remote extensions. Defaults to False install_ext: bool Whether to import the extensions or not. Defaults to True. check_tags: bool Whether to check that all tags are valid. Defaults to True. Returns ------- Tuple[Runnable, Dict[str, str]] A tuple containing the compiled Runnable and a dict containing the extensions the Runnable uses. Raises ------ Exception Depending on the error. """ content, extensions = self.first_parse() config = configparser.ConfigParser() if secrets: config.read(secrets) if install_ext: t = os.path.join(FLAMBE_GLOBAL_FOLDER, "extensions") extensions = download_extensions(extensions, t) install_extensions(extensions, user_flag=False) if import_ext: import_modules(extensions.keys()) # Check that all tags are valid if check_tags: self.check_tags(content) # Compile the runnable now that the extensions were imported. runnable = self.compile_runnable(content) if secrets: runnable.inject_secrets(secrets) if extensions: runnable.inject_extensions(extensions) runnable.parse() return runnable, extensions
def test_exporter_builder(): with tmpdir() as d, tmpdir() as d2, tmpfile( mode="w", suffix=".yaml") as f, tmpfile(mode="w", suffix=".yaml") as f2: # First run an experiment exp = """ !Experiment name: exporter save_path: {} pipeline: dataset: !SSTDataset transform: text: !TextField label: !LabelField model: !TextClassifier embedder: !Embedder embedding: !torch.Embedding num_embeddings: !@ dataset.text.vocab_size embedding_dim: 30 encoder: !PooledRNNEncoder input_size: 30 rnn_type: lstm n_layers: 1 hidden_size: 16 output_layer: !SoftmaxLayer input_size: !@ model[embedder].encoder.rnn.hidden_size output_size: !@ dataset.label.vocab_size exporter: !Exporter model: !@ model text: !@ dataset.text """ exp = exp.format(d) f.write(exp) f.flush() ret = subprocess.run(['flambe', f.name, '-i']) assert ret.returncode == 0 # Then run a builder builder = """ flambe_inference: tests/data/dummy_extensions/inference/ --- !Builder destination: {0} component: !flambe_inference.DummyInferenceEngine model: !TextClassifier.load_from_path path: {1} """ base = os.path.join(d, "output__exporter", "exporter") path_aux = [ x for x in os.listdir(base) if os.path.isdir(os.path.join(base, x)) ][0] # Should be only 1 folder bc of no variants model_path = os.path.join(base, path_aux, "checkpoint", "checkpoint.flambe", "model") builder = builder.format(d2, model_path) f2.write(builder) f2.flush() ret = subprocess.run(['flambe', f2.name, '-i']) assert ret.returncode == 0 # The extensions needs to be imported using extensions.py module extensions.import_modules(["flambe_inference"]) # Import the module after import_modules (which registered tags already) from flambe_inference import DummyInferenceEngine eng1 = flambe.load(d2) assert type(eng1) is DummyInferenceEngine assert type(eng1.model) is TextClassifier extension_path = os.path.join( os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "tests/data/dummy_extensions/inference") assert eng1._extensions == {"flambe_inference": extension_path} eng2 = DummyInferenceEngine.load_from_path(d2) assert type(eng2) is DummyInferenceEngine assert type(eng2.model) is TextClassifier assert eng2._extensions == {"flambe_inference": extension_path} assert module_equals(eng1.model, eng2.model)
def _setup(self, config: Dict): """Subclasses should override this for custom initialization.""" # Set this flag to False, if we find an error, or reduce self.name = config['name'] self.run_flag = True custom_modules = config['custom_modules'] setup_default_modules() import_modules(custom_modules) # Get the current computation block target_block_id = config['to_run'] self.block_id = target_block_id # Update the schemas with the configuration schemas: Dict[str, Schema] = Schema.deserialize(config['schemas']) schemas_copy = deepcopy(schemas) global_vars = config['global_vars'] self.verbose = config['verbose'] self.hyper_params = config['hyper_params'] self.debug = config['debug'] with TrialLogging(log_dir=self.logdir, verbose=self.verbose, console_prefix=self.block_id, hyper_params=self.hyper_params, capture_warnings=True): # Compile, activate links, and load checkpoints filled_schemas: Dict = OrderedDict() for block_id, schema_block in schemas.items(): block_params = config['params'][block_id] utils.update_schema_with_params(schemas_copy[block_id], block_params) # First activate links from previous blocks in the # pipeline utils.update_link_refs(schemas_copy, block_id, global_vars) block: Component = schemas_copy[block_id]() filled_schemas[block_id] = schemas_copy[block_id] if block_id in config['checkpoints']: # Get the block hash needed_set = utils.extract_needed_blocks( schemas, block_id, global_vars) needed_blocks = ((k, v) for k, v in filled_schemas.items() if k in needed_set) block_hash = repr(OrderedDict(needed_blocks)) # Check the mask, if it's False then we end # immediately mask_value = config['checkpoints'][block_id]['mask'][ block_hash] if mask_value is False: self.run_flag = False return # There should be a checkpoint checkpoint = config['checkpoints'][block_id]['paths'][ block_hash] state = load_state_from_file(checkpoint) block.load_state(state) # Holding compiled objects alongside schemas is okay # but not fully expressed in our type annotations. # TODO: fix this in our utils type annotations schemas_copy[block_id] = block # type: ignore # If everything went well, just compile self.block = schemas_copy[target_block_id] # Add tb prefix to computables in case multiple plots are # requested if not config['merge_plot']: self.block.tb_log_prefix = self.name
def load(path: str, map_location=None, auto_install=False, pickle_module=dill, **pickle_load_args): """Load object with state from the given path Loads a flambe object by using the saved config files, and then loads the saved state into said object. See `load_state_from_file` for details regarding how the state is loaded from the save file or directory. Parameters ---------- path : str Path to the save file or directory map_location : type Location (device) where items will be moved. ONLY used when the directory save format is used. See torch.load documentation for more details (the default is None). auto_install : bool If True, automatically installs extensions as needed. pickle_module : type Pickle module that has load and dump methods; dump should accept a pickle_protocol parameter (the default is dill). **pickle_load_args : type Additional args that `pickle_module` should use to load; see torch.load documentation for more details Returns ------- Component object with both the architecture (config) and state that was saved to path Raises ------ LoadError If a Component object is not loadable from the given path because extensions are not installed, or the config is empty, nonexistent, or otherwise invalid. """ state = load_state_from_file(path, map_location, pickle_module, **pickle_load_args) yaml_config = state._metadata[''][FLAMBE_CONFIG_KEY] stash = state._metadata[''][FLAMBE_STASH_KEY] \ if FLAMBE_STASH_KEY in state._metadata[''] else None setup_default_modules() yamls = list(yaml.load_all(yaml_config)) if yamls is None: raise LoadError( "Cannot load schema from empty config. This object may not have been saved" " for any of the following reasons:\n - The object was not created from a" "config or with compile method\n - The object originally linked to other" "objects that cannot be represented in YAML") if len(yamls) > 2: raise LoadError( f"{os.path.join(path, CONFIG_FILE_NAME)} should contain an (optional) " "extensions section and the main object.") if len(yamls) == 2: if yamls[0] is not None: extensions = dict(yamls[0]) custom_modules = extensions.keys() for x in custom_modules: if not is_installed_module(x): if auto_install: logger.warn( f"auto_install==True, installing missing Module " f"{x}: {extensions[x]}") install_extensions({x: extensions[x]}) logger.debug( f"Installed module {x} from {extensions[x]}") else: raise ImportError( f"Module {x} is required and not installed. Please 'pip install'" "the package containing the module or set auto_install flag" " to True.") import_modules([x]) logger.debug(f"Automatically imported {x}") # Reload with extensions' module imported (and registered) schema = list(yaml.load_all(yaml_config))[1] # Set the extensions to the schema so that they are # passed when compiling the component. schema.add_extensions_metadata(extensions) else: schema = yamls[1] elif len(yamls) == 1: schema = yamls[0] else: raise LoadError( "No config found at location; cannot load. Try just loading state with " "the function 'load_state_from_file'") if schema is None: raise LoadError( "Cannot load schema from empty config. This object may not have been saved" " for any of the following reasons:\n - The object was not created from a" "config or with compile method\n - The object originally linked to other" "objects that cannot be represented in YAML") _update_link_refs(schema) # TODO: maybe replace with instance check if solution to circular # dependency with component is found try: instance = schema(stash) except TypeError: raise LoadError( f"Loaded object is not callable - likely because an extension is not " f"installed. Check if {os.path.join(path, CONFIG_FILE_NAME)} has an " f"extensions section at the top and install as necessary. Alternatively " f"set auto_install=True") instance.load_state(state) return instance