Exemple #1
0
    def parse_arguments(self, arg_list):
        """A version of speechbrain.parse_arguments enhanced for hyperparameter optimization.

        If a parameter named 'hpopt' is provided, hyperparameter
        optimization and reporting will be enabled.

        If the parameter value corresponds to a filename, it will
        be read as a hyperpyaml file, and the contents will be added
        to "overrides". This is useful for cases where the values of
        certain hyperparameters are different during hyperparameter
        optimization vs during full training (e.g. number of epochs, saving
        files, etc)

        Arguments
        ---------
        arg_list: a list of arguments

        Returns
        -------
        param_file : str
            The location of the parameters file.
        run_opts : dict
            Run options, such as distributed, device, etc.
        overrides : dict
            The overrides to pass to ``load_hyperpyyaml``.

        Example
        -------
        >>> ctx = HyperparameterOptimizationContext()
        >>> arg_list = ["hparams.yaml", "--x", "1", "--y", "2"]
        >>> hparams_file, run_opts, overrides = ctx.parse_arguments(arg_list)
        >>> print(f"File: {hparams_file}, Overrides: {overrides}")
        File: hparams.yaml, Overrides: {'x': 1, 'y': 2}
        """
        hparams_file, run_opts, overrides_yaml = sb.parse_arguments(arg_list)
        overrides = load_hyperpyyaml(overrides_yaml)
        hpopt = overrides.get(KEY_HPOPT, False)
        hpopt_mode = overrides.get(KEY_HPOPT_MODE) or DEFAULT_REPORTER
        if hpopt:
            self.enabled = True
            self.reporter = get_reporter(hpopt_mode, *self.reporter_args,
                                         **self.reporter_kwargs)
            if isinstance(hpopt, str) and os.path.exists(hpopt):
                with open(hpopt) as hpopt_file:
                    trial_id = get_trial_id()
                    hpopt_overrides = load_hyperpyyaml(
                        hpopt_file,
                        overrides={"trial_id": trial_id},
                        overrides_must_match=False,
                    )
                    overrides = dict(hpopt_overrides, **overrides)
                    for key in [KEY_HPOPT, KEY_HPOPT_MODE]:
                        if key in overrides:
                            del overrides[key]
        return hparams_file, run_opts, overrides
Exemple #2
0
        yield tokens

    sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)

    # 4. Set output:
    sb.dataio.dataset.set_output_keys(
        datasets,
        ["id", "sig", "semantics", "tokens_bos", "tokens_eos", "tokens"],
    )
    return train_data, valid_data, test_data, tokenizer


if __name__ == "__main__":

    # Load hyperparameters file with command-line overrides
    hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
    with open(hparams_file) as fin:
        hparams = load_hyperpyyaml(fin, overrides)

    show_results_every = 100  # plots results every N iterations

    # If distributed_launch=True then
    # create ddp_group with the right communication protocol
    sb.utils.distributed.ddp_init_group(run_opts)

    # Create experiment directory
    sb.create_experiment_directory(
        experiment_directory=hparams["output_folder"],
        hyperparams_to_save=hparams_file,
        overrides=overrides,
    )