示例#1
0
def train(model_config,
          sup_dataset,
          model_folder,
          unsup_dataset=None,
          load_model=None):
    # Determine input and output shapes
    freq_bins = model_config["num_fft"] / 2 + 1  # Make even number of freq bins
    disc_input_shape = [
        model_config["batch_size"], freq_bins - 1, model_config["num_frames"],
        1
    ]  # Shape of discriminator input

    separator_class = Models.Unet.Unet(model_config["num_layers"])
    sep_input_shape, sep_output_shape = separator_class.getUnetPadding(
        np.array(disc_input_shape))
    separator_func = separator_class.get_output

    # Batch input workers
    # Creating the batch generators
    padding_durations = [
        float(sep_input_shape[2] - sep_output_shape[2]) *
        model_config["num_hop"] / model_config["expected_sr"] / 2.0, 0, 0
    ]  # Input context that the input audio has to be padded with while reading audio files
    sup_batch_gen = batchgen.BatchGen_Paired(model_config, sup_dataset,
                                             sep_input_shape, sep_output_shape,
                                             padding_durations[0])

    # Creating unsupervised batch generator if needed
    if unsup_dataset is not None:
        unsup_batch_gens = list()
        for i in range(3):
            shape = (sep_input_shape if i == 0 else sep_output_shape)
            unsup_batch_gens.append(
                batchgen.BatchGen_Single(model_config, unsup_dataset[i], shape,
                                         padding_durations[i]))

    print("Starting worker")
    sup_batch_gen.start_workers()
    print("Started worker!")

    if unsup_dataset is not None:
        for gen in unsup_batch_gens:
            print("Starting worker")
            gen.start_workers()
            print("Started worker!")

    # Placeholders and input normalisation
    mix_context, acc, drums = Input.get_multitrack_placeholders(
        sep_output_shape, sep_input_shape, "sup")
    mix = Input.crop(mix_context, sep_output_shape)
    mix_norm, mix_context_norm, acc_norm, drums_norm = Input.norm(
        mix), Input.norm(mix_context), Input.norm(acc), Input.norm(drums)

    if unsup_dataset is not None:
        mix_context_u, acc_u, drums_u = Input.get_multitrack_placeholders(
            sep_output_shape, sep_input_shape, "unsup")
        mix_u = Input.crop(mix_context_u, sep_output_shape)
        mix_norm_u, mix_context_norm_u, acc_norm_u, drums_norm_u = Input.norm(
            mix_u), Input.norm(mix_context_u), Input.norm(acc_u), Input.norm(
                drums_u)

    print("Training...")

    # BUILD MODELS
    # Separator
    separator_acc_norm, separator_drums_norm = separator_func(mix_context_norm,
                                                              reuse=False)
    separator_acc, separator_drums = Input.denorm(
        separator_acc_norm), Input.denorm(separator_drums_norm)
    if unsup_dataset is not None:
        separator_acc_norm_u, separator_drums_norm_u = separator_func(
            mix_context_norm_u, reuse=True)
        separator_acc_u, separator_drums_u = Input.denorm(
            separator_acc_norm_u), Input.denorm(separator_drums_norm_u)
        mask_loss_u = tf.reduce_mean(
            tf.square(mix_u - separator_acc_u - separator_drums_u))
    mask_loss = tf.reduce_mean(tf.square(mix - separator_acc -
                                         separator_drums))

    # SUMMARIES FOR INPUT AND SEPARATOR
    tf.summary.scalar("mask_loss", mask_loss, collections=["sup", "unsup"])
    if unsup_dataset is not None:
        tf.summary.scalar("mask_loss_u", mask_loss_u, collections=["unsup"])
        tf.summary.scalar("acc_norm_mean_u",
                          tf.reduce_mean(acc_norm_u),
                          collections=["acc_disc"])
        tf.summary.scalar("drums_norm_mean_u",
                          tf.reduce_mean(drums_norm_u),
                          collections=["drums_disc"])
        tf.summary.scalar("acc_sep_norm_mean_u",
                          tf.reduce_mean(separator_acc_norm_u),
                          collections=["acc_disc"])
        tf.summary.scalar("drums_sep_norm_mean_u",
                          tf.reduce_mean(separator_drums_norm_u),
                          collections=["drums_disc"])
    tf.summary.scalar("acc_norm_mean",
                      tf.reduce_mean(acc_norm),
                      collections=['sup'])
    tf.summary.scalar("drums_norm_mean",
                      tf.reduce_mean(drums_norm),
                      collections=['sup'])
    tf.summary.scalar("acc_sep_norm_mean",
                      tf.reduce_mean(separator_acc_norm),
                      collections=['sup'])
    tf.summary.scalar("drums_sep_norm_mean",
                      tf.reduce_mean(separator_drums_norm),
                      collections=['sup'])

    tf.summary.image("sep_acc_norm",
                     separator_acc_norm,
                     collections=["sup", "unsup"])
    tf.summary.image("sep_drums_norm",
                     separator_drums_norm,
                     collections=["sup", "unsup"])

    # BUILD DISCRIMINATORS, if unsupervised training
    unsup_separator_loss = 0
    if unsup_dataset is not None:
        disc_func = Models.WGAN_Critic.dcgan

        # Define real and fake inputs for both discriminators - if separator output and dsicriminator input shapes do not fit perfectly, we will do a centre crop and only discriminate that part
        acc_real_input = Input.crop(acc_norm_u, disc_input_shape)
        acc_fake_input = Input.crop(separator_acc_norm_u, disc_input_shape)
        drums_real_input = Input.crop(drums_norm_u, disc_input_shape)
        drums_fake_input = Input.crop(separator_drums_norm_u, disc_input_shape)

        #WGAN
        acc_disc_loss, acc_disc_real, acc_disc_fake, acc_grad_pen, acc_wasserstein_dist = \
            Models.WGAN_Critic.create_critic(model_config, real_input=acc_real_input, fake_input=acc_fake_input, scope="acc_disc", network_func=disc_func)
        drums_disc_loss, drums_disc_real, drums_disc_fake, drums_grad_pen, drums_wasserstein_dist = \
            Models.WGAN_Critic.create_critic(model_config, real_input=drums_real_input, fake_input=drums_fake_input, scope="drums_disc", network_func=disc_func)

        L_u = -tf.reduce_mean(drums_disc_fake) - tf.reduce_mean(
            acc_disc_fake)  # WGAN based loss for separator (L_u in paper)
        unsup_separator_loss = model_config["alpha"] * L_u + model_config[
            "beta"] * mask_loss_u  # Unsupervised loss for separator: WGAN-based loss L_u and additive penalty term (mask loss), weighted by alpha and beta (hyperparameters)

    # Supervised objective: MSE in log-normalized magnitude space
    sup_separator_loss = tf.reduce_mean(tf.square(separator_drums_norm - drums_norm)) + \
                         tf.reduce_mean(tf.square(separator_acc_norm - acc_norm))

    separator_loss = sup_separator_loss + unsup_separator_loss  # Total separator loss: Supervised + unsupervised loss

    # TRAINING CONTROL VARIABLES
    global_step = tf.get_variable('global_step', [],
                                  initializer=tf.constant_initializer(0),
                                  trainable=False,
                                  dtype=tf.int64)
    increment_global_step = tf.assign(global_step, global_step + 1)
    disc_lr = tf.get_variable('disc_lr', [],
                              initializer=tf.constant_initializer(
                                  model_config["init_disc_lr"],
                                  dtype=tf.float32),
                              trainable=False)
    unsup_sep_lr = tf.get_variable('unsup_sep_lr', [],
                                   initializer=tf.constant_initializer(
                                       model_config["init_unsup_sep_lr"],
                                       dtype=tf.float32),
                                   trainable=False)
    sup_sep_lr = tf.get_variable('sup_sep_lr', [],
                                 initializer=tf.constant_initializer(
                                     model_config["init_sup_sep_lr"],
                                     dtype=tf.float32),
                                 trainable=False)

    # Set up optimizers
    separator_vars = Utils.getTrainableVariables("separator")
    print("Sep_Vars: " + str(Utils.getNumParams(separator_vars)))

    acc_disc_vars, drums_disc_vars = Utils.getTrainableVariables(
        "acc_disc"), Utils.getTrainableVariables("drums_disc")
    print("Drums_Disc_Vars: " + str(Utils.getNumParams(drums_disc_vars)))
    print("Acc_Disc_Vars: " + str(Utils.getNumParams(acc_disc_vars)))

    if unsup_dataset is not None:
        with tf.variable_scope("drums_disc_solver"):
            drums_disc_solver = tf.train.AdamOptimizer(
                learning_rate=disc_lr).minimize(
                    drums_disc_loss,
                    var_list=drums_disc_vars,
                    colocate_gradients_with_ops=True)
        with tf.variable_scope("acc_disc_solver"):
            acc_disc_solver = tf.train.AdamOptimizer(
                learning_rate=disc_lr).minimize(
                    acc_disc_loss,
                    var_list=acc_disc_vars,
                    colocate_gradients_with_ops=True)
        with tf.variable_scope("unsup_separator_solver"):
            unsup_separator_solver = tf.train.AdamOptimizer(
                learning_rate=unsup_sep_lr).minimize(
                    separator_loss,
                    var_list=separator_vars,
                    colocate_gradients_with_ops=True)
    else:
        with tf.variable_scope("separator_solver"):
            sup_separator_solver = (tf.train.AdamOptimizer(
                learning_rate=sup_sep_lr).minimize(
                    sup_separator_loss,
                    var_list=separator_vars,
                    colocate_gradients_with_ops=True))

    # SUMMARIES FOR DISCRIMINATORS AND LOSSES
    acc_disc_summaries = tf.summary.merge_all(key="acc_disc")
    drums_disc_summaries = tf.summary.merge_all(key="drums_disc")
    tf.summary.scalar("sup_sep_loss",
                      sup_separator_loss,
                      collections=['sup', "unsup"])
    tf.summary.scalar("unsup_sep_loss",
                      unsup_separator_loss,
                      collections=['unsup'])
    tf.summary.scalar("sep_loss", separator_loss, collections=["sup", "unsup"])
    sup_summaries = tf.summary.merge_all(key='sup')
    unsup_summaries = tf.summary.merge_all(key='unsup')

    # Start session
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())
    writer = tf.summary.FileWriter(model_config["log_dir"] + os.path.sep +
                                   model_folder,
                                   graph=sess.graph)

    # CHECKPOINTING
    # Load pretrained model to continue training, if we are supposed to
    if load_model is not None:
        restorer = tf.train.Saver(tf.global_variables(),
                                  write_version=tf.train.SaverDef.V2)
        print("Num of variables: " + str(len(tf.global_variables())))
        restorer.restore(sess, load_model)
        print('Pre-trained model restored from file ' + load_model)

    saver = tf.train.Saver(tf.global_variables(),
                           write_version=tf.train.SaverDef.V2)

    # Start training loop
    run = True
    _global_step = sess.run(global_step)
    _init_step = _global_step
    it = 0
    while run:
        if unsup_dataset is not None:
            # TRAIN DISCRIMINATORS
            for disc_it in range(model_config["num_disc"]):
                batches = list()
                for gen in unsup_batch_gens:
                    batches.append(gen.get_batch())

                _, _acc_disc_summaries = sess.run(
                    [acc_disc_solver, acc_disc_summaries],
                    feed_dict={
                        mix_context_u: batches[0],
                        acc_u: batches[1]
                    })

                _, _drums_disc_summaries = sess.run(
                    [drums_disc_solver, drums_disc_summaries],
                    feed_dict={
                        mix_context_u: batches[0],
                        drums_u: batches[2]
                    })

                writer.add_summary(_acc_disc_summaries, global_step=it)
                writer.add_summary(_drums_disc_summaries, global_step=it)

                it += 1

        # TRAIN SEPARATOR
        sup_batch = sup_batch_gen.get_batch()

        if unsup_dataset is not None:
            # SUP + UNSUPERVISED TRAINING
            unsup_batches = list()
            for gen in unsup_batch_gens:
                unsup_batches.append(gen.get_batch())

            _, _unsup_summaries, _sup_summaries = sess.run(
                [unsup_separator_solver, unsup_summaries, sup_summaries],
                feed_dict={
                    mix_context: sup_batch[0],
                    acc: sup_batch[1],
                    drums: sup_batch[2],
                    mix_context_u: unsup_batches[0],
                    acc_u: unsup_batches[1],
                    drums_u: unsup_batches[2]
                })
            writer.add_summary(_unsup_summaries, global_step=_global_step)
        else:
            # PURELY SUPERVISED TRAINING
            _, _sup_summaries = sess.run([sup_separator_solver, sup_summaries],
                                         feed_dict={
                                             mix_context: sup_batch[0],
                                             acc: sup_batch[1],
                                             drums: sup_batch[2]
                                         })
            writer.add_summary(_sup_summaries, global_step=_global_step)

        # Increment step counter, check if maximum iterations per epoch is achieved and stop in that case
        _global_step = sess.run(increment_global_step)

        if _global_step - _init_step > model_config["epoch_it"]:
            run = False
            print("Finished training phase, stopping batch generators")
            sup_batch_gen.stop_workers()

            if unsup_dataset is not None:
                for gen in unsup_batch_gens:
                    gen.stop_workers()

    # Epoch finished - Save model
    print("Finished epoch!")
    save_path = saver.save(sess,
                           model_config["model_base_dir"] + os.path.sep +
                           model_folder + os.path.sep + model_folder,
                           global_step=int(_global_step))

    # Close session, clear computational graph
    writer.flush()
    writer.close()
    sess.close()
    tf.reset_default_graph()

    return save_path
