Example #1
0
    def forward(self, signal):
        with torch.no_grad():
            x = torch.stft(signal,
                           n_fft=self.n_fft,
                           hop_length=self.hop_length,
                           win_length=self.n_fft,
                           window=self.window)
            real = x[..., 0]
            imag = x[..., 1]
            mag = torch.sqrt(real**2 + imag**2)
            phase = torch.atan2(imag, real)

        mix = torch.stack((mag, phase), dim=-1)
        return mix
Example #2
0
    def forward(self, x):
        batch, channels, time = x.shape
        x = F.pad(x, (0, self.hop_size))
        x = torch.stft(x.view(batch, -1),
                       n_fft=self.window_size,
                       hop_length=self.hop_size,
                       win_length=self.window_size,
                       normalized=True,
                       center=False)

        x = torch.abs(x[:, 1:, :, 0])
        features, x = self.main(x, return_features=True)
        x = self.judge(x)
        return [features], [x]
Example #3
0
def smooth(sin, kernel_size, c):
    f = torch.stft(
        sin,
        n_fft=kernel_size,
        hop_length=1,
        pad_mode="reflect",
        return_complex=True,
    )  # N,kernel_size/2+1,T+1
    absf = torch.abs(f)
    mxabsf = torch.max(absf, dim=1)[0][:, None, :]
    sparsef = (mxabsf == absf).float() * f
    del absf, mxabsf
    recsin = torch.fft.ifft(sparsef, dim=1).real[:, 0, :]
    return recsin
Example #4
0
    def forward(self, x, amplitude):
        x = x.permute(0, 2, 3, 1)
        y = self.pa(amplitude.permute(0, 2, 3, 1), x)

        z = functional.istft(y, hparams.fft_size)
        z = torch.stft(z, hparams.fft_size)

        input = torch.cat((x, y, z), dim=3)
        z = z.permute(0, 3, 1, 2)
        del x, y
        input = input.permute(0, 3, 1, 2)
        input = self.dnn1(input)

        return z, z - self.dnn3(self.dnn2(input) + input)
Example #5
0
def _test_istft_is_inverse_of_stft(kwargs):
    # generates a random sound signal for each tril and then does the stft/istft
    # operation to check whether we can reconstruct signal
    for data_size in [(2, 20), (3, 15), (4, 10)]:
        for i in range(100):

            sound = common_utils.random_float_tensor(i, data_size)

            stft = torch.stft(sound, **kwargs)
            estimate = torchaudio.functional.istft(stft,
                                                   length=sound.size(1),
                                                   **kwargs)

            _compare_estimate(sound, estimate)
Example #6
0
 def forward(self, audio):
     p = (self.n_fft - self.hop_length) // 2
     audio = F.pad(audio, (p, p), "reflect").squeeze(1)
     fft = torch.stft(
         audio,
         n_fft=self.n_fft,
         hop_length=self.hop_length,
         win_length=self.win_length,
         window=self.window,
         center=False,
     )
     real_part, imag_part = fft.unbind(-1)
     magnitude = torch.sqrt(real_part**2 + imag_part**2)
     return magnitude
