Example #1
0
    def parse_arguments(self, arg_list):
        """A version of speechbrain.parse_arguments enhanced for hyperparameter optimization.

        If a parameter named 'hpopt' is provided, hyperparameter
        optimization and reporting will be enabled.

        If the parameter value corresponds to a filename, it will
        be read as a hyperpyaml file, and the contents will be added
        to "overrides". This is useful for cases where the values of
        certain hyperparameters are different during hyperparameter
        optimization vs during full training (e.g. number of epochs, saving
        files, etc)

        Arguments
        ---------
        arg_list: a list of arguments

        Returns
        -------
        param_file : str
            The location of the parameters file.
        run_opts : dict
            Run options, such as distributed, device, etc.
        overrides : dict
            The overrides to pass to ``load_hyperpyyaml``.

        Example
        -------
        >>> ctx = HyperparameterOptimizationContext()
        >>> arg_list = ["hparams.yaml", "--x", "1", "--y", "2"]
        >>> hparams_file, run_opts, overrides = ctx.parse_arguments(arg_list)
        >>> print(f"File: {hparams_file}, Overrides: {overrides}")
        File: hparams.yaml, Overrides: {'x': 1, 'y': 2}
        """
        hparams_file, run_opts, overrides_yaml = sb.parse_arguments(arg_list)
        overrides = load_hyperpyyaml(overrides_yaml)
        hpopt = overrides.get(KEY_HPOPT, False)
        hpopt_mode = overrides.get(KEY_HPOPT_MODE) or DEFAULT_REPORTER
        if hpopt:
            self.enabled = True
            self.reporter = get_reporter(hpopt_mode, *self.reporter_args,
                                         **self.reporter_kwargs)
            if isinstance(hpopt, str) and os.path.exists(hpopt):
                with open(hpopt) as hpopt_file:
                    trial_id = get_trial_id()
                    hpopt_overrides = load_hyperpyyaml(
                        hpopt_file,
                        overrides={"trial_id": trial_id},
                        overrides_must_match=False,
                    )
                    overrides = dict(hpopt_overrides, **overrides)
                    for key in [KEY_HPOPT, KEY_HPOPT_MODE]:
                        if key in overrides:
                            del overrides[key]
        return hparams_file, run_opts, overrides
Example #2
0
def main(device="cpu"):
    experiment_dir = pathlib.Path(__file__).resolve().parent
    hparams_file = experiment_dir / "hyperparams.yaml"
    data_folder = "../../../../samples/audio_samples/nn_training_samples"
    data_folder = (experiment_dir / data_folder).resolve()

    # Load model hyper parameters:
    with open(hparams_file) as fin:
        hparams = load_hyperpyyaml(fin)

    # Dataset creation
    train_data, valid_data = data_prep(data_folder, hparams)

    # Trainer initialization
    seq2seq_brain = seq2seqBrain(
        hparams["modules"],
        hparams["opt_class"],
        hparams,
        run_opts={"device": device},
    )

    # Training/validation loop
    seq2seq_brain.fit(
        range(hparams["N_epochs"]),
        train_data,
        valid_data,
        train_loader_kwargs=hparams["dataloader_options"],
        valid_loader_kwargs=hparams["dataloader_options"],
    )
    # Evaluation is run separately (now just evaluating on valid data)
    seq2seq_brain.evaluate(valid_data)

    # Check that model overfits for integration test
    assert seq2seq_brain.train_loss < 1.0
def main():
    overrides = {
        "output_folder": output_folder,
        "data_folder": os.path.join(experiment_dir, "..", "..", "..",
                                    "samples"),
    }
    with open(hyperparams_file) as fin:
        hyperparams = load_hyperpyyaml(fin, overrides)

    sb.create_experiment_directory(
        experiment_directory=output_folder,
        hyperparams_to_save=hyperparams_file,
        overrides=overrides,
    )

    dataloader = sb.dataio.dataloader.make_dataloader(
        dataset=hyperparams["sample_data"],
        batch_size=hyperparams["batch_size"])
    for (
            id,
        (wav, wav_len),
    ) in iter(dataloader):
        wav_drop = hyperparams["drop_freq"](wav)
        # save results on file
        for i, snt_id in enumerate(id):
            filepath = (hyperparams["output_folder"] + "/save/" + snt_id +
                        ".flac")
            write_audio(filepath, wav_drop[i], 16000)
