def run_steps(node_embeddings): """Runs propagation and updates.""" if model_type == "ggnn": final_embeddings = ggnn_steps(node_embeddings, edge_embeddings) elif model_type == "transformer": neighbor_mask = graph_layers.edge_mask( edges=example.edges, num_nodes=graph_metadata.num_nodes, num_edge_types=num_edge_types, forward_edge_type_indices=forward_edge_type_indices, reverse_edge_type_indices=reverse_edge_type_indices) # Allow nodes to attend to themselves neighbor_mask = jnp.maximum(neighbor_mask, jnp.eye(graph_metadata.num_nodes)) final_embeddings = transformer_steps( node_embeddings, edge_embeddings, neighbor_mask, num_real_nodes_per_graph=example.graph_metadata.num_nodes) elif model_type == "nri_encoder": final_embeddings = nri_steps( node_embeddings, edge_embeddings, num_real_nodes_per_graph=example.graph_metadata.num_nodes) return final_embeddings
def test_edge_nonzeros(self, which): (edges, forward_edge_type_indices, reverse_edge_type_indices) = self._setup_edges() if which == "edge_mask": nonzeros = graph_layers.edge_mask( edges, num_nodes=4, num_edge_types=3, forward_edge_type_indices=forward_edge_type_indices, reverse_edge_type_indices=reverse_edge_type_indices) elif which == "LearnableEdgeEmbeddings": embeddings, _ = graph_layers.LearnableEdgeEmbeddings.init( jax.random.PRNGKey(0), edges=edges, num_nodes=4, num_edge_types=3, forward_edge_type_indices=forward_edge_type_indices, reverse_edge_type_indices=reverse_edge_type_indices, embedding_dim=EDGE_EMBEDDING_DIM) nonzeros = jnp.any(embeddings != 0, axis=-1).astype(jnp.float32) expected = np.zeros([4, 4], np.float32) expected[2, 0] = 1 expected[0, 1] = 1 expected[1, 2] = 1 expected[0, 2] = 1 expected[1, 0] = 1 expected[2, 1] = 1 np.testing.assert_allclose(nonzeros, expected)
def apply( self, graph_context, node_embeddings, edge_embeddings, forward_edge_types=gin.REQUIRED, reverse_edge_types=gin.REQUIRED, walk_length_log2=gin.REQUIRED, ): """Modifies edge embeddings using a uniform random walk. Uses an efficient repeated-squaring technique to compute the absorbing distribution. Args: graph_context: Input graph for this example. node_embeddings: Current node embeddings, as <float32[num_nodes, node_embedding_dim]> edge_embeddings: Current edge embeddings, as <float32[num_nodes, num_nodes, edge_embedding_dim]> forward_edge_types: Edge types to use in the forward direction. As a list of lists to allow configuring groups of edges in config files; this will be flattened before use. reverse_edge_types: Edge types to use in the reverse direction. Note that reversed edge types are given a separate embedding from forward edge types; undirected edges should be represented by adding two edges in opposite directions and then only using `forward_edge_types`. Also a list of lists, as above. walk_length_log2: Base-2 logarithm of maximum walk length; this determines how many times we will square the transition matrix (doubling the walk length). Returns: New node and edge embeddings. Node embeddings will not be modified. Edge embeddings will be modified by adding a new edge type (either embedded or concatenated based on graph_context.edges_are_embedded). """ num_nodes = node_embeddings.shape[0] # pylint: disable=g-complex-comprehension forward_edge_type_indices = [ graph_context.edge_types_to_indices[type_str] for group in forward_edge_types for type_str in group ] reverse_edge_type_indices = [ graph_context.edge_types_to_indices[type_str] for group in reverse_edge_types for type_str in group ] # pylint: enable=g-complex-comprehension adjacency = graph_layers.edge_mask( edges=graph_context.bundle.edges, num_nodes=num_nodes, num_edge_types=len(graph_context.edge_types_to_indices), forward_edge_type_indices=forward_edge_type_indices, reverse_edge_type_indices=reverse_edge_type_indices) adjacency = jnp.maximum(adjacency, jnp.eye(num_nodes)) absorbing_logit = self.param( "absorbing_logit", shape=(), initializer=lambda *_: jax.scipy.special.logit(0.1)) absorbing_prob = jax.nn.sigmoid(absorbing_logit) nonabsorbing_prob = jax.nn.sigmoid(-absorbing_logit) walk_matrix = nonabsorbing_prob * adjacency / jnp.sum( adjacency, axis=1, keepdims=True) # A, I # A^2, A + I # (A^2)^2 = A^4, (A + I)A^2 + (A + I) = A^3 + A^2 + A + I # ... def step(state, _): nth_power, nth_partial_sum = state return (nth_power @ nth_power, nth_power @ nth_partial_sum + nth_partial_sum), None (_, partial_sum), _ = jax.lax.scan(step, (walk_matrix, jnp.eye(num_nodes)), None, length=walk_length_log2) approx_visits = absorbing_prob * partial_sum logits = model_util.safe_logit(approx_visits) logits = model_util.ScaleAndShift(logits) edge_weights = jax.nn.sigmoid(logits) return (node_embeddings, _add_edges(edge_embeddings, edge_weights[:, :, None], graph_context.edges_are_embedded))