Example #7
0
    def forward(self, x, seq_len):
        dtype = x.dtype
        x = x.to(torch.float)

        seq_len = self.get_seq_len(seq_len)

        # dither
        if self.dither > 0:
            x += self.dither * torch.randn_like(x)

        # do preemphasis
        if hasattr(self, 'preemph') and self.preemph is not None:
            x = torch.cat(
                (x[:, 0].unsqueeze(1), x[:, 1:] - self.preemph * x[:, :-1]),
                dim=1)

        # get spectrogram
        x = torch.stft(x,
                       n_fft=self.n_fft,
                       hop_length=self.hop_length,
                       win_length=self.win_length,
                       center=self.center,
                       window=self.window.to(torch.float))
        x = torch.sqrt(x.pow(2).sum(-1))

        # log features if required
        if self.log:
            x = torch.log(x + 1e-20)

        # frame splicing if required
        if self.frame_splicing > 1:
            x = splice_frames(x, self.frame_splicing)

        # normalize if required
        if self.normalize:
            x = normalize_batch(x, seq_len, normalize_type=self.normalize)

        # mask to zero any values beyond seq_len in batch, pad to multiple of
        # `pad_to` (for efficiency)
        max_len = x.size(-1)
        mask = torch.arange(max_len).to(seq_len.device)
        mask = mask.expand(x.size(0), max_len) >= seq_len.unsqueeze(1)
        x = x.masked_fill(mask.unsqueeze(1).to(device=x.device), 0)
        del mask
        if self.pad_to > 0:
            pad_amt = x.size(-1) % self.pad_to
            if pad_amt != 0:
                x = nn.functional.pad(x, (0, self.pad_to - pad_amt))

        return x.to(dtype)
    def to_spec_complex(self, input_signal: torch.Tensor):
        """
        input_signal: *, signal
        output: *, N, T, 2
        """
        # if input_signal.dtype != self.window.dtype or input_signal.device != self.window.device :
        #     self.window = torch.as_tensor(self.window, dtype=input_signal.dtype, device=input_signal.device)
        # else:
        #     window = self.window

        return torch.stft(input_signal,
                          self.n_fft,
                          self.hop_length,
                          window=self.window)
Example #9
0
 def computec(inputs, targets):
     eps = 0.0001
     loss = []
     n_ffts = [2048, 512, 128, 32]
     for n_fft in n_ffts:
         spec_inputs = torch.stft(torch.mean(inputs, dim=1), n_fft=n_fft)
         spec_inputs = spec_inputs[:, :, :, 0]**2 + spec_inputs[:, :, :,
                                                                1]**2
         spec_inputs = spec_inputs + eps * torch.ones_like(spec_inputs)
         spec_targets = torch.stft(torch.mean(targets, dim=1), n_fft=n_fft)
         spec_targets = spec_targets[:, :, :, 0]**2 + spec_targets[:, :, :,
                                                                   1]**2
         spec_targets = spec_targets + eps * torch.ones_like(spec_targets)
         # [B, N, F]
         L1 = torch.mean(torch.mean(torch.abs(spec_inputs - spec_targets),
                                    dim=2),
                         dim=1)
         L1_log = torch.mean(torch.mean(
             torch.abs(torch.log(spec_inputs) - torch.log(spec_targets)),
             dim=2),
                             dim=1)
         loss.append(L1 + L1_log)
     return torch.mean(torch.stack(loss, dim=1))
Example #10
0
 def forward(self, x):
     stfts = []
     for i, scale in enumerate(self.scales):
         cur_fft = torch.stft(x, n_fft=scale, window=self.windows[i], hop_length=int((1-self.overlap)*scale), center=False)
         stfts.append(amp(cur_fft))
     if (self.reshape):
         stft_tab = []
         for b in range(x.shape[0]):
             cur_fft = []
             for s, _ in enumerate(self.scales):
                 cur_fft.append(stfts[s][b])
             stft_tab.append(cur_fft)
         stfts = stft_tab
     return stfts
    def closure(status_dict):
        x = status_dict['x']
        pre_spec = status_dict['pre_spec']

        new_spec = torch.stft(x, n_fft, **processed_args)
        output = new_spec.abs()
        new_spec = new_spec - pre_spec * lr
        status_dict['pre_spec'] = new_spec

        norm = new_spec.abs().add_(1e-16)
        new_spec = new_spec * target_spec / norm
        x, _ = istft(new_spec, norm_envelope=norm_envelope)
        status_dict['x'] = x
        return output
Example #12
0
    def __call__(self, waveform):
        stft_matrix = torch.stft(waveform, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length,
                   window=self.window, onesided=self.onesided)
        if self.abs_val:

            stft_matrix = torch.sqrt(stft_matrix[:, :, :, 0] ** 2 + stft_matrix[:, :, :, 1] ** 2)

        if self.log_val:
            stft_matrix = np.log10(stft_matrix + EPS)
        if self.transpose:
            stft_matrix = stft_matrix.T

        #return np.float32(stft_matrix)
        return stft_matrix
