示例#1
0
def predict(track):
    '''
    Function in accordance with MUSB evaluation API. Takes MUSDB track object and computes corresponding source estimates, as well as calls evlauation script.
    Model has to be saved beforehand into a pickle file containing model configuration dictionary and checkpoint path!
    :param track: Track object
    :return: Source estimates dictionary
    '''
    '''if track.filename[:4] == "test" or int(track.filename[:3]) > 53:
        return {
            'vocals': np.zeros(track.audio.shape),
            'accompaniment': np.zeros(track.audio.shape)
        }'''
    # Load model hyper-parameters and model checkpoint path
    with open("prediction_params.pkl", "r") as file:
        [model_config, load_model] = pickle.load(file)

    # Determine input and output shapes, if we use U-net as separator
    disc_input_shape = [model_config["batch_size"], model_config["num_frames"], 0]  # Shape of discriminator input
    if model_config["network"] == "unet":
        separator_class = Models.UnetAudioSeparator.UnetAudioSeparator(model_config["num_layers"], model_config["num_initial_filters"],
                                                                   output_type=model_config["output_type"],
                                                                   context=model_config["context"],
                                                                   mono=model_config["mono_downmix"],
                                                                   upsampling=model_config["upsampling"],
                                                                   num_sources=model_config["num_sources"],
                                                                   filter_size=model_config["filter_size"],
                                                                   merge_filter_size=model_config["merge_filter_size"])
    elif model_config["network"] == "unet_spectrogram":
        separator_class = Models.UnetSpectrogramSeparator.UnetSpectrogramSeparator(model_config["num_layers"], model_config["num_initial_filters"],
                                                                       mono=model_config["mono_downmix"],
                                                                       num_sources=model_config["num_sources"])
    else:
        raise NotImplementedError

    sep_input_shape, sep_output_shape = separator_class.get_padding(np.array(disc_input_shape))
    separator_func = separator_class.get_output

    # Batch size of 1
    sep_input_shape[0] = 1
    sep_output_shape[0] = 1

    mix_context, sources = Input.get_multitrack_placeholders(sep_output_shape, model_config["num_sources"], sep_input_shape, "input")

    print("Testing...")

    # BUILD MODELS
    # Separator
    separator_sources = separator_func(mix_context, False, reuse=False)

    # Start session and queue input threads
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())

    # Load model
    # Load pretrained model to continue training, if we are supposed to
    restorer = tf.train.Saver(None, write_version=tf.train.SaverDef.V2)
    print("Num of variables" + str(len(tf.global_variables())))
    restorer.restore(sess, load_model)
    print('Pre-trained model restored for song prediction')

    mix_audio, orig_sr, mix_channels = track.audio, track.rate, track.audio.shape[1] # Audio has (n_samples, n_channels) shape
    separator_preds = predict_track(model_config, sess, mix_audio, orig_sr, sep_input_shape, sep_output_shape, separator_sources, mix_context)

    # Upsample predicted source audio and convert to stereo
    pred_audio = [librosa.resample(pred.T, model_config["expected_sr"], orig_sr).T for pred in separator_preds]

    if model_config["mono_downmix"] and mix_channels > 1: # Convert to multichannel if mixture input was multichannel by duplicating mono estimate
        pred_audio = [np.tile(pred, [1, mix_channels]) for pred in pred_audio]

    # Set estimates depending on estimation task (voice or multi-instrument separation)
    if model_config["task"] == "voice": # [acc, vocals] order
        estimates = {
            'vocals' : pred_audio[1],
            'accompaniment' : pred_audio[0]
        }
    else: # [bass, drums, other, vocals]
        estimates = {
            'bass' : pred_audio[0],
            'drums' : pred_audio[1],
            'other' : pred_audio[2],
            'vocals' : pred_audio[3]
        }

    # Evaluate using museval
    scores = museval.eval_mus_track(
        track, estimates, output_dir="/mnt/daten/Datasets/MUSDB18/eval", # SiSec should use longer win and hop parameters here to make evaluation more stable!
    )

    # print nicely formatted mean scores
    print(scores)

    # Close session, clear computational graph
    sess.close()
    tf.reset_default_graph()

    return estimates
