Exemplo n.º 1
0
 def test_call_without_circular_padding(self):
     context_features = {
         'context_feature_1': tf.constant([[1], [0]], dtype=tf.float32)
     }
     example_features = {
         'example_feature_1':
         tf.constant([[[1], [0], [-1]], [[0], [1], [0]]], dtype=tf.float32)
     }
     mask = tf.constant([[True, True, False], [True, False, False]],
                        dtype=tf.bool)
     expected_concat_tensor = tf.constant(
         [[[1., 1.], [1., 0.], [1., -1.]], [[0., 0.], [0., 1.], [0., 0.]]],
         dtype=tf.float32)
     concat_tensor = layers.ConcatFeatures(circular_padding=False)(
         (context_features, example_features, mask))
     self.assertAllClose(expected_concat_tensor, concat_tensor)
Exemplo n.º 2
0
 def test_serialization(self):
     layer = layers.ConcatFeatures()
     serialized = tf.keras.layers.serialize(layer)
     loaded = tf.keras.layers.deserialize(serialized)
     self.assertAllEqual(loaded.get_config(), layer.get_config())