def sparse_gnn_edge_mlp_layer(node_embeddings: tf.Tensor,
                              adjacency_lists: List[tf.Tensor],
                              type_to_num_incoming_edges: tf.Tensor,
                              state_dim: Optional[int],
                              num_timesteps: int = 1,
                              activation_function: Optional[str] = "ReLU",
                              message_aggregation_function: str = "sum",
                              normalize_by_num_incoming: bool = False,
                              use_target_state_as_input: bool = True,
                              num_edge_hidden_layers: int = 1) -> tf.Tensor:
    """
    Compute new graph states by neural message passing using an edge MLP.
    For this, we assume existing node states h^t_v and a list of per-edge-type adjacency
    matrices A_\ell.

    We compute new states as follows:
        h^{t+1}_v := \sum_\ell
                     \sum_{(u, v) \in A_\ell}
                        \sigma(1/c_{v,\ell} * MLP(h^t_u || h^t_v))
        c_{\v,\ell} is usually 1 (but could also be the number of incoming edges).
    The learnable parameters of this are the W_\ell, F_{\ell,\alpha}, F_{\ell,\beta} \in R^{D, D}.

    We use the following abbreviations in shape descriptions:
    * V: number of nodes
    * D: state dimension
    * L: number of different edge types
    * E: number of edges of a given edge type

    Arguments:
        node_embeddings: float32 tensor of shape [V, D], the original representation of
            each node in the graph.
        adjacency_lists: List of L adjacency lists, represented as int32 tensors of shape
            [E, 2]. Concretely, adjacency_lists[l][k,:] == [v, u] means that the k-th edge
            of type l connects node v to node u.
        type_to_num_incoming_edges: float32 tensor of shape [L, V] representing the number
            of incoming edges of a given type. Concretely, type_to_num_incoming_edges[l, v]
            is the number of edge of type l connecting to node v.
        state_dim: Optional size of output dimension of the GNN layer. If not set, defaults
            to D, the dimensionality of the input. If different from the input dimension,
            parameter num_timesteps has to be 1.
        num_timesteps: Number of repeated applications of this message passing layer.
        activation_function: Type of activation function used.
        message_aggregation_function: Type of aggregation function used for messages.
        normalize_by_num_incoming: Flag indicating if messages should be scaled by 1/(number
            of incoming edges).
        use_target_state_as_input: Flag indicating if the edge MLP should consume both
            source and target state (True) or only source state (False).
        num_edge_hidden_layers: Number of hidden layers of the edge MLP.
        message_weights_dropout_ratio: Dropout ratio applied to the weights used
            to compute message passing functions.

    Returns:
        float32 tensor of shape [V, state_dim]
    """
    num_nodes = tf.shape(input=node_embeddings, out_type=tf.int32)[0]
    if state_dim is None:
        state_dim = tf.shape(input=node_embeddings, out_type=tf.int32)[1]

    # === Prepare things we need across all timesteps:
    activation_fn = get_activation(activation_function)
    message_aggregation_fn = get_aggregation_function(
        message_aggregation_function)
    edge_type_to_edge_mlp = []  # MLPs to compute the edge messages
    edge_type_to_message_targets = []  # List of tensors of message targets
    for edge_type_idx, adjacency_list_for_edge_type in enumerate(
            adjacency_lists):
        edge_type_to_edge_mlp.append(
            MLP(out_size=state_dim,
                hidden_layers=num_edge_hidden_layers,
                activation_fun=tf.nn.elu,
                name="Edge_%i_MLP" % edge_type_idx))
        edge_type_to_message_targets.append(adjacency_list_for_edge_type[:, 1])

    # Let M be the number of messages (sum of all E):
    message_targets = tf.concat(edge_type_to_message_targets,
                                axis=0)  # Shape [M]

    cur_node_states = node_embeddings
    for _ in range(num_timesteps):
        messages_per_type = []  # list of tensors of messages of shape [E, D]
        # Collect incoming messages per edge type
        for edge_type_idx, adjacency_list_for_edge_type in enumerate(
                adjacency_lists):
            edge_sources = adjacency_list_for_edge_type[:, 0]
            edge_targets = adjacency_list_for_edge_type[:, 1]
            edge_source_states = \
                tf.nn.embedding_lookup(params=cur_node_states,
                                       ids=edge_sources)  # Shape [E, D]

            edge_mlp_inputs = edge_source_states
            if use_target_state_as_input:
                edge_target_states = \
                    tf.nn.embedding_lookup(params=cur_node_states,
                                           ids=edge_targets)  # Shape [E, D]
                edge_mlp_inputs = tf.concat(
                    [edge_source_states, edge_target_states],
                    axis=1)  # Shape [E, 2*D]

            messages = edge_type_to_edge_mlp[edge_type_idx](
                edge_mlp_inputs)  # Shape [E, D]

            if normalize_by_num_incoming:
                per_message_num_incoming_edges = \
                    tf.nn.embedding_lookup(params=type_to_num_incoming_edges[edge_type_idx, :],
                                           ids=edge_targets)  # Shape [E, H]
                messages = tf.expand_dims(
                    1.0 / (per_message_num_incoming_edges + SMALL_NUMBER),
                    axis=-1) * messages
            messages_per_type.append(messages)

        all_messages = tf.concat(messages_per_type, axis=0)  # Shape [M, D]
        all_messages = activation_fn(
            all_messages
        )  # Shape [M, D]  (Apply nonlinearity to Edge-MLP outputs as well)
        aggregated_messages = \
            message_aggregation_fn(data=all_messages,
                                   segment_ids=message_targets,
                                   num_segments=num_nodes)  # Shape [V, D]

        new_node_states = aggregated_messages
        cur_node_states = new_node_states

    return cur_node_states
