コード例 #1
0
 def testNeuralGPU(self):
     hparams = common_hparams.basic_params1()
     batch_size = 3
     input_length = 5
     target_length = input_length
     input_vocab_size = 9
     target_vocab_size = 11
     p_hparams = problem_hparams.test_problem_hparams(
         input_vocab_size, target_vocab_size, hparams)
     inputs = np.random.randint(input_vocab_size,
                                size=(batch_size, input_length, 1, 1))
     targets = np.random.randint(target_vocab_size,
                                 size=(batch_size, target_length, 1, 1))
     with self.test_session() as session:
         features = {
             "inputs": tf.constant(inputs, dtype=tf.int32),
             "targets": tf.constant(targets, dtype=tf.int32)
         }
         model = neural_gpu.NeuralGPU(hparams, tf.estimator.ModeKeys.TRAIN,
                                      p_hparams)
         logits, _ = model(features)
         session.run(tf.global_variables_initializer())
         res = session.run(logits)
     self.assertEqual(res.shape,
                      (batch_size, target_length, 1, 1, target_vocab_size))
コード例 #2
0
 def testNeuralGPU(self):
   hparams = common_hparams.basic_params1()
   batch_size = 3
   input_length = 5
   target_length = input_length
   input_vocab_size = 9
   target_vocab_size = 11
   p_hparams = problem_hparams.test_problem_hparams(hparams, input_vocab_size,
                                                    target_vocab_size)
   inputs = -1 + np.random.random_integers(
       input_vocab_size, size=(batch_size, input_length, 1, 1))
   targets = -1 + np.random.random_integers(
       target_vocab_size, size=(batch_size, target_length, 1, 1))
   with self.test_session() as session:
     features = {
         "inputs": tf.constant(inputs, dtype=tf.int32),
         "targets": tf.constant(targets, dtype=tf.int32)
     }
     model = neural_gpu.NeuralGPU(hparams, p_hparams)
     shadred_logits, _, _ = model.model_fn(features, True)
     logits = tf.concat(shadred_logits, 0)
     session.run(tf.global_variables_initializer())
     res = session.run(logits)
   self.assertEqual(res.shape, (batch_size, target_length, 1, 1,
                                target_vocab_size))