Beispiel #1
0
    def mix_model(self):
        sess = tf.Session()
        self.load_model(sess, log_dir=self.config.log_dir)
        val_generator = data_gen(self.config)
        count_batch = 0
        for batch_count, [
                out_audios, out_envelopes, out_features, total_count
        ] in enumerate(val_generator):
            out_features_copy = np.copy(out_features)
            for j in range(int(len(out_features) / 2) - 1):
                out_features[j] = out_features_copy[-1 - j]
                out_features[-1 - j] = out_features_copy[j]

            feed_dict = {self.input_placeholder: out_envelopes[:,:,:self.config.rhyfeats], self.cond_placeholder: out_features,\
             self.output_placeholder: out_audios, self.is_train: False}
            output_full = sess.run(self.output_wav, feed_dict=feed_dict)

            for count in range(self.config.batch_size):
                if self.config.model == "spec":
                    out_audio = utils.griffinlim(
                        np.exp(output_full[count]) - 1, self.config)
                else:
                    out_audio = output_full[count]
                output_file = os.path.join(
                    self.config.output_dir,
                    'output_{}_{}_{}.wav'.format(batch_count, count,
                                                 self.config.model))
                sf.write(output_file, np.clip(out_audio, -1, 1),
                         self.config.fs)
                sf.write(
                    os.path.join(self.config.output_dir,
                                 'gt_{}_{}.wav'.format(batch_count, count)),
                    out_audios[count], self.config.fs)
            utils.progress(batch_count, total_count)
Beispiel #2
0
    def train(self):
        """
        Function to train the model, and save Tensorboard summary, for N epochs. 
        """
        sess = tf.Session()

        self.loss_function()
        self.get_optimizers()
        self.load_model(sess, config.log_dir)
        self.get_summary(sess, config.log_dir)
        start_epoch = int(
            sess.run(tf.train.get_global_step()) /
            (config.batches_per_epoch_train))

        print("Start from: %d" % start_epoch)

        for epoch in range(start_epoch, config.num_epochs):
            data_generator = data_gen()
            start_time = time.time()

            batch_num = 0
            epoch_train_loss = 0

            with tf.variable_scope('Training'):
                for ins, outs in data_generator:

                    step_loss, summary_str = self.train_model(ins, outs, sess)
                    epoch_train_loss += step_loss

                    self.train_summary_writer.add_summary(summary_str, epoch)
                    self.train_summary_writer.flush()

                    utils.progress(batch_num,
                                   config.batches_per_epoch_train,
                                   suffix='training done')

                    batch_num += 1

                epoch_train_loss = epoch_train_loss / batch_num
                print_dict = {"Training Loss": epoch_train_loss}

            if (epoch + 1) % config.validate_every == 0:
                pre, acc, rec = self.validate_model(sess)
                print_dict["Validation Precision"] = pre
                print_dict["Validation Accuracy"] = acc
                print_dict["Validation Recall"] = rec

            end_time = time.time()
            if (epoch + 1) % config.print_every == 0:
                self.print_summary(print_dict, epoch, end_time - start_time)
            if (epoch + 1) % config.save_every == 0 or (
                    epoch + 1) == config.num_epochs:
                self.save_model(sess, epoch + 1, config.log_dir)
