Esempio n. 1
0
def synth_file(file_name="nus_MCUR_sing_10.hdf5",
               singer_index=0,
               file_path=config.wav_dir,
               show_plots=True,
               save_file="GBO"):

    stat_file = h5py.File(config.stat_dir + 'stats.hdf5', mode='r')
    max_feat = np.array(stat_file["feats_maximus"])
    min_feat = np.array(stat_file["feats_minimus"])
    with tf.Graph().as_default():

        input_placeholder = tf.placeholder(tf.float32,
                                           shape=(config.batch_size,
                                                  config.max_phr_len, 66),
                                           name='input_placeholder')
        tf.summary.histogram('inputs', input_placeholder)

        output_placeholder = tf.placeholder(tf.float32,
                                            shape=(config.batch_size,
                                                   config.max_phr_len, 64),
                                            name='output_placeholder')

        f0_input_placeholder = tf.placeholder(tf.float32,
                                              shape=(config.batch_size,
                                                     config.max_phr_len, 1),
                                              name='f0_input_placeholder')

        rand_input_placeholder = tf.placeholder(tf.float32,
                                                shape=(config.batch_size,
                                                       config.max_phr_len, 4),
                                                name='rand_input_placeholder')

        prob = tf.placeholder_with_default(1.0, shape=())

        phoneme_labels = tf.placeholder(tf.int32,
                                        shape=(config.batch_size,
                                               config.max_phr_len),
                                        name='phoneme_placeholder')
        phone_onehot_labels = tf.one_hot(indices=tf.cast(
            phoneme_labels, tf.int32),
                                         depth=42)

        phoneme_labels_2 = tf.placeholder(tf.float32,
                                          shape=(config.batch_size,
                                                 config.max_phr_len, 42),
                                          name='phoneme_placeholder_1')
        # phone_onehot_labels = tf.one_hot(indices=tf.cast(phoneme_labels, tf.int32), depth=42)

        singer_labels = tf.placeholder(tf.float32,
                                       shape=(config.batch_size),
                                       name='singer_placeholder')
        singer_onehot_labels = tf.one_hot(indices=tf.cast(
            singer_labels, tf.int32),
                                          depth=12)

        with tf.variable_scope('phone_Model') as scope:
            # regularizer = tf.contrib.layers.l2_regularizer(scale=0.1)
            pho_logits = modules.phone_network(input_placeholder)
            pho_classes = tf.argmax(pho_logits, axis=-1)
            pho_probs = tf.nn.softmax(pho_logits)

        with tf.variable_scope('Final_Model') as scope:
            voc_output = modules.final_net(singer_onehot_labels,
                                           f0_input_placeholder,
                                           phoneme_labels_2)
            voc_output_decoded = tf.nn.sigmoid(voc_output)
            scope.reuse_variables()
            voc_output_3 = modules.final_net(singer_onehot_labels,
                                             f0_input_placeholder, pho_probs)
            voc_output_3_decoded = tf.nn.sigmoid(voc_output_3)

            # scope.reuse_variables()

            # voc_output_gen = modules.final_net(singer_onehot_labels, f0_input_placeholder, pho_probs)
            # voc_output_decoded_gen = tf.nn.sigmoid(voc_output_gen)

        # with tf.variable_scope('singer_Model') as scope:
        #     singer_embedding, singer_logits = modules.singer_network(input_placeholder, prob)
        #     singer_classes = tf.argmax(singer_logits, axis=-1)
        #     singer_probs = tf.nn.softmax(singer_logits)

        with tf.variable_scope('Generator') as scope:
            voc_output_2 = modules.GAN_generator(singer_onehot_labels,
                                                 phoneme_labels_2,
                                                 f0_input_placeholder,
                                                 rand_input_placeholder)

        with tf.variable_scope('Discriminator') as scope:
            D_fake = modules.GAN_discriminator(voc_output_2,
                                               singer_onehot_labels,
                                               phone_onehot_labels,
                                               f0_input_placeholder)

        saver = tf.train.Saver(max_to_keep=config.max_models_to_keep)

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        sess = tf.Session()

        sess.run(init_op)

        ckpt = tf.train.get_checkpoint_state(config.log_dir)

        if ckpt and ckpt.model_checkpoint_path:
            print("Using the model in %s" % ckpt.model_checkpoint_path)
            saver.restore(sess, ckpt.model_checkpoint_path)
        # saver.restore(sess, './log/model.ckpt-3999')

        # import pdb;pdb.set_trace()

        voc_file = h5py.File(config.voice_dir + file_name, "r")

        # speaker_file = h5py.File(config.voice_dir+speaker_file, "r")

        feats = np.array(voc_file['feats'])
        # feats = utils.input_to_feats('./54228_chorus.wav_ori_vocals.wav', mode = 1)

        f0 = feats[:, -2]

        # import pdb;pdb.set_trace()

        med = np.median(f0[f0 > 0])

        f0[f0 == 0] = med

        f0 = f0 - 12

        feats[:, -2] = feats[:, -2] - 12

        f0_nor = (f0 - min_feat[-2]) / (max_feat[-2] - min_feat[-2])

        feats = (feats - min_feat) / (max_feat - min_feat)

        pho_target = np.array(voc_file["phonemes"])

        in_batches_f0, nchunks_in = utils.generate_overlapadd(
            f0_nor.reshape(-1, 1))

        in_batches_pho, nchunks_in_pho = utils.generate_overlapadd(
            pho_target.reshape(-1, 1))

        in_batches_feat, kaka = utils.generate_overlapadd(feats)

        # import pdb;pdb.set_trace()

        out_batches_feats = []

        out_batches_feats_1 = []

        out_batches_feats_gan = []

        for in_batch_f0, in_batch_pho_target, in_batch_feat in zip(
                in_batches_f0, in_batches_pho, in_batches_feat):

            in_batch_f0 = in_batch_f0.reshape(
                [config.batch_size, config.max_phr_len, 1])

            in_batch_pho_target = in_batch_pho_target.reshape(
                [config.batch_size, config.max_phr_len])

            # in_batch_pho_target = sess.run(pho_probs, feed_dict = {input_placeholder: in_batch_feat})

            output_feats, output_feats_1, output_feats_gan = sess.run(
                [voc_output_decoded, voc_output_3_decoded, voc_output_2],
                feed_dict={
                    input_placeholder:
                    in_batch_feat,
                    f0_input_placeholder:
                    in_batch_f0,
                    phoneme_labels_2:
                    in_batch_pho_target,
                    singer_labels:
                    np.ones(30) * singer_index,
                    rand_input_placeholder:
                    np.random.normal(-1.0,
                                     1.0,
                                     size=[30, config.max_phr_len, 4])
                })

            out_batches_feats.append(output_feats)

            out_batches_feats_1.append(output_feats_1)

            out_batches_feats_gan.append(output_feats_gan / 2 + 0.5)

            # out_batches_voc_stft_phase.append(output_voc_stft_phase)

        # import pdb;pdb.set_trace()

        out_batches_feats = np.array(out_batches_feats)
        # import pdb;pdb.set_trace()
        out_batches_feats = utils.overlapadd(out_batches_feats, nchunks_in)

        out_batches_feats_1 = np.array(out_batches_feats_1)
        # import pdb;pdb.set_trace()
        out_batches_feats_1 = utils.overlapadd(out_batches_feats_1, nchunks_in)

        out_batches_feats_gan = np.array(out_batches_feats_gan)
        # import pdb;pdb.set_trace()
        out_batches_feats_gan = utils.overlapadd(out_batches_feats_gan,
                                                 nchunks_in)

        feats = feats * (max_feat - min_feat) + min_feat

        out_batches_feats = out_batches_feats * (max_feat[:-2] -
                                                 min_feat[:-2]) + min_feat[:-2]

        out_batches_feats_1 = out_batches_feats_1 * (
            max_feat[:-2] - min_feat[:-2]) + min_feat[:-2]

        out_batches_feats_gan = out_batches_feats_gan * (
            max_feat[:-2] - min_feat[:-2]) + min_feat[:-2]

        out_batches_feats = out_batches_feats[:len(feats)]

        out_batches_feats_1 = out_batches_feats_1[:len(feats)]

        out_batches_feats_gan = out_batches_feats_gan[:len(feats)]

        first_op = np.concatenate([out_batches_feats, feats[:, -2:]], axis=-1)

        pho_op = np.concatenate([out_batches_feats_1, feats[:, -2:]], axis=-1)

        gan_op = np.concatenate([out_batches_feats_gan, feats[:, -2:]],
                                axis=-1)

        # import pdb;pdb.set_trace()
        gan_op = np.ascontiguousarray(gan_op)

        pho_op = np.ascontiguousarray(pho_op)

        first_op = np.ascontiguousarray(first_op)

        if show_plots:

            plt.figure(1)

            ax1 = plt.subplot(311)

            plt.imshow(feats[:, :60].T, aspect='auto', origin='lower')

            ax1.set_title("Ground Truth Vocoder Features", fontsize=10)

            ax2 = plt.subplot(312, sharex=ax1, sharey=ax1)

            plt.imshow(out_batches_feats[:, :60].T,
                       aspect='auto',
                       origin='lower')

            ax2.set_title("Cross Entropy Output Vocoder Features", fontsize=10)

            ax3 = plt.subplot(313, sharex=ax1, sharey=ax1)

            ax3.set_title("GAN Vocoder Output Features", fontsize=10)

            # plt.imshow(out_batches_feats_1[:,:60].T,aspect='auto',origin='lower')
            #
            # plt.subplot(414, sharex = ax1, sharey = ax1)

            plt.imshow(out_batches_feats_gan[:, :60].T,
                       aspect='auto',
                       origin='lower')

            plt.figure(2)

            plt.subplot(211)

            plt.imshow(feats[:, 60:-2].T, aspect='auto', origin='lower')

            plt.subplot(212)

            plt.imshow(out_batches_feats[:, -4:].T,
                       aspect='auto',
                       origin='lower')

            plt.show()

            save_file = input(
                "Which files to synthesise G for GAN, B for Binary Entropy, "
                "O for original, or any combination. Default is None").upper(
                ) or "N"

        else:
            save_file = input(
                "Which files to synthesise G for GAN, B for Binary Entropy, "
                "O for original, or any combination. Default is all (GBO)"
            ).upper() or "GBO"

        if "G" in save_file:

            utils.feats_to_audio(gan_op[:, :], file_name[:-4] + 'gan_op.wav')

            print("GAN file saved to {}".format(
                os.path.join(config.val_dir, file_name[:-4] + 'gan_op.wav')))

        if "O" in save_file:

            utils.feats_to_audio(feats[:, :], file_name[:-4] + 'ori_op.wav')

            print("Originl file, resynthesized via WORLD vocoder saved to {}".
                  format(
                      os.path.join(config.val_dir,
                                   file_name[:-4] + 'ori_op.wav')))
            #
        if "B" in save_file:
            # # utils.feats_to_audio(pho_op[:5000,:],file_name[:-4]+'phoop.wav')
            #
            utils.feats_to_audio(first_op[:, :], file_name[:-4] + 'bce_op.wav')
            print("Binar cross entropy file saved to {}".format(
                os.path.join(config.val_dir, file_name[:-4] + 'bce_op.wav')))
