def forward(self, compute_STFT,wav): feats = compute_STFT(wav) feats = spectral_magnitude(feats, power=1)# power spectrum # Log1p reduces the emphasis on small differences feats = torch.log1p(feats) return feats
def init_matrices(self, train_loader): """ This function is used to initialize the parameter matrices """ batch = next(iter(train_loader)) X = self.hparams.compute_features(batch.wav.data) X = spectral_magnitude(X, power=2) n = X.shape[0] * X.shape[1] # initialize eps = 1e-20 w = 0.1 * torch.rand(self.hparams.m, self.hparams.K) + 1 self.w = w / torch.sum(w, dim=0) + eps h = 0.1 * torch.rand(self.hparams.K, n) + 1 self.h = h / torch.sum(h, dim=0) + eps
def compute_forward(self, batch, stage): batch = batch.to(self.device) noisy_wavs, lens = batch.noisy_sig feats = self.hparams.compute_STFT(noisy_wavs) feats = spectral_magnitude(feats, power=0.5) feats = torch.log1p(feats) predict_spec = self.hparams.model(feats) # Also return predicted wav if stage != sb.Stage.TRAIN: predict_wav = self.hparams.resynth(torch.expm1(predict_spec), noisy_wavs) else: predict_wav = None return predict_spec, predict_wav
def forward(self, wav): """Returns a set of features generated from the input waveforms. Arguments --------- wav : tensor A batch of audio signals to transform to features. """ STFT = self.compute_STFT(wav) mag = spectral_magnitude(STFT) fbanks = self.compute_fbanks(mag) if self.deltas: delta1 = self.compute_deltas(fbanks) delta2 = self.compute_deltas(delta1) fbanks = torch.cat([fbanks, delta1, delta2], dim=2) if self.context: fbanks = self.context_window(fbanks) return fbanks
def compute_objectives(self, predictions, batch, stage): """Computes the loss given the predicted and targeted outputs""" predict_spec, predict_wav = predictions ids = batch.id clean_wav, lens = batch.clean_sig feats = self.hparams.compute_STFT(clean_wav) feats = spectral_magnitude(feats, power=0.5) feats = torch.log1p(feats) loss = self.hparams.compute_cost(predict_spec, feats, lens) self.loss_metric.append(ids, predict_spec, feats, lens, reduction="batch") if stage != sb.Stage.TRAIN: # Evaluate speech quality/intelligibility self.stoi_metric.append(ids, predict_wav, clean_wav, lens, reduction="batch") self.pesq_metric.append(batch.id, predict=predict_wav, target=clean_wav, lengths=lens) # Write wavs to file if stage == sb.Stage.TEST: lens = lens * clean_wav.shape[1] for name, wav, length in zip(ids, predict_wav, lens): enhance_path = os.path.join(self.hparams.enhanced_folder, name) if not enhance_path.endswith(".wav"): enhance_path = enhance_path + ".wav" torchaudio.save( enhance_path, torch.unsqueeze(wav[:int(length)].cpu(), 0), self.hparams.Sample_rate, ) return loss
def compute_forward(self, batch): """Forward pass, to be overridden by sub-classes. Arguments --------- batch : PaddedBatch The input tensor or tensors for processing. init_params : bool Whether this pass should initialize parameters rather than return the results of the forward pass. """ X = self.hparams.compute_features(batch.wav.data) X = spectral_magnitude(X, power=2) # concatenate all the inputs X = X.reshape(-1, X.size(-1)).t() eps = 1e-20 g = X.sum(dim=0) + eps z = X / g v = z / (torch.matmul(self.w, self.h) + eps) nw = self.w * torch.matmul(v, self.h.t()) self.w = nw / (torch.sum(nw, dim=0) + eps) nh = self.h * torch.matmul(self.w.t(), v) # sparsity nh = nh + 0.02 * nh**(1.0 + 0.1) self.h = nh / (torch.sum(nh, dim=0) + eps) self.h *= g deviation = (X - torch.matmul(self.w, self.h)).abs().mean().item() return torch.matmul(self.w, self.h), self.w, self.h / g, deviation
def compute_feats(self, wavs): """Feature computation pipeline""" feats = self.hparams.compute_STFT(wavs) feats = spectral_magnitude(feats, power=0.5) feats = torch.log1p(feats) return feats
def reconstruct_results( X1hat, X2hat, X_stft, sample_rate, win_length, hop_length, ): """This function reconstructs the separated spectra into waveforms. Arguments --------- Xhat1 : torch.tensor The separated spectrum for source 1 of size [BS, nfft/2 + 1, T], where, BS = batch size, nfft = fft size, T = length of the spectra. Xhat2 : torch.tensor The separated spectrum for source 2 of size [BS, nfft/2 + 1, T]. The size definitions are the same as Xhat1. X_stft : torch.tensor This is the magnitude spectra for the mixtures. The size is [BS x nfft//2 + 1 x T x 2] where, BS = batch size, nfft = fft size, T = number of time steps in the spectra. The last dimension is to represent complex numbers. sample_rate : int The sampling rate (in Hz) in which we would like to save the results. win_length : int The length of stft windows (in ms). hop_length : int The length with which we shift the STFT windows (in ms). Returns ------- x1hats : list List of waveforms for source 1. x2hats : list List of waveforms for source 2. Example ------- >>> BS, nfft, T = 10, 512, 16000 >>> sample_rate, win_length, hop_length = 16000, 25, 10 >>> X1hat = torch.randn(BS, nfft//2 + 1, T) >>> X2hat = torch.randn(BS, nfft//2 + 1, T) >>> X_stft = torch.randn(BS, nfft//2 + 1, T, 2) >>> x1hats, x2hats = reconstruct_results(X1hat, X2hat, X_stft, sample_rate, win_length, hop_length) """ ISTFT = spf.ISTFT( sample_rate=sample_rate, win_length=win_length, hop_length=hop_length ) phase_mix = spectral_phase(X_stft) mag_mix = spectral_magnitude(X_stft, power=2) x1hats, x2hats = [], [] eps = 1e-25 for i in range(X1hat.shape[0]): X1hat_stft = ( (X1hat[i] / (eps + X1hat[i] + X2hat[i])).unsqueeze(-1) * mag_mix[i].unsqueeze(-1) * torch.cat( [ torch.cos(phase_mix[i].unsqueeze(-1)), torch.sin(phase_mix[i].unsqueeze(-1)), ], dim=-1, ) ) X2hat_stft = ( (X2hat[i] / (eps + X1hat[i] + X2hat[i])).unsqueeze(-1) * mag_mix[i].unsqueeze(-1) * torch.cat( [ torch.cos(phase_mix[i].unsqueeze(-1)), torch.sin(phase_mix[i].unsqueeze(-1)), ], dim=-1, ) ) X1hat_stft = X1hat_stft.unsqueeze(0).permute(0, 2, 1, 3) X2hat_stft = X2hat_stft.unsqueeze(0).permute(0, 2, 1, 3) shat1 = ISTFT(X1hat_stft) shat2 = ISTFT(X2hat_stft) div_factor = 10 x1 = shat1 / (div_factor * shat1.std()) x2 = shat2 / (div_factor * shat2.std()) x1hats.append(x1) x2hats.append(x2) return x1hats, x2hats
def main(): experiment_dir = os.path.dirname(os.path.realpath(__file__)) hparams_file = os.path.join(experiment_dir, "hyperparams.yaml") data_folder = "../../../../samples/audio_samples/sourcesep_samples" data_folder = os.path.realpath(os.path.join(experiment_dir, data_folder)) with open(hparams_file) as fin: hparams = load_hyperpyyaml(fin, {"data_folder": data_folder}) sb.create_experiment_directory( experiment_directory=hparams["output_folder"], hyperparams_to_save=hparams_file, ) torch.manual_seed(0) NMF1 = NMF_Brain(hparams=hparams) train_loader = sb.dataio.dataloader.make_dataloader( hparams["train_data"], **hparams["loader_kwargs"]) NMF1.init_matrices(train_loader) print("fitting model 1") NMF1.fit( train_set=train_loader, valid_set=None, epoch_counter=range(hparams["N_epochs"]), progressbar=False, ) W1hat = NMF1.training_out[1] NMF2 = NMF_Brain(hparams=hparams) train_loader = sb.dataio.dataloader.make_dataloader( hparams["train_data"], **hparams["loader_kwargs"]) NMF2.init_matrices(train_loader) print("fitting model 2") NMF2.fit( train_set=train_loader, valid_set=None, epoch_counter=range(hparams["N_epochs"]), progressbar=False, ) W2hat = NMF2.training_out[1] # separate mixture_loader = sb.dataio.dataloader.make_dataloader( hparams["test_data"], **hparams["loader_kwargs"]) mix_batch = next(iter(mixture_loader)) Xmix = NMF1.hparams.compute_features(mix_batch.wav.data) Xmix_mag = spectral_magnitude(Xmix, power=2) X1hat, X2hat = sb_nmf.NMF_separate_spectra([W1hat, W2hat], Xmix_mag) x1hats, x2hats = sb_nmf.reconstruct_results( X1hat, X2hat, Xmix.permute(0, 2, 1, 3), hparams["sample_rate"], hparams["win_length"], hparams["hop_length"], ) if hparams["save_reconstructed"]: savepath = "results/save/" if not os.path.exists("results"): os.mkdir("results") if not os.path.exists(savepath): os.mkdir(savepath) for i, (x1hat, x2hat) in enumerate(zip(x1hats, x2hats)): write_audio( os.path.join(savepath, "separated_source1_{}.wav".format(i)), x1hat.squeeze(0), 16000, ) write_audio( os.path.join(savepath, "separated_source2_{}.wav".format(i)), x2hat.squeeze(0), 16000, ) if hparams["copy_original_files"]: datapath = "samples/audio_samples/sourcesep_samples" filedir = os.path.dirname(os.path.realpath(__file__)) speechbrain_path = os.path.abspath(os.path.join( filedir, "../../../..")) copypath = os.path.realpath(os.path.join(speechbrain_path, datapath)) all_files = os.listdir(copypath) wav_files = [fl for fl in all_files if ".wav" in fl] for wav_file in wav_files: shutil.copy(copypath + "/" + wav_file, savepath)
def compute_feats(self, wavs): feats = self.hparams.compute_STFT(wavs) feats = spectral_magnitude(feats, power=0.5) feats = torch.log1p(feats) return feats