Beispiel #3
0
def train(_):
    stat_file = h5py.File(config.stat_dir + 'stats.hdf5', mode='r')
    max_feat = np.array(stat_file["feats_maximus"])
    min_feat = np.array(stat_file["feats_minimus"])
    with tf.Graph().as_default():

        input_placeholder = tf.placeholder(tf.float32,
                                           shape=(config.batch_size,
                                                  config.max_phr_len, 66),
                                           name='input_placeholder')
        tf.summary.histogram('inputs', input_placeholder)

        output_placeholder = tf.placeholder(tf.float32,
                                            shape=(config.batch_size,
                                                   config.max_phr_len, 64),
                                            name='output_placeholder')

        f0_input_placeholder = tf.placeholder(tf.float32,
                                              shape=(config.batch_size,
                                                     config.max_phr_len, 1),
                                              name='f0_input_placeholder')

        rand_input_placeholder = tf.placeholder(tf.float32,
                                                shape=(config.batch_size,
                                                       config.max_phr_len, 4),
                                                name='rand_input_placeholder')

        # pho_input_placeholder = tf.placeholder(tf.float32, shape=(config.batch_size,config.max_phr_len, 42),name='pho_input_placeholder')

        prob = tf.placeholder_with_default(1.0, shape=())

        phoneme_labels = tf.placeholder(tf.int32,
                                        shape=(config.batch_size,
                                               config.max_phr_len),
                                        name='phoneme_placeholder')
        phone_onehot_labels = tf.one_hot(indices=tf.cast(
            phoneme_labels, tf.int32),
                                         depth=42)

        singer_labels = tf.placeholder(tf.float32,
                                       shape=(config.batch_size),
                                       name='singer_placeholder')
        singer_onehot_labels = tf.one_hot(indices=tf.cast(
            singer_labels, tf.int32),
                                          depth=12)

        phoneme_labels_shuffled = tf.placeholder(tf.int32,
                                                 shape=(config.batch_size,
                                                        config.max_phr_len),
                                                 name='phoneme_placeholder_s')
        phone_onehot_labels_shuffled = tf.one_hot(indices=tf.cast(
            phoneme_labels_shuffled, tf.int32),
                                                  depth=42)

        singer_labels_shuffled = tf.placeholder(tf.float32,
                                                shape=(config.batch_size),
                                                name='singer_placeholder_s')
        singer_onehot_labels_shuffled = tf.one_hot(indices=tf.cast(
            singer_labels_shuffled, tf.int32),
                                                   depth=12)

        with tf.variable_scope('phone_Model') as scope:
            # regularizer = tf.contrib.layers.l2_regularizer(scale=0.1)
            pho_logits = modules.phone_network(input_placeholder)
            pho_classes = tf.argmax(pho_logits, axis=-1)
            pho_probs = tf.nn.softmax(pho_logits)

        with tf.variable_scope('Final_Model') as scope:
            voc_output = modules.final_net(singer_onehot_labels,
                                           f0_input_placeholder,
                                           phone_onehot_labels)
            voc_output_decoded = tf.nn.sigmoid(voc_output)
            scope.reuse_variables()
            voc_output_3 = modules.final_net(singer_onehot_labels,
                                             f0_input_placeholder, pho_probs)
            voc_output_3_decoded = tf.nn.sigmoid(voc_output_3)

        # with tf.variable_scope('singer_Model') as scope:
        #     singer_embedding, singer_logits = modules.singer_network(input_placeholder, prob)
        #     singer_classes = tf.argmax(singer_logits, axis=-1)
        #     singer_probs = tf.nn.softmax(singer_logits)

        with tf.variable_scope('Generator') as scope:
            voc_output_2 = modules.GAN_generator(singer_onehot_labels,
                                                 phone_onehot_labels,
                                                 f0_input_placeholder,
                                                 rand_input_placeholder)
            # scope.reuse_variables()
            # voc_output_2_2 = modules.GAN_generator(voc_output_3_decoded, singer_onehot_labels, phone_onehot_labels, f0_input_placeholder, rand_input_placeholder)

        with tf.variable_scope('Discriminator') as scope:
            D_real = modules.GAN_discriminator(
                (output_placeholder - 0.5) * 2, singer_onehot_labels,
                phone_onehot_labels, f0_input_placeholder)
            scope.reuse_variables()
            D_fake = modules.GAN_discriminator(voc_output_2,
                                               singer_onehot_labels,
                                               phone_onehot_labels,
                                               f0_input_placeholder)
            # scope.reuse_variables()
            # epsilon = tf.random_uniform([], 0.0, 1.0)
            # x_hat = (output_placeholder-0.5)*2*epsilon + (1-epsilon)* voc_output_2
            # d_hat = modules.GAN_discriminator(x_hat, singer_onehot_labels, phone_onehot_labels, f0_input_placeholder)
            # scope.reuse_variables()
            # D_fake_2 = modules.GAN_discriminator(voc_output_2_2, singer_onehot_labels, phone_onehot_labels, f0_input_placeholder)
            scope.reuse_variables()
            D_fake_real = modules.GAN_discriminator(
                (voc_output_decoded - 0.5) * 2, singer_onehot_labels,
                phone_onehot_labels, f0_input_placeholder)
        # import pdb;pdb.set_trace()

        # Get network parameters

        final_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                         scope="Final_Model")

        g_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                     scope="Generator")

        d_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                     scope="Discriminator")

        phone_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                         scope="phone_Model")

        # Phoneme network loss and summary

        pho_weights = tf.reduce_sum(config.phonemas_weights *
                                    phone_onehot_labels,
                                    axis=-1)

        unweighted_losses = tf.nn.softmax_cross_entropy_with_logits(
            labels=phone_onehot_labels, logits=pho_logits)

        weighted_losses = unweighted_losses * pho_weights

        pho_loss = tf.reduce_mean(weighted_losses)
        # +tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels= output_placeholder, logits=voc_output_3))*0.001

        # reconstruct_loss_pho = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels = output_placeholder, logits=voc_output_decoded_gen)) *0.00001

        # pho_loss+=reconstruct_loss_pho

        pho_acc = tf.metrics.accuracy(labels=phoneme_labels,
                                      predictions=pho_classes)

        pho_summary = tf.summary.scalar('pho_loss', pho_loss)

        pho_acc_summary = tf.summary.scalar('pho_accuracy', pho_acc[0])

        # Discriminator Loss

        # D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels = tf.ones_like(D_real) , logits=D_real+1e-12))
        # D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels = tf.zeros_like(D_fake) , logits=D_fake+1e-12)) + tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels = tf.zeros_like(D_fake_2) , logits=D_fake_2+1e-12)) *0.5
        # D_loss_fake_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels = tf.zeros_like(D_fake_real) , logits=D_fake_real+1e-12))

        # D_loss_real = tf.reduce_mean(D_real+1e-12)
        # D_loss_fake = - tf.reduce_mean(D_fake+1e-12)
        # D_loss_fake_real = - tf.reduce_mean(D_fake_real+1e-12)

        # gradients = tf.gradients(d_hat, x_hat)[0] + 1e-6
        # slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
        # gradient_penalty = tf.reduce_mean((slopes-1.0)**2)
        # errD += gradient_penalty
        # D_loss_fake_real = - tf.reduce_mean(D_fake_real)

        D_correct_pred = tf.equal(tf.round(tf.sigmoid(D_real)),
                                  tf.ones_like(D_real))

        D_correct_pred_fake = tf.equal(tf.round(tf.sigmoid(D_fake_real)),
                                       tf.ones_like(D_fake_real))

        D_accuracy = tf.reduce_mean(tf.cast(D_correct_pred, tf.float32))

        D_accuracy_fake = tf.reduce_mean(
            tf.cast(D_correct_pred_fake, tf.float32))

        D_loss = tf.reduce_mean(D_real + 1e-12) - tf.reduce_mean(D_fake +
                                                                 1e-12)
        # -tf.reduce_mean(D_fake_real+1e-12)*0.001

        dis_summary = tf.summary.scalar('dis_loss', D_loss)

        dis_acc_summary = tf.summary.scalar('dis_acc', D_accuracy)

        dis_acc_fake_summary = tf.summary.scalar('dis_acc_fake',
                                                 D_accuracy_fake)

        #Final net loss

        # G_loss_GAN = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels= tf.ones_like(D_real), logits=D_fake+1e-12)) + tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels= tf.ones_like(D_fake_2), logits=D_fake_2+1e-12))
        # + tf.reduce_sum(tf.abs(output_placeholder- (voc_output_2/2+0.5))*(1-input_placeholder[:,:,-1:])) *0.00001

        G_loss_GAN = tf.reduce_mean(D_fake + 1e-12) + tf.reduce_sum(
            tf.abs(output_placeholder - (voc_output_2 / 2 + 0.5))) * 0.00005
        # + tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels= output_placeholder, logits=voc_output)) *0.000005
        #

        G_correct_pred = tf.equal(tf.round(tf.sigmoid(D_fake)),
                                  tf.ones_like(D_real))

        # G_correct_pred_2 = tf.equal(tf.round(tf.sigmoid(D_fake_2)), tf.ones_like(D_real))

        G_accuracy = tf.reduce_mean(tf.cast(G_correct_pred, tf.float32))

        gen_summary = tf.summary.scalar('gen_loss', G_loss_GAN)

        gen_acc_summary = tf.summary.scalar('gen_acc', G_accuracy)

        final_loss = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels= output_placeholder, logits=voc_output)) \
                           # +tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels= output_placeholder, logits=voc_output_3))*0.5

        # reconstruct_loss = tf.reduce_sum(tf.abs(output_placeholder- (voc_output_2/2+0.5)))

        final_summary = tf.summary.scalar('final_loss', final_loss)

        summary = tf.summary.merge_all()

        # summary_val = tf.summary.merge([f0_summary_midi, pho_summary, singer_summary, reconstruct_summary, pho_acc_summary_val,  f0_acc_summary_midi_val, singer_acc_summary_val ])

        # vuv_summary = tf.summary.scalar('vuv_loss', vuv_loss)

        # loss_summary = tf.summary.scalar('total_loss', loss)

        #Global steps

        global_step = tf.Variable(0, name='global_step', trainable=False)

        global_step_re = tf.Variable(0, name='global_step_re', trainable=False)

        global_step_dis = tf.Variable(0,
                                      name='global_step_dis',
                                      trainable=False)

        global_step_gen = tf.Variable(0,
                                      name='global_step_gen',
                                      trainable=False)

        #Optimizers

        pho_optimizer = tf.train.AdamOptimizer(learning_rate=config.init_lr)

        re_optimizer = tf.train.AdamOptimizer(learning_rate=config.init_lr)

        dis_optimizer = tf.train.RMSPropOptimizer(learning_rate=5e-5)

        gen_optimizer = tf.train.RMSPropOptimizer(learning_rate=5e-5)
        # GradientDescentOptimizer

        # update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

        # Training functions
        pho_train_function = pho_optimizer.minimize(pho_loss,
                                                    global_step=global_step,
                                                    var_list=phone_params)

        # with tf.control_dependencies(update_ops):
        re_train_function = re_optimizer.minimize(final_loss,
                                                  global_step=global_step_re,
                                                  var_list=final_params)

        dis_train_function = dis_optimizer.minimize(
            D_loss, global_step=global_step_dis, var_list=d_params)

        gen_train_function = gen_optimizer.minimize(
            G_loss_GAN, global_step=global_step_gen, var_list=g_params)

        clip_discriminator_var_op = [
            var.assign(tf.clip_by_value(var, -0.01, 0.01)) for var in d_params
        ]

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        saver = tf.train.Saver(max_to_keep=config.max_models_to_keep)
        sess = tf.Session()

        sess.run(init_op)

        ckpt = tf.train.get_checkpoint_state(config.log_dir)

        if ckpt and ckpt.model_checkpoint_path:
            print("Using the model in %s" % ckpt.model_checkpoint_path)
            saver.restore(sess, ckpt.model_checkpoint_path)

        train_summary_writer = tf.summary.FileWriter(config.log_dir + 'train/',
                                                     sess.graph)
        val_summary_writer = tf.summary.FileWriter(config.log_dir + 'val/',
                                                   sess.graph)

        start_epoch = int(
            sess.run(tf.train.get_global_step()) /
            (config.batches_per_epoch_train))

        print("Start from: %d" % start_epoch)

        for epoch in xrange(start_epoch, config.num_epochs):

            if epoch < 25 or epoch % 100 == 0:
                n_critic = 25
            else:
                n_critic = 5

            data_generator = data_gen(sec_mode=0)
            start_time = time.time()

            val_generator = data_gen(mode='val')

            batch_num = 0

            epoch_pho_loss = 0
            epoch_gen_loss = 0
            epoch_re_loss = 0
            epoch_dis_loss = 0

            epoch_pho_acc = 0
            epoch_gen_acc = 0
            epoch_dis_acc = 0
            epoch_dis_acc_fake = 0

            val_epoch_pho_loss = 0
            val_epoch_gen_loss = 0
            val_epoch_dis_loss = 0

            val_epoch_pho_acc = 0
            val_epoch_gen_acc = 0
            val_epoch_dis_acc = 0
            val_epoch_dis_acc_fake = 0

            with tf.variable_scope('Training'):

                for feats, f0, phos, singer_ids in data_generator:

                    # plt.imshow(feats.reshape(-1,66).T,aspect = 'auto', origin ='lower')

                    # plt.show()

                    # import pdb;pdb.set_trace()

                    pho_one_hot = one_hotize(phos, max_index=42)

                    f0 = f0.reshape([config.batch_size, config.max_phr_len, 1])

                    sing_id_shu = np.copy(singer_ids)

                    phos_shu = np.copy(phos)

                    np.random.shuffle(sing_id_shu)

                    np.random.shuffle(phos_shu)

                    for critic_itr in range(n_critic):
                        feed_dict = {
                            input_placeholder:
                            feats,
                            output_placeholder:
                            feats[:, :, :-2],
                            f0_input_placeholder:
                            f0,
                            rand_input_placeholder:
                            np.random.uniform(-1.0,
                                              1.0,
                                              size=[30, config.max_phr_len,
                                                    4]),
                            phoneme_labels:
                            phos,
                            singer_labels:
                            singer_ids,
                            phoneme_labels_shuffled:
                            phos_shu,
                            singer_labels_shuffled:
                            sing_id_shu
                        }
                        sess.run(dis_train_function, feed_dict=feed_dict)
                        sess.run(clip_discriminator_var_op,
                                 feed_dict=feed_dict)

                    feed_dict = {
                        input_placeholder:
                        feats,
                        output_placeholder:
                        feats[:, :, :-2],
                        f0_input_placeholder:
                        f0,
                        rand_input_placeholder:
                        np.random.uniform(-1.0,
                                          1.0,
                                          size=[30, config.max_phr_len, 4]),
                        phoneme_labels:
                        phos,
                        singer_labels:
                        singer_ids,
                        phoneme_labels_shuffled:
                        phos_shu,
                        singer_labels_shuffled:
                        sing_id_shu
                    }

                    _, _, step_re_loss, step_gen_loss, step_gen_acc = sess.run(
                        [
                            re_train_function, gen_train_function, final_loss,
                            G_loss_GAN, G_accuracy
                        ],
                        feed_dict=feed_dict)
                    # if step_gen_acc>0.3:
                    step_dis_loss, step_dis_acc, step_dis_acc_fake = sess.run(
                        [D_loss, D_accuracy, D_accuracy_fake],
                        feed_dict=feed_dict)
                    _, step_pho_loss, step_pho_acc = sess.run(
                        [pho_train_function, pho_loss, pho_acc],
                        feed_dict=feed_dict)
                    # else:
                    # step_dis_loss, step_dis_acc = sess.run([D_loss, D_accuracy], feed_dict = feed_dict)

                    epoch_pho_loss += step_pho_loss
                    epoch_re_loss += step_re_loss
                    epoch_gen_loss += step_gen_loss
                    epoch_dis_loss += step_dis_loss

                    epoch_pho_acc += step_pho_acc[0]
                    epoch_gen_acc += step_gen_acc
                    epoch_dis_acc += step_dis_acc
                    epoch_dis_acc_fake += step_dis_acc_fake

                    utils.progress(batch_num,
                                   config.batches_per_epoch_train,
                                   suffix='training done')
                    batch_num += 1

                epoch_pho_loss = epoch_pho_loss / config.batches_per_epoch_train
                epoch_re_loss = epoch_re_loss / config.batches_per_epoch_train
                epoch_gen_loss = epoch_gen_loss / config.batches_per_epoch_train
                epoch_dis_loss = epoch_dis_loss / config.batches_per_epoch_train
                epoch_dis_acc_fake = epoch_dis_acc_fake / config.batches_per_epoch_train

                epoch_pho_acc = epoch_pho_acc / config.batches_per_epoch_train
                epoch_gen_acc = epoch_gen_acc / config.batches_per_epoch_train
                epoch_dis_acc = epoch_dis_acc / config.batches_per_epoch_train
                summary_str = sess.run(summary, feed_dict=feed_dict)
                # import pdb;pdb.set_trace()
                train_summary_writer.add_summary(summary_str, epoch)
                # # summary_writer.add_summary(summary_str_val, epoch)
                train_summary_writer.flush()

            with tf.variable_scope('Validation'):

                for feats, f0, phos, singer_ids in val_generator:

                    pho_one_hot = one_hotize(phos, max_index=42)

                    f0 = f0.reshape([config.batch_size, config.max_phr_len, 1])

                    sing_id_shu = np.copy(singer_ids)

                    phos_shu = np.copy(phos)

                    np.random.shuffle(sing_id_shu)

                    np.random.shuffle(phos_shu)

                    feed_dict = {
                        input_placeholder:
                        feats,
                        output_placeholder:
                        feats[:, :, :-2],
                        f0_input_placeholder:
                        f0,
                        rand_input_placeholder:
                        np.random.uniform(-1.0,
                                          1.0,
                                          size=[30, config.max_phr_len, 4]),
                        phoneme_labels:
                        phos,
                        singer_labels:
                        singer_ids,
                        phoneme_labels_shuffled:
                        phos_shu,
                        singer_labels_shuffled:
                        sing_id_shu
                    }

                    step_pho_loss, step_pho_acc = sess.run([pho_loss, pho_acc],
                                                           feed_dict=feed_dict)
                    step_gen_loss, step_gen_acc = sess.run(
                        [final_loss, G_accuracy], feed_dict=feed_dict)
                    step_dis_loss, step_dis_acc, step_dis_acc_fake = sess.run(
                        [D_loss, D_accuracy, D_accuracy_fake],
                        feed_dict=feed_dict)

                    val_epoch_pho_loss += step_pho_loss
                    val_epoch_gen_loss += step_gen_loss
                    val_epoch_dis_loss += step_dis_loss

                    val_epoch_pho_acc += step_pho_acc[0]
                    val_epoch_gen_acc += step_gen_acc
                    val_epoch_dis_acc += step_dis_acc
                    val_epoch_dis_acc_fake += step_dis_acc_fake

                    utils.progress(batch_num,
                                   config.batches_per_epoch_train,
                                   suffix='training done')
                    batch_num += 1

                val_epoch_pho_loss = val_epoch_pho_loss / config.batches_per_epoch_val
                val_epoch_gen_loss = val_epoch_gen_loss / config.batches_per_epoch_val
                val_epoch_dis_loss = val_epoch_dis_loss / config.batches_per_epoch_val

                val_epoch_pho_acc = val_epoch_pho_acc / config.batches_per_epoch_val
                val_epoch_gen_acc = val_epoch_gen_acc / config.batches_per_epoch_val
                val_epoch_dis_acc = val_epoch_dis_acc / config.batches_per_epoch_val
                val_epoch_dis_acc_fake = val_epoch_dis_acc_fake / config.batches_per_epoch_val

                summary_str = sess.run(summary, feed_dict=feed_dict)
                # import pdb;pdb.set_trace()
                val_summary_writer.add_summary(summary_str, epoch)
                # # summary_writer.add_summary(summary_str_val, epoch)
                val_summary_writer.flush()
            duration = time.time() - start_time

            # np.save('./ikala_eval/accuracies', f0_accs)

            if (epoch + 1) % config.print_every == 0:
                print('epoch %d: Phone Loss = %.10f (%.3f sec)' %
                      (epoch + 1, epoch_pho_loss, duration))
                print('        : Phone Accuracy = %.10f ' % (epoch_pho_acc))
                print('        : Recon Loss = %.10f ' % (epoch_re_loss))
                print('        : Gen Loss = %.10f ' % (epoch_gen_loss))
                print('        : Gen Accuracy = %.10f ' % (epoch_gen_acc))
                print('        : Dis Loss = %.10f ' % (epoch_dis_loss))
                print('        : Dis Accuracy = %.10f ' % (epoch_dis_acc))
                print('        : Dis Accuracy Fake = %.10f ' %
                      (epoch_dis_acc_fake))
                print('        : Val Phone Accuracy = %.10f ' %
                      (val_epoch_pho_acc))
                print('        : Val Gen Loss = %.10f ' % (val_epoch_gen_loss))
                print('        : Val Gen Accuracy = %.10f ' %
                      (val_epoch_gen_acc))
                print('        : Val Dis Loss = %.10f ' % (val_epoch_dis_loss))
                print('        : Val Dis Accuracy = %.10f ' %
                      (val_epoch_dis_acc))
                print('        : Val Dis Accuracy Fake = %.10f ' %
                      (val_epoch_dis_acc_fake))

            if (epoch + 1) % config.save_every == 0 or (
                    epoch + 1) == config.num_epochs:
                # utils.list_to_file(val_f0_accs,'./ikala_eval/accuracies_'+str(epoch+1)+'.txt')
                checkpoint_file = os.path.join(config.log_dir, 'model.ckpt')
                saver.save(sess, checkpoint_file, global_step=epoch)
