Example #1
0
  def testGrid2BasicLSTMCellTied(self):
    with self.test_session() as sess:
      with variable_scope.variable_scope(
          'root', initializer=init_ops.constant_initializer(0.2)):
        x = array_ops.zeros([1, 3])
        m = array_ops.zeros([1, 8])
        cell = grid_rnn_cell.Grid2BasicLSTMCell(2, tied=True)
        self.assertEqual(cell.state_size, 8)

        g, s = cell(x, m)
        self.assertEqual(g.get_shape(), (1, 2))
        self.assertEqual(s.get_shape(), (1, 8))

        sess.run([variables.global_variables_initializer()])
        res = sess.run([g, s], {
            x: np.array([[1., 1., 1.]]),
            m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])
        })
        self.assertEqual(res[0].shape, (1, 2))
        self.assertEqual(res[1].shape, (1, 8))
        self.assertAllClose(res[0], [[0.36617181, 0.36617181]])
        self.assertAllClose(res[1], [[0.71053141, 0.71053141, 0.36617181,
                                      0.36617181, 0.72320831, 0.80555487,
                                      0.39102408, 0.42150158]])

        res = sess.run([g, s], {x: np.array([[1., 1., 1.]]), m: res[1]})
        self.assertEqual(res[0].shape, (1, 2))
        self.assertEqual(res[1].shape, (1, 8))
        self.assertAllClose(res[0], [[0.36703536, 0.36703536]])
        self.assertAllClose(res[1], [[0.71200621, 0.71200621, 0.36703536,
                                      0.36703536, 0.80941606, 0.87550586,
                                      0.40108523, 0.42199609]])
Example #2
0
    def testGrid2BasicLSTMCellWithRelu(self):
        with self.test_session(use_gpu=False) as sess:
            with variable_scope.variable_scope(
                    'root', initializer=init_ops.constant_initializer(0.2)):
                x = array_ops.zeros([1, 3])
                m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])), )
                cell = grid_rnn_cell.Grid2BasicLSTMCell(
                    2, tied=False, non_recurrent_fn=nn_ops.relu)
                self.assertEqual(cell.state_size, ((2, 2), ))

                g, s = cell(x, m)
                self.assertEqual(g[0].get_shape(), (1, 2))
                self.assertEqual(s[0].c.get_shape(), (1, 2))
                self.assertEqual(s[0].h.get_shape(), (1, 2))

                sess.run([variables.global_variables_initializer()])
                res_g, res_s = sess.run(
                    [g, s], {
                        x: np.array([[1., 1., 1.]]),
                        m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])), )
                    })
                self.assertEqual(res_g[0].shape, (1, 2))
                self.assertAllClose(res_g[0], [[0.31667367, 0.31667367]])
                self.assertAllClose(res_s,
                                    (([[0.29530135, 0.37520045]
                                       ], [[0.17044567, 0.21292259]]), ))
  def testGrid2BasicLSTMCell(self):
    with self.test_session(use_gpu=False) as sess:
      with variable_scope.variable_scope(
          'root', initializer=init_ops.constant_initializer(0.2)) as root_scope:
        x = array_ops.zeros([1, 3])
        m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),
             (array_ops.zeros([1, 2]), array_ops.zeros([1, 2])))
        cell = grid_rnn_cell.Grid2BasicLSTMCell(2)
        self.assertEqual(cell.state_size, ((2, 2), (2, 2)))

        g, s = cell(x, m)
        self.assertEqual(g[0].get_shape(), (1, 2))
        self.assertEqual(s[0].c.get_shape(), (1, 2))
        self.assertEqual(s[0].h.get_shape(), (1, 2))
        self.assertEqual(s[1].c.get_shape(), (1, 2))
        self.assertEqual(s[1].h.get_shape(), (1, 2))

        sess.run([variables.global_variables_initializer()])
        res_g, res_s = sess.run([g, s], {
            x:
                np.array([[1., 1., 1.]]),
            m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),
                (np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]])))
        })
        self.assertEqual(res_g[0].shape, (1, 2))
        self.assertEqual(res_s[0].c.shape, (1, 2))
        self.assertEqual(res_s[0].h.shape, (1, 2))
        self.assertEqual(res_s[1].c.shape, (1, 2))
        self.assertEqual(res_s[1].h.shape, (1, 2))

        self.assertAllClose(res_g, ([[0.36617181, 0.36617181]],))
        self.assertAllClose(
            res_s, (([[0.71053141, 0.71053141]], [[0.36617181, 0.36617181]]),
                    ([[0.72320831, 0.80555487]], [[0.39102408, 0.42150158]])))

        # emulate a loop through the input sequence,
        # where we call cell() multiple times
        root_scope.reuse_variables()
        g2, s2 = cell(x, m)
        self.assertEqual(g2[0].get_shape(), (1, 2))
        self.assertEqual(s2[0].c.get_shape(), (1, 2))
        self.assertEqual(s2[0].h.get_shape(), (1, 2))
        self.assertEqual(s2[1].c.get_shape(), (1, 2))
        self.assertEqual(s2[1].h.get_shape(), (1, 2))

        res_g2, res_s2 = sess.run([g2, s2],
                                  {x: np.array([[2., 2., 2.]]),
                                   m: res_s})
        self.assertEqual(res_g2[0].shape, (1, 2))
        self.assertEqual(res_s2[0].c.shape, (1, 2))
        self.assertEqual(res_s2[0].h.shape, (1, 2))
        self.assertEqual(res_s2[1].c.shape, (1, 2))
        self.assertEqual(res_s2[1].h.shape, (1, 2))
        self.assertAllClose(res_g2[0], [[0.58847463, 0.58847463]])
        self.assertAllClose(
            res_s2, (([[1.40469193, 1.40469193]], [[0.58847463, 0.58847463]]),
                     ([[0.97726452, 1.04626071]], [[0.4927212, 0.51137757]])))
