Beispiel #1
0
 def test_call_without_circular_padding(self):
     context_features = {
         'context_feature_1': tf.constant([[1], [0]], dtype=tf.float32)
     }
     example_features = {
         'example_feature_1':
         tf.constant([[[1], [0], [-1]], [[0], [1], [0]]], dtype=tf.float32)
     }
     mask = tf.constant([[True, True, False], [True, False, False]],
                        dtype=tf.bool)
     expected_context_features = {
         'context_feature_1':
         tf.constant([[1], [1], [1], [0], [0], [0]], dtype=tf.float32)
     }
     expected_example_features = {
         'example_feature_1':
         tf.constant([[1], [0], [-1], [0], [1], [0]], dtype=tf.float32)
     }
     (flattened_context_features,
      flattened_example_features) = layers.FlattenList(
          circular_padding=False)(inputs=(context_features,
                                          example_features, mask))
     self.assertAllClose(expected_context_features,
                         flattened_context_features)
     self.assertAllClose(expected_example_features,
                         flattened_example_features)
Beispiel #2
0
 def test_call_raise_error(self):
   context_features = {
       'context_feature_1': tf.constant([[1], [0]], dtype=tf.float32)
   }
   example_features = {}
   mask = tf.constant([[True, True, False], [True, False, False]],
                      dtype=tf.bool)
   with self.assertRaises(ValueError):
     layers.FlattenList()(inputs=(context_features, example_features, mask))
Beispiel #3
0
    def __call__(
        self,
        context_features: TensorDict,
        example_features: TensorDict,
        mask: tf.Tensor,
    ) -> Union[tf.Tensor, TensorDict]:
        """See `Scorer`."""
        (flattened_context_features,
         flattened_example_features) = layers.FlattenList()(
             inputs=(context_features, example_features, mask))

        flattened_logits = self._score_flattened(flattened_context_features,
                                                 flattened_example_features)

        # Handle a dict of logits for the multi-task setting.
        if isinstance(flattened_logits, dict):
            logits = {
                k: layers.RestoreList(name=k)(inputs=(v, mask))
                for k, v in flattened_logits.items()
            }
        else:
            logits = layers.RestoreList()(inputs=(flattened_logits, mask))
        return logits
Beispiel #4
0
 def test_serialization(self):
     layer = layers.FlattenList()
     serialized = tf.keras.layers.serialize(layer)
     loaded = tf.keras.layers.deserialize(serialized)
     self.assertAllEqual(loaded.get_config(), layer.get_config())