Beispiel #4
0
	def train(self):
		"""
		Function to train the model, and save Tensorboard summary, for N epochs. 
		"""
		sess = tf.Session()

		self.loss_function()
		self.get_optimizers()
		self.load_model(sess, config.log_dir)
		self.get_summary(sess, config.log_dir)
		start_epoch = int(sess.run(tf.train.get_global_step()) / (config.batches_per_epoch_train))


		print("Start from: %d" % start_epoch)


		for epoch in range(start_epoch, config.num_epochs):
			data_generator = data_gen()
			val_generator = data_gen(mode = 'Val')
			start_time = time.time()


			batch_num = 0
			epoch_final_loss = 0
			epoch_harm_loss = 0
			epoch_ap_loss = 0
			epoch_vuv_loss = 0
			epoch_f0_loss = 0

			val_final_loss = 0
			val_harm_loss = 0
			val_ap_loss = 0
			val_vuv_loss = 0
			val_f0_loss = 0

			with tf.variable_scope('Training'):
				for voc, feat in data_generator:

					final_loss, summary_str = self.train_model(voc, feat, sess)

					epoch_final_loss+=final_loss

					self.train_summary_writer.add_summary(summary_str, epoch)
					self.train_summary_writer.flush()

					utils.progress(batch_num,config.batches_per_epoch_train, suffix = 'training done')

					batch_num+=1

				epoch_final_loss = epoch_final_loss/batch_num

				print_dict = {"Final Loss": epoch_final_loss}

			if (epoch + 1) % config.validate_every == 0:
				batch_num = 0
				with tf.variable_scope('Validation'):
					for voc, feat in val_generator:

						final_loss, summary_str= self.validate_model(voc, feat, sess)
						val_final_loss+=final_loss

						self.val_summary_writer.add_summary(summary_str, epoch)
						self.val_summary_writer.flush()
						batch_num+=1

						utils.progress(batch_num, config.batches_per_epoch_val, suffix='validation done')

					val_final_loss = val_final_loss/batch_num

					print_dict["Val Final Loss"] =  val_final_loss

			end_time = time.time()
			if (epoch + 1) % config.print_every == 0:
				self.print_summary(print_dict, epoch, end_time-start_time)
			if (epoch + 1) % config.save_every == 0 or (epoch + 1) == config.num_epochs:
				self.save_model(sess, epoch+1, config.log_dir)