def main():
    pytest.importorskip("numba")
    experiment_dir = pathlib.Path(__file__).resolve().parent
    hparams_file = experiment_dir / "hyperparams.yaml"
    data_folder = "../../../../samples/audio_samples/nn_training_samples"
    data_folder = (experiment_dir / data_folder).resolve()

    # Load model hyper parameters:
    with open(hparams_file) as fin:
        hparams = load_hyperpyyaml(fin)

    # Dataset creation
    train_data, valid_data, label_encoder = data_prep(data_folder, hparams)

    # Trainer initialization
    trasducer_brain = TransducerBrain(hparams["modules"], hparams["opt_class"],
                                      hparams)

    # Training/validation loop
    trasducer_brain.fit(
        range(hparams["N_epochs"]),
        train_data,
        valid_data,
        train_loader_kwargs=hparams["dataloader_options"],
        valid_loader_kwargs=hparams["dataloader_options"],
    )
    # Evaluation is run separately (now just evaluating on valid data)
    trasducer_brain.evaluate(valid_data)

    # Check that model overfits for integration test
    assert trasducer_brain.train_loss < 1.0
Example #5
0
def main():
    experiment_dir = pathlib.Path(__file__).resolve().parent
    hparams_file = experiment_dir / "hyperparams.yaml"
    data_folder = "../../../../samples/audio_samples/nn_training_samples"
    data_folder = (experiment_dir / data_folder).resolve()

    # Load model hyper parameters:
    with open(hparams_file) as fin:
        hparams = load_hyperpyyaml(fin)

    # Dataset creation
    train_data, valid_data = data_prep(data_folder)

    # Trainer initialization
    gan_brain = EnhanceGanBrain(modules=hparams["modules"], hparams=hparams)

    # Training/validation loop
    gan_brain.fit(
        range(hparams["N_epochs"]),
        train_data,
        valid_data,
        train_loader_kwargs=hparams["dataloader_options"],
        valid_loader_kwargs=hparams["dataloader_options"],
    )
    # Evaluation is run separately (now just evaluating on valid data)
    gan_brain.evaluate(valid_data)

    # Check test loss (mse), train loss is GAN loss
    assert gan_brain.test_loss < 0.002
Example #6
0
def main(device="cpu"):

    experiment_dir = os.path.dirname(os.path.abspath(__file__))
    hparams_file = os.path.join(experiment_dir, "hyperparams.yaml")
    data_folder = "../../../../../samples/audio_samples/vad"
    data_folder = os.path.abspath(experiment_dir + data_folder)
    with open(hparams_file) as fin:
        hparams = load_hyperpyyaml(fin)

    # Data IO creation
    train_data, valid_data = data_prep(data_folder, hparams)

    # Trainer initialization
    ctc_brain = VADBrain(
        hparams["modules"],
        hparams["opt_class"],
        hparams,
        run_opts={"device": device},
    )

    # Training/validation loop
    ctc_brain.fit(
        range(hparams["N_epochs"]),
        train_data,
        valid_data,
        train_loader_kwargs=hparams["dataloader_options"],
        valid_loader_kwargs=hparams["dataloader_options"],
    )
    # Evaluation is run separately (now just evaluating on valid data)
    ctc_brain.evaluate(valid_data)

    # Check if model overfits for integration test
    assert ctc_brain.train_loss < 1.0