예제 #2
0
def sparse_rgcn_layer(
    node_embeddings: tf.Tensor,
    adjacency_lists: List[tf.Tensor],
    type_to_num_incoming_edges: tf.Tensor,
    state_dim: Optional[int],
    num_timesteps: int = 1,
    activation_function: Optional[str] = "tanh",
    message_aggregation_function: str = "sum",
    normalize_by_num_incoming: bool = True,
    use_both_source_and_target: bool = False,
) -> tf.Tensor:
    """
    Compute new graph states by neural message passing.
    This implements the R-GCN model (Schlichtkrull et al., https://arxiv.org/pdf/1703.06103.pdf)
    for the case of few relations / edge types, i.e., we do not use the dimensionality-reduction
    tricks from section 2.2 of that paper.
    For this, we assume existing node states h^t_v and a list of per-edge-type adjacency
    matrices A_\ell.

    We compute new states as follows:
        h^{t+1}_v := \sigma(\sum_\ell
                            \sum_{(u, v) \in A_\ell}
                               1/c_{v,\ell} * (W_\ell * h^t_u))
    c_{\v,\ell} is usually the number of \ell edges going into v.
    The learnable parameters of this are the W_\ell \in R^{D,D}.

    We use the following abbreviations in shape descriptions:
    * V: number of nodes
    * D: state dimension
    * L: number of different edge types
    * E: number of edges of a given edge type

    Arguments:
        node_embeddings: float32 tensor of shape [V, D], the original representation of
            each node in the graph.
        adjacency_lists: List of L adjacency lists, represented as int32 tensors of shape
            [E, 2]. Concretely, adjacency_lists[l][k,:] == [v, u] means that the k-th edge
            of type l connects node v to node u.
        type_to_num_incoming_edges: float32 tensor of shape [L, V] representing the number
            of incoming edges of a given type. Concretely, type_to_num_incoming_edges[l, v]
            is the number of edge of type l connecting to node v.
        state_dim: Optional size of output dimension of the GNN layer. If not set, defaults
            to D, the dimensionality of the input. If different from the input dimension,
            parameter num_timesteps has to be 1.
        num_timesteps: Number of repeated applications of this message passing layer.
        activation_function: Type of activation function used.
        message_aggregation_function: Type of aggregation function used for messages.
        normalize_by_num_incoming: Flag indicating if messages should be scaled by 1/(number
            of incoming edges).

    Returns:
        float32 tensor of shape [V, state_dim]
    """
    num_nodes = tf.shape(input=node_embeddings, out_type=tf.int32)[0]
    if state_dim is None:
        state_dim = tf.shape(input=node_embeddings, out_type=tf.int32)[1]

    # === Prepare things we need across all timesteps:
    activation_fn = get_activation(activation_function)
    message_aggregation_fn = get_aggregation_function(
        message_aggregation_function)
    edge_type_to_message_transformation_layers = [
    ]  # Layers to compute the message from a source state
    edge_type_to_message_targets = []  # List of tensors of message targets
    for edge_type_idx, adjacency_list_for_edge_type in enumerate(
            adjacency_lists):
        edge_type_to_message_transformation_layers.append(
            tf.keras.layers.Dense(units=state_dim,
                                  use_bias=False,
                                  activation=None,
                                  name="Edge_%i_Weight" % edge_type_idx))
        edge_type_to_message_targets.append(adjacency_list_for_edge_type[:, 1])

    # Let M be the number of messages (sum of all E):
    message_targets = tf.concat(edge_type_to_message_targets,
                                axis=0)  # Shape [M]

    cur_node_states = node_embeddings
    for _ in range(num_timesteps):
        messages_per_type = []  # list of tensors of messages of shape [E, H]
        # Collect incoming messages per edge type
        for edge_type_idx, adjacency_list_for_edge_type in enumerate(
                adjacency_lists):
            edge_sources = adjacency_list_for_edge_type[:, 0]
            edge_targets = adjacency_list_for_edge_type[:, 1]
            edge_source_states = \
                tf.nn.embedding_lookup(params=cur_node_states,
                                       ids=edge_sources)  # Shape [E, H]

            if use_both_source_and_target:
                edge_target_states = \
                    tf.nn.embedding_lookup(params=cur_node_states,
                                           ids=edge_targets)  # Shape [E, H]
                edge_state_pairs = tf.concat(
                    [edge_source_states, edge_target_states],
                    axis=-1)  # Shape [E, 2H]
                messages = edge_type_to_message_transformation_layers[
                    edge_type_idx](edge_state_pairs)  # Shape [E, H]
            else:
                messages = edge_type_to_message_transformation_layers[
                    edge_type_idx](edge_source_states)  # Shape [E, H]

            if normalize_by_num_incoming:
                num_incoming_to_node_per_message = \
                    tf.nn.embedding_lookup(params=type_to_num_incoming_edges[edge_type_idx, :],
                                           ids=edge_targets)  # Shape [E, H]
                messages = tf.expand_dims(
                    1.0 / (num_incoming_to_node_per_message + SMALL_NUMBER),
                    axis=-1) * messages

            messages_per_type.append(messages)

        cur_messages = tf.concat(messages_per_type, axis=0)  # Shape [M, H]
        aggregated_messages = \
            message_aggregation_fn(data=cur_messages,
                                   segment_ids=message_targets,
                                   num_segments=num_nodes)  # Shape [V, H]

        new_node_states = activation_fn(aggregated_messages)  # Shape [V, H]
        cur_node_states = new_node_states

    return cur_node_states
