def test_relational_layer_for_tensor(self):
   a = np.array([[[1], [2]]])
   shouldbe = np.array([[[2, 3], [4, 3]]])
   layer = relational_layers.RelationalLayer(
       tf.keras.layers.Lambda(lambda x: x),
       tf.keras.layers.Lambda(lambda x: tf.reduce_sum(x, axis=-2)))
   result = self.evaluate(layer(tf.constant(a)))
   self.assertAllClose(shouldbe, result)
  def __init__(self,
               edge_mlp=gin.REQUIRED,
               graph_mlp=gin.REQUIRED,
               dropout_in_last_graph_layer=gin.REQUIRED,
               name="OptimizedWildRelNet",
               **kwargs):
    """Constructs a OptimizedWildRelNet.

    Args:
      edge_mlp: List with number of latent nodes in different layers of the edge
        MLP.
      graph_mlp: List with number of latent nodes in different layers of the
        graph MLP.
      dropout_in_last_graph_layer: Dropout fraction to be applied in the last
        layer of the graph MLP.
      name: String with the name of the model.
      **kwargs: Other keyword arguments passed to tf.keras.Model.
    """
    super(OptimizedWildRelNet, self).__init__(name=name, **kwargs)

    # Create the EdgeMLP.
    edge_layers = []
    for num_units in edge_mlp:
      edge_layers += [
          tf.keras.layers.Dense(
              num_units,
              activation=get_activation(),
              kernel_initializer=get_kernel_initializer())
      ]
    self.edge_layer = tf.keras.models.Sequential(edge_layers, "edge_mlp")

    # Create the GraphMLP.
    graph_layers = []
    for num_units in graph_mlp:
      graph_layers += [
          tf.keras.layers.Dense(
              num_units,
              activation=get_activation(),
              kernel_initializer=get_kernel_initializer())
      ]
    if dropout_in_last_graph_layer:
      graph_layers += [
          tf.keras.layers.Dropout(
              1. - dropout_in_last_graph_layer,
              noise_shape=[1, 1, graph_mlp[-1]])
      ]
    graph_layers += [
        tf.keras.layers.Dense(1, kernel_initializer=get_kernel_initializer())
    ]

    # Create the auxiliary layers.
    self.graph_layer = tf.keras.models.Sequential(graph_layers, "graph_mlp")
    self.stacking_layer = relational_layers.StackAnswers()

    # Create the WildRelationNet.
    self.wildrelnet = tf.keras.models.Sequential([
        relational_layers.AddPositionalEncoding(),
        relational_layers.RelationalLayer(
            self.edge_layer,
            tf.keras.layers.Lambda(lambda x: tf.reduce_sum(x, axis=-2))),
        tf.keras.layers.Lambda(lambda x: tf.reduce_sum(x, axis=-2)),
        self.graph_layer,
        tf.keras.layers.Lambda(lambda x: tf.reduce_sum(x, axis=-1)),
    ], "wildrelnet")