Beispiel #5
0
def train(_):
    # stat_file = h5py.File(config.stat_dir+'stats.hdf5', mode='r')
    # max_feat = np.array(stat_file["feats_maximus"])
    # min_feat = np.array(stat_file["feats_minimus"])
    with tf.Graph().as_default():

        output_placeholder = tf.placeholder(tf.float32,
                                            shape=(config.batch_size,
                                                   config.max_phr_len, 64),
                                            name='output_placeholder')

        f0_output_placeholder = tf.placeholder(tf.float32,
                                               shape=(config.batch_size,
                                                      config.max_phr_len, 1),
                                               name='f0_output_placeholder')

        f0_input_placeholder = tf.placeholder(tf.float32,
                                              shape=(config.batch_size,
                                                     config.max_phr_len),
                                              name='f0_input_placeholder')
        f0_onehot_labels = tf.one_hot(indices=tf.cast(f0_input_placeholder,
                                                      tf.int32),
                                      depth=len(config.notes))

        f0_context_placeholder = tf.placeholder(tf.float32,
                                                shape=(config.batch_size,
                                                       config.max_phr_len, 1),
                                                name='f0_context_placeholder')

        uv_placeholder = tf.placeholder(tf.float32,
                                        shape=(config.batch_size,
                                               config.max_phr_len, 1),
                                        name='uv_placeholder')

        phone_context_placeholder = tf.placeholder(
            tf.float32,
            shape=(config.batch_size, config.max_phr_len, 1),
            name='phone_context_placeholder')

        rand_input_placeholder = tf.placeholder(tf.float32,
                                                shape=(config.batch_size,
                                                       config.max_phr_len, 64),
                                                name='rand_input_placeholder')

        prob = tf.placeholder_with_default(1.0, shape=())

        phoneme_labels = tf.placeholder(tf.int32,
                                        shape=(config.batch_size,
                                               config.max_phr_len),
                                        name='phoneme_placeholder')
        phone_onehot_labels = tf.one_hot(indices=tf.cast(
            phoneme_labels, tf.int32),
                                         depth=len(config.phonemas))

        with tf.variable_scope('Generator_feats') as scope:
            inputs = tf.concat([
                phone_onehot_labels, f0_onehot_labels,
                phone_context_placeholder, f0_context_placeholder
            ],
                               axis=-1)
            voc_output = modules.GAN_generator(inputs)

        with tf.variable_scope('Discriminator_feats') as scope:
            inputs = tf.concat([
                phone_onehot_labels, f0_onehot_labels,
                phone_context_placeholder, f0_context_placeholder
            ],
                               axis=-1)
            D_real = modules.GAN_discriminator((output_placeholder - 0.5) * 2,
                                               inputs)
            scope.reuse_variables()
            D_fake = modules.GAN_discriminator(voc_output, inputs)

        with tf.variable_scope('Generator_f0') as scope:
            inputs = tf.concat([
                phone_onehot_labels, f0_onehot_labels,
                phone_context_placeholder, f0_context_placeholder,
                output_placeholder
            ],
                               axis=-1)
            # inputs = tf.concat([phone_onehot_labels, f0_onehot_labels, phone_context_placeholder, f0_context_placeholder, (voc_output/2)+0.5], axis = -1)
            f0_output = modules.GAN_generator_f0(inputs)

            scope.reuse_variables()

            inputs = tf.concat([
                phone_onehot_labels, f0_onehot_labels,
                phone_context_placeholder, f0_context_placeholder,
                (voc_output / 2) + 0.5
            ],
                               axis=-1)
            f0_output_2 = modules.GAN_generator_f0(inputs)

        with tf.variable_scope('Discriminator_f0') as scope:
            inputs = tf.concat([
                phone_onehot_labels, f0_onehot_labels,
                phone_context_placeholder, f0_context_placeholder,
                output_placeholder
            ],
                               axis=-1)
            D_real_f0 = modules.GAN_discriminator_f0(
                (f0_output_placeholder - 0.5) * 2, inputs)
            scope.reuse_variables()
            D_fake_f0 = modules.GAN_discriminator_f0(f0_output, inputs)

            scope.reuse_variables()

            inputs = tf.concat([
                phone_onehot_labels, f0_onehot_labels,
                phone_context_placeholder, f0_context_placeholder,
                (voc_output / 2) + 0.5
            ],
                               axis=-1)
            D_real_f0_2 = modules.GAN_discriminator_f0(
                (f0_output_placeholder - 0.5) * 2, inputs)
            scope.reuse_variables()
            D_fake_f0_2 = modules.GAN_discriminator_f0(f0_output_2, inputs)

        g_params_feats = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                           scope="Generator_feats")

        d_params_feats = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                           scope="Discriminator_feats")

        g_params_f0 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                        scope="Generator_f0")

        d_params_f0 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                        scope="Discriminator_f0")

        D_loss = tf.reduce_mean(D_real + 1e-12) - tf.reduce_mean(D_fake +
                                                                 1e-12)

        dis_summary = tf.summary.scalar('dis_loss', D_loss)

        G_loss_GAN = tf.reduce_mean(D_fake + 1e-12) + tf.reduce_sum(
            tf.abs(output_placeholder - (voc_output / 2 + 0.5))) * 0.00005

        gen_summary = tf.summary.scalar('gen_loss', G_loss_GAN)

        D_loss_f0 = tf.reduce_mean(D_real_f0 +
                                   1e-12) - tf.reduce_mean(D_fake_f0 + 1e-12)

        dis_summary_f0 = tf.summary.scalar('dis_loss_f0', D_loss_f0)

        G_loss_GAN_f0 = tf.reduce_mean(D_fake_f0 + 1e-12) + tf.reduce_sum(
            tf.abs(f0_output_placeholder - (f0_output / 2 + 0.5))) * 0.00005
        # + tf.reduce_mean(D_fake_f0_2+1e-12) + tf.reduce_sum(tf.abs(f0_output_placeholder- (f0_output_2/2+0.5))) *0.00005

        D_loss_f0_2 = tf.reduce_mean(D_real_f0_2 +
                                     1e-12) - tf.reduce_mean(D_fake_f0_2 +
                                                             1e-12)

        G_loss_GAN_f0_2 = tf.reduce_mean(D_fake_f0_2 + 1e-12) + tf.reduce_sum(
            tf.abs(f0_output_placeholder - (f0_output_2 / 2 + 0.5))) * 0.00005

        gen_summary_f0 = tf.summary.scalar('gen_loss_f0', G_loss_GAN_f0)

        summary = tf.summary.merge_all()

        global_step = tf.Variable(0, name='global_step', trainable=False)

        global_step_dis = tf.Variable(0,
                                      name='global_step_dis',
                                      trainable=False)

        global_step_f0 = tf.Variable(0, name='global_step_f0', trainable=False)
        global_step_dis_f0 = tf.Variable(0,
                                         name='global_step_dis_f0',
                                         trainable=False)

        global_step_f0_2 = tf.Variable(0,
                                       name='global_step_f0_2',
                                       trainable=False)

        global_step_dis_f0_2 = tf.Variable(0,
                                           name='global_step_dis_f0_2',
                                           trainable=False)

        dis_optimizer = tf.train.RMSPropOptimizer(learning_rate=5e-5)

        gen_optimizer = tf.train.RMSPropOptimizer(learning_rate=5e-5)

        dis_optimizer_f0 = tf.train.RMSPropOptimizer(learning_rate=5e-5)

        gen_optimizer_f0 = tf.train.RMSPropOptimizer(learning_rate=5e-5)

        dis_optimizer_f0_2 = tf.train.RMSPropOptimizer(learning_rate=5e-5)

        gen_optimizer_f0_2 = tf.train.RMSPropOptimizer(learning_rate=5e-5)
        # GradientDescentOptimizer

        dis_train_function = dis_optimizer.minimize(
            D_loss, global_step=global_step_dis, var_list=d_params_feats)

        gen_train_function = gen_optimizer.minimize(G_loss_GAN,
                                                    global_step=global_step,
                                                    var_list=g_params_feats)

        dis_train_function_f0 = dis_optimizer.minimize(
            D_loss_f0, global_step=global_step_dis_f0, var_list=d_params_f0)

        gen_train_function_f0 = gen_optimizer.minimize(
            G_loss_GAN_f0, global_step=global_step_f0, var_list=g_params_f0)

        dis_train_function_f0_2 = dis_optimizer.minimize(
            D_loss_f0_2,
            global_step=global_step_dis_f0_2,
            var_list=d_params_f0)

        gen_train_function_f0_2 = gen_optimizer.minimize(
            G_loss_GAN_f0_2,
            global_step=global_step_f0_2,
            var_list=g_params_f0)

        clip_discriminator_var_op_feats = [
            var.assign(tf.clip_by_value(var, -0.01, 0.01))
            for var in d_params_feats
        ]

        clip_discriminator_var_op_f0 = [
            var.assign(tf.clip_by_value(var, -0.01, 0.01))
            for var in d_params_f0
        ]

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        saver = tf.train.Saver(max_to_keep=config.max_models_to_keep)
        sess = tf.Session()

        sess.run(init_op)

        ckpt = tf.train.get_checkpoint_state(config.log_dir)

        if ckpt and ckpt.model_checkpoint_path:
            print("Using the model in %s" % ckpt.model_checkpoint_path)
            saver.restore(sess, ckpt.model_checkpoint_path)

        train_summary_writer = tf.summary.FileWriter(config.log_dir + 'train/',
                                                     sess.graph)
        val_summary_writer = tf.summary.FileWriter(config.log_dir + 'val/',
                                                   sess.graph)

        start_epoch = int(
            sess.run(tf.train.get_global_step()) /
            (config.batches_per_epoch_train))

        print("Start from: %d" % start_epoch)

        for epoch in xrange(start_epoch, config.num_epochs):

            if epoch < 25 or epoch % 100 == 0:
                n_critic = 25
            else:
                n_critic = 5

            if epoch < 1025 or epoch % 100 == 0:
                n_critic_f0 = 25
            else:
                n_critic_f0 = 5

            data_generator = data_gen(sec_mode=0)
            start_time = time.time()

            val_generator = data_gen(mode='val')

            batch_num = 0

            # epoch_pho_loss = 0
            epoch_gen_loss = 0
            epoch_dis_loss = 0
            epoch_gen_loss_f0 = 0
            epoch_dis_loss_f0 = 0

            with tf.variable_scope('Training'):

                for feats, conds in data_generator:
                    f0 = conds[:, :, 2]
                    phones = conds[:, :, 0]
                    f0_context = conds[:, :, -1:]
                    phones_context = conds[:, :, 1:2]

                    feed_dict = {
                        f0_output_placeholder: feats[:, :, -2:-1],
                        f0_input_placeholder: f0,
                        phoneme_labels: phones,
                        phone_context_placeholder: phones_context,
                        f0_context_placeholder: f0_context,
                        output_placeholder: feats[:, :, :64],
                        uv_placeholder: feats[:, :, -1:]
                    }

                    for critic_itr in range(n_critic):

                        sess.run(dis_train_function, feed_dict=feed_dict)
                        sess.run(clip_discriminator_var_op_feats,
                                 feed_dict=feed_dict)

                    # feed_dict = {input_placeholder: feats, output_placeholder: feats[:,:,:-2], f0_input_placeholder: f0, rand_input_placeholder: np.random.uniform(-1.0, 1.0, size=[30,config.max_phr_len,4]),
                    # phoneme_labels:phos, singer_labels: singer_ids, phoneme_labels_shuffled:phos_shu, singer_labels_shuffled:sing_id_shu}

                    _, step_gen_loss = sess.run(
                        [gen_train_function, G_loss_GAN], feed_dict=feed_dict)
                    # import pdb;pdb.set_trace()
                    # if step_gen_acc>0.3:
                    step_dis_loss = sess.run(D_loss, feed_dict=feed_dict)

                    # feed_dict = {input_placeholder: feats, output_placeholder: feats[:,:,:-2], f0_input_placeholder: f0, rand_input_placeholder: np.random.uniform(-1.0, 1.0, size=[30,config.max_phr_len,4]),
                    # phoneme_labels:phos, singer_labels: singer_ids, phoneme_labels_shuffled:phos_shu, singer_labels_shuffled:sing_id_shu}

                    if epoch > 1000:
                        for critic_itr in range(n_critic_f0):
                            sess.run(dis_train_function_f0_2,
                                     feed_dict=feed_dict)
                            sess.run(clip_discriminator_var_op_f0,
                                     feed_dict=feed_dict)

                    # feed_dict = {input_placeholder: feats, output_placeholder: feats[:,:,:-2], f0_input_placeholder: f0, rand_input_placeholder: np.random.uniform(-1.0, 1.0, size=[30,config.max_phr_len,4]),
                    # phoneme_labels:phos, singer_labels: singer_ids, phoneme_labels_shuffled:phos_shu, singer_labels_shuffled:sing_id_shu}

                        _, step_gen_loss_2 = sess.run(
                            [gen_train_function_f0_2, G_loss_GAN_f0_2],
                            feed_dict=feed_dict)
                        # import pdb;pdb.set_trace()
                        # if step_gen_acc>0.3:
                        step_dis_loss_2 = sess.run(D_loss_f0_2,
                                                   feed_dict=feed_dict)
                    else:
                        for critic_itr in range(n_critic):
                            sess.run(dis_train_function_f0,
                                     feed_dict=feed_dict)
                            sess.run(clip_discriminator_var_op_f0,
                                     feed_dict=feed_dict)
                        _, step_gen_loss_f0 = sess.run(
                            [gen_train_function_f0, G_loss_GAN_f0],
                            feed_dict=feed_dict)
                        # import pdb;pdb.set_trace()
                        # if step_gen_acc>0.3:
                        step_dis_loss_f0 = sess.run(D_loss_f0,
                                                    feed_dict=feed_dict)

                    # _, step_pho_loss, step_pho_acc = sess.run([pho_train_function, pho_loss, pho_acc], feed_dict= feed_dict)
                    # else:
                    # step_dis_loss, step_dis_acc = sess.run([D_loss, D_accuracy], feed_dict = feed_dict)

                    # epoch_pho_loss+=step_pho_loss
                    # epoch_re_loss+=step_re_loss
                    epoch_gen_loss += step_gen_loss
                    epoch_dis_loss += step_dis_loss

                    # epoch_pho_acc+=step_pho_acc[0]
                    # epoch_gen_acc+=step_gen_acc
                    # epoch_dis_acc+=step_dis_acc
                    # epoch_dis_acc_fake+=step_dis_acc_fake

                    utils.progress(batch_num,
                                   config.batches_per_epoch_train,
                                   suffix='training done')
                    batch_num += 1

                # epoch_pho_loss = epoch_pho_loss/config.batches_per_epoch_train
                # epoch_re_loss = epoch_re_loss/config.batches_per_epoch_train
                epoch_gen_loss = epoch_gen_loss / config.batches_per_epoch_train
                epoch_dis_loss = epoch_dis_loss / config.batches_per_epoch_train
                # epoch_dis_acc_fake = epoch_dis_acc_fake/config.batches_per_epoch_train

                # epoch_pho_acc = epoch_pho_acc/config.batches_per_epoch_train
                # epoch_gen_acc = epoch_gen_acc/config.batches_per_epoch_train
                # epoch_dis_acc = epoch_dis_acc/config.batches_per_epoch_train
                summary_str = sess.run(summary, feed_dict=feed_dict)
                # import pdb;pdb.set_trace()
                train_summary_writer.add_summary(summary_str, epoch)
                # # summary_writer.add_summary(summary_str_val, epoch)
                train_summary_writer.flush()

            duration = time.time() - start_time

            # np.save('./ikala_eval/accuracies', f0_accs)

            if (epoch + 1) % config.print_every == 0:
                print('epoch %d: Gen Loss = %.10f (%.3f sec)' %
                      (epoch + 1, epoch_gen_loss, duration))
                # print('        : Phone Accuracy = %.10f ' % (epoch_pho_acc))
                # print('        : Recon Loss = %.10f ' % (epoch_re_loss))
                # print('        : Gen Loss = %.10f ' % (epoch_gen_loss))
                # print('        : Gen Accuracy = %.10f ' % (epoch_gen_acc))
                print('        : Dis Loss = %.10f ' % (epoch_dis_loss))
                # print('        : Dis Accuracy = %.10f ' % (epoch_dis_acc))
                # print('        : Dis Accuracy Fake = %.10f ' % (epoch_dis_acc_fake))
                # print('        : Val Phone Accuracy = %.10f ' % (val_epoch_pho_acc))
                # print('        : Val Gen Loss = %.10f ' % (val_epoch_gen_loss))
                # print('        : Val Gen Accuracy = %.10f ' % (val_epoch_gen_acc))
                # print('        : Val Dis Loss = %.10f ' % (val_epoch_dis_loss))
                # print('        : Val Dis Accuracy = %.10f ' % (val_epoch_dis_acc))
                # print('        : Val Dis Accuracy Fake = %.10f ' % (val_epoch_dis_acc_fake))

            if (epoch + 1) % config.save_every == 0 or (
                    epoch + 1) == config.num_epochs:
                # utils.list_to_file(val_f0_accs,'./ikala_eval/accuracies_'+str(epoch+1)+'.txt')
                checkpoint_file = os.path.join(config.log_dir, 'model.ckpt')
                saver.save(sess, checkpoint_file, global_step=epoch)