示例#2
0
def train(model_config, experiment_id, sup_dataset, load_model=None):
    # Determine input and output shapes
    disc_input_shape = [model_config["batch_size"], model_config["num_frames"], 0]  # Shape of input
    if model_config["network"] == "unet":
        separator_class = Models.UnetAudioSeparator.UnetAudioSeparator(model_config["num_layers"], model_config["num_initial_filters"],
                                                                   output_type=model_config["output_type"],
                                                                   context=model_config["context"],
                                                                   mono=model_config["mono_downmix"],
                                                                   upsampling=model_config["upsampling"],
                                                                   num_sources=model_config["num_sources"],
                                                                   filter_size=model_config["filter_size"],
                                                                   merge_filter_size=model_config["merge_filter_size"])
    elif model_config["network"] == "unet_spectrogram":
        separator_class = Models.UnetSpectrogramSeparator.UnetSpectrogramSeparator(model_config["num_layers"], model_config["num_initial_filters"],
                                                                       mono=model_config["mono_downmix"],
                                                                       num_sources=model_config["num_sources"])
    else:
        raise NotImplementedError

    sep_input_shape, sep_output_shape = separator_class.get_padding(np.array(disc_input_shape))
    separator_func = separator_class.get_output

    # Creating the batch generators
    assert((sep_input_shape[1] - sep_output_shape[1]) % 2 == 0)
    pad_durations = np.array([float((sep_input_shape[1] - sep_output_shape[1])/2), 0, 0]) / float(model_config["expected_sr"])  # Input context that the input audio has to be padded ON EACH SIDE
    sup_batch_gen = batchgen.BatchGen_Paired(
        model_config,
        sup_dataset,
        sep_input_shape,
        sep_output_shape,
        pad_durations[0]
    )

    print("Starting worker")
    sup_batch_gen.start_workers()
    print("Started worker!")

    # Placeholders and input normalisation
    mix_context, sources = Input.get_multitrack_placeholders(sep_output_shape, model_config["num_sources"], sep_input_shape, "sup")
    #tf.summary.audio("mix", mix_context, 22050, collections=["sup"])
    mix = Utils.crop(mix_context, sep_output_shape)

    print("Training...")

    # BUILD MODELS
    # Separator
    separator_sources = separator_func(mix_context, True, not model_config["raw_audio_loss"], reuse=False) # Sources are output in order [acc, voice] for voice separation, [bass, drums, other, vocals] for multi-instrument separation

    # Supervised objective: MSE in log-normalized magnitude space
    separator_loss = 0
    for (real_source, sep_source) in zip(sources, separator_sources):
        if model_config["network"] == "unet_spectrogram" and not model_config["raw_audio_loss"]:
            window = functools.partial(window_ops.hann_window, periodic=True)
            stfts = tf.contrib.signal.stft(tf.squeeze(real_source, 2), frame_length=1024, frame_step=768,
                                           fft_length=1024, window_fn=window)
            real_mag = tf.abs(stfts)
            separator_loss += tf.reduce_mean(tf.abs(real_mag - sep_source))
        else:
            separator_loss += tf.reduce_mean(tf.square(real_source - sep_source))
    separator_loss = separator_loss / float(len(sources)) # Normalise by number of sources

    # TRAINING CONTROL VARIABLES
    global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False, dtype=tf.int64)
    increment_global_step = tf.assign(global_step, global_step + 1)

    # Set up optimizers
    separator_vars = Utils.getTrainableVariables("separator")
    print("Sep_Vars: " + str(Utils.getNumParams(separator_vars)))
    print("Num of variables" + str(len(tf.global_variables())))

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        with tf.variable_scope("separator_solver"):
            separator_solver = tf.train.AdamOptimizer(learning_rate=model_config["init_sup_sep_lr"]).minimize(separator_loss, var_list=separator_vars)

    # SUMMARIES
    tf.summary.scalar("sep_loss", separator_loss, collections=["sup"])
    sup_summaries = tf.summary.merge_all(key='sup')

    # Start session and queue input threads
    config = tf.ConfigProto()
    config.gpu_options.allow_growth=True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())
    writer = tf.summary.FileWriter(model_config["log_dir"] + os.path.sep + str(experiment_id),graph=sess.graph)

    # CHECKPOINTING
    # Load pretrained model to continue training, if we are supposed to
    if load_model != None:
        restorer = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V2)
        print("Num of variables" + str(len(tf.global_variables())))
        restorer.restore(sess, load_model)
        print('Pre-trained model restored from file ' + load_model)

    saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V2)

    # Start training loop
    run = True
    _global_step = sess.run(global_step)
    _init_step = _global_step
    it = 0
    while run:
        # TRAIN SEPARATOR
        sup_batch = sup_batch_gen.get_batch()
        feed = {i:d for i,d in zip(sources, sup_batch[1:])}
        feed.update({mix_context : sup_batch[0]})
        _, _sup_summaries = sess.run([separator_solver, sup_summaries], feed)
        writer.add_summary(_sup_summaries, global_step=_global_step)

        # Increment step counter, check if maximum iterations per epoch is achieved and stop in that case
        _global_step = sess.run(increment_global_step)

        if _global_step - _init_step > model_config["epoch_it"]:
            run = False
            print("Finished training phase, stopping batch generators")
            sup_batch_gen.stop_workers()

    # Epoch finished - Save model
    print("Finished epoch!")
    save_path = saver.save(sess, model_config["model_base_dir"] + os.path.sep + str(experiment_id) + os.path.sep + str(experiment_id), global_step=int(_global_step))

    # Close session, clear computational graph
    writer.flush()
    writer.close()
    sess.close()
    tf.reset_default_graph()

    return save_path
