def testRegularizers(self, trainable, state_size): batch_size = 6 # Set the attribute to the class since it we can't set properties of # abstract classes snt.RNNCore.state_size = state_size flat_state_size = nest.flatten(state_size) core = snt.RNNCore(name="dummy_core") flat_regularizer = ([contrib_layers.l1_regularizer(scale=0.5)] * len(flat_state_size)) trainable_regularizers = nest.pack_sequence_as( structure=state_size, flat_sequence=flat_regularizer) core.initial_state(batch_size, dtype=tf.float32, trainable=trainable, trainable_regularizers=trainable_regularizers) graph_regularizers = tf.get_collection( tf.GraphKeys.REGULARIZATION_LOSSES) if not trainable: self.assertFalse(graph_regularizers) else: self.assertEqual(len(graph_regularizers), len(flat_state_size)) if not tf.executing_eagerly(): for i in range(len(flat_state_size)): self.assertRegexpMatches(graph_regularizers[i].name, ".*l1_regularizer.*")
def testInitialStateTuple(self, trainable, use_custom_initial_value, state_size): batch_size = 6 # Set the attribute to the class since it we can't set properties of # abstract classes snt.RNNCore.state_size = state_size flat_state_size = nest.flatten(state_size) core = snt.RNNCore(name="dummy_core") if use_custom_initial_value: flat_initializer = [tf.constant_initializer(2) ] * len(flat_state_size) trainable_initializers = nest.pack_sequence_as( structure=state_size, flat_sequence=flat_initializer) else: trainable_initializers = None initial_state = core.initial_state( batch_size, dtype=tf.float32, trainable=trainable, trainable_initializers=trainable_initializers) nest.assert_same_structure(initial_state, state_size) flat_initial_state = nest.flatten(initial_state) for state, size in zip(flat_initial_state, flat_state_size): self.assertEqual(state.get_shape(), [batch_size, size]) self.evaluate(tf.global_variables_initializer()) flat_initial_state_value = self.evaluate(flat_initial_state) for value, size in zip(flat_initial_state_value, flat_state_size): expected_initial_state = np.empty([batch_size, size]) if not trainable: expected_initial_state.fill(0) elif use_custom_initial_value: expected_initial_state.fill(2) else: value_row = value[0] expected_initial_state = np.tile(value_row, (batch_size, 1)) self.assertAllClose(value, expected_initial_state)