def decode(): """Decoding the inputs using current model.""" tfrecords_lst, num_batches = read_list_file('tt_tf', FLAGS.batch_size) with tf.Graph().as_default(): with tf.device('/cpu:0'): with tf.name_scope('input'): tt_mixed, tt_labels, tt_genders, tt_lengths = get_padded_batch( tfrecords_lst, FLAGS.batch_size, FLAGS.input_size * 2, FLAGS.output_size * 2, num_enqueuing_threads=1, num_epochs=1, shuffle=False) tt_inputs = tf.slice(tt_mixed, [0, 0, 0], [-1, -1, FLAGS.input_size]) tt_angles = tf.slice(tt_mixed, [0, 0, FLAGS.input_size], [-1, -1, -1]) # Create two models with train_input and val_input individually. with tf.name_scope('model'): model = LSTM(FLAGS, tt_inputs, tt_labels, tt_lengths, tt_genders, infer=True) init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) sess = tf.Session() sess.run(init) ckpt = tf.train.get_checkpoint_state(FLAGS.save_dir + '/nnet') if ckpt and ckpt.model_checkpoint_path: tf.logging.info("Restore from " + ckpt.model_checkpoint_path) model.saver.restore(sess, ckpt.model_checkpoint_path) else: tf.logging.fatal("checkpoint not fou1nd.") sys.exit(-1) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) #cmvn_filename = os.path.join(FLAGS.date_dir, "/train_cmvn.npz") #if os.path.isfile(cmvn_filename): # cmvn = np.load(cmvn_filename) #else: # tf.logging.fatal("%s not exist, exit now." % cmvn_filename) # sys.exit(-1) data_dir = FLAGS.data_dir if not os.path.exists(data_dir): os.makedirs(data_dir) processed = 0 try: for batch in xrange(num_batches): if coord.should_stop(): break if FLAGS.assign == 'def': cleaned1, cleaned2, angles, lengths = sess.run( [model._cleaned1, model._cleaned2, tt_angles, tt_lengths]) else: x1, x2 = model.get_opt_output() cleaned1, cleaned2 = sess.run([x1, x2]) spec1 = cleaned1 * np.exp(angles * 1j) spec2 = cleaned2 * np.exp(angles * 1j) #sequence = activations * cmvn['stddev_labels'] + \ # cmvn['mean_labels'] for i in range(0, FLAGS.batch_size): tffilename = tfrecords_lst[i + processed] (_, name) = os.path.split(tffilename) (partname, _) = os.path.splitext(name) wav_name1 = data_dir + '/' + partname + '_1.wav' wav_name2 = data_dir + '/' + partname + '_2.wav' wav1 = istft(spec1[i, 0:lengths[i], :], size=256, shift=128) wav2 = istft(spec2[i, 0:lengths[i], :], size=256, shift=128) audiowrite(wav1, wav_name1, 8000, True, True) audiowrite(wav2, wav_name2, 8000, True, True) processed = processed + FLAGS.batch_size if batch % 50 == 0: print(batch) except Exception, e: # Report exceptions to the coordinator. coord.request_stop(e)
def decode(): """Decoding the inputs using current model.""" tfrecords_lst, num_batches = read_list_file('tt', FLAGS.batch_size) with tf.Graph().as_default(): with tf.device('/cpu:0'): with tf.name_scope('input'): tt_mixed, tt_inputs, tt_labels1, tt_labels2, tt_lengths = get_padded_batch_v2( tfrecords_lst, 1, FLAGS.input_size, FLAGS.output_size, num_enqueuing_threads=1, num_epochs=1, shuffle=False) # Create two models with train_input and val_input individually. with tf.name_scope('model'): model = LSTM(FLAGS, tt_inputs, tt_mixed, tt_labels1, tt_labels2, tt_lengths, infer=True) init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) sess = tf.Session() sess.run(init) ckpt = tf.train.get_checkpoint_state(FLAGS.save_dir + '/nnet') if ckpt and ckpt.model_checkpoint_path: tf.logging.info("Restore from " + ckpt.model_checkpoint_path) model.saver.restore(sess, ckpt.model_checkpoint_path) else: tf.logging.fatal("checkpoint not fou1nd.") sys.exit(-1) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) #cmvn_filename = os.path.join(FLAGS.date_dir, "/train_cmvn.npz") #if os.path.isfile(cmvn_filename): # cmvn = np.load(cmvn_filename) #else: # tf.logging.fatal("%s not exist, exit now." % cmvn_filename) # sys.exit(-1) data_dir = FLAGS.data_dir if not os.path.exists(data_dir): os.makedirs(data_dir) try: for batch in xrange(num_batches): if coord.should_stop(): break if FLAGS.assign == 'def': cleaned1, cleaned2 = sess.run( [model._cleaned1, model._cleaned2]) else: x1, x2 = model.get_opt_output() cleaned1, cleaned2 = sess.run([x1, x2]) #sequence = activations * cmvn['stddev_labels'] + \ # cmvn['mean_labels'] tffilename = tfrecords_lst[batch] (_, name) = os.path.split(tffilename) (uttid, _) = os.path.splitext(name) (partname, _) = os.path.splitext(uttid) #np.savetxt('data/mask/'+partname + '_1.mask', m1) #np.savetxt('data/mask/'+partname + '_2.mask', m2) kaldi_writer1 = kio.ArkWriter(data_dir + '/' + partname + '_1.wav.scp') kaldi_writer2 = kio.ArkWriter(data_dir + '/' + partname + '_2.wav.scp') kaldi_writer1.write_next_utt( data_dir + '/' + partname + '_1.wav.ark', uttid, cleaned1[0, :, :]) kaldi_writer2.write_next_utt( data_dir + '/' + partname + '_2.wav.ark', uttid, cleaned2[0, :, :]) kaldi_writer1.close() kaldi_writer2.close() if batch % 500 == 0: print(batch) except Exception, e: # Report exceptions to the coordinator. coord.request_stop(e)