def parse_arguments(self, arg_list): """A version of speechbrain.parse_arguments enhanced for hyperparameter optimization. If a parameter named 'hpopt' is provided, hyperparameter optimization and reporting will be enabled. If the parameter value corresponds to a filename, it will be read as a hyperpyaml file, and the contents will be added to "overrides". This is useful for cases where the values of certain hyperparameters are different during hyperparameter optimization vs during full training (e.g. number of epochs, saving files, etc) Arguments --------- arg_list: a list of arguments Returns ------- param_file : str The location of the parameters file. run_opts : dict Run options, such as distributed, device, etc. overrides : dict The overrides to pass to ``load_hyperpyyaml``. Example ------- >>> ctx = HyperparameterOptimizationContext() >>> arg_list = ["hparams.yaml", "--x", "1", "--y", "2"] >>> hparams_file, run_opts, overrides = ctx.parse_arguments(arg_list) >>> print(f"File: {hparams_file}, Overrides: {overrides}") File: hparams.yaml, Overrides: {'x': 1, 'y': 2} """ hparams_file, run_opts, overrides_yaml = sb.parse_arguments(arg_list) overrides = load_hyperpyyaml(overrides_yaml) hpopt = overrides.get(KEY_HPOPT, False) hpopt_mode = overrides.get(KEY_HPOPT_MODE) or DEFAULT_REPORTER if hpopt: self.enabled = True self.reporter = get_reporter(hpopt_mode, *self.reporter_args, **self.reporter_kwargs) if isinstance(hpopt, str) and os.path.exists(hpopt): with open(hpopt) as hpopt_file: trial_id = get_trial_id() hpopt_overrides = load_hyperpyyaml( hpopt_file, overrides={"trial_id": trial_id}, overrides_must_match=False, ) overrides = dict(hpopt_overrides, **overrides) for key in [KEY_HPOPT, KEY_HPOPT_MODE]: if key in overrides: del overrides[key] return hparams_file, run_opts, overrides
yield tokens sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline) # 4. Set output: sb.dataio.dataset.set_output_keys( datasets, ["id", "sig", "semantics", "tokens_bos", "tokens_eos", "tokens"], ) return train_data, valid_data, test_data, tokenizer if __name__ == "__main__": # Load hyperparameters file with command-line overrides hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) with open(hparams_file) as fin: hparams = load_hyperpyyaml(fin, overrides) show_results_every = 100 # plots results every N iterations # If distributed_launch=True then # create ddp_group with the right communication protocol sb.utils.distributed.ddp_init_group(run_opts) # Create experiment directory sb.create_experiment_directory( experiment_directory=hparams["output_folder"], hyperparams_to_save=hparams_file, overrides=overrides, )