Exemplo n.º 1
0
    def testGrid2LSTMCellLegacy(self):
        """Test for legacy case (when state_is_tuple=False)."""
        with self.test_session() as sess:
            with variable_scope.variable_scope(
                    'root', initializer=init_ops.constant_initializer(0.5)):
                x = array_ops.zeros([1, 3])
                m = array_ops.zeros([1, 8])
                cell = grid_rnn_cell.Grid2LSTMCell(2,
                                                   use_peepholes=True,
                                                   state_is_tuple=False,
                                                   output_is_tuple=False)
                self.assertEqual(cell.state_size, 8)

                g, s = cell(x, m)
                self.assertEqual(g.get_shape(), (1, 2))
                self.assertEqual(s.get_shape(), (1, 8))

                sess.run([variables.global_variables_initializer()])
                res = sess.run(
                    [g, s], {
                        x: np.array([[1., 1., 1.]]),
                        m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])
                    })
                self.assertEqual(res[0].shape, (1, 2))
                self.assertEqual(res[1].shape, (1, 8))
                self.assertAllClose(res[0], [[0.95686918, 0.95686918]])
                self.assertAllClose(res[1], [[
                    2.41515064, 2.41515064, 0.95686918, 0.95686918, 1.38917875,
                    1.49043763, 0.83884692, 0.86036491
                ]])
Exemplo n.º 2
0
    def testGrid2LSTMCellWithRelu(self):
        with self.test_session() as sess:
            with variable_scope.variable_scope(
                    'root', initializer=init_ops.constant_initializer(0.5)):
                x = array_ops.zeros([1, 3])
                m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])), )
                cell = grid_rnn_cell.Grid2LSTMCell(
                    2, use_peepholes=True, non_recurrent_fn=nn_ops.relu)
                self.assertEqual(cell.state_size, ((2, 2), ))

                g, s = cell(x, m)
                self.assertEqual(g[0].get_shape(), (1, 2))
                self.assertEqual(s[0].c.get_shape(), (1, 2))
                self.assertEqual(s[0].h.get_shape(), (1, 2))

                sess.run([variables.global_variables_initializer()])
                res_g, res_s = sess.run(
                    [g, s], {
                        x: np.array([[1., 1., 1.]]),
                        m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])), )
                    })
                self.assertEqual(res_g[0].shape, (1, 2))
                self.assertAllClose(res_g[0], [[2.1831727, 2.1831727]])
                self.assertAllClose(res_s,
                                    (([[0.92270052, 1.02325559]
                                       ], [[0.66159075, 0.70475441]]), ))
Exemplo n.º 3
0
  def testGrid2LSTMCellTied(self):
    with self.test_session(use_gpu=False) as sess:
      with variable_scope.variable_scope('root',
          initializer=init_ops.constant_initializer(0.5)):
        x = array_ops.zeros([1, 3])
        m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),
             (array_ops.zeros([1, 2]), array_ops.zeros([1, 2])))
        cell = grid_rnn_cell.Grid2LSTMCell(2, tied=True, use_peepholes=True)
        self.assertEqual(cell.state_size, ((2, 2), (2, 2)))

        g, s = cell(x, m)
        self.assertEqual(g[0].get_shape(), (1, 2))
        self.assertEqual(s[0].c.get_shape(), (1, 2))
        self.assertEqual(s[0].h.get_shape(), (1, 2))
        self.assertEqual(s[1].c.get_shape(), (1, 2))
        self.assertEqual(s[1].h.get_shape(), (1, 2))

        sess.run([variables.global_variables_initializer()])
        res_g, res_s = sess.run(
            [g, s], {x: np.array([[1., 1., 1.]]),
                     m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),
                         (np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]])))})
        self.assertEqual(res_g[0].shape, (1, 2))
        self.assertEqual(res_s[0].c.shape, (1, 2))
        self.assertEqual(res_s[0].h.shape, (1, 2))
        self.assertEqual(res_s[1].c.shape, (1, 2))
        self.assertEqual(res_s[1].h.shape, (1, 2))

        self.assertAllClose(res_g[0], [[0.95686918, 0.95686918]])
        self.assertAllClose(res_s, (([[2.41515064, 2.41515064]],
                                     [[0.95686918, 0.95686918]]),
                                    ([[1.38917875, 1.49043763]],
                                     [[0.83884692, 0.86036491]])))
Exemplo n.º 4
0
  def testGrid2LSTMCellReLUWithRNN(self):
    batch_size = 3
    input_size = 5
    max_length = 6  # unrolled up to this length
    num_units = 2

    with variable_scope.variable_scope(
        'root', initializer=init_ops.constant_initializer(0.5)):
      cell = grid_rnn_cell.Grid2LSTMCell(
          num_units=num_units, non_recurrent_fn=nn_ops.relu)

      inputs = max_length * [
          array_ops.placeholder(
              dtypes.float32, shape=(batch_size, input_size))
      ]

      outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)

    self.assertEqual(len(outputs), len(inputs))
    self.assertEqual(state.get_shape(), (batch_size, 4))

    for out, inp in zip(outputs, inputs):
      self.assertEqual(out.get_shape()[0], inp.get_shape()[0])
      self.assertEqual(out.get_shape()[1], num_units)
      self.assertEqual(out.dtype, inp.dtype)

    with self.test_session() as sess:
      sess.run(variables.global_variables_initializer())

      input_value = np.ones((batch_size, input_size))
      values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value})
      for v in values:
        self.assertTrue(np.all(np.isfinite(v)))
Exemplo n.º 5
0
  def testGrid2LSTMCellWithRNNAndDynamicBatchSize(self):
    """Test for #4296
    """
    input_size = 5
    max_length = 6  # unrolled up to this length
    num_units = 2

    with variable_scope.variable_scope('root',
                           initializer=init_ops.constant_initializer(0.5)):
      cell = grid_rnn_cell.Grid2LSTMCell(num_units=num_units)

      inputs = max_length * [
        array_ops.placeholder(
          dtypes.float32, shape=(None, input_size))
      ]

      outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)

    self.assertEqual(len(outputs), len(inputs))

    for out, inp in zip(outputs, inputs):
      self.assertEqual(len(out), 1)
      self.assertTrue(out[0].get_shape()[0].value is None)
      self.assertEqual(out[0].get_shape()[1], num_units)
      self.assertEqual(out[0].dtype, inp.dtype)

    with self.test_session() as sess:
      sess.run(variables.global_variables_initializer())

      input_value = np.ones((3, input_size))
      values = sess.run(outputs + [state],
                        feed_dict={inputs[0]: input_value})
      for tp in values[:-1]:
        for v in tp:
          self.assertTrue(np.all(np.isfinite(v)))
      for tp in values[-1]:
        for st in tp:
          for v in st:
            self.assertTrue(np.all(np.isfinite(v)))