def train(model_config, sup_dataset, model_folder, unsup_dataset=None, load_model=None): # Determine input and output shapes freq_bins = model_config["num_fft"] / 2 + 1 # Make even number of freq bins disc_input_shape = [ model_config["batch_size"], freq_bins - 1, model_config["num_frames"], 1 ] # Shape of discriminator input separator_class = Models.Unet.Unet(model_config["num_layers"]) sep_input_shape, sep_output_shape = separator_class.getUnetPadding( np.array(disc_input_shape)) separator_func = separator_class.get_output # Batch input workers # Creating the batch generators padding_durations = [ float(sep_input_shape[2] - sep_output_shape[2]) * model_config["num_hop"] / model_config["expected_sr"] / 2.0, 0, 0 ] # Input context that the input audio has to be padded with while reading audio files sup_batch_gen = batchgen.BatchGen_Paired(model_config, sup_dataset, sep_input_shape, sep_output_shape, padding_durations[0]) # Creating unsupervised batch generator if needed if unsup_dataset is not None: unsup_batch_gens = list() for i in range(3): shape = (sep_input_shape if i == 0 else sep_output_shape) unsup_batch_gens.append( batchgen.BatchGen_Single(model_config, unsup_dataset[i], shape, padding_durations[i])) print("Starting worker") sup_batch_gen.start_workers() print("Started worker!") if unsup_dataset is not None: for gen in unsup_batch_gens: print("Starting worker") gen.start_workers() print("Started worker!") # Placeholders and input normalisation mix_context, acc, drums = Input.get_multitrack_placeholders( sep_output_shape, sep_input_shape, "sup") mix = Input.crop(mix_context, sep_output_shape) mix_norm, mix_context_norm, acc_norm, drums_norm = Input.norm( mix), Input.norm(mix_context), Input.norm(acc), Input.norm(drums) if unsup_dataset is not None: mix_context_u, acc_u, drums_u = Input.get_multitrack_placeholders( sep_output_shape, sep_input_shape, "unsup") mix_u = Input.crop(mix_context_u, sep_output_shape) mix_norm_u, mix_context_norm_u, acc_norm_u, drums_norm_u = Input.norm( mix_u), Input.norm(mix_context_u), Input.norm(acc_u), Input.norm( drums_u) print("Training...") # BUILD MODELS # Separator separator_acc_norm, separator_drums_norm = separator_func(mix_context_norm, reuse=False) separator_acc, separator_drums = Input.denorm( separator_acc_norm), Input.denorm(separator_drums_norm) if unsup_dataset is not None: separator_acc_norm_u, separator_drums_norm_u = separator_func( mix_context_norm_u, reuse=True) separator_acc_u, separator_drums_u = Input.denorm( separator_acc_norm_u), Input.denorm(separator_drums_norm_u) mask_loss_u = tf.reduce_mean( tf.square(mix_u - separator_acc_u - separator_drums_u)) mask_loss = tf.reduce_mean(tf.square(mix - separator_acc - separator_drums)) # SUMMARIES FOR INPUT AND SEPARATOR tf.summary.scalar("mask_loss", mask_loss, collections=["sup", "unsup"]) if unsup_dataset is not None: tf.summary.scalar("mask_loss_u", mask_loss_u, collections=["unsup"]) tf.summary.scalar("acc_norm_mean_u", tf.reduce_mean(acc_norm_u), collections=["acc_disc"]) tf.summary.scalar("drums_norm_mean_u", tf.reduce_mean(drums_norm_u), collections=["drums_disc"]) tf.summary.scalar("acc_sep_norm_mean_u", tf.reduce_mean(separator_acc_norm_u), collections=["acc_disc"]) tf.summary.scalar("drums_sep_norm_mean_u", tf.reduce_mean(separator_drums_norm_u), collections=["drums_disc"]) tf.summary.scalar("acc_norm_mean", tf.reduce_mean(acc_norm), collections=['sup']) tf.summary.scalar("drums_norm_mean", tf.reduce_mean(drums_norm), collections=['sup']) tf.summary.scalar("acc_sep_norm_mean", tf.reduce_mean(separator_acc_norm), collections=['sup']) tf.summary.scalar("drums_sep_norm_mean", tf.reduce_mean(separator_drums_norm), collections=['sup']) tf.summary.image("sep_acc_norm", separator_acc_norm, collections=["sup", "unsup"]) tf.summary.image("sep_drums_norm", separator_drums_norm, collections=["sup", "unsup"]) # BUILD DISCRIMINATORS, if unsupervised training unsup_separator_loss = 0 if unsup_dataset is not None: disc_func = Models.WGAN_Critic.dcgan # Define real and fake inputs for both discriminators - if separator output and dsicriminator input shapes do not fit perfectly, we will do a centre crop and only discriminate that part acc_real_input = Input.crop(acc_norm_u, disc_input_shape) acc_fake_input = Input.crop(separator_acc_norm_u, disc_input_shape) drums_real_input = Input.crop(drums_norm_u, disc_input_shape) drums_fake_input = Input.crop(separator_drums_norm_u, disc_input_shape) #WGAN acc_disc_loss, acc_disc_real, acc_disc_fake, acc_grad_pen, acc_wasserstein_dist = \ Models.WGAN_Critic.create_critic(model_config, real_input=acc_real_input, fake_input=acc_fake_input, scope="acc_disc", network_func=disc_func) drums_disc_loss, drums_disc_real, drums_disc_fake, drums_grad_pen, drums_wasserstein_dist = \ Models.WGAN_Critic.create_critic(model_config, real_input=drums_real_input, fake_input=drums_fake_input, scope="drums_disc", network_func=disc_func) L_u = -tf.reduce_mean(drums_disc_fake) - tf.reduce_mean( acc_disc_fake) # WGAN based loss for separator (L_u in paper) unsup_separator_loss = model_config["alpha"] * L_u + model_config[ "beta"] * mask_loss_u # Unsupervised loss for separator: WGAN-based loss L_u and additive penalty term (mask loss), weighted by alpha and beta (hyperparameters) # Supervised objective: MSE in log-normalized magnitude space sup_separator_loss = tf.reduce_mean(tf.square(separator_drums_norm - drums_norm)) + \ tf.reduce_mean(tf.square(separator_acc_norm - acc_norm)) separator_loss = sup_separator_loss + unsup_separator_loss # Total separator loss: Supervised + unsupervised loss # TRAINING CONTROL VARIABLES global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False, dtype=tf.int64) increment_global_step = tf.assign(global_step, global_step + 1) disc_lr = tf.get_variable('disc_lr', [], initializer=tf.constant_initializer( model_config["init_disc_lr"], dtype=tf.float32), trainable=False) unsup_sep_lr = tf.get_variable('unsup_sep_lr', [], initializer=tf.constant_initializer( model_config["init_unsup_sep_lr"], dtype=tf.float32), trainable=False) sup_sep_lr = tf.get_variable('sup_sep_lr', [], initializer=tf.constant_initializer( model_config["init_sup_sep_lr"], dtype=tf.float32), trainable=False) # Set up optimizers separator_vars = Utils.getTrainableVariables("separator") print("Sep_Vars: " + str(Utils.getNumParams(separator_vars))) acc_disc_vars, drums_disc_vars = Utils.getTrainableVariables( "acc_disc"), Utils.getTrainableVariables("drums_disc") print("Drums_Disc_Vars: " + str(Utils.getNumParams(drums_disc_vars))) print("Acc_Disc_Vars: " + str(Utils.getNumParams(acc_disc_vars))) if unsup_dataset is not None: with tf.variable_scope("drums_disc_solver"): drums_disc_solver = tf.train.AdamOptimizer( learning_rate=disc_lr).minimize( drums_disc_loss, var_list=drums_disc_vars, colocate_gradients_with_ops=True) with tf.variable_scope("acc_disc_solver"): acc_disc_solver = tf.train.AdamOptimizer( learning_rate=disc_lr).minimize( acc_disc_loss, var_list=acc_disc_vars, colocate_gradients_with_ops=True) with tf.variable_scope("unsup_separator_solver"): unsup_separator_solver = tf.train.AdamOptimizer( learning_rate=unsup_sep_lr).minimize( separator_loss, var_list=separator_vars, colocate_gradients_with_ops=True) else: with tf.variable_scope("separator_solver"): sup_separator_solver = (tf.train.AdamOptimizer( learning_rate=sup_sep_lr).minimize( sup_separator_loss, var_list=separator_vars, colocate_gradients_with_ops=True)) # SUMMARIES FOR DISCRIMINATORS AND LOSSES acc_disc_summaries = tf.summary.merge_all(key="acc_disc") drums_disc_summaries = tf.summary.merge_all(key="drums_disc") tf.summary.scalar("sup_sep_loss", sup_separator_loss, collections=['sup', "unsup"]) tf.summary.scalar("unsup_sep_loss", unsup_separator_loss, collections=['unsup']) tf.summary.scalar("sep_loss", separator_loss, collections=["sup", "unsup"]) sup_summaries = tf.summary.merge_all(key='sup') unsup_summaries = tf.summary.merge_all(key='unsup') # Start session config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) sess.run(tf.global_variables_initializer()) writer = tf.summary.FileWriter(model_config["log_dir"] + os.path.sep + model_folder, graph=sess.graph) # CHECKPOINTING # Load pretrained model to continue training, if we are supposed to if load_model is not None: restorer = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V2) print("Num of variables: " + str(len(tf.global_variables()))) restorer.restore(sess, load_model) print('Pre-trained model restored from file ' + load_model) saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V2) # Start training loop run = True _global_step = sess.run(global_step) _init_step = _global_step it = 0 while run: if unsup_dataset is not None: # TRAIN DISCRIMINATORS for disc_it in range(model_config["num_disc"]): batches = list() for gen in unsup_batch_gens: batches.append(gen.get_batch()) _, _acc_disc_summaries = sess.run( [acc_disc_solver, acc_disc_summaries], feed_dict={ mix_context_u: batches[0], acc_u: batches[1] }) _, _drums_disc_summaries = sess.run( [drums_disc_solver, drums_disc_summaries], feed_dict={ mix_context_u: batches[0], drums_u: batches[2] }) writer.add_summary(_acc_disc_summaries, global_step=it) writer.add_summary(_drums_disc_summaries, global_step=it) it += 1 # TRAIN SEPARATOR sup_batch = sup_batch_gen.get_batch() if unsup_dataset is not None: # SUP + UNSUPERVISED TRAINING unsup_batches = list() for gen in unsup_batch_gens: unsup_batches.append(gen.get_batch()) _, _unsup_summaries, _sup_summaries = sess.run( [unsup_separator_solver, unsup_summaries, sup_summaries], feed_dict={ mix_context: sup_batch[0], acc: sup_batch[1], drums: sup_batch[2], mix_context_u: unsup_batches[0], acc_u: unsup_batches[1], drums_u: unsup_batches[2] }) writer.add_summary(_unsup_summaries, global_step=_global_step) else: # PURELY SUPERVISED TRAINING _, _sup_summaries = sess.run([sup_separator_solver, sup_summaries], feed_dict={ mix_context: sup_batch[0], acc: sup_batch[1], drums: sup_batch[2] }) writer.add_summary(_sup_summaries, global_step=_global_step) # Increment step counter, check if maximum iterations per epoch is achieved and stop in that case _global_step = sess.run(increment_global_step) if _global_step - _init_step > model_config["epoch_it"]: run = False print("Finished training phase, stopping batch generators") sup_batch_gen.stop_workers() if unsup_dataset is not None: for gen in unsup_batch_gens: gen.stop_workers() # Epoch finished - Save model print("Finished epoch!") save_path = saver.save(sess, model_config["model_base_dir"] + os.path.sep + model_folder + os.path.sep + model_folder, global_step=int(_global_step)) # Close session, clear computational graph writer.flush() writer.close() sess.close() tf.reset_default_graph() return save_path
def bss_evaluate(model_config, dataset, load_model): ''' Calculates source separation evaluation metrics of a given separator model on the test set using BSS-Eval :param model_config: Separation network configuration required to build symbolic computation graph of network :param dataset: Test dataset :param load_model: Path to separator model checkpoint containing the network weights :return: Dict containing evaluation metrics, wav files of predicted drum and acc are written to results directory ''' # Determine input and output shapes, if we use U-net as separator track_number = 1 freq_bins = model_config["num_fft"] / 2 + 1 # Make even number of freq bins disc_input_shape = [1, freq_bins - 1, model_config["num_frames"], 1] # Shape of discriminator input separator_class = Models.Unet.Unet(model_config["num_layers"]) sep_input_shape, sep_output_shape = separator_class.getUnetPadding( np.array(disc_input_shape)) separator_func = separator_class.get_output # Placeholders and input normalisation input_ph, queue, [mix_context, acc, drums] = Input.get_multitrack_input( sep_output_shape[1:], 1, name="input_batch", input_shape=sep_input_shape[1:]) mix = Input.crop(mix_context, sep_output_shape) mix_norm, mix_context_norm, acc_norm, drums_norm = Input.norm(mix), Input.norm(mix_context), \ Input.norm(acc), Input.norm(drums) print("Testing...") # BUILD MODELS # Separator separator_acc_norm, separator_drums_norm = separator_func(mix_context_norm, reuse=False) separator_acc, separator_drums = Input.denorm( separator_acc_norm), Input.denorm(separator_drums_norm) # Start session and queue input threads sess = tf.Session() sess.run(tf.global_variables_initializer()) # Load model # Load pretrained model to continue training, if we are supposed to restorer = tf.train.Saver(None, write_version=tf.train.SaverDef.V2) print("Num of variables" + str(len(tf.global_variables()))) restorer.restore(sess, load_model) print('Pre-trained model restored for testing') # Initialize total score object song_scores = list() for multitrack in dataset[8:15]: filename = multitrack[0].path print("Evaluating file: " + filename + "Track number: " + track_number) if filename.__contains__("DSD100"): db = "DSD100" elif filename.__contains__("Kala"): db = "IKala" elif filename.__contains__("ccmixter"): db = "CCMixter" elif filename.__contains__("MedleyDB"): db = "MedleyDB" else: db = 'musdb18orNIStems' song_info = {"Title": filename, "Database": db} # Load mixture and pad it so that output sources have the same length after STFT/ISTFT mix_audio, mix_sr = librosa.load(multitrack[0].path, sr=model_config["expected_sr"]) mix_length = len(mix_audio) # Pad input so that ISTFT later leads to same-length audio mix_audio_pad = librosa.util.fix_length( mix_audio, mix_length + model_config["num_fft"] // 2) mix_mag, mix_ph = Input.audioFileToSpectrogram(mix_audio_pad, model_config["num_fft"], model_config["num_hop"]) source_time_frames = mix_mag.shape[1] # Preallocate source predictions (same shape as input mixture) acc_pred_mag = np.zeros(mix_mag.shape, np.float32) drums_pred_mag = np.zeros(mix_mag.shape, np.float32) input_time_frames = sep_input_shape[2] output_time_frames = sep_output_shape[2] # Pad mix spectrogram across time at beg and end so network can make prediction at the beginning and end of signal pad_time_frames = (input_time_frames - output_time_frames) / 2 mix_mag = np.pad(mix_mag, [(0, 0), (pad_time_frames, pad_time_frames)], mode="constant", constant_values=0.0) # Iterate over mixture magnitudes, fetch network prediction for source_pos in range(0, source_time_frames, output_time_frames): # If output patch reaches over the end of source spectrogram, set it so we predict at end of the output, then stop if source_pos + output_time_frames > source_time_frames: source_pos = source_time_frames - output_time_frames # Prepare mixture excerpt by selecting time interval mix_mag_part = mix_mag[:, source_pos:source_pos + input_time_frames] mix_mag_part = Utils.pad_freqs( mix_mag_part, sep_input_shape[1:3]) # Pad along frequency axis mix_mag_part = mix_mag_part[np.newaxis, :, :, np.newaxis] # Fetch network prediction acc_mag_part, drums_mag_part = sess.run( [separator_acc, separator_drums], feed_dict={mix_context: mix_mag_part}) # Save predictions #source_shape = [1, freq_bins, acc_mag_part.shape[2], 1] acc_pred_mag[:, source_pos:source_pos + output_time_frames] = acc_mag_part[0, :-1, :, 0] drums_pred_mag[:, source_pos:source_pos + output_time_frames] = drums_mag_part[0, :-1, :, 0] # Spectrograms to audio, using mixture phase acc_pred_audio = Input.spectrogramToAudioFile(acc_pred_mag, model_config["num_fft"], model_config["num_hop"], phase=mix_ph, length=mix_length, phaseIterations=0) drums_pred_audio = Input.spectrogramToAudioFile( drums_pred_mag, model_config["num_fft"], model_config["num_hop"], phase=mix_ph, length=mix_length, phaseIterations=0) # Load original sources if isinstance(multitrack[1], float): acc_audio = np.zeros(mix_audio.shape, np.float32) else: acc_audio, _ = librosa.load(multitrack[1].path, sr=model_config["expected_sr"]) if isinstance(multitrack[2], float): drum_audio = np.zeros(mix_audio.shape, np.float32) else: drum_audio, _ = librosa.load(multitrack[2].path, sr=model_config["expected_sr"]) # Check if any reference source is completely silent, if so, inject some very slight noise to avoid problems during SDR reference_zero = False if np.max(np.abs(acc_audio)) == 0.0: acc_audio += np.random.uniform(-1e-10, 1e-10, size=acc_audio.shape) reference_zero = True if np.max(np.abs(drum_audio)) == 0.0: drum_audio += np.random.uniform(-1e-10, 1e-10, size=drum_audio.shape) reference_zero = True # Evaluate BSS according to MIREX separation method # http://www.music-ir.org/mirex/wiki/2016:Singing_Voice_Separation #/ np.linalg.norm(acc_audio + drum_audio) # Normalized audio ref_sources = np.vstack([acc_audio, drum_audio]) #/ np.linalg.norm(acc_pred_audio + drums_pred_audio) # Normalized estimates pred_sources = np.vstack([acc_pred_audio, drums_pred_audio]) validate(ref_sources, pred_sources) scores = bss_eval_sources(ref_sources, pred_sources, compute_permutation=False) song_info["SDR"] = scores[0] song_info["SIR"] = scores[1] song_info["SAR"] = scores[2] # Compute reference scores and SNR only if both sources are not silent, since they are undefined otherwise if not reference_zero: mix_ref = np.vstack([mix_audio, mix_audio ]) #/ np.linalg.norm(mix_audio + mix_audio) mix_scores = bss_eval_sources(ref_sources, mix_ref, compute_permutation=False) norm_scores = np.array(scores) - np.array(mix_scores) # Compute SNR: 10 log_10 ( ||s_target||^2 / ||s_target - alpha * s_estimate||^2 ), scale target for optimal SNR drums_snr = alpha_snr(drum_audio, drums_pred_audio) acc_snr = alpha_snr(acc_audio, acc_pred_audio) drums_ref_snr = alpha_snr(drum_audio, mix_audio) acc_ref_snr = alpha_snr(acc_audio, mix_audio) song_info["NSDR"] = norm_scores[0] song_info["NSIR"] = norm_scores[1] song_info["SNR"] = np.array([acc_snr, drums_snr]) song_info["NSNR"] = np.array( [acc_snr - acc_ref_snr, drums_snr - drums_ref_snr]) song_scores.append(song_info) print(song_info) dirpath = os.path.join(os.getcwd(), 'results') try: fn = dirpath + "tr_" + track_number + "_drums.wav" librosa.output.write_wav(fn, drums_pred_audio, sr=model_config["expected_sr"]) fn = dirpath + "tr_" + track_number + "_acc.wav" librosa.output.write_wav(fn, acc_pred_audio, sr=model_config["expected_sr"]) except Exception as e: print("Failed to write wav files, error: " + str(e)) track_number = track_number + 1 try: with open(dirpath + '/' + str(experiment_id) + "BSS_eval.pkl", "wb") as file: pickle.dump(song_scores, file) except Exception as e: print("Failed to write score files, error: " + str(e)) #Close session, clear computational graph sess.close() tf.reset_default_graph() return song_scores