Exemple #1
0
    def preprocess(self,
                   secrets: Optional[str] = None,
                   download_ext: bool = True,
                   install_ext: bool = False,
                   import_ext: bool = True,
                   check_tags: bool = True,
                   **kwargs) -> Tuple[Runnable, Dict[str, str]]:
        """Preprocess the runnable file.

        Looks for syntax errors, import errors, etc. Also injects
        the secrets into the runnables.

        If this method runs and ends without exceptions, then the
        experiment is ok to be run. If this method raises an Error and
        the SafeExecutionContext is used as context manager,
        then the __exit__ method will be executed.

        Parameters
        ----------
        secrets: Optional[str]
            Optional path to the secrets file
        install_ext: bool
            Whether to install the extensions or not.
            This process also downloads the remote extensions.
            Defaults to False
        install_ext: bool
            Whether to import the extensions or not.
            Defaults to True.
        check_tags: bool
            Whether to check that all tags are valid. Defaults to True.

        Returns
        -------
        Tuple[Runnable, Dict[str, str]]
            A tuple containing the compiled Runnable and a dict
            containing the extensions the Runnable uses.

        Raises
        ------
        Exception
            Depending on the error.

        """
        content, extensions = self.first_parse()

        config = configparser.ConfigParser()
        if secrets:
            config.read(secrets)

        if install_ext:
            t = os.path.join(FLAMBE_GLOBAL_FOLDER, "extensions")
            extensions = download_extensions(extensions, t)
            install_extensions(extensions, user_flag=False)

        if import_ext:
            import_modules(extensions.keys())

        # Check that all tags are valid
        if check_tags:
            self.check_tags(content)

        # Compile the runnable now that the extensions were imported.
        runnable = self.compile_runnable(content)

        if secrets:
            runnable.inject_secrets(secrets)

        if extensions:
            runnable.inject_extensions(extensions)

        runnable.parse()

        return runnable, extensions
def test_exporter_builder():
    with tmpdir() as d, tmpdir() as d2, tmpfile(
            mode="w", suffix=".yaml") as f, tmpfile(mode="w",
                                                    suffix=".yaml") as f2:
        # First run an experiment
        exp = """
!Experiment

name: exporter
save_path: {}

pipeline:
  dataset: !SSTDataset
    transform:
      text: !TextField
      label: !LabelField
  model: !TextClassifier
    embedder: !Embedder
      embedding: !torch.Embedding
        num_embeddings: !@ dataset.text.vocab_size
        embedding_dim: 30
      encoder: !PooledRNNEncoder
        input_size: 30
        rnn_type: lstm
        n_layers: 1
        hidden_size: 16
    output_layer: !SoftmaxLayer
      input_size: !@ model[embedder].encoder.rnn.hidden_size
      output_size: !@ dataset.label.vocab_size

  exporter: !Exporter
    model: !@ model
    text: !@ dataset.text
"""

        exp = exp.format(d)
        f.write(exp)
        f.flush()
        ret = subprocess.run(['flambe', f.name, '-i'])
        assert ret.returncode == 0

        # Then run a builder

        builder = """
flambe_inference: tests/data/dummy_extensions/inference/
---

!Builder

destination: {0}

component: !flambe_inference.DummyInferenceEngine
  model: !TextClassifier.load_from_path
    path: {1}
"""
        base = os.path.join(d, "output__exporter", "exporter")
        path_aux = [
            x for x in os.listdir(base) if os.path.isdir(os.path.join(base, x))
        ][0]  # Should be only 1 folder bc of no variants
        model_path = os.path.join(base, path_aux, "checkpoint",
                                  "checkpoint.flambe", "model")

        builder = builder.format(d2, model_path)
        f2.write(builder)
        f2.flush()

        ret = subprocess.run(['flambe', f2.name, '-i'])
        assert ret.returncode == 0

        # The extensions needs to be imported using extensions.py module
        extensions.import_modules(["flambe_inference"])

        # Import the module after import_modules (which registered tags already)
        from flambe_inference import DummyInferenceEngine

        eng1 = flambe.load(d2)

        assert type(eng1) is DummyInferenceEngine
        assert type(eng1.model) is TextClassifier

        extension_path = os.path.join(
            os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
            "tests/data/dummy_extensions/inference")
        assert eng1._extensions == {"flambe_inference": extension_path}

        eng2 = DummyInferenceEngine.load_from_path(d2)

        assert type(eng2) is DummyInferenceEngine
        assert type(eng2.model) is TextClassifier

        assert eng2._extensions == {"flambe_inference": extension_path}

        assert module_equals(eng1.model, eng2.model)
