def circular_correlation(self, left_x, right_x): """ Computes the circular correlation of two vectors a and b via their fast fourier transforms In python code, ifft(np.conj(fft(a)) * fft(b)).real :param left_x: () :param right_x: (a - j * b) * (c + j * d) = (ac + bd) + j * (ad - bc) :return: """ left_x_real = left_x left_x_imag = chainer.as_variable( self.xp.zeros_like(left_x_real, dtype=self.xp.float32)) left_x_fft_real, left_x_fft_imag = functions.fft( (left_x_real, left_x_imag)) right_x_real = right_x right_x_imag = chainer.as_variable( self.xp.zeros_like(right_x_real, dtype=self.xp.float32)) right_x_fft_real, right_x_fft_imag = functions.fft( (right_x_real, right_x_imag)) prod_fft_real = left_x_fft_real * right_x_fft_real + left_x_fft_imag * right_x_fft_imag prod_fft_imag = left_x_fft_real * right_x_fft_imag - left_x_fft_imag * right_x_fft_real ifft_real, _ = functions.ifft((prod_fft_real, prod_fft_imag)) return ifft_real
def fourier_transform(self, x): """ :param x: (mb, N, hidden_dim) :return: tuple of x_fft_real and x_fft_imag, (mb, N, hidden_dim) """ x_real = x x_imag = chainer.as_variable( self.xp.zeros_like(x_real, dtype=self.xp.float32)) x_fft_real, x_fft_imag = functions.fft((x_real, x_imag)) return x_fft_real, x_fft_imag
def chainer_fft_spectrogram(xr, xi, forward=True): T, O = xr.shape if (forward): # STFT yr, yi = F.fft((xr, xi)) yr = F.transpose(yr[:, :int(O / 2) + 1]) yi = F.transpose(yi[:, :int(O / 2) + 1]) else: # iSTFT xr_cnj = F.fliplr(xr[:, 1:O - 1]) xi_cnj = -F.fliplr(xi[:, 1:O - 1]) xr = F.concat((xr, xr_cnj), axis=1) xi = F.concat((xi, xi_cnj), axis=1) yr, yi = F.ifft((xr, xi)) yr = F.transpose(yr) yi = F.transpose(yi) return yr, yi
def stft(x, frame_length=1024, hop_length=512): # ..., FFT axis if not isinstance(x, chainer.Variable): x = chainer.as_variable(x) xp = x.xp pad_len = (x.shape[-1] // hop_length - frame_length // hop_length + 1) * hop_length + frame_length pad = pad_len - x.shape[-1] if pad > 0: shape = list(x.shape) pad = xp.zeros(shape[:-1] + [pad]).astype(x.dtype) x = F.concat((x, pad), -1) index = frame(np.arange(x.shape[-1]), frame_length, hop_length).T tmp = x[..., index] * xp.hamming(frame_length).astype(x.dtype) yr, yi = F.fft((tmp, xp.zeros(tmp.shape).astype(x.dtype))) return yr[..., :frame_length // 2 + 1], yi[..., :frame_length // 2 + 1]
def stft(x, window): # print(x.dtype) wSize = window.shape[0] xSize = x.shape[-1] bSize = x.shape[0] x = x.reshape(bSize, 1, xSize) # h = xp.vstack([x[i:xSize-wSize+i] for i in range(256)]).T # h = F.transpose(F.vstack([x[:,i:xSize-wSize+i] for i in range(wSize)]), axes=(0,2,1)) h = F.concat([x[:, :, i:xSize - wSize + i + 1] for i in range(wSize)], axis=1) h = F.transpose(h, axes=(0, 2, 1)) h = h.reshape(bSize, 1, xSize - wSize + 1, wSize) # print(h.dtype) # print(h.dtype) h = h * window h = F.fft((h, xp.zeros(h.shape, dtype=xp.float32))) h = F.concat((h[0], h[1]), axis=1) h = F.transpose(h, axes=(0, 1, 3, 2)) return h