コード例 #1
0
    def testResidualWrapperWithSlice(self):
        with self.test_session() as sess:
            with variable_scope.variable_scope(
                    "root", initializer=init_ops.constant_initializer(0.5)):
                x = array_ops.zeros([1, 5])
                m = array_ops.zeros([1, 3])
                base_cell = rnn_cell_impl.GRUCell(3)
                g, m_new = base_cell(x, m)
                variable_scope.get_variable_scope().reuse_variables()

                def residual_with_slice_fn(inp, out):
                    inp_sliced = array_ops.slice(inp, [0, 0], [-1, 3])
                    return inp_sliced + out

                g_res, m_new_res = rnn_cell_impl.ResidualWrapper(
                    base_cell, residual_with_slice_fn)(x, m)
                sess.run([variables_lib.global_variables_initializer()])
                res_g, res_g_res, res_m_new, res_m_new_res = sess.run(
                    [g, g_res, m_new, m_new_res], {
                        x: np.array([[1., 1., 1., 1., 1.]]),
                        m: np.array([[0.1, 0.1, 0.1]])
                    })
                # Residual connections
                self.assertAllClose(res_g_res, res_g + [1., 1., 1.])
                # States are left untouched
                self.assertAllClose(res_m_new, res_m_new_res)
コード例 #2
0
    def testResidualWrapper(self):
        with self.test_session() as sess:
            with variable_scope.variable_scope(
                    "root", initializer=init_ops.constant_initializer(0.5)):
                x = array_ops.zeros([1, 3])
                m = array_ops.zeros([1, 3])
                base_cell = rnn_cell_impl.GRUCell(3)
                g, m_new = base_cell(x, m)
                variable_scope.get_variable_scope().reuse_variables()
                wrapper_object = rnn_cell_impl.ResidualWrapper(base_cell)
                (name, dep), = wrapper_object._checkpoint_dependencies
                wrapper_object.get_config()  # Should not throw an error
                self.assertIs(dep, base_cell)
                self.assertEqual("cell", name)

                g_res, m_new_res = wrapper_object(x, m)
                sess.run([variables_lib.global_variables_initializer()])
                res = sess.run([g, g_res, m_new, m_new_res], {
                    x: np.array([[1., 1., 1.]]),
                    m: np.array([[0.1, 0.1, 0.1]])
                })
                # Residual connections
                self.assertAllClose(res[1], res[0] + [1., 1., 1.])
                # States are left untouched
                self.assertAllClose(res[2], res[3])
コード例 #3
0
 def testResidualWrapper(self):
   with self.test_session() as sess:
     with variable_scope.variable_scope(
         "root", initializer=init_ops.constant_initializer(0.5)):
       x = array_ops.zeros([1, 3])
       m = array_ops.zeros([1, 3])
       base_cell = rnn_cell_impl.GRUCell(3)
       g, m_new = base_cell(x, m)
       variable_scope.get_variable_scope().reuse_variables()
       g_res, m_new_res = rnn_cell_impl.ResidualWrapper(base_cell)(x, m)
       sess.run([variables_lib.global_variables_initializer()])
       res = sess.run([g, g_res, m_new, m_new_res], {
           x: np.array([[1., 1., 1.]]),
           m: np.array([[0.1, 0.1, 0.1]])
       })
       # Residual connections
       self.assertAllClose(res[1], res[0] + [1., 1., 1.])
       # States are left untouched
       self.assertAllClose(res[2], res[3])