예제 #3
0
def sparse_ggnn_layer(node_embeddings: tf.Tensor,
                      adjacency_lists: List[tf.Tensor],
                      state_dim: Optional[int],
                      num_timesteps: int = 1,
                      gated_unit_type: str = "gru",
                      activation_function: str = "tanh",
                      message_aggregation_function: str = "sum") -> tf.Tensor:
    """
    Compute new graph states by neural message passing and gated units on the nodes.
    For this, we assume existing node states h^t_v and a list of per-edge-type adjacency
    matrices A_\ell.

    We compute new states as follows:
        h^{t+1}_v := Cell(h^t_v, \sum_\ell
                                 \sum_{(u, v) \in A_\ell}
                                     W_\ell * h^t_u)
    The learnable parameters of this are the recurrent Cell and the W_\ell \in R^{D,D}.

    We use the following abbreviations in shape descriptions:
    * V: number of nodes
    * D: state dimension
    * L: number of different edge types
    * E: number of edges of a given edge type

    Arguments:
        node_embeddings: float32 tensor of shape [V, D], the original representation of
            each node in the graph.
        adjacency_lists: List of L adjacency lists, represented as int32 tensors of shape
            [E, 2]. Concretely, adjacency_lists[l][k,:] == [v, u] means that the k-th edge
            of type l connects node v to node u.
        state_dim: Optional size of output dimension of the GNN layer. If not set, defaults
            to D, the dimensionality of the input. If different from the input dimension,
            parameter num_timesteps has to be 1.
        num_timesteps: Number of repeated applications of this message passing layer.
        gated_unit_type: Type of the recurrent unit used (one of RNN, GRU and LSTM).
        activation_function: Type of activation function used.
        message_aggregation_function: Type of aggregation function used for messages.

    Returns:
        float32 tensor of shape [V, state_dim]
    """
    num_nodes = tf.shape(input=node_embeddings, out_type=tf.int32)[0]
    if state_dim is None:
        state_dim = tf.shape(input=node_embeddings, out_type=tf.int32)[1]

    # === Prepare things we need across all timesteps:
    message_aggregation_fn = get_aggregation_function(
        message_aggregation_function)
    gated_cell = get_gated_unit(state_dim, gated_unit_type,
                                activation_function)
    edge_type_to_message_transformation_layers = [
    ]  # Layers to compute the message from a source state
    edge_type_to_message_targets = []  # List of tensors of message targets
    for edge_type_idx, adjacency_list_for_edge_type in enumerate(
            adjacency_lists):
        edge_type_to_message_transformation_layers.append(
            tf.keras.layers.Dense(units=state_dim,
                                  use_bias=False,
                                  activation=None,
                                  name="Edge_%i_Weight" % edge_type_idx))
        edge_type_to_message_targets.append(adjacency_list_for_edge_type[:, 1])

    # Let M be the number of messages (sum of all E):
    message_targets = tf.concat(edge_type_to_message_targets,
                                axis=0)  # Shape [M]

    cur_node_states = node_embeddings
    for _ in range(num_timesteps):
        messages = []  # list of tensors of messages of shape [E, D]
        message_source_states = [
        ]  # list of tensors of edge source states of shape [E, D]

        # Collect incoming messages per edge type
        for edge_type_idx, adjacency_list_for_edge_type in enumerate(
                adjacency_lists):
            edge_sources = adjacency_list_for_edge_type[:, 0]
            edge_source_states = tf.nn.embedding_lookup(
                params=cur_node_states, ids=edge_sources)  # Shape [E, D]
            all_messages_for_edge_type = \
                edge_type_to_message_transformation_layers[edge_type_idx](edge_source_states)  # Shape [E,D]
            messages.append(all_messages_for_edge_type)
            message_source_states.append(edge_source_states)

        messages = tf.concat(messages, axis=0)  # Shape [M, D]
        aggregated_messages = \
            message_aggregation_fn(data=messages,
                                   segment_ids=message_targets,
                                   num_segments=num_nodes)  # Shape [V, D]

        # pass updated vertex features into RNN cell
        new_node_states = gated_cell(aggregated_messages,
                                     [cur_node_states])[0]  # Shape [V, D]
        cur_node_states = new_node_states

    return cur_node_states