Example #7
0
    def from_hparams(
        cls,
        source,
        hparams_file="hyperparams.yaml",
        overrides={},
        savedir=None,
        use_auth_token=False,
        **kwargs,
    ):
        """Fetch and load based from outside source based on HyperPyYAML file

        The source can be a location on the filesystem or online/huggingface

        The hyperparams file should contain a "modules" key, which is a
        dictionary of torch modules used for computation.

        The hyperparams file should contain a "pretrainer" key, which is a
        speechbrain.utils.parameter_transfer.Pretrainer

        Arguments
        ---------
        source : str
            The location to use for finding the model. See
            ``speechbrain.pretrained.fetching.fetch`` for details.
        hparams_file : str
            The name of the hyperparameters file to use for constructing
            the modules necessary for inference. Must contain two keys:
            "modules" and "pretrainer", as described.
        overrides : dict
            Any changes to make to the hparams file when it is loaded.
        savedir : str or Path
            Where to put the pretraining material. If not given, will use
            ./pretrained_models/<class-name>-hash(source).
        use_auth_token : bool (default: False)
            If true Hugginface's auth_token will be used to load private models from the HuggingFace Hub,
            default is False because majority of models are public.
        """
        if savedir is None:
            clsname = cls.__name__
            savedir = f"./pretrained_models/{clsname}-{hash(source)}"
        hparams_local_path = fetch(hparams_file, source, savedir,
                                   use_auth_token)

        # Load the modules:
        with open(hparams_local_path) as fin:
            hparams = load_hyperpyyaml(fin, overrides)

        # Pretraining:
        pretrainer = hparams["pretrainer"]
        pretrainer.set_collect_in(savedir)
        # For distributed setups, have this here:
        run_on_main(pretrainer.collect_files,
                    kwargs={"default_source": source})
        # Load on the CPU. Later the params can be moved elsewhere by specifying
        # run_opts={"device": ...}
        pretrainer.load_collected(device="cpu")

        # Now return the system
        return cls(hparams["modules"], hparams, **kwargs)
    def __init__(
        self,
        hparams_file="https://www.dropbox.com/s/ct72as3hapy8kb5/ecapa_big.yaml?dl=1",
        overrides={},
        freeze_params=True,
        norm_emb=True,
        save_folder="emb_model",
    ):
        """Downloads the pretrained modules specified in the yaml"""
        super().__init__()
        self.norm_emb = norm_emb

        save_model_path = os.path.join(save_folder, "embedding.yaml")
        download_file(hparams_file, save_model_path)
        hparams_file = save_model_path

        # Loading modules defined in the yaml file
        with open(hparams_file) as fin:
            overrides["save_folder"] = save_folder
            self.hparams = load_hyperpyyaml(fin, overrides)

        # putting modules on the right device
        # We need to check if DDP has been initialised
        # in order to give the right device
        if torch.distributed.is_initialized():
            self.device = ":".join([
                self.hparams["device"].split(":")[0], os.environ["LOCAL_RANK"]
            ])
        else:
            self.device = self.hparams["device"]

        # Creating directory where pre-trained models are stored
        if not os.path.isabs(self.hparams["save_folder"]):
            dirname = os.path.dirname(__file__)
            self.hparams["save_folder"] = os.path.join(
                dirname, self.hparams["save_folder"])
        if not os.path.isdir(self.hparams["save_folder"]):
            os.makedirs(self.hparams["save_folder"])

        # putting modules on the right device
        self.embedding_model = self.hparams["embedding_model"].to(self.device)
        self.mean_var_norm = self.hparams["mean_var_norm"].to(self.device)
        self.mean_var_norm_emb = self.hparams["mean_var_norm_emb"].to(
            self.device)
        self.similarity = torch.nn.CosineSimilarity(dim=-1, eps=1e-6)

        # Load pretrained modules
        self.load_model()

        # If we don't want to backprop, freeze the pretrained parameters
        if freeze_params:
            self.embedding_model.eval()
            for p in self.embedding_model.parameters():
                p.requires_grad = False
def combine_multiple_hyperpyyaml_files_into_one(input_hyperpyyaml_files=None, extra_kv={}, output_hyperpyyaml_file=None):
    """
    # input_hyperpyyaml_files: dictionary of hyperpyyaml config files for this experiment
    # output_hyperpyyaml_file: str
    """
    combined_hparams = {}
    for k,v in input_hyperpyyaml_files.items():
        with open(v, "r") as f:
            combined_hparams[k] = load_hyperpyyaml(f)

    for k,v in extra_kv.items():
        combined_hparams[k] = v

    write_hyperpyyaml_file(output_hyperpyyaml_file, combined_hparams)
