Example #1
0
    def forward(self, compute_STFT,wav):
        feats = compute_STFT(wav)
        feats = spectral_magnitude(feats, power=1)# power spectrum

        # Log1p reduces the emphasis on small differences
        feats = torch.log1p(feats)

        return feats
    def init_matrices(self, train_loader):
        """
        This function is used to initialize the parameter matrices
        """

        batch = next(iter(train_loader))
        X = self.hparams.compute_features(batch.wav.data)
        X = spectral_magnitude(X, power=2)
        n = X.shape[0] * X.shape[1]

        # initialize
        eps = 1e-20
        w = 0.1 * torch.rand(self.hparams.m, self.hparams.K) + 1
        self.w = w / torch.sum(w, dim=0) + eps

        h = 0.1 * torch.rand(self.hparams.K, n) + 1
        self.h = h / torch.sum(h, dim=0) + eps
Example #3
0
    def compute_forward(self, batch, stage):
        batch = batch.to(self.device)
        noisy_wavs, lens = batch.noisy_sig

        feats = self.hparams.compute_STFT(noisy_wavs)
        feats = spectral_magnitude(feats, power=0.5)
        feats = torch.log1p(feats)

        predict_spec = self.hparams.model(feats)

        # Also return predicted wav
        if stage != sb.Stage.TRAIN:
            predict_wav = self.hparams.resynth(torch.expm1(predict_spec),
                                               noisy_wavs)
        else:
            predict_wav = None

        return predict_spec, predict_wav
Example #4
0
    def forward(self, wav):
        """Returns a set of features generated from the input waveforms.

        Arguments
        ---------
        wav : tensor
            A batch of audio signals to transform to features.
        """
        STFT = self.compute_STFT(wav)
        mag = spectral_magnitude(STFT)
        fbanks = self.compute_fbanks(mag)
        if self.deltas:
            delta1 = self.compute_deltas(fbanks)
            delta2 = self.compute_deltas(delta1)
            fbanks = torch.cat([fbanks, delta1, delta2], dim=2)
        if self.context:
            fbanks = self.context_window(fbanks)
        return fbanks
Example #5
0
    def compute_objectives(self, predictions, batch, stage):
        """Computes the loss given the predicted and targeted outputs"""
        predict_spec, predict_wav = predictions
        ids = batch.id
        clean_wav, lens = batch.clean_sig

        feats = self.hparams.compute_STFT(clean_wav)
        feats = spectral_magnitude(feats, power=0.5)
        feats = torch.log1p(feats)

        loss = self.hparams.compute_cost(predict_spec, feats, lens)

        self.loss_metric.append(ids,
                                predict_spec,
                                feats,
                                lens,
                                reduction="batch")

        if stage != sb.Stage.TRAIN:
            # Evaluate speech quality/intelligibility
            self.stoi_metric.append(ids,
                                    predict_wav,
                                    clean_wav,
                                    lens,
                                    reduction="batch")
            self.pesq_metric.append(batch.id,
                                    predict=predict_wav,
                                    target=clean_wav,
                                    lengths=lens)

        # Write wavs to file
        if stage == sb.Stage.TEST:
            lens = lens * clean_wav.shape[1]
            for name, wav, length in zip(ids, predict_wav, lens):
                enhance_path = os.path.join(self.hparams.enhanced_folder, name)
                if not enhance_path.endswith(".wav"):
                    enhance_path = enhance_path + ".wav"
                torchaudio.save(
                    enhance_path,
                    torch.unsqueeze(wav[:int(length)].cpu(), 0),
                    self.hparams.Sample_rate,
                )

        return loss
    def compute_forward(self, batch):
        """Forward pass, to be overridden by sub-classes.

        Arguments
        ---------
        batch : PaddedBatch
            The input tensor or tensors for processing.
        init_params : bool
            Whether this pass should initialize parameters rather
            than return the results of the forward pass.
        """

        X = self.hparams.compute_features(batch.wav.data)
        X = spectral_magnitude(X, power=2)

        # concatenate all the inputs
        X = X.reshape(-1, X.size(-1)).t()

        eps = 1e-20
        g = X.sum(dim=0) + eps
        z = X / g

        v = z / (torch.matmul(self.w, self.h) + eps)

        nw = self.w * torch.matmul(v, self.h.t())
        self.w = nw / (torch.sum(nw, dim=0) + eps)

        nh = self.h * torch.matmul(self.w.t(), v)
        # sparsity
        nh = nh + 0.02 * nh**(1.0 + 0.1)

        self.h = nh / (torch.sum(nh, dim=0) + eps)

        self.h *= g

        deviation = (X - torch.matmul(self.w, self.h)).abs().mean().item()

        return torch.matmul(self.w, self.h), self.w, self.h / g, deviation
