def testResidualWrapperWithSlice(self): with self.test_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 5]) m = array_ops.zeros([1, 3]) base_cell = rnn_cell_impl.GRUCell(3) g, m_new = base_cell(x, m) variable_scope.get_variable_scope().reuse_variables() def residual_with_slice_fn(inp, out): inp_sliced = array_ops.slice(inp, [0, 0], [-1, 3]) return inp_sliced + out g_res, m_new_res = rnn_cell_impl.ResidualWrapper( base_cell, residual_with_slice_fn)(x, m) sess.run([variables_lib.global_variables_initializer()]) res_g, res_g_res, res_m_new, res_m_new_res = sess.run( [g, g_res, m_new, m_new_res], { x: np.array([[1., 1., 1., 1., 1.]]), m: np.array([[0.1, 0.1, 0.1]]) }) # Residual connections self.assertAllClose(res_g_res, res_g + [1., 1., 1.]) # States are left untouched self.assertAllClose(res_m_new, res_m_new_res)
def testResidualWrapper(self): with self.test_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 3]) m = array_ops.zeros([1, 3]) base_cell = rnn_cell_impl.GRUCell(3) g, m_new = base_cell(x, m) variable_scope.get_variable_scope().reuse_variables() wrapper_object = rnn_cell_impl.ResidualWrapper(base_cell) (name, dep), = wrapper_object._checkpoint_dependencies wrapper_object.get_config() # Should not throw an error self.assertIs(dep, base_cell) self.assertEqual("cell", name) g_res, m_new_res = wrapper_object(x, m) sess.run([variables_lib.global_variables_initializer()]) res = sess.run([g, g_res, m_new, m_new_res], { x: np.array([[1., 1., 1.]]), m: np.array([[0.1, 0.1, 0.1]]) }) # Residual connections self.assertAllClose(res[1], res[0] + [1., 1., 1.]) # States are left untouched self.assertAllClose(res[2], res[3])
def testResidualWrapper(self): with self.test_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 3]) m = array_ops.zeros([1, 3]) base_cell = rnn_cell_impl.GRUCell(3) g, m_new = base_cell(x, m) variable_scope.get_variable_scope().reuse_variables() g_res, m_new_res = rnn_cell_impl.ResidualWrapper(base_cell)(x, m) sess.run([variables_lib.global_variables_initializer()]) res = sess.run([g, g_res, m_new, m_new_res], { x: np.array([[1., 1., 1.]]), m: np.array([[0.1, 0.1, 0.1]]) }) # Residual connections self.assertAllClose(res[1], res[0] + [1., 1., 1.]) # States are left untouched self.assertAllClose(res[2], res[3])