示例#3
0
def test(model_config, audio_list, model_folder, load_model):
    # Determine input and output shapes
    disc_input_shape = [
        model_config["batch_size"], model_config["num_frames"], 0
    ]  # Shape of discriminator input
    if model_config["network"] == "unet":
        separator_class = Models.UnetAudioSeparator.UnetAudioSeparator(
            model_config["num_layers"],
            model_config["num_initial_filters"],
            output_type=model_config["output_type"],
            context=model_config["context"],
            mono=model_config["mono_downmix"],
            upsampling=model_config["upsampling"],
            num_sources=model_config["num_sources"],
            filter_size=model_config["filter_size"],
            merge_filter_size=model_config["merge_filter_size"])
    elif model_config["network"] == "unet_spectrogram":
        separator_class = Models.UnetSpectrogramSeparator.UnetSpectrogramSeparator(
            model_config["num_layers"],
            model_config["num_initial_filters"],
            mono=model_config["mono_downmix"],
            num_sources=model_config["num_sources"])
    else:
        raise NotImplementedError

    sep_input_shape, sep_output_shape = separator_class.get_padding(
        np.array(disc_input_shape))
    separator_func = separator_class.get_output

    # Creating the batch generators
    assert ((sep_input_shape[1] - sep_output_shape[1]) % 2 == 0)

    # Batch size of 1
    sep_input_shape[0] = 1
    sep_output_shape[0] = 1

    mix_context, sources = Input.get_multitrack_placeholders(
        sep_output_shape, model_config["num_sources"], sep_input_shape,
        "input")

    print("Testing...")

    # BUILD MODELS
    # Separator
    separator_sources = separator_func(mix_context, False, False, reuse=False)

    global_step = tf.get_variable('global_step', [],
                                  initializer=tf.constant_initializer(0),
                                  trainable=False,
                                  dtype=tf.int64)

    # Start session and queue input threads
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    writer = tf.summary.FileWriter(model_config["log_dir"] + os.path.sep +
                                   model_folder,
                                   graph=sess.graph)

    # CHECKPOINTING
    # Load pretrained model to test
    restorer = tf.train.Saver(tf.global_variables(),
                              write_version=tf.train.SaverDef.V2)
    print("Num of variables" + str(len(tf.global_variables())))
    restorer.restore(sess, load_model)
    print('Pre-trained model restored for testing')

    input_audio = tf.placeholder(tf.float32, shape=[None, 1])
    window = functools.partial(window_ops.hann_window, periodic=True)
    stft = tf.contrib.signal.stft(tf.squeeze(input_audio, 1),
                                  frame_length=1024,
                                  frame_step=768,
                                  fft_length=1024,
                                  window_fn=window)
    mag = tf.abs(stft)

    # Start training loop
    _global_step = sess.run(global_step)
    print("Starting!")

    total_loss = 0.0
    total_samples = 0
    for sample in audio_list:  # Go through all tracks
        # Load mixture and fetch prediction for mixture
        mix_audio, mix_sr = Utils.load(sample[0].path, sr=None, mono=False)
        sources_pred = Evaluate.predict_track(model_config, sess, mix_audio,
                                              mix_sr, sep_input_shape,
                                              sep_output_shape,
                                              separator_sources, mix_context)

        # Load original sources
        sources_gt = list()
        for s in sample[1:]:
            s_audio, _ = Utils.load(s.path,
                                    sr=model_config["expected_sr"],
                                    mono=model_config["mono_downmix"],
                                    res_type="kaiser_fast")
            sources_gt.append(s_audio)

        # Determine mean squared error
        for (source_gt, source_pred) in zip(sources_gt, sources_pred):
            if model_config[
                    "network"] == "unet_spectrogram" and not model_config[
                        "raw_audio_loss"]:
                real_mag = sess.run(mag, feed_dict={input_audio: source_gt})
                pred_mag = sess.run(mag, feed_dict={input_audio: source_pred})
                total_loss += np.sum(np.abs(real_mag - pred_mag))
                total_samples += np.prod(
                    real_mag.shape
                )  # Number of entries is product of number of sources and number of outputs per source
            else:
                total_loss += np.sum(np.square(source_gt - source_pred))
                total_samples += np.prod(
                    source_gt.shape
                )  # Number of entries is product of number of sources and number of outputs per source

        print("MSE for track " + sample[0].path + ": " +
              str(total_loss / float(total_samples)))
    mean_mse_loss = total_loss / float(total_samples)

    summary = tf.Summary(
        value=[tf.Summary.Value(tag="test_loss", simple_value=mean_mse_loss)])
    writer.add_summary(summary, global_step=_global_step)

    writer.flush()
    writer.close()

    print("Finished testing - Mean MSE: " + str(mean_mse_loss))

    # Close session, clear computational graph
    sess.close()
    tf.reset_default_graph()

    return mean_mse_loss