示例#2
0
def bss_evaluate(model_config, dataset, load_model):
    '''
    Calculates source separation evaluation metrics of a given separator model on the test set using BSS-Eval
    :param model_config: Separation network configuration required to build symbolic computation graph of network
    :param dataset: Test dataset
    :param load_model: Path to separator model checkpoint containing the network weights
    :return: Dict containing evaluation metrics, wav files of predicted drum and acc are written to results directory
    '''
    # Determine input and output shapes, if we use U-net as separator
    track_number = 1
    freq_bins = model_config["num_fft"] / 2 + 1  # Make even number of freq bins
    disc_input_shape = [1, freq_bins - 1, model_config["num_frames"],
                        1]  # Shape of discriminator input

    separator_class = Models.Unet.Unet(model_config["num_layers"])
    sep_input_shape, sep_output_shape = separator_class.getUnetPadding(
        np.array(disc_input_shape))
    separator_func = separator_class.get_output

    # Placeholders and input normalisation
    input_ph, queue, [mix_context, acc, drums] = Input.get_multitrack_input(
        sep_output_shape[1:],
        1,
        name="input_batch",
        input_shape=sep_input_shape[1:])

    mix = Input.crop(mix_context, sep_output_shape)
    mix_norm, mix_context_norm, acc_norm, drums_norm = Input.norm(mix), Input.norm(mix_context), \
                                                        Input.norm(acc), Input.norm(drums)

    print("Testing...")

    # BUILD MODELS
    # Separator
    separator_acc_norm, separator_drums_norm = separator_func(mix_context_norm,
                                                              reuse=False)
    separator_acc, separator_drums = Input.denorm(
        separator_acc_norm), Input.denorm(separator_drums_norm)

    # Start session and queue input threads
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())

    # Load model
    # Load pretrained model to continue training, if we are supposed to
    restorer = tf.train.Saver(None, write_version=tf.train.SaverDef.V2)
    print("Num of variables" + str(len(tf.global_variables())))
    restorer.restore(sess, load_model)
    print('Pre-trained model restored for testing')

    # Initialize total score object
    song_scores = list()

    for multitrack in dataset[8:15]:
        filename = multitrack[0].path
        print("Evaluating file: " + filename + "Track number: " + track_number)
        if filename.__contains__("DSD100"):
            db = "DSD100"
        elif filename.__contains__("Kala"):
            db = "IKala"
        elif filename.__contains__("ccmixter"):
            db = "CCMixter"
        elif filename.__contains__("MedleyDB"):
            db = "MedleyDB"
        else:
            db = 'musdb18orNIStems'
        song_info = {"Title": filename, "Database": db}

        # Load mixture and pad it so that output sources have the same length after STFT/ISTFT
        mix_audio, mix_sr = librosa.load(multitrack[0].path,
                                         sr=model_config["expected_sr"])
        mix_length = len(mix_audio)
        # Pad input so that ISTFT later leads to same-length audio
        mix_audio_pad = librosa.util.fix_length(
            mix_audio, mix_length + model_config["num_fft"] // 2)
        mix_mag, mix_ph = Input.audioFileToSpectrogram(mix_audio_pad,
                                                       model_config["num_fft"],
                                                       model_config["num_hop"])
        source_time_frames = mix_mag.shape[1]

        # Preallocate source predictions (same shape as input mixture)
        acc_pred_mag = np.zeros(mix_mag.shape, np.float32)
        drums_pred_mag = np.zeros(mix_mag.shape, np.float32)

        input_time_frames = sep_input_shape[2]
        output_time_frames = sep_output_shape[2]

        # Pad mix spectrogram across time at beg and end so network can make prediction at the beginning and end of signal
        pad_time_frames = (input_time_frames - output_time_frames) / 2
        mix_mag = np.pad(mix_mag, [(0, 0), (pad_time_frames, pad_time_frames)],
                         mode="constant",
                         constant_values=0.0)

        # Iterate over mixture magnitudes, fetch network prediction
        for source_pos in range(0, source_time_frames, output_time_frames):
            # If output patch reaches over the end of source spectrogram, set it so we predict at end of the output, then stop
            if source_pos + output_time_frames > source_time_frames:
                source_pos = source_time_frames - output_time_frames

            # Prepare mixture excerpt by selecting time interval
            mix_mag_part = mix_mag[:,
                                   source_pos:source_pos + input_time_frames]
            mix_mag_part = Utils.pad_freqs(
                mix_mag_part, sep_input_shape[1:3])  # Pad along frequency axis
            mix_mag_part = mix_mag_part[np.newaxis, :, :, np.newaxis]

            # Fetch network prediction
            acc_mag_part, drums_mag_part = sess.run(
                [separator_acc, separator_drums],
                feed_dict={mix_context: mix_mag_part})

            # Save predictions
            #source_shape = [1, freq_bins, acc_mag_part.shape[2], 1]
            acc_pred_mag[:, source_pos:source_pos +
                         output_time_frames] = acc_mag_part[0, :-1, :, 0]
            drums_pred_mag[:, source_pos:source_pos +
                           output_time_frames] = drums_mag_part[0, :-1, :, 0]

        # Spectrograms to audio, using mixture phase
        acc_pred_audio = Input.spectrogramToAudioFile(acc_pred_mag,
                                                      model_config["num_fft"],
                                                      model_config["num_hop"],
                                                      phase=mix_ph,
                                                      length=mix_length,
                                                      phaseIterations=0)

        drums_pred_audio = Input.spectrogramToAudioFile(
            drums_pred_mag,
            model_config["num_fft"],
            model_config["num_hop"],
            phase=mix_ph,
            length=mix_length,
            phaseIterations=0)

        # Load original sources
        if isinstance(multitrack[1], float):
            acc_audio = np.zeros(mix_audio.shape, np.float32)
        else:
            acc_audio, _ = librosa.load(multitrack[1].path,
                                        sr=model_config["expected_sr"])
        if isinstance(multitrack[2], float):
            drum_audio = np.zeros(mix_audio.shape, np.float32)
        else:
            drum_audio, _ = librosa.load(multitrack[2].path,
                                         sr=model_config["expected_sr"])

        # Check if any reference source is completely silent, if so, inject some very slight noise to avoid problems during SDR
        reference_zero = False
        if np.max(np.abs(acc_audio)) == 0.0:
            acc_audio += np.random.uniform(-1e-10, 1e-10, size=acc_audio.shape)
            reference_zero = True
        if np.max(np.abs(drum_audio)) == 0.0:
            drum_audio += np.random.uniform(-1e-10,
                                            1e-10,
                                            size=drum_audio.shape)
            reference_zero = True

        # Evaluate BSS according to MIREX separation method # http://www.music-ir.org/mirex/wiki/2016:Singing_Voice_Separation
        #/ np.linalg.norm(acc_audio + drum_audio) # Normalized audio
        ref_sources = np.vstack([acc_audio, drum_audio])
        #/ np.linalg.norm(acc_pred_audio + drums_pred_audio) # Normalized estimates
        pred_sources = np.vstack([acc_pred_audio, drums_pred_audio])
        validate(ref_sources, pred_sources)
        scores = bss_eval_sources(ref_sources,
                                  pred_sources,
                                  compute_permutation=False)

        song_info["SDR"] = scores[0]
        song_info["SIR"] = scores[1]
        song_info["SAR"] = scores[2]

        # Compute reference scores and SNR only if both sources are not silent, since they are undefined otherwise
        if not reference_zero:
            mix_ref = np.vstack([mix_audio, mix_audio
                                 ])  #/ np.linalg.norm(mix_audio + mix_audio)
            mix_scores = bss_eval_sources(ref_sources,
                                          mix_ref,
                                          compute_permutation=False)
            norm_scores = np.array(scores) - np.array(mix_scores)

            # Compute SNR: 10 log_10 ( ||s_target||^2 / ||s_target - alpha * s_estimate||^2 ), scale target for optimal SNR
            drums_snr = alpha_snr(drum_audio, drums_pred_audio)
            acc_snr = alpha_snr(acc_audio, acc_pred_audio)
            drums_ref_snr = alpha_snr(drum_audio, mix_audio)
            acc_ref_snr = alpha_snr(acc_audio, mix_audio)

            song_info["NSDR"] = norm_scores[0]
            song_info["NSIR"] = norm_scores[1]
            song_info["SNR"] = np.array([acc_snr, drums_snr])
            song_info["NSNR"] = np.array(
                [acc_snr - acc_ref_snr, drums_snr - drums_ref_snr])

        song_scores.append(song_info)
        print(song_info)
        dirpath = os.path.join(os.getcwd(), 'results')
        try:
            fn = dirpath + "tr_" + track_number + "_drums.wav"
            librosa.output.write_wav(fn,
                                     drums_pred_audio,
                                     sr=model_config["expected_sr"])
            fn = dirpath + "tr_" + track_number + "_acc.wav"
            librosa.output.write_wav(fn,
                                     acc_pred_audio,
                                     sr=model_config["expected_sr"])
        except Exception as e:
            print("Failed to write wav files, error: " + str(e))

        track_number = track_number + 1
    try:
        with open(dirpath + '/' + str(experiment_id) + "BSS_eval.pkl",
                  "wb") as file:
            pickle.dump(song_scores, file)
    except Exception as e:
        print("Failed to write score files, error: " + str(e))

    #Close session, clear computational graph
    sess.close()
    tf.reset_default_graph()

    return song_scores