Пример #1
0
    def testComputation(self):
        model_rnn = snt.ModelRNN(self.model)
        inputs = tf.random_normal([self.batch_size, 5])
        prev_state_data = np.random.randn(self.batch_size, self.hidden_size)
        prev_state = tf.convert_to_tensor(prev_state_data)

        outputs, next_state = model_rnn(inputs, prev_state)

        outputs_value = self.evaluate([outputs, next_state])
        outputs_value, next_state_value = outputs_value

        self.assertAllClose(prev_state_data, outputs_value)
        self.assertAllClose(outputs_value, next_state_value)
Пример #2
0
    def testShape(self):
        model_rnn = snt.ModelRNN(self.model)
        inputs = tf.random_normal([self.batch_size, 5])
        prev_state = tf.placeholder(tf.float32,
                                    shape=[self.batch_size, self.hidden_size])

        outputs, next_state = model_rnn(inputs, prev_state)
        batch_size_shape = tf.TensorShape(self.batch_size)
        expected_shape = batch_size_shape.concatenate(self.model.output_size)

        self.assertNotEqual(expected_shape, inputs.get_shape())
        self.assertEqual(expected_shape, prev_state.get_shape())
        self.assertEqual(expected_shape, next_state.get_shape())
        self.assertEqual(expected_shape, outputs.get_shape())
Пример #3
0
  def testComputation(self):
    model_rnn = snt.ModelRNN(self.model)
    inputs = tf.random_normal([self.batch_size, 5])
    prev_state = tf.placeholder(tf.float32,
                                shape=[self.batch_size, self.hidden_size])

    outputs, next_state = model_rnn(inputs, prev_state)

    with self.test_session() as sess:
      prev_state_data = np.random.randn(self.batch_size, self.hidden_size)
      feed_dict = {prev_state: prev_state_data}
      sess.run(tf.global_variables_initializer())

      outputs_value = sess.run([outputs, next_state], feed_dict=feed_dict)
      outputs_value, next_state_value = outputs_value

    self.assertAllClose(prev_state_data, outputs_value)
    self.assertAllClose(outputs_value, next_state_value)
Пример #4
0
 def testBadArguments(self):
     with self.assertRaises(AttributeError):
         snt.ModelRNN(tf.identity)
     with self.assertRaises(TypeError):
         snt.ModelRNN(np.array([42]))