示例#4
0
def train(model_config,
          sup_dataset,
          model_folder,
          unsup_dataset=None,
          load_model=None):
    # Determine input and output shapes
    freq_bins = model_config["num_fft"] / 2 + 1  # Make even number of freq bins
    disc_input_shape = [
        model_config["batch_size"], freq_bins - 1, model_config["num_frames"],
        1
    ]  # Shape of discriminator input

    separator_class = Models.Unet.Unet(model_config["num_layers"])
    sep_input_shape, sep_output_shape = separator_class.getUnetPadding(
        np.array(disc_input_shape))
    separator_func = separator_class.get_output

    # Batch input workers
    # Creating the batch generators
    padding_durations = [
        float(sep_input_shape[2] - sep_output_shape[2]) *
        model_config["num_hop"] / model_config["expected_sr"] / 2.0, 0, 0
    ]  # Input context that the input audio has to be padded with while reading audio files
    sup_batch_gen = batchgen.BatchGen_Paired(model_config, sup_dataset,
                                             sep_input_shape, sep_output_shape,
                                             padding_durations[0])

    # Creating unsupervised batch generator if needed
    if unsup_dataset is not None:
        unsup_batch_gens = list()
        for i in range(3):
            shape = (sep_input_shape if i == 0 else sep_output_shape)
            unsup_batch_gens.append(
                batchgen.BatchGen_Single(model_config, unsup_dataset[i], shape,
                                         padding_durations[i]))

    print("Starting worker")
    sup_batch_gen.start_workers()
    print("Started worker!")

    if unsup_dataset is not None:
        for gen in unsup_batch_gens:
            print("Starting worker")
            gen.start_workers()
            print("Started worker!")

    # Placeholders and input normalisation
    mix_context, acc, drums = Input.get_multitrack_placeholders(
        sep_output_shape, sep_input_shape, "sup")
    mix = Input.crop(mix_context, sep_output_shape)
    mix_norm, mix_context_norm, acc_norm, drums_norm = Input.norm(
        mix), Input.norm(mix_context), Input.norm(acc), Input.norm(drums)

    if unsup_dataset is not None:
        mix_context_u, acc_u, drums_u = Input.get_multitrack_placeholders(
            sep_output_shape, sep_input_shape, "unsup")
        mix_u = Input.crop(mix_context_u, sep_output_shape)
        mix_norm_u, mix_context_norm_u, acc_norm_u, drums_norm_u = Input.norm(
            mix_u), Input.norm(mix_context_u), Input.norm(acc_u), Input.norm(
                drums_u)

    print("Training...")

    # BUILD MODELS
    # Separator
    separator_acc_norm, separator_drums_norm = separator_func(mix_context_norm,
                                                              reuse=False)
    separator_acc, separator_drums = Input.denorm(
        separator_acc_norm), Input.denorm(separator_drums_norm)
    if unsup_dataset is not None:
        separator_acc_norm_u, separator_drums_norm_u = separator_func(
            mix_context_norm_u, reuse=True)
        separator_acc_u, separator_drums_u = Input.denorm(
            separator_acc_norm_u), Input.denorm(separator_drums_norm_u)
        mask_loss_u = tf.reduce_mean(
            tf.square(mix_u - separator_acc_u - separator_drums_u))
    mask_loss = tf.reduce_mean(tf.square(mix - separator_acc -
                                         separator_drums))

    # SUMMARIES FOR INPUT AND SEPARATOR
    tf.summary.scalar("mask_loss", mask_loss, collections=["sup", "unsup"])
    if unsup_dataset is not None:
        tf.summary.scalar("mask_loss_u", mask_loss_u, collections=["unsup"])
        tf.summary.scalar("acc_norm_mean_u",
                          tf.reduce_mean(acc_norm_u),
                          collections=["acc_disc"])
        tf.summary.scalar("drums_norm_mean_u",
                          tf.reduce_mean(drums_norm_u),
                          collections=["drums_disc"])
        tf.summary.scalar("acc_sep_norm_mean_u",
                          tf.reduce_mean(separator_acc_norm_u),
                          collections=["acc_disc"])
        tf.summary.scalar("drums_sep_norm_mean_u",
                          tf.reduce_mean(separator_drums_norm_u),
                          collections=["drums_disc"])
    tf.summary.scalar("acc_norm_mean",
                      tf.reduce_mean(acc_norm),
                      collections=['sup'])
    tf.summary.scalar("drums_norm_mean",
                      tf.reduce_mean(drums_norm),
                      collections=['sup'])
    tf.summary.scalar("acc_sep_norm_mean",
                      tf.reduce_mean(separator_acc_norm),
                      collections=['sup'])
    tf.summary.scalar("drums_sep_norm_mean",
                      tf.reduce_mean(separator_drums_norm),
                      collections=['sup'])

    tf.summary.image("sep_acc_norm",
                     separator_acc_norm,
                     collections=["sup", "unsup"])
    tf.summary.image("sep_drums_norm",
                     separator_drums_norm,
                     collections=["sup", "unsup"])

    # BUILD DISCRIMINATORS, if unsupervised training
    unsup_separator_loss = 0
    if unsup_dataset is not None:
        disc_func = Models.WGAN_Critic.dcgan

        # Define real and fake inputs for both discriminators - if separator output and dsicriminator input shapes do not fit perfectly, we will do a centre crop and only discriminate that part
        acc_real_input = Input.crop(acc_norm_u, disc_input_shape)
        acc_fake_input = Input.crop(separator_acc_norm_u, disc_input_shape)
        drums_real_input = Input.crop(drums_norm_u, disc_input_shape)
        drums_fake_input = Input.crop(separator_drums_norm_u, disc_input_shape)

        #WGAN
        acc_disc_loss, acc_disc_real, acc_disc_fake, acc_grad_pen, acc_wasserstein_dist = \
            Models.WGAN_Critic.create_critic(model_config, real_input=acc_real_input, fake_input=acc_fake_input, scope="acc_disc", network_func=disc_func)
        drums_disc_loss, drums_disc_real, drums_disc_fake, drums_grad_pen, drums_wasserstein_dist = \
            Models.WGAN_Critic.create_critic(model_config, real_input=drums_real_input, fake_input=drums_fake_input, scope="drums_disc", network_func=disc_func)

        L_u = -tf.reduce_mean(drums_disc_fake) - tf.reduce_mean(
            acc_disc_fake)  # WGAN based loss for separator (L_u in paper)
        unsup_separator_loss = model_config["alpha"] * L_u + model_config[
            "beta"] * mask_loss_u  # Unsupervised loss for separator: WGAN-based loss L_u and additive penalty term (mask loss), weighted by alpha and beta (hyperparameters)

    # Supervised objective: MSE in log-normalized magnitude space
    sup_separator_loss = tf.reduce_mean(tf.square(separator_drums_norm - drums_norm)) + \
                         tf.reduce_mean(tf.square(separator_acc_norm - acc_norm))

    separator_loss = sup_separator_loss + unsup_separator_loss  # Total separator loss: Supervised + unsupervised loss

    # TRAINING CONTROL VARIABLES
    global_step = tf.get_variable('global_step', [],
                                  initializer=tf.constant_initializer(0),
                                  trainable=False,
                                  dtype=tf.int64)
    increment_global_step = tf.assign(global_step, global_step + 1)
    disc_lr = tf.get_variable('disc_lr', [],
                              initializer=tf.constant_initializer(
                                  model_config["init_disc_lr"],
                                  dtype=tf.float32),
                              trainable=False)
    unsup_sep_lr = tf.get_variable('unsup_sep_lr', [],
                                   initializer=tf.constant_initializer(
                                       model_config["init_unsup_sep_lr"],
                                       dtype=tf.float32),
                                   trainable=False)
    sup_sep_lr = tf.get_variable('sup_sep_lr', [],
                                 initializer=tf.constant_initializer(
                                     model_config["init_sup_sep_lr"],
                                     dtype=tf.float32),
                                 trainable=False)

    # Set up optimizers
    separator_vars = Utils.getTrainableVariables("separator")
    print("Sep_Vars: " + str(Utils.getNumParams(separator_vars)))

    acc_disc_vars, drums_disc_vars = Utils.getTrainableVariables(
        "acc_disc"), Utils.getTrainableVariables("drums_disc")
    print("Drums_Disc_Vars: " + str(Utils.getNumParams(drums_disc_vars)))
    print("Acc_Disc_Vars: " + str(Utils.getNumParams(acc_disc_vars)))

    if unsup_dataset is not None:
        with tf.variable_scope("drums_disc_solver"):
            drums_disc_solver = tf.train.AdamOptimizer(
                learning_rate=disc_lr).minimize(
                    drums_disc_loss,
                    var_list=drums_disc_vars,
                    colocate_gradients_with_ops=True)
        with tf.variable_scope("acc_disc_solver"):
            acc_disc_solver = tf.train.AdamOptimizer(
                learning_rate=disc_lr).minimize(
                    acc_disc_loss,
                    var_list=acc_disc_vars,
                    colocate_gradients_with_ops=True)
        with tf.variable_scope("unsup_separator_solver"):
            unsup_separator_solver = tf.train.AdamOptimizer(
                learning_rate=unsup_sep_lr).minimize(
                    separator_loss,
                    var_list=separator_vars,
                    colocate_gradients_with_ops=True)
    else:
        with tf.variable_scope("separator_solver"):
            sup_separator_solver = (tf.train.AdamOptimizer(
                learning_rate=sup_sep_lr).minimize(
                    sup_separator_loss,
                    var_list=separator_vars,
                    colocate_gradients_with_ops=True))

    # SUMMARIES FOR DISCRIMINATORS AND LOSSES
    acc_disc_summaries = tf.summary.merge_all(key="acc_disc")
    drums_disc_summaries = tf.summary.merge_all(key="drums_disc")
    tf.summary.scalar("sup_sep_loss",
                      sup_separator_loss,
                      collections=['sup', "unsup"])
    tf.summary.scalar("unsup_sep_loss",
                      unsup_separator_loss,
                      collections=['unsup'])
    tf.summary.scalar("sep_loss", separator_loss, collections=["sup", "unsup"])
    sup_summaries = tf.summary.merge_all(key='sup')
    unsup_summaries = tf.summary.merge_all(key='unsup')

    # Start session
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())
    writer = tf.summary.FileWriter(model_config["log_dir"] + os.path.sep +
                                   model_folder,
                                   graph=sess.graph)

    # CHECKPOINTING
    # Load pretrained model to continue training, if we are supposed to
    if load_model is not None:
        restorer = tf.train.Saver(tf.global_variables(),
                                  write_version=tf.train.SaverDef.V2)
        print("Num of variables: " + str(len(tf.global_variables())))
        restorer.restore(sess, load_model)
        print('Pre-trained model restored from file ' + load_model)

    saver = tf.train.Saver(tf.global_variables(),
                           write_version=tf.train.SaverDef.V2)

    # Start training loop
    run = True
    _global_step = sess.run(global_step)
    _init_step = _global_step
    it = 0
    while run:
        if unsup_dataset is not None:
            # TRAIN DISCRIMINATORS
            for disc_it in range(model_config["num_disc"]):
                batches = list()
                for gen in unsup_batch_gens:
                    batches.append(gen.get_batch())

                _, _acc_disc_summaries = sess.run(
                    [acc_disc_solver, acc_disc_summaries],
                    feed_dict={
                        mix_context_u: batches[0],
                        acc_u: batches[1]
                    })

                _, _drums_disc_summaries = sess.run(
                    [drums_disc_solver, drums_disc_summaries],
                    feed_dict={
                        mix_context_u: batches[0],
                        drums_u: batches[2]
                    })

                writer.add_summary(_acc_disc_summaries, global_step=it)
                writer.add_summary(_drums_disc_summaries, global_step=it)

                it += 1

        # TRAIN SEPARATOR
        sup_batch = sup_batch_gen.get_batch()

        if unsup_dataset is not None:
            # SUP + UNSUPERVISED TRAINING
            unsup_batches = list()
            for gen in unsup_batch_gens:
                unsup_batches.append(gen.get_batch())

            _, _unsup_summaries, _sup_summaries = sess.run(
                [unsup_separator_solver, unsup_summaries, sup_summaries],
                feed_dict={
                    mix_context: sup_batch[0],
                    acc: sup_batch[1],
                    drums: sup_batch[2],
                    mix_context_u: unsup_batches[0],
                    acc_u: unsup_batches[1],
                    drums_u: unsup_batches[2]
                })
            writer.add_summary(_unsup_summaries, global_step=_global_step)
        else:
            # PURELY SUPERVISED TRAINING
            _, _sup_summaries = sess.run([sup_separator_solver, sup_summaries],
                                         feed_dict={
                                             mix_context: sup_batch[0],
                                             acc: sup_batch[1],
                                             drums: sup_batch[2]
                                         })
            writer.add_summary(_sup_summaries, global_step=_global_step)

        # Increment step counter, check if maximum iterations per epoch is achieved and stop in that case
        _global_step = sess.run(increment_global_step)

        if _global_step - _init_step > model_config["epoch_it"]:
            run = False
            print("Finished training phase, stopping batch generators")
            sup_batch_gen.stop_workers()

            if unsup_dataset is not None:
                for gen in unsup_batch_gens:
                    gen.stop_workers()

    # Epoch finished - Save model
    print("Finished epoch!")
    save_path = saver.save(sess,
                           model_config["model_base_dir"] + os.path.sep +
                           model_folder + os.path.sep + model_folder,
                           global_step=int(_global_step))

    # Close session, clear computational graph
    writer.flush()
    writer.close()
    sess.close()
    tf.reset_default_graph()

    return save_path
