def test_stft_istft_identity(ctx, window_size, stride, fft_size, window_type, center, pad_mode): backend = ctx.backend[0].split(":")[0] if backend == 'cuda': pytest.skip( 'CUDA Convolution N-D is only supported in CUDNN extension') x_shape = create_stft_input_shape(window_size) x = np.random.randn(*x_shape) # Skip for NOLA condition violation length = x_shape[1] if is_nola_violation(window_type, window_size, stride, fft_size, length, center): pytest.skip('NOLA condition violation.') return x = nn.Variable.from_numpy_array(x) with nn.context_scope(ctx): yr, yi = F.stft(x, window_size, stride, fft_size, window_type, center, pad_mode) z = F.istft(yr, yi, window_size, stride, fft_size, window_type, center, pad_mode="constant") z.forward() assert (np.allclose(x.d, z.d, atol=1e-5, rtol=1e-5))
def test_istft(ctx, window_size, stride, fft_size, window_type, center): backend = ctx.backend[0].split(":")[0] if backend == 'cuda': pytest.skip('CUDA Convolution N-D is only supported in CUDNN extension') # clear all previous STFT conv/deconv kernels nn.clear_parameters() # Make sure that iSTFT(STFT(x)) = x x = np.random.randn(1, window_size * 10) nx = nn.Variable.from_numpy_array(x) with nn.context_scope(ctx): nyr, nyi = F.stft(nx, window_size=window_size, stride=stride, fft_size=fft_size, window_type=window_type, center=center) nz = F.istft(nyr, nyi, window_size=window_size, stride=stride, fft_size=fft_size, window_type=window_type, center=center) nz.forward() invalid = window_size - stride assert(np.allclose(nx.d[:, invalid:-invalid], nz.d[:, invalid:-invalid], atol=1e-5, rtol=1e-5))
def stft_backward(inputs, window_size, stride, fft_size, window_type='hanning', center=True, pad_mode='reflect', as_istft_backward=False): """ Args: inputs (list of nn.Variable): Incomming grads/inputs to/of the forward function. kwargs (dict of arguments): Dictionary of the corresponding function arguments. Return: list of Variable: Return the gradients wrt inputs of the corresponding function. """ dyr = inputs[0] dyi = inputs[1] dx = F.istft(dyr, dyi, window_size, stride, fft_size, window_type, center, pad_mode, as_stft_backward=not as_istft_backward) return dx
def test_istft(window_size, stride, fft_size, window_type, center): # clear all previous STFT conv/deconv kernels nn.clear_parameters() # Make sure that iSTFT(STFT(x)) = x x = np.random.randn(1, window_size * 10) nx = nn.Variable.from_numpy_array(x) nyr, nyi = F.stft(nx, window_size=window_size, stride=stride, fft_size=fft_size, window_type=window_type, center=center) nz = F.istft(nyr, nyi, window_size=window_size, stride=stride, fft_size=fft_size, window_type=window_type, center=center) nz.forward() invalid = window_size - stride assert (np.allclose(nx.d[:, invalid:-invalid], nz.d[:, invalid:-invalid], atol=1e-5, rtol=1e-5))
def check_nola_violation(y_r, y_i, window_size, stride, fft_size, window_type, center, pad_mode, as_stft_backward): # Check reference raise try: # Use PyTorch to check NOLA condition becasue librosa does not raise. # If PyTorch is not installed, NOLA condition test for reference is skipped. import torch def ref_istft_torch(y_r, y_i, window_size, stride, fft_size, window_type, center): y_r = np.reshape(y_r, y_r.shape + (1, )) y_i = np.reshape(y_i, y_i.shape + (1, )) y = np.concatenate((y_r, y_i), axis=3) y = torch.tensor(y) y = torch.view_as_complex(y) window = torch.tensor(create_window_func(window_type, window_size)) x_shape = create_stft_input_shape(window_size) length = x_shape[1] x = torch.istft(y, n_fft=fft_size, hop_length=stride, win_length=window_size, window=window, center=center, length=length) return x with pytest.raises(RuntimeError, match=r"window overlap add"): # NOLA condition is checked during forward execution. ref_istft_torch(y_r, y_i, window_size, stride, fft_size, window_type, center) except: # Install PyTorch to check NOLA condition validation of reference istft. pass # Check NNabla raise y_r = nn.Variable.from_numpy_array(y_r) y_i = nn.Variable.from_numpy_array(y_i) with pytest.raises( RuntimeError, match=r"NOLA\(Nonzero Overlap Add\) condition is not met."): # NOLA condition is checked during setup. _ = F.istft(y_r, y_i, window_size, stride, fft_size, window_type, center, pad_mode, as_stft_backward)
def ref_istft(y_r, y_i, window_size, stride, fft_size, window_type, center, pad_mode, as_stft_backward): if not as_stft_backward: # Use librosa.istft as the forward reference. # Convert to librosa.istft input format. y = y_r + 1j * y_i # Get original signal length. x_shape = create_stft_input_shape(window_size) length = x_shape[1] # librosa.istft does not support batched input. b = y.shape[0] xs = [] for i in range(b): x = librosa.istft(y[i], hop_length=stride, win_length=window_size, window=window_type, center=center, length=length) xs.append(x) return np.array(xs) else: # Use F.stft backward as the reference y_r = nn.Variable.from_numpy_array(y_r) y_i = nn.Variable.from_numpy_array(y_i) # Just create stft inputs x = F.istft(y_r, y_i, window_size, stride, fft_size, window_type, center, pad_mode, True) # Execute istft backward x.need_grad = True x.grad.zero() z_r, z_i = F.stft(x, window_size, stride, fft_size, window_type, center, pad_mode) z_r.g = y_r.d z_i.g = y_i.d z = F.sink(z_r, z_i, one_input_grad=False) z.forward() z.backward() return x.g
def ref_stft(x, window_size, stride, fft_size, window_type, center, pad_mode, as_istft_backward): if not as_istft_backward: # Use librosa.stft as the forward reference. # librosa.stft does not support batched input. window_type = 'hann' if window_type == 'hanning' else window_type b = x.shape[0] ys = [] for i in range(b): y = librosa.stft(x[i], n_fft=fft_size, hop_length=stride, win_length=window_size, window=window_type, center=center, pad_mode=pad_mode) ys.append(y) # Convert to nnabla stft output format ys = np.array(ys) y_r = ys.real y_i = ys.imag return y_r, y_i else: # Use F.istft backward as the reference x = nn.Variable.from_numpy_array(x) # Just create istft inputs y_r, y_i = F.stft(x, window_size, stride, fft_size, window_type, center, pad_mode) # Execute istft backward y_r.need_grad = True y_i.need_grad = True y_r.grad.zero() y_i.grad.zero() z = F.istft(y_r, y_i, window_size, stride, fft_size, window_type, center, pad_mode) z.forward() z.backward(x.data) return y_r.g, y_i.g
def istft(y_r, y_i, window_size, stride, fft_size, window_type='hanning', center=True): '''Workaround wrapper of ISTFT for fixing a bug in nnabla<=1.15.0 ''' from utils import get_nnabla_version_integer if get_nnabla_version_integer() > 11500: return F.istft(**locals()) import numpy as np from nnabla.parameter import get_parameter, get_parameter_or_create conv_cos = get_parameter('conv_cos') conv_sin = get_parameter('conv_sin') if conv_cos is None or conv_sin is None: if window_type == 'hanning': window_func = np.hanning(window_size + 1)[:-1] elif window_type == 'hamming': window_func = np.hamming(window_size + 1)[:-1] elif window_type == 'rectangular' or window_type is None: window_func = np.ones(window_size) else: raise ValueError("Unknown window type {}.".format(window_type)) # pad window if `fft_size > window_size` if fft_size > window_size: diff = fft_size - window_size window_func = np.pad(window_func, (diff // 2, diff - diff // 2), mode='constant') elif fft_size < window_size: raise ValueError( "FFT size has to be as least as large as window size.") # compute inverse STFT filter coefficients if fft_size % stride != 0: raise ValueError("FFT size needs to be a multiple of stride.") inv_window_func = np.zeros_like(window_func) for s in range(0, fft_size, stride): inv_window_func += np.roll(np.square(window_func), s) mat_cos = np.zeros((fft_size // 2 + 1, 1, fft_size)) mat_sin = np.zeros((fft_size // 2 + 1, 1, fft_size)) for w in range(fft_size // 2 + 1): alpha = 1.0 if w == 0 or w == fft_size // 2 else 2.0 alpha /= fft_size for t in range(fft_size): mat_cos[w, 0, t] = alpha * \ np.cos(2. * np.pi * w * t / fft_size) mat_sin[w, 0, t] = alpha * \ np.sin(2. * np.pi * w * t / fft_size) mat_cos = mat_cos * window_func / inv_window_func mat_sin = mat_sin * window_func / inv_window_func conv_cos = get_parameter_or_create('conv_cos', initializer=mat_cos, need_grad=False) conv_sin = get_parameter_or_create('conv_sin', initializer=mat_sin, need_grad=False) # compute inverse STFT x_cos = F.deconvolution(y_r, conv_cos, stride=(stride, )) x_sin = F.deconvolution(y_i, conv_sin, stride=(stride, )) x = F.reshape(x_cos - x_sin, (x_cos.shape[0], x_cos.shape[2])) if center: x = x[:, fft_size // 2:-fft_size // 2] return x
def __call__(self, x, test=False): ''' Input: (nb_samples, nb_channels, nb_timesteps) or (nb_frames, nb_samples, nb_channels, nb_bins) Outputs: Input Power/Mag Spectrogram, Output Power/Mag Spectrogram and Predictd sources ''' self.test = test fft_real, fft_imag = get_stft(x, n_fft=self.n_fft, n_hop=self.n_hop) x_theta = F.atan2(fft_imag, fft_real) x = get_spectogram(fft_real, fft_imag, mono=(self.nb_channels == 1)) nb_frames, nb_samples, nb_channels, nb_bins = x.shape mix_spec = F.identity(x) x = x[..., :self.nb_bins] # clone x_bass = F.identity(x) x_drums = F.identity(x) x_vocals = F.identity(x) x_other = F.identity(x) # shift and scale input to mean=0 std=1 (across all bins) x_bass += self.input_mean_bass x_drums += self.input_mean_drums x_vocals += self.input_mean_vocals x_other += self.input_mean_other x_bass *= self.input_scale_bass x_drums *= self.input_scale_drums x_vocals *= self.input_scale_vocals x_other *= self.input_scale_other # encode and normalize every instance in a batch x_bass = self.fc_bn(x_bass, self.hidden_size, "fc1_bass", activation='tanh') x_drums = self.fc_bn(x_drums, self.hidden_size, "fc1_drums", activation='tanh') x_vocals = self.fc_bn(x_vocals, self.hidden_size, "fc1_vocals", activation='tanh') x_other = self.fc_bn(x_other, self.hidden_size, "fc1_other", activation='tanh') # Average the sources cross_1 = (x_bass + x_drums + x_vocals + x_other) / 4.0 # apply 3-layers of stacked LSTM lstm_out_bass = self.lstm(cross_1, nb_samples, "lstm_bass") lstm_out_drums = self.lstm(cross_1, nb_samples, "lstm_drums") lstm_out_vocals = self.lstm(cross_1, nb_samples, "lstm_vocals") lstm_out_other = self.lstm(cross_1, nb_samples, "lstm_other") # lstm skip connection x_bass = F.concatenate(x_bass, lstm_out_bass) x_drums = F.concatenate(x_drums, lstm_out_drums) x_vocals = F.concatenate(x_vocals, lstm_out_vocals) x_other = F.concatenate(x_other, lstm_out_other) cross_2 = (x_bass + x_drums + x_vocals + x_other) / 4.0 # first dense stage + batch norm x_bass = self.fc_bn(cross_2, self.hidden_size, "fc2_bass", activation='relu') x_drums = self.fc_bn(cross_2, self.hidden_size, "fc2_drums", activation='relu') x_vocals = self.fc_bn(cross_2, self.hidden_size, "fc2_vocals", activation='relu') x_other = self.fc_bn(cross_2, self.hidden_size, "fc2_other", activation='relu') # second dense stage + batch norm x_bass = self.fc_bn(x_bass, nb_channels * nb_bins, "fc3_bass") x_drums = self.fc_bn(x_drums, nb_channels * nb_bins, "fc3_drums") x_vocals = self.fc_bn(x_vocals, nb_channels * nb_bins, "fc3_vocals") x_other = self.fc_bn(x_other, nb_channels * nb_bins, "fc3_other") # reshape back to original dim x_bass = F.reshape( x_bass, (nb_frames, nb_samples, nb_channels, self.nb_output_bins)) x_drums = F.reshape( x_drums, (nb_frames, nb_samples, nb_channels, self.nb_output_bins)) x_vocals = F.reshape( x_vocals, (nb_frames, nb_samples, nb_channels, self.nb_output_bins)) x_other = F.reshape( x_other, (nb_frames, nb_samples, nb_channels, self.nb_output_bins)) # apply output scale and shift x_bass *= self.output_scale_bass x_drums *= self.output_scale_drums x_vocals *= self.output_scale_vocals x_other *= self.output_scale_other x_bass += self.output_mean_bass x_drums += self.output_mean_drums x_vocals += self.output_mean_vocals x_other += self.output_mean_other # since our output is non-negative, we can apply RELU mask_bass = F.relu(x_bass) mask_drums = F.relu(x_drums) mask_vocals = F.relu(x_vocals) mask_other = F.relu(x_other) # (Frames, Bsize, Channels, Fbins) x_bass = mask_bass * mix_spec x_drums = mask_drums * mix_spec x_vocals = mask_vocals * mix_spec x_other = mask_other * mix_spec if not self.is_predict: tmp = F.stack(*[x_bass, x_drums, x_vocals, x_other], axis=0) # (4(sources), Frames, Bsize(16), 2(channels), Fbins) ==> (4, Bsize, Channels, Fbins, Frames) tmp = F.transpose(tmp, (0, 2, 3, 4, 1)) pred_r, pred_i = [], [] for i in range(tmp.shape[0]): pred_r.append(tmp[i] * F.cos(x_theta)) pred_i.append(tmp[i] * F.sin(x_theta)) pred_r = F.stack(*pred_r, axis=0) pred_i = F.stack(*pred_i, axis=0) pred_r = F.reshape( pred_r, (4 * nb_samples * nb_channels, self.nb_output_bins, nb_frames)) pred_i = F.reshape( pred_i, (4 * nb_samples * nb_channels, self.nb_output_bins, nb_frames)) pred = F.istft(pred_r, pred_i, self.n_fft, self.n_hop, self.n_fft, window_type='hanning', center=True) pred = F.reshape(pred, (4, nb_samples, nb_channels, -1)) else: pred = None return mix_spec, F.concatenate(mask_bass, mask_drums, mask_vocals, mask_other, axis=2), pred