예제 #1
0
def decode():
    """Decoding the inputs using current model."""
    tfrecords_lst, num_batches = read_list_file('tt_tf', FLAGS.batch_size)

    with tf.Graph().as_default():
        with tf.device('/cpu:0'):
            with tf.name_scope('input'):
                tt_mixed, tt_labels, tt_genders, tt_lengths = get_padded_batch(
                    tfrecords_lst,
                    FLAGS.batch_size,
                    FLAGS.input_size * 2,
                    FLAGS.output_size * 2,
                    num_enqueuing_threads=1,
                    num_epochs=1,
                    shuffle=False)
                tt_inputs = tf.slice(tt_mixed, [0, 0, 0],
                                     [-1, -1, FLAGS.input_size])
                tt_angles = tf.slice(tt_mixed, [0, 0, FLAGS.input_size],
                                     [-1, -1, -1])
        # Create two models with train_input and val_input individually.
        with tf.name_scope('model'):
            model = LSTM(FLAGS,
                         tt_inputs,
                         tt_labels,
                         tt_lengths,
                         tt_genders,
                         infer=True)

        init = tf.group(tf.global_variables_initializer(),
                        tf.local_variables_initializer())

        sess = tf.Session()

        sess.run(init)

        ckpt = tf.train.get_checkpoint_state(FLAGS.save_dir + '/nnet')
        if ckpt and ckpt.model_checkpoint_path:
            tf.logging.info("Restore from " + ckpt.model_checkpoint_path)
            model.saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            tf.logging.fatal("checkpoint not fou1nd.")
            sys.exit(-1)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    #cmvn_filename = os.path.join(FLAGS.date_dir, "/train_cmvn.npz")
    #if os.path.isfile(cmvn_filename):
    #    cmvn = np.load(cmvn_filename)
    #else:
    #    tf.logging.fatal("%s not exist, exit now." % cmvn_filename)
    #    sys.exit(-1)

    data_dir = FLAGS.data_dir
    if not os.path.exists(data_dir):
        os.makedirs(data_dir)
    processed = 0
    try:
        for batch in xrange(num_batches):
            if coord.should_stop():
                break
            if FLAGS.assign == 'def':
                cleaned1, cleaned2, angles, lengths = sess.run(
                    [model._cleaned1, model._cleaned2, tt_angles, tt_lengths])
            else:
                x1, x2 = model.get_opt_output()
                cleaned1, cleaned2 = sess.run([x1, x2])
            spec1 = cleaned1 * np.exp(angles * 1j)
            spec2 = cleaned2 * np.exp(angles * 1j)
            #sequence = activations * cmvn['stddev_labels'] + \
            #    cmvn['mean_labels']
            for i in range(0, FLAGS.batch_size):
                tffilename = tfrecords_lst[i + processed]
                (_, name) = os.path.split(tffilename)
                (partname, _) = os.path.splitext(name)
                wav_name1 = data_dir + '/' + partname + '_1.wav'
                wav_name2 = data_dir + '/' + partname + '_2.wav'
                wav1 = istft(spec1[i, 0:lengths[i], :], size=256, shift=128)
                wav2 = istft(spec2[i, 0:lengths[i], :], size=256, shift=128)
                audiowrite(wav1, wav_name1, 8000, True, True)
                audiowrite(wav2, wav_name2, 8000, True, True)
            processed = processed + FLAGS.batch_size

            if batch % 50 == 0:
                print(batch)

    except Exception, e:
        # Report exceptions to the coordinator.
        coord.request_stop(e)
예제 #2
0
def decode():
    """Decoding the inputs using current model."""
    tfrecords_lst, num_batches = read_list_file('tt', FLAGS.batch_size)

    with tf.Graph().as_default():
        with tf.device('/cpu:0'):
            with tf.name_scope('input'):
                tt_mixed, tt_inputs, tt_labels1, tt_labels2, tt_lengths = get_padded_batch_v2(
                    tfrecords_lst,
                    1,
                    FLAGS.input_size,
                    FLAGS.output_size,
                    num_enqueuing_threads=1,
                    num_epochs=1,
                    shuffle=False)

        # Create two models with train_input and val_input individually.
        with tf.name_scope('model'):
            model = LSTM(FLAGS,
                         tt_inputs,
                         tt_mixed,
                         tt_labels1,
                         tt_labels2,
                         tt_lengths,
                         infer=True)

        init = tf.group(tf.global_variables_initializer(),
                        tf.local_variables_initializer())

        sess = tf.Session()

        sess.run(init)

        ckpt = tf.train.get_checkpoint_state(FLAGS.save_dir + '/nnet')
        if ckpt and ckpt.model_checkpoint_path:
            tf.logging.info("Restore from " + ckpt.model_checkpoint_path)
            model.saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            tf.logging.fatal("checkpoint not fou1nd.")
            sys.exit(-1)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    #cmvn_filename = os.path.join(FLAGS.date_dir, "/train_cmvn.npz")
    #if os.path.isfile(cmvn_filename):
    #    cmvn = np.load(cmvn_filename)
    #else:
    #    tf.logging.fatal("%s not exist, exit now." % cmvn_filename)
    #    sys.exit(-1)

    data_dir = FLAGS.data_dir
    if not os.path.exists(data_dir):
        os.makedirs(data_dir)

    try:
        for batch in xrange(num_batches):
            if coord.should_stop():
                break
            if FLAGS.assign == 'def':
                cleaned1, cleaned2 = sess.run(
                    [model._cleaned1, model._cleaned2])
            else:
                x1, x2 = model.get_opt_output()
                cleaned1, cleaned2 = sess.run([x1, x2])

            #sequence = activations * cmvn['stddev_labels'] + \
            #    cmvn['mean_labels']

            tffilename = tfrecords_lst[batch]
            (_, name) = os.path.split(tffilename)
            (uttid, _) = os.path.splitext(name)
            (partname, _) = os.path.splitext(uttid)
            #np.savetxt('data/mask/'+partname + '_1.mask', m1)
            #np.savetxt('data/mask/'+partname + '_2.mask', m2)
            kaldi_writer1 = kio.ArkWriter(data_dir + '/' + partname +
                                          '_1.wav.scp')
            kaldi_writer2 = kio.ArkWriter(data_dir + '/' + partname +
                                          '_2.wav.scp')
            kaldi_writer1.write_next_utt(
                data_dir + '/' + partname + '_1.wav.ark', uttid,
                cleaned1[0, :, :])
            kaldi_writer2.write_next_utt(
                data_dir + '/' + partname + '_2.wav.ark', uttid,
                cleaned2[0, :, :])
            kaldi_writer1.close()
            kaldi_writer2.close()
            if batch % 500 == 0:
                print(batch)

    except Exception, e:
        # Report exceptions to the coordinator.
        coord.request_stop(e)