Esempio n. 2
0
def train(_):
    stat_file = h5py.File(config.stat_dir + 'stats.hdf5', mode='r')
    max_feat = np.array(stat_file["feats_maximus"])
    min_feat = np.array(stat_file["feats_minimus"])
    with tf.Graph().as_default():

        input_placeholder = tf.placeholder(tf.float32,
                                           shape=(config.batch_size,
                                                  config.max_phr_len, 66),
                                           name='input_placeholder')
        tf.summary.histogram('inputs', input_placeholder)

        output_placeholder = tf.placeholder(tf.float32,
                                            shape=(config.batch_size,
                                                   config.max_phr_len, 64),
                                            name='output_placeholder')

        f0_input_placeholder = tf.placeholder(tf.float32,
                                              shape=(config.batch_size,
                                                     config.max_phr_len, 1),
                                              name='f0_input_placeholder')

        rand_input_placeholder = tf.placeholder(tf.float32,
                                                shape=(config.batch_size,
                                                       config.max_phr_len, 4),
                                                name='rand_input_placeholder')

        # pho_input_placeholder = tf.placeholder(tf.float32, shape=(config.batch_size,config.max_phr_len, 42),name='pho_input_placeholder')

        prob = tf.placeholder_with_default(1.0, shape=())

        phoneme_labels = tf.placeholder(tf.int32,
                                        shape=(config.batch_size,
                                               config.max_phr_len),
                                        name='phoneme_placeholder')
        phone_onehot_labels = tf.one_hot(indices=tf.cast(
            phoneme_labels, tf.int32),
                                         depth=42)

        singer_labels = tf.placeholder(tf.float32,
                                       shape=(config.batch_size),
                                       name='singer_placeholder')
        singer_onehot_labels = tf.one_hot(indices=tf.cast(
            singer_labels, tf.int32),
                                          depth=12)

        phoneme_labels_shuffled = tf.placeholder(tf.int32,
                                                 shape=(config.batch_size,
                                                        config.max_phr_len),
                                                 name='phoneme_placeholder_s')
        phone_onehot_labels_shuffled = tf.one_hot(indices=tf.cast(
            phoneme_labels_shuffled, tf.int32),
                                                  depth=42)

        singer_labels_shuffled = tf.placeholder(tf.float32,
                                                shape=(config.batch_size),
                                                name='singer_placeholder_s')
        singer_onehot_labels_shuffled = tf.one_hot(indices=tf.cast(
            singer_labels_shuffled, tf.int32),
                                                   depth=12)

        with tf.variable_scope('phone_Model') as scope:
            # regularizer = tf.contrib.layers.l2_regularizer(scale=0.1)
            pho_logits = modules.phone_network(input_placeholder)
            pho_classes = tf.argmax(pho_logits, axis=-1)
            pho_probs = tf.nn.softmax(pho_logits)

        with tf.variable_scope('Final_Model') as scope:
            voc_output = modules.final_net(singer_onehot_labels,
                                           f0_input_placeholder,
                                           phone_onehot_labels)
            voc_output_decoded = tf.nn.sigmoid(voc_output)
            scope.reuse_variables()
            voc_output_3 = modules.final_net(singer_onehot_labels,
                                             f0_input_placeholder, pho_probs)
            voc_output_3_decoded = tf.nn.sigmoid(voc_output_3)

        # with tf.variable_scope('singer_Model') as scope:
        #     singer_embedding, singer_logits = modules.singer_network(input_placeholder, prob)
        #     singer_classes = tf.argmax(singer_logits, axis=-1)
        #     singer_probs = tf.nn.softmax(singer_logits)

        with tf.variable_scope('Generator') as scope:
            voc_output_2 = modules.GAN_generator(singer_onehot_labels,
                                                 phone_onehot_labels,
                                                 f0_input_placeholder,
                                                 rand_input_placeholder)
            # scope.reuse_variables()
            # voc_output_2_2 = modules.GAN_generator(voc_output_3_decoded, singer_onehot_labels, phone_onehot_labels, f0_input_placeholder, rand_input_placeholder)

        with tf.variable_scope('Discriminator') as scope:
            D_real = modules.GAN_discriminator(
                (output_placeholder - 0.5) * 2, singer_onehot_labels,
                phone_onehot_labels, f0_input_placeholder)
            scope.reuse_variables()
            D_fake = modules.GAN_discriminator(voc_output_2,
                                               singer_onehot_labels,
                                               phone_onehot_labels,
                                               f0_input_placeholder)
            # scope.reuse_variables()
            # epsilon = tf.random_uniform([], 0.0, 1.0)
            # x_hat = (output_placeholder-0.5)*2*epsilon + (1-epsilon)* voc_output_2
            # d_hat = modules.GAN_discriminator(x_hat, singer_onehot_labels, phone_onehot_labels, f0_input_placeholder)
            # scope.reuse_variables()
            # D_fake_2 = modules.GAN_discriminator(voc_output_2_2, singer_onehot_labels, phone_onehot_labels, f0_input_placeholder)
            scope.reuse_variables()
            D_fake_real = modules.GAN_discriminator(
                (voc_output_decoded - 0.5) * 2, singer_onehot_labels,
                phone_onehot_labels, f0_input_placeholder)
        # import pdb;pdb.set_trace()

        # Get network parameters

        final_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                         scope="Final_Model")

        g_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                     scope="Generator")

        d_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                     scope="Discriminator")

        phone_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                         scope="phone_Model")

        # Phoneme network loss and summary

        pho_weights = tf.reduce_sum(config.phonemas_weights *
                                    phone_onehot_labels,
                                    axis=-1)

        unweighted_losses = tf.nn.softmax_cross_entropy_with_logits(
            labels=phone_onehot_labels, logits=pho_logits)

        weighted_losses = unweighted_losses * pho_weights

        pho_loss = tf.reduce_mean(weighted_losses)
        # +tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels= output_placeholder, logits=voc_output_3))*0.001

        # reconstruct_loss_pho = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels = output_placeholder, logits=voc_output_decoded_gen)) *0.00001

        # pho_loss+=reconstruct_loss_pho

        pho_acc = tf.metrics.accuracy(labels=phoneme_labels,
                                      predictions=pho_classes)

        pho_summary = tf.summary.scalar('pho_loss', pho_loss)

        pho_acc_summary = tf.summary.scalar('pho_accuracy', pho_acc[0])

        # Discriminator Loss

        # D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels = tf.ones_like(D_real) , logits=D_real+1e-12))
        # D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels = tf.zeros_like(D_fake) , logits=D_fake+1e-12)) + tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels = tf.zeros_like(D_fake_2) , logits=D_fake_2+1e-12)) *0.5
        # D_loss_fake_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels = tf.zeros_like(D_fake_real) , logits=D_fake_real+1e-12))

        # D_loss_real = tf.reduce_mean(D_real+1e-12)
        # D_loss_fake = - tf.reduce_mean(D_fake+1e-12)
        # D_loss_fake_real = - tf.reduce_mean(D_fake_real+1e-12)

        # gradients = tf.gradients(d_hat, x_hat)[0] + 1e-6
        # slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
        # gradient_penalty = tf.reduce_mean((slopes-1.0)**2)
        # errD += gradient_penalty
        # D_loss_fake_real = - tf.reduce_mean(D_fake_real)

        D_correct_pred = tf.equal(tf.round(tf.sigmoid(D_real)),
                                  tf.ones_like(D_real))

        D_correct_pred_fake = tf.equal(tf.round(tf.sigmoid(D_fake_real)),
                                       tf.ones_like(D_fake_real))

        D_accuracy = tf.reduce_mean(tf.cast(D_correct_pred, tf.float32))

        D_accuracy_fake = tf.reduce_mean(
            tf.cast(D_correct_pred_fake, tf.float32))

        D_loss = tf.reduce_mean(D_real + 1e-12) - tf.reduce_mean(D_fake +
                                                                 1e-12)
        # -tf.reduce_mean(D_fake_real+1e-12)*0.001

        dis_summary = tf.summary.scalar('dis_loss', D_loss)

        dis_acc_summary = tf.summary.scalar('dis_acc', D_accuracy)

        dis_acc_fake_summary = tf.summary.scalar('dis_acc_fake',
                                                 D_accuracy_fake)

        #Final net loss

        # G_loss_GAN = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels= tf.ones_like(D_real), logits=D_fake+1e-12)) + tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels= tf.ones_like(D_fake_2), logits=D_fake_2+1e-12))
        # + tf.reduce_sum(tf.abs(output_placeholder- (voc_output_2/2+0.5))*(1-input_placeholder[:,:,-1:])) *0.00001

        G_loss_GAN = tf.reduce_mean(D_fake + 1e-12) + tf.reduce_sum(
            tf.abs(output_placeholder - (voc_output_2 / 2 + 0.5))) * 0.00005
        # + tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels= output_placeholder, logits=voc_output)) *0.000005
        #

        G_correct_pred = tf.equal(tf.round(tf.sigmoid(D_fake)),
                                  tf.ones_like(D_real))

        # G_correct_pred_2 = tf.equal(tf.round(tf.sigmoid(D_fake_2)), tf.ones_like(D_real))

        G_accuracy = tf.reduce_mean(tf.cast(G_correct_pred, tf.float32))

        gen_summary = tf.summary.scalar('gen_loss', G_loss_GAN)

        gen_acc_summary = tf.summary.scalar('gen_acc', G_accuracy)

        final_loss = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels= output_placeholder, logits=voc_output)) \
                           # +tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels= output_placeholder, logits=voc_output_3))*0.5

        # reconstruct_loss = tf.reduce_sum(tf.abs(output_placeholder- (voc_output_2/2+0.5)))

        final_summary = tf.summary.scalar('final_loss', final_loss)

        summary = tf.summary.merge_all()

        # summary_val = tf.summary.merge([f0_summary_midi, pho_summary, singer_summary, reconstruct_summary, pho_acc_summary_val,  f0_acc_summary_midi_val, singer_acc_summary_val ])

        # vuv_summary = tf.summary.scalar('vuv_loss', vuv_loss)

        # loss_summary = tf.summary.scalar('total_loss', loss)

        #Global steps

        global_step = tf.Variable(0, name='global_step', trainable=False)

        global_step_re = tf.Variable(0, name='global_step_re', trainable=False)

        global_step_dis = tf.Variable(0,
                                      name='global_step_dis',
                                      trainable=False)

        global_step_gen = tf.Variable(0,
                                      name='global_step_gen',
                                      trainable=False)

        #Optimizers

        pho_optimizer = tf.train.AdamOptimizer(learning_rate=config.init_lr)

        re_optimizer = tf.train.AdamOptimizer(learning_rate=config.init_lr)

        dis_optimizer = tf.train.RMSPropOptimizer(learning_rate=5e-5)

        gen_optimizer = tf.train.RMSPropOptimizer(learning_rate=5e-5)
        # GradientDescentOptimizer

        # update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

        # Training functions
        pho_train_function = pho_optimizer.minimize(pho_loss,
                                                    global_step=global_step,
                                                    var_list=phone_params)

        # with tf.control_dependencies(update_ops):
        re_train_function = re_optimizer.minimize(final_loss,
                                                  global_step=global_step_re,
                                                  var_list=final_params)

        dis_train_function = dis_optimizer.minimize(
            D_loss, global_step=global_step_dis, var_list=d_params)

        gen_train_function = gen_optimizer.minimize(
            G_loss_GAN, global_step=global_step_gen, var_list=g_params)

        clip_discriminator_var_op = [
            var.assign(tf.clip_by_value(var, -0.01, 0.01)) for var in d_params
        ]

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        saver = tf.train.Saver(max_to_keep=config.max_models_to_keep)
        sess = tf.Session()

        sess.run(init_op)

        ckpt = tf.train.get_checkpoint_state(config.log_dir)

        if ckpt and ckpt.model_checkpoint_path:
            print("Using the model in %s" % ckpt.model_checkpoint_path)
            saver.restore(sess, ckpt.model_checkpoint_path)

        train_summary_writer = tf.summary.FileWriter(config.log_dir + 'train/',
                                                     sess.graph)
        val_summary_writer = tf.summary.FileWriter(config.log_dir + 'val/',
                                                   sess.graph)

        start_epoch = int(
            sess.run(tf.train.get_global_step()) /
            (config.batches_per_epoch_train))

        print("Start from: %d" % start_epoch)

        for epoch in xrange(start_epoch, config.num_epochs):

            if epoch < 25 or epoch % 100 == 0:
                n_critic = 25
            else:
                n_critic = 5

            data_generator = data_gen(sec_mode=0)
            start_time = time.time()

            val_generator = data_gen(mode='val')

            batch_num = 0

            epoch_pho_loss = 0
            epoch_gen_loss = 0
            epoch_re_loss = 0
            epoch_dis_loss = 0

            epoch_pho_acc = 0
            epoch_gen_acc = 0
            epoch_dis_acc = 0
            epoch_dis_acc_fake = 0

            val_epoch_pho_loss = 0
            val_epoch_gen_loss = 0
            val_epoch_dis_loss = 0

            val_epoch_pho_acc = 0
            val_epoch_gen_acc = 0
            val_epoch_dis_acc = 0
            val_epoch_dis_acc_fake = 0

            with tf.variable_scope('Training'):

                for feats, f0, phos, singer_ids in data_generator:

                    # plt.imshow(feats.reshape(-1,66).T,aspect = 'auto', origin ='lower')

                    # plt.show()

                    # import pdb;pdb.set_trace()

                    pho_one_hot = one_hotize(phos, max_index=42)

                    f0 = f0.reshape([config.batch_size, config.max_phr_len, 1])

                    sing_id_shu = np.copy(singer_ids)

                    phos_shu = np.copy(phos)

                    np.random.shuffle(sing_id_shu)

                    np.random.shuffle(phos_shu)

                    for critic_itr in range(n_critic):
                        feed_dict = {
                            input_placeholder:
                            feats,
                            output_placeholder:
                            feats[:, :, :-2],
                            f0_input_placeholder:
                            f0,
                            rand_input_placeholder:
                            np.random.uniform(-1.0,
                                              1.0,
                                              size=[30, config.max_phr_len,
                                                    4]),
                            phoneme_labels:
                            phos,
                            singer_labels:
                            singer_ids,
                            phoneme_labels_shuffled:
                            phos_shu,
                            singer_labels_shuffled:
                            sing_id_shu
                        }
                        sess.run(dis_train_function, feed_dict=feed_dict)
                        sess.run(clip_discriminator_var_op,
                                 feed_dict=feed_dict)

                    feed_dict = {
                        input_placeholder:
                        feats,
                        output_placeholder:
                        feats[:, :, :-2],
                        f0_input_placeholder:
                        f0,
                        rand_input_placeholder:
                        np.random.uniform(-1.0,
                                          1.0,
                                          size=[30, config.max_phr_len, 4]),
                        phoneme_labels:
                        phos,
                        singer_labels:
                        singer_ids,
                        phoneme_labels_shuffled:
                        phos_shu,
                        singer_labels_shuffled:
                        sing_id_shu
                    }

                    _, _, step_re_loss, step_gen_loss, step_gen_acc = sess.run(
                        [
                            re_train_function, gen_train_function, final_loss,
                            G_loss_GAN, G_accuracy
                        ],
                        feed_dict=feed_dict)
                    # if step_gen_acc>0.3:
                    step_dis_loss, step_dis_acc, step_dis_acc_fake = sess.run(
                        [D_loss, D_accuracy, D_accuracy_fake],
                        feed_dict=feed_dict)
                    _, step_pho_loss, step_pho_acc = sess.run(
                        [pho_train_function, pho_loss, pho_acc],
                        feed_dict=feed_dict)
                    # else:
                    # step_dis_loss, step_dis_acc = sess.run([D_loss, D_accuracy], feed_dict = feed_dict)

                    epoch_pho_loss += step_pho_loss
                    epoch_re_loss += step_re_loss
                    epoch_gen_loss += step_gen_loss
                    epoch_dis_loss += step_dis_loss

                    epoch_pho_acc += step_pho_acc[0]
                    epoch_gen_acc += step_gen_acc
                    epoch_dis_acc += step_dis_acc
                    epoch_dis_acc_fake += step_dis_acc_fake

                    utils.progress(batch_num,
                                   config.batches_per_epoch_train,
                                   suffix='training done')
                    batch_num += 1

                epoch_pho_loss = epoch_pho_loss / config.batches_per_epoch_train
                epoch_re_loss = epoch_re_loss / config.batches_per_epoch_train
                epoch_gen_loss = epoch_gen_loss / config.batches_per_epoch_train
                epoch_dis_loss = epoch_dis_loss / config.batches_per_epoch_train
                epoch_dis_acc_fake = epoch_dis_acc_fake / config.batches_per_epoch_train

                epoch_pho_acc = epoch_pho_acc / config.batches_per_epoch_train
                epoch_gen_acc = epoch_gen_acc / config.batches_per_epoch_train
                epoch_dis_acc = epoch_dis_acc / config.batches_per_epoch_train
                summary_str = sess.run(summary, feed_dict=feed_dict)
                # import pdb;pdb.set_trace()
                train_summary_writer.add_summary(summary_str, epoch)
                # # summary_writer.add_summary(summary_str_val, epoch)
                train_summary_writer.flush()

            with tf.variable_scope('Validation'):

                for feats, f0, phos, singer_ids in val_generator:

                    pho_one_hot = one_hotize(phos, max_index=42)

                    f0 = f0.reshape([config.batch_size, config.max_phr_len, 1])

                    sing_id_shu = np.copy(singer_ids)

                    phos_shu = np.copy(phos)

                    np.random.shuffle(sing_id_shu)

                    np.random.shuffle(phos_shu)

                    feed_dict = {
                        input_placeholder:
                        feats,
                        output_placeholder:
                        feats[:, :, :-2],
                        f0_input_placeholder:
                        f0,
                        rand_input_placeholder:
                        np.random.uniform(-1.0,
                                          1.0,
                                          size=[30, config.max_phr_len, 4]),
                        phoneme_labels:
                        phos,
                        singer_labels:
                        singer_ids,
                        phoneme_labels_shuffled:
                        phos_shu,
                        singer_labels_shuffled:
                        sing_id_shu
                    }

                    step_pho_loss, step_pho_acc = sess.run([pho_loss, pho_acc],
                                                           feed_dict=feed_dict)
                    step_gen_loss, step_gen_acc = sess.run(
                        [final_loss, G_accuracy], feed_dict=feed_dict)
                    step_dis_loss, step_dis_acc, step_dis_acc_fake = sess.run(
                        [D_loss, D_accuracy, D_accuracy_fake],
                        feed_dict=feed_dict)

                    val_epoch_pho_loss += step_pho_loss
                    val_epoch_gen_loss += step_gen_loss
                    val_epoch_dis_loss += step_dis_loss

                    val_epoch_pho_acc += step_pho_acc[0]
                    val_epoch_gen_acc += step_gen_acc
                    val_epoch_dis_acc += step_dis_acc
                    val_epoch_dis_acc_fake += step_dis_acc_fake

                    utils.progress(batch_num,
                                   config.batches_per_epoch_train,
                                   suffix='training done')
                    batch_num += 1

                val_epoch_pho_loss = val_epoch_pho_loss / config.batches_per_epoch_val
                val_epoch_gen_loss = val_epoch_gen_loss / config.batches_per_epoch_val
                val_epoch_dis_loss = val_epoch_dis_loss / config.batches_per_epoch_val

                val_epoch_pho_acc = val_epoch_pho_acc / config.batches_per_epoch_val
                val_epoch_gen_acc = val_epoch_gen_acc / config.batches_per_epoch_val
                val_epoch_dis_acc = val_epoch_dis_acc / config.batches_per_epoch_val
                val_epoch_dis_acc_fake = val_epoch_dis_acc_fake / config.batches_per_epoch_val

                summary_str = sess.run(summary, feed_dict=feed_dict)
                # import pdb;pdb.set_trace()
                val_summary_writer.add_summary(summary_str, epoch)
                # # summary_writer.add_summary(summary_str_val, epoch)
                val_summary_writer.flush()
            duration = time.time() - start_time

            # np.save('./ikala_eval/accuracies', f0_accs)

            if (epoch + 1) % config.print_every == 0:
                print('epoch %d: Phone Loss = %.10f (%.3f sec)' %
                      (epoch + 1, epoch_pho_loss, duration))
                print('        : Phone Accuracy = %.10f ' % (epoch_pho_acc))
                print('        : Recon Loss = %.10f ' % (epoch_re_loss))
                print('        : Gen Loss = %.10f ' % (epoch_gen_loss))
                print('        : Gen Accuracy = %.10f ' % (epoch_gen_acc))
                print('        : Dis Loss = %.10f ' % (epoch_dis_loss))
                print('        : Dis Accuracy = %.10f ' % (epoch_dis_acc))
                print('        : Dis Accuracy Fake = %.10f ' %
                      (epoch_dis_acc_fake))
                print('        : Val Phone Accuracy = %.10f ' %
                      (val_epoch_pho_acc))
                print('        : Val Gen Loss = %.10f ' % (val_epoch_gen_loss))
                print('        : Val Gen Accuracy = %.10f ' %
                      (val_epoch_gen_acc))
                print('        : Val Dis Loss = %.10f ' % (val_epoch_dis_loss))
                print('        : Val Dis Accuracy = %.10f ' %
                      (val_epoch_dis_acc))
                print('        : Val Dis Accuracy Fake = %.10f ' %
                      (val_epoch_dis_acc_fake))

            if (epoch + 1) % config.save_every == 0 or (
                    epoch + 1) == config.num_epochs:
                # utils.list_to_file(val_f0_accs,'./ikala_eval/accuracies_'+str(epoch+1)+'.txt')
                checkpoint_file = os.path.join(config.log_dir, 'model.ckpt')
                saver.save(sess, checkpoint_file, global_step=epoch)
