Ejemplo n.º 1
0
        def inner_apply(edge_embeddings, node_embeddings):
            first_layer_dim = mlp_vtoe_dims[0]
            additional_layer_dims = mlp_vtoe_dims[1:]

            if allow_non_adjacent and edge_embeddings is not None:
                num_separate_mlps = 1 + edge_embeddings.shape[-1]
            elif allow_non_adjacent:
                num_separate_mlps = 1
            elif edge_embeddings is not None:
                num_separate_mlps = edge_embeddings.shape[-1]
            else:
                raise ValueError(
                    "Either allow_non_adjacent should be True, or "
                    "edge_embeddings should be provided")

            node_embedding_dim = node_embeddings.shape[-1]

            # First layer: process each node embedding.
            weight_from_source = self.param(
                "l0_weight_from_source",
                shape=(num_separate_mlps, node_embedding_dim, first_layer_dim),
                initializer=initializers.xavier_normal())
            weight_from_dest = self.param(
                "l0_weight_from_dest",
                shape=(num_separate_mlps, node_embedding_dim, first_layer_dim),
                initializer=initializers.xavier_normal())
            bias = self.param("l0_bias",
                              shape=(num_separate_mlps, first_layer_dim),
                              initializer=initializers.zeros)
            from_source = jnp.einsum("sx,kxy->sky", node_embeddings,
                                     weight_from_source)
            from_dest = jnp.einsum("dx,kxy->dky", node_embeddings,
                                   weight_from_dest)
            activations = jax.nn.relu(from_source[:, None, :, :] +
                                      from_dest[None, :, :, :] +
                                      bias[None, None, :, :])

            # Additional layers: MLP for each edge type.
            for i, layer_dim in enumerate(additional_layer_dims):
                weight = self.param(f"l{i+1}_weight",
                                    shape=(num_separate_mlps,
                                           activations.shape[-1], layer_dim),
                                    initializer=initializers.xavier_normal())
                bias = self.param(f"l{i+1}_bias",
                                  shape=(num_separate_mlps, layer_dim),
                                  initializer=initializers.zeros)
                activations = jax.nn.relu(
                    jnp.einsum("sdkx,kxy->sdky", activations, weight) +
                    bias[None, None, :, :])

            # Sum over edge types and possibly over source nodes.
            if edge_embeddings is None:
                result = activations.squeeze(axis=2)
                if mask is not None:
                    result = jnp.where(mask[:, :, None], result,
                                       jnp.zeros_like(result))
                if message_passing:
                    result = jnp.sum(result, axis=0)
            else:
                if allow_non_adjacent:
                    if mask is None:
                        pairwise = jnp.ones(edge_embeddings.shape[:2] + (1, ))
                    else:
                        pairwise = mask
                    mlp_weights = jnp.concatenate([
                        edge_embeddings,
                        pairwise.astype("float")[:, :, None]
                    ], -1)
                else:
                    mlp_weights = edge_embeddings

                if message_passing:
                    result = jnp.einsum("sdky,sdk->dy", activations,
                                        mlp_weights)
                else:
                    result = jnp.einsum("sdky,sdk->sdy", activations,
                                        mlp_weights)

            return result
Ejemplo n.º 2
0
    def apply(self,
              edge_embeddings,
              node_embeddings,
              message_dim=gin.REQUIRED,
              scale_by_num_nodes=True,
              scale_by_edge_embedding=False,
              use_efficient_conv=True,
              just_use_xavier=False,
              with_bias=False):
        """Apply the linear message passing layer.

    Args:
      edge_embeddings: <float32[num_nodes, num_nodes, edge_embedding_dim]> dense
        edge embedding matrix, where zeros indicate no edge.
      node_embeddings: <float32[num_nodes, node_embedding_dim]> node embedding
        matrix.
      message_dim: Dimension of the desired messages.
      scale_by_num_nodes: Whether to scale down the message tensor
        initialization by sqrt(num_nodes), to correct for some nodes getting
        many messages.
      scale_by_edge_embedding: Whether to scale down the message tensor
        initialization by sqrt(edge_embedding_dim), to correct for edge
        embeddings having high magnitude (i.e. with learned edge embeddings).
      use_efficient_conv: Whether to directly lower the einsum into an XLA
        convolution, to ensure it is memory efficient.
      just_use_xavier: Whether to do standard Xavier initialization instead of
        scaling down the parameter based on above scaling factors.
      with_bias: Whether to add a bias term (which depends on edge type but not
        source content).

    Returns:
      <float32[num_nodes, message_dim]> containing the sum of received messages.
    """
        edge_embedding_dim = edge_embeddings.shape[-1]
        node_embedding_dim = node_embeddings.shape[-1]
        num_nodes = node_embeddings.shape[0]
        if just_use_xavier:
            message_passing_tensor = self.param(
                "message_passing_tensor",
                shape=(edge_embedding_dim, node_embedding_dim, message_dim),
                initializer=initializers.xavier_normal())
            if with_bias:
                edge_bias_tensor = self.param(
                    "edge_bias_tensor",
                    shape=(edge_embedding_dim, message_dim),
                    initializer=initializers.xavier_normal())
        else:
            variance_correction = node_embedding_dim
            if scale_by_num_nodes:
                variance_correction *= num_nodes
            if scale_by_edge_embedding:
                variance_correction *= edge_embedding_dim
            message_passing_tensor = self.param(
                "message_passing_tensor",
                shape=(edge_embedding_dim, node_embedding_dim, message_dim),
                initializer=initializers.normal()) / np.sqrt(
                    variance_correction)
            if with_bias:
                edge_bias_tensor = self.param(
                    "edge_bias_tensor",
                    shape=(edge_embedding_dim, message_dim),
                    initializer=initializers.normal()) / np.sqrt(
                        variance_correction / node_embedding_dim)

        if use_efficient_conv:
            # Carefully chose conv axes so that the node axis is the feature axis.
            # First, sources compute the messages to send.
            # eim, si->sem
            # 0CN,0OI->C0N
            messages = jax.lax.conv_general_dilated(message_passing_tensor,
                                                    node_embeddings[None],
                                                    window_strides=(1, ),
                                                    padding="VALID",
                                                    dimension_numbers=("0CN",
                                                                       "0OI",
                                                                       "C0N"))
            # Next, messages are sent across edges.
            # sem,sde-> dm
            # C0N,IO0->0CN
            received = jax.lax.conv_general_dilated(
                messages,
                edge_embeddings,
                window_strides=(edge_embedding_dim, ),
                padding="VALID",
                dimension_numbers=("C0N", "IO0", "0CN")).squeeze(0)
        else:
            # Let JAX handle the einsum implementation.
            received = jnp.einsum("sde,si,eim->dm", edge_embeddings,
                                  node_embeddings, message_passing_tensor)

        if with_bias:
            received = (
                received +
                jnp.einsum("sde,em->dm", edge_embeddings, edge_bias_tensor))

        return received