def predict(model_config,
            input_path,
            output_path=None,
            output_name=None,
            load_model=None):
    # Determine input and output shapes

    print("Producing source estimates for input mixture file " + input_path)
    # Prepare input audio for prediction function
    #     audio, sr = Utils.load(input_path, sr=None, mono=False)
    audio, sr = librosa.load(input_path,
                             sr=model_config["sample_rate"],
                             mono=False)
    expon = math.ceil(np.log(len(audio)) / np.log(2))
    padded_total_len = int(2**expon)
    padded_len = padded_total_len - len(audio)
    padded_audio = np.pad(audio, (0, padded_len), 'constant')
    padded_audio = np.reshape(padded_audio, (1, len(padded_audio), 1))

    disc_input_shape = [1, padded_audio.shape[1], 0]  # Shape of input

    if model_config["network"] == "unet":
        separator_class = Models.UnetAudioSeparator.UnetAudioSeparator(
            model_config["num_layers"],
            model_config["num_initial_filters"],
            output_type=model_config["output_type"],
            context=model_config["context"],
            mono=model_config["mono_downmix"],
            upsampling=model_config["upsampling"],
            num_sources=model_config["num_sources"],
            filter_size=model_config["filter_size"],
            merge_filter_size=model_config["merge_filter_size"])

    else:
        raise NotImplementedError

    sep_input_shape, sep_output_shape = separator_class.get_padding(
        np.array(disc_input_shape))
    separator_func = separator_class.get_output
    print(sep_input_shape, sep_output_shape)
    # Creating the batch generators
    assert ((sep_input_shape[1] - sep_output_shape[1]) % 2 == 0)

    # Placeholders and input normalisation
    mix_context, sources = Input.get_multitrack_placeholders(
        sep_output_shape, model_config["num_sources"], sep_input_shape, "sup")

    # BUILD MODELS
    # Separator
    separator_sources = separator_func(
        mix_context, False, not model_config["raw_audio_loss"], reuse=False
    )  # Sources are output in order [noise, speech] for speech enhancement
    #     separator_loss = tf.reduce_mean(tf.abs(sources - separator_sources[0]))

    # Set up optimizers
    separator_vars = Utils.getTrainableVariables("separator")
    print("Sep_Vars: " + str(Utils.getNumParams(separator_vars)))
    print("Num of variables " + str(len(tf.global_variables())))

    # Start session and queue input threads
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    # CHECKPOINTING
    # Load pretrained model to continue training, if we are supposed to
    if load_model != None:
        restorer = tf.train.Saver(tf.global_variables(),
                                  write_version=tf.train.SaverDef.V2)
        print("Num of variables" + str(len(tf.global_variables())))
        restorer.restore(sess, load_model)
        print('Pre-trained model restored from file ' + load_model)

    separator_sources_value = sess.run([separator_sources],
                                       feed_dict={mix_context: padded_audio})

    sess.close()
    tf.reset_default_graph()

    output_audio = separator_sources_value[0][0][0, :len(audio), 0]

    # Save source estimates as audio files into output dictionary
    input_folder, input_filename = os.path.split(input_path)
    if output_name is None:
        output_name = input_filename
    if output_path is None:
        # By default, set it to the input_path folder
        output_path = input_folder
        output_filename = os.path.join(output_path, output_name)
    else:
        output_filename = os.path.join(output_path, output_name)

    if not os.path.exists(output_path):
        print("WARNING: Given output path " + output_path +
              " does not exist. Trying to create it...")
        os.makedirs(output_path)
    assert (os.path.exists(output_path))

    sf.write(output_filename, output_audio, model_config["sample_rate"])
