Esempio n. 1
0
  def testComputation(self):
    inputs = tf.placeholder(tf.float32, shape=[self.batch_size, self.in_size])
    prev_state = tf.placeholder(tf.float32,
                                shape=[self.batch_size, self.in_size])

    vanilla_rnn = snt.VanillaRNN(name="rnn", hidden_size=self.in_size)
    residual = snt.ResidualCore(vanilla_rnn, name="residual")

    output, new_state = residual(inputs, prev_state)
    in_to_hid = vanilla_rnn.in_to_hidden_variables
    hid_to_hid = vanilla_rnn.hidden_to_hidden_variables
    with self.test_session() as sess:
      # With random data, check the TF calculation matches the Numpy version.
      input_data = np.random.randn(self.batch_size, self.in_size)
      prev_state_data = np.random.randn(self.batch_size, self.in_size)
      tf.global_variables_initializer().run()

      fetches = [output, new_state, in_to_hid[0], in_to_hid[1],
                 hid_to_hid[0], hid_to_hid[1]]
      output = sess.run(fetches,
                        {inputs: input_data, prev_state: prev_state_data})
    output_v, new_state_v, in_to_hid_w, in_to_hid_b = output[:4]
    hid_to_hid_w, hid_to_hid_b = output[4:]

    real_in_to_hid = np.dot(input_data, in_to_hid_w) + in_to_hid_b
    real_hid_to_hid = np.dot(prev_state_data, hid_to_hid_w) + hid_to_hid_b
    vanilla_output = np.tanh(real_in_to_hid + real_hid_to_hid)
    residual_output = vanilla_output + input_data

    self.assertAllClose(residual_output, output_v)
    self.assertAllClose(vanilla_output, new_state_v)
Esempio n. 2
0
  def testShape(self):
    inputs = tf.placeholder(tf.float32, shape=[self.batch_size, self.in_size])
    prev_state = tf.placeholder(
        tf.float32, shape=[self.batch_size, self.in_size])
    vanilla_rnn = snt.VanillaRNN(self.in_size)
    residual_wrapper = snt.ResidualCore(vanilla_rnn, name="residual")
    output, next_state = residual_wrapper(inputs, prev_state)
    shape = np.ndarray((self.batch_size, self.in_size))

    self.assertEqual(self.in_size, residual_wrapper.output_size)
    self.assertShapeEqual(shape, output)
    self.assertShapeEqual(shape, next_state)
Esempio n. 3
0
    def testHeterogeneousState(self):
        """Checks that the shape and type of the initial state are preserved."""

        core = HeterogeneousStateCore(name="rnn", hidden_size=self.in_size)
        residual = snt.ResidualCore(core, name="residual")

        core_state = core.initial_state(self.batch_size)
        residual_state = residual.initial_state(self.batch_size)

        self.assertEqual(core_state[0].shape.as_list(),
                         residual_state[0].shape.as_list())
        self.assertEqual(core_state[1].shape.as_list(),
                         residual_state[1].shape.as_list())
        self.assertEqual(core_state[0].dtype, residual_state[0].dtype)
        self.assertEqual(core_state[1].dtype, residual_state[1].dtype)