Esempio n. 1
0
 def testBlueNet(self):
   vocab_size = 9
   x = np.random.random_integers(1, high=vocab_size - 1, size=(3, 5, 1, 1))
   y = np.random.random_integers(1, high=vocab_size - 1, size=(3, 1, 1, 1))
   hparams = bluenet.bluenet_tiny()
   p_hparams = problem_hparams.test_problem_hparams(vocab_size, vocab_size)
   with self.test_session() as session:
     tf.train.get_or_create_global_step()
     features = {
         "inputs": tf.constant(x, dtype=tf.int32),
         "targets": tf.constant(y, dtype=tf.int32),
     }
     model = bluenet.BlueNet(
         hparams, tf.estimator.ModeKeys.TRAIN, p_hparams)
     logits, _ = model(features)
     session.run(tf.global_variables_initializer())
     res = session.run(logits)
   self.assertEqual(res.shape, (3, 5, 1, 1, vocab_size))
Esempio n. 2
0
 def testBlueNet(self):
   vocab_size = 9
   x = np.random.random_integers(1, high=vocab_size - 1, size=(3, 5, 1, 1))
   y = np.random.random_integers(1, high=vocab_size - 1, size=(3, 1, 1, 1))
   hparams = bluenet.bluenet_tiny()
   p_hparams = problem_hparams.test_problem_hparams(hparams, vocab_size,
                                                    vocab_size)
   with self.test_session() as session:
     tf.train.get_or_create_global_step()
     features = {
         "inputs": tf.constant(x, dtype=tf.int32),
         "targets": tf.constant(y, dtype=tf.int32),
     }
     model = bluenet.BlueNet(
         hparams, tf.contrib.learn.ModeKeys.TRAIN, p_hparams)
     sharded_logits, _, _ = model.model_fn(features)
     logits = tf.concat(sharded_logits, 0)
     session.run(tf.global_variables_initializer())
     res = session.run(logits)
   self.assertEqual(res.shape, (3, 5, 1, 1, vocab_size))