Beispiel #6
0
def trainNetwork(save_name='model_e' + str(config.num_epochs) + '_b' +
                 str(config.batches_per_epoch_train) + '_bs' +
                 str(config.batch_size)):
    assert torch.cuda.is_available(), "Code only usable with cuda"

    #autoencoder =  AutoEncoder().cuda()

    autoencoder = AutoEncoder().cuda()

    autoencoder.load_state_dict(
        torch.load('./log/model_e8000_b50_bs5_3469.pt'))

    optimizer = torch.optim.Adadelta(autoencoder.parameters(), lr=1, rho=0.95)

    loss_func = nn.MSELoss(size_average=False)
    #loss_func   =  nn.L1Loss( size_average=False )

    train_evol = []

    val_evol = []

    count = 0

    for epoch in range(config.num_epochs):

        start_time = time.time()

        generator = data_gen()

        val_gen = data_gen(mode="Val")

        train_loss = 0
        train_loss_vocals = 0
        train_loss_drums = 0
        train_loss_bass = 0
        train_alpha_diff = 0
        train_beta_other = 0
        train_beta_other_voc = 0

        val_loss = 0
        val_loss_vocals = 0
        val_loss_drums = 0
        val_loss_bass = 0
        val_alpha_diff = 0
        val_beta_other = 0
        val_beta_other_voc = 0

        optimizer.zero_grad()

        count = 0

        for inputs, targets in generator:

            step_loss_vocals, step_loss_drums, step_loss_bass, alpha_diff, beta_other, beta_other_voc = loss_calc(
                inputs, targets, loss_func, autoencoder)
            # start_time = time.time()

            # add regularization terms from paper
            step_loss = abs(step_loss_vocals + step_loss_drums +
                            step_loss_bass - beta_other - alpha_diff -
                            beta_other_voc)

            # print time.time()-start_time
            # import pdb;pdb.set_trace()
            # start_time = time.time()

            train_loss += step_loss.item()
            if np.isnan(train_loss):
                #import pdb;pdb.set_trace()
                optimizer.zero_grad()
                print("error output contains NaN")
            train_loss_vocals += step_loss_vocals.item()
            train_loss_drums += step_loss_drums.item()
            train_loss_bass += step_loss_bass.item()
            train_alpha_diff += alpha_diff.item()
            train_beta_other += beta_other.item()
            train_beta_other_voc += beta_other_voc.item()

            step_loss.backward()
            #clip gradient
            # torch.nn.utils.clip_grad_norm_( autoencoder.parameters(),1)

            for p in autoencoder.parameters():
                p.grad.data.clamp(-1, 1)

            optimizer.step()

            # print time.time()-start_time

            utils.progress(count,
                           config.batches_per_epoch_train,
                           suffix='training done')

            count += 1

        train_loss = train_loss / (config.batches_per_epoch_train * count *
                                   config.max_phr_len * 513)
        train_loss_vocals = train_loss_vocals / (
            config.batches_per_epoch_train * count * config.max_phr_len * 513)
        train_loss_drums = train_loss_drums / (
            config.batches_per_epoch_train * count * config.max_phr_len * 513)
        train_loss_bass = train_loss_bass / (config.batches_per_epoch_train *
                                             count * config.max_phr_len * 513)
        train_alpha_diff = train_alpha_diff / (
            config.batches_per_epoch_train * count * config.max_phr_len * 513)
        train_beta_other = train_beta_other / (
            config.batches_per_epoch_train * count * config.max_phr_len * 513)
        train_beta_other_voc = train_beta_other_voc / (
            config.batches_per_epoch_train * count * config.max_phr_len * 513)

        train_evol.append([
            train_loss, train_loss_vocals, train_loss_drums, train_loss_bass,
            train_alpha_diff, train_beta_other, train_beta_other_voc
        ])

        count = 0

        for inputs, targets in val_gen:

            step_loss_vocals, step_loss_drums, step_loss_bass, alpha_diff, beta_other, beta_other_voc = loss_calc(
                inputs, targets, loss_func, autoencoder)

            # add regularization terms from paper
            step_loss = abs(step_loss_vocals + step_loss_drums +
                            step_loss_bass - beta_other - alpha_diff -
                            beta_other_voc)

            val_loss += step_loss.item()
            val_loss_vocals += step_loss_vocals.item()
            val_loss_drums += step_loss_drums.item()
            val_loss_bass += step_loss_bass.item()
            val_alpha_diff += alpha_diff.item()
            val_beta_other += beta_other.item()
            val_beta_other_voc += beta_other_voc.item()

            utils.progress(count,
                           config.batches_per_epoch_val,
                           suffix='validation done')

            count += 1
        val_loss = val_loss / (config.batches_per_epoch_val * count *
                               config.max_phr_len * 513)
        val_loss_vocals = val_loss_vocals / (config.batches_per_epoch_val *
                                             count * config.max_phr_len * 513)
        val_loss_drums = val_loss_drums / (config.batches_per_epoch_val *
                                           count * config.max_phr_len * 513)
        val_loss_bass = val_loss_bass / (config.batches_per_epoch_val * count *
                                         config.max_phr_len * 513)
        val_alpha_diff = val_alpha_diff / (config.batches_per_epoch_val *
                                           count * config.max_phr_len * 513)
        val_beta_other = val_beta_other / (config.batches_per_epoch_val *
                                           count * config.max_phr_len * 513)
        val_beta_other_voc = val_beta_other_voc / (
            config.batches_per_epoch_val * count * config.max_phr_len * 513)
        val_evol.append([
            val_loss, val_loss_vocals, val_loss_drums, val_loss_bass,
            val_alpha_diff, val_beta_other, val_beta_other_voc
        ])

        # import pdb;pdb.set_trace()

        duration = time.time() - start_time

        if (epoch + 1) % config.print_every == 0:
            print('epoch %d/%d, took %.2f seconds, epoch total loss: %.7f' %
                  (epoch + 1, config.num_epochs, duration, train_loss))
            print('                                  epoch vocal loss: %.7f' %
                  (train_loss_vocals))
            print('                                  epoch drums loss: %.7f' %
                  (train_loss_drums))
            print('                                  epoch bass  loss: %.7f' %
                  (train_loss_bass))
            print('                                  epoch alpha diff: %.7f' %
                  (train_alpha_diff))
            print('                                  epoch beta  diff: %.7f' %
                  (train_beta_other))
            print('                                  epoch beta2 diff: %.7f' %
                  (train_beta_other_voc))

            print(
                '                                  validation total loss: %.7f'
                % (val_loss))
            print(
                '                                  validation vocal loss: %.7f'
                % (val_loss_vocals))
            print(
                '                                  validation drums loss: %.7f'
                % (val_loss_drums))
            print(
                '                                  validation bass  loss: %.7f'
                % (val_loss_bass))
            print(
                '                                  validation alpha diff: %.7f'
                % (val_alpha_diff))
            print(
                '                                  validation beta  diff: %.7f'
                % (val_beta_other))
            print(
                '                                  validation beta2 diff: %.7f'
                % (val_beta_other_voc))

        # import pdb;pdb.set_trace()
        if (epoch + 1) % config.save_every == 0:
            torch.save(
                autoencoder.state_dict(),
                config.log_dir + save_name + '_' + str(epoch + 3470) + '.pt')
            np.save(config.log_dir + 'train_loss', np.array(train_evol))
            np.save(config.log_dir + 'val_loss', np.array(val_evol))
        # import pdb;pdb.set_trace()

    torch.save(autoencoder.state_dict(),
               config.log_dir + save_name + '_' + str(epoch + 99) + '.pt')