Example #10
0
    def __init__(
        self,
        hparams_file="https://www.dropbox.com/s/54vmm04g3gezwz3/pretrained_ASR_BPE1000.yaml?dl=1",
        save_folder="asr_model",
        overrides={},
        freeze_params=True,
    ):
        """Downloads the pretrained modules specified in the yaml"""
        super().__init__()

        save_model_path = os.path.join(save_folder, "ASR.yaml")
        download_file(hparams_file, save_model_path)
        hparams_file = save_model_path

        # Loading modules defined in the yaml file
        with open(hparams_file) as fin:
            overrides["save_folder"] = save_folder
            self.hparams = load_hyperpyyaml(fin, overrides)

        # putting modules on the right device
        # We need to check if DDP has been initialised
        # in order to give the right device
        if torch.distributed.is_initialized():
            self.device = ":".join([
                self.hparams["device"].split(":")[0], os.environ["LOCAL_RANK"]
            ])
        else:
            self.device = self.hparams["device"]

        # Creating directory where pre-trained models are stored
        if not os.path.isdir(self.hparams["save_folder"]):
            os.makedirs(self.hparams["save_folder"])

        # putting modules on the right device
        self.mod = torch.nn.ModuleDict(self.hparams["modules"]).to(self.device)

        # Load pretrained modules
        self.load_asr()

        # The tokenizer is the one used by the LM
        self.tokenizer = self.hparams["lm_model"].tokenizer

        # If we don't want to backprop, freeze the pretrained parameters
        if freeze_params:
            self.mod.asr_model.eval()
            for p in self.mod.asr_model.parameters():
                p.requires_grad = False
            self.mod.lm_model.eval()
            for p in self.mod.lm_model.parameters():
                p.requires_grad = False
    def __init__(
        self,
        hparams_file="hparams/pretrained.yaml",
        save_folder="lm_TAS",
        overrides={},
        freeze_params=True,
    ):
        """Downloads the pretrained modules specified in the yaml"""
        super().__init__()

        self.save_folder = save_folder

        # Download yaml file from the web
        save_file = os.path.join(save_folder, "LM_TAS.yaml")
        download_file(hparams_file, save_file)
        hparams_file = save_file

        # Loading modules defined in the yaml file
        with open(hparams_file) as fin:
            overrides["save_folder"] = save_folder
            self.hparams = load_hyperpyyaml(fin, overrides)

        if not os.path.isdir(self.hparams["save_folder"]):
            os.makedirs(self.hparams["save_folder"])

        # putting modules on the right device
        # We need to check if DDP has been initialised
        # in order to give the right device
        if torch.distributed.is_initialized():
            self.device = ":".join(
                [self.hparams["device"].split(":")[0], os.environ["LOCAL_RANK"]]
            )
        else:
            self.device = self.hparams["device"]

        self.net = self.hparams["net"].to(self.device)

        # Load pretrained modules
        self.load_lm()

        # Load tokenizer
        self.tokenizer = self.hparams["tokenizer"].spm

        # If we don't want to backprop, freeze the pretrained parameters
        if freeze_params:
            self.net.eval()
            for p in self.net.parameters():
                p.requires_grad = False
Example #12
0
    def __init__(
        self,
        hparams_file="hparams/pretrained_BPE1000.yaml",
        overrides={},
        freeze_params=True,
    ):
        """Downloads the pretrained modules specified in the yaml"""
        super().__init__()

        # Loading modules defined in the yaml file
        with open(hparams_file) as fin:
            self.hparams = load_hyperpyyaml(fin, overrides)

        self.device = self.hparams["device"]

        # Creating directory where pre-trained models are stored
        if not os.path.isabs(self.hparams["save_folder"]):
            dirname = os.path.dirname(__file__)
            self.hparams["save_folder"] = os.path.join(
                dirname, self.hparams["save_folder"])
        if not os.path.isdir(self.hparams["save_folder"]):
            os.makedirs(self.hparams["save_folder"])

        # putting modules on the right device
        self.mod = torch.nn.ModuleDict(self.hparams["modules"]).to(self.device)

        # Load pretrained modules
        self.load_tokenizer()
        self.load_asr()

        # If we don't want to backprop, freeze the pretrained parameters
        if freeze_params:
            self.mod.asr_model.eval()
            for p in self.mod.asr_model.parameters():
                p.requires_grad = False
            self.mod.lm_model.eval()
            for p in self.mod.lm_model.parameters():
                p.requires_grad = False
