def mix_model(self): sess = tf.Session() self.load_model(sess, log_dir=self.config.log_dir) val_generator = data_gen(self.config) count_batch = 0 for batch_count, [ out_audios, out_envelopes, out_features, total_count ] in enumerate(val_generator): out_features_copy = np.copy(out_features) for j in range(int(len(out_features) / 2) - 1): out_features[j] = out_features_copy[-1 - j] out_features[-1 - j] = out_features_copy[j] feed_dict = {self.input_placeholder: out_envelopes[:,:,:self.config.rhyfeats], self.cond_placeholder: out_features,\ self.output_placeholder: out_audios, self.is_train: False} output_full = sess.run(self.output_wav, feed_dict=feed_dict) for count in range(self.config.batch_size): if self.config.model == "spec": out_audio = utils.griffinlim( np.exp(output_full[count]) - 1, self.config) else: out_audio = output_full[count] output_file = os.path.join( self.config.output_dir, 'output_{}_{}_{}.wav'.format(batch_count, count, self.config.model)) sf.write(output_file, np.clip(out_audio, -1, 1), self.config.fs) sf.write( os.path.join(self.config.output_dir, 'gt_{}_{}.wav'.format(batch_count, count)), out_audios[count], self.config.fs) utils.progress(batch_count, total_count)
def train(self): """ Function to train the model, and save Tensorboard summary, for N epochs. """ sess = tf.Session() self.loss_function() self.get_optimizers() self.load_model(sess, config.log_dir) self.get_summary(sess, config.log_dir) start_epoch = int( sess.run(tf.train.get_global_step()) / (config.batches_per_epoch_train)) print("Start from: %d" % start_epoch) for epoch in range(start_epoch, config.num_epochs): data_generator = data_gen() start_time = time.time() batch_num = 0 epoch_train_loss = 0 with tf.variable_scope('Training'): for ins, outs in data_generator: step_loss, summary_str = self.train_model(ins, outs, sess) epoch_train_loss += step_loss self.train_summary_writer.add_summary(summary_str, epoch) self.train_summary_writer.flush() utils.progress(batch_num, config.batches_per_epoch_train, suffix='training done') batch_num += 1 epoch_train_loss = epoch_train_loss / batch_num print_dict = {"Training Loss": epoch_train_loss} if (epoch + 1) % config.validate_every == 0: pre, acc, rec = self.validate_model(sess) print_dict["Validation Precision"] = pre print_dict["Validation Accuracy"] = acc print_dict["Validation Recall"] = rec end_time = time.time() if (epoch + 1) % config.print_every == 0: self.print_summary(print_dict, epoch, end_time - start_time) if (epoch + 1) % config.save_every == 0 or ( epoch + 1) == config.num_epochs: self.save_model(sess, epoch + 1, config.log_dir)
def train(_): stat_file = h5py.File(config.stat_dir + 'stats.hdf5', mode='r') max_feat = np.array(stat_file["feats_maximus"]) min_feat = np.array(stat_file["feats_minimus"]) with tf.Graph().as_default(): input_placeholder = tf.placeholder(tf.float32, shape=(config.batch_size, config.max_phr_len, 66), name='input_placeholder') tf.summary.histogram('inputs', input_placeholder) output_placeholder = tf.placeholder(tf.float32, shape=(config.batch_size, config.max_phr_len, 64), name='output_placeholder') f0_input_placeholder = tf.placeholder(tf.float32, shape=(config.batch_size, config.max_phr_len, 1), name='f0_input_placeholder') rand_input_placeholder = tf.placeholder(tf.float32, shape=(config.batch_size, config.max_phr_len, 4), name='rand_input_placeholder') # pho_input_placeholder = tf.placeholder(tf.float32, shape=(config.batch_size,config.max_phr_len, 42),name='pho_input_placeholder') prob = tf.placeholder_with_default(1.0, shape=()) phoneme_labels = tf.placeholder(tf.int32, shape=(config.batch_size, config.max_phr_len), name='phoneme_placeholder') phone_onehot_labels = tf.one_hot(indices=tf.cast( phoneme_labels, tf.int32), depth=42) singer_labels = tf.placeholder(tf.float32, shape=(config.batch_size), name='singer_placeholder') singer_onehot_labels = tf.one_hot(indices=tf.cast( singer_labels, tf.int32), depth=12) phoneme_labels_shuffled = tf.placeholder(tf.int32, shape=(config.batch_size, config.max_phr_len), name='phoneme_placeholder_s') phone_onehot_labels_shuffled = tf.one_hot(indices=tf.cast( phoneme_labels_shuffled, tf.int32), depth=42) singer_labels_shuffled = tf.placeholder(tf.float32, shape=(config.batch_size), name='singer_placeholder_s') singer_onehot_labels_shuffled = tf.one_hot(indices=tf.cast( singer_labels_shuffled, tf.int32), depth=12) with tf.variable_scope('phone_Model') as scope: # regularizer = tf.contrib.layers.l2_regularizer(scale=0.1) pho_logits = modules.phone_network(input_placeholder) pho_classes = tf.argmax(pho_logits, axis=-1) pho_probs = tf.nn.softmax(pho_logits) with tf.variable_scope('Final_Model') as scope: voc_output = modules.final_net(singer_onehot_labels, f0_input_placeholder, phone_onehot_labels) voc_output_decoded = tf.nn.sigmoid(voc_output) scope.reuse_variables() voc_output_3 = modules.final_net(singer_onehot_labels, f0_input_placeholder, pho_probs) voc_output_3_decoded = tf.nn.sigmoid(voc_output_3) # with tf.variable_scope('singer_Model') as scope: # singer_embedding, singer_logits = modules.singer_network(input_placeholder, prob) # singer_classes = tf.argmax(singer_logits, axis=-1) # singer_probs = tf.nn.softmax(singer_logits) with tf.variable_scope('Generator') as scope: voc_output_2 = modules.GAN_generator(singer_onehot_labels, phone_onehot_labels, f0_input_placeholder, rand_input_placeholder) # scope.reuse_variables() # voc_output_2_2 = modules.GAN_generator(voc_output_3_decoded, singer_onehot_labels, phone_onehot_labels, f0_input_placeholder, rand_input_placeholder) with tf.variable_scope('Discriminator') as scope: D_real = modules.GAN_discriminator( (output_placeholder - 0.5) * 2, singer_onehot_labels, phone_onehot_labels, f0_input_placeholder) scope.reuse_variables() D_fake = modules.GAN_discriminator(voc_output_2, singer_onehot_labels, phone_onehot_labels, f0_input_placeholder) # scope.reuse_variables() # epsilon = tf.random_uniform([], 0.0, 1.0) # x_hat = (output_placeholder-0.5)*2*epsilon + (1-epsilon)* voc_output_2 # d_hat = modules.GAN_discriminator(x_hat, singer_onehot_labels, phone_onehot_labels, f0_input_placeholder) # scope.reuse_variables() # D_fake_2 = modules.GAN_discriminator(voc_output_2_2, singer_onehot_labels, phone_onehot_labels, f0_input_placeholder) scope.reuse_variables() D_fake_real = modules.GAN_discriminator( (voc_output_decoded - 0.5) * 2, singer_onehot_labels, phone_onehot_labels, f0_input_placeholder) # import pdb;pdb.set_trace() # Get network parameters final_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="Final_Model") g_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="Generator") d_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="Discriminator") phone_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="phone_Model") # Phoneme network loss and summary pho_weights = tf.reduce_sum(config.phonemas_weights * phone_onehot_labels, axis=-1) unweighted_losses = tf.nn.softmax_cross_entropy_with_logits( labels=phone_onehot_labels, logits=pho_logits) weighted_losses = unweighted_losses * pho_weights pho_loss = tf.reduce_mean(weighted_losses) # +tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels= output_placeholder, logits=voc_output_3))*0.001 # reconstruct_loss_pho = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels = output_placeholder, logits=voc_output_decoded_gen)) *0.00001 # pho_loss+=reconstruct_loss_pho pho_acc = tf.metrics.accuracy(labels=phoneme_labels, predictions=pho_classes) pho_summary = tf.summary.scalar('pho_loss', pho_loss) pho_acc_summary = tf.summary.scalar('pho_accuracy', pho_acc[0]) # Discriminator Loss # D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels = tf.ones_like(D_real) , logits=D_real+1e-12)) # D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels = tf.zeros_like(D_fake) , logits=D_fake+1e-12)) + tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels = tf.zeros_like(D_fake_2) , logits=D_fake_2+1e-12)) *0.5 # D_loss_fake_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels = tf.zeros_like(D_fake_real) , logits=D_fake_real+1e-12)) # D_loss_real = tf.reduce_mean(D_real+1e-12) # D_loss_fake = - tf.reduce_mean(D_fake+1e-12) # D_loss_fake_real = - tf.reduce_mean(D_fake_real+1e-12) # gradients = tf.gradients(d_hat, x_hat)[0] + 1e-6 # slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1])) # gradient_penalty = tf.reduce_mean((slopes-1.0)**2) # errD += gradient_penalty # D_loss_fake_real = - tf.reduce_mean(D_fake_real) D_correct_pred = tf.equal(tf.round(tf.sigmoid(D_real)), tf.ones_like(D_real)) D_correct_pred_fake = tf.equal(tf.round(tf.sigmoid(D_fake_real)), tf.ones_like(D_fake_real)) D_accuracy = tf.reduce_mean(tf.cast(D_correct_pred, tf.float32)) D_accuracy_fake = tf.reduce_mean( tf.cast(D_correct_pred_fake, tf.float32)) D_loss = tf.reduce_mean(D_real + 1e-12) - tf.reduce_mean(D_fake + 1e-12) # -tf.reduce_mean(D_fake_real+1e-12)*0.001 dis_summary = tf.summary.scalar('dis_loss', D_loss) dis_acc_summary = tf.summary.scalar('dis_acc', D_accuracy) dis_acc_fake_summary = tf.summary.scalar('dis_acc_fake', D_accuracy_fake) #Final net loss # G_loss_GAN = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels= tf.ones_like(D_real), logits=D_fake+1e-12)) + tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels= tf.ones_like(D_fake_2), logits=D_fake_2+1e-12)) # + tf.reduce_sum(tf.abs(output_placeholder- (voc_output_2/2+0.5))*(1-input_placeholder[:,:,-1:])) *0.00001 G_loss_GAN = tf.reduce_mean(D_fake + 1e-12) + tf.reduce_sum( tf.abs(output_placeholder - (voc_output_2 / 2 + 0.5))) * 0.00005 # + tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels= output_placeholder, logits=voc_output)) *0.000005 # G_correct_pred = tf.equal(tf.round(tf.sigmoid(D_fake)), tf.ones_like(D_real)) # G_correct_pred_2 = tf.equal(tf.round(tf.sigmoid(D_fake_2)), tf.ones_like(D_real)) G_accuracy = tf.reduce_mean(tf.cast(G_correct_pred, tf.float32)) gen_summary = tf.summary.scalar('gen_loss', G_loss_GAN) gen_acc_summary = tf.summary.scalar('gen_acc', G_accuracy) final_loss = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels= output_placeholder, logits=voc_output)) \ # +tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels= output_placeholder, logits=voc_output_3))*0.5 # reconstruct_loss = tf.reduce_sum(tf.abs(output_placeholder- (voc_output_2/2+0.5))) final_summary = tf.summary.scalar('final_loss', final_loss) summary = tf.summary.merge_all() # summary_val = tf.summary.merge([f0_summary_midi, pho_summary, singer_summary, reconstruct_summary, pho_acc_summary_val, f0_acc_summary_midi_val, singer_acc_summary_val ]) # vuv_summary = tf.summary.scalar('vuv_loss', vuv_loss) # loss_summary = tf.summary.scalar('total_loss', loss) #Global steps global_step = tf.Variable(0, name='global_step', trainable=False) global_step_re = tf.Variable(0, name='global_step_re', trainable=False) global_step_dis = tf.Variable(0, name='global_step_dis', trainable=False) global_step_gen = tf.Variable(0, name='global_step_gen', trainable=False) #Optimizers pho_optimizer = tf.train.AdamOptimizer(learning_rate=config.init_lr) re_optimizer = tf.train.AdamOptimizer(learning_rate=config.init_lr) dis_optimizer = tf.train.RMSPropOptimizer(learning_rate=5e-5) gen_optimizer = tf.train.RMSPropOptimizer(learning_rate=5e-5) # GradientDescentOptimizer # update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # Training functions pho_train_function = pho_optimizer.minimize(pho_loss, global_step=global_step, var_list=phone_params) # with tf.control_dependencies(update_ops): re_train_function = re_optimizer.minimize(final_loss, global_step=global_step_re, var_list=final_params) dis_train_function = dis_optimizer.minimize( D_loss, global_step=global_step_dis, var_list=d_params) gen_train_function = gen_optimizer.minimize( G_loss_GAN, global_step=global_step_gen, var_list=g_params) clip_discriminator_var_op = [ var.assign(tf.clip_by_value(var, -0.01, 0.01)) for var in d_params ] init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) saver = tf.train.Saver(max_to_keep=config.max_models_to_keep) sess = tf.Session() sess.run(init_op) ckpt = tf.train.get_checkpoint_state(config.log_dir) if ckpt and ckpt.model_checkpoint_path: print("Using the model in %s" % ckpt.model_checkpoint_path) saver.restore(sess, ckpt.model_checkpoint_path) train_summary_writer = tf.summary.FileWriter(config.log_dir + 'train/', sess.graph) val_summary_writer = tf.summary.FileWriter(config.log_dir + 'val/', sess.graph) start_epoch = int( sess.run(tf.train.get_global_step()) / (config.batches_per_epoch_train)) print("Start from: %d" % start_epoch) for epoch in xrange(start_epoch, config.num_epochs): if epoch < 25 or epoch % 100 == 0: n_critic = 25 else: n_critic = 5 data_generator = data_gen(sec_mode=0) start_time = time.time() val_generator = data_gen(mode='val') batch_num = 0 epoch_pho_loss = 0 epoch_gen_loss = 0 epoch_re_loss = 0 epoch_dis_loss = 0 epoch_pho_acc = 0 epoch_gen_acc = 0 epoch_dis_acc = 0 epoch_dis_acc_fake = 0 val_epoch_pho_loss = 0 val_epoch_gen_loss = 0 val_epoch_dis_loss = 0 val_epoch_pho_acc = 0 val_epoch_gen_acc = 0 val_epoch_dis_acc = 0 val_epoch_dis_acc_fake = 0 with tf.variable_scope('Training'): for feats, f0, phos, singer_ids in data_generator: # plt.imshow(feats.reshape(-1,66).T,aspect = 'auto', origin ='lower') # plt.show() # import pdb;pdb.set_trace() pho_one_hot = one_hotize(phos, max_index=42) f0 = f0.reshape([config.batch_size, config.max_phr_len, 1]) sing_id_shu = np.copy(singer_ids) phos_shu = np.copy(phos) np.random.shuffle(sing_id_shu) np.random.shuffle(phos_shu) for critic_itr in range(n_critic): feed_dict = { input_placeholder: feats, output_placeholder: feats[:, :, :-2], f0_input_placeholder: f0, rand_input_placeholder: np.random.uniform(-1.0, 1.0, size=[30, config.max_phr_len, 4]), phoneme_labels: phos, singer_labels: singer_ids, phoneme_labels_shuffled: phos_shu, singer_labels_shuffled: sing_id_shu } sess.run(dis_train_function, feed_dict=feed_dict) sess.run(clip_discriminator_var_op, feed_dict=feed_dict) feed_dict = { input_placeholder: feats, output_placeholder: feats[:, :, :-2], f0_input_placeholder: f0, rand_input_placeholder: np.random.uniform(-1.0, 1.0, size=[30, config.max_phr_len, 4]), phoneme_labels: phos, singer_labels: singer_ids, phoneme_labels_shuffled: phos_shu, singer_labels_shuffled: sing_id_shu } _, _, step_re_loss, step_gen_loss, step_gen_acc = sess.run( [ re_train_function, gen_train_function, final_loss, G_loss_GAN, G_accuracy ], feed_dict=feed_dict) # if step_gen_acc>0.3: step_dis_loss, step_dis_acc, step_dis_acc_fake = sess.run( [D_loss, D_accuracy, D_accuracy_fake], feed_dict=feed_dict) _, step_pho_loss, step_pho_acc = sess.run( [pho_train_function, pho_loss, pho_acc], feed_dict=feed_dict) # else: # step_dis_loss, step_dis_acc = sess.run([D_loss, D_accuracy], feed_dict = feed_dict) epoch_pho_loss += step_pho_loss epoch_re_loss += step_re_loss epoch_gen_loss += step_gen_loss epoch_dis_loss += step_dis_loss epoch_pho_acc += step_pho_acc[0] epoch_gen_acc += step_gen_acc epoch_dis_acc += step_dis_acc epoch_dis_acc_fake += step_dis_acc_fake utils.progress(batch_num, config.batches_per_epoch_train, suffix='training done') batch_num += 1 epoch_pho_loss = epoch_pho_loss / config.batches_per_epoch_train epoch_re_loss = epoch_re_loss / config.batches_per_epoch_train epoch_gen_loss = epoch_gen_loss / config.batches_per_epoch_train epoch_dis_loss = epoch_dis_loss / config.batches_per_epoch_train epoch_dis_acc_fake = epoch_dis_acc_fake / config.batches_per_epoch_train epoch_pho_acc = epoch_pho_acc / config.batches_per_epoch_train epoch_gen_acc = epoch_gen_acc / config.batches_per_epoch_train epoch_dis_acc = epoch_dis_acc / config.batches_per_epoch_train summary_str = sess.run(summary, feed_dict=feed_dict) # import pdb;pdb.set_trace() train_summary_writer.add_summary(summary_str, epoch) # # summary_writer.add_summary(summary_str_val, epoch) train_summary_writer.flush() with tf.variable_scope('Validation'): for feats, f0, phos, singer_ids in val_generator: pho_one_hot = one_hotize(phos, max_index=42) f0 = f0.reshape([config.batch_size, config.max_phr_len, 1]) sing_id_shu = np.copy(singer_ids) phos_shu = np.copy(phos) np.random.shuffle(sing_id_shu) np.random.shuffle(phos_shu) feed_dict = { input_placeholder: feats, output_placeholder: feats[:, :, :-2], f0_input_placeholder: f0, rand_input_placeholder: np.random.uniform(-1.0, 1.0, size=[30, config.max_phr_len, 4]), phoneme_labels: phos, singer_labels: singer_ids, phoneme_labels_shuffled: phos_shu, singer_labels_shuffled: sing_id_shu } step_pho_loss, step_pho_acc = sess.run([pho_loss, pho_acc], feed_dict=feed_dict) step_gen_loss, step_gen_acc = sess.run( [final_loss, G_accuracy], feed_dict=feed_dict) step_dis_loss, step_dis_acc, step_dis_acc_fake = sess.run( [D_loss, D_accuracy, D_accuracy_fake], feed_dict=feed_dict) val_epoch_pho_loss += step_pho_loss val_epoch_gen_loss += step_gen_loss val_epoch_dis_loss += step_dis_loss val_epoch_pho_acc += step_pho_acc[0] val_epoch_gen_acc += step_gen_acc val_epoch_dis_acc += step_dis_acc val_epoch_dis_acc_fake += step_dis_acc_fake utils.progress(batch_num, config.batches_per_epoch_train, suffix='training done') batch_num += 1 val_epoch_pho_loss = val_epoch_pho_loss / config.batches_per_epoch_val val_epoch_gen_loss = val_epoch_gen_loss / config.batches_per_epoch_val val_epoch_dis_loss = val_epoch_dis_loss / config.batches_per_epoch_val val_epoch_pho_acc = val_epoch_pho_acc / config.batches_per_epoch_val val_epoch_gen_acc = val_epoch_gen_acc / config.batches_per_epoch_val val_epoch_dis_acc = val_epoch_dis_acc / config.batches_per_epoch_val val_epoch_dis_acc_fake = val_epoch_dis_acc_fake / config.batches_per_epoch_val summary_str = sess.run(summary, feed_dict=feed_dict) # import pdb;pdb.set_trace() val_summary_writer.add_summary(summary_str, epoch) # # summary_writer.add_summary(summary_str_val, epoch) val_summary_writer.flush() duration = time.time() - start_time # np.save('./ikala_eval/accuracies', f0_accs) if (epoch + 1) % config.print_every == 0: print('epoch %d: Phone Loss = %.10f (%.3f sec)' % (epoch + 1, epoch_pho_loss, duration)) print(' : Phone Accuracy = %.10f ' % (epoch_pho_acc)) print(' : Recon Loss = %.10f ' % (epoch_re_loss)) print(' : Gen Loss = %.10f ' % (epoch_gen_loss)) print(' : Gen Accuracy = %.10f ' % (epoch_gen_acc)) print(' : Dis Loss = %.10f ' % (epoch_dis_loss)) print(' : Dis Accuracy = %.10f ' % (epoch_dis_acc)) print(' : Dis Accuracy Fake = %.10f ' % (epoch_dis_acc_fake)) print(' : Val Phone Accuracy = %.10f ' % (val_epoch_pho_acc)) print(' : Val Gen Loss = %.10f ' % (val_epoch_gen_loss)) print(' : Val Gen Accuracy = %.10f ' % (val_epoch_gen_acc)) print(' : Val Dis Loss = %.10f ' % (val_epoch_dis_loss)) print(' : Val Dis Accuracy = %.10f ' % (val_epoch_dis_acc)) print(' : Val Dis Accuracy Fake = %.10f ' % (val_epoch_dis_acc_fake)) if (epoch + 1) % config.save_every == 0 or ( epoch + 1) == config.num_epochs: # utils.list_to_file(val_f0_accs,'./ikala_eval/accuracies_'+str(epoch+1)+'.txt') checkpoint_file = os.path.join(config.log_dir, 'model.ckpt') saver.save(sess, checkpoint_file, global_step=epoch)
def train(self): """ Function to train the model, and save Tensorboard summary, for N epochs. """ sess = tf.Session() self.loss_function() self.get_optimizers() self.load_model(sess, config.log_dir) self.get_summary(sess, config.log_dir) start_epoch = int(sess.run(tf.train.get_global_step()) / (config.batches_per_epoch_train)) print("Start from: %d" % start_epoch) for epoch in range(start_epoch, config.num_epochs): data_generator = data_gen() val_generator = data_gen(mode = 'Val') start_time = time.time() batch_num = 0 epoch_final_loss = 0 epoch_harm_loss = 0 epoch_ap_loss = 0 epoch_vuv_loss = 0 epoch_f0_loss = 0 val_final_loss = 0 val_harm_loss = 0 val_ap_loss = 0 val_vuv_loss = 0 val_f0_loss = 0 with tf.variable_scope('Training'): for voc, feat in data_generator: final_loss, summary_str = self.train_model(voc, feat, sess) epoch_final_loss+=final_loss self.train_summary_writer.add_summary(summary_str, epoch) self.train_summary_writer.flush() utils.progress(batch_num,config.batches_per_epoch_train, suffix = 'training done') batch_num+=1 epoch_final_loss = epoch_final_loss/batch_num print_dict = {"Final Loss": epoch_final_loss} if (epoch + 1) % config.validate_every == 0: batch_num = 0 with tf.variable_scope('Validation'): for voc, feat in val_generator: final_loss, summary_str= self.validate_model(voc, feat, sess) val_final_loss+=final_loss self.val_summary_writer.add_summary(summary_str, epoch) self.val_summary_writer.flush() batch_num+=1 utils.progress(batch_num, config.batches_per_epoch_val, suffix='validation done') val_final_loss = val_final_loss/batch_num print_dict["Val Final Loss"] = val_final_loss end_time = time.time() if (epoch + 1) % config.print_every == 0: self.print_summary(print_dict, epoch, end_time-start_time) if (epoch + 1) % config.save_every == 0 or (epoch + 1) == config.num_epochs: self.save_model(sess, epoch+1, config.log_dir)
def train(_): # stat_file = h5py.File(config.stat_dir+'stats.hdf5', mode='r') # max_feat = np.array(stat_file["feats_maximus"]) # min_feat = np.array(stat_file["feats_minimus"]) with tf.Graph().as_default(): output_placeholder = tf.placeholder(tf.float32, shape=(config.batch_size, config.max_phr_len, 64), name='output_placeholder') f0_output_placeholder = tf.placeholder(tf.float32, shape=(config.batch_size, config.max_phr_len, 1), name='f0_output_placeholder') f0_input_placeholder = tf.placeholder(tf.float32, shape=(config.batch_size, config.max_phr_len), name='f0_input_placeholder') f0_onehot_labels = tf.one_hot(indices=tf.cast(f0_input_placeholder, tf.int32), depth=len(config.notes)) f0_context_placeholder = tf.placeholder(tf.float32, shape=(config.batch_size, config.max_phr_len, 1), name='f0_context_placeholder') uv_placeholder = tf.placeholder(tf.float32, shape=(config.batch_size, config.max_phr_len, 1), name='uv_placeholder') phone_context_placeholder = tf.placeholder( tf.float32, shape=(config.batch_size, config.max_phr_len, 1), name='phone_context_placeholder') rand_input_placeholder = tf.placeholder(tf.float32, shape=(config.batch_size, config.max_phr_len, 64), name='rand_input_placeholder') prob = tf.placeholder_with_default(1.0, shape=()) phoneme_labels = tf.placeholder(tf.int32, shape=(config.batch_size, config.max_phr_len), name='phoneme_placeholder') phone_onehot_labels = tf.one_hot(indices=tf.cast( phoneme_labels, tf.int32), depth=len(config.phonemas)) with tf.variable_scope('Generator_feats') as scope: inputs = tf.concat([ phone_onehot_labels, f0_onehot_labels, phone_context_placeholder, f0_context_placeholder ], axis=-1) voc_output = modules.GAN_generator(inputs) with tf.variable_scope('Discriminator_feats') as scope: inputs = tf.concat([ phone_onehot_labels, f0_onehot_labels, phone_context_placeholder, f0_context_placeholder ], axis=-1) D_real = modules.GAN_discriminator((output_placeholder - 0.5) * 2, inputs) scope.reuse_variables() D_fake = modules.GAN_discriminator(voc_output, inputs) with tf.variable_scope('Generator_f0') as scope: inputs = tf.concat([ phone_onehot_labels, f0_onehot_labels, phone_context_placeholder, f0_context_placeholder, output_placeholder ], axis=-1) # inputs = tf.concat([phone_onehot_labels, f0_onehot_labels, phone_context_placeholder, f0_context_placeholder, (voc_output/2)+0.5], axis = -1) f0_output = modules.GAN_generator_f0(inputs) scope.reuse_variables() inputs = tf.concat([ phone_onehot_labels, f0_onehot_labels, phone_context_placeholder, f0_context_placeholder, (voc_output / 2) + 0.5 ], axis=-1) f0_output_2 = modules.GAN_generator_f0(inputs) with tf.variable_scope('Discriminator_f0') as scope: inputs = tf.concat([ phone_onehot_labels, f0_onehot_labels, phone_context_placeholder, f0_context_placeholder, output_placeholder ], axis=-1) D_real_f0 = modules.GAN_discriminator_f0( (f0_output_placeholder - 0.5) * 2, inputs) scope.reuse_variables() D_fake_f0 = modules.GAN_discriminator_f0(f0_output, inputs) scope.reuse_variables() inputs = tf.concat([ phone_onehot_labels, f0_onehot_labels, phone_context_placeholder, f0_context_placeholder, (voc_output / 2) + 0.5 ], axis=-1) D_real_f0_2 = modules.GAN_discriminator_f0( (f0_output_placeholder - 0.5) * 2, inputs) scope.reuse_variables() D_fake_f0_2 = modules.GAN_discriminator_f0(f0_output_2, inputs) g_params_feats = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="Generator_feats") d_params_feats = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="Discriminator_feats") g_params_f0 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="Generator_f0") d_params_f0 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="Discriminator_f0") D_loss = tf.reduce_mean(D_real + 1e-12) - tf.reduce_mean(D_fake + 1e-12) dis_summary = tf.summary.scalar('dis_loss', D_loss) G_loss_GAN = tf.reduce_mean(D_fake + 1e-12) + tf.reduce_sum( tf.abs(output_placeholder - (voc_output / 2 + 0.5))) * 0.00005 gen_summary = tf.summary.scalar('gen_loss', G_loss_GAN) D_loss_f0 = tf.reduce_mean(D_real_f0 + 1e-12) - tf.reduce_mean(D_fake_f0 + 1e-12) dis_summary_f0 = tf.summary.scalar('dis_loss_f0', D_loss_f0) G_loss_GAN_f0 = tf.reduce_mean(D_fake_f0 + 1e-12) + tf.reduce_sum( tf.abs(f0_output_placeholder - (f0_output / 2 + 0.5))) * 0.00005 # + tf.reduce_mean(D_fake_f0_2+1e-12) + tf.reduce_sum(tf.abs(f0_output_placeholder- (f0_output_2/2+0.5))) *0.00005 D_loss_f0_2 = tf.reduce_mean(D_real_f0_2 + 1e-12) - tf.reduce_mean(D_fake_f0_2 + 1e-12) G_loss_GAN_f0_2 = tf.reduce_mean(D_fake_f0_2 + 1e-12) + tf.reduce_sum( tf.abs(f0_output_placeholder - (f0_output_2 / 2 + 0.5))) * 0.00005 gen_summary_f0 = tf.summary.scalar('gen_loss_f0', G_loss_GAN_f0) summary = tf.summary.merge_all() global_step = tf.Variable(0, name='global_step', trainable=False) global_step_dis = tf.Variable(0, name='global_step_dis', trainable=False) global_step_f0 = tf.Variable(0, name='global_step_f0', trainable=False) global_step_dis_f0 = tf.Variable(0, name='global_step_dis_f0', trainable=False) global_step_f0_2 = tf.Variable(0, name='global_step_f0_2', trainable=False) global_step_dis_f0_2 = tf.Variable(0, name='global_step_dis_f0_2', trainable=False) dis_optimizer = tf.train.RMSPropOptimizer(learning_rate=5e-5) gen_optimizer = tf.train.RMSPropOptimizer(learning_rate=5e-5) dis_optimizer_f0 = tf.train.RMSPropOptimizer(learning_rate=5e-5) gen_optimizer_f0 = tf.train.RMSPropOptimizer(learning_rate=5e-5) dis_optimizer_f0_2 = tf.train.RMSPropOptimizer(learning_rate=5e-5) gen_optimizer_f0_2 = tf.train.RMSPropOptimizer(learning_rate=5e-5) # GradientDescentOptimizer dis_train_function = dis_optimizer.minimize( D_loss, global_step=global_step_dis, var_list=d_params_feats) gen_train_function = gen_optimizer.minimize(G_loss_GAN, global_step=global_step, var_list=g_params_feats) dis_train_function_f0 = dis_optimizer.minimize( D_loss_f0, global_step=global_step_dis_f0, var_list=d_params_f0) gen_train_function_f0 = gen_optimizer.minimize( G_loss_GAN_f0, global_step=global_step_f0, var_list=g_params_f0) dis_train_function_f0_2 = dis_optimizer.minimize( D_loss_f0_2, global_step=global_step_dis_f0_2, var_list=d_params_f0) gen_train_function_f0_2 = gen_optimizer.minimize( G_loss_GAN_f0_2, global_step=global_step_f0_2, var_list=g_params_f0) clip_discriminator_var_op_feats = [ var.assign(tf.clip_by_value(var, -0.01, 0.01)) for var in d_params_feats ] clip_discriminator_var_op_f0 = [ var.assign(tf.clip_by_value(var, -0.01, 0.01)) for var in d_params_f0 ] init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) saver = tf.train.Saver(max_to_keep=config.max_models_to_keep) sess = tf.Session() sess.run(init_op) ckpt = tf.train.get_checkpoint_state(config.log_dir) if ckpt and ckpt.model_checkpoint_path: print("Using the model in %s" % ckpt.model_checkpoint_path) saver.restore(sess, ckpt.model_checkpoint_path) train_summary_writer = tf.summary.FileWriter(config.log_dir + 'train/', sess.graph) val_summary_writer = tf.summary.FileWriter(config.log_dir + 'val/', sess.graph) start_epoch = int( sess.run(tf.train.get_global_step()) / (config.batches_per_epoch_train)) print("Start from: %d" % start_epoch) for epoch in xrange(start_epoch, config.num_epochs): if epoch < 25 or epoch % 100 == 0: n_critic = 25 else: n_critic = 5 if epoch < 1025 or epoch % 100 == 0: n_critic_f0 = 25 else: n_critic_f0 = 5 data_generator = data_gen(sec_mode=0) start_time = time.time() val_generator = data_gen(mode='val') batch_num = 0 # epoch_pho_loss = 0 epoch_gen_loss = 0 epoch_dis_loss = 0 epoch_gen_loss_f0 = 0 epoch_dis_loss_f0 = 0 with tf.variable_scope('Training'): for feats, conds in data_generator: f0 = conds[:, :, 2] phones = conds[:, :, 0] f0_context = conds[:, :, -1:] phones_context = conds[:, :, 1:2] feed_dict = { f0_output_placeholder: feats[:, :, -2:-1], f0_input_placeholder: f0, phoneme_labels: phones, phone_context_placeholder: phones_context, f0_context_placeholder: f0_context, output_placeholder: feats[:, :, :64], uv_placeholder: feats[:, :, -1:] } for critic_itr in range(n_critic): sess.run(dis_train_function, feed_dict=feed_dict) sess.run(clip_discriminator_var_op_feats, feed_dict=feed_dict) # feed_dict = {input_placeholder: feats, output_placeholder: feats[:,:,:-2], f0_input_placeholder: f0, rand_input_placeholder: np.random.uniform(-1.0, 1.0, size=[30,config.max_phr_len,4]), # phoneme_labels:phos, singer_labels: singer_ids, phoneme_labels_shuffled:phos_shu, singer_labels_shuffled:sing_id_shu} _, step_gen_loss = sess.run( [gen_train_function, G_loss_GAN], feed_dict=feed_dict) # import pdb;pdb.set_trace() # if step_gen_acc>0.3: step_dis_loss = sess.run(D_loss, feed_dict=feed_dict) # feed_dict = {input_placeholder: feats, output_placeholder: feats[:,:,:-2], f0_input_placeholder: f0, rand_input_placeholder: np.random.uniform(-1.0, 1.0, size=[30,config.max_phr_len,4]), # phoneme_labels:phos, singer_labels: singer_ids, phoneme_labels_shuffled:phos_shu, singer_labels_shuffled:sing_id_shu} if epoch > 1000: for critic_itr in range(n_critic_f0): sess.run(dis_train_function_f0_2, feed_dict=feed_dict) sess.run(clip_discriminator_var_op_f0, feed_dict=feed_dict) # feed_dict = {input_placeholder: feats, output_placeholder: feats[:,:,:-2], f0_input_placeholder: f0, rand_input_placeholder: np.random.uniform(-1.0, 1.0, size=[30,config.max_phr_len,4]), # phoneme_labels:phos, singer_labels: singer_ids, phoneme_labels_shuffled:phos_shu, singer_labels_shuffled:sing_id_shu} _, step_gen_loss_2 = sess.run( [gen_train_function_f0_2, G_loss_GAN_f0_2], feed_dict=feed_dict) # import pdb;pdb.set_trace() # if step_gen_acc>0.3: step_dis_loss_2 = sess.run(D_loss_f0_2, feed_dict=feed_dict) else: for critic_itr in range(n_critic): sess.run(dis_train_function_f0, feed_dict=feed_dict) sess.run(clip_discriminator_var_op_f0, feed_dict=feed_dict) _, step_gen_loss_f0 = sess.run( [gen_train_function_f0, G_loss_GAN_f0], feed_dict=feed_dict) # import pdb;pdb.set_trace() # if step_gen_acc>0.3: step_dis_loss_f0 = sess.run(D_loss_f0, feed_dict=feed_dict) # _, step_pho_loss, step_pho_acc = sess.run([pho_train_function, pho_loss, pho_acc], feed_dict= feed_dict) # else: # step_dis_loss, step_dis_acc = sess.run([D_loss, D_accuracy], feed_dict = feed_dict) # epoch_pho_loss+=step_pho_loss # epoch_re_loss+=step_re_loss epoch_gen_loss += step_gen_loss epoch_dis_loss += step_dis_loss # epoch_pho_acc+=step_pho_acc[0] # epoch_gen_acc+=step_gen_acc # epoch_dis_acc+=step_dis_acc # epoch_dis_acc_fake+=step_dis_acc_fake utils.progress(batch_num, config.batches_per_epoch_train, suffix='training done') batch_num += 1 # epoch_pho_loss = epoch_pho_loss/config.batches_per_epoch_train # epoch_re_loss = epoch_re_loss/config.batches_per_epoch_train epoch_gen_loss = epoch_gen_loss / config.batches_per_epoch_train epoch_dis_loss = epoch_dis_loss / config.batches_per_epoch_train # epoch_dis_acc_fake = epoch_dis_acc_fake/config.batches_per_epoch_train # epoch_pho_acc = epoch_pho_acc/config.batches_per_epoch_train # epoch_gen_acc = epoch_gen_acc/config.batches_per_epoch_train # epoch_dis_acc = epoch_dis_acc/config.batches_per_epoch_train summary_str = sess.run(summary, feed_dict=feed_dict) # import pdb;pdb.set_trace() train_summary_writer.add_summary(summary_str, epoch) # # summary_writer.add_summary(summary_str_val, epoch) train_summary_writer.flush() duration = time.time() - start_time # np.save('./ikala_eval/accuracies', f0_accs) if (epoch + 1) % config.print_every == 0: print('epoch %d: Gen Loss = %.10f (%.3f sec)' % (epoch + 1, epoch_gen_loss, duration)) # print(' : Phone Accuracy = %.10f ' % (epoch_pho_acc)) # print(' : Recon Loss = %.10f ' % (epoch_re_loss)) # print(' : Gen Loss = %.10f ' % (epoch_gen_loss)) # print(' : Gen Accuracy = %.10f ' % (epoch_gen_acc)) print(' : Dis Loss = %.10f ' % (epoch_dis_loss)) # print(' : Dis Accuracy = %.10f ' % (epoch_dis_acc)) # print(' : Dis Accuracy Fake = %.10f ' % (epoch_dis_acc_fake)) # print(' : Val Phone Accuracy = %.10f ' % (val_epoch_pho_acc)) # print(' : Val Gen Loss = %.10f ' % (val_epoch_gen_loss)) # print(' : Val Gen Accuracy = %.10f ' % (val_epoch_gen_acc)) # print(' : Val Dis Loss = %.10f ' % (val_epoch_dis_loss)) # print(' : Val Dis Accuracy = %.10f ' % (val_epoch_dis_acc)) # print(' : Val Dis Accuracy Fake = %.10f ' % (val_epoch_dis_acc_fake)) if (epoch + 1) % config.save_every == 0 or ( epoch + 1) == config.num_epochs: # utils.list_to_file(val_f0_accs,'./ikala_eval/accuracies_'+str(epoch+1)+'.txt') checkpoint_file = os.path.join(config.log_dir, 'model.ckpt') saver.save(sess, checkpoint_file, global_step=epoch)
def trainNetwork(save_name='model_e' + str(config.num_epochs) + '_b' + str(config.batches_per_epoch_train) + '_bs' + str(config.batch_size)): assert torch.cuda.is_available(), "Code only usable with cuda" #autoencoder = AutoEncoder().cuda() autoencoder = AutoEncoder().cuda() autoencoder.load_state_dict( torch.load('./log/model_e8000_b50_bs5_3469.pt')) optimizer = torch.optim.Adadelta(autoencoder.parameters(), lr=1, rho=0.95) loss_func = nn.MSELoss(size_average=False) #loss_func = nn.L1Loss( size_average=False ) train_evol = [] val_evol = [] count = 0 for epoch in range(config.num_epochs): start_time = time.time() generator = data_gen() val_gen = data_gen(mode="Val") train_loss = 0 train_loss_vocals = 0 train_loss_drums = 0 train_loss_bass = 0 train_alpha_diff = 0 train_beta_other = 0 train_beta_other_voc = 0 val_loss = 0 val_loss_vocals = 0 val_loss_drums = 0 val_loss_bass = 0 val_alpha_diff = 0 val_beta_other = 0 val_beta_other_voc = 0 optimizer.zero_grad() count = 0 for inputs, targets in generator: step_loss_vocals, step_loss_drums, step_loss_bass, alpha_diff, beta_other, beta_other_voc = loss_calc( inputs, targets, loss_func, autoencoder) # start_time = time.time() # add regularization terms from paper step_loss = abs(step_loss_vocals + step_loss_drums + step_loss_bass - beta_other - alpha_diff - beta_other_voc) # print time.time()-start_time # import pdb;pdb.set_trace() # start_time = time.time() train_loss += step_loss.item() if np.isnan(train_loss): #import pdb;pdb.set_trace() optimizer.zero_grad() print("error output contains NaN") train_loss_vocals += step_loss_vocals.item() train_loss_drums += step_loss_drums.item() train_loss_bass += step_loss_bass.item() train_alpha_diff += alpha_diff.item() train_beta_other += beta_other.item() train_beta_other_voc += beta_other_voc.item() step_loss.backward() #clip gradient # torch.nn.utils.clip_grad_norm_( autoencoder.parameters(),1) for p in autoencoder.parameters(): p.grad.data.clamp(-1, 1) optimizer.step() # print time.time()-start_time utils.progress(count, config.batches_per_epoch_train, suffix='training done') count += 1 train_loss = train_loss / (config.batches_per_epoch_train * count * config.max_phr_len * 513) train_loss_vocals = train_loss_vocals / ( config.batches_per_epoch_train * count * config.max_phr_len * 513) train_loss_drums = train_loss_drums / ( config.batches_per_epoch_train * count * config.max_phr_len * 513) train_loss_bass = train_loss_bass / (config.batches_per_epoch_train * count * config.max_phr_len * 513) train_alpha_diff = train_alpha_diff / ( config.batches_per_epoch_train * count * config.max_phr_len * 513) train_beta_other = train_beta_other / ( config.batches_per_epoch_train * count * config.max_phr_len * 513) train_beta_other_voc = train_beta_other_voc / ( config.batches_per_epoch_train * count * config.max_phr_len * 513) train_evol.append([ train_loss, train_loss_vocals, train_loss_drums, train_loss_bass, train_alpha_diff, train_beta_other, train_beta_other_voc ]) count = 0 for inputs, targets in val_gen: step_loss_vocals, step_loss_drums, step_loss_bass, alpha_diff, beta_other, beta_other_voc = loss_calc( inputs, targets, loss_func, autoencoder) # add regularization terms from paper step_loss = abs(step_loss_vocals + step_loss_drums + step_loss_bass - beta_other - alpha_diff - beta_other_voc) val_loss += step_loss.item() val_loss_vocals += step_loss_vocals.item() val_loss_drums += step_loss_drums.item() val_loss_bass += step_loss_bass.item() val_alpha_diff += alpha_diff.item() val_beta_other += beta_other.item() val_beta_other_voc += beta_other_voc.item() utils.progress(count, config.batches_per_epoch_val, suffix='validation done') count += 1 val_loss = val_loss / (config.batches_per_epoch_val * count * config.max_phr_len * 513) val_loss_vocals = val_loss_vocals / (config.batches_per_epoch_val * count * config.max_phr_len * 513) val_loss_drums = val_loss_drums / (config.batches_per_epoch_val * count * config.max_phr_len * 513) val_loss_bass = val_loss_bass / (config.batches_per_epoch_val * count * config.max_phr_len * 513) val_alpha_diff = val_alpha_diff / (config.batches_per_epoch_val * count * config.max_phr_len * 513) val_beta_other = val_beta_other / (config.batches_per_epoch_val * count * config.max_phr_len * 513) val_beta_other_voc = val_beta_other_voc / ( config.batches_per_epoch_val * count * config.max_phr_len * 513) val_evol.append([ val_loss, val_loss_vocals, val_loss_drums, val_loss_bass, val_alpha_diff, val_beta_other, val_beta_other_voc ]) # import pdb;pdb.set_trace() duration = time.time() - start_time if (epoch + 1) % config.print_every == 0: print('epoch %d/%d, took %.2f seconds, epoch total loss: %.7f' % (epoch + 1, config.num_epochs, duration, train_loss)) print(' epoch vocal loss: %.7f' % (train_loss_vocals)) print(' epoch drums loss: %.7f' % (train_loss_drums)) print(' epoch bass loss: %.7f' % (train_loss_bass)) print(' epoch alpha diff: %.7f' % (train_alpha_diff)) print(' epoch beta diff: %.7f' % (train_beta_other)) print(' epoch beta2 diff: %.7f' % (train_beta_other_voc)) print( ' validation total loss: %.7f' % (val_loss)) print( ' validation vocal loss: %.7f' % (val_loss_vocals)) print( ' validation drums loss: %.7f' % (val_loss_drums)) print( ' validation bass loss: %.7f' % (val_loss_bass)) print( ' validation alpha diff: %.7f' % (val_alpha_diff)) print( ' validation beta diff: %.7f' % (val_beta_other)) print( ' validation beta2 diff: %.7f' % (val_beta_other_voc)) # import pdb;pdb.set_trace() if (epoch + 1) % config.save_every == 0: torch.save( autoencoder.state_dict(), config.log_dir + save_name + '_' + str(epoch + 3470) + '.pt') np.save(config.log_dir + 'train_loss', np.array(train_evol)) np.save(config.log_dir + 'val_loss', np.array(val_evol)) # import pdb;pdb.set_trace() torch.save(autoencoder.state_dict(), config.log_dir + save_name + '_' + str(epoch + 99) + '.pt')
def source_separate(self): sess = tf.Session() self.load_model(sess, log_dir=self.config.log_dir) val_generator = data_gen(self.config) count_batch = 0 for batch_count, [ out_audios, out_envelopes, out_features, total_count ] in enumerate(val_generator): out_envelopes_bass = np.copy(out_envelopes) out_envelopes_bass[:, :, 1:3] = 0 out_envelopes_mid = np.copy(out_envelopes) out_envelopes_mid[:, :, 0] = 0 out_envelopes_mid[:, :, 2] = 0 out_envelopes_high = np.copy(out_envelopes) out_envelopes_high[:, :, :2] = 0 feed_dict = {self.input_placeholder: out_envelopes_bass[:,:,:self.config.rhyfeats], self.cond_placeholder: out_features,\ self.output_placeholder: out_audios, self.is_train: False} output_bass = sess.run(self.output_wav, feed_dict=feed_dict) feed_dict = {self.input_placeholder: out_envelopes_mid[:,:,:self.config.rhyfeats], self.cond_placeholder: out_features,\ self.output_placeholder: out_audios, self.is_train: False} output_mid = sess.run(self.output_wav, feed_dict=feed_dict) feed_dict = {self.input_placeholder: out_envelopes_high[:,:,:self.config.rhyfeats], self.cond_placeholder: out_features,\ self.output_placeholder: out_audios, self.is_train: False} output_high = sess.run(self.output_wav, feed_dict=feed_dict) for count in range(self.config.batch_size): if self.config.model == "spec": out_audio_bass = utils.griffinlim( np.exp(output_bass[count]) - 1, self.config) out_audio_mid = utils.griffinlim( np.exp(output_mid[count]) - 1, self.config) out_audio_high = utils.griffinlim( np.exp(output_high[count]) - 1, self.config) else: out_audio_bass = output_bass[count] out_audio_mid = output_mid[count] out_audio_high = output_high[count] output_file_bass = os.path.join( self.config.output_dir, 'output_{}_{}_{}_bass.wav'.format(batch_count, count, self.config.model)) sf.write(output_file_bass, np.clip(out_audio_bass, -1, 1), self.config.fs) output_file_mid = os.path.join( self.config.output_dir, 'output_{}_{}_{}_mid.wav'.format(batch_count, count, self.config.model)) sf.write(output_file_mid, np.clip(out_audio_mid, -1, 1), self.config.fs) output_file_high = os.path.join( self.config.output_dir, 'output_{}_{}_{}_high.wav'.format(batch_count, count, self.config.model)) sf.write(output_file_high, np.clip(out_audio_high, -1, 1), self.config.fs) sf.write( os.path.join(self.config.output_dir, 'gt_{}_{}.wav'.format(batch_count, count)), out_audios[count], self.config.fs) utils.progress(batch_count, total_count)
def trainNetwork(dataset='model6'): save_name = 'dn_model' # Encoder denoiser_vocals = Encoder().cuda() autoencoder = AutoEncoder() autoencoder.load_state_dict(torch.load(config.log_dir + dataset + '.pt')) optimizer = torch.optim.SGD(denoiser_vocals.parameters(), 1e-6) loss_func = nn.L1Loss(size_average=False) optimizer.zero_grad() train_evol = [] eval_evol = [] for epoch in range(config.dn_num_epochs): start_time = time.time() train_gen = data_gen() val_gen = data_gen(mode="Val") optimizer.zero_grad() train_loss = 0 eval_loss = 0 count = 0 for inputs, targets in train_gen: output = autoencoder(Variable(torch.FloatTensor(inputs))).cuda() target_vocals = targets[:, :2, :, :] target_drums = targets[:, 2:4, :, :] target_bass = targets[:, 4:6, :, :] target_others = targets[:, 6:, :, :] vocals = output[:, :2, :, :] drums = output[:, 2:4, :, :] bass = output[:, 4:6, :, :] others = output[:, 6:, :, :] total_sources = vocals + bass + drums + others mask_vocals = vocals / total_sources mask_drums = drums / total_sources mask_bass = bass / total_sources mask_others = others / total_sources out_vocals = vocals * mask_vocals out_drums = drums * mask_drums out_bass = bass * mask_bass out_others = others * mask_others input_vocals = Variable(out_vocals) denoised_vocals = denoiser_vocals(input_vocals).cuda() step_loss = loss_func( denoised_vocals, Variable(torch.cuda.FloatTensor(target_vocals), requires_grad=False)) train_loss += step_loss.item() step_loss.backward() optimizer.step() utils.progress(count, config.batches_per_epoch_train, suffix='training done') count += 1 train_evol.append(train_loss) count = 0 for inputs, targets in val_gen: out_sources = autoencoder(Variable( torch.FloatTensor(inputs))).cuda() vocals = output[:, :2, :, :] drums = output[:, 2:4, :, :] bass = output[:, 4:6, :, :] others = output[:, 6:, :, :] target_vocals = targets[:, :2, :, :] target_drums = targets[:, 2:4, :, :] target_bass = targets[:, 4:6, :, :] target_otherss = targets[:, 6:, :, :] total_sources = vocals + bass + drums + others mask_vocals = vocals / total_sources mask_drums = drums / total_sources mask_bass = bass / total_sources mask_others = others / total_sources out_vocals = vocals * mask_vocals out_drums = drums * mask_drums out_bass = bass * mask_bass out_others = others * mask_others input_vocals = Variable(out_vocals) denoised_vocals = denoiser_vocals(input_vocals).cuda() step_loss = loss_func( denoised_vocals, Variable(torch.cuda.FloatTensor(target_vocals), requires_grad=False)) eval_loss += step_loss.item() utils.progress(count, config.batches_per_epoch_val, suffix='validation done') count += 1 eval_evol.append(eval_loss) duration = time.time() - start_time if (epoch + 1) % config.print_every == 0: print('epoch %d/%d, took %.2f seconds, epoch total loss: %.7f' % (epoch + 1, config.num_epochs, duration, train_loss / (config.batches_per_epoch_train * count * config.max_phr_len * 513))) print( ' validation total loss: %.7f' % (eval_loss / (config.batches_per_epoch_train * count * config.max_phr_len * 513))) if (epoch + 1) % config.save_every == 0: torch.save( denoiser_vocals.state_dict(), config.dn_log_dir + save_name + '_' + str(epoch) + '.pt') np.save(config.dn_log_dir + 'dn_train_loss', np.array(train_evol)) np.save(config.dn_log_dir + 'dn_val_loss', np.array(eval_evol))
def train(_): stat_file = h5py.File(config.stat_dir+'stats.hdf5', mode='r') max_feat = np.array(stat_file["feats_maximus"]) min_feat = np.array(stat_file["feats_minimus"]) with tf.Graph().as_default(): input_placeholder = tf.placeholder(tf.float32, shape=(config.batch_size,config.max_phr_len,config.input_features),name='input_placeholder') tf.summary.histogram('inputs', input_placeholder) target_placeholder = tf.placeholder(tf.float32, shape=(config.batch_size,config.max_phr_len,3),name='target_placeholder') tf.summary.histogram('targets', target_placeholder) with tf.variable_scope('First_Model') as scope: f0, f0_1, vuv = modules.f0_network(input_placeholder) # tf.summary.histogram('initial_output', op) # tf.summary.histogram('harm', harm) # tf.summary.histogram('ap', ap) tf.summary.histogram('f0', f0) tf.summary.histogram('vuv', vuv) # initial_loss = tf.reduce_sum(tf.abs(op - target_placeholder[:,:,:60])*np.linspace(1.0,0.7,60)*(1-target_placeholder[:,:,-1:])) # harm_loss = tf.reduce_sum(tf.abs(harm - target_placeholder[:,:,:60])*np.linspace(1.0,0.7,60)*(1-target_placeholder[:,:,-1:])) # ap_loss = tf.reduce_sum(tf.abs(ap - target_placeholder[:,:,60:-2])*(1-target_placeholder[:,:,-1:])) f0_loss_1 = tf.reduce_sum(tf.abs(f0 - target_placeholder[:,:,-3:-2])*(1-target_placeholder[:,:,-1:])) f0_loss_2 = tf.reduce_sum(tf.abs(f0_1 - target_placeholder[:,:,-2:-1])*(1-target_placeholder[:,:,-1:])) # vuv_loss = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels=, logits=vuv)) vuv_loss = tf.reduce_sum(binary_cross(target_placeholder[:,:,-1:],vuv)) loss = f0_loss_1 + vuv_loss + f0_loss_2 # initial_summary = tf.summary.scalar('initial_loss', initial_loss) # harm_summary = tf.summary.scalar('harm_loss', harm_loss) # ap_summary = tf.summary.scalar('ap_loss', ap_loss) f0_summary_1 = tf.summary.scalar('f0_loss_1', f0_loss_1) f0_summary_2 = tf.summary.scalar('f0_loss_2', f0_loss_2) vuv_summary = tf.summary.scalar('vuv_loss', vuv_loss) loss_summary = tf.summary.scalar('total_loss', loss) global_step = tf.Variable(0, name='global_step', trainable=False) optimizer = tf.train.AdamOptimizer(learning_rate = config.init_lr) # optimizer_f0 = tf.train.AdamOptimizer(learning_rate = config.init_lr) train_function = optimizer.minimize(loss, global_step= global_step) # train_f0 = optimizer.minimize(f0_loss, global_step= global_step) # train_harm = optimizer.minimize(harm_loss, global_step= global_step) # train_ap = optimizer.minimize(ap_loss, global_step= global_step) # train_f0 = optimizer.minimize(f0_loss, global_step= global_step) # train_vuv = optimizer.minimize(vuv_loss, global_step= global_step) summary = tf.summary.merge_all() init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) saver = tf.train.Saver(max_to_keep= config.max_models_to_keep) sess = tf.Session() sess.run(init_op) ckpt = tf.train.get_checkpoint_state(config.log_dir) if ckpt and ckpt.model_checkpoint_path: print("Using the model in %s"%ckpt.model_checkpoint_path) saver.restore(sess, ckpt.model_checkpoint_path) train_summary_writer = tf.summary.FileWriter(config.log_dir+'train/', sess.graph) val_summary_writer = tf.summary.FileWriter(config.log_dir+'val/', sess.graph) start_epoch = int(sess.run(tf.train.get_global_step())/(config.batches_per_epoch_train)) print("Start from: %d" % start_epoch) f0_accs = [] for epoch in xrange(start_epoch, config.num_epochs): val_f0_accs_1 = [] val_f0_accs_2 = [] data_generator = data_gen() start_time = time.time() epoch_loss_harm = 0 epoch_loss_ap = 0 epoch_loss_f0_1 = 0 epoch_loss_f0_2 = 0 epoch_loss_vuv = 0 epoch_total_loss = 0 # epoch_initial_loss = 0 epoch_loss_harm_val = 0 epoch_loss_ap_val = 0 epoch_loss_f0_val_1 = 0 epoch_loss_f0_val_2 = 0 epoch_loss_vuv_val = 0 epoch_total_loss_val = 0 # epoch_initial_loss_val = 0 if config.use_gan: epoch_loss_generator_GAN = 0 epoch_loss_generator_diff = 0 epoch_loss_discriminator_real = 0 epoch_loss_discriminator_fake = 0 val_epoch_loss_generator_GAN = 0 val_epoch_loss_generator_diff = 0 val_epoch_loss_discriminator_real = 0 val_epoch_loss_discriminator_fake = 0 batch_num = 0 batch_num_val = 0 val_generator = data_gen(mode='val') # val_generator = get_batches(train_filename=config.h5py_file_val, batches_per_epoch=config.batches_per_epoch_val) with tf.variable_scope('Training'): for voc, feat in data_generator: _, step_loss_f0_1,step_loss_f0_2, step_loss_vuv, step_total_loss = sess.run([train_function, f0_loss_1,f0_loss_2, vuv_loss, loss], feed_dict={input_placeholder: voc,target_placeholder: feat}) # _, step_loss_f0 = sess.run([train_f0, f0_loss], feed_dict={input_placeholder: voc,target_placeholder: feat}) if config.use_gan: _, step_dis_loss_real, step_dis_loss_fake = sess.run([d_optimizer, D_loss_real,D_loss_fake], feed_dict={input_placeholder: voc,target_placeholder: feat}) _, step_gen_loss_GAN, step_gen_loss_diff = sess.run([g_optimizer, G_loss_GAN, G_loss_diff], feed_dict={input_placeholder: voc,target_placeholder: feat}) # else : # _, step_dis_loss_real, step_dis_loss_fake = sess.run([d_optimizer_grad, D_loss_real,D_loss_fake], feed_dict={input_placeholder: voc,target_placeholder: feat}) # _, step_gen_loss_diff = sess.run([g_optimizer_diff, G_loss_diff], feed_dict={input_placeholder: voc,target_placeholder: feat}) # step_gen_loss_GAN = 0 # _, step_loss_harm = sess.run([train_harm, harm_loss], feed_dict={input_placeholder: voc,target_placeholder: feat}) # _, step_loss_ap = sess.run([train_ap, ap_loss], feed_dict={input_placeholder: voc,target_placeholder: feat}) # _, step_loss_f0 = sess.run([train_f0, f0_loss], feed_dict={input_placeholder: voc,target_placeholder: feat}) # _, step_loss_vuv = sess.run([train_vuv, vuv_loss], feed_dict={input_placeholder: voc,target_placeholder: feat}) # epoch_initial_loss+=step_initial_loss # epoch_loss_harm+=step_loss_harm # epoch_loss_ap+=step_loss_ap epoch_loss_f0_1+=step_loss_f0_1 epoch_loss_f0_2+=step_loss_f0_2 epoch_loss_vuv+=step_loss_vuv epoch_total_loss+=step_total_loss if config.use_gan: epoch_loss_generator_GAN+=step_gen_loss_GAN epoch_loss_generator_diff+=step_gen_loss_diff epoch_loss_discriminator_real+=step_dis_loss_real epoch_loss_discriminator_fake+=step_dis_loss_fake utils.progress(batch_num,config.batches_per_epoch_train, suffix = 'training done') batch_num+=1 # epoch_initial_loss = epoch_initial_loss/(config.batches_per_epoch_train *config.batch_size*config.max_phr_len*60) # epoch_loss_harm = epoch_loss_harm/(config.batches_per_epoch_train *config.batch_size*config.max_phr_len*60) # epoch_loss_ap = epoch_loss_ap/(config.batches_per_epoch_train *config.batch_size*config.max_phr_len*4) epoch_loss_f0_1 = epoch_loss_f0_1/(config.batches_per_epoch_train *config.batch_size*config.max_phr_len) epoch_loss_f0_2 = epoch_loss_f0_2/(config.batches_per_epoch_train *config.batch_size*config.max_phr_len) epoch_loss_vuv = epoch_loss_vuv/(config.batches_per_epoch_train *config.batch_size*config.max_phr_len) epoch_total_loss = epoch_total_loss/(config.batches_per_epoch_train *config.batch_size*config.max_phr_len*3) if config.use_gan: epoch_loss_generator_GAN = epoch_loss_generator_GAN/(config.batches_per_epoch_train *config.batch_size) epoch_loss_generator_diff = epoch_loss_generator_diff/(config.batches_per_epoch_train *config.batch_size*config.max_phr_len*60) epoch_loss_discriminator_real = epoch_loss_discriminator_real/(config.batches_per_epoch_train *config.batch_size) epoch_loss_discriminator_fake = epoch_loss_discriminator_fake/(config.batches_per_epoch_train *config.batch_size) summary_str = sess.run(summary, feed_dict={input_placeholder: voc,target_placeholder: feat}) train_summary_writer.add_summary(summary_str, epoch) # summary_writer.add_summary(summary_str_val, epoch) train_summary_writer.flush() with tf.variable_scope('Validation'): for voc, feat,nchunks_in, lent, county, max_count in val_generator: if (epoch + 1) % config.print_every == 0 or (epoch + 1) == config.num_epochs: if county == 1: f0_gt = [] vuv_gt = [] f0_output_1 = [] f0_output_2 = [] f0_op_1, f0_op_2 = sess.run([f0,f0_1],feed_dict={input_placeholder: voc,target_placeholder: feat}) f0_output_1.append(f0_op_1) f0_output_2.append(f0_op_2) f0_gt.append(feat[:,:,-2:-1]) vuv_gt.append(feat[:,:,-1:]) if county == max_count: f0_output_1 = utils.overlapadd(np.array(f0_output_1), nchunks_in) f0_output_2 = utils.overlapadd(np.array(f0_output_2), nchunks_in) f0_gt = utils.overlapadd(np.array(f0_gt), nchunks_in) vuv_gt = utils.overlapadd(np.array(vuv_gt), nchunks_in) f0_output_1 = f0_output_1[:lent] f0_output_2 = f0_output_2[:lent] f0_gt = f0_gt[:lent] vuv_gt = vuv_gt[:lent] f0_output_1 = f0_output_1*((max_feat[-2]-min_feat[-2])+min_feat[-2])*(1-vuv_gt) f0_output_2 = f0_output_2*((max_feat[-2]-min_feat[-2])+min_feat[-2])*(1-vuv_gt) f0_gt = f0_gt*((max_feat[-2]-min_feat[-2])+min_feat[-2])*(1-vuv_gt) # f0_output_1[f0_output_1 == 0] = np.nan # f0_gt[f0_gt == 0] = np.nan f0_difference_1 = np.nan_to_num(abs(f0_gt-f0_output_1)) f0_greater_1 = np.where(f0_difference_1>config.f0_threshold) diff_per_1 = f0_greater_1[0].shape[0]/len(f0_output_1) val_f0_accs_1.append(1 - diff_per_1) f0_difference_2 = np.nan_to_num(abs(f0_gt-f0_output_2)) f0_greater_2 = np.where(f0_difference_2>config.f0_threshold) diff_per_2 = f0_greater_2[0].shape[0]/len(f0_output_2) val_f0_accs_2.append(1 - diff_per_2) # import pdb;pdb.set_trace() # step_initial_loss_val = sess.run(initial_loss, feed_dict={input_placeholder: voc,target_placeholder: feat}) # step_loss_harm_val = sess.run(harm_loss, feed_dict={input_placeholder: voc,target_placeholder: feat}) # step_loss_ap_val = sess.run(ap_loss, feed_dict={input_placeholder: voc,target_placeholder: feat}) step_loss_f0_val_1 = sess.run(f0_loss_1, feed_dict={input_placeholder: voc,target_placeholder: feat}) step_loss_f0_val_2 = sess.run(f0_loss_2, feed_dict={input_placeholder: voc,target_placeholder: feat}) step_loss_vuv_val = sess.run(vuv_loss, feed_dict={input_placeholder: voc,target_placeholder: feat}) step_total_loss_val = sess.run(loss, feed_dict={input_placeholder: voc,target_placeholder: feat}) if config.use_gan: step_gen_loss_GAN, step_gen_loss_diff = sess.run([G_loss_GAN, G_loss_diff], feed_dict={input_placeholder: voc,target_placeholder: feat}) step_dis_loss_real,step_dis_loss_fake = sess.run([D_loss_real,D_loss_fake], feed_dict={input_placeholder: voc,target_placeholder: feat}) # epoch_initial_loss_val+=step_initial_loss_val # epoch_loss_harm_val+=step_loss_harm_val # epoch_loss_ap_val+=step_loss_ap_val epoch_loss_f0_val_1+=step_loss_f0_val_1 epoch_loss_f0_val_2+=step_loss_f0_val_2 epoch_loss_vuv_val+=step_loss_vuv_val epoch_total_loss_val+=step_total_loss_val if config.use_gan: val_epoch_loss_generator_GAN += step_gen_loss_GAN val_epoch_loss_generator_diff += step_gen_loss_diff val_epoch_loss_discriminator_real += step_dis_loss_real val_epoch_loss_discriminator_fake += step_dis_loss_fake utils.progress(batch_num_val,config.batches_per_epoch_val, suffix = 'validiation done') batch_num_val+=1 if (epoch + 1) % config.print_every == 0 or (epoch + 1) == config.num_epochs: f0_accs.append(np.mean(val_f0_accs_2)) # epoch_initial_loss_val = epoch_initial_loss_val/(config.batches_per_epoch_val *config.batch_size*config.max_phr_len*60) # epoch_loss_harm_val = epoch_loss_harm_val/(batch_num_val *config.batch_size*config.max_phr_len*60) # epoch_loss_ap_val = epoch_loss_ap_val/(batch_num_val *config.batch_size*config.max_phr_len*4) epoch_loss_f0_val_1 = epoch_loss_f0_val_1/(batch_num_val *config.batch_size*config.max_phr_len) epoch_loss_f0_val_2 = epoch_loss_f0_val_2/(batch_num_val *config.batch_size*config.max_phr_len) epoch_loss_vuv_val = epoch_loss_vuv_val/(batch_num_val *config.batch_size*config.max_phr_len) epoch_total_loss_val = epoch_total_loss_val/(batch_num_val *config.batch_size*config.max_phr_len*66) if config.use_gan: val_epoch_loss_generator_GAN = val_epoch_loss_generator_GAN/(config.batches_per_epoch_val *config.batch_size) val_epoch_loss_generator_diff = val_epoch_loss_generator_diff/(config.batches_per_epoch_val *config.batch_size*config.max_phr_len*60) val_epoch_loss_discriminator_real = val_epoch_loss_discriminator_real/(config.batches_per_epoch_val *config.batch_size) val_epoch_loss_discriminator_fake = val_epoch_loss_discriminator_fake/(config.batches_per_epoch_val *config.batch_size) summary_str = sess.run(summary, feed_dict={input_placeholder: voc,target_placeholder: feat}) val_summary_writer.add_summary(summary_str, epoch) # summary_writer.add_summary(summary_str_val, epoch) val_summary_writer.flush() duration = time.time() - start_time np.save('./ikala_eval/accuracies', f0_accs) if (epoch+1) % config.print_every == 0: print('epoch %d: F0 Training Loss = %.10f (%.3f sec)' % (epoch+1, epoch_loss_f0_1, duration)) # print(' : Ap Training Loss = %.10f ' % (epoch_loss_ap)) # print(' : F0 Training Loss = %.10f ' % (epoch_loss_f0)) print(' : VUV Training Loss = %.10f ' % (epoch_loss_vuv)) # print(' : Initial Training Loss = %.10f ' % (epoch_initial_loss)) if config.use_gan: print(' : Gen GAN Training Loss = %.10f ' % (epoch_loss_generator_GAN)) print(' : Gen diff Training Loss = %.10f ' % (epoch_loss_generator_diff)) print(' : Discriminator Training Loss Real = %.10f ' % (epoch_loss_discriminator_real)) print(' : Discriminator Training Loss Fake = %.10f ' % (epoch_loss_discriminator_fake)) # print(' : Harm Validation Loss = %.10f ' % (epoch_loss_harm_val)) # print(' : Ap Validation Loss = %.10f ' % (epoch_loss_ap_val)) print(' : F0 Validation Loss_1 = %.10f ' % (epoch_loss_f0_val_1)) print(' : F0 Validation Loss_2 = %.10f ' % (epoch_loss_f0_val_2)) print(' : VUV Validation Loss = %.10f ' % (epoch_loss_vuv_val)) if (epoch + 1) % config.print_every == 0 or (epoch + 1) == config.num_epochs: print(' : Mean F0 IKala Accuracy_1 = %.10f ' % (np.mean(val_f0_accs_1))) print(' : Mean F0 IKala Accuracy_2 = %.10f ' % (np.mean(val_f0_accs_2))) # print(' : Mean F0 IKala Accuracy = '+'%{1:.{0}f}%'.format(np.mean(val_f0_accs))) # print(' : Initial Validation Loss = %.10f ' % (epoch_initial_loss_val)) if config.use_gan: print(' : Gen GAN Validation Loss = %.10f ' % (val_epoch_loss_generator_GAN)) print(' : Gen diff Validation Loss = %.10f ' % (val_epoch_loss_generator_diff)) print(' : Discriminator Validation Loss Real = %.10f ' % (val_epoch_loss_discriminator_real)) print(' : Discriminator Validation Loss Fake = %.10f ' % (val_epoch_loss_discriminator_fake)) if (epoch + 1) % config.save_every == 0 or (epoch + 1) == config.num_epochs: # utils.list_to_file(val_f0_accs,'./ikala_eval/accuracies_'+str(epoch+1)+'.txt') checkpoint_file = os.path.join(config.log_dir, 'model.ckpt') saver.save(sess, checkpoint_file, global_step=epoch)