예제 #4
0
def sparse_rgin_layer(
    node_embeddings: tf.Tensor,
    adjacency_lists: List[tf.Tensor],
    state_dim: Optional[int],
    num_timesteps: int = 1,
    activation_function: Optional[str] = "ReLU",
    message_aggregation_function: str = "sum",
    use_target_state_as_input: bool = False,
    num_edge_MLP_hidden_layers: Optional[int] = 1,
    num_aggr_MLP_hidden_layers: Optional[int] = None,
) -> tf.Tensor:
    """
    Compute new graph states by neural message passing using MLPs for state updates
    and message computation.
    For this, we assume existing node states h^t_v and a list of per-edge-type adjacency
    matrices A_\ell.

    We compute new states as follows:
        h^{t+1}_v := \sigma(MLP_{aggr}(\sum_\ell \sum_{(u, v) \in A_\ell} MLP_\ell(h^t_u)))
    The learnable parameters of this are the MLPs MLP_\ell.
    This is derived from Cor. 6 of arXiv:1810.00826, instantiating the functions f, \phi
    with _separate_ MLPs. This is more powerful than the GIN formulation in Eq. (4.1) of
    arXiv:1810.00826, as we want to be able to distinguish graphs of the form
     G_1 = (V={1, 2, 3}, E_1={(1, 2)}, E_2={(3, 2)})
    and
     G_2 = (V={1, 2, 3}, E_1={(3, 2)}, E_2={(1, 2)})
    from each other. If we would treat all edges the same,
    G_1.E_1 \cup G_1.E_2 == G_2.E_1 \cup G_2.E_2 would imply that the two graphs
    become indistuingishable.
    Hence, we introduce per-edge-type MLPs, which also means that we have to drop
    the optimisation of modelling f \circ \phi by a single MLP used in the original
    GIN formulation.

    We use the following abbreviations in shape descriptions:
    * V: number of nodes
    * D: state dimension
    * L: number of different edge types
    * E: number of edges of a given edge type

    Arguments:
        node_embeddings: float32 tensor of shape [V, D], the original representation of
            each node in the graph.
        adjacency_lists: List of L adjacency lists, represented as int32 tensors of shape
            [E, 2]. Concretely, adjacency_lists[l][k,:] == [v, u] means that the k-th edge
            of type l connects node v to node u.
        state_dim: Optional size of output dimension of the GNN layer. If not set, defaults
            to D, the dimensionality of the input. If different from the input dimension,
            parameter num_timesteps has to be 1.
        num_timesteps: Number of repeated applications of this message passing layer.
        activation_function: Type of activation function used.
        message_aggregation_function: Type of aggregation function used for messages.
        use_target_state_as_input: Flag indicating if the edge MLP should consume both
            source and target state (True) or only source state (False).
        num_edge_MLP_hidden_layers: Number of hidden layers of the MLPs used to transform
            messages from neighbouring nodes. If None, the raw states are used directly.
        num_aggr_MLP_hidden_layers: Number of hidden layers of the MLPs used on the
            aggregation of messages from neighbouring nodes. If none, the aggregated messages
            are used directly.

    Returns:
        float32 tensor of shape [V, state_dim]
    """
    num_nodes = tf.shape(node_embeddings, out_type=tf.int32)[0]
    if state_dim is None:
        state_dim = tf.shape(node_embeddings, out_type=tf.int32)[1]

    # === Prepare things we need across all timesteps:
    activation_fn = get_activation(activation_function)
    message_aggregation_fn = get_aggregation_function(
        message_aggregation_function)

    if num_aggr_MLP_hidden_layers is not None:
        aggregation_MLP = MLP(out_size=state_dim,
                              hidden_layers=num_aggr_MLP_hidden_layers,
                              activation_fun=activation_fn,
                              name="Aggregation_MLP")  # type: Optional[MLP]
    else:
        aggregation_MLP = None

    if num_edge_MLP_hidden_layers is not None:
        edge_type_to_edge_mlp = [
        ]  # type: Optional[List[MLP]]  # MLPs to compute the edge messages
    else:
        edge_type_to_edge_mlp = None
    edge_type_to_message_targets = []  # List of tensors of message targets
    for edge_type_idx, adjacency_list_for_edge_type in enumerate(
            adjacency_lists):
        if edge_type_to_edge_mlp is not None and num_edge_MLP_hidden_layers is not None:
            edge_type_to_edge_mlp.append(
                MLP(out_size=state_dim,
                    hidden_layers=num_edge_MLP_hidden_layers,
                    activation_fun=activation_fn,
                    name="Edge_%i_MLP" % edge_type_idx))
        edge_type_to_message_targets.append(adjacency_list_for_edge_type[:, 1])

    # Let M be the number of messages (sum of all E):
    message_targets = tf.concat(edge_type_to_message_targets,
                                axis=0)  # Shape [M]

    cur_node_states = node_embeddings
    for _ in range(num_timesteps):
        messages_per_type = []  # list of tensors of messages of shape [E, D]
        # Collect incoming messages per edge type
        for edge_type_idx, adjacency_list_for_edge_type in enumerate(
                adjacency_lists):
            edge_sources = adjacency_list_for_edge_type[:, 0]
            edge_targets = adjacency_list_for_edge_type[:, 1]
            edge_source_states = \
                tf.nn.embedding_lookup(params=cur_node_states,
                                       ids=edge_sources)  # Shape [E, D]

            edge_mlp_inputs = edge_source_states
            if use_target_state_as_input:
                edge_target_states = \
                    tf.nn.embedding_lookup(params=cur_node_states,
                                           ids=edge_targets)  # Shape [E, D]
                edge_mlp_inputs = tf.concat(
                    [edge_source_states, edge_target_states],
                    axis=1)  # Shape [E, 2*D]

            if edge_type_to_edge_mlp is not None:
                messages = edge_type_to_edge_mlp[edge_type_idx](
                    edge_mlp_inputs)  # Shape [E, D]
            else:
                messages = edge_mlp_inputs
            messages_per_type.append(messages)

        all_messages = tf.concat(messages_per_type, axis=0)  # Shape [M, D]
        if edge_type_to_edge_mlp is not None:
            all_messages = activation_fn(
                all_messages
            )  # Shape [M, D]  (Apply nonlinearity to Edge-MLP outputs as well)
        aggregated_messages = \
            message_aggregation_fn(data=all_messages,
                                   segment_ids=message_targets,
                                   num_segments=num_nodes)  # Shape [V, D]

        new_node_states = aggregated_messages
        if aggregation_MLP is not None:
            new_node_states = aggregation_MLP(new_node_states)
        new_node_states = activation_fn(
            new_node_states
        )  # Note that the final MLP layer has no activation, so we do that here explicitly
        new_node_states = tf.contrib.layers.layer_norm(new_node_states)
        cur_node_states = new_node_states

    return cur_node_states