Example #13
0
    def forward(self, x):
        if len(x.shape) == 3:
            x = x.squeeze(1)

        if not config.USE_CACHED_PADDING:
            x = nn.functional.pad(x, (0, self.nfft - self.hop))

        S = torch.stft(x, self.nfft, self.hop, 512, center=self.center)
        S = 2 * module(S) / 512
        S_mel = self.mel.matmul(S)

        if self.training:
            S_mel = S_mel[..., :x.shape[-1] // self.hop]
        return (torch.log10(torch.clamp(S_mel, min=1e-5)) + 5) / 5
Example #14
0
 def get_mel(self, x):
     batch_size = x.size(0)
     S = torch.stft(x,
                    self.win_size,
                    self.hop_size,
                    window=self.window,
                    pad_mode='constant').pow(2).sum(3)
     mel_filt = torch.sparse_coo_tensor(self.filter_idx, self.filter_value,
                                        self.filter_size)
     N = S.size(1)
     mel_S = mel_filt @ S.transpose(0, 1).contiguous().view(N, -1)
     # compress
     mel_S.add_(1e-7).log_()
     return mel_S.view(self.n_mels, batch_size, -1).transpose(0, 1)
Example #15
0
def stft(x, fft_size, hop_size, win_length, window):
    """Perform STFT and convert to magnitude spectrogram.

    Args:
        x (Tensor): Input signal tensor (B, T).
        fft_size (int): FFT size.
        hop_size (int): Hop size.
        win_length (int): Window length.
        window (str): Window function type.

    Returns:
        Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).

    """
    if is_pytorch_17plus:
        x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=False)
    else:
        x_stft = torch.stft(x, fft_size, hop_size, win_length, window)
    real = x_stft[..., 0]
    imag = x_stft[..., 1]
    # NOTE(kan-bayashi): clamp is needed to avoid nan or inf
    spectrum = torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7))
    return spectrum
Example #16
0
    def test_stft_roundtrip_complex_window(self, device, dtype):
        test_args = list(
            product(
                # input
                (torch.randn(600, device=device, dtype=dtype),
                 torch.randn(807, device=device, dtype=dtype),
                 torch.randn(12, 14, device=device, dtype=dtype),
                 torch.randn(9, 6, device=device, dtype=dtype)),
                # n_fft
                (50, 27),
                # hop_length
                (None, 10),
                # pad_mode
                (
                    "constant", ),
                # normalized
                (True, False),
            ))
        for args in test_args:
            x, n_fft, hop_length, pad_mode, normalized = args
            window = torch.rand(n_fft, device=device, dtype=torch.cdouble)
            x_stft = torch.stft(x,
                                n_fft=n_fft,
                                hop_length=hop_length,
                                window=window,
                                center=True,
                                pad_mode=pad_mode,
                                normalized=normalized)
            self.assertEqual(x_stft.dtype, torch.cdouble)
            self.assertEqual(x_stft.size(-2), n_fft)  # Not onesided

            x_roundtrip = torch.istft(x_stft,
                                      n_fft=n_fft,
                                      hop_length=hop_length,
                                      window=window,
                                      center=True,
                                      normalized=normalized,
                                      length=x.size(-1),
                                      return_complex=True)
            self.assertEqual(x_stft.dtype, torch.cdouble)

            if not dtype.is_complex:
                self.assertEqual(x_roundtrip.imag,
                                 torch.zeros_like(x_roundtrip.imag),
                                 atol=1e-6,
                                 rtol=0)
                self.assertEqual(x_roundtrip.real, x)
            else:
                self.assertEqual(x_roundtrip, x)
