def inner_apply(edge_embeddings, node_embeddings): first_layer_dim = mlp_vtoe_dims[0] additional_layer_dims = mlp_vtoe_dims[1:] if allow_non_adjacent and edge_embeddings is not None: num_separate_mlps = 1 + edge_embeddings.shape[-1] elif allow_non_adjacent: num_separate_mlps = 1 elif edge_embeddings is not None: num_separate_mlps = edge_embeddings.shape[-1] else: raise ValueError( "Either allow_non_adjacent should be True, or " "edge_embeddings should be provided") node_embedding_dim = node_embeddings.shape[-1] # First layer: process each node embedding. weight_from_source = self.param( "l0_weight_from_source", shape=(num_separate_mlps, node_embedding_dim, first_layer_dim), initializer=initializers.xavier_normal()) weight_from_dest = self.param( "l0_weight_from_dest", shape=(num_separate_mlps, node_embedding_dim, first_layer_dim), initializer=initializers.xavier_normal()) bias = self.param("l0_bias", shape=(num_separate_mlps, first_layer_dim), initializer=initializers.zeros) from_source = jnp.einsum("sx,kxy->sky", node_embeddings, weight_from_source) from_dest = jnp.einsum("dx,kxy->dky", node_embeddings, weight_from_dest) activations = jax.nn.relu(from_source[:, None, :, :] + from_dest[None, :, :, :] + bias[None, None, :, :]) # Additional layers: MLP for each edge type. for i, layer_dim in enumerate(additional_layer_dims): weight = self.param(f"l{i+1}_weight", shape=(num_separate_mlps, activations.shape[-1], layer_dim), initializer=initializers.xavier_normal()) bias = self.param(f"l{i+1}_bias", shape=(num_separate_mlps, layer_dim), initializer=initializers.zeros) activations = jax.nn.relu( jnp.einsum("sdkx,kxy->sdky", activations, weight) + bias[None, None, :, :]) # Sum over edge types and possibly over source nodes. if edge_embeddings is None: result = activations.squeeze(axis=2) if mask is not None: result = jnp.where(mask[:, :, None], result, jnp.zeros_like(result)) if message_passing: result = jnp.sum(result, axis=0) else: if allow_non_adjacent: if mask is None: pairwise = jnp.ones(edge_embeddings.shape[:2] + (1, )) else: pairwise = mask mlp_weights = jnp.concatenate([ edge_embeddings, pairwise.astype("float")[:, :, None] ], -1) else: mlp_weights = edge_embeddings if message_passing: result = jnp.einsum("sdky,sdk->dy", activations, mlp_weights) else: result = jnp.einsum("sdky,sdk->sdy", activations, mlp_weights) return result
def apply(self, edge_embeddings, node_embeddings, message_dim=gin.REQUIRED, scale_by_num_nodes=True, scale_by_edge_embedding=False, use_efficient_conv=True, just_use_xavier=False, with_bias=False): """Apply the linear message passing layer. Args: edge_embeddings: <float32[num_nodes, num_nodes, edge_embedding_dim]> dense edge embedding matrix, where zeros indicate no edge. node_embeddings: <float32[num_nodes, node_embedding_dim]> node embedding matrix. message_dim: Dimension of the desired messages. scale_by_num_nodes: Whether to scale down the message tensor initialization by sqrt(num_nodes), to correct for some nodes getting many messages. scale_by_edge_embedding: Whether to scale down the message tensor initialization by sqrt(edge_embedding_dim), to correct for edge embeddings having high magnitude (i.e. with learned edge embeddings). use_efficient_conv: Whether to directly lower the einsum into an XLA convolution, to ensure it is memory efficient. just_use_xavier: Whether to do standard Xavier initialization instead of scaling down the parameter based on above scaling factors. with_bias: Whether to add a bias term (which depends on edge type but not source content). Returns: <float32[num_nodes, message_dim]> containing the sum of received messages. """ edge_embedding_dim = edge_embeddings.shape[-1] node_embedding_dim = node_embeddings.shape[-1] num_nodes = node_embeddings.shape[0] if just_use_xavier: message_passing_tensor = self.param( "message_passing_tensor", shape=(edge_embedding_dim, node_embedding_dim, message_dim), initializer=initializers.xavier_normal()) if with_bias: edge_bias_tensor = self.param( "edge_bias_tensor", shape=(edge_embedding_dim, message_dim), initializer=initializers.xavier_normal()) else: variance_correction = node_embedding_dim if scale_by_num_nodes: variance_correction *= num_nodes if scale_by_edge_embedding: variance_correction *= edge_embedding_dim message_passing_tensor = self.param( "message_passing_tensor", shape=(edge_embedding_dim, node_embedding_dim, message_dim), initializer=initializers.normal()) / np.sqrt( variance_correction) if with_bias: edge_bias_tensor = self.param( "edge_bias_tensor", shape=(edge_embedding_dim, message_dim), initializer=initializers.normal()) / np.sqrt( variance_correction / node_embedding_dim) if use_efficient_conv: # Carefully chose conv axes so that the node axis is the feature axis. # First, sources compute the messages to send. # eim, si->sem # 0CN,0OI->C0N messages = jax.lax.conv_general_dilated(message_passing_tensor, node_embeddings[None], window_strides=(1, ), padding="VALID", dimension_numbers=("0CN", "0OI", "C0N")) # Next, messages are sent across edges. # sem,sde-> dm # C0N,IO0->0CN received = jax.lax.conv_general_dilated( messages, edge_embeddings, window_strides=(edge_embedding_dim, ), padding="VALID", dimension_numbers=("C0N", "IO0", "0CN")).squeeze(0) else: # Let JAX handle the einsum implementation. received = jnp.einsum("sde,si,eim->dm", edge_embeddings, node_embeddings, message_passing_tensor) if with_bias: received = ( received + jnp.einsum("sde,em->dm", edge_embeddings, edge_bias_tensor)) return received
def inner_apply(edge_embeddings, node_embeddings): # Einsum letters: # n: querying node, attends to others. # m: queried node, is attended to. # h: attention head # d: node embedding dimension # e: edge embedding dimension # q: query/key dimension # v: value dimension edge_embedding_dim = edge_embeddings.shape[-1] node_embedding_dim = node_embeddings.shape[-1] nonlocal query_key_dim, value_dim if query_key_dim is None: if node_embedding_dim % heads != 0: raise ValueError( "No query_key_dim provided, but node embedding dim " f"({node_embedding_dim}) was not divisible by head count " f"({heads})") query_key_dim = node_embedding_dim // heads if value_dim is None: value_dim = query_key_dim # Compute queries. query_tensor = self.param("query_tensor", shape=(heads, node_embedding_dim, query_key_dim), initializer=initializers.xavier_normal()) query = jnp.einsum("nd,hdq->hnq", node_embeddings, query_tensor) # Dot-product the queries with the node and edge keys. node_key_tensor = self.param( "node_key_tensor", shape=(heads, node_embedding_dim, query_key_dim), initializer=initializers.xavier_normal()) dot_product_logits = jnp.einsum("md,hdq,hnq->hnm", node_embeddings, node_key_tensor, query) if like_great: # Edges contribute based on key sums, as in Hellendoorn et al. # edge_biases: <float32[num_nodes, num_nodes]>, computed as `w^T e + b` edge_biases = flax.nn.Dense(edge_embeddings, features=1).squeeze(-1) # Einsum sums keys over `q` dim (equivalent to broadcasting out biases). edge_logits = jnp.einsum("md,hdq,nm->hnm", node_embeddings, node_key_tensor, edge_biases) else: # Queries attend to edge keys, as in Wang et al. edge_key_tensor = self.param( "edge_key_tensor", shape=(edge_embedding_dim, query_key_dim), initializer=initializers.xavier_normal()) edge_logits = jnp.einsum("nme,eq,hnq->hnm", edge_embeddings, edge_key_tensor, query) # Combine, normalize, and possibly mask. attention_logits = ((dot_product_logits + edge_logits) / jnp.sqrt(query_key_dim)) if mask is not None: attention_logits = attention_logits + jnp.log(mask)[None, :, :] attention_weights = jax.nn.softmax(attention_logits, axis=2) # Wrap attention weights with a Flax module so we can extract intermediate # outputs. attention_weights = jax_util.flax_tag(attention_weights, name="attention_weights") # Compute values. node_value_tensor = self.param( "node_value_tensor", shape=(heads, node_embedding_dim, value_dim), initializer=initializers.xavier_normal()) attention_node_value = jnp.einsum( "hnm,hmv->hnv", attention_weights, jnp.einsum("md,hdv->hmv", node_embeddings, node_value_tensor)) if like_great: # Only nodes contribute to values, as in Hellendoorn et al. attention_value = attention_node_value else: # Edges also contribute to values, as in Wang et al. edge_value_tensor = self.param( "edge_value_tensor", shape=(edge_embedding_dim, value_dim), initializer=initializers.xavier_normal()) attention_edge_value = jnp.einsum("hnm,nme,ev->hnv", attention_weights, edge_embeddings, edge_value_tensor) attention_value = attention_node_value + attention_edge_value # Project them back. output_tensor = self.param( "output_tensor", shape=(heads, value_dim, out_dim), initializer=initializers.xavier_normal()) output = jnp.einsum("hnv,hvo->no", attention_value, output_tensor) return output
def __init__(self, in_size, out_size): super().__init__() self.weight = ParamInit((out_size, in_size), init.xavier_normal()) self.bias = ParamInit((out_size, ), init.normal())