def train(): tr_tfrecords_lst, tr_num_batches = read_list_file("tr_tf", FLAGS.batch_size) val_tfrecords_lst, val_num_batches = read_list_file( "cv_tf", FLAGS.batch_size) with tf.Graph().as_default(): with tf.device('/cpu:0'): with tf.name_scope('input'): tr_mixed, tr_labels, tr_genders, tr_lengths = get_padded_batch( tr_tfrecords_lst, FLAGS.batch_size, FLAGS.input_size * 2, FLAGS.output_size * 2, num_enqueuing_threads=FLAGS.num_threads, num_epochs=FLAGS.max_epochs) val_mixed, val_labels, val_genders, val_lengths = get_padded_batch( val_tfrecords_lst, FLAGS.batch_size, FLAGS.input_size * 2, FLAGS.output_size * 2, num_enqueuing_threads=FLAGS.num_threads, num_epochs=FLAGS.max_epochs + 1) tr_inputs = tf.slice(tr_mixed, [0, 0, 0], [-1, -1, FLAGS.input_size]) val_inputs = tf.slice(val_mixed, [0, 0, 0], [-1, -1, FLAGS.input_size]) with tf.name_scope('model'): tr_model = LSTM(FLAGS, tr_inputs, tr_labels, tr_lengths, tr_genders) # tr_model and val_model should share variables tf.get_variable_scope().reuse_variables() val_model = LSTM(FLAGS, val_inputs, val_labels, val_lengths, val_genders) show_all_variables() init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) # Prevent exhausting all the gpu memories. config = tf.ConfigProto() config.gpu_options.allow_growth = True # config.gpu_options.per_process_gpu_memory_fraction = 0.5 config.allow_soft_placement = True #sess = tf.InteractiveSession(config=config) sess = tf.Session(config=config) sess.run(init) if FLAGS.resume_training.lower() == 'true': ckpt = tf.train.get_checkpoint_state(FLAGS.save_dir + '/nnet') if ckpt and ckpt.model_checkpoint_path: tf.logging.info("restore from" + ckpt.model_checkpoint_path) tr_model.saver.restore(sess, ckpt.model_checkpoint_path) best_path = ckpt.model_checkpoint_path else: tf.logging.fatal("checkpoint not found") coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) try: # Cross validation before training. loss_prev = eval_one_epoch(sess, coord, val_model, val_num_batches) tf.logging.info("CROSSVAL PRERUN AVG.LOSS %.4F" % loss_prev) sess.run(tf.assign(tr_model.lr, FLAGS.learning_rate)) for epoch in xrange(FLAGS.max_epochs): start_time = time.time() # Training tr_loss = train_one_epoch(sess, coord, tr_model, tr_num_batches) # Validation val_loss = eval_one_epoch(sess, coord, val_model, val_num_batches) end_time = time.time() # Determine checkpoint path ckpt_name = "nnet_iter%d_lrate%e_tr%.4f_cv%.4f" % ( epoch + 1, FLAGS.learning_rate, tr_loss, val_loss) ckpt_dir = FLAGS.save_dir + '/nnet' if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir) ckpt_path = os.path.join(ckpt_dir, ckpt_name) # Relative loss between previous and current val_loss rel_impr = (loss_prev - val_loss) / loss_prev # Accept or reject new parameters if val_loss < loss_prev: tr_model.saver.save(sess, ckpt_path) # Logging train loss along with validation loss loss_prev = val_loss best_path = ckpt_path tf.logging.info( "ITERATION %d: TRAIN AVG.LOSS %.4f, (lrate%e) CROSSVAL" " AVG.LOSS %.4f, %s (%s), TIME USED: %.2fs" % (epoch + 1, tr_loss, FLAGS.learning_rate, val_loss, "nnet accepted", ckpt_name, (end_time - start_time) / 1)) else: tr_model.saver.restore(sess, best_path) tf.logging.info( "ITERATION %d: TRAIN AVG.LOSS %.4f, (lrate%e) CROSSVAL" " AVG.LOSS %.4f, %s, (%s), TIME USED: %.2fs" % (epoch + 1, tr_loss, FLAGS.learning_rate, val_loss, "nnet rejected", ckpt_name, (end_time - start_time) / 1)) # Start halving when improvement is low if rel_impr < FLAGS.start_halving_impr: FLAGS.learning_rate *= FLAGS.halving_factor sess.run(tf.assign(tr_model.lr, FLAGS.learning_rate)) # Stopping criterion if rel_impr < FLAGS.end_halving_impr: if epoch < FLAGS.min_epochs: tf.logging.info( "we were supposed to finish, but we continue as " "min_epochs : %s" % FLAGS.min_epochs) continue else: tf.logging.info( "finished, too small rel. improvement %g" % rel_impr) break except Exception, e: # Report exceptions to the coordinator. coord.request_stop(e) finally:
def decode(): """Decoding the inputs using current model.""" tfrecords_lst, num_batches = read_list_file('tt_tf', FLAGS.batch_size) with tf.Graph().as_default(): with tf.device('/cpu:0'): with tf.name_scope('input'): tt_mixed, tt_labels, tt_genders, tt_lengths = get_padded_batch( tfrecords_lst, FLAGS.batch_size, FLAGS.input_size * 2, FLAGS.output_size * 2, num_enqueuing_threads=1, num_epochs=1, shuffle=False) tt_inputs = tf.slice(tt_mixed, [0, 0, 0], [-1, -1, FLAGS.input_size]) tt_angles = tf.slice(tt_mixed, [0, 0, FLAGS.input_size], [-1, -1, -1]) # Create two models with train_input and val_input individually. with tf.name_scope('model'): model = LSTM(FLAGS, tt_inputs, tt_labels, tt_lengths, tt_genders, infer=True) init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) sess = tf.Session() sess.run(init) ckpt = tf.train.get_checkpoint_state(FLAGS.save_dir + '/nnet') if ckpt and ckpt.model_checkpoint_path: tf.logging.info("Restore from " + ckpt.model_checkpoint_path) model.saver.restore(sess, ckpt.model_checkpoint_path) else: tf.logging.fatal("checkpoint not fou1nd.") sys.exit(-1) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) #cmvn_filename = os.path.join(FLAGS.date_dir, "/train_cmvn.npz") #if os.path.isfile(cmvn_filename): # cmvn = np.load(cmvn_filename) #else: # tf.logging.fatal("%s not exist, exit now." % cmvn_filename) # sys.exit(-1) data_dir = FLAGS.data_dir if not os.path.exists(data_dir): os.makedirs(data_dir) processed = 0 try: for batch in xrange(num_batches): if coord.should_stop(): break if FLAGS.assign == 'def': cleaned1, cleaned2, angles, lengths = sess.run( [model._cleaned1, model._cleaned2, tt_angles, tt_lengths]) else: x1, x2 = model.get_opt_output() cleaned1, cleaned2 = sess.run([x1, x2]) spec1 = cleaned1 * np.exp(angles * 1j) spec2 = cleaned2 * np.exp(angles * 1j) #sequence = activations * cmvn['stddev_labels'] + \ # cmvn['mean_labels'] for i in range(0, FLAGS.batch_size): tffilename = tfrecords_lst[i + processed] (_, name) = os.path.split(tffilename) (partname, _) = os.path.splitext(name) wav_name1 = data_dir + '/' + partname + '_1.wav' wav_name2 = data_dir + '/' + partname + '_2.wav' wav1 = istft(spec1[i, 0:lengths[i], :], size=256, shift=128) wav2 = istft(spec2[i, 0:lengths[i], :], size=256, shift=128) audiowrite(wav1, wav_name1, 8000, True, True) audiowrite(wav2, wav_name2, 8000, True, True) processed = processed + FLAGS.batch_size if batch % 50 == 0: print(batch) except Exception, e: # Report exceptions to the coordinator. coord.request_stop(e)
def decode(): """Decoding the inputs using current model.""" tfrecords_lst, num_batches = read_list_file('tt', FLAGS.batch_size) with tf.Graph().as_default(): with tf.device('/cpu:0'): with tf.name_scope('input'): tt_mixed, tt_inputs, tt_labels1, tt_labels2, tt_lengths = get_padded_batch_v2( tfrecords_lst, 1, FLAGS.input_size, FLAGS.output_size, num_enqueuing_threads=1, num_epochs=1, shuffle=False) # Create two models with train_input and val_input individually. with tf.name_scope('model'): model = LSTM(FLAGS, tt_inputs, tt_mixed, tt_labels1, tt_labels2, tt_lengths, infer=True) init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) sess = tf.Session() sess.run(init) ckpt = tf.train.get_checkpoint_state(FLAGS.save_dir + '/nnet') if ckpt and ckpt.model_checkpoint_path: tf.logging.info("Restore from " + ckpt.model_checkpoint_path) model.saver.restore(sess, ckpt.model_checkpoint_path) else: tf.logging.fatal("checkpoint not fou1nd.") sys.exit(-1) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) #cmvn_filename = os.path.join(FLAGS.date_dir, "/train_cmvn.npz") #if os.path.isfile(cmvn_filename): # cmvn = np.load(cmvn_filename) #else: # tf.logging.fatal("%s not exist, exit now." % cmvn_filename) # sys.exit(-1) data_dir = FLAGS.data_dir if not os.path.exists(data_dir): os.makedirs(data_dir) try: for batch in xrange(num_batches): if coord.should_stop(): break if FLAGS.assign == 'def': cleaned1, cleaned2 = sess.run( [model._cleaned1, model._cleaned2]) else: x1, x2 = model.get_opt_output() cleaned1, cleaned2 = sess.run([x1, x2]) #sequence = activations * cmvn['stddev_labels'] + \ # cmvn['mean_labels'] tffilename = tfrecords_lst[batch] (_, name) = os.path.split(tffilename) (uttid, _) = os.path.splitext(name) (partname, _) = os.path.splitext(uttid) #np.savetxt('data/mask/'+partname + '_1.mask', m1) #np.savetxt('data/mask/'+partname + '_2.mask', m2) kaldi_writer1 = kio.ArkWriter(data_dir + '/' + partname + '_1.wav.scp') kaldi_writer2 = kio.ArkWriter(data_dir + '/' + partname + '_2.wav.scp') kaldi_writer1.write_next_utt( data_dir + '/' + partname + '_1.wav.ark', uttid, cleaned1[0, :, :]) kaldi_writer2.write_next_utt( data_dir + '/' + partname + '_2.wav.ark', uttid, cleaned2[0, :, :]) kaldi_writer1.close() kaldi_writer2.close() if batch % 500 == 0: print(batch) except Exception, e: # Report exceptions to the coordinator. coord.request_stop(e)