def predict(track, model_config, load_model):
    '''
    Takes audio track and computes corresponding source estimates.
    :param track: Track object
    :return: Source estimates dictionary
    '''

    # Determine input and output shapes, if we use U-net as separator
    disc_input_shape = [
        model_config["batch_size"], model_config["num_frames"], 0
    ]  # Shape of discriminator input
    if model_config["network"] == "unet":
        separator_class = Models.UnetAudioSeparator.UnetAudioSeparator(
            model_config["num_layers"],
            model_config["num_initial_filters"],
            output_type=model_config["output_type"],
            context=model_config["context"],
            mono=model_config["mono_downmix"],
            upsampling=model_config["upsampling"],
            num_sources=model_config["num_sources"],
            filter_size=model_config["filter_size"],
            merge_filter_size=model_config["merge_filter_size"])

    else:
        raise NotImplementedError

    sep_input_shape, sep_output_shape = separator_class.get_padding(
        np.array(disc_input_shape))
    separator_func = separator_class.get_output

    # Batch size of 1
    sep_input_shape[0] = 1
    sep_output_shape[0] = 1

    mix_context, sources = Input.get_multitrack_placeholders(
        sep_output_shape, model_config["num_sources"], sep_input_shape,
        "input")

    print("Testing...")

    # BUILD MODELS
    # Separator
    separator_sources = separator_func(mix_context, False, reuse=False)

    # Start session and queue input threads
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())

    # Load model
    # Load pretrained model to continue training, if we are supposed to
    restorer = tf.train.Saver(None, write_version=tf.train.SaverDef.V2)
    print("Num of variables" + str(len(tf.global_variables())))
    restorer.restore(sess, load_model)
    print('Pre-trained model restored for prediction')

    mix_audio, orig_sr, mix_channels = track.audio, track.rate, track.audio.shape[
        1]  # Audio has (n_samples, n_channels) shape
    separator_preds = predict_track(model_config, sess, mix_audio, orig_sr,
                                    sep_input_shape, sep_output_shape,
                                    separator_sources, mix_context)

    # Upsample predicted source audio and convert to stereo
    pred_audio = [
        Utils.resample(pred, model_config["expected_sr"], orig_sr)
        for pred in separator_preds
    ]

    if model_config[
            "mono_downmix"] and mix_channels > 1:  # Convert to multichannel if mixture input was multichannel by duplicating mono estimate
        pred_audio = [np.tile(pred, [1, mix_channels]) for pred in pred_audio]

    # Set estimates for source separation task for speech enhancement
    estimates = {  # [noise, speech] order
        'speech': pred_audio[1],
        'noise':
        pred_audio[0]  #comment-out this line to only yield speech file
    }

    # Close session, clear computational graph
    sess.close()
    tf.reset_default_graph()

    return estimates