Example #13
0
def test_load_hyperpyyaml(tmpdir):
    from hyperpyyaml import (
        load_hyperpyyaml,
        RefTag,
        Placeholder,
        dump_hyperpyyaml,
    )

    # Basic functionality
    yaml = """
    a: 1
    thing: !new:collections.Counter {}
    """
    things = load_hyperpyyaml(yaml)
    assert things["a"] == 1
    from collections import Counter

    assert things["thing"].__class__ == Counter

    overrides = {"a": 2}
    things = load_hyperpyyaml(yaml, overrides=overrides)
    assert things["a"] == 2
    overrides = "{a: 2}"
    things = load_hyperpyyaml(yaml, overrides=overrides)
    assert things["a"] == 2
    overrides = "{thing: !new:collections.Counter {b: 3}}"
    things = load_hyperpyyaml(yaml, overrides=overrides)
    assert things["thing"]["b"] == 3

    # String replacement
    yaml = """
    a: abc
    b: !ref <a>
    thing: !new:collections.Counter
        a: !ref <a>
    thing2: !new:zip
        - !ref <a>
        - abc
    """
    things = load_hyperpyyaml(yaml)
    assert things["thing"]["a"] == things["a"]
    assert things["a"] == things["b"]
    assert next(things["thing2"]) == ("a", "a")

    # String interpolation
    yaml = """
    a: "a"
    b: !ref <a>/b
    """
    things = load_hyperpyyaml(yaml)
    assert things["b"] == "a/b"

    # Substitution with string conversion
    yaml = """
    a: 1
    b: !ref <a>/b
    """
    things = load_hyperpyyaml(yaml)
    assert things["b"] == "1/b"

    # Nested structures:
    yaml = """
    constants:
        a: 1
    thing: !new:collections.Counter
        other: !new:collections.Counter
            a: !ref <constants[a]>
    """
    things = load_hyperpyyaml(yaml)
    assert things["thing"]["other"].__class__ == Counter
    assert things["thing"]["other"]["a"] == things["constants"]["a"]

    # Positional arguments
    yaml = """
    a: hello
    thing: !new:collections.Counter
        - !ref <a>
    """
    things = load_hyperpyyaml(yaml)
    assert things["thing"]["l"] == 2

    # Invalid class
    yaml = """
    thing: !new:abcdefg.hij
    """
    with pytest.raises(ImportError):
        things = load_hyperpyyaml(yaml)

    # Invalid reference
    yaml = """
    constants:
        a: 1
        b: !ref <constants[c]>
    """
    with pytest.raises(ValueError):
        things = load_hyperpyyaml(yaml)

    # Anchors and aliases
    yaml = """
    thing1: !new:collections.Counter &thing
        a: 3
        b: 5
    thing2: !new:collections.Counter
        <<: *thing
        b: 7
    """
    things = load_hyperpyyaml(yaml)
    assert things["thing1"]["a"] == things["thing2"]["a"]
    assert things["thing1"]["b"] != things["thing2"]["b"]

    # Test references point to same object
    yaml = """
    thing1: !new:collections.Counter
        a: 3
        b: 5
    thing2: !ref <thing1>
    thing3: !new:hyperpyyaml.TestThing
        - !ref <thing1>
        - abc
    """
    things = load_hyperpyyaml(yaml)
    assert things["thing2"]["b"] == things["thing1"]["b"]
    things["thing2"]["b"] = 7
    assert things["thing2"]["b"] == things["thing1"]["b"]
    assert things["thing3"].args[0] == things["thing1"]

    # Copy tag
    yaml = """
    thing1: !new:collections.Counter
        a: 3
        b: 5
    thing2: !copy <thing1>
    """
    things = load_hyperpyyaml(yaml)
    assert things["thing2"]["b"] == things["thing1"]["b"]
    things["thing2"]["b"] = 7
    assert things["thing2"]["b"] != things["thing1"]["b"]

    # Name tag
    yaml = """
    Counter: !name:collections.Counter
    """
    things = load_hyperpyyaml(yaml)
    counter = things["Counter"]()
    assert counter.__class__ == Counter

    # Module tag
    yaml = """
    mod: !module:collections
    """
    things = load_hyperpyyaml(yaml)
    assert things["mod"].__name__ == "collections"

    # Apply tag
    yaml = """
    a: !apply:sum [[1, 2]]
    """
    things = load_hyperpyyaml(yaml)
    assert things["a"] == 3

    # Apply method
    yaml = """
    a: "A STRING"
    common_kwargs:
        thing1: !ref <a.lower>
        thing2: 2
    c: !apply:hyperpyyaml.TestThing.from_keys
        args:
            - 1
            - 2
        kwargs: !ref <common_kwargs>
    """
    things = load_hyperpyyaml(yaml)
    assert things["c"].kwargs["thing1"]() == "a string"
    assert things["c"].specific_key() == "a string"

    # Refattr:
    yaml = """
    thing1: "A string"
    thing2: !ref <thing1.lower>
    thing3: !new:hyperpyyaml.TestThing
        - !ref <thing1.lower>
        - abc
    """
    things = load_hyperpyyaml(yaml)
    assert things["thing2"]() == "a string"
    assert things["thing3"].args[0]() == "a string"

    # Placeholder
    yaml = """
    a: !PLACEHOLDER
    """
    with pytest.raises(ValueError) as excinfo:
        things = load_hyperpyyaml(yaml)
    assert str(excinfo.value) == "'a' is a !PLACEHOLDER and must be replaced."

    # Import
    imported_yaml = """
    a: !PLACEHOLDER
    b: !PLACEHOLDER
    c: !ref <a> // <b>
    """

    import os.path

    test_yaml_file = os.path.join(tmpdir, "test.yaml")
    with open(test_yaml_file, "w") as w:
        w.write(imported_yaml)

    yaml = f"""
    a: 3
    b: !PLACEHOLDER
    import: !include:{test_yaml_file}
        a: !ref <a>
        b: !ref <b>
    d: !ref <import[c]>
    """

    things = load_hyperpyyaml(yaml, {"b": 3})
    assert things["a"] == things["b"]
    assert things["import"]["c"] == 1
    assert things["d"] == things["import"]["c"]

    # Dumping
    dump_dict = {
        "data_folder": Placeholder(),
        "examples": {
            "ex1": RefTag(os.path.join("<data_folder>", "ex1.wav"))
        },
    }

    from io import StringIO

    stringio = StringIO()
    dump_hyperpyyaml(dump_dict, stringio)
    assert stringio.getvalue() == ("data_folder: !PLACEHOLDER\nexamples:\n"
                                   "  ex1: !ref <data_folder>/ex1.wav\n")