Esempio n. 3
0
def synth_file(file_name,
               file_path=config.wav_dir,
               show_plots=True,
               save_file=True):
    if file_name.startswith('ikala'):
        file_name = file_name[6:]
        file_path = config.wav_dir
        utils.write_ori_ikala(os.path.join(file_path, file_name), file_name)
        mode = 0
    elif file_name.startswith('mir'):
        file_name = file_name[4:]
        file_path = config.wav_dir_mir
        utils.write_ori_ikala(os.path.join(file_path, file_name), file_name)
        mode = 0
    elif file_name.startswith('med'):
        file_name = file_name[4:]
        file_path = config.wav_dir_med
        utils.write_ori_med(os.path.join(file_path, file_name), file_name)
        mode = 2
    else:
        mode = 1
        file_path = './'

    stat_file = h5py.File(config.stat_dir + 'stats.hdf5', mode='r')

    max_feat = np.array(stat_file["feats_maximus"])
    min_feat = np.array(stat_file["feats_minimus"])
    max_voc = np.array(stat_file["voc_stft_maximus"])
    min_voc = np.array(stat_file["voc_stft_minimus"])
    max_back = np.array(stat_file["back_stft_maximus"])
    min_back = np.array(stat_file["back_stft_minimus"])
    max_mix = np.array(max_voc) + np.array(max_back)

    with tf.Graph().as_default():

        input_placeholder = tf.placeholder(tf.float32,
                                           shape=(config.batch_size,
                                                  config.max_phr_len,
                                                  config.input_features),
                                           name='input_placeholder')

        with tf.variable_scope('First_Model') as scope:
            harm, ap, f0, vuv = modules.nr_wavenet(input_placeholder)

            # harmy = harm_1+harm

        if config.use_gan:
            with tf.variable_scope('Generator') as scope:
                gen_op = modules.GAN_generator(harm)
        # with tf.variable_scope('Discriminator') as scope:
        #     D_real = modules.GAN_discriminator(target_placeholder[:,:,:60],input_placeholder)
        #     scope.reuse_variables()
        #     D_fake = modules.GAN_discriminator(gen_op,input_placeholder)

        saver = tf.train.Saver(max_to_keep=config.max_models_to_keep)

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        sess = tf.Session()

        sess.run(init_op)

        ckpt = tf.train.get_checkpoint_state(config.log_dir_m1)

        if ckpt and ckpt.model_checkpoint_path:
            print("Using the model in %s" % ckpt.model_checkpoint_path)
            saver.restore(sess, ckpt.model_checkpoint_path)

        mix_stft = utils.file_to_stft(os.path.join(file_path, file_name),
                                      mode=mode)

        targs = utils.input_to_feats(os.path.join(file_path, file_name),
                                     mode=mode)

        import pdb
        pdb.set_trace()

        # f0_sac = utils.file_to_sac(os.path.join(file_path,file_name))
        # f0_sac = (f0_sac-min_feat[-2])/(max_feat[-2]-min_feat[-2])

        in_batches, nchunks_in = utils.generate_overlapadd(mix_stft)
        in_batches = in_batches / max_mix
        # in_batches = utils.normalize(in_batches, 'mix_stft', mode=config.norm_mode_in)
        val_outer = []

        first_pred = []

        cleaner = []

        gan_op = []

        for in_batch in in_batches:
            val_harm, val_ap, val_f0, val_vuv = sess.run(
                [harm, ap, f0, vuv], feed_dict={input_placeholder: in_batch})
            if config.use_gan:
                val_op = sess.run(gen_op,
                                  feed_dict={input_placeholder: in_batch})

                gan_op.append(val_op)

            # first_pred.append(harm1)
            # cleaner.append(val_harm)
            val_harm = val_harm
            val_outs = np.concatenate((val_harm, val_ap, val_f0, val_vuv),
                                      axis=-1)
            val_outer.append(val_outs)

        val_outer = np.array(val_outer)
        val_outer = utils.overlapadd(val_outer, nchunks_in)
        val_outer[:, -1] = np.round(val_outer[:, -1])
        val_outer = val_outer[:targs.shape[0], :]
        val_outer = np.clip(val_outer, 0.0, 1.0)

        import pdb
        pdb.set_trace()

        #Test purposes only
        # first_pred = np.array(first_pred)
        # first_pred = utils.overlapadd(first_pred, nchunks_in)

        # cleaner = np.array(cleaner)
        # cleaner = utils.overlapadd(cleaner, nchunks_in)

        if config.use_gan:
            gan_op = np.array(gan_op)
            gan_op = utils.overlapadd(gan_op, nchunks_in)

        targs = (targs - min_feat) / (max_feat - min_feat)

        # first_pred = (first_pred-min_feat[:60])/(max_feat[:60]-min_feat[:60])
        # cleaner = (cleaner-min_feat[:60])/(max_feat[:60]-min_feat[:60])

        # ax1 = plt.subplot(311)
        # plt.imshow(targs[:,:60].T, origin='lower', aspect='auto')
        # # ax1.set_title("Harmonic Spectral Envelope", fontsize = 10)
        # ax2 = plt.subplot(312)
        # plt.imshow(targs[:,60:64].T, origin='lower', aspect='auto')
        # # ax2.set_title("Aperiodicity Envelope", fontsize = 10)
        # ax3 = plt.subplot(313)
        # plt.plot(targs[:,-2])
        # ax3.set_title("Fundamental Frequency Contour", fontsize = 10)
        if show_plots:

            # import pdb;pdb.set_trace()

            ins = val_outer[:, :60]
            outs = targs[:, :60]
            plt.figure(1)
            ax1 = plt.subplot(211)
            plt.imshow(ins.T, origin='lower', aspect='auto')
            ax1.set_title("Predicted Harm ", fontsize=10)
            ax2 = plt.subplot(212)
            plt.imshow(outs.T, origin='lower', aspect='auto')
            ax2.set_title("Ground Truth Harm ", fontsize=10)
            # ax1 = plt.subplot(413)
            # plt.imshow(first_pred.T, origin='lower', aspect='auto')
            # ax1.set_title("Initial Prediction ", fontsize = 10)
            # ax2 = plt.subplot(412)
            # plt.imshow(cleaner.T, origin='lower', aspect='auto')
            # ax2.set_title("Residual Added ", fontsize = 10)

            if config.use_gan:
                plt.figure(5)
                ax1 = plt.subplot(411)
                plt.imshow(ins.T, origin='lower', aspect='auto')
                ax1.set_title("Predicted Harm ", fontsize=10)
                ax2 = plt.subplot(414)
                plt.imshow(outs.T, origin='lower', aspect='auto')
                ax2.set_title("Ground Truth Harm ", fontsize=10)
                ax1 = plt.subplot(412)
                plt.imshow(gan_op.T, origin='lower', aspect='auto')
                ax1.set_title("GAN output ", fontsize=10)
                ax1 = plt.subplot(413)
                plt.imshow((gan_op[:ins.shape[0], :] + ins).T,
                           origin='lower',
                           aspect='auto')
                ax1.set_title("GAN output ", fontsize=10)

            plt.figure(2)
            ax1 = plt.subplot(211)
            plt.imshow(val_outer[:, 60:-2].T, origin='lower', aspect='auto')
            ax1.set_title("Predicted Aperiodic Part", fontsize=10)
            ax2 = plt.subplot(212)
            plt.imshow(targs[:, 60:-2].T, origin='lower', aspect='auto')
            ax2.set_title("Ground Truth Aperiodic Part", fontsize=10)

            plt.figure(3)

            f0_output = val_outer[:, -2] * (
                (max_feat[-2] - min_feat[-2]) + min_feat[-2])
            f0_output = f0_output * (1 - targs[:, -1])
            f0_output[f0_output == 0] = np.nan
            plt.plot(f0_output, label="Predicted Value")
            f0_gt = targs[:, -2] * (
                (max_feat[-2] - min_feat[-2]) + min_feat[-2])
            f0_gt = f0_gt * (1 - targs[:, -1])
            f0_gt[f0_gt == 0] = np.nan
            plt.plot(f0_gt, label="Ground Truth")
            f0_difference = np.nan_to_num(abs(f0_gt - f0_output))
            f0_greater = np.where(f0_difference > config.f0_threshold)
            diff_per = f0_greater[0].shape[0] / len(f0_output)
            plt.suptitle("Percentage correct = " +
                         '{:.3%}'.format(1 - diff_per))
            # import pdb;pdb.set_trace()

            # import pdb;pdb.set_trace()
            # uu = f0_sac[:,0]*(1-f0_sac[:,1])
            # uu[uu == 0] = np.nan
            # plt.plot(uu, label="Sac f0")
            plt.legend()
            plt.figure(4)
            ax1 = plt.subplot(211)
            plt.plot(val_outer[:, -1])
            ax1.set_title("Predicted Voiced/Unvoiced", fontsize=10)
            ax2 = plt.subplot(212)
            plt.plot(targs[:, -1])
            ax2.set_title("Ground Truth Voiced/Unvoiced", fontsize=10)
            plt.show()
        if save_file:

            val_outer = np.ascontiguousarray(val_outer *
                                             (max_feat - min_feat) + min_feat)
            targs = np.ascontiguousarray(targs * (max_feat - min_feat) +
                                         min_feat)

            # import pdb;pdb.set_trace()

            # val_outer = np.ascontiguousarray(utils.denormalize(val_outer,'feats', mode=config.norm_mode_out))
            try:
                utils.feats_to_audio(val_outer,
                                     file_name[:-4] + '_synth_pred_f0')
                print("File saved to %s" % config.val_dir + file_name[:-4] +
                      '_synth_pred_f0.wav')
            except:
                print("Couldn't synthesize with predicted f0")
            try:
                val_outer[:, -2:] = targs[:, -2:]
                utils.feats_to_audio(val_outer,
                                     file_name[:-4] + '_synth_ori_f0')
                print("File saved to %s" % config.val_dir + file_name[:-4] +
                      '_synth_ori_f0.wav')
            except:
                print("Couldn't synthesize with original f0")