Example #7
0
 def compute_feats(self, wavs):
     """Feature computation pipeline"""
     feats = self.hparams.compute_STFT(wavs)
     feats = spectral_magnitude(feats, power=0.5)
     feats = torch.log1p(feats)
     return feats
Example #8
0
def reconstruct_results(
    X1hat, X2hat, X_stft, sample_rate, win_length, hop_length,
):

    """This function reconstructs the separated spectra into waveforms.

    Arguments
    ---------
    Xhat1 : torch.tensor
        The separated spectrum for source 1 of size [BS, nfft/2 + 1, T],
        where,  BS = batch size, nfft = fft size, T = length of the spectra.
    Xhat2 : torch.tensor
        The separated spectrum for source 2 of size [BS, nfft/2 + 1, T].
        The size definitions are the same as Xhat1.
    X_stft : torch.tensor
        This is the magnitude spectra for the mixtures.
        The size is [BS x nfft//2 + 1 x T x 2] where,
        BS = batch size, nfft = fft size, T = number of time steps in the spectra.
        The last dimension is to represent complex numbers.
    sample_rate : int
        The sampling rate (in Hz) in which we would like to save the results.
    win_length : int
        The length of stft windows (in ms).
    hop_length : int
        The length with which we shift the STFT windows (in ms).

    Returns
    -------
    x1hats : list
        List of waveforms for source 1.
    x2hats : list
        List of waveforms for source 2.

    Example
    -------
    >>> BS, nfft, T = 10, 512, 16000
    >>> sample_rate, win_length, hop_length = 16000, 25, 10
    >>> X1hat = torch.randn(BS, nfft//2 + 1, T)
    >>> X2hat = torch.randn(BS, nfft//2 + 1, T)
    >>> X_stft = torch.randn(BS, nfft//2 + 1, T, 2)
    >>> x1hats, x2hats = reconstruct_results(X1hat, X2hat, X_stft, sample_rate, win_length, hop_length)
    """

    ISTFT = spf.ISTFT(
        sample_rate=sample_rate, win_length=win_length, hop_length=hop_length
    )

    phase_mix = spectral_phase(X_stft)
    mag_mix = spectral_magnitude(X_stft, power=2)

    x1hats, x2hats = [], []
    eps = 1e-25
    for i in range(X1hat.shape[0]):
        X1hat_stft = (
            (X1hat[i] / (eps + X1hat[i] + X2hat[i])).unsqueeze(-1)
            * mag_mix[i].unsqueeze(-1)
            * torch.cat(
                [
                    torch.cos(phase_mix[i].unsqueeze(-1)),
                    torch.sin(phase_mix[i].unsqueeze(-1)),
                ],
                dim=-1,
            )
        )

        X2hat_stft = (
            (X2hat[i] / (eps + X1hat[i] + X2hat[i])).unsqueeze(-1)
            * mag_mix[i].unsqueeze(-1)
            * torch.cat(
                [
                    torch.cos(phase_mix[i].unsqueeze(-1)),
                    torch.sin(phase_mix[i].unsqueeze(-1)),
                ],
                dim=-1,
            )
        )
        X1hat_stft = X1hat_stft.unsqueeze(0).permute(0, 2, 1, 3)
        X2hat_stft = X2hat_stft.unsqueeze(0).permute(0, 2, 1, 3)
        shat1 = ISTFT(X1hat_stft)
        shat2 = ISTFT(X2hat_stft)

        div_factor = 10
        x1 = shat1 / (div_factor * shat1.std())
        x2 = shat2 / (div_factor * shat2.std())

        x1hats.append(x1)
        x2hats.append(x2)
    return x1hats, x2hats
