def training_process(): """The training process. """ print('\n-- Starting training process. Debug mode: {}'.format(debug)) print('-- Setting up modules... ', end='') # Masker modules rnn_enc = RNNEnc(hyper_parameters['reduced_dim'], hyper_parameters['context_length'], debug) rnn_dec = RNNDec(hyper_parameters['rnn_enc_output_dim'], debug) fnn = FNNMasker(hyper_parameters['rnn_enc_output_dim'], hyper_parameters['original_input_dim'], hyper_parameters['context_length']) # Denoiser modules denoiser = CNNDenoiser(hyper_parameters['original_input_dim']) # TwinNet regularization modules twin_net_rnn_dec = TwinRNNDec(hyper_parameters['rnn_enc_output_dim'], debug) twin_net_fnn_masker = FNNMasker(hyper_parameters['rnn_enc_output_dim'], hyper_parameters['original_input_dim'], hyper_parameters['context_length']) affine_transform = AffineTransform(hyper_parameters['rnn_enc_output_dim']) if not debug and torch.has_cudnn: rnn_enc = rnn_enc.cuda() rnn_dec = rnn_dec.cuda() fnn = fnn.cuda() denoiser = denoiser.cuda() twin_net_rnn_dec = twin_net_rnn_dec.cuda() twin_net_fnn_masker = twin_net_fnn_masker.cuda() affine_transform = affine_transform.cuda() print('done.') print('-- Setting up optimizes and losses... ', end='') # Objectives and penalties loss_masker = kl loss_denoiser = kl loss_twin = kl reg_twin = l2_loss reg_fnn_masker = sparsity_penalty reg_fnn_dec = l2_reg_squared # Optimizer optimizer = optim.Adam( list(rnn_enc.parameters()) + list(rnn_dec.parameters()) + list(fnn.parameters()) + list(denoiser.parameters()) + list(twin_net_rnn_dec.parameters()) + list(twin_net_fnn_masker.parameters()) + list(affine_transform.parameters()), lr=hyper_parameters['learning_rate']) print('done.') # Initializing data feeder epoch_it = data_feeder_training( window_size=hyper_parameters['window_size'], fft_size=hyper_parameters['fft_size'], hop_size=hyper_parameters['hop_size'], seq_length=hyper_parameters['seq_length'], context_length=hyper_parameters['context_length'], batch_size=training_constants['batch_size'], files_per_pass=training_constants['files_per_pass'], debug=debug) print('-- Training starts\n') # Training loop starts for epoch in range(training_constants['epochs']): epoch_l_m = [] epoch_l_d = [] epoch_l_tw = [] epoch_l_twin = [] time_start = time.time() # Epoch loop for data in epoch_it(): v_in = Variable(torch.from_numpy(data[0])) v_j = Variable(torch.from_numpy(data[1])) if not debug and torch.has_cudnn: v_in = v_in.cuda() v_j = v_j.cuda() # Masker pass h_enc = rnn_enc(v_in) h_dec = rnn_dec(h_enc) v_j_filt_prime = fnn(h_dec, v_in) # TwinNet pass h_dec_twin = twin_net_rnn_dec(h_enc) v_j_filt_prime_twin = twin_net_fnn_masker(h_dec_twin, v_in) # Twin net regularization affine_output = affine_transform(h_dec) # Denoiser pass v_j_filt = denoiser(v_j_filt_prime) optimizer.zero_grad() # Calculate losses l_m = loss_masker(v_j_filt_prime, v_j) l_d = loss_denoiser(v_j_filt, v_j) l_tw = loss_twin(v_j_filt_prime_twin, v_j) l_twin = reg_twin(affine_output, h_dec_twin.detach()) # Make MaD TwinNet objective loss = l_m + l_d + l_tw + (hyper_parameters['lambda_l_twin'] * l_twin) + \ (hyper_parameters['lambda_1'] * reg_fnn_masker(fnn.linear_layer.weight)) + \ (hyper_parameters['lambda_2'] * reg_fnn_dec(denoiser.fnn_dec.weight)) # Backward pass loss.backward() # Gradient norm clipping torch.nn.utils.clip_grad_norm( list(rnn_enc.parameters()) + list(rnn_dec.parameters()) + list(fnn.parameters()) + list(denoiser.parameters()) + list(twin_net_rnn_dec.parameters()) + list(twin_net_fnn_masker.parameters()) + list(affine_transform.parameters()), max_norm=hyper_parameters['max_grad_norm'], norm_type=2) # Optimize optimizer.step() # Log losses epoch_l_m.append(l_m.data[0]) epoch_l_d.append(l_d.data[0]) epoch_l_tw.append(l_tw.data[0]) epoch_l_twin.append(l_twin.data[0]) time_end = time.time() # Tell us what happened print( training_output_string.format( ep=epoch, l_m=torch.mean(torch.FloatTensor(epoch_l_m)), l_d=torch.mean(torch.FloatTensor(epoch_l_d)), l_tw=torch.mean(torch.FloatTensor(epoch_l_tw)), l_twin=torch.mean(torch.FloatTensor(epoch_l_twin)), t=time_end - time_start)) # Kindly end and save the model print('\n-- Training done.') print('-- Saving model.. ', end='') torch.save(rnn_enc.state_dict(), output_states_path['rnn_enc']) torch.save(rnn_dec.state_dict(), output_states_path['rnn_dec']) torch.save(fnn.state_dict(), output_states_path['fnn']) torch.save(denoiser.state_dict(), output_states_path['denoiser']) print('done.') print('-- That\'s all folks!')
def use_me_process(sources_list, output_file_names): """The usage process. :param sources_list: The file names to be used. :type sources_list: list[str] :param output_file_names: The output file names to be used. :type output_file_names: list[list[str]] """ print('\n-- Welcome to MaD TwinNet.') if debug: print( '\n-- Cannot proceed in debug mode. Please set debug=False at the settings file.' ) print('-- Exiting.') exit(-1) print( '-- Now I will extract the voice and the background music from the provided files' ) # Masker modules rnn_enc = RNNEnc(hyper_parameters['reduced_dim'], hyper_parameters['context_length'], debug) rnn_dec = RNNDec(hyper_parameters['rnn_enc_output_dim'], debug) fnn = FNNMasker(hyper_parameters['rnn_enc_output_dim'], hyper_parameters['original_input_dim'], hyper_parameters['context_length']) # Denoiser modules denoiser = FNNDenoiser(hyper_parameters['original_input_dim']) rnn_enc.load_state_dict(torch.load(output_states_path['rnn_enc'])) rnn_dec.load_state_dict(torch.load(output_states_path['rnn_dec'])) fnn.load_state_dict(torch.load(output_states_path['fnn'])) denoiser.load_state_dict(torch.load(output_states_path['denoiser'])) if not debug and torch.has_cudnn: rnn_enc = rnn_enc.cuda() rnn_dec = rnn_dec.cuda() fnn = fnn.cuda() denoiser = denoiser.cuda() testing_it = data_feeder_testing( window_size=hyper_parameters['window_size'], fft_size=hyper_parameters['fft_size'], hop_size=hyper_parameters['hop_size'], seq_length=hyper_parameters['seq_length'], context_length=hyper_parameters['context_length'], batch_size=1, debug=debug, sources_list=sources_list) print('-- Let\'s go!\n') total_time = 0 for index, data in enumerate(testing_it()): s_time = time.time() mix, mix_magnitude, mix_phase, voice_true, bg_true = data voice_predicted = np.zeros( (mix_magnitude.shape[0], hyper_parameters['seq_length'] - hyper_parameters['context_length'] * 2, hyper_parameters['window_size']), dtype=np.float32) for batch in range( int(mix_magnitude.shape[0] / training_constants['batch_size'])): b_start = batch * training_constants['batch_size'] b_end = (batch + 1) * training_constants['batch_size'] v_in = Variable( torch.from_numpy(mix_magnitude[b_start:b_end, :, :])) if not debug and torch.has_cudnn: v_in = v_in.cuda() tmp_voice_predicted = rnn_enc(v_in) tmp_voice_predicted = rnn_dec(tmp_voice_predicted) tmp_voice_predicted = fnn(tmp_voice_predicted, v_in) tmp_voice_predicted = denoiser(tmp_voice_predicted) voice_predicted[ b_start:b_end, :, :] = tmp_voice_predicted.data.cpu().numpy() data_process_results_testing( index=index, voice_true=voice_true, bg_true=bg_true, voice_predicted=voice_predicted, window_size=hyper_parameters['window_size'], mix=mix, mix_magnitude=mix_magnitude, mix_phase=mix_phase, hop=hyper_parameters['hop_size'], context_length=hyper_parameters['context_length'], output_file_name=output_file_names[index]) e_time = time.time() print( usage_output_string_per_example.format(f=sources_list[index], t=e_time - s_time)) total_time += e_time - s_time print('\n-- Testing finished\n') print(usage_output_string_total.format(t=total_time)) print('-- That\'s all folks!')
def testing_process(): """The testing process. """ device = 'cuda' if not debug and torch.cuda.is_available() else 'cpu' print('\n-- Starting testing process. Debug mode: {}'.format(debug)) print('-- Process on: {}'.format(device), end='\n\n') print('-- Setting up modules... ', end='') # Masker modules rnn_enc = RNNEnc(hyper_parameters['reduced_dim'], hyper_parameters['context_length'], debug) rnn_dec = RNNDec(hyper_parameters['rnn_enc_output_dim'], debug) fnn = FNNMasker(hyper_parameters['rnn_enc_output_dim'], hyper_parameters['original_input_dim'], hyper_parameters['context_length']) # Denoiser modules denoiser = FNNDenoiser(hyper_parameters['original_input_dim']) rnn_enc.load_state_dict(torch.load( output_states_path['rnn_enc'])).to(device) rnn_dec.load_state_dict(torch.load( output_states_path['rnn_dec'])).to(device) fnn.load_state_dict(torch.load(output_states_path['fnn'])).to(device) denoiser.load_state_dict(torch.load( output_states_path['denoiser'])).to(device) print('done.') testing_it = data_feeder_testing( window_size=hyper_parameters['window_size'], fft_size=hyper_parameters['fft_size'], hop_size=hyper_parameters['hop_size'], seq_length=hyper_parameters['seq_length'], context_length=hyper_parameters['context_length'], batch_size=1, debug=debug) print('-- Testing starts\n') sdr = [] sir = [] total_time = 0 for index, data in enumerate(testing_it()): s_time = time.time() mix, mix_magnitude, mix_phase, voice_true, bg_true = data voice_predicted = np.zeros( (mix_magnitude.shape[0], hyper_parameters['seq_length'] - hyper_parameters['context_length'] * 2, hyper_parameters['window_size']), dtype=np.float32) for batch in range( int(mix_magnitude.shape[0] / training_constants['batch_size'])): b_start = batch * training_constants['batch_size'] b_end = (batch + 1) * training_constants['batch_size'] v_in = torch.from_numpy( mix_magnitude[b_start:b_end, :, :]).to(device) tmp_voice_predicted = rnn_enc(v_in) tmp_voice_predicted = rnn_dec(tmp_voice_predicted) tmp_voice_predicted = fnn(tmp_voice_predicted, v_in) tmp_voice_predicted = denoiser(tmp_voice_predicted) voice_predicted[ b_start:b_end, :, :] = tmp_voice_predicted.data.cpu().numpy() tmp_sdr, tmp_sir = data_process_results_testing( index=index, voice_true=voice_true, bg_true=bg_true, voice_predicted=voice_predicted, window_size=hyper_parameters['window_size'], mix=mix, mix_magnitude=mix_magnitude, mix_phase=mix_phase, hop=hyper_parameters['hop_size'], context_length=hyper_parameters['context_length']) e_time = time.time() print( testing_output_string_per_example.format( e=index, sdr=np.median([i for i in tmp_sdr[0] if not np.isnan(i)]), sir=np.median([i for i in tmp_sir[0] if not np.isnan(i)]), t=e_time - s_time)) total_time += e_time - s_time sdr.append(tmp_sdr) sir.append(tmp_sir) print('\n-- Testing finished\n') print( testing_output_string_all.format( sdr=np.median([ii for i in sdr for ii in i[0] if not np.isnan(ii)]), sir=np.median([ii for i in sir for ii in i[0] if not np.isnan(ii)]), t=total_time)) print('\n-- Saving results... ', end='') with open(metrics_paths['sdr'], 'wb') as f: pickle.dump(sdr, f, protocol=2) with open(metrics_paths['sir'], 'wb') as f: pickle.dump(sir, f, protocol=2) print('done!') print('-- That\'s all folks!')