Exemplo n.º 1
0
    def testGrid2BasicRNNCellTied(self):
        with self.test_session() as sess:
            with variable_scope.variable_scope(
                    'root', initializer=init_ops.constant_initializer(0.5)):
                x = array_ops.zeros([2, 2])
                m = (array_ops.zeros([2, 2]), array_ops.zeros([2, 2]))
                cell = grid_rnn_cell.Grid2BasicRNNCell(2, tied=True)
                self.assertEqual(cell.state_size, (2, 2))

                g, s = cell(x, m)
                self.assertEqual(g[0].get_shape(), (2, 2))
                self.assertEqual(s[0].get_shape(), (2, 2))
                self.assertEqual(s[1].get_shape(), (2, 2))

                sess.run([variables.global_variables_initializer()])
                res_g, res_s = sess.run(
                    [g, s], {
                        x:
                        np.array([[1., 1.], [2., 2.]]),
                        m: (np.array([[0.1, 0.1], [0.2, 0.2]
                                      ]), np.array([[0.1, 0.1], [0.2, 0.2]]))
                    })
                self.assertEqual(res_g[0].shape, (2, 2))
                self.assertEqual(res_s[0].shape, (2, 2))
                self.assertEqual(res_s[1].shape, (2, 2))

                self.assertAllClose(
                    res_g,
                    ([[0.94685763, 0.94685763], [0.99480951, 0.99480951]], ))
                self.assertAllClose(
                    res_s,
                    ([[0.94685763, 0.94685763], [0.99480951, 0.99480951]
                      ], [[0.80049908, 0.80049908], [0.97574311, 0.97574311]]))
Exemplo n.º 2
0
  def testGrid2BasicRNNCellWithRelu(self):
    with self.test_session() as sess:
      with variable_scope.variable_scope('root',
          initializer=init_ops.constant_initializer(0.5)):
        x = array_ops.zeros([1, 2])
        m = (array_ops.zeros([1, 2]), )
        cell = grid_rnn_cell.Grid2BasicRNNCell(
            2, non_recurrent_fn=nn_ops.relu)
        self.assertEqual(cell.state_size, (2, ))

        g, s = cell(x, m)
        self.assertEqual(g[0].get_shape(), (1, 2))
        self.assertEqual(s[0].get_shape(), (1, 2))

        sess.run([variables.global_variables_initializer()])
        res_g, res_s = sess.run([g, s], {x: np.array([[1., 1.]]),
                                         m: np.array([[0.1, 0.1]])})
        self.assertEqual(res_g[0].shape, (1, 2))
        self.assertEqual(res_s[0].shape, (1, 2))
        self.assertAllClose(res_g, ([[1.80049896, 1.80049896]], ))
        self.assertAllClose(res_s, ([[0.80049896, 0.80049896]], ))