def test_nnet(nnet, seqlen=100, olap=40, wsz=2049, N=4096, hop=384, B=16): nnet[0].eval() nnet[1].eval() nnet[2].eval() nnet[3].eval() L = olap/2 seg = 2 w = tf.hamming(wsz, True) x, fs = Io.wavRead('/home/mis/Documents/Python/Projects/SourceSeparation/testFiles/supreme_test3.wav', mono=True) mx, px = tf.TimeFrequencyDecomposition.STFT(x, w, N, hop) mx, px, _ = prepare_overlap_sequences(mx, px, mx, seqlen, olap, B) vs_out = np.zeros((mx.shape[0], seqlen-olap, wsz), dtype=np.float32) mask_out1 = np.zeros((mx.shape[0], seqlen-olap, wsz), dtype=np.float32) for batch in xrange(mx.shape[0]/B): # Mixture to Singing voice H_enc = nnet[0](mx[batch * B: (batch+1)*B, :, :]) H_j_dec = it_infer.iterative_recurrent_inference(nnet[1], H_enc, criterion=None, tol=1e-3, max_iter=10) vs_hat, mask = nnet[2](H_j_dec, mx[batch * B: (batch+1)*B, :, :]) y_out = nnet[3](vs_hat) vs_out[batch * B: (batch+1)*B, :, :] = y_out.data.cpu().numpy() mask_out1[batch * B: (batch+1)*B, :, :] = mask.data.cpu().numpy() vs_out.shape = (vs_out.shape[0]*vs_out.shape[1], wsz) mask_out1.shape = (mask_out1.shape[0]*mask_out1.shape[1], wsz) if olap == 1: mx = np.ascontiguousarray(mx, dtype=np.float32) px = np.ascontiguousarray(px, dtype=np.float32) else: mx = np.ascontiguousarray(mx[:, olap/2:-olap/2, :], dtype=np.float32) px = np.ascontiguousarray(px[:, olap/2:-olap/2, :], dtype=np.float32) mx.shape = (mx.shape[0]*mx.shape[1], wsz) px.shape = (px.shape[0]*px.shape[1], wsz) # Approximated sources # Iterative G-L algorithm for GLiter in range(10): y_recb = tf.TimeFrequencyDecomposition.iSTFT(vs_out, px, wsz, hop, True) _, px = tf.TimeFrequencyDecomposition.STFT(y_recb, tf.hamming(wsz, True), N, hop) x = x[olap/2 * hop:] Io.audioWrite(y_recb, 44100, 16, 'results/test_sv.mp3', 'mp3') Io.audioWrite(x[:len(y_recb)], 44100, 16, 'results/test_mix.mp3', 'mp3') return None
def test_nnet(nnet, seqlen=100, olap=40, wsz=2049, N=4096, hop=384, B=16): """ Method to test the model on some data. Writes the outcomes in ".wav" format and. stores them under the defined results path. Args: nnet : (List) A list containing the Pytorch modules of the skip-filtering model. seqlen : (int) Length of the time-sequence. olap : (int) Overlap between spectrogram time-sequences (to recover the missing information from the context information). wsz : (int) Window size in samples. N : (int) The FFT size. hop : (int) Hop size in samples. B : (int) Batch size. """ nnet[0].eval() nnet[1].eval() nnet[2].eval() nnet[3].eval() L = olap / 2 w = tf.hamming(wsz, True) x, fs = Io.wavRead('results/test_files/test.wav', mono=True) mx, px = tf.TimeFrequencyDecomposition.STFT(x, w, N, hop) mx, px, _ = prepare_overlap_sequences(mx, px, mx, seqlen, olap, B) vs_out = np.zeros((mx.shape[0], seqlen - olap, wsz), dtype=np.float32) for batch in xrange(mx.shape[0] / B): # Mixture to Singing voice H_enc = nnet[0](mx[batch * B:(batch + 1) * B, :, :]) H_j_dec = it_infer.iterative_recurrent_inference(nnet[1], H_enc, criterion=None, tol=1e-3, max_iter=10) vs_hat, mask = nnet[2](H_j_dec, mx[batch * B:(batch + 1) * B, :, :]) y_out = nnet[3](vs_hat) vs_out[batch * B:(batch + 1) * B, :, :] = y_out.data.cpu().numpy() vs_out.shape = (vs_out.shape[0] * vs_out.shape[1], wsz) if olap == 1: mx = np.ascontiguousarray(mx, dtype=np.float32) px = np.ascontiguousarray(px, dtype=np.float32) else: mx = np.ascontiguousarray(mx[:, olap / 2:-olap / 2, :], dtype=np.float32) px = np.ascontiguousarray(px[:, olap / 2:-olap / 2, :], dtype=np.float32) mx.shape = (mx.shape[0] * mx.shape[1], wsz) px.shape = (px.shape[0] * px.shape[1], wsz) # Approximated sources # Iterative G-L algorithm for GLiter in range(10): y_recb = tf.TimeFrequencyDecomposition.iSTFT(vs_out, px, wsz, hop, True) _, px = tf.TimeFrequencyDecomposition.STFT(y_recb, tf.hamming(wsz, True), N, hop) x = x[olap / 2 * hop:] Io.wavWrite(y_recb, 44100, 16, 'results/test_files/test_sv.wav') Io.wavWrite(x[:len(y_recb)], 44100, 16, 'results/test_files/test_mix.wav') return None
def test_eval(nnet, B, T, N, L, wsz, hop): """ Method to test the model on the test data. Writes the outcomes in ".wav" format and. stores them under the defined results path. Optionally, it performs BSS-Eval using MIREval python toolbox (Used only for comparison to BSSEval Matlab implementation). The evaluation results are stored under the defined save path. Args: nnet : (List) A list containing the Pytorch modules of the skip-filtering model. B : (int) Batch size. T : (int) Length of the time-sequence. N : (int) The FFT size. L : (int) Number of context frames from the time-sequence. wsz : (int) Window size in samples. hop : (int) Hop size in samples. """ nnet[0].eval() nnet[1].eval() nnet[2].eval() nnet[3].eval() def my_res(mx, vx, L, wsz): """ A helper function to reshape data according to the context frame. """ mx = np.ascontiguousarray(mx[:, L:-L, :], dtype=np.float32) mx.shape = (mx.shape[0] * mx.shape[1], wsz) vx = np.ascontiguousarray(vx[:, L:-L, :], dtype=np.float32) vx.shape = (vx.shape[0] * vx.shape[1], wsz) return mx, vx # Paths for loading and storing the test-set # Generate full paths for test test_sources_list = sorted(os.listdir(sources_path + foldersList[1])) test_sources_list = [ sources_path + foldersList[1] + '/' + i for i in test_sources_list ] # Initializing the containers of the metrics sdr = [] sir = [] sar = [] for indx in xrange(len(test_sources_list)): print('Reading:' + test_sources_list[indx]) # Reading bass, _ = Io.wavRead(os.path.join(test_sources_list[indx], keywords[0]), mono=False) drums, _ = Io.wavRead(os.path.join(test_sources_list[indx], keywords[1]), mono=False) oth, _ = Io.wavRead(os.path.join(test_sources_list[indx], keywords[2]), mono=False) vox, _ = Io.wavRead(os.path.join(test_sources_list[indx], keywords[3]), mono=False) bk_true = np.sum(bass + drums + oth, axis=-1) * 0.5 mix = np.sum(bass + drums + oth + vox, axis=-1) * 0.5 sv_true = np.sum(vox, axis=-1) * 0.5 # STFT Analysing mx, px = tf.TimeFrequencyDecomposition.STFT(mix, tf.hamming(wsz, True), N, hop) # Data reshaping (magnitude and phase) mx, px, _ = prepare_overlap_sequences(mx, px, px, T, 2 * L, B) # The actual "denoising" part vx_hat = np.zeros((mx.shape[0], T - L * 2, wsz), dtype=np.float32) for batch in xrange(mx.shape[0] / B): H_enc = nnet[0](mx[batch * B:(batch + 1) * B, :, :]) H_j_dec = it_infer.iterative_recurrent_inference(nnet[1], H_enc, criterion=None, tol=1e-3, max_iter=10) vs_hat, mask = nnet[2](H_j_dec, mx[batch * B:(batch + 1) * B, :, :]) y_out = nnet[3](vs_hat) vx_hat[batch * B:(batch + 1) * B, :, :] = y_out.data.cpu().numpy() # Final reshaping vx_hat.shape = (vx_hat.shape[0] * vx_hat.shape[1], wsz) mx, px = my_res(mx, px, L, wsz) # Time-domain recovery # Iterative G-L algorithm for GLiter in range(10): sv_hat = tf.TimeFrequencyDecomposition.iSTFT( vx_hat, px, wsz, hop, True) _, px = tf.TimeFrequencyDecomposition.STFT(sv_hat, tf.hamming(wsz, True), N, hop) # Removing the samples that no estimation exists mix = mix[L * hop:] sv_true = sv_true[L * hop:] bk_true = bk_true[L * hop:] # Background music estimation if len(sv_true) > len(sv_hat): bk_hat = mix[:len(sv_hat)] - sv_hat else: bk_hat = mix - sv_hat[:len(mix)] # Disk writing for external BSS_eval using DSD100-tools (used in our paper) Io.wavWrite( sv_true, 44100, 16, os.path.join(save_path, 'tf_true_sv_' + str(indx) + '.wav')) Io.wavWrite( bk_true, 44100, 16, os.path.join(save_path, 'tf_true_bk_' + str(indx) + '.wav')) Io.wavWrite(sv_hat, 44100, 16, os.path.join(save_path, 'tf_hat_sv_' + str(indx) + '.wav')) Io.wavWrite(bk_hat, 44100, 16, os.path.join(save_path, 'tf_hat_bk_' + str(indx) + '.wav')) Io.wavWrite(mix, 44100, 16, os.path.join(save_path, 'tf_mix_' + str(indx) + '.wav')) # Internal BSSEval using librosa (just for comparison) if len(sv_true) > len(sv_hat): c_sdr, _, c_sir, c_sar, _ = bss_eval.bss_eval_images_framewise( [sv_true[:len(sv_hat)], bk_true[:len(sv_hat)]], [sv_hat, bk_hat]) else: c_sdr, _, c_sir, c_sar, _ = bss_eval.bss_eval_images_framewise( [sv_true, bk_true], [sv_hat[:len(sv_true)], bk_hat[:len(sv_true)]]) sdr.append(c_sdr) sir.append(c_sir) sar.append(c_sar) # Storing the results iteratively pickle.dump(sdr, open(os.path.join(save_path, 'SDR.p'), 'wb')) pickle.dump(sir, open(os.path.join(save_path, 'SIR.p'), 'wb')) pickle.dump(sar, open(os.path.join(save_path, 'SAR.p'), 'wb')) return None
def main(training, apply_sparsity): """ The main function to train and test. """ # Reproducible results np.random.seed(218) torch.manual_seed(218) torch.cuda.manual_seed(218) # Torch model torch.set_default_tensor_type('torch.cuda.FloatTensor') # Analysis wsz = 2049 # Window-size Ns = 4096 # FFT size hop = 384 # Hop size fs = 44100 # Sampling frequency # Parameters B = 16 # Batch-size T = 60 # Length of the sequence N = 2049 # Frequency sub-bands to be processed F = 744 # Frequency sub-bands for encoding L = 10 # Context parameter (2*L frames will be removed) epochs = 100 # Epochs init_lr = 1e-4 # Initial learning rate mnorm = 0.5 # L2-based norm clipping mask_loss_threshold = 1.5 # Scalar indicating the threshold for the time-frequency masking module good_loss_threshold = 0.25 # Scalar indicating the threshold for the source enhancment module # Data (Predifined by the DSD100 dataset and the non-instumental/non-bleeding stems of MedleydB) totTrainFiles = 116 numFilesPerTr = 4 print('------------ Building model ------------') encoder = s_s_net.BiGRUEncoder(B, T, N, F, L) decoder = s_s_net.Decoder(B, T, N, F, L, infr=True) sp_decoder = s_s_net.SparseDecoder(B, T, N, F, L) source_enhancement = s_s_net.SourceEnhancement(B, T, N, F, L) encoder.train(mode=True) decoder.train(mode=True) sp_decoder.train(mode=True) source_enhancement.train(mode=True) if torch.has_cudnn: print('------------ CUDA Enabled --------------') encoder.cuda() decoder.cuda() sp_decoder.cuda() source_enhancement.cuda() # Defining objectives rec_criterion = loss_functions.kullback_leibler # Reconstruction criterion optimizer = optim.Adam( list(encoder.parameters()) + list(decoder.parameters()) + list(sp_decoder.parameters()) + list(source_enhancement.parameters()), lr=init_lr) if training: win_viz, winb_viz = visualize.init_visdom() batch_loss = [] # Over epochs batch_index = 0 for epoch in range(epochs): print('Epoch: ' + str(epoch + 1)) epoch_loss = [] # Over the set of files for index in range(totTrainFiles / numFilesPerTr): # Get Data ms, vs = nnet_helpers.get_data(index + 1, numFilesPerTr, wsz, Ns, hop, T, L, B) # Shuffle data shf_indices = np.random.permutation(ms.shape[0]) ms = ms[shf_indices] vs = vs[shf_indices] # Over batches for batch in tqdm(range(ms.shape[0] / B)): # Mixture to Singing voice H_enc = encoder(ms[batch * B:(batch + 1) * B, :, :]) # Iterative inference H_j_dec = it_infer.iterative_recurrent_inference( decoder, H_enc, criterion=None, tol=1e-3, max_iter=10) vs_hat_b = sp_decoder(H_j_dec, ms[batch * B:(batch + 1) * B, :, :])[0] vs_hat_b_filt = source_enhancement(vs_hat_b) # Loss if torch.has_cudnn: loss = rec_criterion( Variable( torch.from_numpy(vs[batch * B:(batch + 1) * B, L:-L, :]).cuda()), vs_hat_b_filt) loss_mask = rec_criterion( Variable( torch.from_numpy(vs[batch * B:(batch + 1) * B, L:-L, :]).cuda()), vs_hat_b) if loss_mask.data[ 0] >= mask_loss_threshold and loss.data[ 0] >= good_loss_threshold: loss += loss_mask else: loss = rec_criterion( Variable( torch.from_numpy(vs[batch * B:(batch + 1) * B, L:-L, :])), vs_hat_b_filt) loss_mask = rec_criterion( Variable( torch.from_numpy(vs[batch * B:(batch + 1) * B, L:-L, :])), vs_hat_b) if loss_mask.data[ 0] >= mask_loss_threshold and loss.data[ 0] >= good_loss_threshold: loss += loss_mask # Store loss for display and scheduler batch_loss += [loss.data[0]] epoch_loss += [loss.data[0]] # Sparsity term if apply_sparsity: sparsity_penalty = torch.sum(torch.abs(torch.diag(sp_decoder.ffDec.weight.data))) * 1e-2 +\ torch.sum(torch.pow(source_enhancement.ffSe_dec.weight, 2.)) * 1e-4 loss += sparsity_penalty winb_viz = visualize.viz.line( X=np.arange(batch_index, batch_index + 1), Y=np.reshape(sparsity_penalty.data[0], (1, )), win=winb_viz, update='append') optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm( list(encoder.parameters()) + list(decoder.parameters()) + list(sp_decoder.parameters()) + list(source_enhancement.parameters()), max_norm=mnorm, norm_type=2) optimizer.step() # Update graphs win_viz = visualize.viz.line( X=np.arange(batch_index, batch_index + 1), Y=np.reshape(batch_loss[batch_index], (1, )), win=win_viz, update='append') batch_index += 1 if (epoch + 1) % 40 == 0: print('------------ Saving model ------------') torch.save( encoder.state_dict(), 'results/torch_sps_encoder_' + str(epoch + 1) + '.pytorch') torch.save( decoder.state_dict(), 'results/torch_sps_decoder_' + str(epoch + 1) + '.pytorch') torch.save( sp_decoder.state_dict(), 'results/torch_sps_sp_decoder_' + str(epoch + 1) + '.pytorch') torch.save( source_enhancement.state_dict(), 'results/torch_sps_se_' + str(epoch + 1) + '.pytorch') print('------------ Done ------------') else: print('------- Loading pre-trained model -------') print('------- Loading inference weights -------') encoder.load_state_dict( torch.load('results/results_inference/torch_sps_encoder.pytorch')) decoder.load_state_dict( torch.load('results/results_inference/torch_sps_decoder.pytorch')) sp_decoder.load_state_dict( torch.load( 'results/results_inference/torch_sps_sp_decoder.pytorch')) source_enhancement.load_state_dict( torch.load('results/results_inference/torch_sps_se.pytorch')) print('------------- Done -------------') return encoder, decoder, sp_decoder, source_enhancement