Beispiel #7
0
    def source_separate(self):
        sess = tf.Session()
        self.load_model(sess, log_dir=self.config.log_dir)
        val_generator = data_gen(self.config)
        count_batch = 0
        for batch_count, [
                out_audios, out_envelopes, out_features, total_count
        ] in enumerate(val_generator):
            out_envelopes_bass = np.copy(out_envelopes)
            out_envelopes_bass[:, :, 1:3] = 0
            out_envelopes_mid = np.copy(out_envelopes)
            out_envelopes_mid[:, :, 0] = 0
            out_envelopes_mid[:, :, 2] = 0
            out_envelopes_high = np.copy(out_envelopes)
            out_envelopes_high[:, :, :2] = 0

            feed_dict = {self.input_placeholder: out_envelopes_bass[:,:,:self.config.rhyfeats], self.cond_placeholder: out_features,\
             self.output_placeholder: out_audios, self.is_train: False}
            output_bass = sess.run(self.output_wav, feed_dict=feed_dict)

            feed_dict = {self.input_placeholder: out_envelopes_mid[:,:,:self.config.rhyfeats], self.cond_placeholder: out_features,\
             self.output_placeholder: out_audios, self.is_train: False}
            output_mid = sess.run(self.output_wav, feed_dict=feed_dict)

            feed_dict = {self.input_placeholder: out_envelopes_high[:,:,:self.config.rhyfeats], self.cond_placeholder: out_features,\
             self.output_placeholder: out_audios, self.is_train: False}
            output_high = sess.run(self.output_wav, feed_dict=feed_dict)

            for count in range(self.config.batch_size):
                if self.config.model == "spec":
                    out_audio_bass = utils.griffinlim(
                        np.exp(output_bass[count]) - 1, self.config)
                    out_audio_mid = utils.griffinlim(
                        np.exp(output_mid[count]) - 1, self.config)
                    out_audio_high = utils.griffinlim(
                        np.exp(output_high[count]) - 1, self.config)
                else:
                    out_audio_bass = output_bass[count]
                    out_audio_mid = output_mid[count]
                    out_audio_high = output_high[count]

                output_file_bass = os.path.join(
                    self.config.output_dir,
                    'output_{}_{}_{}_bass.wav'.format(batch_count, count,
                                                      self.config.model))
                sf.write(output_file_bass, np.clip(out_audio_bass, -1, 1),
                         self.config.fs)
                output_file_mid = os.path.join(
                    self.config.output_dir,
                    'output_{}_{}_{}_mid.wav'.format(batch_count, count,
                                                     self.config.model))
                sf.write(output_file_mid, np.clip(out_audio_mid, -1, 1),
                         self.config.fs)
                output_file_high = os.path.join(
                    self.config.output_dir,
                    'output_{}_{}_{}_high.wav'.format(batch_count, count,
                                                      self.config.model))
                sf.write(output_file_high, np.clip(out_audio_high, -1, 1),
                         self.config.fs)

                sf.write(
                    os.path.join(self.config.output_dir,
                                 'gt_{}_{}.wav'.format(batch_count, count)),
                    out_audios[count], self.config.fs)
            utils.progress(batch_count, total_count)
Beispiel #8
0
def trainNetwork(dataset='model6'):
    save_name = 'dn_model'
    # Encoder
    denoiser_vocals = Encoder().cuda()
    autoencoder = AutoEncoder()

    autoencoder.load_state_dict(torch.load(config.log_dir + dataset + '.pt'))

    optimizer = torch.optim.SGD(denoiser_vocals.parameters(), 1e-6)

    loss_func = nn.L1Loss(size_average=False)

    optimizer.zero_grad()

    train_evol = []

    eval_evol = []

    for epoch in range(config.dn_num_epochs):

        start_time = time.time()

        train_gen = data_gen()

        val_gen = data_gen(mode="Val")

        optimizer.zero_grad()
        train_loss = 0
        eval_loss = 0
        count = 0
        for inputs, targets in train_gen:

            output = autoencoder(Variable(torch.FloatTensor(inputs))).cuda()

            target_vocals = targets[:, :2, :, :]

            target_drums = targets[:, 2:4, :, :]

            target_bass = targets[:, 4:6, :, :]

            target_others = targets[:, 6:, :, :]

            vocals = output[:, :2, :, :]

            drums = output[:, 2:4, :, :]

            bass = output[:, 4:6, :, :]

            others = output[:, 6:, :, :]

            total_sources = vocals + bass + drums + others

            mask_vocals = vocals / total_sources

            mask_drums = drums / total_sources

            mask_bass = bass / total_sources

            mask_others = others / total_sources

            out_vocals = vocals * mask_vocals

            out_drums = drums * mask_drums

            out_bass = bass * mask_bass

            out_others = others * mask_others

            input_vocals = Variable(out_vocals)

            denoised_vocals = denoiser_vocals(input_vocals).cuda()

            step_loss = loss_func(
                denoised_vocals,
                Variable(torch.cuda.FloatTensor(target_vocals),
                         requires_grad=False))

            train_loss += step_loss.item()
            step_loss.backward()

            optimizer.step()

            utils.progress(count,
                           config.batches_per_epoch_train,
                           suffix='training done')

            count += 1
        train_evol.append(train_loss)
        count = 0

        for inputs, targets in val_gen:

            out_sources = autoencoder(Variable(
                torch.FloatTensor(inputs))).cuda()

            vocals = output[:, :2, :, :]

            drums = output[:, 2:4, :, :]

            bass = output[:, 4:6, :, :]

            others = output[:, 6:, :, :]

            target_vocals = targets[:, :2, :, :]

            target_drums = targets[:, 2:4, :, :]

            target_bass = targets[:, 4:6, :, :]

            target_otherss = targets[:, 6:, :, :]

            total_sources = vocals + bass + drums + others

            mask_vocals = vocals / total_sources

            mask_drums = drums / total_sources

            mask_bass = bass / total_sources

            mask_others = others / total_sources

            out_vocals = vocals * mask_vocals

            out_drums = drums * mask_drums

            out_bass = bass * mask_bass

            out_others = others * mask_others
            input_vocals = Variable(out_vocals)
            denoised_vocals = denoiser_vocals(input_vocals).cuda()

            step_loss = loss_func(
                denoised_vocals,
                Variable(torch.cuda.FloatTensor(target_vocals),
                         requires_grad=False))

            eval_loss += step_loss.item()
            utils.progress(count,
                           config.batches_per_epoch_val,
                           suffix='validation done')
            count += 1
        eval_evol.append(eval_loss)
        duration = time.time() - start_time

        if (epoch + 1) % config.print_every == 0:
            print('epoch %d/%d, took %.2f seconds, epoch total loss: %.7f' %
                  (epoch + 1, config.num_epochs, duration, train_loss /
                   (config.batches_per_epoch_train * count *
                    config.max_phr_len * 513)))
            print(
                '                                  validation total loss: %.7f'
                % (eval_loss / (config.batches_per_epoch_train * count *
                                config.max_phr_len * 513)))

        if (epoch + 1) % config.save_every == 0:
            torch.save(
                denoiser_vocals.state_dict(),
                config.dn_log_dir + save_name + '_' + str(epoch) + '.pt')
            np.save(config.dn_log_dir + 'dn_train_loss', np.array(train_evol))
            np.save(config.dn_log_dir + 'dn_val_loss', np.array(eval_evol))
