Exemple #1
0
 def testNegativeLogPerplexity(self):
     predictions = np.random.randint(4, size=(12, 12, 12, 1))
     targets = np.random.randint(4, size=(12, 12, 12, 1))
     with self.test_session() as session:
         scores, _ = metrics.padded_neg_log_perplexity(
             tf.one_hot(predictions, depth=4, dtype=tf.float32),
             tf.constant(targets, dtype=tf.int32))
         a = tf.reduce_mean(scores)
         session.run(tf.global_variables_initializer())
         actual = session.run(a)
     self.assertEqual(actual.shape, ())
 def testNegativeLogPerplexity(self):
   predictions = np.random.randint(4, size=(12, 12, 12, 1))
   targets = np.random.randint(4, size=(12, 12, 12, 1))
   with self.test_session() as session:
     scores, _ = metrics.padded_neg_log_perplexity(
         tf.one_hot(predictions, depth=4, dtype=tf.float32),
         tf.constant(targets, dtype=tf.int32))
     a = tf.reduce_mean(scores)
     session.run(tf.global_variables_initializer())
     actual = session.run(a)
   self.assertEqual(actual.shape, ())
Exemple #3
0
def main(_):
    # Set the logging level.
    tf.logging.set_verbosity(tf.logging.INFO)

    # Import module at usr_dir, if provided.
    if FLAGS.t2t_usr_dir is not None:
        usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
        
    # Get inputs (list formatted) from file.
    assert FLAGS.srcFile is not None
    assert FLAGS.firstPFile is not None
    assert FLAGS.tgtFile is not None
    
    [sorted_inputs, sorted_firstP, sorted_targets], sorted_keys = \
            get_sorted_inputs(FLAGS.srcFile, FLAGS.firstPFile, FLAGS.tgtFile)
            
    num_decode_batches = (len(sorted_inputs) - 1) // FLAGS.eval_batch + 1

    assert len(sorted_inputs) == len(sorted_firstP) == len(sorted_targets)
    
    tf.logging.info("Writing decodes into %s" % FLAGS.scoreFile)
    outfile = tf.gfile.Open(FLAGS.scoreFile, "w")
    
    # Generate hyper-parameters.
    hparams = utils.create_hparams(FLAGS.hparams_set, FLAGS.data_dir, passed_hparams=FLAGS.hparams)
    utils.add_problem_hparams(hparams, FLAGS.problems)
    # Create input function.
    num_datashards = utils.devices.data_parallelism().n
    mode = tf.estimator.ModeKeys.EVAL
    input_fn = utils.input_fn_builder.build_input_fn(mode, hparams,
                                                     data_dir=FLAGS.data_dir,
                                                     num_datashards=num_datashards,
                                                     worker_replicas=FLAGS.worker_replicas,
                                                     worker_id=FLAGS.worker_id,
                                                     batch_size=FLAGS.eval_batch)

    # Get wrappers for feeding datas into models.
    inputs, target = input_fn()
    features = inputs
    features['targets'] = target
    inputs_vocab = hparams.problems[0].vocabulary["inputs"]
    targets_vocab = hparams.problems[0].vocabulary["targets"]
    feed_iters = input_iter(0, num_decode_batches,
                            sorted_inputs, sorted_firstP, sorted_targets,
                            inputs_vocab, targets_vocab)

    model_fn = utils.model_builder.build_model_fn(FLAGS.model, problem_names=[FLAGS.problems],
                                                  train_steps=FLAGS.train_steps,
                                                  worker_id=FLAGS.worker_id, worker_replicas=FLAGS.worker_replicas,
	                                              eval_run_autoregressive=FLAGS.eval_run_autoregressive,
	                                              decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams))
    est_spec = model_fn(features, target, mode, hparams)
    score, _ = metrics.padded_neg_log_perplexity(est_spec.predictions['predictions'], target)
    score = tf.reduce_sum(score, axis=[1,2,3])
    
    # Create session.
    sv = tf.train.Supervisor(logdir=FLAGS.output_dir, global_step=tf.Variable(0, dtype=tf.int64, trainable=False, name='global_step'))
    sess = sv.PrepareSession(config=tf.ConfigProto(allow_soft_placement=True))
    sv.StartQueueRunners(sess, tf.get_default_graph().get_collection(tf.GraphKeys.QUEUE_RUNNERS))

    sumt = 0
    scores_list = []
    
    # Loop for batched translation.
    for i, features in enumerate(feed_iters):
        t = time.time()
        inputs_ = features["inputs"]
        firstP_ = features["firstP"]
        targets_ = features["targets"]
        
        while inputs_.ndim < 4: inputs_ = np.expand_dims(inputs_, axis=-1)
        while firstP_.ndim < 4: firstP_ = np.expand_dims(firstP_, axis=-1)
        while targets_.ndim < 4: targets_ = np.expand_dims(targets_, axis=-1)
        scores = sess.run(score, feed_dict={inputs['inputs']: inputs_, inputs["firstP"]: firstP_, target: targets_})
        scores_list.extend(scores.tolist())
        dt = time.time() - t
        sumt += dt
        avgt = sumt / (i+1)
        needt = (num_decode_batches - i+1) * avgt
        
        print("Batch %d/%d worktime=(%s), lefttime=(%s)" % (i+1, num_decode_batches, time.strftime('%H:%M:%S',time.gmtime(sumt)),time.strftime('%H:%M:%S',time.gmtime(needt))))
     
    scores_list.reverse()

    # Write to file with the original order.
    for index in range(len(sorted_inputs)):
        outfile.write("%.8f\n" % (scores_list[sorted_keys[index]]))