def testNegativeLogPerplexity(self): predictions = np.random.randint(4, size=(12, 12, 12, 1)) targets = np.random.randint(4, size=(12, 12, 12, 1)) with self.test_session() as session: scores, _ = metrics.padded_neg_log_perplexity( tf.one_hot(predictions, depth=4, dtype=tf.float32), tf.constant(targets, dtype=tf.int32)) a = tf.reduce_mean(scores) session.run(tf.global_variables_initializer()) actual = session.run(a) self.assertEqual(actual.shape, ())
def testNegativeLogPerplexity(self): predictions = np.random.randint(4, size=(12, 12, 12, 1)) targets = np.random.randint(4, size=(12, 12, 12, 1)) with self.test_session() as session: scores, _ = metrics.padded_neg_log_perplexity( tf.one_hot(predictions, depth=4, dtype=tf.float32), tf.constant(targets, dtype=tf.int32)) a = tf.reduce_mean(scores) session.run(tf.global_variables_initializer()) actual = session.run(a) self.assertEqual(actual.shape, ())
def main(_): # Set the logging level. tf.logging.set_verbosity(tf.logging.INFO) # Import module at usr_dir, if provided. if FLAGS.t2t_usr_dir is not None: usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) # Get inputs (list formatted) from file. assert FLAGS.srcFile is not None assert FLAGS.firstPFile is not None assert FLAGS.tgtFile is not None [sorted_inputs, sorted_firstP, sorted_targets], sorted_keys = \ get_sorted_inputs(FLAGS.srcFile, FLAGS.firstPFile, FLAGS.tgtFile) num_decode_batches = (len(sorted_inputs) - 1) // FLAGS.eval_batch + 1 assert len(sorted_inputs) == len(sorted_firstP) == len(sorted_targets) tf.logging.info("Writing decodes into %s" % FLAGS.scoreFile) outfile = tf.gfile.Open(FLAGS.scoreFile, "w") # Generate hyper-parameters. hparams = utils.create_hparams(FLAGS.hparams_set, FLAGS.data_dir, passed_hparams=FLAGS.hparams) utils.add_problem_hparams(hparams, FLAGS.problems) # Create input function. num_datashards = utils.devices.data_parallelism().n mode = tf.estimator.ModeKeys.EVAL input_fn = utils.input_fn_builder.build_input_fn(mode, hparams, data_dir=FLAGS.data_dir, num_datashards=num_datashards, worker_replicas=FLAGS.worker_replicas, worker_id=FLAGS.worker_id, batch_size=FLAGS.eval_batch) # Get wrappers for feeding datas into models. inputs, target = input_fn() features = inputs features['targets'] = target inputs_vocab = hparams.problems[0].vocabulary["inputs"] targets_vocab = hparams.problems[0].vocabulary["targets"] feed_iters = input_iter(0, num_decode_batches, sorted_inputs, sorted_firstP, sorted_targets, inputs_vocab, targets_vocab) model_fn = utils.model_builder.build_model_fn(FLAGS.model, problem_names=[FLAGS.problems], train_steps=FLAGS.train_steps, worker_id=FLAGS.worker_id, worker_replicas=FLAGS.worker_replicas, eval_run_autoregressive=FLAGS.eval_run_autoregressive, decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams)) est_spec = model_fn(features, target, mode, hparams) score, _ = metrics.padded_neg_log_perplexity(est_spec.predictions['predictions'], target) score = tf.reduce_sum(score, axis=[1,2,3]) # Create session. sv = tf.train.Supervisor(logdir=FLAGS.output_dir, global_step=tf.Variable(0, dtype=tf.int64, trainable=False, name='global_step')) sess = sv.PrepareSession(config=tf.ConfigProto(allow_soft_placement=True)) sv.StartQueueRunners(sess, tf.get_default_graph().get_collection(tf.GraphKeys.QUEUE_RUNNERS)) sumt = 0 scores_list = [] # Loop for batched translation. for i, features in enumerate(feed_iters): t = time.time() inputs_ = features["inputs"] firstP_ = features["firstP"] targets_ = features["targets"] while inputs_.ndim < 4: inputs_ = np.expand_dims(inputs_, axis=-1) while firstP_.ndim < 4: firstP_ = np.expand_dims(firstP_, axis=-1) while targets_.ndim < 4: targets_ = np.expand_dims(targets_, axis=-1) scores = sess.run(score, feed_dict={inputs['inputs']: inputs_, inputs["firstP"]: firstP_, target: targets_}) scores_list.extend(scores.tolist()) dt = time.time() - t sumt += dt avgt = sumt / (i+1) needt = (num_decode_batches - i+1) * avgt print("Batch %d/%d worktime=(%s), lefttime=(%s)" % (i+1, num_decode_batches, time.strftime('%H:%M:%S',time.gmtime(sumt)),time.strftime('%H:%M:%S',time.gmtime(needt)))) scores_list.reverse() # Write to file with the original order. for index in range(len(sorted_inputs)): outfile.write("%.8f\n" % (scores_list[sorted_keys[index]]))