def testComputation(self): inputs = tf.placeholder(tf.float32, shape=[self.batch_size, self.in_size]) prev_state = tf.placeholder(tf.float32, shape=[self.batch_size, self.in_size]) vanilla_rnn = snt.VanillaRNN(name="rnn", hidden_size=self.in_size) residual = snt.ResidualCore(vanilla_rnn, name="residual") output, new_state = residual(inputs, prev_state) in_to_hid = vanilla_rnn.in_to_hidden_variables hid_to_hid = vanilla_rnn.hidden_to_hidden_variables with self.test_session() as sess: # With random data, check the TF calculation matches the Numpy version. input_data = np.random.randn(self.batch_size, self.in_size) prev_state_data = np.random.randn(self.batch_size, self.in_size) tf.global_variables_initializer().run() fetches = [output, new_state, in_to_hid[0], in_to_hid[1], hid_to_hid[0], hid_to_hid[1]] output = sess.run(fetches, {inputs: input_data, prev_state: prev_state_data}) output_v, new_state_v, in_to_hid_w, in_to_hid_b = output[:4] hid_to_hid_w, hid_to_hid_b = output[4:] real_in_to_hid = np.dot(input_data, in_to_hid_w) + in_to_hid_b real_hid_to_hid = np.dot(prev_state_data, hid_to_hid_w) + hid_to_hid_b vanilla_output = np.tanh(real_in_to_hid + real_hid_to_hid) residual_output = vanilla_output + input_data self.assertAllClose(residual_output, output_v) self.assertAllClose(vanilla_output, new_state_v)
def testShape(self): inputs = tf.placeholder(tf.float32, shape=[self.batch_size, self.in_size]) prev_state = tf.placeholder( tf.float32, shape=[self.batch_size, self.in_size]) vanilla_rnn = snt.VanillaRNN(self.in_size) residual_wrapper = snt.ResidualCore(vanilla_rnn, name="residual") output, next_state = residual_wrapper(inputs, prev_state) shape = np.ndarray((self.batch_size, self.in_size)) self.assertEqual(self.in_size, residual_wrapper.output_size) self.assertShapeEqual(shape, output) self.assertShapeEqual(shape, next_state)
def testHeterogeneousState(self): """Checks that the shape and type of the initial state are preserved.""" core = HeterogeneousStateCore(name="rnn", hidden_size=self.in_size) residual = snt.ResidualCore(core, name="residual") core_state = core.initial_state(self.batch_size) residual_state = residual.initial_state(self.batch_size) self.assertEqual(core_state[0].shape.as_list(), residual_state[0].shape.as_list()) self.assertEqual(core_state[1].shape.as_list(), residual_state[1].shape.as_list()) self.assertEqual(core_state[0].dtype, residual_state[0].dtype) self.assertEqual(core_state[1].dtype, residual_state[1].dtype)