Esempio n. 1
0
    def testStateTupleDictConversion(self):
        """Test `state_tuple_to_dict` and `dict_to_state_tuple`."""
        cell_sizes = [5, 3, 7]
        # A MultiRNNCell of LSTMCells is both a common choice and an interesting
        # test case, because it has two levels of nesting, with an inner class that
        # is not a plain tuple.
        cell = core_rnn_cell_impl.MultiRNNCell(
            [core_rnn_cell_impl.LSTMCell(i) for i in cell_sizes])
        state_dict = {
            dynamic_rnn_estimator._get_state_name(i):
            array_ops.expand_dims(math_ops.range(cell_size), 0)
            for i, cell_size in enumerate([5, 5, 3, 3, 7, 7])
        }
        expected_state = (core_rnn_cell_impl.LSTMStateTuple(
            np.reshape(np.arange(5), [1, -1]),
            np.reshape(np.arange(5), [1, -1])),
                          core_rnn_cell_impl.LSTMStateTuple(
                              np.reshape(np.arange(3), [1, -1]),
                              np.reshape(np.arange(3), [1, -1])),
                          core_rnn_cell_impl.LSTMStateTuple(
                              np.reshape(np.arange(7), [1, -1]),
                              np.reshape(np.arange(7), [1, -1])))
        actual_state = dynamic_rnn_estimator.dict_to_state_tuple(
            state_dict, cell)
        flattened_state = dynamic_rnn_estimator.state_tuple_to_dict(
            actual_state)

        with self.test_session() as sess:
            (state_dict_val, actual_state_val, flattened_state_val) = sess.run(
                [state_dict, actual_state, flattened_state])

        def _recursive_assert_equal(x, y):
            self.assertEqual(type(x), type(y))
            if isinstance(x, (list, tuple)):
                self.assertEqual(len(x), len(y))
                for i, _ in enumerate(x):
                    _recursive_assert_equal(x[i], y[i])
            elif isinstance(x, np.ndarray):
                np.testing.assert_array_equal(x, y)
            else:
                self.fail('Unexpected type: {}'.format(type(x)))

        for k in state_dict_val.keys():
            np.testing.assert_array_almost_equal(
                state_dict_val[k],
                flattened_state_val[k],
                err_msg='Wrong value for state component {}.'.format(k))
        _recursive_assert_equal(expected_state, actual_state_val)
  def testStateTupleDictConversion(self):
    """Test `state_tuple_to_dict` and `dict_to_state_tuple`."""
    cell_sizes = [5, 3, 7]
    # A MultiRNNCell of LSTMCells is both a common choice and an interesting
    # test case, because it has two levels of nesting, with an inner class that
    # is not a plain tuple.
    cell = core_rnn_cell_impl.MultiRNNCell(
        [core_rnn_cell_impl.LSTMCell(i) for i in cell_sizes])
    state_dict = {
        dynamic_rnn_estimator._get_state_name(i):
        array_ops.expand_dims(math_ops.range(cell_size), 0)
        for i, cell_size in enumerate([5, 5, 3, 3, 7, 7])
    }
    expected_state = (core_rnn_cell_impl.LSTMStateTuple(
        np.reshape(np.arange(5), [1, -1]), np.reshape(np.arange(5), [1, -1])),
                      core_rnn_cell_impl.LSTMStateTuple(
                          np.reshape(np.arange(3), [1, -1]),
                          np.reshape(np.arange(3), [1, -1])),
                      core_rnn_cell_impl.LSTMStateTuple(
                          np.reshape(np.arange(7), [1, -1]),
                          np.reshape(np.arange(7), [1, -1])))
    actual_state = dynamic_rnn_estimator.dict_to_state_tuple(state_dict, cell)
    flattened_state = dynamic_rnn_estimator.state_tuple_to_dict(actual_state)

    with self.test_session() as sess:
      (state_dict_val, actual_state_val, flattened_state_val) = sess.run(
          [state_dict, actual_state, flattened_state])

    def _recursive_assert_equal(x, y):
      self.assertEqual(type(x), type(y))
      if isinstance(x, (list, tuple)):
        self.assertEqual(len(x), len(y))
        for i, _ in enumerate(x):
          _recursive_assert_equal(x[i], y[i])
      elif isinstance(x, np.ndarray):
        np.testing.assert_array_equal(x, y)
      else:
        self.fail('Unexpected type: {}'.format(type(x)))

    for k in state_dict_val.keys():
      np.testing.assert_array_almost_equal(
          state_dict_val[k],
          flattened_state_val[k],
          err_msg='Wrong value for state component {}.'.format(k))
    _recursive_assert_equal(expected_state, actual_state_val)