def main():
    experiment_dir = os.path.dirname(os.path.realpath(__file__))
    hparams_file = os.path.join(experiment_dir, "hyperparams.yaml")
    data_folder = "../../../../samples/audio_samples/sourcesep_samples"
    data_folder = os.path.realpath(os.path.join(experiment_dir, data_folder))
    with open(hparams_file) as fin:
        hparams = load_hyperpyyaml(fin, {"data_folder": data_folder})

    sb.create_experiment_directory(
        experiment_directory=hparams["output_folder"],
        hyperparams_to_save=hparams_file,
    )
    torch.manual_seed(0)

    NMF1 = NMF_Brain(hparams=hparams)
    train_loader = sb.dataio.dataloader.make_dataloader(
        hparams["train_data"], **hparams["loader_kwargs"])

    NMF1.init_matrices(train_loader)

    print("fitting model 1")
    NMF1.fit(
        train_set=train_loader,
        valid_set=None,
        epoch_counter=range(hparams["N_epochs"]),
        progressbar=False,
    )
    W1hat = NMF1.training_out[1]

    NMF2 = NMF_Brain(hparams=hparams)
    train_loader = sb.dataio.dataloader.make_dataloader(
        hparams["train_data"], **hparams["loader_kwargs"])
    NMF2.init_matrices(train_loader)

    print("fitting model 2")
    NMF2.fit(
        train_set=train_loader,
        valid_set=None,
        epoch_counter=range(hparams["N_epochs"]),
        progressbar=False,
    )
    W2hat = NMF2.training_out[1]

    # separate
    mixture_loader = sb.dataio.dataloader.make_dataloader(
        hparams["test_data"], **hparams["loader_kwargs"])
    mix_batch = next(iter(mixture_loader))

    Xmix = NMF1.hparams.compute_features(mix_batch.wav.data)
    Xmix_mag = spectral_magnitude(Xmix, power=2)

    X1hat, X2hat = sb_nmf.NMF_separate_spectra([W1hat, W2hat], Xmix_mag)

    x1hats, x2hats = sb_nmf.reconstruct_results(
        X1hat,
        X2hat,
        Xmix.permute(0, 2, 1, 3),
        hparams["sample_rate"],
        hparams["win_length"],
        hparams["hop_length"],
    )

    if hparams["save_reconstructed"]:
        savepath = "results/save/"
        if not os.path.exists("results"):
            os.mkdir("results")

        if not os.path.exists(savepath):
            os.mkdir(savepath)

        for i, (x1hat, x2hat) in enumerate(zip(x1hats, x2hats)):
            write_audio(
                os.path.join(savepath, "separated_source1_{}.wav".format(i)),
                x1hat.squeeze(0),
                16000,
            )
            write_audio(
                os.path.join(savepath, "separated_source2_{}.wav".format(i)),
                x2hat.squeeze(0),
                16000,
            )

        if hparams["copy_original_files"]:
            datapath = "samples/audio_samples/sourcesep_samples"

        filedir = os.path.dirname(os.path.realpath(__file__))
        speechbrain_path = os.path.abspath(os.path.join(
            filedir, "../../../.."))
        copypath = os.path.realpath(os.path.join(speechbrain_path, datapath))

        all_files = os.listdir(copypath)
        wav_files = [fl for fl in all_files if ".wav" in fl]

        for wav_file in wav_files:
            shutil.copy(copypath + "/" + wav_file, savepath)
Example #10
0
 def compute_feats(self, wavs):
     feats = self.hparams.compute_STFT(wavs)
     feats = spectral_magnitude(feats, power=0.5)
     feats = torch.log1p(feats)
     return feats