Esempio n. 4
0
def train(_):
    stat_file = h5py.File(config.stat_dir + 'stats.hdf5', mode='r')
    max_feat = np.array(stat_file["feats_maximus"])
    min_feat = np.array(stat_file["feats_minimus"])
    with tf.Graph().as_default():

        input_placeholder = tf.placeholder(tf.float32,
                                           shape=(config.batch_size,
                                                  config.max_phr_len,
                                                  config.input_features),
                                           name='input_placeholder')
        tf.summary.histogram('inputs', input_placeholder)
        target_placeholder = tf.placeholder(tf.float32,
                                            shape=(config.batch_size,
                                                   config.max_phr_len,
                                                   config.output_features),
                                            name='target_placeholder')
        tf.summary.histogram('targets', target_placeholder)

        with tf.variable_scope('First_Model') as scope:
            harm, ap, f0, vuv = modules.nr_wavenet(input_placeholder)

            # tf.summary.histogram('initial_output', op)

            tf.summary.histogram('harm', harm)

            tf.summary.histogram('ap', ap)

            tf.summary.histogram('f0', f0)

            tf.summary.histogram('vuv', vuv)

        if config.use_gan:

            with tf.variable_scope('Generator') as scope:
                gen_op = modules.GAN_generator(harm)
            with tf.variable_scope('Discriminator') as scope:
                D_real = modules.GAN_discriminator(
                    target_placeholder[:, :, :60], input_placeholder)
                scope.reuse_variables()
                D_fake = modules.GAN_discriminator(gen_op + harmy,
                                                   input_placeholder)

            # Comment out these lines to train without GAN

            D_loss_real = -tf.reduce_mean(tf.log(D_real + 1e-12))
            D_loss_fake = -tf.reduce_mean(tf.log(1. - (D_fake + 1e-12)))

            D_loss = D_loss_real + D_loss_fake

            D_summary_real = tf.summary.scalar('Discriminator_Loss_Real',
                                               D_loss_real)
            D_summary_fake = tf.summary.scalar('Discriminator_Loss_Fake',
                                               D_loss_fake)

            G_loss_GAN = -tf.reduce_mean(tf.log(D_fake + 1e-12))
            G_loss_diff = tf.reduce_sum(
                tf.abs(gen_op + harmy - target_placeholder[:, :, :60]) *
                (1 - target_placeholder[:, :, -1:])) * 0.5
            G_loss = G_loss_GAN + G_loss_diff

            G_summary_GAN = tf.summary.scalar('Generator_Loss_GAN', G_loss_GAN)
            G_summary_diff = tf.summary.scalar('Generator_Loss_diff',
                                               G_loss_diff)

            vars = tf.trainable_variables()

            # import pdb;pdb.set_trace()

            d_params = [
                v for v in vars if v.name.startswith('Discriminator/D')
            ]
            g_params = [v for v in vars if v.name.startswith('Generator/G')]

            # import pdb;pdb.set_trace()

            # d_optimizer_grad = tf.train.GradientDescentOptimizer(learning_rate=config.gan_lr).minimize(D_loss, var_list=d_params)
            # g_optimizer = tf.train.GradientDescentOptimizer(learning_rate=config.gan_lr).minimize(G_loss, var_list=g_params)

            d_optimizer = tf.train.GradientDescentOptimizer(
                learning_rate=config.gan_lr).minimize(D_loss,
                                                      var_list=d_params)
            # g_optimizer_diff = tf.train.AdamOptimizer(learning_rate=config.gan_lr).minimize(G_loss_diff, var_list=g_params)
            g_optimizer = tf.train.AdamOptimizer(
                learning_rate=config.gan_lr).minimize(G_loss,
                                                      var_list=g_params)

        # initial_loss = tf.reduce_sum(tf.abs(op - target_placeholder[:,:,:60])*np.linspace(1.0,0.7,60)*(1-target_placeholder[:,:,-1:]))

        harm_loss = tf.reduce_sum(
            tf.abs(harm - target_placeholder[:, :, :60]) *
            np.linspace(1.0, 0.7, 60) * (1 - target_placeholder[:, :, -1:]))

        ap_loss = tf.reduce_sum(
            tf.abs(ap - target_placeholder[:, :, 60:-2]) *
            (1 - target_placeholder[:, :, -1:]))

        f0_loss = tf.reduce_sum(
            tf.abs(f0 - target_placeholder[:, :, -2:-1]) *
            (1 - target_placeholder[:, :, -1:]))

        # vuv_loss = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels=, logits=vuv))

        vuv_loss = tf.reduce_mean(
            tf.reduce_sum(binary_cross(target_placeholder[:, :, -1:], vuv)))

        loss = harm_loss + ap_loss + vuv_loss + f0_loss * config.f0_weight

        # initial_summary = tf.summary.scalar('initial_loss', initial_loss)

        harm_summary = tf.summary.scalar('harm_loss', harm_loss)

        ap_summary = tf.summary.scalar('ap_loss', ap_loss)

        f0_summary = tf.summary.scalar('f0_loss', f0_loss)

        vuv_summary = tf.summary.scalar('vuv_loss', vuv_loss)

        loss_summary = tf.summary.scalar('total_loss', loss)

        global_step = tf.Variable(0, name='global_step', trainable=False)

        optimizer = tf.train.AdamOptimizer(learning_rate=config.init_lr)

        # optimizer_f0 = tf.train.AdamOptimizer(learning_rate = config.init_lr)

        train_function = optimizer.minimize(loss, global_step=global_step)

        # train_f0 = optimizer.minimize(f0_loss, global_step= global_step)

        # train_harm = optimizer.minimize(harm_loss, global_step= global_step)

        # train_ap = optimizer.minimize(ap_loss, global_step= global_step)

        # train_f0 = optimizer.minimize(f0_loss, global_step= global_step)

        # train_vuv = optimizer.minimize(vuv_loss, global_step= global_step)

        summary = tf.summary.merge_all()

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        saver = tf.train.Saver(max_to_keep=config.max_models_to_keep)
        sess = tf.Session()

        sess.run(init_op)

        ckpt = tf.train.get_checkpoint_state(config.log_dir_m1)

        if ckpt and ckpt.model_checkpoint_path:
            print("Using the model in %s" % ckpt.model_checkpoint_path)
            saver.restore(sess, ckpt.model_checkpoint_path)

        train_summary_writer = tf.summary.FileWriter(
            config.log_dir_m1 + 'train/', sess.graph)
        val_summary_writer = tf.summary.FileWriter(config.log_dir_m1 + 'val/',
                                                   sess.graph)

        start_epoch = int(
            sess.run(tf.train.get_global_step()) /
            (config.batches_per_epoch_train))

        print("Start from: %d" % start_epoch)
        f0_accs = []
        for epoch in xrange(start_epoch, config.num_epochs):
            val_f0_accs = []

            data_generator = data_gen()
            start_time = time.time()

            epoch_loss_harm = 0
            epoch_loss_ap = 0
            epoch_loss_f0 = 0
            epoch_loss_vuv = 0
            epoch_total_loss = 0
            # epoch_initial_loss = 0

            epoch_loss_harm_val = 0
            epoch_loss_ap_val = 0
            epoch_loss_f0_val = 0
            epoch_loss_vuv_val = 0
            epoch_total_loss_val = 0
            # epoch_initial_loss_val = 0

            if config.use_gan:
                epoch_loss_generator_GAN = 0
                epoch_loss_generator_diff = 0
                epoch_loss_discriminator_real = 0
                epoch_loss_discriminator_fake = 0

                val_epoch_loss_generator_GAN = 0
                val_epoch_loss_generator_diff = 0
                val_epoch_loss_discriminator_real = 0
                val_epoch_loss_discriminator_fake = 0

            batch_num = 0
            batch_num_val = 0
            val_generator = data_gen(mode='val')

            # val_generator = get_batches(train_filename=config.h5py_file_val, batches_per_epoch=config.batches_per_epoch_val_m1)

            with tf.variable_scope('Training'):

                for voc, feat in data_generator:
                    voc = np.clip(
                        voc + np.random.rand(config.max_phr_len,
                                             config.input_features) *
                        np.clip(np.random.rand(1), 0.0,
                                config.noise_threshold), 0.0, 1.0)

                    _, step_loss_harm, step_loss_ap, step_loss_f0, step_loss_vuv, step_total_loss = sess.run(
                        [
                            train_function, harm_loss, ap_loss, f0_loss,
                            vuv_loss, loss
                        ],
                        feed_dict={
                            input_placeholder: voc,
                            target_placeholder: feat
                        })
                    # _, step_loss_f0 = sess.run([train_f0, f0_loss], feed_dict={input_placeholder: voc,target_placeholder: feat})

                    if config.use_gan:
                        _, step_dis_loss_real, step_dis_loss_fake = sess.run(
                            [d_optimizer, D_loss_real, D_loss_fake],
                            feed_dict={
                                input_placeholder: voc,
                                target_placeholder: feat
                            })
                        _, step_gen_loss_GAN, step_gen_loss_diff = sess.run(
                            [g_optimizer, G_loss_GAN, G_loss_diff],
                            feed_dict={
                                input_placeholder: voc,
                                target_placeholder: feat
                            })
                    # else :
                    #     _, step_dis_loss_real, step_dis_loss_fake = sess.run([d_optimizer_grad, D_loss_real,D_loss_fake], feed_dict={input_placeholder: voc,target_placeholder: feat})
                    #     _, step_gen_loss_diff = sess.run([g_optimizer_diff, G_loss_diff], feed_dict={input_placeholder: voc,target_placeholder: feat})
                    #     step_gen_loss_GAN = 0

                    # _, step_loss_harm = sess.run([train_harm, harm_loss], feed_dict={input_placeholder: voc,target_placeholder: feat})
                    # _, step_loss_ap = sess.run([train_ap, ap_loss], feed_dict={input_placeholder: voc,target_placeholder: feat})
                    # _, step_loss_f0 = sess.run([train_f0, f0_loss], feed_dict={input_placeholder: voc,target_placeholder: feat})
                    # _, step_loss_vuv = sess.run([train_vuv, vuv_loss], feed_dict={input_placeholder: voc,target_placeholder: feat})

                    # epoch_initial_loss+=step_initial_loss
                    epoch_loss_harm += step_loss_harm
                    epoch_loss_ap += step_loss_ap
                    epoch_loss_f0 += step_loss_f0
                    epoch_loss_vuv += step_loss_vuv
                    epoch_total_loss += step_total_loss

                    if config.use_gan:

                        epoch_loss_generator_GAN += step_gen_loss_GAN
                        epoch_loss_generator_diff += step_gen_loss_diff
                        epoch_loss_discriminator_real += step_dis_loss_real
                        epoch_loss_discriminator_fake += step_dis_loss_fake

                    utils.progress(batch_num,
                                   config.batches_per_epoch_train,
                                   suffix='training done')
                    batch_num += 1

                # epoch_initial_loss = epoch_initial_loss/(config.batches_per_epoch_train *config.batch_size*config.max_phr_len*60)
                epoch_loss_harm = epoch_loss_harm / (
                    config.batches_per_epoch_train * config.batch_size *
                    config.max_phr_len * 60)
                epoch_loss_ap = epoch_loss_ap / (
                    config.batches_per_epoch_train * config.batch_size *
                    config.max_phr_len * 4)
                epoch_loss_f0 = epoch_loss_f0 / (
                    config.batches_per_epoch_train * config.batch_size *
                    config.max_phr_len)
                epoch_loss_vuv = epoch_loss_vuv / (
                    config.batches_per_epoch_train * config.batch_size *
                    config.max_phr_len)
                epoch_total_loss = epoch_total_loss / (
                    config.batches_per_epoch_train * config.batch_size *
                    config.max_phr_len * 66)

                if config.use_gan:

                    epoch_loss_generator_GAN = epoch_loss_generator_GAN / (
                        config.batches_per_epoch_train * config.batch_size)
                    epoch_loss_generator_diff = epoch_loss_generator_diff / (
                        config.batches_per_epoch_train * config.batch_size *
                        config.max_phr_len * 60)
                    epoch_loss_discriminator_real = epoch_loss_discriminator_real / (
                        config.batches_per_epoch_train * config.batch_size)
                    epoch_loss_discriminator_fake = epoch_loss_discriminator_fake / (
                        config.batches_per_epoch_train * config.batch_size)

                summary_str = sess.run(summary,
                                       feed_dict={
                                           input_placeholder: voc,
                                           target_placeholder: feat
                                       })
                train_summary_writer.add_summary(summary_str, epoch)
                # summary_writer.add_summary(summary_str_val, epoch)
                train_summary_writer.flush()

            with tf.variable_scope('Validation'):

                for voc, feat in val_generator:

                    step_loss_harm_val = sess.run(harm_loss,
                                                  feed_dict={
                                                      input_placeholder: voc,
                                                      target_placeholder: feat
                                                  })
                    step_loss_ap_val = sess.run(ap_loss,
                                                feed_dict={
                                                    input_placeholder: voc,
                                                    target_placeholder: feat
                                                })
                    step_loss_f0_val = sess.run(f0_loss,
                                                feed_dict={
                                                    input_placeholder: voc,
                                                    target_placeholder: feat
                                                })
                    step_loss_vuv_val = sess.run(vuv_loss,
                                                 feed_dict={
                                                     input_placeholder: voc,
                                                     target_placeholder: feat
                                                 })
                    step_total_loss_val = sess.run(loss,
                                                   feed_dict={
                                                       input_placeholder: voc,
                                                       target_placeholder: feat
                                                   })

                    epoch_loss_harm_val += step_loss_harm_val
                    epoch_loss_ap_val += step_loss_ap_val
                    epoch_loss_f0_val += step_loss_f0_val
                    epoch_loss_vuv_val += step_loss_vuv_val
                    epoch_total_loss_val += step_total_loss_val

                    if config.use_gan:

                        val_epoch_loss_generator_GAN += step_gen_loss_GAN
                        val_epoch_loss_generator_diff += step_gen_loss_diff
                        val_epoch_loss_discriminator_real += step_dis_loss_real
                        val_epoch_loss_discriminator_fake += step_dis_loss_fake

                    utils.progress(batch_num_val,
                                   config.batches_per_epoch_val_m1,
                                   suffix='validiation done')
                    batch_num_val += 1

                # f0_accs.append(np.mean(val_f0_accs))

                # epoch_initial_loss_val = epoch_initial_loss_val/(config.batches_per_epoch_val_m1 *config.batch_size*config.max_phr_len*60)
                epoch_loss_harm_val = epoch_loss_harm_val / (
                    batch_num_val * config.batch_size * config.max_phr_len *
                    60)
                epoch_loss_ap_val = epoch_loss_ap_val / (
                    batch_num_val * config.batch_size * config.max_phr_len * 4)
                epoch_loss_f0_val = epoch_loss_f0_val / (
                    batch_num_val * config.batch_size * config.max_phr_len)
                epoch_loss_vuv_val = epoch_loss_vuv_val / (
                    batch_num_val * config.batch_size * config.max_phr_len)
                epoch_total_loss_val = epoch_total_loss_val / (
                    batch_num_val * config.batch_size * config.max_phr_len *
                    66)

                if config.use_gan:

                    val_epoch_loss_generator_GAN = val_epoch_loss_generator_GAN / (
                        config.batches_per_epoch_val_m1 * config.batch_size)
                    val_epoch_loss_generator_diff = val_epoch_loss_generator_diff / (
                        config.batches_per_epoch_val_m1 * config.batch_size *
                        config.max_phr_len * 60)
                    val_epoch_loss_discriminator_real = val_epoch_loss_discriminator_real / (
                        config.batches_per_epoch_val_m1 * config.batch_size)
                    val_epoch_loss_discriminator_fake = val_epoch_loss_discriminator_fake / (
                        config.batches_per_epoch_val_m1 * config.batch_size)

                summary_str = sess.run(summary,
                                       feed_dict={
                                           input_placeholder: voc,
                                           target_placeholder: feat
                                       })
                val_summary_writer.add_summary(summary_str, epoch)
                # summary_writer.add_summary(summary_str_val, epoch)
                val_summary_writer.flush()

            duration = time.time() - start_time

            # np.save('./ikala_eval/accuracies', f0_accs)

            if (epoch + 1) % config.print_every == 0:
                print('epoch %d: Harm Training Loss = %.10f (%.3f sec)' %
                      (epoch + 1, epoch_loss_harm, duration))
                print('        : Ap Training Loss = %.10f ' % (epoch_loss_ap))
                print('        : F0 Training Loss = %.10f ' % (epoch_loss_f0))
                print('        : VUV Training Loss = %.10f ' %
                      (epoch_loss_vuv))
                # print('        : Initial Training Loss = %.10f ' % (epoch_initial_loss))

                if config.use_gan:

                    print('        : Gen GAN Training Loss = %.10f ' %
                          (epoch_loss_generator_GAN))
                    print('        : Gen diff Training Loss = %.10f ' %
                          (epoch_loss_generator_diff))
                    print(
                        '        : Discriminator Training Loss Real = %.10f ' %
                        (epoch_loss_discriminator_real))
                    print(
                        '        : Discriminator Training Loss Fake = %.10f ' %
                        (epoch_loss_discriminator_fake))

                print('        : Harm Validation Loss = %.10f ' %
                      (epoch_loss_harm_val))
                print('        : Ap Validation Loss = %.10f ' %
                      (epoch_loss_ap_val))
                print('        : F0 Validation Loss = %.10f ' %
                      (epoch_loss_f0_val))
                print('        : VUV Validation Loss = %.10f ' %
                      (epoch_loss_vuv_val))

                # if (epoch + 1) % config.save_every == 0 or (epoch + 1) == config.num_epochs:
                # print('        : Mean F0 IKala Accuracy  = %.10f ' % (np.mean(val_f0_accs)))

                # print('        : Mean F0 IKala Accuracy = '+'%{1:.{0}f}%'.format(np.mean(val_f0_accs)))
                # print('        : Initial Validation Loss = %.10f ' % (epoch_initial_loss_val))

                if config.use_gan:

                    print('        : Gen GAN Validation Loss = %.10f ' %
                          (val_epoch_loss_generator_GAN))
                    print('        : Gen diff Validation Loss = %.10f ' %
                          (val_epoch_loss_generator_diff))
                    print(
                        '        : Discriminator Validation Loss Real = %.10f '
                        % (val_epoch_loss_discriminator_real))
                    print(
                        '        : Discriminator Validation Loss Fake = %.10f '
                        % (val_epoch_loss_discriminator_fake))

            if (epoch + 1) % config.save_every == 0 or (
                    epoch + 1) == config.num_epochs:
                utils.list_to_file(
                    val_f0_accs,
                    './ikala_eval/accuracies_' + str(epoch + 1) + '.txt')
                checkpoint_file = os.path.join(config.log_dir_m1, 'model.ckpt')
                saver.save(sess, checkpoint_file, global_step=epoch)
