Beispiel #1
0
 def testByteNet(self):
     vocab_size = 9
     x = np.random.randint(1, high=vocab_size, size=(3, 5, 1, 1))
     y = np.random.randint(1, high=vocab_size, size=(3, 6, 1, 1))
     hparams = bytenet.bytenet_base()
     p_hparams = problem_hparams.test_problem_hparams(
         vocab_size, vocab_size, hparams)
     with self.test_session() as session:
         features = {
             "inputs": tf.constant(x, dtype=tf.int32),
             "targets": tf.constant(y, dtype=tf.int32),
         }
         model = bytenet.ByteNet(hparams, tf.estimator.ModeKeys.TRAIN,
                                 p_hparams)
         logits, _ = model(features)
         session.run(tf.global_variables_initializer())
         res = session.run(logits)
     self.assertEqual(res.shape, (3, 50, 1, 1, vocab_size))
 def testByteNet(self):
     vocab_size = 9
     x = np.random.random_integers(1,
                                   high=vocab_size - 1,
                                   size=(3, 5, 1, 1))
     y = np.random.random_integers(1,
                                   high=vocab_size - 1,
                                   size=(3, 6, 1, 1))
     hparams = bytenet.bytenet_base()
     p_hparams = problem_hparams.test_problem_hparams(
         hparams, vocab_size, vocab_size)
     with self.test_session() as session:
         features = {
             "inputs": tf.constant(x, dtype=tf.int32),
             "targets": tf.constant(y, dtype=tf.int32),
         }
         model = bytenet.ByteNet(hparams, p_hparams)
         sharded_logits, _, _ = model.model_fn(features, True)
         logits = tf.concat(sharded_logits, 0)
         session.run(tf.global_variables_initializer())
         res = session.run(logits)
     self.assertEqual(res.shape, (3, 50, 1, 1, vocab_size))