예제 #5
0
def sparse_rgdcn_layer(node_embeddings: tf.Tensor,
                       adjacency_lists: List[tf.Tensor],
                       type_to_num_incoming_edges: tf.Tensor,
                       num_channels: int = 8,
                       channel_dim: int = 16,
                       num_timesteps: int = 1,
                       use_full_state_for_channel_weights: bool = False,
                       tie_channel_weights: bool = False,
                       activation_function: Optional[str] = "tanh",
                       message_aggregation_function: str = "sum",
                       normalize_by_num_incoming: bool = True,
                       ) -> tf.Tensor:
    """
    Compute new graph states by message passing using dynamic convolutions for edge kernels.
    For this, we assume existing node states h^t_v and a list of per-edge-type adjacency
    matrices A_\ell.
    We split each state h^t_v into C "channels" of dimension K, and use h^t_{v,c,:} to refer to
    the slice of the node state corresponding to the c-th channel.

    Four variants of the model are implemented:

    (1) Edge kernels computed from full target node state using weights shared across all channels:
        [use_full_state_for_channel_weights = True, tie_channel_weights = True]
          h^{t+1}_v := \sigma(Concat(\sum_\ell \sum_{(u, v) \in A_\ell} W^t_{\ell,v} * h^t_{u,c,:}
                                     | 1 <= c <= C))
          W^t_{\ell,v} := F_\ell * h^t_{v,:,:}
        The learnable parameters of this are the F_\ell \in R^{C*K, K*K}.

    (2) Edge kernels computed from full target node state using separate weights for each channel:
        [use_full_state_for_channel_weights = True, tie_channel_weights = False]
          h^{t+1}_v := \sigma(Concat(\sum_\ell \sum_{(u, v) \in A_\ell} \sigma(W^t_{\ell,v,c} * h^t_{u,c,:}
                                     | 1 <= c <= C)
          W^t_{\ell,v,c} := F_{\ell,c} * h^t_{v,:,:}
        The learnable parameters of this are the F_{\ell,c} \in R^{C*K, K*K}.

    (3) Edge kernels computed from corresponding channel of target node using weights shared across all channels:
        [use_full_state_for_channel_weights = False, tie_channel_weights = True]
          h^{t+1}_v := \sigma(Concat(\sum_\ell \sum_{(u, v) \in A_\ell} \sigma(W^t_{\ell,v,c} * h^t_{u,c,:}
                                     | 1 <= c <= C)
          W^t_{\ell,v,c} := F_{\ell} * h^t_{v,c,:}
        The learnable parameters of this are the F_\ell \in R^{K, K*K}.

    (4) Edge kernels computed from corresponding channel of target node using separate weights for each channel:
        [use_full_state_for_channel_weights = False, tie_channel_weights = False]
          h^{t+1}_v := \sigma(Concat(\sum_\ell \sum_{(u, v) \in A_\ell} W^t_{\ell,v,c} * h^t_{u,c,:}
                                     | 1 <= c <= C))
          W^t_{\ell,v,c} := F_{\ell,c} * h^t_{v,c,:}
        The learnable parameters of this are the F_{\ell,c} \in R^{K, K*K}.

    We use the following abbreviations in shape descriptions:
    * V: number of nodes
    * C: number of "channels"
    * K: dimension of each "channel"
    * D: state dimension, fixed to C * K.
    * L: number of different edge types
    * E: number of edges of a given edge type

    Args:
        node_embeddings: float32 tensor of shape [V, D], the original representation of
            each node in the graph.
        adjacency_lists: List of L adjacency lists, represented as int32 tensors of shape
            [E, 2]. Concretely, adjacency_lists[l][k,:] == [v, u] means that the k-th edge
            of type l connects node v to node u.
        num_channels: Number of "channels" to split state information into.
        channel_dim: Size of each "channel"
        num_timesteps: Number of repeated applications of this message passing layer.
        use_full_state_for_channel_weights: Flag indicating if the full state is used to
            compute the weights for individual channels, or only the corresponding channel.
        tie_channel_weights: Flag indicating if the weights for computing the per-channel
            linear layer are shared or not.
        activation_function: Type of activation function used.
        message_aggregation_function: Type of aggregation function used for messages.
        normalize_by_num_incoming: Flag indicating if messages should be scaled by 1/(number
            of incoming edges).

    Returns:
        float32 tensor of shape [V, D]
    """
    num_nodes = tf.shape(node_embeddings, out_type=tf.int32)[0]

    # === Prepare things we need across all timesteps:
    activation_fn = get_activation(activation_function)
    message_aggregation_fn = get_aggregation_function(message_aggregation_function)
    edge_type_to_channel_to_weight_computation_layers = []  # Layers to compute the dynamic computation weights
    edge_type_to_message_targets = []  # List of tensors of message targets

    for edge_type_idx, adjacency_list_for_edge_type in enumerate(adjacency_lists):
        channel_to_weight_computation_layers = []
        for channel in range(num_channels):
            if channel == 0 or not(tie_channel_weights):
                channel_to_weight_computation_layers.append(
                    tf.keras.layers.Dense(
                        units=channel_dim * channel_dim,
                        use_bias=False,
                        kernel_initializer=tf.initializers.truncated_normal(mean=0.0, stddev=1.0 / (channel_dim**2)),
                        activation=activation_fn,
                        name="Edge_%i_Channel_%i_Weight_Computation" % (edge_type_idx, channel)))
            else:  # Case channel > 0 and tie_channel_weights
                channel_to_weight_computation_layers.append(
                    channel_to_weight_computation_layers[-1])
        edge_type_to_channel_to_weight_computation_layers.append(channel_to_weight_computation_layers)

        edge_type_to_message_targets.append(adjacency_list_for_edge_type[:, 1])

    # Let M be the number of messages (sum of all E):
    message_targets = tf.concat(edge_type_to_message_targets, axis=0)  # Shape [M]

    cur_node_states = node_embeddings  # Shape [V, D]
    for _ in range(num_timesteps):
        node_states_chunked = tf.reshape(cur_node_states,
                                         shape=(-1, num_channels, channel_dim))  # shape [V, C, K]

        new_node_states_chunked = []  # type: List[tf.Tensor]  # C tensors of shape [V, K]
        for channel_idx in range(num_channels):
            cur_channel_node_states = node_states_chunked[:, channel_idx, :]  # shape [V, K]
            cur_channel_message_per_type = []  # list of tensors of messages of shape [E, K]

            # Collect incoming messages per edge type
            for edge_type_idx, adjacency_list_for_edge_type in enumerate(adjacency_lists):
                edge_sources = adjacency_list_for_edge_type[:, 0]
                edge_targets = adjacency_list_for_edge_type[:, 1]
                edge_source_states = \
                    tf.nn.embedding_lookup(params=cur_channel_node_states,
                                           ids=edge_sources)  # Shape [E, K]

                if use_full_state_for_channel_weights:
                    weight_computation_input = cur_node_states
                else:
                    weight_computation_input = cur_channel_node_states
                # TODO: In the tie_channel_weights && use_full_state_for_channel_weights case,
                # this is the same for each channel:
                weight_compute_layer = edge_type_to_channel_to_weight_computation_layers[edge_type_idx][channel_idx]
                edge_weights = weight_compute_layer(weight_computation_input)  # Shape [V, K*K]
                edge_weights = tf.reshape(edge_weights, shape=(-1, channel_dim, channel_dim))  # Shape [V, K, K]
                edge_weights_for_targets = \
                    tf.nn.embedding_lookup(params=edge_weights, ids=edge_targets)  # Shape [E, K, K]

                # Matrix multiply between edge_source_states[v] and edge_weights_for_targets[v]:
                messages = tf.einsum('vi,vij->vj', edge_source_states, edge_weights_for_targets)  # Shape [E, K]
                if normalize_by_num_incoming:
                    num_incoming_to_node_per_message = \
                        tf.nn.embedding_lookup(params=type_to_num_incoming_edges[edge_type_idx, :],
                                               ids=edge_targets)  # Shape [E]
                    messages = tf.expand_dims(1.0 / (num_incoming_to_node_per_message + SMALL_NUMBER), axis=-1) * messages

                cur_channel_message_per_type.append(messages)

            cur_channel_messages = tf.concat(cur_channel_message_per_type, axis=0)  # Shape [M, K]
            cur_channel_aggregated_incoming_messages = \
                message_aggregation_fn(data=cur_channel_messages,
                                       segment_ids=message_targets,
                                       num_segments=num_nodes)  # Shape [V, K]
            cur_channel_aggregated_incoming_messages = activation_fn(cur_channel_aggregated_incoming_messages)

            new_node_states_chunked.append(cur_channel_aggregated_incoming_messages)

        new_node_states = tf.concat(new_node_states_chunked, axis=1)  # Shape [V, C * K]
        cur_node_states = new_node_states

    return cur_node_states