def main():
    experiment_dir = os.path.dirname(os.path.realpath(__file__))
    hparams_file = os.path.join(experiment_dir, "hyperparams.yaml")
    data_folder = "../../../../samples/audio_samples/sourcesep_samples"
    data_folder = os.path.realpath(os.path.join(experiment_dir, data_folder))
    with open(hparams_file) as fin:
        hparams = load_hyperpyyaml(fin, {"data_folder": data_folder})

    sb.create_experiment_directory(
        experiment_directory=hparams["output_folder"],
        hyperparams_to_save=hparams_file,
    )
    torch.manual_seed(0)

    NMF1 = NMF_Brain(hparams=hparams)
    train_loader = sb.dataio.dataloader.make_dataloader(
        hparams["train_data"], **hparams["loader_kwargs"])

    NMF1.init_matrices(train_loader)

    print("fitting model 1")
    NMF1.fit(
        train_set=train_loader,
        valid_set=None,
        epoch_counter=range(hparams["N_epochs"]),
        progressbar=False,
    )
    W1hat = NMF1.training_out[1]

    NMF2 = NMF_Brain(hparams=hparams)
    train_loader = sb.dataio.dataloader.make_dataloader(
        hparams["train_data"], **hparams["loader_kwargs"])
    NMF2.init_matrices(train_loader)

    print("fitting model 2")
    NMF2.fit(
        train_set=train_loader,
        valid_set=None,
        epoch_counter=range(hparams["N_epochs"]),
        progressbar=False,
    )
    W2hat = NMF2.training_out[1]

    # separate
    mixture_loader = sb.dataio.dataloader.make_dataloader(
        hparams["test_data"], **hparams["loader_kwargs"])
    mix_batch = next(iter(mixture_loader))

    Xmix = NMF1.hparams.compute_features(mix_batch.wav.data)
    Xmix_mag = spectral_magnitude(Xmix, power=2)

    X1hat, X2hat = sb_nmf.NMF_separate_spectra([W1hat, W2hat], Xmix_mag)

    x1hats, x2hats = sb_nmf.reconstruct_results(
        X1hat,
        X2hat,
        Xmix.permute(0, 2, 1, 3),
        hparams["sample_rate"],
        hparams["win_length"],
        hparams["hop_length"],
    )

    if hparams["save_reconstructed"]:
        savepath = "results/save/"
        if not os.path.exists("results"):
            os.mkdir("results")

        if not os.path.exists(savepath):
            os.mkdir(savepath)

        for i, (x1hat, x2hat) in enumerate(zip(x1hats, x2hats)):
            write_audio(
                os.path.join(savepath, "separated_source1_{}.wav".format(i)),
                x1hat.squeeze(0),
                16000,
            )
            write_audio(
                os.path.join(savepath, "separated_source2_{}.wav".format(i)),
                x2hat.squeeze(0),
                16000,
            )

        if hparams["copy_original_files"]:
            datapath = "samples/audio_samples/sourcesep_samples"

        filedir = os.path.dirname(os.path.realpath(__file__))
        speechbrain_path = os.path.abspath(os.path.join(
            filedir, "../../../.."))
        copypath = os.path.realpath(os.path.join(speechbrain_path, datapath))

        all_files = os.listdir(copypath)
        wav_files = [fl for fl in all_files if ".wav" in fl]

        for wav_file in wav_files:
            shutil.copy(copypath + "/" + wav_file, savepath)
