def sparse_gnn_edge_mlp_layer(node_embeddings: tf.Tensor, adjacency_lists: List[tf.Tensor], type_to_num_incoming_edges: tf.Tensor, state_dim: Optional[int], num_timesteps: int = 1, activation_function: Optional[str] = "ReLU", message_aggregation_function: str = "sum", normalize_by_num_incoming: bool = False, use_target_state_as_input: bool = True, num_edge_hidden_layers: int = 1) -> tf.Tensor: """ Compute new graph states by neural message passing using an edge MLP. For this, we assume existing node states h^t_v and a list of per-edge-type adjacency matrices A_\ell. We compute new states as follows: h^{t+1}_v := \sum_\ell \sum_{(u, v) \in A_\ell} \sigma(1/c_{v,\ell} * MLP(h^t_u || h^t_v)) c_{\v,\ell} is usually 1 (but could also be the number of incoming edges). The learnable parameters of this are the W_\ell, F_{\ell,\alpha}, F_{\ell,\beta} \in R^{D, D}. We use the following abbreviations in shape descriptions: * V: number of nodes * D: state dimension * L: number of different edge types * E: number of edges of a given edge type Arguments: node_embeddings: float32 tensor of shape [V, D], the original representation of each node in the graph. adjacency_lists: List of L adjacency lists, represented as int32 tensors of shape [E, 2]. Concretely, adjacency_lists[l][k,:] == [v, u] means that the k-th edge of type l connects node v to node u. type_to_num_incoming_edges: float32 tensor of shape [L, V] representing the number of incoming edges of a given type. Concretely, type_to_num_incoming_edges[l, v] is the number of edge of type l connecting to node v. state_dim: Optional size of output dimension of the GNN layer. If not set, defaults to D, the dimensionality of the input. If different from the input dimension, parameter num_timesteps has to be 1. num_timesteps: Number of repeated applications of this message passing layer. activation_function: Type of activation function used. message_aggregation_function: Type of aggregation function used for messages. normalize_by_num_incoming: Flag indicating if messages should be scaled by 1/(number of incoming edges). use_target_state_as_input: Flag indicating if the edge MLP should consume both source and target state (True) or only source state (False). num_edge_hidden_layers: Number of hidden layers of the edge MLP. message_weights_dropout_ratio: Dropout ratio applied to the weights used to compute message passing functions. Returns: float32 tensor of shape [V, state_dim] """ num_nodes = tf.shape(input=node_embeddings, out_type=tf.int32)[0] if state_dim is None: state_dim = tf.shape(input=node_embeddings, out_type=tf.int32)[1] # === Prepare things we need across all timesteps: activation_fn = get_activation(activation_function) message_aggregation_fn = get_aggregation_function( message_aggregation_function) edge_type_to_edge_mlp = [] # MLPs to compute the edge messages edge_type_to_message_targets = [] # List of tensors of message targets for edge_type_idx, adjacency_list_for_edge_type in enumerate( adjacency_lists): edge_type_to_edge_mlp.append( MLP(out_size=state_dim, hidden_layers=num_edge_hidden_layers, activation_fun=tf.nn.elu, name="Edge_%i_MLP" % edge_type_idx)) edge_type_to_message_targets.append(adjacency_list_for_edge_type[:, 1]) # Let M be the number of messages (sum of all E): message_targets = tf.concat(edge_type_to_message_targets, axis=0) # Shape [M] cur_node_states = node_embeddings for _ in range(num_timesteps): messages_per_type = [] # list of tensors of messages of shape [E, D] # Collect incoming messages per edge type for edge_type_idx, adjacency_list_for_edge_type in enumerate( adjacency_lists): edge_sources = adjacency_list_for_edge_type[:, 0] edge_targets = adjacency_list_for_edge_type[:, 1] edge_source_states = \ tf.nn.embedding_lookup(params=cur_node_states, ids=edge_sources) # Shape [E, D] edge_mlp_inputs = edge_source_states if use_target_state_as_input: edge_target_states = \ tf.nn.embedding_lookup(params=cur_node_states, ids=edge_targets) # Shape [E, D] edge_mlp_inputs = tf.concat( [edge_source_states, edge_target_states], axis=1) # Shape [E, 2*D] messages = edge_type_to_edge_mlp[edge_type_idx]( edge_mlp_inputs) # Shape [E, D] if normalize_by_num_incoming: per_message_num_incoming_edges = \ tf.nn.embedding_lookup(params=type_to_num_incoming_edges[edge_type_idx, :], ids=edge_targets) # Shape [E, H] messages = tf.expand_dims( 1.0 / (per_message_num_incoming_edges + SMALL_NUMBER), axis=-1) * messages messages_per_type.append(messages) all_messages = tf.concat(messages_per_type, axis=0) # Shape [M, D] all_messages = activation_fn( all_messages ) # Shape [M, D] (Apply nonlinearity to Edge-MLP outputs as well) aggregated_messages = \ message_aggregation_fn(data=all_messages, segment_ids=message_targets, num_segments=num_nodes) # Shape [V, D] new_node_states = aggregated_messages cur_node_states = new_node_states return cur_node_states
def sparse_rgcn_layer( node_embeddings: tf.Tensor, adjacency_lists: List[tf.Tensor], type_to_num_incoming_edges: tf.Tensor, state_dim: Optional[int], num_timesteps: int = 1, activation_function: Optional[str] = "tanh", message_aggregation_function: str = "sum", normalize_by_num_incoming: bool = True, use_both_source_and_target: bool = False, ) -> tf.Tensor: """ Compute new graph states by neural message passing. This implements the R-GCN model (Schlichtkrull et al., https://arxiv.org/pdf/1703.06103.pdf) for the case of few relations / edge types, i.e., we do not use the dimensionality-reduction tricks from section 2.2 of that paper. For this, we assume existing node states h^t_v and a list of per-edge-type adjacency matrices A_\ell. We compute new states as follows: h^{t+1}_v := \sigma(\sum_\ell \sum_{(u, v) \in A_\ell} 1/c_{v,\ell} * (W_\ell * h^t_u)) c_{\v,\ell} is usually the number of \ell edges going into v. The learnable parameters of this are the W_\ell \in R^{D,D}. We use the following abbreviations in shape descriptions: * V: number of nodes * D: state dimension * L: number of different edge types * E: number of edges of a given edge type Arguments: node_embeddings: float32 tensor of shape [V, D], the original representation of each node in the graph. adjacency_lists: List of L adjacency lists, represented as int32 tensors of shape [E, 2]. Concretely, adjacency_lists[l][k,:] == [v, u] means that the k-th edge of type l connects node v to node u. type_to_num_incoming_edges: float32 tensor of shape [L, V] representing the number of incoming edges of a given type. Concretely, type_to_num_incoming_edges[l, v] is the number of edge of type l connecting to node v. state_dim: Optional size of output dimension of the GNN layer. If not set, defaults to D, the dimensionality of the input. If different from the input dimension, parameter num_timesteps has to be 1. num_timesteps: Number of repeated applications of this message passing layer. activation_function: Type of activation function used. message_aggregation_function: Type of aggregation function used for messages. normalize_by_num_incoming: Flag indicating if messages should be scaled by 1/(number of incoming edges). Returns: float32 tensor of shape [V, state_dim] """ num_nodes = tf.shape(input=node_embeddings, out_type=tf.int32)[0] if state_dim is None: state_dim = tf.shape(input=node_embeddings, out_type=tf.int32)[1] # === Prepare things we need across all timesteps: activation_fn = get_activation(activation_function) message_aggregation_fn = get_aggregation_function( message_aggregation_function) edge_type_to_message_transformation_layers = [ ] # Layers to compute the message from a source state edge_type_to_message_targets = [] # List of tensors of message targets for edge_type_idx, adjacency_list_for_edge_type in enumerate( adjacency_lists): edge_type_to_message_transformation_layers.append( tf.keras.layers.Dense(units=state_dim, use_bias=False, activation=None, name="Edge_%i_Weight" % edge_type_idx)) edge_type_to_message_targets.append(adjacency_list_for_edge_type[:, 1]) # Let M be the number of messages (sum of all E): message_targets = tf.concat(edge_type_to_message_targets, axis=0) # Shape [M] cur_node_states = node_embeddings for _ in range(num_timesteps): messages_per_type = [] # list of tensors of messages of shape [E, H] # Collect incoming messages per edge type for edge_type_idx, adjacency_list_for_edge_type in enumerate( adjacency_lists): edge_sources = adjacency_list_for_edge_type[:, 0] edge_targets = adjacency_list_for_edge_type[:, 1] edge_source_states = \ tf.nn.embedding_lookup(params=cur_node_states, ids=edge_sources) # Shape [E, H] if use_both_source_and_target: edge_target_states = \ tf.nn.embedding_lookup(params=cur_node_states, ids=edge_targets) # Shape [E, H] edge_state_pairs = tf.concat( [edge_source_states, edge_target_states], axis=-1) # Shape [E, 2H] messages = edge_type_to_message_transformation_layers[ edge_type_idx](edge_state_pairs) # Shape [E, H] else: messages = edge_type_to_message_transformation_layers[ edge_type_idx](edge_source_states) # Shape [E, H] if normalize_by_num_incoming: num_incoming_to_node_per_message = \ tf.nn.embedding_lookup(params=type_to_num_incoming_edges[edge_type_idx, :], ids=edge_targets) # Shape [E, H] messages = tf.expand_dims( 1.0 / (num_incoming_to_node_per_message + SMALL_NUMBER), axis=-1) * messages messages_per_type.append(messages) cur_messages = tf.concat(messages_per_type, axis=0) # Shape [M, H] aggregated_messages = \ message_aggregation_fn(data=cur_messages, segment_ids=message_targets, num_segments=num_nodes) # Shape [V, H] new_node_states = activation_fn(aggregated_messages) # Shape [V, H] cur_node_states = new_node_states return cur_node_states
def sparse_ggnn_layer(node_embeddings: tf.Tensor, adjacency_lists: List[tf.Tensor], state_dim: Optional[int], num_timesteps: int = 1, gated_unit_type: str = "gru", activation_function: str = "tanh", message_aggregation_function: str = "sum") -> tf.Tensor: """ Compute new graph states by neural message passing and gated units on the nodes. For this, we assume existing node states h^t_v and a list of per-edge-type adjacency matrices A_\ell. We compute new states as follows: h^{t+1}_v := Cell(h^t_v, \sum_\ell \sum_{(u, v) \in A_\ell} W_\ell * h^t_u) The learnable parameters of this are the recurrent Cell and the W_\ell \in R^{D,D}. We use the following abbreviations in shape descriptions: * V: number of nodes * D: state dimension * L: number of different edge types * E: number of edges of a given edge type Arguments: node_embeddings: float32 tensor of shape [V, D], the original representation of each node in the graph. adjacency_lists: List of L adjacency lists, represented as int32 tensors of shape [E, 2]. Concretely, adjacency_lists[l][k,:] == [v, u] means that the k-th edge of type l connects node v to node u. state_dim: Optional size of output dimension of the GNN layer. If not set, defaults to D, the dimensionality of the input. If different from the input dimension, parameter num_timesteps has to be 1. num_timesteps: Number of repeated applications of this message passing layer. gated_unit_type: Type of the recurrent unit used (one of RNN, GRU and LSTM). activation_function: Type of activation function used. message_aggregation_function: Type of aggregation function used for messages. Returns: float32 tensor of shape [V, state_dim] """ num_nodes = tf.shape(input=node_embeddings, out_type=tf.int32)[0] if state_dim is None: state_dim = tf.shape(input=node_embeddings, out_type=tf.int32)[1] # === Prepare things we need across all timesteps: message_aggregation_fn = get_aggregation_function( message_aggregation_function) gated_cell = get_gated_unit(state_dim, gated_unit_type, activation_function) edge_type_to_message_transformation_layers = [ ] # Layers to compute the message from a source state edge_type_to_message_targets = [] # List of tensors of message targets for edge_type_idx, adjacency_list_for_edge_type in enumerate( adjacency_lists): edge_type_to_message_transformation_layers.append( tf.keras.layers.Dense(units=state_dim, use_bias=False, activation=None, name="Edge_%i_Weight" % edge_type_idx)) edge_type_to_message_targets.append(adjacency_list_for_edge_type[:, 1]) # Let M be the number of messages (sum of all E): message_targets = tf.concat(edge_type_to_message_targets, axis=0) # Shape [M] cur_node_states = node_embeddings for _ in range(num_timesteps): messages = [] # list of tensors of messages of shape [E, D] message_source_states = [ ] # list of tensors of edge source states of shape [E, D] # Collect incoming messages per edge type for edge_type_idx, adjacency_list_for_edge_type in enumerate( adjacency_lists): edge_sources = adjacency_list_for_edge_type[:, 0] edge_source_states = tf.nn.embedding_lookup( params=cur_node_states, ids=edge_sources) # Shape [E, D] all_messages_for_edge_type = \ edge_type_to_message_transformation_layers[edge_type_idx](edge_source_states) # Shape [E,D] messages.append(all_messages_for_edge_type) message_source_states.append(edge_source_states) messages = tf.concat(messages, axis=0) # Shape [M, D] aggregated_messages = \ message_aggregation_fn(data=messages, segment_ids=message_targets, num_segments=num_nodes) # Shape [V, D] # pass updated vertex features into RNN cell new_node_states = gated_cell(aggregated_messages, [cur_node_states])[0] # Shape [V, D] cur_node_states = new_node_states return cur_node_states
def sparse_rgin_layer( node_embeddings: tf.Tensor, adjacency_lists: List[tf.Tensor], state_dim: Optional[int], num_timesteps: int = 1, activation_function: Optional[str] = "ReLU", message_aggregation_function: str = "sum", use_target_state_as_input: bool = False, num_edge_MLP_hidden_layers: Optional[int] = 1, num_aggr_MLP_hidden_layers: Optional[int] = None, ) -> tf.Tensor: """ Compute new graph states by neural message passing using MLPs for state updates and message computation. For this, we assume existing node states h^t_v and a list of per-edge-type adjacency matrices A_\ell. We compute new states as follows: h^{t+1}_v := \sigma(MLP_{aggr}(\sum_\ell \sum_{(u, v) \in A_\ell} MLP_\ell(h^t_u))) The learnable parameters of this are the MLPs MLP_\ell. This is derived from Cor. 6 of arXiv:1810.00826, instantiating the functions f, \phi with _separate_ MLPs. This is more powerful than the GIN formulation in Eq. (4.1) of arXiv:1810.00826, as we want to be able to distinguish graphs of the form G_1 = (V={1, 2, 3}, E_1={(1, 2)}, E_2={(3, 2)}) and G_2 = (V={1, 2, 3}, E_1={(3, 2)}, E_2={(1, 2)}) from each other. If we would treat all edges the same, G_1.E_1 \cup G_1.E_2 == G_2.E_1 \cup G_2.E_2 would imply that the two graphs become indistuingishable. Hence, we introduce per-edge-type MLPs, which also means that we have to drop the optimisation of modelling f \circ \phi by a single MLP used in the original GIN formulation. We use the following abbreviations in shape descriptions: * V: number of nodes * D: state dimension * L: number of different edge types * E: number of edges of a given edge type Arguments: node_embeddings: float32 tensor of shape [V, D], the original representation of each node in the graph. adjacency_lists: List of L adjacency lists, represented as int32 tensors of shape [E, 2]. Concretely, adjacency_lists[l][k,:] == [v, u] means that the k-th edge of type l connects node v to node u. state_dim: Optional size of output dimension of the GNN layer. If not set, defaults to D, the dimensionality of the input. If different from the input dimension, parameter num_timesteps has to be 1. num_timesteps: Number of repeated applications of this message passing layer. activation_function: Type of activation function used. message_aggregation_function: Type of aggregation function used for messages. use_target_state_as_input: Flag indicating if the edge MLP should consume both source and target state (True) or only source state (False). num_edge_MLP_hidden_layers: Number of hidden layers of the MLPs used to transform messages from neighbouring nodes. If None, the raw states are used directly. num_aggr_MLP_hidden_layers: Number of hidden layers of the MLPs used on the aggregation of messages from neighbouring nodes. If none, the aggregated messages are used directly. Returns: float32 tensor of shape [V, state_dim] """ num_nodes = tf.shape(node_embeddings, out_type=tf.int32)[0] if state_dim is None: state_dim = tf.shape(node_embeddings, out_type=tf.int32)[1] # === Prepare things we need across all timesteps: activation_fn = get_activation(activation_function) message_aggregation_fn = get_aggregation_function( message_aggregation_function) if num_aggr_MLP_hidden_layers is not None: aggregation_MLP = MLP(out_size=state_dim, hidden_layers=num_aggr_MLP_hidden_layers, activation_fun=activation_fn, name="Aggregation_MLP") # type: Optional[MLP] else: aggregation_MLP = None if num_edge_MLP_hidden_layers is not None: edge_type_to_edge_mlp = [ ] # type: Optional[List[MLP]] # MLPs to compute the edge messages else: edge_type_to_edge_mlp = None edge_type_to_message_targets = [] # List of tensors of message targets for edge_type_idx, adjacency_list_for_edge_type in enumerate( adjacency_lists): if edge_type_to_edge_mlp is not None and num_edge_MLP_hidden_layers is not None: edge_type_to_edge_mlp.append( MLP(out_size=state_dim, hidden_layers=num_edge_MLP_hidden_layers, activation_fun=activation_fn, name="Edge_%i_MLP" % edge_type_idx)) edge_type_to_message_targets.append(adjacency_list_for_edge_type[:, 1]) # Let M be the number of messages (sum of all E): message_targets = tf.concat(edge_type_to_message_targets, axis=0) # Shape [M] cur_node_states = node_embeddings for _ in range(num_timesteps): messages_per_type = [] # list of tensors of messages of shape [E, D] # Collect incoming messages per edge type for edge_type_idx, adjacency_list_for_edge_type in enumerate( adjacency_lists): edge_sources = adjacency_list_for_edge_type[:, 0] edge_targets = adjacency_list_for_edge_type[:, 1] edge_source_states = \ tf.nn.embedding_lookup(params=cur_node_states, ids=edge_sources) # Shape [E, D] edge_mlp_inputs = edge_source_states if use_target_state_as_input: edge_target_states = \ tf.nn.embedding_lookup(params=cur_node_states, ids=edge_targets) # Shape [E, D] edge_mlp_inputs = tf.concat( [edge_source_states, edge_target_states], axis=1) # Shape [E, 2*D] if edge_type_to_edge_mlp is not None: messages = edge_type_to_edge_mlp[edge_type_idx]( edge_mlp_inputs) # Shape [E, D] else: messages = edge_mlp_inputs messages_per_type.append(messages) all_messages = tf.concat(messages_per_type, axis=0) # Shape [M, D] if edge_type_to_edge_mlp is not None: all_messages = activation_fn( all_messages ) # Shape [M, D] (Apply nonlinearity to Edge-MLP outputs as well) aggregated_messages = \ message_aggregation_fn(data=all_messages, segment_ids=message_targets, num_segments=num_nodes) # Shape [V, D] new_node_states = aggregated_messages if aggregation_MLP is not None: new_node_states = aggregation_MLP(new_node_states) new_node_states = activation_fn( new_node_states ) # Note that the final MLP layer has no activation, so we do that here explicitly new_node_states = tf.contrib.layers.layer_norm(new_node_states) cur_node_states = new_node_states return cur_node_states
def sparse_rgdcn_layer(node_embeddings: tf.Tensor, adjacency_lists: List[tf.Tensor], type_to_num_incoming_edges: tf.Tensor, num_channels: int = 8, channel_dim: int = 16, num_timesteps: int = 1, use_full_state_for_channel_weights: bool = False, tie_channel_weights: bool = False, activation_function: Optional[str] = "tanh", message_aggregation_function: str = "sum", normalize_by_num_incoming: bool = True, ) -> tf.Tensor: """ Compute new graph states by message passing using dynamic convolutions for edge kernels. For this, we assume existing node states h^t_v and a list of per-edge-type adjacency matrices A_\ell. We split each state h^t_v into C "channels" of dimension K, and use h^t_{v,c,:} to refer to the slice of the node state corresponding to the c-th channel. Four variants of the model are implemented: (1) Edge kernels computed from full target node state using weights shared across all channels: [use_full_state_for_channel_weights = True, tie_channel_weights = True] h^{t+1}_v := \sigma(Concat(\sum_\ell \sum_{(u, v) \in A_\ell} W^t_{\ell,v} * h^t_{u,c,:} | 1 <= c <= C)) W^t_{\ell,v} := F_\ell * h^t_{v,:,:} The learnable parameters of this are the F_\ell \in R^{C*K, K*K}. (2) Edge kernels computed from full target node state using separate weights for each channel: [use_full_state_for_channel_weights = True, tie_channel_weights = False] h^{t+1}_v := \sigma(Concat(\sum_\ell \sum_{(u, v) \in A_\ell} \sigma(W^t_{\ell,v,c} * h^t_{u,c,:} | 1 <= c <= C) W^t_{\ell,v,c} := F_{\ell,c} * h^t_{v,:,:} The learnable parameters of this are the F_{\ell,c} \in R^{C*K, K*K}. (3) Edge kernels computed from corresponding channel of target node using weights shared across all channels: [use_full_state_for_channel_weights = False, tie_channel_weights = True] h^{t+1}_v := \sigma(Concat(\sum_\ell \sum_{(u, v) \in A_\ell} \sigma(W^t_{\ell,v,c} * h^t_{u,c,:} | 1 <= c <= C) W^t_{\ell,v,c} := F_{\ell} * h^t_{v,c,:} The learnable parameters of this are the F_\ell \in R^{K, K*K}. (4) Edge kernels computed from corresponding channel of target node using separate weights for each channel: [use_full_state_for_channel_weights = False, tie_channel_weights = False] h^{t+1}_v := \sigma(Concat(\sum_\ell \sum_{(u, v) \in A_\ell} W^t_{\ell,v,c} * h^t_{u,c,:} | 1 <= c <= C)) W^t_{\ell,v,c} := F_{\ell,c} * h^t_{v,c,:} The learnable parameters of this are the F_{\ell,c} \in R^{K, K*K}. We use the following abbreviations in shape descriptions: * V: number of nodes * C: number of "channels" * K: dimension of each "channel" * D: state dimension, fixed to C * K. * L: number of different edge types * E: number of edges of a given edge type Args: node_embeddings: float32 tensor of shape [V, D], the original representation of each node in the graph. adjacency_lists: List of L adjacency lists, represented as int32 tensors of shape [E, 2]. Concretely, adjacency_lists[l][k,:] == [v, u] means that the k-th edge of type l connects node v to node u. num_channels: Number of "channels" to split state information into. channel_dim: Size of each "channel" num_timesteps: Number of repeated applications of this message passing layer. use_full_state_for_channel_weights: Flag indicating if the full state is used to compute the weights for individual channels, or only the corresponding channel. tie_channel_weights: Flag indicating if the weights for computing the per-channel linear layer are shared or not. activation_function: Type of activation function used. message_aggregation_function: Type of aggregation function used for messages. normalize_by_num_incoming: Flag indicating if messages should be scaled by 1/(number of incoming edges). Returns: float32 tensor of shape [V, D] """ num_nodes = tf.shape(node_embeddings, out_type=tf.int32)[0] # === Prepare things we need across all timesteps: activation_fn = get_activation(activation_function) message_aggregation_fn = get_aggregation_function(message_aggregation_function) edge_type_to_channel_to_weight_computation_layers = [] # Layers to compute the dynamic computation weights edge_type_to_message_targets = [] # List of tensors of message targets for edge_type_idx, adjacency_list_for_edge_type in enumerate(adjacency_lists): channel_to_weight_computation_layers = [] for channel in range(num_channels): if channel == 0 or not(tie_channel_weights): channel_to_weight_computation_layers.append( tf.keras.layers.Dense( units=channel_dim * channel_dim, use_bias=False, kernel_initializer=tf.initializers.truncated_normal(mean=0.0, stddev=1.0 / (channel_dim**2)), activation=activation_fn, name="Edge_%i_Channel_%i_Weight_Computation" % (edge_type_idx, channel))) else: # Case channel > 0 and tie_channel_weights channel_to_weight_computation_layers.append( channel_to_weight_computation_layers[-1]) edge_type_to_channel_to_weight_computation_layers.append(channel_to_weight_computation_layers) edge_type_to_message_targets.append(adjacency_list_for_edge_type[:, 1]) # Let M be the number of messages (sum of all E): message_targets = tf.concat(edge_type_to_message_targets, axis=0) # Shape [M] cur_node_states = node_embeddings # Shape [V, D] for _ in range(num_timesteps): node_states_chunked = tf.reshape(cur_node_states, shape=(-1, num_channels, channel_dim)) # shape [V, C, K] new_node_states_chunked = [] # type: List[tf.Tensor] # C tensors of shape [V, K] for channel_idx in range(num_channels): cur_channel_node_states = node_states_chunked[:, channel_idx, :] # shape [V, K] cur_channel_message_per_type = [] # list of tensors of messages of shape [E, K] # Collect incoming messages per edge type for edge_type_idx, adjacency_list_for_edge_type in enumerate(adjacency_lists): edge_sources = adjacency_list_for_edge_type[:, 0] edge_targets = adjacency_list_for_edge_type[:, 1] edge_source_states = \ tf.nn.embedding_lookup(params=cur_channel_node_states, ids=edge_sources) # Shape [E, K] if use_full_state_for_channel_weights: weight_computation_input = cur_node_states else: weight_computation_input = cur_channel_node_states # TODO: In the tie_channel_weights && use_full_state_for_channel_weights case, # this is the same for each channel: weight_compute_layer = edge_type_to_channel_to_weight_computation_layers[edge_type_idx][channel_idx] edge_weights = weight_compute_layer(weight_computation_input) # Shape [V, K*K] edge_weights = tf.reshape(edge_weights, shape=(-1, channel_dim, channel_dim)) # Shape [V, K, K] edge_weights_for_targets = \ tf.nn.embedding_lookup(params=edge_weights, ids=edge_targets) # Shape [E, K, K] # Matrix multiply between edge_source_states[v] and edge_weights_for_targets[v]: messages = tf.einsum('vi,vij->vj', edge_source_states, edge_weights_for_targets) # Shape [E, K] if normalize_by_num_incoming: num_incoming_to_node_per_message = \ tf.nn.embedding_lookup(params=type_to_num_incoming_edges[edge_type_idx, :], ids=edge_targets) # Shape [E] messages = tf.expand_dims(1.0 / (num_incoming_to_node_per_message + SMALL_NUMBER), axis=-1) * messages cur_channel_message_per_type.append(messages) cur_channel_messages = tf.concat(cur_channel_message_per_type, axis=0) # Shape [M, K] cur_channel_aggregated_incoming_messages = \ message_aggregation_fn(data=cur_channel_messages, segment_ids=message_targets, num_segments=num_nodes) # Shape [V, K] cur_channel_aggregated_incoming_messages = activation_fn(cur_channel_aggregated_incoming_messages) new_node_states_chunked.append(cur_channel_aggregated_incoming_messages) new_node_states = tf.concat(new_node_states_chunked, axis=1) # Shape [V, C * K] cur_node_states = new_node_states return cur_node_states
def sparse_gnn_film_layer( node_embeddings: tf.Tensor, adjacency_lists: List[tf.Tensor], type_to_num_incoming_edges: tf.Tensor, state_dim: Optional[int], num_timesteps: int = 1, activation_function: Optional[str] = "ReLU", message_aggregation_function: str = "sum", normalize_by_num_incoming: bool = False, ) -> tf.Tensor: """ Compute new graph states by neural message passing modulated by the target state. For this, we assume existing node states h^t_v and a list of per-edge-type adjacency matrices A_\ell. We compute new states as follows: h^{t+1}_v := \sum_\ell \sum_{(u, v) \in A_\ell} \sigma(1/c_{v,\ell} * \alpha_{\ell,v} * (W_\ell * h^t_u) + \beta_{\ell,v}) \alpha_{\ell,v} := F_{\ell,\alpha} * h^t_v \beta_{\ell,v} := F_{\ell,\beta} * h^t_v c_{\v,\ell} is usually 1 (but could also be the number of incoming edges). The learnable parameters of this are the W_\ell, F_{\ell,\alpha}, F_{\ell,\beta} \in R^{D, D}. We use the following abbreviations in shape descriptions: * V: number of nodes * D: state dimension * L: number of different edge types * E: number of edges of a given edge type Arguments: node_embeddings: float32 tensor of shape [V, D], the original representation of each node in the graph. adjacency_lists: List of L adjacency lists, represented as int32 tensors of shape [E, 2]. Concretely, adjacency_lists[l][k,:] == [v, u] means that the k-th edge of type l connects node v to node u. type_to_num_incoming_edges: float32 tensor of shape [L, V] representing the number of incoming edges of a given type. Concretely, type_to_num_incoming_edges[l, v] is the number of edge of type l connecting to node v. state_dim: Optional size of output dimension of the GNN layer. If not set, defaults to D, the dimensionality of the input. If different from the input dimension, parameter num_timesteps has to be 1. num_timesteps: Number of repeated applications of this message passing layer. activation_function: Type of activation function used. message_aggregation_function: Type of aggregation function used for messages. normalize_by_num_incoming: Flag indicating if messages should be scaled by 1/(number of incoming edges). Returns: float32 tensor of shape [V, state_dim] """ num_nodes = tf.shape(input=node_embeddings, out_type=tf.int32)[0] if state_dim is None: state_dim = tf.shape(input=node_embeddings, out_type=tf.int32)[1] # === Prepare things we need across all timesteps: activation_fn = get_activation(activation_function) message_aggregation_fn = get_aggregation_function( message_aggregation_function) edge_type_to_message_transformation_layers = [ ] # Layers to compute the message from a source state edge_type_to_film_computation_layers = [ ] # Layers to compute the \beta/\gamma weights for FiLM edge_type_to_message_targets = [] # List of tensors of message targets for edge_type_idx, adjacency_list_for_edge_type in enumerate( adjacency_lists): edge_type_to_message_transformation_layers.append( tf.keras.layers.Dense( units=state_dim, use_bias=False, activation=None, # Activation only after FiLM modulation name="Edge_%i_Weight" % edge_type_idx)) edge_type_to_film_computation_layers.append( tf.keras.layers.Dense( units=2 * state_dim, # Computes \gamma, \beta in one go use_bias=False, activation=None, name="Edge_%i_FiLM_Computations" % edge_type_idx)) edge_type_to_message_targets.append(adjacency_list_for_edge_type[:, 1]) # Let M be the number of messages (sum of all E): message_targets = tf.concat(edge_type_to_message_targets, axis=0) # Shape [M] cur_node_states = node_embeddings for _ in range(num_timesteps): messages_per_type = [] # list of tensors of messages of shape [E, D] # Collect incoming messages per edge type for edge_type_idx, adjacency_list_for_edge_type in enumerate( adjacency_lists): edge_sources = adjacency_list_for_edge_type[:, 0] edge_targets = adjacency_list_for_edge_type[:, 1] edge_source_states = \ tf.nn.embedding_lookup(params=cur_node_states, ids=edge_sources) # Shape [E, D] messages = edge_type_to_message_transformation_layers[ edge_type_idx](edge_source_states) # Shape [E, D] if normalize_by_num_incoming: per_message_num_incoming_edges = \ tf.nn.embedding_lookup(params=type_to_num_incoming_edges[edge_type_idx, :], ids=edge_targets) # Shape [E, H] messages = tf.expand_dims( 1.0 / (per_message_num_incoming_edges + SMALL_NUMBER), axis=-1) * messages film_weights = edge_type_to_film_computation_layers[edge_type_idx]( cur_node_states) per_message_film_weights = \ tf.nn.embedding_lookup(params=film_weights, ids=edge_targets) per_message_film_gamma_weights = per_message_film_weights[:, : state_dim] # Shape [E, D] per_message_film_beta_weights = per_message_film_weights[:, state_dim:] # Shape [E, D] modulated_messages = per_message_film_gamma_weights * messages + per_message_film_beta_weights messages_per_type.append(modulated_messages) all_messages = tf.concat(messages_per_type, axis=0) # Shape [M, D] all_messages = activation_fn(all_messages) # Shape [M, D] aggregated_messages = \ message_aggregation_fn(data=all_messages, segment_ids=message_targets, num_segments=num_nodes) # Shape [V, D] new_node_states = aggregated_messages # new_node_states = activation_fn(new_node_states) return cur_node_states
def sparse_ggnn_layer(node_embeddings: tf.Tensor, adjacency_lists: List[tf.Tensor], state_dim: Optional[int], num_timesteps: int = 1, gated_unit_type: str = "gru", activation_function: str = "tanh", message_aggregation_function: str = "sum") -> tf.Tensor: """ Compute new graph states by neural message passing and gated units on the nodes. For this, we assume existing node states h^t_v and a list of per-edge-type adjacency matrices A_\ell. We compute new states as follows: h^{t+1}_v := Cell(h^t_v, \sum_\ell \sum_{(u, v) \in A_\ell} W_\ell * h^t_u) The learnable parameters of this are the recurrent Cell and the W_\ell \in R^{D,D}. We use the following abbreviations in shape descriptions: * V: number of nodes * D: state dimension * L: number of different edge types * E: number of edges of a given edge type Arguments: node_embeddings: float32 tensor of shape [V, D], the original representation of each node in the graph. adjacency_lists: List of L adjacency lists, represented as int32 tensors of shape [E, 2]. Concretely, adjacency_lists[l][k,:] == [v, u] means that the k-th edge of type l connects node v to node u. state_dim: Optional size of output dimension of the GNN layer. If not set, defaults to D, the dimensionality of the input. If different from the input dimension, parameter num_timesteps has to be 1. num_timesteps: Number of repeated applications of this message passing layer. gated_unit_type: Type of the recurrent unit used (one of RNN, GRU and LSTM). activation_function: Type of activation function used. message_aggregation_function: Type of aggregation function used for messages. Returns: float32 tensor of shape [V, state_dim] """ num_nodes = tf.shape(node_embeddings, out_type=tf.int32)[0] if state_dim is None: state_dim = tf.shape(node_embeddings, out_type=tf.int32)[1] # === Prepare things we need across all timesteps: message_aggregation_fn = get_aggregation_function( message_aggregation_function) gated_cell = get_gated_unit(state_dim, gated_unit_type, activation_function) edge_type_to_message_transformation_layers = [ ] # Layers to compute the message from a source state edge_type_to_message_targets = [] # List of tensors of message targets # for each edge type, create a dense linear layer to project into the state dimension. for edge_type_idx, adjacency_list_for_edge_type in enumerate( adjacency_lists): edge_type_to_message_transformation_layers.append( tf.keras.layers.Dense(units=state_dim, use_bias=False, activation=None, name="Edge_%i_Weight" % edge_type_idx)) # append the column vector of targets (i.e., only the second column of the list) # for the edge types to the edge_type_to_message_targets edge_type_to_message_targets.append(adjacency_list_for_edge_type[:, 1]) # Produces row vectors of targets (i.e., [[1, 2], [1, 5]] turns to [2 5] # Let M be the number of messages (sum of all E): message_targets = tf.concat(edge_type_to_message_targets, axis=0) # Shape [M] # single row vector of all message targets cur_node_states = node_embeddings for _ in range(num_timesteps): messages = [] # list of tensors of messages of shape [E, D] message_source_states = [ ] # list of tensors of edge source states of shape [E, D] # Collect incoming messages per edge type for edge_type_idx, adjacency_list_for_edge_type in enumerate( adjacency_lists): edge_sources = adjacency_list_for_edge_type[:, 0] # Shape [E] edge_source_states = tf.nn.embedding_lookup( params=cur_node_states, ids=edge_sources) # Shape [E, D] all_messages_for_edge_type = \ edge_type_to_message_transformation_layers[edge_type_idx](edge_source_states)# Shape [E,D] # This just projects the edge source states to the desired dimension through the linear layers that # were added to the list abvoe on line 62. messages.append( all_messages_for_edge_type) # List of tensors of shape [E,D] message_source_states.append( edge_source_states) # List if tensors of shape [E,D] messages = tf.concat(messages, axis=0) # Shape [M, D] aggregated_messages = \ message_aggregation_fn(data=messages, segment_ids=message_targets, num_segments=num_nodes) # Shape [V, D] # this function sums the node states based on message targets. # thus, every node in the messages list that is to the same target will be # summed together. Thus, the length of this vector is the number of # nodes (or more precisely, the number of message targets). # pass updated vertex features into RNN cell # first parameter is the input (shape [batch, feature] in this case [V, D]. # the second parameter are the states. In more than one timestep, this would be # the output from the previous timestep. In timestep 0, this is the initial state. # index [0] for the return because the return value is [new_node_states, [list]] # where list is the internal states that should be passed to the next layer. new_node_states = gated_cell(aggregated_messages, [cur_node_states])[0] # Shape [V, D] cur_node_states = new_node_states return cur_node_states