def testRegularizers(self, trainable, state_size):
        batch_size = 6

        # Set the attribute to the class since it we can't set properties of
        # abstract classes
        snt.RNNCore.state_size = state_size
        flat_state_size = nest.flatten(state_size)
        core = snt.RNNCore(name="dummy_core")
        flat_regularizer = ([contrib_layers.l1_regularizer(scale=0.5)] *
                            len(flat_state_size))
        trainable_regularizers = nest.pack_sequence_as(
            structure=state_size, flat_sequence=flat_regularizer)

        core.initial_state(batch_size,
                           dtype=tf.float32,
                           trainable=trainable,
                           trainable_regularizers=trainable_regularizers)

        graph_regularizers = tf.get_collection(
            tf.GraphKeys.REGULARIZATION_LOSSES)
        if not trainable:
            self.assertFalse(graph_regularizers)
        else:
            self.assertEqual(len(graph_regularizers), len(flat_state_size))
            if not tf.executing_eagerly():
                for i in range(len(flat_state_size)):
                    self.assertRegexpMatches(graph_regularizers[i].name,
                                             ".*l1_regularizer.*")
    def testInitialStateTuple(self, trainable, use_custom_initial_value,
                              state_size):
        batch_size = 6

        # Set the attribute to the class since it we can't set properties of
        # abstract classes
        snt.RNNCore.state_size = state_size
        flat_state_size = nest.flatten(state_size)
        core = snt.RNNCore(name="dummy_core")
        if use_custom_initial_value:
            flat_initializer = [tf.constant_initializer(2)
                                ] * len(flat_state_size)
            trainable_initializers = nest.pack_sequence_as(
                structure=state_size, flat_sequence=flat_initializer)
        else:
            trainable_initializers = None
        initial_state = core.initial_state(
            batch_size,
            dtype=tf.float32,
            trainable=trainable,
            trainable_initializers=trainable_initializers)

        nest.assert_same_structure(initial_state, state_size)
        flat_initial_state = nest.flatten(initial_state)

        for state, size in zip(flat_initial_state, flat_state_size):
            self.assertEqual(state.get_shape(), [batch_size, size])

        self.evaluate(tf.global_variables_initializer())
        flat_initial_state_value = self.evaluate(flat_initial_state)
        for value, size in zip(flat_initial_state_value, flat_state_size):
            expected_initial_state = np.empty([batch_size, size])
            if not trainable:
                expected_initial_state.fill(0)
            elif use_custom_initial_value:
                expected_initial_state.fill(2)
            else:
                value_row = value[0]
                expected_initial_state = np.tile(value_row, (batch_size, 1))
            self.assertAllClose(value, expected_initial_state)