Esempio n. 5
0
def synth_file(file_name="015.hdf5",
               singer_index=0,
               file_path=config.wav_dir,
               show_plots=True):

    stat_file = h5py.File('./stats.hdf5', mode='r')
    max_feat = np.array(stat_file["feats_maximus"])
    min_feat = np.array(stat_file["feats_minimus"])
    with tf.Graph().as_default():

        output_placeholder = tf.placeholder(tf.float32,
                                            shape=(config.batch_size,
                                                   config.max_phr_len, 64),
                                            name='output_placeholder')

        f0_output_placeholder = tf.placeholder(tf.float32,
                                               shape=(config.batch_size,
                                                      config.max_phr_len, 1),
                                               name='f0_output_placeholder')

        f0_input_placeholder = tf.placeholder(tf.float32,
                                              shape=(config.batch_size,
                                                     config.max_phr_len),
                                              name='f0_input_placeholder')
        f0_onehot_labels = tf.one_hot(indices=tf.cast(f0_input_placeholder,
                                                      tf.int32),
                                      depth=len(config.notes))

        f0_context_placeholder = tf.placeholder(tf.float32,
                                                shape=(config.batch_size,
                                                       config.max_phr_len, 1),
                                                name='f0_context_placeholder')

        phone_context_placeholder = tf.placeholder(
            tf.float32,
            shape=(config.batch_size, config.max_phr_len, 1),
            name='phone_context_placeholder')

        rand_input_placeholder = tf.placeholder(tf.float32,
                                                shape=(config.batch_size,
                                                       config.max_phr_len, 64),
                                                name='rand_input_placeholder')

        prob = tf.placeholder_with_default(1.0, shape=())

        phoneme_labels = tf.placeholder(tf.int32,
                                        shape=(config.batch_size,
                                               config.max_phr_len),
                                        name='phoneme_placeholder')
        phone_onehot_labels = tf.one_hot(indices=tf.cast(
            phoneme_labels, tf.int32),
                                         depth=len(config.phonemas))

        with tf.variable_scope('Generator_feats') as scope:
            inputs = tf.concat([
                phone_onehot_labels, f0_onehot_labels,
                phone_context_placeholder, f0_context_placeholder
            ],
                               axis=-1)
            voc_output = modules.GAN_generator(inputs)

        with tf.variable_scope('Generator_f0') as scope:
            inputs = tf.concat([
                phone_onehot_labels, f0_onehot_labels,
                phone_context_placeholder, f0_context_placeholder,
                output_placeholder
            ],
                               axis=-1)
            # inputs = tf.concat([phone_onehot_labels, f0_onehot_labels, phone_context_placeholder, f0_context_placeholder, (voc_output/2)+0.5], axis = -1)
            f0_output = modules.GAN_generator_f0(inputs)

            scope.reuse_variables()

            inputs = tf.concat([
                phone_onehot_labels, f0_onehot_labels,
                phone_context_placeholder, f0_context_placeholder,
                (voc_output / 2) + 0.5
            ],
                               axis=-1)
            f0_output_2 = modules.GAN_generator_f0(inputs)

        saver = tf.train.Saver(max_to_keep=config.max_models_to_keep)

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        sess = tf.Session()

        sess.run(init_op)

        ckpt = tf.train.get_checkpoint_state(config.log_dir)

        if ckpt and ckpt.model_checkpoint_path:
            print("Using the model in %s" % ckpt.model_checkpoint_path)
            saver.restore(sess, ckpt.model_checkpoint_path)
        # saver.restore(sess, './log/model.ckpt-3999')

        # import pdb;pdb.set_trace()

        feat_file = h5py.File(config.feats_dir + file_name, "r")

        # speaker_file = h5py.File(config.voice_dir+speaker_file, "r")

        # feats = utils.input_to_feats('./54228_chorus.wav_ori_vocals.wav', mode = 1)

        feats = feat_file["world_feats"][()]

        feats = (feats - min_feat) / (max_feat - min_feat)

        phones = feat_file["phonemes"][()]

        notes = feat_file["notes"][()]

        phones = np.concatenate([phones, notes], axis=-1)

        # in_batches_f0, nchunks_in = utils.generate_overlapadd(f0_nor.reshape(-1,1))

        in_batches_pho, nchunks_in = utils.generate_overlapadd(phones)

        in_batches_feat, kaka = utils.generate_overlapadd(feats)

        noters = np.expand_dims(
            np.array([config.notes[int(x)] for x in notes[:, 0]]), 1)

        out_batches_feats = []

        out_batches_f0 = []

        for conds, feat in zip(in_batches_pho, in_batches_feat):
            # import pdb;pdb.set_trace()
            f0 = conds[:, :, 2]
            phones = conds[:, :, 0]
            f0_context = conds[:, :, -1:]
            phones_context = conds[:, :, 1:2]

            feed_dict = {
                f0_input_placeholder: f0,
                phoneme_labels: phones,
                phone_context_placeholder: phones_context,
                f0_context_placeholder: f0_context,
                output_placeholder: feat[:, :, :-2]
            }

            output_feats_gan, output_f0 = sess.run([voc_output, f0_output_2],
                                                   feed_dict=feed_dict)

            out_batches_feats.append(output_feats_gan / 2 + 0.5)
            out_batches_f0.append(output_f0 / 2 + 0.5)

            # out_batches_voc_stft_phase.append(output_voc_stft_phase)

        out_batches_feats = np.array(out_batches_feats)
        out_batches_feats = utils.overlapadd(out_batches_feats, nchunks_in)

        out_batches_f0 = np.array(out_batches_f0)
        out_batches_f0 = utils.overlapadd(out_batches_f0, nchunks_in)

        feats = feats * (max_feat - min_feat) + min_feat

        out_batches_feats = out_batches_feats * (max_feat[:-2] -
                                                 min_feat[:-2]) + min_feat[:-2]

        out_batches_feats = out_batches_feats[:len(feats)]

        out_batches_f0 = out_batches_f0 * (max_feat[-2] -
                                           min_feat[-2]) + min_feat[-2]

        out_batches_f0 = out_batches_f0[:len(feats)]

        diff_1 = (out_batches_f0 - noters) * (1 - feats[:, -1:])

        diff_2 = (feats[:, -2:-1] - noters) * (1 - feats[:, -1:])

        print("Mean predicted note deviation {}".format(diff_1.mean()))
        print("Mean original note deviation {}".format(diff_2.mean()))

        print("STD predicted note deviation {}".format(diff_1.std()))
        print("STD original note deviation {}".format(diff_2.std()))

        plt.figure(1)
        plt.suptitle("F0 contour")
        plt.plot(out_batches_f0, label='Predicted F0')
        plt.plot(feats[:, -2], label="Ground Truth F0")
        plt.plot(noters, label="Input Midi Note")
        # plt.plot(phones[:,])
        plt.legend()

        # plt.figure(2)
        # ax1 = plt.subplot(211)

        # plt.imshow(feats[:,:60].T,aspect='auto',origin='lower')

        # ax1.set_title("Ground Truth Vocoder Features", fontsize=10)

        # ax2 = plt.subplot(212, sharex = ax1, sharey = ax1)

        # plt.imshow(out_batches_feats[:,:60].T,aspect='auto',origin='lower')

        # ax2.set_title("GAN Output Vocoder Features", fontsize=10)

        plt.show()

        import pdb
        pdb.set_trace()

        # out_batches_feats_gan= out_batches_feats_gan[:len(feats)]

        first_op = np.concatenate(
            [out_batches_feats, out_batches_f0, feats[:, -1:]], axis=-1)
        second_op = np.concatenate(
            [feats[:, 60:64], out_batches_f0, feats[:, -1:]], axis=-1)

        # pho_op = np.concatenate([out_batches_feats_1,feats[:,-2:]], axis = -1)

        # gan_op = np.concatenate([out_batches_feats_gan,feats[:,-2:]], axis = -1)

        # import pdb;pdb.set_trace()
        # gan_op = np.ascontiguousarray(gan_op)

        # pho_op = np.ascontiguousarray(pho_op)

        first_op = np.ascontiguousarray(first_op)
        second_op = np.ascontiguousarray(second_op)

        utils.feats_to_audio(first_op, file_name[:-4] + '_gan_op')
        print("Full output saved to {}".format(
            os.path.join(config.val_dir, file_name[:-4] + '_gan_op.wav')))
        utils.feats_to_audio(first_op, file_name[:-4] + '_F0_op')
        print("Only F0 saved to {}".format(
            os.path.join(config.val_dir, file_name[:-4] + '_F0_op.wav')))