예제 #6
0
def sparse_gnn_film_layer(
    node_embeddings: tf.Tensor,
    adjacency_lists: List[tf.Tensor],
    type_to_num_incoming_edges: tf.Tensor,
    state_dim: Optional[int],
    num_timesteps: int = 1,
    activation_function: Optional[str] = "ReLU",
    message_aggregation_function: str = "sum",
    normalize_by_num_incoming: bool = False,
) -> tf.Tensor:
    """
    Compute new graph states by neural message passing modulated by the target state.
    For this, we assume existing node states h^t_v and a list of per-edge-type adjacency
    matrices A_\ell.

    We compute new states as follows:
        h^{t+1}_v := \sum_\ell
                     \sum_{(u, v) \in A_\ell}
                        \sigma(1/c_{v,\ell} * \alpha_{\ell,v} * (W_\ell * h^t_u) + \beta_{\ell,v})
        \alpha_{\ell,v} := F_{\ell,\alpha} * h^t_v
        \beta_{\ell,v} := F_{\ell,\beta} * h^t_v
        c_{\v,\ell} is usually 1 (but could also be the number of incoming edges).
    The learnable parameters of this are the W_\ell, F_{\ell,\alpha}, F_{\ell,\beta} \in R^{D, D}.

    We use the following abbreviations in shape descriptions:
    * V: number of nodes
    * D: state dimension
    * L: number of different edge types
    * E: number of edges of a given edge type

    Arguments:
        node_embeddings: float32 tensor of shape [V, D], the original representation of
            each node in the graph.
        adjacency_lists: List of L adjacency lists, represented as int32 tensors of shape
            [E, 2]. Concretely, adjacency_lists[l][k,:] == [v, u] means that the k-th edge
            of type l connects node v to node u.
        type_to_num_incoming_edges: float32 tensor of shape [L, V] representing the number
            of incoming edges of a given type. Concretely, type_to_num_incoming_edges[l, v]
            is the number of edge of type l connecting to node v.
        state_dim: Optional size of output dimension of the GNN layer. If not set, defaults
            to D, the dimensionality of the input. If different from the input dimension,
            parameter num_timesteps has to be 1.
        num_timesteps: Number of repeated applications of this message passing layer.
        activation_function: Type of activation function used.
        message_aggregation_function: Type of aggregation function used for messages.
        normalize_by_num_incoming: Flag indicating if messages should be scaled by 1/(number
            of incoming edges).

    Returns:
        float32 tensor of shape [V, state_dim]
    """
    num_nodes = tf.shape(input=node_embeddings, out_type=tf.int32)[0]
    if state_dim is None:
        state_dim = tf.shape(input=node_embeddings, out_type=tf.int32)[1]

    # === Prepare things we need across all timesteps:
    activation_fn = get_activation(activation_function)
    message_aggregation_fn = get_aggregation_function(
        message_aggregation_function)
    edge_type_to_message_transformation_layers = [
    ]  # Layers to compute the message from a source state
    edge_type_to_film_computation_layers = [
    ]  # Layers to compute the \beta/\gamma weights for FiLM
    edge_type_to_message_targets = []  # List of tensors of message targets
    for edge_type_idx, adjacency_list_for_edge_type in enumerate(
            adjacency_lists):
        edge_type_to_message_transformation_layers.append(
            tf.keras.layers.Dense(
                units=state_dim,
                use_bias=False,
                activation=None,  # Activation only after FiLM modulation
                name="Edge_%i_Weight" % edge_type_idx))
        edge_type_to_film_computation_layers.append(
            tf.keras.layers.Dense(
                units=2 * state_dim,  # Computes \gamma, \beta in one go
                use_bias=False,
                activation=None,
                name="Edge_%i_FiLM_Computations" % edge_type_idx))
        edge_type_to_message_targets.append(adjacency_list_for_edge_type[:, 1])

    # Let M be the number of messages (sum of all E):
    message_targets = tf.concat(edge_type_to_message_targets,
                                axis=0)  # Shape [M]

    cur_node_states = node_embeddings
    for _ in range(num_timesteps):
        messages_per_type = []  # list of tensors of messages of shape [E, D]
        # Collect incoming messages per edge type
        for edge_type_idx, adjacency_list_for_edge_type in enumerate(
                adjacency_lists):
            edge_sources = adjacency_list_for_edge_type[:, 0]
            edge_targets = adjacency_list_for_edge_type[:, 1]
            edge_source_states = \
                tf.nn.embedding_lookup(params=cur_node_states,
                                       ids=edge_sources)  # Shape [E, D]
            messages = edge_type_to_message_transformation_layers[
                edge_type_idx](edge_source_states)  # Shape [E, D]

            if normalize_by_num_incoming:
                per_message_num_incoming_edges = \
                    tf.nn.embedding_lookup(params=type_to_num_incoming_edges[edge_type_idx, :],
                                           ids=edge_targets)  # Shape [E, H]
                messages = tf.expand_dims(
                    1.0 / (per_message_num_incoming_edges + SMALL_NUMBER),
                    axis=-1) * messages

            film_weights = edge_type_to_film_computation_layers[edge_type_idx](
                cur_node_states)
            per_message_film_weights = \
                tf.nn.embedding_lookup(params=film_weights, ids=edge_targets)
            per_message_film_gamma_weights = per_message_film_weights[:, :
                                                                      state_dim]  # Shape [E, D]
            per_message_film_beta_weights = per_message_film_weights[:,
                                                                     state_dim:]  # Shape [E, D]

            modulated_messages = per_message_film_gamma_weights * messages + per_message_film_beta_weights
            messages_per_type.append(modulated_messages)

        all_messages = tf.concat(messages_per_type, axis=0)  # Shape [M, D]
        all_messages = activation_fn(all_messages)  # Shape [M, D]
        aggregated_messages = \
            message_aggregation_fn(data=all_messages,
                                   segment_ids=message_targets,
                                   num_segments=num_nodes)  # Shape [V, D]
        new_node_states = aggregated_messages
        # new_node_states = activation_fn(new_node_states)

    return cur_node_states
