def testStateTupleDictConversion(self): """Test `state_tuple_to_dict` and `dict_to_state_tuple`.""" cell_sizes = [5, 3, 7] # A MultiRNNCell of LSTMCells is both a common choice and an interesting # test case, because it has two levels of nesting, with an inner class that # is not a plain tuple. cell = core_rnn_cell_impl.MultiRNNCell( [core_rnn_cell_impl.LSTMCell(i) for i in cell_sizes]) state_dict = { dynamic_rnn_estimator._get_state_name(i): array_ops.expand_dims(math_ops.range(cell_size), 0) for i, cell_size in enumerate([5, 5, 3, 3, 7, 7]) } expected_state = (core_rnn_cell_impl.LSTMStateTuple( np.reshape(np.arange(5), [1, -1]), np.reshape(np.arange(5), [1, -1])), core_rnn_cell_impl.LSTMStateTuple( np.reshape(np.arange(3), [1, -1]), np.reshape(np.arange(3), [1, -1])), core_rnn_cell_impl.LSTMStateTuple( np.reshape(np.arange(7), [1, -1]), np.reshape(np.arange(7), [1, -1]))) actual_state = dynamic_rnn_estimator.dict_to_state_tuple( state_dict, cell) flattened_state = dynamic_rnn_estimator.state_tuple_to_dict( actual_state) with self.test_session() as sess: (state_dict_val, actual_state_val, flattened_state_val) = sess.run( [state_dict, actual_state, flattened_state]) def _recursive_assert_equal(x, y): self.assertEqual(type(x), type(y)) if isinstance(x, (list, tuple)): self.assertEqual(len(x), len(y)) for i, _ in enumerate(x): _recursive_assert_equal(x[i], y[i]) elif isinstance(x, np.ndarray): np.testing.assert_array_equal(x, y) else: self.fail('Unexpected type: {}'.format(type(x))) for k in state_dict_val.keys(): np.testing.assert_array_almost_equal( state_dict_val[k], flattened_state_val[k], err_msg='Wrong value for state component {}.'.format(k)) _recursive_assert_equal(expected_state, actual_state_val)
def testStateTupleDictConversion(self): """Test `state_tuple_to_dict` and `dict_to_state_tuple`.""" cell_sizes = [5, 3, 7] # A MultiRNNCell of LSTMCells is both a common choice and an interesting # test case, because it has two levels of nesting, with an inner class that # is not a plain tuple. cell = core_rnn_cell_impl.MultiRNNCell( [core_rnn_cell_impl.LSTMCell(i) for i in cell_sizes]) state_dict = { dynamic_rnn_estimator._get_state_name(i): array_ops.expand_dims(math_ops.range(cell_size), 0) for i, cell_size in enumerate([5, 5, 3, 3, 7, 7]) } expected_state = (core_rnn_cell_impl.LSTMStateTuple( np.reshape(np.arange(5), [1, -1]), np.reshape(np.arange(5), [1, -1])), core_rnn_cell_impl.LSTMStateTuple( np.reshape(np.arange(3), [1, -1]), np.reshape(np.arange(3), [1, -1])), core_rnn_cell_impl.LSTMStateTuple( np.reshape(np.arange(7), [1, -1]), np.reshape(np.arange(7), [1, -1]))) actual_state = dynamic_rnn_estimator.dict_to_state_tuple(state_dict, cell) flattened_state = dynamic_rnn_estimator.state_tuple_to_dict(actual_state) with self.test_session() as sess: (state_dict_val, actual_state_val, flattened_state_val) = sess.run( [state_dict, actual_state, flattened_state]) def _recursive_assert_equal(x, y): self.assertEqual(type(x), type(y)) if isinstance(x, (list, tuple)): self.assertEqual(len(x), len(y)) for i, _ in enumerate(x): _recursive_assert_equal(x[i], y[i]) elif isinstance(x, np.ndarray): np.testing.assert_array_equal(x, y) else: self.fail('Unexpected type: {}'.format(type(x))) for k in state_dict_val.keys(): np.testing.assert_array_almost_equal( state_dict_val[k], flattened_state_val[k], err_msg='Wrong value for state component {}.'.format(k)) _recursive_assert_equal(expected_state, actual_state_val)