示例#1
0
    def separate_file(self, path, savedir="."):
        """Separate sources from file.

        Arguments
        ---------
        path : str
            Path to file which has a mixture of sources. It can be a local
            path, a web url, or a huggingface repo.
        savedir : path
            Path where to store the wav signals (when downloaded from the web).
        Returns
        -------
        tensor
            Separated sources
        """
        source, fl = split_path(path)
        path = fetch(fl, source=source, savedir=savedir)

        batch, fs_file = torchaudio.load(path)
        batch = batch.to(self.device)
        fs_model = self.hparams.sample_rate

        # resample the data if needed
        if fs_file != fs_model:
            print("Resampling the audio from {} Hz to {} Hz".format(
                fs_file, fs_model))
            tf = torchaudio.transforms.Resample(orig_freq=fs_file,
                                                new_freq=fs_model)
            batch = batch.mean(dim=0, keepdim=True)
            batch = tf(batch)

        est_sources = self.separate_batch(batch)
        est_sources = est_sources / est_sources.max(dim=1, keepdim=True)[0]
        return est_sources
示例#2
0
    def separate_file(self, path, savedir="."):
        """Separate sources from file.

        Arguments
        ---------
        path : str
            Path to file which has a mixture of sources. It can be a local
            path, a web url, or a huggingface repo.
        savedir : path
            Path where to store the wav signals (when downloaded from the web).
        Returns
        -------
        tensor
            Separated sources
        """
        source, fl = split_path(path)
        path = fetch(fl, source=source, savedir=savedir)

        batch, _ = torchaudio.load(path)

        # eric fix
        batch = batch.to(self.device)

        est_sources = self.separate_batch(batch)
        est_sources = est_sources / est_sources.max(dim=1, keepdim=True)[0]
        return est_sources
示例#3
0
    def from_hparams(
        cls,
        source,
        hparams_file="hyperparams.yaml",
        overrides={},
        savedir=None,
        use_auth_token=False,
        **kwargs,
    ):
        """Fetch and load based from outside source based on HyperPyYAML file

        The source can be a location on the filesystem or online/huggingface

        The hyperparams file should contain a "modules" key, which is a
        dictionary of torch modules used for computation.

        The hyperparams file should contain a "pretrainer" key, which is a
        speechbrain.utils.parameter_transfer.Pretrainer

        Arguments
        ---------
        source : str
            The location to use for finding the model. See
            ``speechbrain.pretrained.fetching.fetch`` for details.
        hparams_file : str
            The name of the hyperparameters file to use for constructing
            the modules necessary for inference. Must contain two keys:
            "modules" and "pretrainer", as described.
        overrides : dict
            Any changes to make to the hparams file when it is loaded.
        savedir : str or Path
            Where to put the pretraining material. If not given, will use
            ./pretrained_models/<class-name>-hash(source).
        use_auth_token : bool (default: False)
            If true Hugginface's auth_token will be used to load private models from the HuggingFace Hub,
            default is False because majority of models are public.
        """
        if savedir is None:
            clsname = cls.__name__
            savedir = f"./pretrained_models/{clsname}-{hash(source)}"
        hparams_local_path = fetch(hparams_file, source, savedir,
                                   use_auth_token)

        # Load the modules:
        with open(hparams_local_path) as fin:
            hparams = load_hyperpyyaml(fin, overrides)

        # Pretraining:
        pretrainer = hparams["pretrainer"]
        pretrainer.set_collect_in(savedir)
        # For distributed setups, have this here:
        run_on_main(pretrainer.collect_files,
                    kwargs={"default_source": source})
        # Load on the CPU. Later the params can be moved elsewhere by specifying
        # run_opts={"device": ...}
        pretrainer.load_collected(device="cpu")

        # Now return the system
        return cls(hparams["modules"], hparams, **kwargs)
示例#4
0
    def load_audio(self, path, savedir="."):
        """Load an audio file with this model"s input spec

        When using a speech model, it is important to use the same type of data,
        as was used to train the model. This means for example using the same
        sampling rate and number of channels. It is, however, possible to
        convert a file from a higher sampling rate to a lower one (downsampling).
        Similarly, it is simple to downmix a stereo file to mono.
        The path can be a local path, a web url, or a link to a huggingface repo.
        """
        source, fl = split_path(path)
        path = fetch(fl, source=source, savedir=savedir)
        signal, sr = torchaudio.load(path, channels_first=False)
        return self.audio_normalizer(signal, sr)