예제 #7
0
def sparse_ggnn_layer(node_embeddings: tf.Tensor,
                      adjacency_lists: List[tf.Tensor],
                      state_dim: Optional[int],
                      num_timesteps: int = 1,
                      gated_unit_type: str = "gru",
                      activation_function: str = "tanh",
                      message_aggregation_function: str = "sum") -> tf.Tensor:
    """
    Compute new graph states by neural message passing and gated units on the nodes.
    For this, we assume existing node states h^t_v and a list of per-edge-type adjacency
    matrices A_\ell.

    We compute new states as follows:
        h^{t+1}_v := Cell(h^t_v, \sum_\ell
                                 \sum_{(u, v) \in A_\ell}
                                     W_\ell * h^t_u)
    The learnable parameters of this are the recurrent Cell and the W_\ell \in R^{D,D}.

    We use the following abbreviations in shape descriptions:
    * V: number of nodes
    * D: state dimension
    * L: number of different edge types
    * E: number of edges of a given edge type

    Arguments:
        node_embeddings: float32 tensor of shape [V, D], the original representation of
            each node in the graph.
        adjacency_lists: List of L adjacency lists, represented as int32 tensors of shape
            [E, 2]. Concretely, adjacency_lists[l][k,:] == [v, u] means that the k-th edge
            of type l connects node v to node u.
        state_dim: Optional size of output dimension of the GNN layer. If not set, defaults
            to D, the dimensionality of the input. If different from the input dimension,
            parameter num_timesteps has to be 1.
        num_timesteps: Number of repeated applications of this message passing layer.
        gated_unit_type: Type of the recurrent unit used (one of RNN, GRU and LSTM).
        activation_function: Type of activation function used.
        message_aggregation_function: Type of aggregation function used for messages.

    Returns:
        float32 tensor of shape [V, state_dim]
    """
    num_nodes = tf.shape(node_embeddings, out_type=tf.int32)[0]
    if state_dim is None:
        state_dim = tf.shape(node_embeddings, out_type=tf.int32)[1]

    # === Prepare things we need across all timesteps:
    message_aggregation_fn = get_aggregation_function(
        message_aggregation_function)
    gated_cell = get_gated_unit(state_dim, gated_unit_type,
                                activation_function)
    edge_type_to_message_transformation_layers = [
    ]  # Layers to compute the message from a source state
    edge_type_to_message_targets = []  # List of tensors of message targets
    # for each edge type, create a dense linear layer to project into the state dimension.
    for edge_type_idx, adjacency_list_for_edge_type in enumerate(
            adjacency_lists):
        edge_type_to_message_transformation_layers.append(
            tf.keras.layers.Dense(units=state_dim,
                                  use_bias=False,
                                  activation=None,
                                  name="Edge_%i_Weight" % edge_type_idx))
        # append the column vector of targets (i.e., only the second column of the list)
        # for the edge types to the edge_type_to_message_targets
        edge_type_to_message_targets.append(adjacency_list_for_edge_type[:, 1])
        # Produces row vectors of targets (i.e., [[1, 2], [1, 5]] turns to [2 5]

    # Let M be the number of messages (sum of all E):
    message_targets = tf.concat(edge_type_to_message_targets,
                                axis=0)  # Shape [M]
    # single row vector of all message targets

    cur_node_states = node_embeddings
    for _ in range(num_timesteps):
        messages = []  # list of tensors of messages of shape [E, D]
        message_source_states = [
        ]  # list of tensors of edge source states of shape [E, D]

        # Collect incoming messages per edge type
        for edge_type_idx, adjacency_list_for_edge_type in enumerate(
                adjacency_lists):
            edge_sources = adjacency_list_for_edge_type[:, 0]  # Shape [E]
            edge_source_states = tf.nn.embedding_lookup(
                params=cur_node_states, ids=edge_sources)  # Shape [E, D]
            all_messages_for_edge_type = \
                edge_type_to_message_transformation_layers[edge_type_idx](edge_source_states)# Shape [E,D]
            # This just projects the edge source states to the desired dimension through the linear layers that
            # were added to the list abvoe on line 62.

            messages.append(
                all_messages_for_edge_type)  # List of tensors of shape [E,D]
            message_source_states.append(
                edge_source_states)  # List if tensors of shape [E,D]

        messages = tf.concat(messages, axis=0)  # Shape [M, D]
        aggregated_messages = \
            message_aggregation_fn(data=messages,
                                   segment_ids=message_targets,
                                   num_segments=num_nodes)  # Shape [V, D]
        # this function sums the node states based on message targets.
        # thus, every node in the messages list that is to the same target will be
        # summed together. Thus, the length of this vector is the number of
        # nodes (or more precisely, the number of message targets).

        # pass updated vertex features into RNN cell
        # first parameter is the input (shape [batch, feature] in this case [V, D].
        # the second parameter are the states. In more than one timestep, this would be
        # the output from the previous timestep. In timestep 0, this is the initial state.

        # index [0] for the return because the return value is [new_node_states, [list]]
        # where list is the internal states that should be passed to the next layer.
        new_node_states = gated_cell(aggregated_messages,
                                     [cur_node_states])[0]  # Shape [V, D]
        cur_node_states = new_node_states

    return cur_node_states