def test_call_without_circular_padding(self): context_features = { 'context_feature_1': tf.constant([[1], [0]], dtype=tf.float32) } example_features = { 'example_feature_1': tf.constant([[[1], [0], [-1]], [[0], [1], [0]]], dtype=tf.float32) } mask = tf.constant([[True, True, False], [True, False, False]], dtype=tf.bool) expected_concat_tensor = tf.constant( [[[1., 1.], [1., 0.], [1., -1.]], [[0., 0.], [0., 1.], [0., 0.]]], dtype=tf.float32) concat_tensor = layers.ConcatFeatures(circular_padding=False)( (context_features, example_features, mask)) self.assertAllClose(expected_concat_tensor, concat_tensor)
def test_serialization(self): layer = layers.ConcatFeatures() serialized = tf.keras.layers.serialize(layer) loaded = tf.keras.layers.deserialize(serialized) self.assertAllEqual(loaded.get_config(), layer.get_config())