def predict(track): ''' Function in accordance with MUSB evaluation API. Takes MUSDB track object and computes corresponding source estimates, as well as calls evlauation script. Model has to be saved beforehand into a pickle file containing model configuration dictionary and checkpoint path! :param track: Track object :return: Source estimates dictionary ''' '''if track.filename[:4] == "test" or int(track.filename[:3]) > 53: return { 'vocals': np.zeros(track.audio.shape), 'accompaniment': np.zeros(track.audio.shape) }''' # Load model hyper-parameters and model checkpoint path with open("prediction_params.pkl", "r") as file: [model_config, load_model] = pickle.load(file) # Determine input and output shapes, if we use U-net as separator disc_input_shape = [model_config["batch_size"], model_config["num_frames"], 0] # Shape of discriminator input if model_config["network"] == "unet": separator_class = Models.UnetAudioSeparator.UnetAudioSeparator(model_config["num_layers"], model_config["num_initial_filters"], output_type=model_config["output_type"], context=model_config["context"], mono=model_config["mono_downmix"], upsampling=model_config["upsampling"], num_sources=model_config["num_sources"], filter_size=model_config["filter_size"], merge_filter_size=model_config["merge_filter_size"]) elif model_config["network"] == "unet_spectrogram": separator_class = Models.UnetSpectrogramSeparator.UnetSpectrogramSeparator(model_config["num_layers"], model_config["num_initial_filters"], mono=model_config["mono_downmix"], num_sources=model_config["num_sources"]) else: raise NotImplementedError sep_input_shape, sep_output_shape = separator_class.get_padding(np.array(disc_input_shape)) separator_func = separator_class.get_output # Batch size of 1 sep_input_shape[0] = 1 sep_output_shape[0] = 1 mix_context, sources = Input.get_multitrack_placeholders(sep_output_shape, model_config["num_sources"], sep_input_shape, "input") print("Testing...") # BUILD MODELS # Separator separator_sources = separator_func(mix_context, False, reuse=False) # Start session and queue input threads sess = tf.Session() sess.run(tf.global_variables_initializer()) # Load model # Load pretrained model to continue training, if we are supposed to restorer = tf.train.Saver(None, write_version=tf.train.SaverDef.V2) print("Num of variables" + str(len(tf.global_variables()))) restorer.restore(sess, load_model) print('Pre-trained model restored for song prediction') mix_audio, orig_sr, mix_channels = track.audio, track.rate, track.audio.shape[1] # Audio has (n_samples, n_channels) shape separator_preds = predict_track(model_config, sess, mix_audio, orig_sr, sep_input_shape, sep_output_shape, separator_sources, mix_context) # Upsample predicted source audio and convert to stereo pred_audio = [librosa.resample(pred.T, model_config["expected_sr"], orig_sr).T for pred in separator_preds] if model_config["mono_downmix"] and mix_channels > 1: # Convert to multichannel if mixture input was multichannel by duplicating mono estimate pred_audio = [np.tile(pred, [1, mix_channels]) for pred in pred_audio] # Set estimates depending on estimation task (voice or multi-instrument separation) if model_config["task"] == "voice": # [acc, vocals] order estimates = { 'vocals' : pred_audio[1], 'accompaniment' : pred_audio[0] } else: # [bass, drums, other, vocals] estimates = { 'bass' : pred_audio[0], 'drums' : pred_audio[1], 'other' : pred_audio[2], 'vocals' : pred_audio[3] } # Evaluate using museval scores = museval.eval_mus_track( track, estimates, output_dir="/mnt/daten/Datasets/MUSDB18/eval", # SiSec should use longer win and hop parameters here to make evaluation more stable! ) # print nicely formatted mean scores print(scores) # Close session, clear computational graph sess.close() tf.reset_default_graph() return estimates
def train(model_config, experiment_id, sup_dataset, load_model=None): # Determine input and output shapes disc_input_shape = [model_config["batch_size"], model_config["num_frames"], 0] # Shape of input if model_config["network"] == "unet": separator_class = Models.UnetAudioSeparator.UnetAudioSeparator(model_config["num_layers"], model_config["num_initial_filters"], output_type=model_config["output_type"], context=model_config["context"], mono=model_config["mono_downmix"], upsampling=model_config["upsampling"], num_sources=model_config["num_sources"], filter_size=model_config["filter_size"], merge_filter_size=model_config["merge_filter_size"]) elif model_config["network"] == "unet_spectrogram": separator_class = Models.UnetSpectrogramSeparator.UnetSpectrogramSeparator(model_config["num_layers"], model_config["num_initial_filters"], mono=model_config["mono_downmix"], num_sources=model_config["num_sources"]) else: raise NotImplementedError sep_input_shape, sep_output_shape = separator_class.get_padding(np.array(disc_input_shape)) separator_func = separator_class.get_output # Creating the batch generators assert((sep_input_shape[1] - sep_output_shape[1]) % 2 == 0) pad_durations = np.array([float((sep_input_shape[1] - sep_output_shape[1])/2), 0, 0]) / float(model_config["expected_sr"]) # Input context that the input audio has to be padded ON EACH SIDE sup_batch_gen = batchgen.BatchGen_Paired( model_config, sup_dataset, sep_input_shape, sep_output_shape, pad_durations[0] ) print("Starting worker") sup_batch_gen.start_workers() print("Started worker!") # Placeholders and input normalisation mix_context, sources = Input.get_multitrack_placeholders(sep_output_shape, model_config["num_sources"], sep_input_shape, "sup") #tf.summary.audio("mix", mix_context, 22050, collections=["sup"]) mix = Utils.crop(mix_context, sep_output_shape) print("Training...") # BUILD MODELS # Separator separator_sources = separator_func(mix_context, True, not model_config["raw_audio_loss"], reuse=False) # Sources are output in order [acc, voice] for voice separation, [bass, drums, other, vocals] for multi-instrument separation # Supervised objective: MSE in log-normalized magnitude space separator_loss = 0 for (real_source, sep_source) in zip(sources, separator_sources): if model_config["network"] == "unet_spectrogram" and not model_config["raw_audio_loss"]: window = functools.partial(window_ops.hann_window, periodic=True) stfts = tf.contrib.signal.stft(tf.squeeze(real_source, 2), frame_length=1024, frame_step=768, fft_length=1024, window_fn=window) real_mag = tf.abs(stfts) separator_loss += tf.reduce_mean(tf.abs(real_mag - sep_source)) else: separator_loss += tf.reduce_mean(tf.square(real_source - sep_source)) separator_loss = separator_loss / float(len(sources)) # Normalise by number of sources # TRAINING CONTROL VARIABLES global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False, dtype=tf.int64) increment_global_step = tf.assign(global_step, global_step + 1) # Set up optimizers separator_vars = Utils.getTrainableVariables("separator") print("Sep_Vars: " + str(Utils.getNumParams(separator_vars))) print("Num of variables" + str(len(tf.global_variables()))) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): with tf.variable_scope("separator_solver"): separator_solver = tf.train.AdamOptimizer(learning_rate=model_config["init_sup_sep_lr"]).minimize(separator_loss, var_list=separator_vars) # SUMMARIES tf.summary.scalar("sep_loss", separator_loss, collections=["sup"]) sup_summaries = tf.summary.merge_all(key='sup') # Start session and queue input threads config = tf.ConfigProto() config.gpu_options.allow_growth=True sess = tf.Session(config=config) sess.run(tf.global_variables_initializer()) writer = tf.summary.FileWriter(model_config["log_dir"] + os.path.sep + str(experiment_id),graph=sess.graph) # CHECKPOINTING # Load pretrained model to continue training, if we are supposed to if load_model != None: restorer = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V2) print("Num of variables" + str(len(tf.global_variables()))) restorer.restore(sess, load_model) print('Pre-trained model restored from file ' + load_model) saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V2) # Start training loop run = True _global_step = sess.run(global_step) _init_step = _global_step it = 0 while run: # TRAIN SEPARATOR sup_batch = sup_batch_gen.get_batch() feed = {i:d for i,d in zip(sources, sup_batch[1:])} feed.update({mix_context : sup_batch[0]}) _, _sup_summaries = sess.run([separator_solver, sup_summaries], feed) writer.add_summary(_sup_summaries, global_step=_global_step) # Increment step counter, check if maximum iterations per epoch is achieved and stop in that case _global_step = sess.run(increment_global_step) if _global_step - _init_step > model_config["epoch_it"]: run = False print("Finished training phase, stopping batch generators") sup_batch_gen.stop_workers() # Epoch finished - Save model print("Finished epoch!") save_path = saver.save(sess, model_config["model_base_dir"] + os.path.sep + str(experiment_id) + os.path.sep + str(experiment_id), global_step=int(_global_step)) # Close session, clear computational graph writer.flush() writer.close() sess.close() tf.reset_default_graph() return save_path
def test(model_config, audio_list, model_folder, load_model): # Determine input and output shapes disc_input_shape = [ model_config["batch_size"], model_config["num_frames"], 0 ] # Shape of discriminator input if model_config["network"] == "unet": separator_class = Models.UnetAudioSeparator.UnetAudioSeparator( model_config["num_layers"], model_config["num_initial_filters"], output_type=model_config["output_type"], context=model_config["context"], mono=model_config["mono_downmix"], upsampling=model_config["upsampling"], num_sources=model_config["num_sources"], filter_size=model_config["filter_size"], merge_filter_size=model_config["merge_filter_size"]) elif model_config["network"] == "unet_spectrogram": separator_class = Models.UnetSpectrogramSeparator.UnetSpectrogramSeparator( model_config["num_layers"], model_config["num_initial_filters"], mono=model_config["mono_downmix"], num_sources=model_config["num_sources"]) else: raise NotImplementedError sep_input_shape, sep_output_shape = separator_class.get_padding( np.array(disc_input_shape)) separator_func = separator_class.get_output # Creating the batch generators assert ((sep_input_shape[1] - sep_output_shape[1]) % 2 == 0) # Batch size of 1 sep_input_shape[0] = 1 sep_output_shape[0] = 1 mix_context, sources = Input.get_multitrack_placeholders( sep_output_shape, model_config["num_sources"], sep_input_shape, "input") print("Testing...") # BUILD MODELS # Separator separator_sources = separator_func(mix_context, False, False, reuse=False) global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False, dtype=tf.int64) # Start session and queue input threads sess = tf.Session() sess.run(tf.global_variables_initializer()) writer = tf.summary.FileWriter(model_config["log_dir"] + os.path.sep + model_folder, graph=sess.graph) # CHECKPOINTING # Load pretrained model to test restorer = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V2) print("Num of variables" + str(len(tf.global_variables()))) restorer.restore(sess, load_model) print('Pre-trained model restored for testing') input_audio = tf.placeholder(tf.float32, shape=[None, 1]) window = functools.partial(window_ops.hann_window, periodic=True) stft = tf.contrib.signal.stft(tf.squeeze(input_audio, 1), frame_length=1024, frame_step=768, fft_length=1024, window_fn=window) mag = tf.abs(stft) # Start training loop _global_step = sess.run(global_step) print("Starting!") total_loss = 0.0 total_samples = 0 for sample in audio_list: # Go through all tracks # Load mixture and fetch prediction for mixture mix_audio, mix_sr = Utils.load(sample[0].path, sr=None, mono=False) sources_pred = Evaluate.predict_track(model_config, sess, mix_audio, mix_sr, sep_input_shape, sep_output_shape, separator_sources, mix_context) # Load original sources sources_gt = list() for s in sample[1:]: s_audio, _ = Utils.load(s.path, sr=model_config["expected_sr"], mono=model_config["mono_downmix"], res_type="kaiser_fast") sources_gt.append(s_audio) # Determine mean squared error for (source_gt, source_pred) in zip(sources_gt, sources_pred): if model_config[ "network"] == "unet_spectrogram" and not model_config[ "raw_audio_loss"]: real_mag = sess.run(mag, feed_dict={input_audio: source_gt}) pred_mag = sess.run(mag, feed_dict={input_audio: source_pred}) total_loss += np.sum(np.abs(real_mag - pred_mag)) total_samples += np.prod( real_mag.shape ) # Number of entries is product of number of sources and number of outputs per source else: total_loss += np.sum(np.square(source_gt - source_pred)) total_samples += np.prod( source_gt.shape ) # Number of entries is product of number of sources and number of outputs per source print("MSE for track " + sample[0].path + ": " + str(total_loss / float(total_samples))) mean_mse_loss = total_loss / float(total_samples) summary = tf.Summary( value=[tf.Summary.Value(tag="test_loss", simple_value=mean_mse_loss)]) writer.add_summary(summary, global_step=_global_step) writer.flush() writer.close() print("Finished testing - Mean MSE: " + str(mean_mse_loss)) # Close session, clear computational graph sess.close() tf.reset_default_graph() return mean_mse_loss
def train(model_config, sup_dataset, model_folder, unsup_dataset=None, load_model=None): # Determine input and output shapes freq_bins = model_config["num_fft"] / 2 + 1 # Make even number of freq bins disc_input_shape = [ model_config["batch_size"], freq_bins - 1, model_config["num_frames"], 1 ] # Shape of discriminator input separator_class = Models.Unet.Unet(model_config["num_layers"]) sep_input_shape, sep_output_shape = separator_class.getUnetPadding( np.array(disc_input_shape)) separator_func = separator_class.get_output # Batch input workers # Creating the batch generators padding_durations = [ float(sep_input_shape[2] - sep_output_shape[2]) * model_config["num_hop"] / model_config["expected_sr"] / 2.0, 0, 0 ] # Input context that the input audio has to be padded with while reading audio files sup_batch_gen = batchgen.BatchGen_Paired(model_config, sup_dataset, sep_input_shape, sep_output_shape, padding_durations[0]) # Creating unsupervised batch generator if needed if unsup_dataset is not None: unsup_batch_gens = list() for i in range(3): shape = (sep_input_shape if i == 0 else sep_output_shape) unsup_batch_gens.append( batchgen.BatchGen_Single(model_config, unsup_dataset[i], shape, padding_durations[i])) print("Starting worker") sup_batch_gen.start_workers() print("Started worker!") if unsup_dataset is not None: for gen in unsup_batch_gens: print("Starting worker") gen.start_workers() print("Started worker!") # Placeholders and input normalisation mix_context, acc, drums = Input.get_multitrack_placeholders( sep_output_shape, sep_input_shape, "sup") mix = Input.crop(mix_context, sep_output_shape) mix_norm, mix_context_norm, acc_norm, drums_norm = Input.norm( mix), Input.norm(mix_context), Input.norm(acc), Input.norm(drums) if unsup_dataset is not None: mix_context_u, acc_u, drums_u = Input.get_multitrack_placeholders( sep_output_shape, sep_input_shape, "unsup") mix_u = Input.crop(mix_context_u, sep_output_shape) mix_norm_u, mix_context_norm_u, acc_norm_u, drums_norm_u = Input.norm( mix_u), Input.norm(mix_context_u), Input.norm(acc_u), Input.norm( drums_u) print("Training...") # BUILD MODELS # Separator separator_acc_norm, separator_drums_norm = separator_func(mix_context_norm, reuse=False) separator_acc, separator_drums = Input.denorm( separator_acc_norm), Input.denorm(separator_drums_norm) if unsup_dataset is not None: separator_acc_norm_u, separator_drums_norm_u = separator_func( mix_context_norm_u, reuse=True) separator_acc_u, separator_drums_u = Input.denorm( separator_acc_norm_u), Input.denorm(separator_drums_norm_u) mask_loss_u = tf.reduce_mean( tf.square(mix_u - separator_acc_u - separator_drums_u)) mask_loss = tf.reduce_mean(tf.square(mix - separator_acc - separator_drums)) # SUMMARIES FOR INPUT AND SEPARATOR tf.summary.scalar("mask_loss", mask_loss, collections=["sup", "unsup"]) if unsup_dataset is not None: tf.summary.scalar("mask_loss_u", mask_loss_u, collections=["unsup"]) tf.summary.scalar("acc_norm_mean_u", tf.reduce_mean(acc_norm_u), collections=["acc_disc"]) tf.summary.scalar("drums_norm_mean_u", tf.reduce_mean(drums_norm_u), collections=["drums_disc"]) tf.summary.scalar("acc_sep_norm_mean_u", tf.reduce_mean(separator_acc_norm_u), collections=["acc_disc"]) tf.summary.scalar("drums_sep_norm_mean_u", tf.reduce_mean(separator_drums_norm_u), collections=["drums_disc"]) tf.summary.scalar("acc_norm_mean", tf.reduce_mean(acc_norm), collections=['sup']) tf.summary.scalar("drums_norm_mean", tf.reduce_mean(drums_norm), collections=['sup']) tf.summary.scalar("acc_sep_norm_mean", tf.reduce_mean(separator_acc_norm), collections=['sup']) tf.summary.scalar("drums_sep_norm_mean", tf.reduce_mean(separator_drums_norm), collections=['sup']) tf.summary.image("sep_acc_norm", separator_acc_norm, collections=["sup", "unsup"]) tf.summary.image("sep_drums_norm", separator_drums_norm, collections=["sup", "unsup"]) # BUILD DISCRIMINATORS, if unsupervised training unsup_separator_loss = 0 if unsup_dataset is not None: disc_func = Models.WGAN_Critic.dcgan # Define real and fake inputs for both discriminators - if separator output and dsicriminator input shapes do not fit perfectly, we will do a centre crop and only discriminate that part acc_real_input = Input.crop(acc_norm_u, disc_input_shape) acc_fake_input = Input.crop(separator_acc_norm_u, disc_input_shape) drums_real_input = Input.crop(drums_norm_u, disc_input_shape) drums_fake_input = Input.crop(separator_drums_norm_u, disc_input_shape) #WGAN acc_disc_loss, acc_disc_real, acc_disc_fake, acc_grad_pen, acc_wasserstein_dist = \ Models.WGAN_Critic.create_critic(model_config, real_input=acc_real_input, fake_input=acc_fake_input, scope="acc_disc", network_func=disc_func) drums_disc_loss, drums_disc_real, drums_disc_fake, drums_grad_pen, drums_wasserstein_dist = \ Models.WGAN_Critic.create_critic(model_config, real_input=drums_real_input, fake_input=drums_fake_input, scope="drums_disc", network_func=disc_func) L_u = -tf.reduce_mean(drums_disc_fake) - tf.reduce_mean( acc_disc_fake) # WGAN based loss for separator (L_u in paper) unsup_separator_loss = model_config["alpha"] * L_u + model_config[ "beta"] * mask_loss_u # Unsupervised loss for separator: WGAN-based loss L_u and additive penalty term (mask loss), weighted by alpha and beta (hyperparameters) # Supervised objective: MSE in log-normalized magnitude space sup_separator_loss = tf.reduce_mean(tf.square(separator_drums_norm - drums_norm)) + \ tf.reduce_mean(tf.square(separator_acc_norm - acc_norm)) separator_loss = sup_separator_loss + unsup_separator_loss # Total separator loss: Supervised + unsupervised loss # TRAINING CONTROL VARIABLES global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False, dtype=tf.int64) increment_global_step = tf.assign(global_step, global_step + 1) disc_lr = tf.get_variable('disc_lr', [], initializer=tf.constant_initializer( model_config["init_disc_lr"], dtype=tf.float32), trainable=False) unsup_sep_lr = tf.get_variable('unsup_sep_lr', [], initializer=tf.constant_initializer( model_config["init_unsup_sep_lr"], dtype=tf.float32), trainable=False) sup_sep_lr = tf.get_variable('sup_sep_lr', [], initializer=tf.constant_initializer( model_config["init_sup_sep_lr"], dtype=tf.float32), trainable=False) # Set up optimizers separator_vars = Utils.getTrainableVariables("separator") print("Sep_Vars: " + str(Utils.getNumParams(separator_vars))) acc_disc_vars, drums_disc_vars = Utils.getTrainableVariables( "acc_disc"), Utils.getTrainableVariables("drums_disc") print("Drums_Disc_Vars: " + str(Utils.getNumParams(drums_disc_vars))) print("Acc_Disc_Vars: " + str(Utils.getNumParams(acc_disc_vars))) if unsup_dataset is not None: with tf.variable_scope("drums_disc_solver"): drums_disc_solver = tf.train.AdamOptimizer( learning_rate=disc_lr).minimize( drums_disc_loss, var_list=drums_disc_vars, colocate_gradients_with_ops=True) with tf.variable_scope("acc_disc_solver"): acc_disc_solver = tf.train.AdamOptimizer( learning_rate=disc_lr).minimize( acc_disc_loss, var_list=acc_disc_vars, colocate_gradients_with_ops=True) with tf.variable_scope("unsup_separator_solver"): unsup_separator_solver = tf.train.AdamOptimizer( learning_rate=unsup_sep_lr).minimize( separator_loss, var_list=separator_vars, colocate_gradients_with_ops=True) else: with tf.variable_scope("separator_solver"): sup_separator_solver = (tf.train.AdamOptimizer( learning_rate=sup_sep_lr).minimize( sup_separator_loss, var_list=separator_vars, colocate_gradients_with_ops=True)) # SUMMARIES FOR DISCRIMINATORS AND LOSSES acc_disc_summaries = tf.summary.merge_all(key="acc_disc") drums_disc_summaries = tf.summary.merge_all(key="drums_disc") tf.summary.scalar("sup_sep_loss", sup_separator_loss, collections=['sup', "unsup"]) tf.summary.scalar("unsup_sep_loss", unsup_separator_loss, collections=['unsup']) tf.summary.scalar("sep_loss", separator_loss, collections=["sup", "unsup"]) sup_summaries = tf.summary.merge_all(key='sup') unsup_summaries = tf.summary.merge_all(key='unsup') # Start session config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) sess.run(tf.global_variables_initializer()) writer = tf.summary.FileWriter(model_config["log_dir"] + os.path.sep + model_folder, graph=sess.graph) # CHECKPOINTING # Load pretrained model to continue training, if we are supposed to if load_model is not None: restorer = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V2) print("Num of variables: " + str(len(tf.global_variables()))) restorer.restore(sess, load_model) print('Pre-trained model restored from file ' + load_model) saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V2) # Start training loop run = True _global_step = sess.run(global_step) _init_step = _global_step it = 0 while run: if unsup_dataset is not None: # TRAIN DISCRIMINATORS for disc_it in range(model_config["num_disc"]): batches = list() for gen in unsup_batch_gens: batches.append(gen.get_batch()) _, _acc_disc_summaries = sess.run( [acc_disc_solver, acc_disc_summaries], feed_dict={ mix_context_u: batches[0], acc_u: batches[1] }) _, _drums_disc_summaries = sess.run( [drums_disc_solver, drums_disc_summaries], feed_dict={ mix_context_u: batches[0], drums_u: batches[2] }) writer.add_summary(_acc_disc_summaries, global_step=it) writer.add_summary(_drums_disc_summaries, global_step=it) it += 1 # TRAIN SEPARATOR sup_batch = sup_batch_gen.get_batch() if unsup_dataset is not None: # SUP + UNSUPERVISED TRAINING unsup_batches = list() for gen in unsup_batch_gens: unsup_batches.append(gen.get_batch()) _, _unsup_summaries, _sup_summaries = sess.run( [unsup_separator_solver, unsup_summaries, sup_summaries], feed_dict={ mix_context: sup_batch[0], acc: sup_batch[1], drums: sup_batch[2], mix_context_u: unsup_batches[0], acc_u: unsup_batches[1], drums_u: unsup_batches[2] }) writer.add_summary(_unsup_summaries, global_step=_global_step) else: # PURELY SUPERVISED TRAINING _, _sup_summaries = sess.run([sup_separator_solver, sup_summaries], feed_dict={ mix_context: sup_batch[0], acc: sup_batch[1], drums: sup_batch[2] }) writer.add_summary(_sup_summaries, global_step=_global_step) # Increment step counter, check if maximum iterations per epoch is achieved and stop in that case _global_step = sess.run(increment_global_step) if _global_step - _init_step > model_config["epoch_it"]: run = False print("Finished training phase, stopping batch generators") sup_batch_gen.stop_workers() if unsup_dataset is not None: for gen in unsup_batch_gens: gen.stop_workers() # Epoch finished - Save model print("Finished epoch!") save_path = saver.save(sess, model_config["model_base_dir"] + os.path.sep + model_folder + os.path.sep + model_folder, global_step=int(_global_step)) # Close session, clear computational graph writer.flush() writer.close() sess.close() tf.reset_default_graph() return save_path
def predict(model_config, input_path, output_path=None, output_name=None, load_model=None): # Determine input and output shapes print("Producing source estimates for input mixture file " + input_path) # Prepare input audio for prediction function # audio, sr = Utils.load(input_path, sr=None, mono=False) audio, sr = librosa.load(input_path, sr=model_config["sample_rate"], mono=False) expon = math.ceil(np.log(len(audio)) / np.log(2)) padded_total_len = int(2**expon) padded_len = padded_total_len - len(audio) padded_audio = np.pad(audio, (0, padded_len), 'constant') padded_audio = np.reshape(padded_audio, (1, len(padded_audio), 1)) disc_input_shape = [1, padded_audio.shape[1], 0] # Shape of input if model_config["network"] == "unet": separator_class = Models.UnetAudioSeparator.UnetAudioSeparator( model_config["num_layers"], model_config["num_initial_filters"], output_type=model_config["output_type"], context=model_config["context"], mono=model_config["mono_downmix"], upsampling=model_config["upsampling"], num_sources=model_config["num_sources"], filter_size=model_config["filter_size"], merge_filter_size=model_config["merge_filter_size"]) else: raise NotImplementedError sep_input_shape, sep_output_shape = separator_class.get_padding( np.array(disc_input_shape)) separator_func = separator_class.get_output print(sep_input_shape, sep_output_shape) # Creating the batch generators assert ((sep_input_shape[1] - sep_output_shape[1]) % 2 == 0) # Placeholders and input normalisation mix_context, sources = Input.get_multitrack_placeholders( sep_output_shape, model_config["num_sources"], sep_input_shape, "sup") # BUILD MODELS # Separator separator_sources = separator_func( mix_context, False, not model_config["raw_audio_loss"], reuse=False ) # Sources are output in order [noise, speech] for speech enhancement # separator_loss = tf.reduce_mean(tf.abs(sources - separator_sources[0])) # Set up optimizers separator_vars = Utils.getTrainableVariables("separator") print("Sep_Vars: " + str(Utils.getNumParams(separator_vars))) print("Num of variables " + str(len(tf.global_variables()))) # Start session and queue input threads config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) sess.run(tf.global_variables_initializer()) # CHECKPOINTING # Load pretrained model to continue training, if we are supposed to if load_model != None: restorer = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V2) print("Num of variables" + str(len(tf.global_variables()))) restorer.restore(sess, load_model) print('Pre-trained model restored from file ' + load_model) separator_sources_value = sess.run([separator_sources], feed_dict={mix_context: padded_audio}) sess.close() tf.reset_default_graph() output_audio = separator_sources_value[0][0][0, :len(audio), 0] # Save source estimates as audio files into output dictionary input_folder, input_filename = os.path.split(input_path) if output_name is None: output_name = input_filename if output_path is None: # By default, set it to the input_path folder output_path = input_folder output_filename = os.path.join(output_path, output_name) else: output_filename = os.path.join(output_path, output_name) if not os.path.exists(output_path): print("WARNING: Given output path " + output_path + " does not exist. Trying to create it...") os.makedirs(output_path) assert (os.path.exists(output_path)) sf.write(output_filename, output_audio, model_config["sample_rate"])
def predict(track, model_config, load_model): ''' Takes audio track and computes corresponding source estimates. :param track: Track object :return: Source estimates dictionary ''' # Determine input and output shapes, if we use U-net as separator disc_input_shape = [ model_config["batch_size"], model_config["num_frames"], 0 ] # Shape of discriminator input if model_config["network"] == "unet": separator_class = Models.UnetAudioSeparator.UnetAudioSeparator( model_config["num_layers"], model_config["num_initial_filters"], output_type=model_config["output_type"], context=model_config["context"], mono=model_config["mono_downmix"], upsampling=model_config["upsampling"], num_sources=model_config["num_sources"], filter_size=model_config["filter_size"], merge_filter_size=model_config["merge_filter_size"]) else: raise NotImplementedError sep_input_shape, sep_output_shape = separator_class.get_padding( np.array(disc_input_shape)) separator_func = separator_class.get_output # Batch size of 1 sep_input_shape[0] = 1 sep_output_shape[0] = 1 mix_context, sources = Input.get_multitrack_placeholders( sep_output_shape, model_config["num_sources"], sep_input_shape, "input") print("Testing...") # BUILD MODELS # Separator separator_sources = separator_func(mix_context, False, reuse=False) # Start session and queue input threads sess = tf.Session() sess.run(tf.global_variables_initializer()) # Load model # Load pretrained model to continue training, if we are supposed to restorer = tf.train.Saver(None, write_version=tf.train.SaverDef.V2) print("Num of variables" + str(len(tf.global_variables()))) restorer.restore(sess, load_model) print('Pre-trained model restored for prediction') mix_audio, orig_sr, mix_channels = track.audio, track.rate, track.audio.shape[ 1] # Audio has (n_samples, n_channels) shape separator_preds = predict_track(model_config, sess, mix_audio, orig_sr, sep_input_shape, sep_output_shape, separator_sources, mix_context) # Upsample predicted source audio and convert to stereo pred_audio = [ Utils.resample(pred, model_config["expected_sr"], orig_sr) for pred in separator_preds ] if model_config[ "mono_downmix"] and mix_channels > 1: # Convert to multichannel if mixture input was multichannel by duplicating mono estimate pred_audio = [np.tile(pred, [1, mix_channels]) for pred in pred_audio] # Set estimates for source separation task for speech enhancement estimates = { # [noise, speech] order 'speech': pred_audio[1], 'noise': pred_audio[0] #comment-out this line to only yield speech file } # Close session, clear computational graph sess.close() tf.reset_default_graph() return estimates