def train(_):
    stat_file = h5py.File(config.stat_dir+'stats.hdf5', mode='r')
    max_feat = np.array(stat_file["feats_maximus"])
    min_feat = np.array(stat_file["feats_minimus"])
    with tf.Graph().as_default():
        
        input_placeholder = tf.placeholder(tf.float32, shape=(config.batch_size,config.max_phr_len,config.input_features),name='input_placeholder')
        tf.summary.histogram('inputs', input_placeholder)
        target_placeholder = tf.placeholder(tf.float32, shape=(config.batch_size,config.max_phr_len,3),name='target_placeholder')
        tf.summary.histogram('targets', target_placeholder)

        with tf.variable_scope('First_Model') as scope:
            f0, f0_1, vuv = modules.f0_network(input_placeholder)

            # tf.summary.histogram('initial_output', op)

            # tf.summary.histogram('harm', harm)

            # tf.summary.histogram('ap', ap)

            tf.summary.histogram('f0', f0)

            tf.summary.histogram('vuv', vuv)

        # initial_loss = tf.reduce_sum(tf.abs(op - target_placeholder[:,:,:60])*np.linspace(1.0,0.7,60)*(1-target_placeholder[:,:,-1:]))

        # harm_loss = tf.reduce_sum(tf.abs(harm - target_placeholder[:,:,:60])*np.linspace(1.0,0.7,60)*(1-target_placeholder[:,:,-1:]))

        # ap_loss = tf.reduce_sum(tf.abs(ap - target_placeholder[:,:,60:-2])*(1-target_placeholder[:,:,-1:]))

        f0_loss_1 = tf.reduce_sum(tf.abs(f0 - target_placeholder[:,:,-3:-2])*(1-target_placeholder[:,:,-1:])) 

        f0_loss_2 = tf.reduce_sum(tf.abs(f0_1 - target_placeholder[:,:,-2:-1])*(1-target_placeholder[:,:,-1:])) 

        # vuv_loss = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels=, logits=vuv))

        vuv_loss = tf.reduce_sum(binary_cross(target_placeholder[:,:,-1:],vuv))

        loss = f0_loss_1 + vuv_loss + f0_loss_2

        # initial_summary = tf.summary.scalar('initial_loss', initial_loss)

        # harm_summary = tf.summary.scalar('harm_loss', harm_loss)

        # ap_summary = tf.summary.scalar('ap_loss', ap_loss)

        f0_summary_1 = tf.summary.scalar('f0_loss_1', f0_loss_1)

        f0_summary_2 = tf.summary.scalar('f0_loss_2', f0_loss_2)

        vuv_summary = tf.summary.scalar('vuv_loss', vuv_loss)

        loss_summary = tf.summary.scalar('total_loss', loss)

        global_step = tf.Variable(0, name='global_step', trainable=False)

        optimizer = tf.train.AdamOptimizer(learning_rate = config.init_lr)

        # optimizer_f0 = tf.train.AdamOptimizer(learning_rate = config.init_lr)

        train_function = optimizer.minimize(loss, global_step= global_step)

        # train_f0 = optimizer.minimize(f0_loss, global_step= global_step)

        # train_harm = optimizer.minimize(harm_loss, global_step= global_step)

        # train_ap = optimizer.minimize(ap_loss, global_step= global_step)

        # train_f0 = optimizer.minimize(f0_loss, global_step= global_step)

        # train_vuv = optimizer.minimize(vuv_loss, global_step= global_step)

        summary = tf.summary.merge_all()

        init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
        saver = tf.train.Saver(max_to_keep= config.max_models_to_keep)
        sess = tf.Session()

        sess.run(init_op)

        ckpt = tf.train.get_checkpoint_state(config.log_dir)

        if ckpt and ckpt.model_checkpoint_path:
            print("Using the model in %s"%ckpt.model_checkpoint_path)
            saver.restore(sess, ckpt.model_checkpoint_path)


        train_summary_writer = tf.summary.FileWriter(config.log_dir+'train/', sess.graph)
        val_summary_writer = tf.summary.FileWriter(config.log_dir+'val/', sess.graph)

        
        start_epoch = int(sess.run(tf.train.get_global_step())/(config.batches_per_epoch_train))

        print("Start from: %d" % start_epoch)
        f0_accs = []
        for epoch in xrange(start_epoch, config.num_epochs):
            val_f0_accs_1 = []
            val_f0_accs_2 = []


            data_generator = data_gen()
            start_time = time.time()

            epoch_loss_harm = 0
            epoch_loss_ap = 0
            epoch_loss_f0_1 = 0
            epoch_loss_f0_2 = 0
            epoch_loss_vuv = 0
            epoch_total_loss = 0
            # epoch_initial_loss = 0

            epoch_loss_harm_val = 0
            epoch_loss_ap_val = 0
            epoch_loss_f0_val_1 = 0
            epoch_loss_f0_val_2 = 0
            epoch_loss_vuv_val = 0
            epoch_total_loss_val = 0
            # epoch_initial_loss_val = 0

            if config.use_gan:
                epoch_loss_generator_GAN = 0
                epoch_loss_generator_diff = 0
                epoch_loss_discriminator_real = 0
                epoch_loss_discriminator_fake = 0

                val_epoch_loss_generator_GAN = 0
                val_epoch_loss_generator_diff = 0
                val_epoch_loss_discriminator_real = 0
                val_epoch_loss_discriminator_fake = 0

            batch_num = 0
            batch_num_val = 0
            val_generator = data_gen(mode='val')

            # val_generator = get_batches(train_filename=config.h5py_file_val, batches_per_epoch=config.batches_per_epoch_val)

            with tf.variable_scope('Training'):

                for voc, feat in data_generator:

                    _, step_loss_f0_1,step_loss_f0_2, step_loss_vuv, step_total_loss = sess.run([train_function, 
                        f0_loss_1,f0_loss_2, vuv_loss, loss], feed_dict={input_placeholder: voc,target_placeholder: feat})
                    # _, step_loss_f0 = sess.run([train_f0, f0_loss], feed_dict={input_placeholder: voc,target_placeholder: feat})
                    
                    if config.use_gan:
                        _, step_dis_loss_real, step_dis_loss_fake = sess.run([d_optimizer, D_loss_real,D_loss_fake], feed_dict={input_placeholder: voc,target_placeholder: feat})
                        _, step_gen_loss_GAN, step_gen_loss_diff = sess.run([g_optimizer, G_loss_GAN, G_loss_diff], feed_dict={input_placeholder: voc,target_placeholder: feat})
                    # else :
                    #     _, step_dis_loss_real, step_dis_loss_fake = sess.run([d_optimizer_grad, D_loss_real,D_loss_fake], feed_dict={input_placeholder: voc,target_placeholder: feat})
                    #     _, step_gen_loss_diff = sess.run([g_optimizer_diff, G_loss_diff], feed_dict={input_placeholder: voc,target_placeholder: feat})
                    #     step_gen_loss_GAN = 0




                    # _, step_loss_harm = sess.run([train_harm, harm_loss], feed_dict={input_placeholder: voc,target_placeholder: feat})
                    # _, step_loss_ap = sess.run([train_ap, ap_loss], feed_dict={input_placeholder: voc,target_placeholder: feat})
                    # _, step_loss_f0 = sess.run([train_f0, f0_loss], feed_dict={input_placeholder: voc,target_placeholder: feat})
                    # _, step_loss_vuv = sess.run([train_vuv, vuv_loss], feed_dict={input_placeholder: voc,target_placeholder: feat})

                    # epoch_initial_loss+=step_initial_loss
                    # epoch_loss_harm+=step_loss_harm
                    # epoch_loss_ap+=step_loss_ap
                    epoch_loss_f0_1+=step_loss_f0_1
                    epoch_loss_f0_2+=step_loss_f0_2
                    epoch_loss_vuv+=step_loss_vuv
                    epoch_total_loss+=step_total_loss

                    if config.use_gan:

                        epoch_loss_generator_GAN+=step_gen_loss_GAN
                        epoch_loss_generator_diff+=step_gen_loss_diff
                        epoch_loss_discriminator_real+=step_dis_loss_real
                        epoch_loss_discriminator_fake+=step_dis_loss_fake



                    utils.progress(batch_num,config.batches_per_epoch_train, suffix = 'training done')
                    batch_num+=1


                # epoch_initial_loss = epoch_initial_loss/(config.batches_per_epoch_train *config.batch_size*config.max_phr_len*60)
                # epoch_loss_harm = epoch_loss_harm/(config.batches_per_epoch_train *config.batch_size*config.max_phr_len*60)
                # epoch_loss_ap = epoch_loss_ap/(config.batches_per_epoch_train *config.batch_size*config.max_phr_len*4)
                epoch_loss_f0_1 = epoch_loss_f0_1/(config.batches_per_epoch_train *config.batch_size*config.max_phr_len)
                epoch_loss_f0_2 = epoch_loss_f0_2/(config.batches_per_epoch_train *config.batch_size*config.max_phr_len)
                epoch_loss_vuv = epoch_loss_vuv/(config.batches_per_epoch_train *config.batch_size*config.max_phr_len)
                epoch_total_loss = epoch_total_loss/(config.batches_per_epoch_train *config.batch_size*config.max_phr_len*3)

                if config.use_gan:

                    epoch_loss_generator_GAN = epoch_loss_generator_GAN/(config.batches_per_epoch_train *config.batch_size)
                    epoch_loss_generator_diff = epoch_loss_generator_diff/(config.batches_per_epoch_train *config.batch_size*config.max_phr_len*60)
                    epoch_loss_discriminator_real = epoch_loss_discriminator_real/(config.batches_per_epoch_train *config.batch_size)
                    epoch_loss_discriminator_fake = epoch_loss_discriminator_fake/(config.batches_per_epoch_train *config.batch_size)
                

                summary_str = sess.run(summary, feed_dict={input_placeholder: voc,target_placeholder: feat})
                train_summary_writer.add_summary(summary_str, epoch)
                # summary_writer.add_summary(summary_str_val, epoch)
                train_summary_writer.flush()

            with tf.variable_scope('Validation'):

                for voc, feat,nchunks_in, lent, county, max_count in val_generator:

                    if (epoch + 1) % config.print_every == 0 or (epoch + 1) == config.num_epochs:

                        if county == 1:
                            f0_gt = []
                            vuv_gt = []
                            f0_output_1 = []
                            f0_output_2 = []

                        f0_op_1, f0_op_2 = sess.run([f0,f0_1],feed_dict={input_placeholder: voc,target_placeholder: feat})
                        f0_output_1.append(f0_op_1)
                        f0_output_2.append(f0_op_2)
                        f0_gt.append(feat[:,:,-2:-1])
                        vuv_gt.append(feat[:,:,-1:])

                        if county == max_count:
                            f0_output_1 = utils.overlapadd(np.array(f0_output_1), nchunks_in) 
                            f0_output_2 = utils.overlapadd(np.array(f0_output_2), nchunks_in) 
                            f0_gt = utils.overlapadd(np.array(f0_gt), nchunks_in) 
                            vuv_gt = utils.overlapadd(np.array(vuv_gt), nchunks_in) 

                            f0_output_1 = f0_output_1[:lent]
                            f0_output_2 = f0_output_2[:lent]
                            f0_gt = f0_gt[:lent]
                            vuv_gt = vuv_gt[:lent]

                            f0_output_1 = f0_output_1*((max_feat[-2]-min_feat[-2])+min_feat[-2])*(1-vuv_gt)
                            f0_output_2 = f0_output_2*((max_feat[-2]-min_feat[-2])+min_feat[-2])*(1-vuv_gt)
                            f0_gt = f0_gt*((max_feat[-2]-min_feat[-2])+min_feat[-2])*(1-vuv_gt)

                            # f0_output_1[f0_output_1 == 0] = np.nan

                            # f0_gt[f0_gt == 0] = np.nan

                            f0_difference_1 = np.nan_to_num(abs(f0_gt-f0_output_1))
                            f0_greater_1 = np.where(f0_difference_1>config.f0_threshold)
                            diff_per_1 = f0_greater_1[0].shape[0]/len(f0_output_1)
                            val_f0_accs_1.append(1 - diff_per_1)

                            f0_difference_2 = np.nan_to_num(abs(f0_gt-f0_output_2))
                            f0_greater_2 = np.where(f0_difference_2>config.f0_threshold)
                            diff_per_2 = f0_greater_2[0].shape[0]/len(f0_output_2)
                            val_f0_accs_2.append(1 - diff_per_2)
                
                        # import pdb;pdb.set_trace()






                    # step_initial_loss_val = sess.run(initial_loss, feed_dict={input_placeholder: voc,target_placeholder: feat})
                    # step_loss_harm_val = sess.run(harm_loss, feed_dict={input_placeholder: voc,target_placeholder: feat})
                    # step_loss_ap_val = sess.run(ap_loss, feed_dict={input_placeholder: voc,target_placeholder: feat})
                    step_loss_f0_val_1 = sess.run(f0_loss_1, feed_dict={input_placeholder: voc,target_placeholder: feat})
                    step_loss_f0_val_2 = sess.run(f0_loss_2, feed_dict={input_placeholder: voc,target_placeholder: feat})
                    step_loss_vuv_val = sess.run(vuv_loss, feed_dict={input_placeholder: voc,target_placeholder: feat})
                    step_total_loss_val = sess.run(loss, feed_dict={input_placeholder: voc,target_placeholder: feat})

                    if config.use_gan:
                        step_gen_loss_GAN, step_gen_loss_diff = sess.run([G_loss_GAN, G_loss_diff], feed_dict={input_placeholder: voc,target_placeholder: feat})
                        step_dis_loss_real,step_dis_loss_fake = sess.run([D_loss_real,D_loss_fake], feed_dict={input_placeholder: voc,target_placeholder: feat})

                    # epoch_initial_loss_val+=step_initial_loss_val
                    # epoch_loss_harm_val+=step_loss_harm_val
                    # epoch_loss_ap_val+=step_loss_ap_val
                    epoch_loss_f0_val_1+=step_loss_f0_val_1
                    epoch_loss_f0_val_2+=step_loss_f0_val_2
                    epoch_loss_vuv_val+=step_loss_vuv_val
                    epoch_total_loss_val+=step_total_loss_val

                    if config.use_gan:

                        val_epoch_loss_generator_GAN += step_gen_loss_GAN
                        val_epoch_loss_generator_diff += step_gen_loss_diff
                        val_epoch_loss_discriminator_real += step_dis_loss_real
                        val_epoch_loss_discriminator_fake += step_dis_loss_fake

                    utils.progress(batch_num_val,config.batches_per_epoch_val, suffix = 'validiation done')
                    batch_num_val+=1
                if (epoch + 1) % config.print_every == 0 or (epoch + 1) == config.num_epochs:    
                    f0_accs.append(np.mean(val_f0_accs_2))

                # epoch_initial_loss_val = epoch_initial_loss_val/(config.batches_per_epoch_val *config.batch_size*config.max_phr_len*60)
                # epoch_loss_harm_val = epoch_loss_harm_val/(batch_num_val *config.batch_size*config.max_phr_len*60)
                # epoch_loss_ap_val = epoch_loss_ap_val/(batch_num_val *config.batch_size*config.max_phr_len*4)
                epoch_loss_f0_val_1 = epoch_loss_f0_val_1/(batch_num_val *config.batch_size*config.max_phr_len)
                epoch_loss_f0_val_2 = epoch_loss_f0_val_2/(batch_num_val *config.batch_size*config.max_phr_len)
                epoch_loss_vuv_val = epoch_loss_vuv_val/(batch_num_val *config.batch_size*config.max_phr_len)
                epoch_total_loss_val = epoch_total_loss_val/(batch_num_val *config.batch_size*config.max_phr_len*66)

                if config.use_gan:

                    val_epoch_loss_generator_GAN = val_epoch_loss_generator_GAN/(config.batches_per_epoch_val *config.batch_size)
                    val_epoch_loss_generator_diff = val_epoch_loss_generator_diff/(config.batches_per_epoch_val *config.batch_size*config.max_phr_len*60)
                    val_epoch_loss_discriminator_real = val_epoch_loss_discriminator_real/(config.batches_per_epoch_val *config.batch_size)
                    val_epoch_loss_discriminator_fake = val_epoch_loss_discriminator_fake/(config.batches_per_epoch_val *config.batch_size)

                summary_str = sess.run(summary, feed_dict={input_placeholder: voc,target_placeholder: feat})
                val_summary_writer.add_summary(summary_str, epoch)
                # summary_writer.add_summary(summary_str_val, epoch)
                val_summary_writer.flush()

            duration = time.time() - start_time

            np.save('./ikala_eval/accuracies', f0_accs)

            if (epoch+1) % config.print_every == 0:
                print('epoch %d: F0 Training Loss = %.10f (%.3f sec)' % (epoch+1, epoch_loss_f0_1, duration))
                # print('        : Ap Training Loss = %.10f ' % (epoch_loss_ap))
                # print('        : F0 Training Loss = %.10f ' % (epoch_loss_f0))
                print('        : VUV Training Loss = %.10f ' % (epoch_loss_vuv))
                # print('        : Initial Training Loss = %.10f ' % (epoch_initial_loss))

                if config.use_gan:

                    print('        : Gen GAN Training Loss = %.10f ' % (epoch_loss_generator_GAN))
                    print('        : Gen diff Training Loss = %.10f ' % (epoch_loss_generator_diff))
                    print('        : Discriminator Training Loss Real = %.10f ' % (epoch_loss_discriminator_real))
                    print('        : Discriminator Training Loss Fake = %.10f ' % (epoch_loss_discriminator_fake))

                # print('        : Harm Validation Loss = %.10f ' % (epoch_loss_harm_val))
                # print('        : Ap Validation Loss = %.10f ' % (epoch_loss_ap_val))
                print('        : F0 Validation Loss_1 = %.10f ' % (epoch_loss_f0_val_1))
                print('        : F0 Validation Loss_2 = %.10f ' % (epoch_loss_f0_val_2))
                print('        : VUV Validation Loss = %.10f ' % (epoch_loss_vuv_val))
                
                if (epoch + 1) % config.print_every == 0 or (epoch + 1) == config.num_epochs:
                    print('        : Mean F0 IKala Accuracy_1  = %.10f ' % (np.mean(val_f0_accs_1)))
                    print('        : Mean F0 IKala Accuracy_2  = %.10f ' % (np.mean(val_f0_accs_2)))

                # print('        : Mean F0 IKala Accuracy = '+'%{1:.{0}f}%'.format(np.mean(val_f0_accs)))
                # print('        : Initial Validation Loss = %.10f ' % (epoch_initial_loss_val))

                if config.use_gan:

                    print('        : Gen GAN Validation Loss = %.10f ' % (val_epoch_loss_generator_GAN))
                    print('        : Gen diff Validation Loss = %.10f ' % (val_epoch_loss_generator_diff))
                    print('        : Discriminator Validation Loss Real = %.10f ' % (val_epoch_loss_discriminator_real))
                    print('        : Discriminator Validation Loss Fake = %.10f ' % (val_epoch_loss_discriminator_fake))


            if (epoch + 1) % config.save_every == 0 or (epoch + 1) == config.num_epochs:
                # utils.list_to_file(val_f0_accs,'./ikala_eval/accuracies_'+str(epoch+1)+'.txt')
                checkpoint_file = os.path.join(config.log_dir, 'model.ckpt')
                saver.save(sess, checkpoint_file, global_step=epoch)