示例#5
0
    def collect_files(self, default_source=None):
        """Fetches parameters from known paths with fallback default_source

        The actual parameter files may reside elsewhere, but this ensures a
        symlink in the self.collect_in directory. The symlink always uses the
        loadable key in the filename. This standardization makes it easier to
        orchestrate pretraining on e.g. distributed setups.

        Use the default_source if you have everything organized neatly into one
        location, like a Huggingface hub repo.

        Arguments
        ---------
        default_source : str or Path
            This is used for each loadable which doesn't have a path already
            specified. If the loadable has key "asr", then the file to look for is
            default_source/asr.ckpt

        Returns
        -------
        dict
            Mapping from loadable key to a local path from which loadable's
            parameters can be loaded. This is not used in this class, but
            can possibly be helpful.
        """
        logger.debug(
            f"Collecting files (or symlinks) for pretraining in {self.collect_in}."
        )
        self.collect_in.mkdir(exist_ok=True)
        loadable_paths = {}
        for name in self.loadables:
            save_filename = name + PARAMFILE_EXT
            if name in self.paths:
                source, filename = self.split_path(self.paths[name])
            elif default_source is not None:
                filename = save_filename
                source = default_source
            else:
                raise ValueError(f"Path not specified for '{name}', "
                                 "and no default_source given!")
            path = fetch(filename,
                         source,
                         self.collect_in,
                         save_filename=save_filename)
            loadable_paths[name] = path
        return loadable_paths
    def _from_pretrained(self, source, config, model, save_path):
        """This function manages the source checking and loading of the params.
        # 1. Is the model from HF or a local path
        # 2. Is the model pretrained with HF or SpeechBrain
        # 3. Download (if appropriate) and load with respect to 1. and 2.
        """

        is_sb, ckpt_file = self._check_model_source(source)
        if is_sb:
            config = config.from_pretrained(source, cache_dir=save_path)
            self.model = model(config)
            self.model.gradient_checkpointing_disable()  # Required by DDP
            # fetch the checkpoint file
            ckpt_full_path = fetch(filename=ckpt_file,
                                   source=source,
                                   savedir=save_path)
            # We transfer the parameters from the checkpoint.
            self._load_sb_pretrained_w2v2_parameters(ckpt_full_path)
        else:
            self.model = model.from_pretrained(source, cache_dir=save_path)
示例#7
0
def create_rirs(output_dir, sr=8000):
    """
    This function creates the room impulse responses from the WHAMR! dataset
    The implementation is based on the scripts from http://wham.whisper.ai/

    Arguments:
    ------
    output_dir (str) : directory for saving the RIRs
    sr (int) : sampling rate with which we save

    """

    assert (pyroomacoustics.__version__ == "0.3.1"
            ), "The pyroomacoustics version needs to be 0.3.1"

    os.makedirs(output_dir)

    metafilesdir = os.path.dirname(os.path.realpath(__file__))
    filelist = [
        "mix_2_spk_filenames_tr.csv",
        "mix_2_spk_filenames_cv.csv",
        "mix_2_spk_filenames_tt.csv",
        "reverb_params_tr.csv",
        "reverb_params_cv.csv",
        "reverb_params_tt.csv",
    ]

    savedir = os.path.join(metafilesdir, "data")
    for fl in filelist:
        if not os.path.exists(os.path.join(savedir, fl)):
            fetch(
                "metadata/" + fl,
                "speechbrain/sepformer-whamr",
                savedir=savedir,
                save_filename=fl,
            )

    FILELIST_STUB = os.path.join(metafilesdir, "data",
                                 "mix_2_spk_filenames_{}.csv")

    SPLITS = ["tr"]

    reverb_param_stub = os.path.join(metafilesdir, "data",
                                     "reverb_params_{}.csv")

    for splt in SPLITS:

        wsjmix_path = FILELIST_STUB.format(splt)
        wsjmix_df = pd.read_csv(wsjmix_path)

        reverb_param_path = reverb_param_stub.format(splt)
        reverb_param_df = pd.read_csv(reverb_param_path)

        utt_ids = wsjmix_df.output_filename.values

        for output_name in tqdm(utt_ids):
            utt_row = reverb_param_df[reverb_param_df["utterance_id"] ==
                                      output_name]
            room = WhamRoom(
                [
                    utt_row["room_x"].iloc[0],
                    utt_row["room_y"].iloc[0],
                    utt_row["room_z"].iloc[0],
                ],
                [
                    [
                        utt_row["micL_x"].iloc[0],
                        utt_row["micL_y"].iloc[0],
                        utt_row["mic_z"].iloc[0],
                    ],
                    [
                        utt_row["micR_x"].iloc[0],
                        utt_row["micR_y"].iloc[0],
                        utt_row["mic_z"].iloc[0],
                    ],
                ],
                [
                    utt_row["s1_x"].iloc[0],
                    utt_row["s1_y"].iloc[0],
                    utt_row["s1_z"].iloc[0],
                ],
                [
                    utt_row["s2_x"].iloc[0],
                    utt_row["s2_y"].iloc[0],
                    utt_row["s2_z"].iloc[0],
                ],
                utt_row["T60"].iloc[0],
            )
            room.generate_rirs()

            rir = room.rir_reverberant

            for i, mics in enumerate(rir):
                for j, source in enumerate(mics):
                    h = resample_poly(source, sr, 16000)
                    h_torch = torch.from_numpy(h).float().unsqueeze(0)

                    torchaudio.save(
                        os.path.join(
                            output_dir,
                            "{}_{}_".format(i, j) + output_name,
                        ),
                        h_torch,
                        sr,
                    )