Example #17
0
    def forward(self, x, is_istft=True):

        #         print(x.shape)
        x = torch.stft(input=x,
                       n_fft=self.n_fft,
                       hop_length=self.hop_length,
                       normalized=True)
        x = x.narrow(2, 0, x.shape[2] - 1)
        x = x.unsqueeze(1)
        #         print(x.shape)

        # downsampling/encoding
        d0 = self.downsample0(x)
        d1 = self.downsample1(d0)
        d2 = self.downsample2(d1)
        d3 = self.downsample3(d2)
        d4 = self.downsample4(d3)

        # upsampling/decoding
        u0 = self.upsample0(d4)
        # skip-connection
        c0 = torch.cat((u0, d3), dim=1)

        u1 = self.upsample1(c0)
        c1 = torch.cat((u1, d2), dim=1)

        u2 = self.upsample2(c1)
        c2 = torch.cat((u2, d1), dim=1)

        u3 = self.upsample3(c2)
        c3 = torch.cat((u3, d0), dim=1)

        u4 = self.upsample4(c3)

        # u4 - the mask
        if x.shape[3] < u4.shape[3]:
            x = F.pad(x, (0, 0, 0, 1))
        if u4.shape[3] < x.shape[3]:
            x = x.narrow(3, 0, u4.shape[3])
        #         print(u4.shape, x.shape)
        output = u4 * x
        if is_istft:
            output = torch.squeeze(output, 1)
            output = torch.istft(output,
                                 n_fft=self.n_fft,
                                 hop_length=self.hop_length,
                                 normalized=True)

        return output
	def pre_process(self, data, data_length):

		# ToDo - write the code for generating pitch features

		fbank = self.fbank[data.get_device()]

		pre_emphasis = config.fbank['pre_emphasis']
		frame_size = config.fbank['frame_size']
		frame_stride = config.fbank['frame_stride']
		n_fft = config.fbank['n_fft']
		rate = config.fbank['rate']

		emphasized_data = torch.zeros_like(data).float()

		if config.use_cuda:
			emphasized_data = emphasized_data.to(data.device)

		emphasized_data[:, 1:] = data[:, 1:] - pre_emphasis * data[:, :-1]
		emphasized_data[:, 0] = data[:, 0]

		frame_length, frame_step = frame_size * rate, frame_stride * rate  # Convert from seconds to samples
		frame_length = int(frame_length)
		frame_step = int(frame_step)

		mag_frames = torch.norm(
			torch.stft(
				emphasized_data,
				n_fft=n_fft,
				hop_length=frame_step,
				win_length=frame_length,
				window=torch.hamming_window(frame_length).to(emphasized_data.device),
				pad_mode='constant'
			), dim=3).transpose(2, 1)

		pow_frames = ((1.0 / n_fft) * (mag_frames ** 2))  # Power Spectrum

		filter_banks = torch.matmul(pow_frames, fbank.transpose(1, 0))
		filter_banks[filter_banks == 0] = 2.220446049250313e-16
		filter_banks = 20 * torch.log10(filter_banks)  # dB
		filter_banks -= (torch.mean(filter_banks, dim=(0, 1), keepdim=True) + 1e-8)

		if data_length is None:
			ilens = (torch.ones([filter_banks.shape[0]])*filter_banks.shape[1]).long()
		else:
			ilens = torch.FloatTensor([data_length_i//frame_step + 1 for data_length_i in data_length]).long()

		# for filter_banks.shape[0]

		return filter_banks, ilens
    def forward(self, mixture):
        """
        Args:
            mixture: [M, T], M is batch size, T is #samples
        Returns:
            mixture_w: [M, N, K], where K = (T-L)/(L/2)+1 = 2T/L-1
        """
        mixture_f = torch.stft(mixture,self.L,self.L//2,center = False)#[M, self.L/2+1 , K,2]
        mixture_f = (mixture_f[:,:,:,0]**2+mixture_f[:,:,:,1]**2)**0.5 #[M, self.L/2+1 , K

        mixture = torch.unsqueeze(mixture, 1)  # [M, 1, T]
        mixture_w = F.relu(self.conv1d_U(mixture))  # [M, N, K]
        mixture_all = torch.cat((mixture_w, torch.log1p(mixture_f)),1) # [M, N+self.L/2+1, K] torch.log1p
        mixture_all = self.se(mixture_all)
        return mixture_w,mixture_all
Example #20
0
def complex_stft(x, fft_size, hop_size, win_length, window):
    """Perform STFT and convert to magnitude spectrogram.
    Args:
        x (Tensor): Input signal tensor (B, T).
        fft_size (int): FFT size.
        hop_size (int): Hop size.
        win_length (int): Window length.
        window (str): Window function type.
    Returns:
        Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
    """
    if is_pytorch_17plus:
        x_stft = torch.stft(x, fft_size, hop_size, win_length, window)
    else:
        x_stft = torch.stft(x,
                            fft_size,
                            hop_size,
                            win_length,
                            window,
                            return_complex=True)
    real = x_stft[..., 0]
    imag = x_stft[..., 1]

    return real.transpose(2, 1), imag.transpose(2, 1)
Example #21
0
    def __call__(self, x):
        X_left = torch.stft(x[:, 0, :],
                       n_fft=self.n_fft,
                       hop_length=self.n_hop,
                       win_length=self.n_fft,
                       window=to_device(self.window, x.device),
                       onesided=True,
                       center=True,
                       pad_mode='constant',
                       normalized=True)
        # compute power from real and imag parts (magnitude^2)
        X_left.pow_(2.0)
        X_left = X_left[:,:,:,0] + X_left[:,:,:,1]
        X_left = X_left.unsqueeze(1) # Add channel dimension

        if (x.size(1) > 1):
            X_right = torch.stft(x[:, 1, :],
                           n_fft=self.n_fft,
                           hop_length=self.n_hop,
                           win_length=self.n_fft,
                           window=to_device(self.window, x.device),
                           onesided=True,
                           center=True,
                           pad_mode='constant',
                           normalized=True)        
            # compute power from real and imag parts (magnitude^2)
            X_right.pow_(2.0)
            X_right = X_right[:,:,:,0] + X_right[:,:,:,1]
            X_right = X_right.unsqueeze(1) # Add channel dimension
            res = torch.cat([X_left, X_right], dim=1) 
            assert(res.dim() == 4) # Check dim (n sample * channels * h * w)
            return res
            
        else:
            assert(X_left.dim() == 4) # Check dim (n sample * channels * h * w)
            return X_left # Return only mono channel
Example #22
0
 def forward(self, waveforms):
     """x is perhaps (batch, freq, time).
     returns (batch, n_mel, time)"""
     mag_stfts = torch.stft(waveforms,
                            self.n_fft,
                            hop_length=self.hop_length,
                            window=self.window).pow(2).sum(
                                -1)  # (batch, n_freq, time)
     mag_stfts = torch.sqrt(
         mag_stfts + EPS)  # without EPS, backpropagating can yield NaN.
     # Project onto the pseudo-cqt basis
     mag_melgrams = torch.matmul(self.mel_fb, mag_stfts)
     if self.log:
         mag_melgrams = to_log(mag_melgrams)
     return mag_melgrams
Example #23
0
def compute_stft(wav):
    """
    Computes stft feature from wav

    Args:
        wav (Tensor): B x L
    """
    stft = torch.stft(wav, win_length, hop_length=hop_length, window=win)

    # only keep freqs smaller than self.F
    stft = stft[:, :F, :, :]
    real = stft[:, :, :, 0]
    im = stft[:, :, :, 1]
    mag = torch.sqrt(real**2 + im**2)
    return stft, mag
Example #24
0
def split_wav(filepath):
    wavData,fs = torchaudio.load(filepath)
    wavData = torch.mean(wavData, dim=0)

    complex_mix = torch.stft(wavData, n_fft = n_fft, hop_length = hop_sz, window = window_fn)
    complex_mix_pow = complex_mix.pow(2).sum(-1)
    complex_mix_mag = torch.sqrt(complex_mix_pow)

    n_splits = (complex_mix_mag.size()[1]//hop_sz)

    complex_mix_mag = complex_mix_mag[:,0:n_splits*512]

    chunks = torch.chunk(complex_mix_mag,n_splits,1)
    stack = torch.stack(chunks,dim = 0)
    return stack
Example #25
0
 def stft(self, x):
     ''' 
      Must be 1 x len
     '''
     if x.shape[-1] < self.n_fft:
         shape = list(x.shape)
         shape[-1] = self.win_length
         x_ = torch.zeros(shape)
         x_[:, :x.shape[-1]] = x
         x = x_
     return torch.stft(x,
                       n_fft=self.n_fft,
                       hop_length=self.hop_length,
                       win_length=self.win_length,
                       window=self.window.to(dtype=torch.float))
def spectral_local_response_normalization(x, size=3, n_fft=512):
    x_stft = torch.stft(x, n_fft)
    amplitude = torch.sqrt(x_stft[:, :, :, 0]**2 + x_stft[:, :, :, 1]**2)
    x_stft_norm = torch.mean(amplitude, dim=1, keepdim=True)
    x_avg = F.avg_pool1d(x_stft_norm,
                         kernel_size=size,
                         stride=1,
                         padding=size // 2,
                         count_include_pad=False)
    normalized_stft = x_stft / x_avg.unsqueeze(3)
    #normalized_stft = x_stft
    normalized_x = istft(normalized_stft)
    padding_length = (normalized_x.shape[1] - x.shape[1]) // 2
    normalized_x = normalized_x[:, :x.shape[1]]
    return normalized_x
def torch_spectrogram(signal, n_fft=None, hop=None, window=np.hanning):
    window = window(n_fft) / window(n_fft).sum()
    window = torch.from_numpy(window).to(signal.device).float()
    try:
        stft = torch.stft(signal,
                          n_fft=n_fft,
                          hop_length=hop,
                          win_length=n_fft,
                          window=window,
                          onesided=True,
                          center=False,
                          normalized=False,
                          return_complex=False)
    except:
        stft = torch.stft(signal,
                          n_fft=n_fft,
                          hop_length=hop,
                          win_length=n_fft,
                          window=window,
                          onesided=True,
                          center=False,
                          normalized=False)
    stft = (stft**2).sum(-1)
    return stft
Example #28
0
def multiscale_fft(signal, scales, overlap):
    stfts = []
    for s in scales:
        S = torch.stft(
            signal,
            s,
            int(s * (1 - overlap)),
            s,
            torch.hann_window(s).to(signal),
            True,
            normalized=True,
            return_complex=True,
        ).abs()
        stfts.append(S)
    return stfts
Example #29
0
    def _stft(self, data: torch.Tensor, n_fft: int, hop_length: int):
        win_length = n_fft
        window = self._stft_window

        stft = torch.stft(
            data,
            n_fft=n_fft,
            hop_length=hop_length,
            win_length=win_length,
            window=window,
            center=True,
            pad_mode='reflect',
            normalized=False,
        )
        return stft
Example #30
0
 def __call__(self, x):
     # B x D x T x 2
     o = torch.stft(
         x,
         self.n_fft,
         self.hop_length,
         self.win_length,
         self.window,
         center=True,
         pad_mode="reflect",  # compatible with audio.py
         normalized=False,
         onesided=True)
     M = o[:, :, :, 0]
     P = o[:, :, :, 1]
     return torch.sqrt(torch.clamp(M**2 + P**2, min=1e-8))
Example #31
0
    def __call__(self, sig):
        """
        Args:
            sig (Tensor or Variable): Tensor of audio of size (c, n)

        Returns:
            spec_f (Tensor or Variable): channels x hops x n_fft (c, l, f), where channels
                is unchanged, hops is the number of hops, and n_fft is the
                number of fourier bins, which should be the window size divided
                by 2 plus 1.

        """
        sig, is_variable = _check_is_variable(sig)

        assert sig.dim() == 2

        spec_f = torch.stft(sig, self.ws, self.hop, self.n_fft,
                            True, True, self.window, self.pad)  # (c, l, n_fft, 2)
        spec_f /= self.window.pow(2).sum().sqrt()
        spec_f = spec_f.pow(2).sum(-1)  # get power of "complex" tensor (c, l, n_fft)
        return spec_f if is_variable else spec_f.data