def test_multi_dim_batch_apply_for_tensors(self):
   # Tensors used for testing.
   input_tensor = np.arange(24).reshape((2, 3, 4))
   kernel = np.arange(24, 36).reshape((4, 3))
   # Compute the correct solutions.
   shouldbe = np.matmul(input_tensor, kernel)
   # Compute the solution based on the layer.
   layer = relational_layers.MultiDimBatchApply(
       tf.keras.layers.Lambda(lambda x: tf.matmul(x, tf.constant(kernel))),
       num_dims_to_keep=1)
   result = self.evaluate(layer(tf.constant(input_tensor)))
   # Check that they are the same.
   self.assertAllClose(shouldbe, result)
  def __init__(self, hub_path=gin.REQUIRED, name="HubEmbedding", **kwargs):
    """Constructs a HubEmbedding.

    Args:
      hub_path: Path to the TFHub module.
      name: String with the name of the model.
      **kwargs: Other keyword arguments passed to tf.keras.Model.
    """
    super(HubEmbedding, self).__init__(name=name, **kwargs)

    def _embedder(x):
      embedder_module = hub.Module(hub_path)
      return embedder_module(dict(images=x), signature="representation")

    self.embedding_layer = relational_layers.MultiDimBatchApply(
        tf.keras.layers.Lambda(_embedder))
  def __init__(self,
               num_latent=gin.REQUIRED,
               name="BaselineCNNEmbedder",
               **kwargs):
    """Constructs a BaselineCNNEmbedder.

    Args:
      num_latent: Integer with the number of latent dimensions.
      name: String with the name of the model.
      **kwargs: Other keyword arguments passed to tf.keras.Model.
    """
    super(BaselineCNNEmbedder, self).__init__(name=name, **kwargs)
    embedding_layers = [
        tf.keras.layers.Conv2D(
            32, (4, 4),
            2,
            activation=get_activation(),
            padding="same",
            kernel_initializer=get_kernel_initializer()),
        tf.keras.layers.Conv2D(
            32, (4, 4),
            2,
            activation=get_activation(),
            padding="same",
            kernel_initializer=get_kernel_initializer()),
        tf.keras.layers.Conv2D(
            64, (4, 4),
            2,
            activation=get_activation(),
            padding="same",
            kernel_initializer=get_kernel_initializer()),
        tf.keras.layers.Conv2D(
            64, (4, 4),
            2,
            activation=get_activation(),
            padding="same",
            kernel_initializer=get_kernel_initializer()),
        tf.keras.layers.Flatten(),
    ]
    self.embedding_layer = relational_layers.MultiDimBatchApply(
        tf.keras.models.Sequential(embedding_layers, "embedding_cnn"))