def get_logpriors(): logpriors = 0 if FLAGS.post2like: with kaldi_data(os.path.join(FLAGS.data_dir, FLAGS.occupances)) as kd: logpriors = kd.read_counts() return logpriors
def produce_likelihoods(): tf_model = import_tf_model() logpriors = get_logpriors() with tf.Graph().as_default(): features = tf.placeholder(tf.float32, shape=(None, None)) logits = tf_model.inference(features) ckpt_path = os.path.join(FLAGS.data_dir, FLAGS.tf_ckpt_path) saver = tf.train.Saver() sess = tf.Session() saver.restore(sess, ckpt_path) with kaldi_data(FLAGS.features_rspec) as kd_reader: with kaldi_data(FLAGS.prob_wspec, 'w') as kd_writer: for d in kd_reader.read_utterance(FLAGS.batch_size): utterance_id = d[0] batch = tf_model.process_data(d[1], FLAGS.transpose_input) r = sess.run([logits], feed_dict={features: batch}) kd_writer.write_batches([utterance_id, r[0] - logpriors]) sess.close() return
def produce_likelihoods(): with kaldi_helpers.kaldi_data('./t.ark') as kd: batch1 = kd.read_utterance(-1) u1 = batch1.next() u2 = batch1.next() print(np.shape(u1[1])) with kaldi_helpers.kaldi_data(FLAGS.occupances) as kd: logprioirs = kd.read_counts() with tf.Graph().as_default(): val_images, val_labels = eval_inputs() images = tf.placeholder(tf.float32, shape=(None, 1320)) labels = tf.placeholder(tf.int32, shape=(None)) logits = tf_model.inference(images, 2048, 2048, 2048) loss = tf_model.loss(logits, labels) saver = tf.train.Saver() sess = tf.Session() coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) saver.restore(sess, '../../../data/tf_fbank_deltas_nocmvn/cnnmodel-31530') vi, vl = sess.run([val_images, val_labels]) r = sess.run([logits], feed_dict={images: vi}) l = r[0] - logprioirs with kaldi_helpers.kaldi_data('./t_like_u1.ark', 'w') as kd: kd.write_utterance([[u1[0], l]]) r = sess.run([loss], feed_dict={images: vi, labels: vl}) print(r) r = sess.run([loss], feed_dict={images: u1[1], labels: vl}) print(r) l = vi # with kaldi_helpers.kaldi_data('./t_feats_tfr.ark', 'w') as kd: # kd.write_utterance([[u1[0], l]]) coord.request_stop() # Wait for threads to finish. coord.join(threads) sess.close()