def test_gccphat(device): from speechbrain.processing.features import STFT from speechbrain.processing.multi_mic import Covariance, GccPhat # Creating the test signal fs = 16000 delay = 60 sig = torch.randn([10, fs], device=device) sig_delayed = torch.cat( (torch.zeros([10, delay], device=device), sig[:, 0:-delay]), 1 ) xs = torch.stack((sig_delayed, sig), -1) stft = STFT(sample_rate=fs).to(device) Xs = stft(xs) # Computing the covariance matrix for GCC-PHAT cov = Covariance().to(device) gccphat = GccPhat().to(device) XXs = cov(Xs).to(device) tdoas = torch.abs(gccphat(XXs)) n_valid_tdoas = torch.sum(torch.abs(tdoas[..., 1] - delay) < 1e-3) assert n_valid_tdoas == Xs.shape[0] * Xs.shape[1] assert torch.jit.trace(stft, xs) assert torch.jit.trace(cov, Xs) assert torch.jit.trace(gccphat, XXs)
def __init__( self, deltas=False, context=False, requires_grad=False, sample_rate=16000, n_fft=400, n_mels=40, filter_shape="triangular", param_change_factor=1.0, param_rand_factor=0.0, left_frames=5, right_frames=5, ): super().__init__() self.deltas = deltas self.context = context self.requires_grad = requires_grad self.compute_STFT = STFT(sample_rate=sample_rate, n_fft=n_fft) self.compute_fbanks = Filterbank( n_fft=n_fft, n_mels=n_mels, f_min=0, f_max=sample_rate / 2, freeze=not requires_grad, filter_shape=filter_shape, param_change_factor=param_change_factor, param_rand_factor=param_rand_factor, ) self.compute_deltas = Deltas(input_size=n_mels) self.context_window = ContextWindow( left_frames=left_frames, right_frames=right_frames, )
def __init__(self, sampling_rate=16000): super().__init__() self.fs = sampling_rate self.stft = STFT(sample_rate=self.fs) self.cov = Covariance() self.gccphat = GccPhat() self.delaysum = DelaySum() self.istft = ISTFT(sample_rate=self.fs)
def __init__(self, win_length=36, hop_length=12, *args, **kwargs): super().__init__(*args, **kwargs) sample_rate = self.compute_STFT.sample_rate n_fft = self.compute_STFT.n_fft self.compute_STFT = STFT( sample_rate=sample_rate, n_fft=n_fft, win_length=win_length, hop_length=hop_length, )
def __init__( self, deltas=True, context=True, requires_grad=False, sample_rate=16000, f_min=0, f_max=None, n_fft=400, n_mels=23, n_mfcc=20, filter_shape="triangular", param_change_factor=1.0, param_rand_factor=0.0, left_frames=5, right_frames=5, win_length=25, hop_length=10, ): super().__init__() self.deltas = deltas self.context = context self.requires_grad = requires_grad if f_max is None: f_max = sample_rate / 2 self.compute_STFT = STFT( sample_rate=sample_rate, n_fft=n_fft, win_length=win_length, hop_length=hop_length, ) self.compute_fbanks = Filterbank( sample_rate=sample_rate, n_fft=n_fft, n_mels=n_mels, f_min=f_min, f_max=f_max, freeze=not requires_grad, filter_shape=filter_shape, param_change_factor=param_change_factor, param_rand_factor=param_rand_factor, ) self.compute_dct = DCT(input_size=n_mels, n_out=n_mfcc) self.compute_deltas = Deltas(input_size=n_mfcc) self.context_window = ContextWindow( left_frames=left_frames, right_frames=right_frames, )
def test_istft(device): from speechbrain.processing.features import STFT from speechbrain.processing.features import ISTFT fs = 16000 inp = torch.randn([10, 16000], device=device) inp = torch.stack(3 * [inp], -1) compute_stft = STFT(sample_rate=fs).to(device) compute_istft = ISTFT(sample_rate=fs).to(device) out = compute_istft(compute_stft(inp), sig_length=16000) assert torch.sum(torch.abs(inp - out) < 5e-5) >= inp.numel() - 5 assert torch.jit.trace(compute_stft, inp) assert torch.jit.trace(compute_istft, compute_stft(inp))
def generalized_eigenvalue(audio_file, diffuse=True, show_plots=False): xs_speech = read_audio(audio_file) xs_speech = xs_speech.unsqueeze(0) stft = STFT(sample_rate=fs) cov = Covariance() gev = Gev() istft = ISTFT(sample_rate=fs) Xs = stft(xs_speech) SSs = cov(Xs) NNs = cov(Xs) Ys_gev = gev(Xs, SSs, NNs) ys_gev = istft(Ys_gev) if show_plots: plt.figure(1) plt.title("Noisy signal at microphone 1") plt.imshow( torch.transpose( torch.log(Xs[0, :, :, 0, 0] ** 2 + Xs[0, :, :, 1, 0] ** 2), 1, 0 ), origin="lower", ) plt.figure(2) plt.title("Noisy signal at microphone 1") plt.plot(xs_speech.squeeze()[:, 0]) plt.figure(3) plt.title("Beamformed signal") plt.imshow( torch.transpose( torch.log( Ys_gev[0, :, :, 0, 0] ** 2 + Ys_gev[0, :, :, 1, 0] ** 2 ), 1, 0, ), origin="lower", ) plt.figure(4) plt.title("Beamformed signal") plt.plot(ys_gev.squeeze()) plt.show() return ys_gev.squeeze()
def delay_and_sum(audio_file, show_plots=False): xs_speech = read_audio(audio_file) xs_speech = xs_speech.unsqueeze(0) stft = STFT(sample_rate=fs) cov = Covariance() gccphat = GccPhat() delaysum = DelaySum() istft = ISTFT(sample_rate=fs) Xs = stft(xs_speech) XXs = cov(Xs) tdoas = gccphat(XXs) Ys_ds = delaysum(Xs, tdoas) ys_ds = istft(Ys_ds) if show_plots: plt.figure(1) plt.title("Noisy signal at microphone 1") plt.imshow( torch.transpose( torch.log(Xs[0, :, :, 0, 0]**2 + Xs[0, :, :, 1, 0]**2), 1, 0), origin="lower", ) plt.figure(2) plt.title("Noisy signal at microphone 1") plt.plot(xs_speech.squeeze()[:, 0]) plt.figure(3) plt.title("Beamformed signal") plt.imshow( torch.transpose( torch.log(Ys_ds[0, :, :, 0, 0]**2 + Ys_ds[0, :, :, 1, 0]**2), 1, 0, ), origin="lower", ) plt.figure(4) plt.title("Beamformed signal") plt.plot(ys_ds.squeeze()) return ys_ds.squeeze()
import os import sys import torch import logging import speechbrain as sb from hyperpyyaml import load_hyperpyyaml from speechbrain.utils.distributed import run_on_main from speechbrain.processing.features import STFT, ISTFT from speechbrain.processing.multi_mic import Covariance from speechbrain.processing.multi_mic import GccPhat from speechbrain.processing.multi_mic import DelaySum logger = logging.getLogger(__name__) stft = STFT(sample_rate=16000) cov = Covariance() gccphat = GccPhat() delaysum = DelaySum() istft = ISTFT(sample_rate=16000) # Define training procedure class ASR_Brain(sb.Brain): def compute_forward(self, batch, stage): "Given an input batch it computes the phoneme probabilities." batch = batch.to(self.device) wavs1, wav_lens1 = batch.sig1 wavs2, wav_lens2 = batch.sig2 wavs3, wav_lens3 = batch.sig3 wavs4, wav_lens4 = batch.sig4