Ejemplo n.º 3
0
        def inner_apply(edge_embeddings, node_embeddings):
            # Einsum letters:
            # n: querying node, attends to others.
            # m: queried node, is attended to.
            # h: attention head
            # d: node embedding dimension
            # e: edge embedding dimension
            # q: query/key dimension
            # v: value dimension
            edge_embedding_dim = edge_embeddings.shape[-1]
            node_embedding_dim = node_embeddings.shape[-1]

            nonlocal query_key_dim, value_dim

            if query_key_dim is None:
                if node_embedding_dim % heads != 0:
                    raise ValueError(
                        "No query_key_dim provided, but node embedding dim "
                        f"({node_embedding_dim}) was not divisible by head count "
                        f"({heads})")
                query_key_dim = node_embedding_dim // heads

            if value_dim is None:
                value_dim = query_key_dim

            # Compute queries.
            query_tensor = self.param("query_tensor",
                                      shape=(heads, node_embedding_dim,
                                             query_key_dim),
                                      initializer=initializers.xavier_normal())
            query = jnp.einsum("nd,hdq->hnq", node_embeddings, query_tensor)

            # Dot-product the queries with the node and edge keys.
            node_key_tensor = self.param(
                "node_key_tensor",
                shape=(heads, node_embedding_dim, query_key_dim),
                initializer=initializers.xavier_normal())
            dot_product_logits = jnp.einsum("md,hdq,hnq->hnm", node_embeddings,
                                            node_key_tensor, query)

            if like_great:
                # Edges contribute based on key sums, as in Hellendoorn et al.
                # edge_biases: <float32[num_nodes, num_nodes]>, computed as `w^T e + b`
                edge_biases = flax.nn.Dense(edge_embeddings,
                                            features=1).squeeze(-1)
                # Einsum sums keys over `q` dim (equivalent to broadcasting out biases).
                edge_logits = jnp.einsum("md,hdq,nm->hnm", node_embeddings,
                                         node_key_tensor, edge_biases)
            else:
                # Queries attend to edge keys, as in Wang et al.
                edge_key_tensor = self.param(
                    "edge_key_tensor",
                    shape=(edge_embedding_dim, query_key_dim),
                    initializer=initializers.xavier_normal())
                edge_logits = jnp.einsum("nme,eq,hnq->hnm", edge_embeddings,
                                         edge_key_tensor, query)

            # Combine, normalize, and possibly mask.
            attention_logits = ((dot_product_logits + edge_logits) /
                                jnp.sqrt(query_key_dim))
            if mask is not None:
                attention_logits = attention_logits + jnp.log(mask)[None, :, :]

            attention_weights = jax.nn.softmax(attention_logits, axis=2)

            # Wrap attention weights with a Flax module so we can extract intermediate
            # outputs.
            attention_weights = jax_util.flax_tag(attention_weights,
                                                  name="attention_weights")

            # Compute values.
            node_value_tensor = self.param(
                "node_value_tensor",
                shape=(heads, node_embedding_dim, value_dim),
                initializer=initializers.xavier_normal())
            attention_node_value = jnp.einsum(
                "hnm,hmv->hnv", attention_weights,
                jnp.einsum("md,hdv->hmv", node_embeddings, node_value_tensor))

            if like_great:
                # Only nodes contribute to values, as in Hellendoorn et al.
                attention_value = attention_node_value
            else:
                # Edges also contribute to values, as in Wang et al.
                edge_value_tensor = self.param(
                    "edge_value_tensor",
                    shape=(edge_embedding_dim, value_dim),
                    initializer=initializers.xavier_normal())
                attention_edge_value = jnp.einsum("hnm,nme,ev->hnv",
                                                  attention_weights,
                                                  edge_embeddings,
                                                  edge_value_tensor)

                attention_value = attention_node_value + attention_edge_value

            # Project them back.
            output_tensor = self.param(
                "output_tensor",
                shape=(heads, value_dim, out_dim),
                initializer=initializers.xavier_normal())
            output = jnp.einsum("hnv,hvo->no", attention_value, output_tensor)
            return output
Ejemplo n.º 4
0
 def __init__(self, in_size, out_size):
     super().__init__()
     self.weight = ParamInit((out_size, in_size), init.xavier_normal())
     self.bias = ParamInit((out_size, ), init.normal())