Example #4
0
    def testGrid2BasicLSTMCellTied(self):
        with self.test_session(use_gpu=False) as sess:
            with variable_scope.variable_scope(
                    'root', initializer=init_ops.constant_initializer(0.2)):
                x = array_ops.zeros([1, 3])
                m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),
                     (array_ops.zeros([1, 2]), array_ops.zeros([1, 2])))
                cell = grid_rnn_cell.Grid2BasicLSTMCell(2, tied=True)
                self.assertEqual(cell.state_size, ((2, 2), (2, 2)))

                g, s = cell(x, m)
                self.assertEqual(g[0].get_shape(), (1, 2))
                self.assertEqual(s[0].c.get_shape(), (1, 2))
                self.assertEqual(s[0].h.get_shape(), (1, 2))
                self.assertEqual(s[1].c.get_shape(), (1, 2))
                self.assertEqual(s[1].h.get_shape(), (1, 2))

                sess.run([variables.global_variables_initializer()])
                res_g, res_s = sess.run(
                    [g, s], {
                        x:
                        np.array([[1., 1., 1.]]),
                        m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),
                            (np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]])))
                    })
                self.assertEqual(res_g[0].shape, (1, 2))
                self.assertEqual(res_s[0].c.shape, (1, 2))
                self.assertEqual(res_s[0].h.shape, (1, 2))
                self.assertEqual(res_s[1].c.shape, (1, 2))
                self.assertEqual(res_s[1].h.shape, (1, 2))

                self.assertAllClose(res_g[0], [[0.36617181, 0.36617181]])
                self.assertAllClose(
                    res_s,
                    (([[0.71053141, 0.71053141]], [[0.36617181, 0.36617181]]),
                     ([[0.72320831, 0.80555487]], [[0.39102408, 0.42150158]])))

                res_g, res_s = sess.run([g, s], {
                    x: np.array([[1., 1., 1.]]),
                    m: res_s
                })
                self.assertEqual(res_g[0].shape, (1, 2))

                self.assertAllClose(res_g[0], [[0.36703536, 0.36703536]])
                self.assertAllClose(
                    res_s,
                    (([[0.71200621, 0.71200621]], [[0.36703536, 0.36703536]]),
                     ([[0.80941606, 0.87550586]], [[0.40108523, 0.42199609]])))