Exemple #3
0
    def _setup(self, config: Dict):
        """Subclasses should override this for custom initialization."""
        # Set this flag to False, if we find an error, or reduce
        self.name = config['name']
        self.run_flag = True
        custom_modules = config['custom_modules']
        setup_default_modules()
        import_modules(custom_modules)
        # Get the current computation block
        target_block_id = config['to_run']
        self.block_id = target_block_id
        # Update the schemas with the configuration
        schemas: Dict[str, Schema] = Schema.deserialize(config['schemas'])
        schemas_copy = deepcopy(schemas)
        global_vars = config['global_vars']
        self.verbose = config['verbose']
        self.hyper_params = config['hyper_params']
        self.debug = config['debug']

        with TrialLogging(log_dir=self.logdir,
                          verbose=self.verbose,
                          console_prefix=self.block_id,
                          hyper_params=self.hyper_params,
                          capture_warnings=True):

            # Compile, activate links, and load checkpoints
            filled_schemas: Dict = OrderedDict()
            for block_id, schema_block in schemas.items():

                block_params = config['params'][block_id]

                utils.update_schema_with_params(schemas_copy[block_id],
                                                block_params)

                # First activate links from previous blocks in the
                # pipeline
                utils.update_link_refs(schemas_copy, block_id, global_vars)
                block: Component = schemas_copy[block_id]()
                filled_schemas[block_id] = schemas_copy[block_id]

                if block_id in config['checkpoints']:
                    # Get the block hash
                    needed_set = utils.extract_needed_blocks(
                        schemas, block_id, global_vars)
                    needed_blocks = ((k, v) for k, v in filled_schemas.items()
                                     if k in needed_set)
                    block_hash = repr(OrderedDict(needed_blocks))

                    # Check the mask, if it's False then we end
                    # immediately
                    mask_value = config['checkpoints'][block_id]['mask'][
                        block_hash]
                    if mask_value is False:
                        self.run_flag = False
                        return

                    # There should be a checkpoint
                    checkpoint = config['checkpoints'][block_id]['paths'][
                        block_hash]
                    state = load_state_from_file(checkpoint)
                    block.load_state(state)

                # Holding compiled objects alongside schemas is okay
                # but not fully expressed in our type annotations.
                # TODO: fix this in our utils type annotations
                schemas_copy[block_id] = block  # type: ignore

        # If everything went well, just compile
        self.block = schemas_copy[target_block_id]

        # Add tb prefix to computables in case multiple plots are
        # requested
        if not config['merge_plot']:
            self.block.tb_log_prefix = self.name
Exemple #4
0
def load(path: str,
         map_location=None,
         auto_install=False,
         pickle_module=dill,
         **pickle_load_args):
    """Load object with state from the given path

    Loads a flambe object by using the saved config files, and then
    loads the saved state into said object. See `load_state_from_file`
    for details regarding how the state is loaded from the save file or
    directory.

    Parameters
    ----------
    path : str
        Path to the save file or directory
    map_location : type
        Location (device) where items will be moved. ONLY used when the
        directory save format is used. See torch.load documentation for
        more details (the default is None).
    auto_install : bool
        If True, automatically installs extensions as needed.
    pickle_module : type
        Pickle module that has load and dump methods; dump should
        accept a pickle_protocol parameter (the default is dill).
    **pickle_load_args : type
        Additional args that `pickle_module` should use to load; see
        torch.load documentation for more details

    Returns
    -------
    Component
        object with both the architecture (config) and state that was
        saved to path

    Raises
    ------
    LoadError
        If a Component object is not loadable from the given path
        because extensions are not installed, or the config is empty,
        nonexistent, or otherwise invalid.

    """
    state = load_state_from_file(path, map_location, pickle_module,
                                 **pickle_load_args)
    yaml_config = state._metadata[''][FLAMBE_CONFIG_KEY]
    stash = state._metadata[''][FLAMBE_STASH_KEY] \
        if FLAMBE_STASH_KEY in state._metadata[''] else None
    setup_default_modules()
    yamls = list(yaml.load_all(yaml_config))

    if yamls is None:
        raise LoadError(
            "Cannot load schema from empty config. This object may not have been saved"
            " for any of the following reasons:\n - The object was not created from a"
            "config or with compile method\n - The object originally linked to other"
            "objects that cannot be represented in YAML")
    if len(yamls) > 2:
        raise LoadError(
            f"{os.path.join(path, CONFIG_FILE_NAME)} should contain an (optional) "
            "extensions section and the main object.")
    if len(yamls) == 2:
        if yamls[0] is not None:
            extensions = dict(yamls[0])
            custom_modules = extensions.keys()
            for x in custom_modules:
                if not is_installed_module(x):
                    if auto_install:
                        logger.warn(
                            f"auto_install==True, installing missing Module "
                            f"{x}: {extensions[x]}")
                        install_extensions({x: extensions[x]})
                        logger.debug(
                            f"Installed module {x} from {extensions[x]}")
                    else:
                        raise ImportError(
                            f"Module {x} is required and not installed. Please 'pip install'"
                            "the package containing the module or set auto_install flag"
                            " to True.")
                import_modules([x])
                logger.debug(f"Automatically imported {x}")

            # Reload with extensions' module imported (and registered)
            schema = list(yaml.load_all(yaml_config))[1]

            # Set the extensions to the schema so that they are
            # passed when compiling the component.
            schema.add_extensions_metadata(extensions)
        else:
            schema = yamls[1]
    elif len(yamls) == 1:
        schema = yamls[0]
    else:
        raise LoadError(
            "No config found at location; cannot load. Try just loading state with "
            "the function 'load_state_from_file'")

    if schema is None:
        raise LoadError(
            "Cannot load schema from empty config. This object may not have been saved"
            " for any of the following reasons:\n - The object was not created from a"
            "config or with compile method\n - The object originally linked to other"
            "objects that cannot be represented in YAML")

    _update_link_refs(schema)
    # TODO: maybe replace with instance check if solution to circular
    # dependency with component is found

    try:
        instance = schema(stash)
    except TypeError:
        raise LoadError(
            f"Loaded object is not callable - likely because an extension is not "
            f"installed. Check if {os.path.join(path, CONFIG_FILE_NAME)} has an "
            f"extensions section at the top and install as necessary. Alternatively "
            f"set auto_install=True")
    instance.load_state(state)
    return instance