Example #15
0
--hparams_key model

Authors
 * Peter Plantinga 2020
"""

import torch
import argparse
from collections import OrderedDict
from hyperpyyaml import load_hyperpyyaml

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--hparams", required=True)
    parser.add_argument("--hparams_key", required=True)
    parser.add_argument("--old_ckpt", required=True)
    parser.add_argument("--new_ckpt", required=True)
    args = parser.parse_args()

    with open(args.hparams) as f:
        hparams = load_hyperpyyaml(f, overrides={"data_folder": "asdf"})

    ckpt = torch.load(args.old_ckpt)
    assert len(hparams[args.hparams_key].state_dict()) == len(ckpt)

    new_state_dict = OrderedDict()
    for old_key, new_key in zip(ckpt, hparams[args.hparams_key].state_dict()):
        new_state_dict[new_key] = ckpt[old_key]

    torch.save(new_state_dict, args.new_ckpt)
Example #16
0
    sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)

    # 4. Set output:
    sb.dataio.dataset.set_output_keys(
        datasets,
        ["id", "sig", "semantics", "tokens_bos", "tokens_eos", "tokens"],
    )
    return train_data, valid_data, test_data, tokenizer


if __name__ == "__main__":

    # Load hyperparameters file with command-line overrides
    hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
    with open(hparams_file) as fin:
        hparams = load_hyperpyyaml(fin, overrides)

    show_results_every = 100  # plots results every N iterations

    # If distributed_launch=True then
    # create ddp_group with the right communication protocol
    sb.utils.distributed.ddp_init_group(run_opts)

    # Create experiment directory
    sb.create_experiment_directory(
        experiment_directory=hparams["output_folder"],
        hyperparams_to_save=hparams_file,
        overrides=overrides,
    )

    # Dataset prep (parsing SLURP)
def load_hparams(hparams_fname):
    with open(hparams_fname, "r") as f:
        return load_hyperpyyaml(f)