Esempio n. 6
0
def train(_):
    # stat_file = h5py.File(config.stat_dir+'stats.hdf5', mode='r')
    # max_feat = np.array(stat_file["feats_maximus"])
    # min_feat = np.array(stat_file["feats_minimus"])
    with tf.Graph().as_default():

        output_placeholder = tf.placeholder(tf.float32,
                                            shape=(config.batch_size,
                                                   config.max_phr_len, 64),
                                            name='output_placeholder')

        f0_output_placeholder = tf.placeholder(tf.float32,
                                               shape=(config.batch_size,
                                                      config.max_phr_len, 1),
                                               name='f0_output_placeholder')

        f0_input_placeholder = tf.placeholder(tf.float32,
                                              shape=(config.batch_size,
                                                     config.max_phr_len),
                                              name='f0_input_placeholder')
        f0_onehot_labels = tf.one_hot(indices=tf.cast(f0_input_placeholder,
                                                      tf.int32),
                                      depth=len(config.notes))

        f0_context_placeholder = tf.placeholder(tf.float32,
                                                shape=(config.batch_size,
                                                       config.max_phr_len, 1),
                                                name='f0_context_placeholder')

        uv_placeholder = tf.placeholder(tf.float32,
                                        shape=(config.batch_size,
                                               config.max_phr_len, 1),
                                        name='uv_placeholder')

        phone_context_placeholder = tf.placeholder(
            tf.float32,
            shape=(config.batch_size, config.max_phr_len, 1),
            name='phone_context_placeholder')

        rand_input_placeholder = tf.placeholder(tf.float32,
                                                shape=(config.batch_size,
                                                       config.max_phr_len, 64),
                                                name='rand_input_placeholder')

        prob = tf.placeholder_with_default(1.0, shape=())

        phoneme_labels = tf.placeholder(tf.int32,
                                        shape=(config.batch_size,
                                               config.max_phr_len),
                                        name='phoneme_placeholder')
        phone_onehot_labels = tf.one_hot(indices=tf.cast(
            phoneme_labels, tf.int32),
                                         depth=len(config.phonemas))

        with tf.variable_scope('Generator_feats') as scope:
            inputs = tf.concat([
                phone_onehot_labels, f0_onehot_labels,
                phone_context_placeholder, f0_context_placeholder
            ],
                               axis=-1)
            voc_output = modules.GAN_generator(inputs)

        with tf.variable_scope('Discriminator_feats') as scope:
            inputs = tf.concat([
                phone_onehot_labels, f0_onehot_labels,
                phone_context_placeholder, f0_context_placeholder
            ],
                               axis=-1)
            D_real = modules.GAN_discriminator((output_placeholder - 0.5) * 2,
                                               inputs)
            scope.reuse_variables()
            D_fake = modules.GAN_discriminator(voc_output, inputs)

        with tf.variable_scope('Generator_f0') as scope:
            inputs = tf.concat([
                phone_onehot_labels, f0_onehot_labels,
                phone_context_placeholder, f0_context_placeholder,
                output_placeholder
            ],
                               axis=-1)
            # inputs = tf.concat([phone_onehot_labels, f0_onehot_labels, phone_context_placeholder, f0_context_placeholder, (voc_output/2)+0.5], axis = -1)
            f0_output = modules.GAN_generator_f0(inputs)

            scope.reuse_variables()

            inputs = tf.concat([
                phone_onehot_labels, f0_onehot_labels,
                phone_context_placeholder, f0_context_placeholder,
                (voc_output / 2) + 0.5
            ],
                               axis=-1)
            f0_output_2 = modules.GAN_generator_f0(inputs)

        with tf.variable_scope('Discriminator_f0') as scope:
            inputs = tf.concat([
                phone_onehot_labels, f0_onehot_labels,
                phone_context_placeholder, f0_context_placeholder,
                output_placeholder
            ],
                               axis=-1)
            D_real_f0 = modules.GAN_discriminator_f0(
                (f0_output_placeholder - 0.5) * 2, inputs)
            scope.reuse_variables()
            D_fake_f0 = modules.GAN_discriminator_f0(f0_output, inputs)

            scope.reuse_variables()

            inputs = tf.concat([
                phone_onehot_labels, f0_onehot_labels,
                phone_context_placeholder, f0_context_placeholder,
                (voc_output / 2) + 0.5
            ],
                               axis=-1)
            D_real_f0_2 = modules.GAN_discriminator_f0(
                (f0_output_placeholder - 0.5) * 2, inputs)
            scope.reuse_variables()
            D_fake_f0_2 = modules.GAN_discriminator_f0(f0_output_2, inputs)

        g_params_feats = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                           scope="Generator_feats")

        d_params_feats = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                           scope="Discriminator_feats")

        g_params_f0 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                        scope="Generator_f0")

        d_params_f0 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                        scope="Discriminator_f0")

        D_loss = tf.reduce_mean(D_real + 1e-12) - tf.reduce_mean(D_fake +
                                                                 1e-12)

        dis_summary = tf.summary.scalar('dis_loss', D_loss)

        G_loss_GAN = tf.reduce_mean(D_fake + 1e-12) + tf.reduce_sum(
            tf.abs(output_placeholder - (voc_output / 2 + 0.5))) * 0.00005

        gen_summary = tf.summary.scalar('gen_loss', G_loss_GAN)

        D_loss_f0 = tf.reduce_mean(D_real_f0 +
                                   1e-12) - tf.reduce_mean(D_fake_f0 + 1e-12)

        dis_summary_f0 = tf.summary.scalar('dis_loss_f0', D_loss_f0)

        G_loss_GAN_f0 = tf.reduce_mean(D_fake_f0 + 1e-12) + tf.reduce_sum(
            tf.abs(f0_output_placeholder - (f0_output / 2 + 0.5))) * 0.00005
        # + tf.reduce_mean(D_fake_f0_2+1e-12) + tf.reduce_sum(tf.abs(f0_output_placeholder- (f0_output_2/2+0.5))) *0.00005

        D_loss_f0_2 = tf.reduce_mean(D_real_f0_2 +
                                     1e-12) - tf.reduce_mean(D_fake_f0_2 +
                                                             1e-12)

        G_loss_GAN_f0_2 = tf.reduce_mean(D_fake_f0_2 + 1e-12) + tf.reduce_sum(
            tf.abs(f0_output_placeholder - (f0_output_2 / 2 + 0.5))) * 0.00005

        gen_summary_f0 = tf.summary.scalar('gen_loss_f0', G_loss_GAN_f0)

        summary = tf.summary.merge_all()

        global_step = tf.Variable(0, name='global_step', trainable=False)

        global_step_dis = tf.Variable(0,
                                      name='global_step_dis',
                                      trainable=False)

        global_step_f0 = tf.Variable(0, name='global_step_f0', trainable=False)
        global_step_dis_f0 = tf.Variable(0,
                                         name='global_step_dis_f0',
                                         trainable=False)

        global_step_f0_2 = tf.Variable(0,
                                       name='global_step_f0_2',
                                       trainable=False)

        global_step_dis_f0_2 = tf.Variable(0,
                                           name='global_step_dis_f0_2',
                                           trainable=False)

        dis_optimizer = tf.train.RMSPropOptimizer(learning_rate=5e-5)

        gen_optimizer = tf.train.RMSPropOptimizer(learning_rate=5e-5)

        dis_optimizer_f0 = tf.train.RMSPropOptimizer(learning_rate=5e-5)

        gen_optimizer_f0 = tf.train.RMSPropOptimizer(learning_rate=5e-5)

        dis_optimizer_f0_2 = tf.train.RMSPropOptimizer(learning_rate=5e-5)

        gen_optimizer_f0_2 = tf.train.RMSPropOptimizer(learning_rate=5e-5)
        # GradientDescentOptimizer

        dis_train_function = dis_optimizer.minimize(
            D_loss, global_step=global_step_dis, var_list=d_params_feats)

        gen_train_function = gen_optimizer.minimize(G_loss_GAN,
                                                    global_step=global_step,
                                                    var_list=g_params_feats)

        dis_train_function_f0 = dis_optimizer.minimize(
            D_loss_f0, global_step=global_step_dis_f0, var_list=d_params_f0)

        gen_train_function_f0 = gen_optimizer.minimize(
            G_loss_GAN_f0, global_step=global_step_f0, var_list=g_params_f0)

        dis_train_function_f0_2 = dis_optimizer.minimize(
            D_loss_f0_2,
            global_step=global_step_dis_f0_2,
            var_list=d_params_f0)

        gen_train_function_f0_2 = gen_optimizer.minimize(
            G_loss_GAN_f0_2,
            global_step=global_step_f0_2,
            var_list=g_params_f0)

        clip_discriminator_var_op_feats = [
            var.assign(tf.clip_by_value(var, -0.01, 0.01))
            for var in d_params_feats
        ]

        clip_discriminator_var_op_f0 = [
            var.assign(tf.clip_by_value(var, -0.01, 0.01))
            for var in d_params_f0
        ]

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        saver = tf.train.Saver(max_to_keep=config.max_models_to_keep)
        sess = tf.Session()

        sess.run(init_op)

        ckpt = tf.train.get_checkpoint_state(config.log_dir)

        if ckpt and ckpt.model_checkpoint_path:
            print("Using the model in %s" % ckpt.model_checkpoint_path)
            saver.restore(sess, ckpt.model_checkpoint_path)

        train_summary_writer = tf.summary.FileWriter(config.log_dir + 'train/',
                                                     sess.graph)
        val_summary_writer = tf.summary.FileWriter(config.log_dir + 'val/',
                                                   sess.graph)

        start_epoch = int(
            sess.run(tf.train.get_global_step()) /
            (config.batches_per_epoch_train))

        print("Start from: %d" % start_epoch)

        for epoch in xrange(start_epoch, config.num_epochs):

            if epoch < 25 or epoch % 100 == 0:
                n_critic = 25
            else:
                n_critic = 5

            if epoch < 1025 or epoch % 100 == 0:
                n_critic_f0 = 25
            else:
                n_critic_f0 = 5

            data_generator = data_gen(sec_mode=0)
            start_time = time.time()

            val_generator = data_gen(mode='val')

            batch_num = 0

            # epoch_pho_loss = 0
            epoch_gen_loss = 0
            epoch_dis_loss = 0
            epoch_gen_loss_f0 = 0
            epoch_dis_loss_f0 = 0

            with tf.variable_scope('Training'):

                for feats, conds in data_generator:
                    f0 = conds[:, :, 2]
                    phones = conds[:, :, 0]
                    f0_context = conds[:, :, -1:]
                    phones_context = conds[:, :, 1:2]

                    feed_dict = {
                        f0_output_placeholder: feats[:, :, -2:-1],
                        f0_input_placeholder: f0,
                        phoneme_labels: phones,
                        phone_context_placeholder: phones_context,
                        f0_context_placeholder: f0_context,
                        output_placeholder: feats[:, :, :64],
                        uv_placeholder: feats[:, :, -1:]
                    }

                    for critic_itr in range(n_critic):

                        sess.run(dis_train_function, feed_dict=feed_dict)
                        sess.run(clip_discriminator_var_op_feats,
                                 feed_dict=feed_dict)

                    # feed_dict = {input_placeholder: feats, output_placeholder: feats[:,:,:-2], f0_input_placeholder: f0, rand_input_placeholder: np.random.uniform(-1.0, 1.0, size=[30,config.max_phr_len,4]),
                    # phoneme_labels:phos, singer_labels: singer_ids, phoneme_labels_shuffled:phos_shu, singer_labels_shuffled:sing_id_shu}

                    _, step_gen_loss = sess.run(
                        [gen_train_function, G_loss_GAN], feed_dict=feed_dict)
                    # import pdb;pdb.set_trace()
                    # if step_gen_acc>0.3:
                    step_dis_loss = sess.run(D_loss, feed_dict=feed_dict)

                    # feed_dict = {input_placeholder: feats, output_placeholder: feats[:,:,:-2], f0_input_placeholder: f0, rand_input_placeholder: np.random.uniform(-1.0, 1.0, size=[30,config.max_phr_len,4]),
                    # phoneme_labels:phos, singer_labels: singer_ids, phoneme_labels_shuffled:phos_shu, singer_labels_shuffled:sing_id_shu}

                    if epoch > 1000:
                        for critic_itr in range(n_critic_f0):
                            sess.run(dis_train_function_f0_2,
                                     feed_dict=feed_dict)
                            sess.run(clip_discriminator_var_op_f0,
                                     feed_dict=feed_dict)

                    # feed_dict = {input_placeholder: feats, output_placeholder: feats[:,:,:-2], f0_input_placeholder: f0, rand_input_placeholder: np.random.uniform(-1.0, 1.0, size=[30,config.max_phr_len,4]),
                    # phoneme_labels:phos, singer_labels: singer_ids, phoneme_labels_shuffled:phos_shu, singer_labels_shuffled:sing_id_shu}

                        _, step_gen_loss_2 = sess.run(
                            [gen_train_function_f0_2, G_loss_GAN_f0_2],
                            feed_dict=feed_dict)
                        # import pdb;pdb.set_trace()
                        # if step_gen_acc>0.3:
                        step_dis_loss_2 = sess.run(D_loss_f0_2,
                                                   feed_dict=feed_dict)
                    else:
                        for critic_itr in range(n_critic):
                            sess.run(dis_train_function_f0,
                                     feed_dict=feed_dict)
                            sess.run(clip_discriminator_var_op_f0,
                                     feed_dict=feed_dict)
                        _, step_gen_loss_f0 = sess.run(
                            [gen_train_function_f0, G_loss_GAN_f0],
                            feed_dict=feed_dict)
                        # import pdb;pdb.set_trace()
                        # if step_gen_acc>0.3:
                        step_dis_loss_f0 = sess.run(D_loss_f0,
                                                    feed_dict=feed_dict)

                    # _, step_pho_loss, step_pho_acc = sess.run([pho_train_function, pho_loss, pho_acc], feed_dict= feed_dict)
                    # else:
                    # step_dis_loss, step_dis_acc = sess.run([D_loss, D_accuracy], feed_dict = feed_dict)

                    # epoch_pho_loss+=step_pho_loss
                    # epoch_re_loss+=step_re_loss
                    epoch_gen_loss += step_gen_loss
                    epoch_dis_loss += step_dis_loss

                    # epoch_pho_acc+=step_pho_acc[0]
                    # epoch_gen_acc+=step_gen_acc
                    # epoch_dis_acc+=step_dis_acc
                    # epoch_dis_acc_fake+=step_dis_acc_fake

                    utils.progress(batch_num,
                                   config.batches_per_epoch_train,
                                   suffix='training done')
                    batch_num += 1

                # epoch_pho_loss = epoch_pho_loss/config.batches_per_epoch_train
                # epoch_re_loss = epoch_re_loss/config.batches_per_epoch_train
                epoch_gen_loss = epoch_gen_loss / config.batches_per_epoch_train
                epoch_dis_loss = epoch_dis_loss / config.batches_per_epoch_train
                # epoch_dis_acc_fake = epoch_dis_acc_fake/config.batches_per_epoch_train

                # epoch_pho_acc = epoch_pho_acc/config.batches_per_epoch_train
                # epoch_gen_acc = epoch_gen_acc/config.batches_per_epoch_train
                # epoch_dis_acc = epoch_dis_acc/config.batches_per_epoch_train
                summary_str = sess.run(summary, feed_dict=feed_dict)
                # import pdb;pdb.set_trace()
                train_summary_writer.add_summary(summary_str, epoch)
                # # summary_writer.add_summary(summary_str_val, epoch)
                train_summary_writer.flush()

            duration = time.time() - start_time

            # np.save('./ikala_eval/accuracies', f0_accs)

            if (epoch + 1) % config.print_every == 0:
                print('epoch %d: Gen Loss = %.10f (%.3f sec)' %
                      (epoch + 1, epoch_gen_loss, duration))
                # print('        : Phone Accuracy = %.10f ' % (epoch_pho_acc))
                # print('        : Recon Loss = %.10f ' % (epoch_re_loss))
                # print('        : Gen Loss = %.10f ' % (epoch_gen_loss))
                # print('        : Gen Accuracy = %.10f ' % (epoch_gen_acc))
                print('        : Dis Loss = %.10f ' % (epoch_dis_loss))
                # print('        : Dis Accuracy = %.10f ' % (epoch_dis_acc))
                # print('        : Dis Accuracy Fake = %.10f ' % (epoch_dis_acc_fake))
                # print('        : Val Phone Accuracy = %.10f ' % (val_epoch_pho_acc))
                # print('        : Val Gen Loss = %.10f ' % (val_epoch_gen_loss))
                # print('        : Val Gen Accuracy = %.10f ' % (val_epoch_gen_acc))
                # print('        : Val Dis Loss = %.10f ' % (val_epoch_dis_loss))
                # print('        : Val Dis Accuracy = %.10f ' % (val_epoch_dis_acc))
                # print('        : Val Dis Accuracy Fake = %.10f ' % (val_epoch_dis_acc_fake))

            if (epoch + 1) % config.save_every == 0 or (
                    epoch + 1) == config.num_epochs:
                # utils.list_to_file(val_f0_accs,'./ikala_eval/accuracies_'+str(epoch+1)+'.txt')
                checkpoint_file = os.path.join(config.log_dir, 'model.ckpt')
                saver.save(sess, checkpoint_file, global_step=epoch)