Ejemplo n.º 1
0
def make_inorp(  # Input
        input_node_shape,
        input_edge_shape,
        input_state_shape,
        input_embedd: dict = None,
        # Output
        output_embedd: dict = None,
        output_mlp: dict = None,
        # Model specific parameter
        depth=3,
        use_set2set: bool = False,  # not in original paper
        node_mlp_args: dict = None,
        edge_mlp_args: dict = None,
        set2set_args: dict = None,
        pooling_args: dict = None):
    """
    Generate Interaction network.

    Args:
        input_node_shape (list): Shape of node features. If shape is (None,) embedding layer is used.
        input_edge_shape (list): Shape of edge features. If shape is (None,) embedding layer is used.
        input_state_shape (list): Shape of state features. If shape is (,) embedding layer is used.
        input_embedd (dict): Dictionary of embedding parameters used if input shape is None. Default is
            {'input_node_vocab': 95, 'input_edge_vocab': 5, 'input_state_vocab': 100,
            'input_node_embedd': 64, 'input_edge_embedd': 64, 'input_state_embedd': 64,
            'input_type': 'ragged'}.
        output_embedd (dict): Dictionary of embedding parameters of the graph network. Default is
            {"output_mode": 'graph', "output_type": 'padded'}.
        output_mlp (dict): Dictionary of arguments for final MLP regression or classifcation layer. Default is
            {"use_bias": [True, True, False], "units": [25, 10, 1],
            "activation": ['relu', 'relu', 'sigmoid']}.
        depth (int): Number of convolution layers. Default is 3.
        node_mlp_args (dict): Dictionary of arguments for MLP for node update. Default is
            {"units": [100, 50], "use_bias": True, "activation": ['relu', "linear"]}
        edge_mlp_args (dict): Dictionary of arguments for MLP for interaction update. Default is
            {"units": [100, 100, 100, 100, 50],
            "activation": ['relu', 'relu', 'relu', 'relu', "linear"]}
        use_set2set (str): Use set2set pooling for graph embedding. Default is False.
        set2set_args (dict): Dictionary of set2set layer arguments. Default is
            {'channels': 32, 'T': 3, "pooling_method": "mean", "init_qstar": "mean"}.
        pooling_args (dict): Dictionary for message pooling arguments. Default is
            {'is_sorted': False, 'has_unconnected': True, 'pooling_method': "segment_mean"}

    Returns:
        model (tf.keras.model): Interaction model.

    """
    # default values
    model_default = {
        'input_embedd': {
            'input_node_vocab': 95,
            'input_edge_vocab': 5,
            'input_state_vocab': 100,
            'input_node_embedd': 64,
            'input_edge_embedd': 64,
            'input_state_embedd': 64,
            'input_tensor_type': 'ragged'
        },
        'output_embedd': {
            "output_mode": 'graph',
            "output_tensor_type": 'padded'
        },
        'output_mlp': {
            "use_bias": [True, True, False],
            "units": [25, 10, 1],
            "activation": ['relu', 'relu', 'sigmoid']
        },
        'set2set_args': {
            'channels': 32,
            'T': 3,
            "pooling_method": "mean",
            "init_qstar": "mean"
        },
        'node_mlp_args': {
            "units": [100, 50],
            "use_bias": True,
            "activation": ['relu', "linear"]
        },
        'edge_mlp_args': {
            "units": [100, 100, 100, 100, 50],
            "activation": ['relu', 'relu', 'relu', 'relu', "linear"]
        },
        'pooling_args': {
            'is_sorted': False,
            'has_unconnected': True,
            'pooling_method': "segment_mean"
        }
    }

    # Update default values
    input_embedd = update_model_args(model_default['input_embedd'],
                                     input_embedd)
    output_embedd = update_model_args(model_default['output_embedd'],
                                      output_embedd)
    output_mlp = update_model_args(model_default['output_mlp'], output_mlp)
    set2set_args = update_model_args(model_default['set2set_args'],
                                     set2set_args)
    node_mlp_args = update_model_args(model_default['node_mlp_args'],
                                      node_mlp_args)
    edge_mlp_args = update_model_args(model_default['edge_mlp_args'],
                                      edge_mlp_args)
    pooling_args = update_model_args(model_default['pooling_args'],
                                     pooling_args)
    gather_args = {"node_indexing": "sample"}

    # Make input embedding, if no feature dimension
    node_input, n, edge_input, ed, edge_index_input, env_input, uenv = generate_standard_graph_input(
        input_node_shape, input_edge_shape, input_state_shape, **input_embedd)

    # Preprocessing
    edi = edge_index_input
    ev = GatherState(**gather_args)([uenv, n])
    # n-Layer Step
    for i in range(0, depth):
        # upd = GatherNodes()([n,edi])
        eu1 = GatherNodesIngoing(**gather_args)([n, edi])
        eu2 = GatherNodesOutgoing(**gather_args)([n, edi])
        upd = Concatenate(axis=-1)([eu2, eu1])
        eu = Concatenate(axis=-1)([upd, ed])

        eu = MLP(**edge_mlp_args)(eu)
        # Pool message
        nu = PoolingLocalEdges(**pooling_args)(
            [n, eu, edi])  # Summing for each node connection
        # Add environment
        nu = Concatenate(axis=-1)(
            [n, nu, ev])  # Concatenate node features with new edge updates

        n = MLP(**node_mlp_args)(nu)

    if output_embedd["output_mode"] == 'graph':
        if use_set2set:
            # output
            outss = Dense(set2set_args["channels"], activation="linear")(n)
            out = Set2Set(**set2set_args)(outss)
        else:
            out = PoolingNodes(**pooling_args)(n)

        output_mlp.update({"input_tensor_type": "tensor"})
        main_output = MLP(**output_mlp)(out)

    else:  # Node labeling
        out = n
        main_output = MLP(**output_mlp)(out)

        main_output = ChangeTensorType(
            input_tensor_type="ragged",
            output_tensor_type="tensor")(main_output)
        # no ragged for distribution atm

    model = ks.models.Model(
        inputs=[node_input, edge_input, edge_index_input, env_input],
        outputs=main_output)

    return model
Ejemplo n.º 2
0
def make_gat(  # Input
        input_node_shape,
        input_edge_shape,
        input_embedd: dict = None,
        # Output
        output_embedd: dict = None,
        output_mlp: dict = None,
        # Model specific parameter
        depth=3,
        attention_heads_num=5,
        attention_heads_concat=False,
        attention_args: dict = None):
    """
    Generate Interaction network.

    Args:
        input_node_shape (list): Shape of node features. If shape is (None,) embedding layer is used.
        input_edge_shape (list): Shape of edge features. If shape is (None,) embedding layer is used.
        input_embedd (dict): Dictionary of embedding parameters used if input shape is None. Default is
            {'input_node_vocab': 95, 'input_edge_vocab': 5, 'input_state_vocab': 100,
            'input_node_embedd': 64, 'input_edge_embedd': 64, 'input_state_embedd': 64,
            'input_type': 'ragged'}.
        output_embedd (dict): Dictionary of embedding parameters of the graph network. Default is
            {"output_mode": 'graph', "output_type": 'padded'}.
        output_mlp (dict): Dictionary of arguments for final MLP regression or classifcation layer. Default is
            {"use_bias": [True, True, False], "units": [25, 10, 1],
            "activation": ['relu', 'relu', 'sigmoid']}.
        depth (int): Number of convolution layers. Default is 3.
        attention_heads_num (int): Number of attention heads. Default is 5.
        attention_heads_concat (bool): Concat attention. Default is False.
        attention_args (dict): Layer arguments for attention layer. Default is
            {"units": 32, 'is_sorted': False, 'has_unconnected': True}
    Returns:
        model (tf.keras.model): Interaction model.
    """
    # default values
    model_default = {
        'input_embedd': {
            'input_node_vocab': 95,
            'input_edge_vocab': 5,
            'input_state_vocab': 100,
            'input_node_embedd': 64,
            'input_edge_embedd': 64,
            'input_state_embedd': 64,
            'input_tensor_type': 'ragged'
        },
        'output_embedd': {
            "output_mode": 'graph',
            "output_tensor_type": 'padded'
        },
        'output_mlp': {
            "use_bias": [True, True, False],
            "units": [25, 10, 1],
            "activation": ['relu', 'relu', 'sigmoid']
        },
        'attention_args': {
            "units": 32,
            'is_sorted': False,
            'has_unconnected': True
        }
    }

    # Update default values
    input_embedd = update_model_args(model_default['input_embedd'],
                                     input_embedd)
    output_embedd = update_model_args(model_default['output_embedd'],
                                      output_embedd)
    output_mlp = update_model_args(model_default['output_mlp'], output_mlp)
    attention_args = update_model_args(model_default['attention_args'],
                                       attention_args)
    pooling_nodes_args = {}

    # Make input embedding, if no feature dimension
    node_input, n, edge_input, ed, edge_index_input, _, _ = generate_standard_graph_input(
        input_node_shape, input_edge_shape, None, **input_embedd)

    edi = edge_index_input

    nk = Dense(units=attention_args["units"], activation="linear")(n)
    for i in range(0, depth):
        heads = [
            AttentionHeadGAT(**attention_args)([nk, ed, edi])
            for _ in range(attention_heads_num)
        ]
        if attention_heads_concat:
            nk = Concatenate(axis=-1)(heads)
        else:
            nk = Average()(heads)

    n = nk
    if output_embedd["output_mode"] == 'graph':
        out = PoolingNodes(**pooling_nodes_args)(n)
        output_mlp.update({"input_tensor_type": "tensor"})
        out = MLP(**output_mlp)(out)
        main_output = ks.layers.Flatten()(out)  # will be dense
    else:  # node embedding
        out = MLP(**output_mlp)(n)
        main_output = ChangeTensorType(input_tensor_type="ragged",
                                       output_tensor_type="tensor")(out)

    model = tf.keras.models.Model(
        inputs=[node_input, edge_input, edge_index_input], outputs=main_output)

    return model
Ejemplo n.º 3
0
def make_unet(
        # Input
        input_node_shape,
        input_edge_shape,
        input_embedd: dict = None,
        # Output
        output_embedd: dict = None,
        output_mlp: dict = None,
        # Model specific
        hidden_dim=32,
        depth=4,
        k=0.3,
        score_initializer='ones',
        use_bias=True,
        activation='relu',
        is_sorted=False,
        has_unconnected=True,
        use_reconnect=True
):
    """
    Make Graph U Net.

    Args:
        input_node_shape (list): Shape of node features. If shape is (None,) embedding layer is used.
        input_edge_shape (list): Shape of edge features. If shape is (None,) embedding layer is used.
        input_embedd (list): Dictionary of embedding parameters used if input shape is None. Default is
            {'input_node_vocab': 95, 'input_edge_vocab': 5, 'input_state_vocab': 100,
            'input_node_embedd': 64, 'input_edge_embedd': 64, 'input_state_embedd': 64,
            'input_type': 'ragged'}
        output_mlp (dict, optional): Parameter for MLP output classification/ regression. Defaults to
            {"use_bias": [True, False], "output_dim": [25, 1],
            "activation": ['relu', 'sigmoid']}
        output_embedd (str): Dictionary of embedding parameters of the graph network. Default is
            {"output_mode": 'graph', "output_type": 'padded'}
        hidden_dim (int): Hidden node feature dimension 32,
        depth (int): Depth of pooling steps. Default is 4.
        k (float): Pooling ratio. Default is 0.3.
        score_initializer (str): How to initialize score kernel. Default is 'ones'.
        use_bias (bool): Use bias. Default is True.
        activation (str): Activation function used. Default is 'relu'.
        is_sorted (bool, optional): Edge edge_indices are sorted. Defaults to False.
        has_unconnected (bool, optional): Has unconnected nodes. Defaults to True.
        use_reconnect (bool): Reconnect nodes after pooling. I.e. adj_matrix=adj_matrix^2. Default is True.

    Returns:
        model (ks.models.Model): Unet model.
    """
    # Default values update
    model_default = {'input_embedd': {'input_node_vocab': 95, 'input_edge_vocab': 5, 'input_state_vocab': 100,
                                      'input_node_embedd': 64, 'input_edge_embedd': 64, 'input_state_embedd': 64,
                                      'input_tensor_type': 'ragged'},
                     'output_embedd': {"output_mode": 'graph', "output_type": 'padded'},
                     'output_mlp': {"use_bias": [True, False], "units": [25, 1], "activation": ['relu', 'sigmoid']}
                     }

    # Update model args
    input_embedd = update_model_args(model_default['input_embedd'], input_embedd)
    output_embedd = update_model_args(model_default['output_embedd'], output_embedd)
    output_mlp = update_model_args(model_default['output_mlp'], output_mlp)
    pooling_args = {"pooling_method": 'segment_mean', "is_sorted": is_sorted, "has_unconnected": has_unconnected}

    # Make input embedding, if no feature dimension
    node_input, n, edge_input, ed, edge_index_input, _, _ = generate_standard_graph_input(input_node_shape,
                                                                                          input_edge_shape, None,
                                                                                          **input_embedd)
    tens_type = "values_partition"
    node_indexing = "batch"
    n = ChangeTensorType(input_tensor_type="ragged", output_tensor_type=tens_type)(n)
    ed = ChangeTensorType(input_tensor_type="ragged", output_tensor_type=tens_type)(ed)
    edi = ChangeTensorType(input_tensor_type="ragged", output_tensor_type=tens_type)(edge_index_input)
    edi = ChangeIndexing(input_tensor_type=tens_type, to_indexing=node_indexing)([n, edi])  # disjoint

    output_mlp.update({"input_tensor_type": tens_type})
    gather_args = {"input_tensor_type": tens_type, "node_indexing": node_indexing}
    pooling_args.update({"input_tensor_type": tens_type, "node_indexing": node_indexing})

    # Graph lists
    n = Dense(hidden_dim, use_bias=use_bias, activation='linear', input_tensor_type=tens_type)(n)
    in_graph = [n, ed, edi]
    graph_list = [in_graph]
    map_list = []

    # U Down
    i_graph = in_graph
    for i in range(0, depth):

        n, ed, edi = i_graph
        # GCN layer
        eu = GatherNodesOutgoing(**gather_args)([n, edi])
        eu = Dense(hidden_dim, use_bias=use_bias, activation='linear', input_tensor_type=tens_type)(eu)
        nu = PoolingLocalEdges(**pooling_args)([n, eu, edi])  # Summing for each node connection
        n = Activation(activation=activation, input_tensor_type=tens_type)(nu)

        if use_reconnect:
            ed, edi = AdjacencyPower(n=2, node_indexing=node_indexing, input_tensor_type=tens_type)([n, ed, edi])

        # Pooling
        i_graph, i_map = PoolingTopK(k=k, kernel_initializer=score_initializer,
                                     node_indexing=node_indexing, input_tensor_type=tens_type)([n, ed, edi])

        graph_list.append(i_graph)
        map_list.append(i_map)

    # U Up
    ui_graph = i_graph
    for i in range(depth, 0, -1):
        o_graph = graph_list[i - 1]
        i_map = map_list[i - 1]
        ui_graph = UnPoolingTopK(node_indexing=node_indexing, input_tensor_type=tens_type)(o_graph + i_map + ui_graph)

        n, ed, edi = ui_graph
        # skip connection
        n = Add(input_tensor_type=tens_type)([n, o_graph[0]])
        # GCN
        eu = GatherNodesOutgoing(**gather_args)([n, edi])
        eu = Dense(hidden_dim, use_bias=use_bias, activation='linear', input_tensor_type=tens_type)(eu)
        nu = PoolingLocalEdges(**pooling_args)([n, eu, edi])  # Summing for each node connection
        n = Activation(activation=activation, input_tensor_type=tens_type)(nu)

        ui_graph = [n, ed, edi]

    # Otuput
    n = ui_graph[0]
    if output_embedd["output_mode"] == 'graph':
        out = PoolingNodes(**pooling_args)(n)

        output_mlp.update({"input_tensor_type": "tensor"})
        out = MLP(**output_mlp)(out)
        main_output = ks.layers.Flatten()(out)  # will be dense
    else:  # node embedding
        out = MLP(**output_mlp)(n)
        main_output = ChangeTensorType(input_tensor_type=tens_type, output_tensor_type="tensor")(out)

    model = ks.models.Model(inputs=[node_input, edge_input, edge_index_input], outputs=main_output)

    return model
Ejemplo n.º 4
0
def make_nmpn(
        # Input
        input_node_shape,
        input_edge_shape,
        input_embedd: dict = None,
        # Output
        output_embedd: dict = None,
        output_mlp: dict = None,
        # Model specific
        depth=3,
        node_dim=128,
        edge_dense: dict = None,
        use_set2set=True,
        set2set_args: dict = None,
        pooling_args: dict = None):
    """
    Get Message passing model.

    Args:
        input_node_shape (list): Shape of node features. If shape is (None,) embedding layer is used.
        input_edge_shape (list): Shape of edge features. If shape is (None,) embedding layer is used.
        input_embedd (dict): Dictionary of embedding parameters used if input shape is None. Default is
            {'input_node_vocab': 95, 'input_edge_vocab': 5, 'input_state_vocab': 100,
            'input_node_embedd': 64, 'input_edge_embedd': 64, 'input_state_embedd': 64,
            'input_type': 'ragged'}
        output_embedd (str): Dictionary of embedding parameters of the graph network. Default is
            {"output_mode": 'graph', "output_type": 'padded'}
        output_mlp (dict): Dictionary of MLP arguments for output regression or classifcation. Default is
            {"use_bias": [True, True, False], "units": [25, 10, 1],
            "output_activation": ['selu', 'selu', 'sigmoid']}
        depth (int, optional): Depth. Defaults to 3.
        node_dim (int, optional): Dimension for hidden node representation. Defaults to 128.
        edge_dense (dict): Dictionary of arguments for NN to make edge matrix. Default is
            {'use_bias' : True, 'activation' : 'selu'}
        use_set2set (bool, optional): Use set2set layer. Defaults to True.
        set2set_args (dict): Dictionary of Set2Set Layer Arguments. Default is
            {'channels': 32, 'T': 3, "pooling_method": "sum", "init_qstar": "0"}
        pooling_args (dict): Dictionary for message pooling arguments. Default is
            {'is_sorted': False, 'has_unconnected': True, 'pooling_method': "segment_mean"}

    Returns:
        model (ks.models.Model): Message Passing model.
    """
    # Make default parameter
    model_default = {
        'input_embedd': {
            'input_node_vocab': 95,
            'input_edge_vocab': 5,
            'input_state_vocab': 100,
            'input_node_embedd': 64,
            'input_edge_embedd': 64,
            'input_state_embedd': 64,
            'input_tensor_type': 'ragged'
        },
        'output_embedd': {
            "output_mode": 'graph',
            "output_type": 'padded'
        },
        'output_mlp': {
            "use_bias": [True, True, False],
            "units": [25, 10, 1],
            "activation": ['selu', 'selu', 'sigmoid']
        },
        'set2set_args': {
            'channels': 32,
            'T': 3,
            "pooling_method": "sum",
            "init_qstar": "0"
        },
        'pooling_args': {
            'is_sorted': False,
            'has_unconnected': True,
            'pooling_method': "segment_mean"
        },
        'edge_dense': {
            'use_bias': True,
            'activation': 'selu'
        }
    }

    # Update model args
    input_embedd = update_model_args(model_default['input_embedd'],
                                     input_embedd)
    output_embedd = update_model_args(model_default['output_embedd'],
                                      output_embedd)
    output_mlp = update_model_args(model_default['output_mlp'], output_mlp)
    set2set_args = update_model_args(model_default['set2set_args'],
                                     set2set_args)
    pooling_args = update_model_args(model_default['pooling_args'],
                                     pooling_args)
    edge_dense = update_model_args(model_default['edge_dense'], edge_dense)

    # Make input embedding, if no feature dimension
    node_input, n, edge_input, ed, edge_index_input, _, _ = generate_standard_graph_input(
        input_node_shape, input_edge_shape, None, **input_embedd)

    tens_type = "values_partition"
    node_indexing = "batch"
    n = ChangeTensorType(input_tensor_type="ragged",
                         output_tensor_type=tens_type)(n)
    ed = ChangeTensorType(input_tensor_type="ragged",
                          output_tensor_type=tens_type)(ed)
    edi = ChangeTensorType(input_tensor_type="ragged",
                           output_tensor_type=tens_type)(edge_index_input)
    edi = ChangeIndexing(input_tensor_type=tens_type,
                         to_indexing=node_indexing)([n, edi])
    set2set_args.update({"input_tensor_type": tens_type})
    output_mlp.update({"input_tensor_type": tens_type})
    edge_dense.update({"input_tensor_type": tens_type})
    pooling_args.update({
        "input_tensor_type": tens_type,
        "node_indexing": node_indexing
    })

    n = Dense(node_dim, activation="linear", input_tensor_type=tens_type)(n)
    edge_net = Dense(node_dim * node_dim, **edge_dense)(ed)
    gru = GRUupdate(node_dim,
                    input_tensor_type=tens_type,
                    node_indexing=node_indexing)

    for i in range(0, depth):
        eu = GatherNodesOutgoing(input_tensor_type=tens_type,
                                 node_indexing=node_indexing)([n, edi])
        eu = TrafoMatMulMessages(node_dim,
                                 input_tensor_type=tens_type,
                                 node_indexing=node_indexing)([edge_net, eu])
        eu = PoolingLocalEdges(**pooling_args)(
            [n, eu, edi])  # Summing for each node connections
        n = gru([n, eu])

    if output_embedd["output_mode"] == 'graph':
        if use_set2set:
            # output
            outss = Dense(set2set_args['channels'],
                          activation="linear",
                          input_tensor_type=tens_type)(n)
            out = Set2Set(**set2set_args)(outss)
        else:
            out = PoolingNodes(**pooling_args)(n)

        # final dense layers
        output_mlp.update({"input_tensor_type": "tensor"})
        main_output = MLP(**output_mlp)(out)

    else:  # Node labeling
        out = n
        main_output = MLP(**output_mlp)(out)
        main_output = ChangeTensorType(
            input_tensor_type=tens_type,
            output_tensor_type="tensor")(main_output)
        # no ragged for distribution supported atm

    model = ks.models.Model(inputs=[node_input, edge_input, edge_index_input],
                            outputs=main_output)

    return model
Ejemplo n.º 5
0
def make_megnet(
        # Input
        input_node_shape,
        input_edge_shape,
        input_state_shape,
        input_embedd: dict = None,
        # Output
        output_embedd: dict = None,  # Only graph possible for megnet
        output_mlp: dict = None,
        # Model specs
        meg_block_args: dict = None,
        node_ff_args: dict = None,
        edge_ff_args: dict = None,
        state_ff_args: dict = None,
        set2set_args: dict = None,
        nblocks: int = 3,
        has_ff: bool = True,
        dropout: float = None,
        use_set2set: bool = True,
):
    """
    Get Megnet model.

    Args:
        input_node_shape (list): Shape of node features. If shape is (None,) embedding layer is used.
        input_edge_shape (list): Shape of edge features. If shape is (None,) embedding layer is used.
        input_state_shape (list): Shape of state features. If shape is (,) embedding layer is used.
        input_embedd (dict): Dictionary of embedding parameters used if input shape is None. Default is
            {'input_node_vocab': 95, 'input_edge_vocab': 5, 'input_state_vocab': 100,
            'input_node_embedd': 64, 'input_edge_embedd': 64, 'input_state_embedd': 64,
            'input_type': 'ragged'}.
        output_embedd (str): Dictionary of embedding parameters of the graph network. Default is
            {"output_mode": 'graph', "output_type": 'padded'}
        output_mlp (dict): Dictionary of MLP arguments for output regression or classifcation. Default is
            {"use_bias": [True, True, True], "units": [32, 16, 1],
            "activation": ['softplus2', 'softplus2', 'linear']}.
        meg_block_args (dict): Dictionary of MegBlock arguments. Default is
            {'node_embed': [64, 32, 32], 'edge_embed': [64, 32, 32],
            'env_embed': [64, 32, 32], 'activation': 'softplus2', 'is_sorted': False,
            'has_unconnected': True}.
        node_ff_args (dict): Dictionary of Feed-Forward Layer arguments. Default is
            {"units": [64, 32], "activation": "softplus2"}.
        edge_ff_args (dict): Dictionary of  Feed-Forward Layer arguments. Default is
            {"units": [64, 32], "activation": "softplus2"}.
        state_ff_args (dict): Dictionary of Feed-Forward Layer arguments. Default is
            {"units": [64, 32], "activation": "softplus2"}.
        set2set_args (dict): Dictionary of Set2Set Layer Arguments. Default is
            {'channels': 16, 'T': 3, "pooling_method": "sum", "init_qstar": "0"}
        nblocks (int): Number of block. Default is 3.
        has_ff (bool): Use a Feed-Forward layer. Default is True.
        dropout (float): Use dropout. Default is None.
        use_set2set (bool): Use set2set. Default is True.

    Returns:
        model (tf.keras.models.Model): MEGnet model.
    """
    # Default arguments if None
    model_default = {'input_embedd': {'input_node_vocab': 95, 'input_edge_vocab': 5, 'input_state_vocab': 100,
                                      'input_node_embedd': 64, 'input_edge_embedd': 64, 'input_state_embedd': 64,
                                      'input_tensor_type': 'ragged'},
                     'output_embedd': {"output_mode": 'graph', "output_tensor_type": 'padded'},
                     'output_mlp': {"use_bias": [True, True, True], "units": [32, 16, 1],
                                    "activation": ['kgcnn>softplus2', 'kgcnn>softplus2', 'linear']},
                     'meg_block_args': {'node_embed': [64, 32, 32], 'edge_embed': [64, 32, 32],
                                        'env_embed': [64, 32, 32], 'activation': 'kgcnn>softplus2', 'is_sorted': False,
                                        'has_unconnected': True},
                     'set2set_args': {'channels': 16, 'T': 3, "pooling_method": "sum", "init_qstar": "0"},
                     'node_ff_args': {"units": [64, 32], "activation": "kgcnn>softplus2"},
                     'edge_ff_args': {"units": [64, 32], "activation": "kgcnn>softplus2"},
                     'state_ff_args': {"units": [64, 32], "activation": "kgcnn>softplus2"}
                     }

    # Update default arguments
    input_embedd = update_model_args(model_default['input_embedd'], input_embedd)
    output_embedd = update_model_args(model_default['output_embedd'], output_embedd)
    output_mlp = update_model_args(model_default['output_mlp'], output_mlp)
    meg_block_args = update_model_args(model_default['meg_block_args'], meg_block_args)
    set2set_args = update_model_args(model_default['set2set_args'], set2set_args)
    node_ff_args = update_model_args(model_default['node_ff_args'], node_ff_args)
    edge_ff_args = update_model_args(model_default['edge_ff_args'], edge_ff_args)
    state_ff_args = update_model_args(model_default['state_ff_args'], state_ff_args)
    state_ff_args.update({"input_tensor_type": "tensor"})

    # Make input embedding, if no feature dimension
    node_input, n, edge_input, ed, edge_index_input, env_input, uenv = generate_standard_graph_input(input_node_shape,
                                                                                                     input_edge_shape,
                                                                                                     input_state_shape,
                                                                                                     **input_embedd)



    edi = edge_index_input

    # starting
    vp = n
    ep = ed
    up = uenv
    vp = MLP(**node_ff_args)(vp)
    ep = MLP(**edge_ff_args)(ep)
    up = MLP(**state_ff_args)(up)
    vp2 = vp
    ep2 = ep
    up2 = up
    for i in range(0, nblocks):
        if has_ff and i > 0:
            vp2 = MLP(**node_ff_args)(vp)
            ep2 = MLP(**edge_ff_args)(ep)
            up2 = MLP(**state_ff_args)(up)

        # MEGnetBlock
        vp2, ep2, up2 = MEGnetBlock(**meg_block_args)(
            [vp2, ep2, edi, up2])

        # skip connection
        if dropout is not None:
            vp2 = Dropout(dropout, name='dropout_atom_%d' % i)(vp2)
            ep2 = Dropout(dropout, name='dropout_bond_%d' % i)(ep2)
            up2 = Dropout(dropout, name='dropout_state_%d' % i)(up2)

        vp = Add()([vp2, vp])
        ep = Add()([ep2, ep])
        up = Add(input_tensor_type="tensor")([up2, up])

    if use_set2set:
        vp = Dense(set2set_args["channels"], activation='linear')(vp)  # to match units
        ep = Dense(set2set_args["channels"], activation='linear')(ep)  # to match units
        vp = Set2Set(**set2set_args)(vp)
        ep = Set2Set(**set2set_args)(ep)
    else:
        vp = PoolingNodes()(vp)
        ep = PoolingGlobalEdges()(ep)

    ep = ks.layers.Flatten()(ep)
    vp = ks.layers.Flatten()(vp)
    final_vec = ks.layers.Concatenate(axis=-1)([vp, ep, up])

    if dropout is not None:
        final_vec = ks.layers.Dropout(dropout, name='dropout_final')(final_vec)

    # final dense layers
    main_output = MLP(**output_mlp, input_tensor_type="tensor")(final_vec)

    model = ks.models.Model(inputs=[node_input, edge_input, edge_index_input, env_input], outputs=main_output)

    return model
Ejemplo n.º 6
0
def make_schnet(
        # Input
        input_node_shape,
        input_edge_shape,
        input_embedd: dict = None,
        # Output
        output_mlp: dict = None,
        output_dense: dict = None,
        output_embedd: dict = None,
        # Model specific
        depth=4,
        out_scale_pos=0,
        interaction_args: dict = None,
        node_pooling_args: dict = None):
    """
    Make uncompiled SchNet model.

    Args:
        input_node_shape (list): Shape of node features. If shape is (None,) embedding layer is used.
        input_edge_shape (list): Shape of edge features. If shape is (None,) embedding layer is used.
        input_embedd (list): Dictionary of embedding parameters used if input shape is None. Default is
            {'input_node_vocab': 95, 'input_edge_vocab': 5, 'input_state_vocab': 100,
            'input_node_embedd': 64, 'input_edge_embedd': 64, 'input_state_embedd': 64,
            'input_type': 'ragged'}
        output_mlp (dict, optional): Parameter for MLP output classification/ regression. Defaults to
            {"use_bias": [True, True], "units": [128, 64],
            "activation": ['shifted_softplus', 'shifted_softplus']}
        output_dense (dict): Parameter for Dense scaling layer. Defaults to {"units": 1, "activation": 'linear',
             "use_bias": True}.
        output_embedd (str): Dictionary of embedding parameters of the graph network. Default is
             {"output_mode": 'graph', "output_type": 'padded'}
        depth (int, optional): Number of Interaction units. Defaults to 4.
        out_scale_pos (int, optional): Scaling output, position of layer. Defaults to 0.
        interaction_args (dict): Interaction Layer arguments. Defaults include {"node_dim" : 128, "use_bias": True,
             "activation" : 'shifted_softplus', "cfconv_pool" : 'segment_sum',
             "is_sorted": False, "has_unconnected": True}
        node_pooling_args (dict, optional): Node pooling arguments. Defaults to {"pooling_method": "segment_sum"}.

    Returns:
        model (tf.keras.models.Model): SchNet.

    """
    # Make default values if None
    model_default = {
        'input_embedd': {
            'input_node_vocab': 95,
            'input_edge_vocab': 5,
            'input_state_vocab': 100,
            'input_node_embedd': 64,
            'input_edge_embedd': 64,
            'input_state_embedd': 64,
            'input_tensor_type': 'ragged'
        },
        'output_embedd': {
            "output_mode": 'graph',
            "output_type": 'padded'
        },
        'interaction_args': {
            "units": 128,
            "use_bias": True,
            "activation": 'shifted_softplus',
            "cfconv_pool": 'sum',
            "is_sorted": False,
            "has_unconnected": True
        },
        'output_mlp': {
            "use_bias": [True, True],
            "units": [128, 64],
            "activation": ['shifted_softplus', 'shifted_softplus']
        },
        'output_dense': {
            "units": 1,
            "activation": 'linear',
            "use_bias": True
        },
        'node_pooling_args': {
            "pooling_method": "sum"
        }
    }

    # Update args
    input_embedd = update_model_args(model_default['input_embedd'],
                                     input_embedd)
    interaction_args = update_model_args(model_default['interaction_args'],
                                         interaction_args)
    output_mlp = update_model_args(model_default['output_mlp'], output_mlp)
    output_dense = update_model_args(model_default['output_dense'],
                                     output_dense)
    output_embedd = update_model_args(model_default['output_embedd'],
                                      output_embedd)
    node_pooling_args = update_model_args(model_default['node_pooling_args'],
                                          node_pooling_args)

    # Make input embedding, if no feature dimension
    node_input, n, edge_input, ed, edge_index_input, _, _ = generate_standard_graph_input(
        input_node_shape, input_edge_shape, None, **input_embedd)

    # Use representation
    tens_type = "values_partition"
    node_indexing = "batch"
    n = ChangeTensorType(input_tensor_type="ragged",
                         output_tensor_type=tens_type)(n)
    ed = ChangeTensorType(input_tensor_type="ragged",
                          output_tensor_type=tens_type)(ed)
    edi = ChangeTensorType(input_tensor_type="ragged",
                           output_tensor_type=tens_type)(edge_index_input)
    edi = ChangeIndexing(input_tensor_type=tens_type,
                         to_indexing=node_indexing)([n, edi])

    n = Dense(interaction_args["units"],
              activation='linear',
              input_tensor_type=tens_type)(n)

    for i in range(0, depth):
        n = SchNetInteraction(input_tensor_type=tens_type,
                              node_indexing=node_indexing,
                              **interaction_args)([n, ed, edi])

    n = MLP(input_tensor_type=tens_type, **output_mlp)(n)

    mlp_last = Dense(input_tensor_type=tens_type, **output_dense)

    if output_embedd["output_mode"] == 'graph':
        if out_scale_pos == 0:
            n = mlp_last(n)
        out = PoolingNodes(input_tensor_type=tens_type,
                           node_indexing=node_indexing,
                           **node_pooling_args)(n)
        if out_scale_pos == 1:
            out = mlp_last(out)
        main_output = ks.layers.Flatten()(out)  # will be dense
    else:  # node embedding
        out = mlp_last(n)
        main_output = ChangeTensorType(
            input_tensor_type="values_partition",
            output_tensor_type="tensor")(out)  # no ragged for distribution atm

    model = ks.models.Model(inputs=[node_input, edge_input, edge_index_input],
                            outputs=main_output)

    return model
Ejemplo n.º 7
0
def make_gcn(
        # Input
        input_node_shape,
        input_edge_shape,
        input_embedd: dict = None,
        # Output
        output_embedd: dict = None,
        output_mlp: dict = None,
        # Model specific
        depth=3,
        gcn_args: dict = None):
    """
    Make GCN model.

    Args:
        input_node_shape (list): Shape of node features. If shape is (None,) embedding layer is used.
        input_edge_shape (list): Shape of edge features. If shape is (None,) embedding layer is used.
        input_embedd (dict): Dictionary of embedding parameters used if input shape is None. Default is
            {"input_node_vocab": 100, "input_edge_vocab": 10, "input_state_vocab": 100,
            "input_node_embedd": 64, "input_edge_embedd": 64, "input_state_embedd": 64,
            "input_type": 'ragged'}.
        output_embedd (dict): Dictionary of embedding parameters of the graph network. Default is
            {"output_mode": 'graph', "output_type": 'padded'}.
        output_mlp (dict): Dictionary of arguments for final MLP regression or classifcation layer. Default is
            {"use_bias": [True, True, False], "units": [25, 10, 1],
            "activation": ['relu', 'relu', 'sigmoid']}.
        depth (int, optional): Number of convolutions. Defaults to 3.
        gcn_args (dict): Dictionary of arguments for the GCN convolutional unit. Defaults to
            {"units": 100, "use_bias": True, "activation": 'relu', "pooling_method": 'segment_sum',
            "is_sorted": False, "has_unconnected": "True"}.

    Returns:
        model (tf.keras.models.Model): uncompiled model.

    """

    if input_edge_shape[-1] != 1:
        raise ValueError(
            "No edge features available for GCN, only edge weights of pre-scaled adjacency matrix, \
                         must be shape (batch, None, 1), but got (without batch-dimension): ",
            input_edge_shape)
    # Make default args
    model_default = {
        'input_embedd': {
            "input_node_vocab": 100,
            "input_edge_vocab": 10,
            "input_state_vocab": 100,
            "input_node_embedd": 64,
            "input_edge_embedd": 64,
            "input_state_embedd": 64,
            "input_tensor_type": 'ragged'
        },
        'output_embedd': {
            "output_mode": 'graph',
            "output_tensor_type": 'masked'
        },
        'output_mlp': {
            "use_bias": [True, True, False],
            "units": [25, 10, 1],
            "activation": ['relu', 'relu', 'sigmoid']
        },
        'gcn_args': {
            "units": 100,
            "use_bias": True,
            "activation": 'relu',
            "pooling_method": 'sum',
            "is_sorted": False,
            "has_unconnected": True
        }
    }

    # Update model parameter
    input_embedd = update_model_args(model_default['input_embedd'],
                                     input_embedd)
    output_embedd = update_model_args(model_default['output_embedd'],
                                      output_embedd)
    output_mlp = update_model_args(model_default['output_mlp'], output_mlp)
    gcn_args = update_model_args(model_default['gcn_args'], gcn_args)

    # Make input embedding, if no feature dimension
    node_input, n, edge_input, ed, edge_index_input, env_input, uenv = generate_standard_graph_input(
        input_node_shape, input_edge_shape, None, **input_embedd)

    edi = edge_index_input

    # Map to units
    n = Dense(gcn_args["units"], use_bias=True, activation='linear')(n)

    # n-Layer Step
    for i in range(0, depth):
        n = GCN(**gcn_args)([n, ed, edi])

    if output_embedd["output_mode"] == "graph":
        out = PoolingNodes()(n)  # will return tensor
        output_mlp.update({"input_tensor_type": "tensor"})
        out = MLP(**output_mlp)(out)

    else:  # Node labeling
        out = n
        out = MLP(**output_mlp)(out)
        out = ChangeTensorType(
            input_tensor_type='ragged', output_tensor_type="tensor")(
                out)  # no ragged for distribution supported atm

    model = ks.models.Model(inputs=[node_input, edge_input, edge_index_input],
                            outputs=out)

    return model
Ejemplo n.º 8
0
def make_attentiveFP(  # Input
        input_node_shape,
        input_edge_shape,
        input_embedd: dict = None,
        # Output
        output_embedd: dict = None,
        output_mlp: dict = None,
        # Model specific parameter
        depth=3,
        attention_args: dict = None):
    """
    Generate Interaction network.

    Args:
        input_node_shape (list): Shape of node features. If shape is (None,) embedding layer is used.
        input_edge_shape (list): Shape of edge features. If shape is (None,) embedding layer is used.
        input_embedd (dict): Dictionary of embedding parameters used if input shape is None. Default is
            {'input_node_vocab': 95, 'input_edge_vocab': 5, 'input_state_vocab': 100,
            'input_node_embedd': 64, 'input_edge_embedd': 64, 'input_state_embedd': 64,
            'input_type': 'ragged'}.
        output_embedd (dict): Dictionary of embedding parameters of the graph network. Default is
            {"output_mode": 'graph', "output_type": 'padded'}.
        output_mlp (dict): Dictionary of arguments for final MLP regression or classifcation layer. Default is
            {"use_bias": [True, True, False], "units": [25, 10, 1],
            "activation": ['relu', 'relu', 'sigmoid']}.
        depth (int): Number of convolution layers. Default is 3.
        attention_args (dict): Layer arguments for attention layer. Default is
            {"units": 32, 'is_sorted': False, 'has_unconnected': True}
    Returns:
        model (tf.keras.model): Interaction model.
    """
    print("Warning model has not been tested yet.")
    # default values
    model_default = {
        'input_embedd': {
            'input_node_vocab': 95,
            'input_edge_vocab': 5,
            'input_state_vocab': 100,
            'input_node_embedd': 64,
            'input_edge_embedd': 64,
            'input_state_embedd': 64,
            'input_tensor_type': 'ragged'
        },
        'output_embedd': {
            "output_mode": 'graph',
            "output_tensor_type": 'padded'
        },
        'output_mlp': {
            "use_bias": [True, True, False],
            "units": [25, 10, 1],
            "activation": ['relu', 'relu', 'sigmoid']
        },
        'attention_args': {
            "units": 32,
            'is_sorted': False,
            'has_unconnected': True
        }
    }

    # Update default values
    input_embedd = update_model_args(model_default['input_embedd'],
                                     input_embedd)
    output_embedd = update_model_args(model_default['output_embedd'],
                                      output_embedd)
    output_mlp = update_model_args(model_default['output_mlp'], output_mlp)
    attention_args = update_model_args(model_default['attention_args'],
                                       attention_args)

    # Make input embedding, if no feature dimension
    node_input, n, edge_input, ed, edge_index_input, _, _ = generate_standard_graph_input(
        input_node_shape, input_edge_shape, None, **input_embedd)

    edi = edge_index_input
    nk = Dense(units=attention_args['units'])(n)
    Ck = AttentiveHeadFP(use_edge_features=True,
                         **attention_args)([nk, ed, edi])
    nk = GRUupdate(units=attention_args['units'])([nk, Ck])

    for i in range(1, depth):
        Ck = AttentiveHeadFP(**attention_args)([nk, ed, edi])
        nk = GRUupdate(units=attention_args['units'])([nk, Ck])

    n = nk
    if output_embedd["output_mode"] == 'graph':
        out = AttentiveNodePooling(units=attention_args['units'])(n)
        output_mlp.update({"input_tensor_type": "tensor"})
        out = MLP(**output_mlp)(out)
        main_output = ks.layers.Flatten()(out)  # will be dense
    else:  # node embedding
        out = MLP(**output_mlp)(n)
        main_output = ChangeTensorType(input_tensor_type="ragged",
                                       output_tensor_type="tensor")(out)

    model = tf.keras.models.Model(
        inputs=[node_input, edge_input, edge_index_input], outputs=main_output)

    return model
Ejemplo n.º 9
0
def make_dimnet_pp(
        # Input
        input_node_shape,
        input_embedd: dict = None,
        # Output
        output_embedd: dict = None,
        # Model specific parameter
        emb_size = 128,
        out_emb_size = 256,
        int_emb_size = 64,
        basis_emb_size =8,
        num_blocks = 4,
        num_spherical = 7,
        num_radial= 6,
        cutoff=5.0,
        envelope_exponent=5,
        num_before_skip=1,
        num_after_skip=2,
        num_dense_output=3,
        num_targets=12,
        activation="swish",
        extensive=True,
        output_init='zeros',
        ):
    model_default = {'input_embedd': {'input_node_vocab': 95, 'input_node_embedd': 64, 'input_tensor_type': 'ragged'}
                     }

    input_embedd = update_model_args(model_default['input_embedd'], input_embedd)
    node_input, n, xyz_input, bond_index_input, angle_index_input, _ = generate_mol_graph_input(input_node_shape,
                                                                                                [None, 3],
                                                                                                [None, 2],
                                                                                                [None, 2],
                                                                                                **input_embedd)
    x = xyz_input
    edi = bond_index_input
    adi = angle_index_input

    # Calculate distances
    d = NodeDistance()([x, edi])
    rbf = BesselBasisLayer(num_radial=num_radial, cutoff=cutoff, envelope_exponent=envelope_exponent)(d)

    # Calculate angles
    a = EdgeAngle()([x, edi, adi])
    sbf = SphericalBasisLayer(num_spherical=num_spherical, num_radial=num_radial, cutoff=cutoff,
                              envelope_exponent=envelope_exponent)([d, a, adi])

    # Embedding block
    rbf_emb = Dense(emb_size, use_bias=True, activation=activation, kernel_initializer="orthogonal")(rbf)
    n_pairs = GatherNodes()([n, edi])
    x = Concatenate(axis=-1)([n_pairs, rbf_emb])
    x = Dense(emb_size, use_bias=True, activation=activation, kernel_initializer="orthogonal")(x)
    ps = DimNetOutputBlock(emb_size, out_emb_size, num_dense_output, num_targets=num_targets,
                           output_kernel_initializer=output_init)([n, x, rbf, edi])

    # Interaction blocks
    add_xp = Add()
    for i in range(num_blocks):
        x = DimNetInteractionPPBlock(emb_size, int_emb_size, basis_emb_size, num_before_skip, num_after_skip)(
            [x, rbf, sbf, adi])
        p_update = DimNetOutputBlock(emb_size, out_emb_size, num_dense_output, num_targets=num_targets,
                                     output_kernel_initializer=output_init)([n, x, rbf, edi])
        ps = add_xp([ps, p_update])

    if extensive:
        main_output = PoolingNodes(pooling_method="sum")(ps)
    else:
        main_output = PoolingNodes(pooling_method="mean")(ps)

    model = tf.keras.models.Model(inputs=[node_input, xyz_input, bond_index_input, angle_index_input],
                                  outputs=main_output)

    return model
Ejemplo n.º 10
0
def make_graph_sage(  # Input
        input_node_shape,
        input_edge_shape,
        input_embedd: dict = None,
        # Output
        output_embedd: dict = None,
        output_mlp: dict = None,
        # Model specific parameter
        depth=3,
        use_edge_features=False,
        node_mlp_args: dict = None,
        edge_mlp_args: dict = None,
        pooling_args: dict = None):
    """
    Generate Interaction network.

    Args:
        input_node_shape (list): Shape of node features. If shape is (None,) embedding layer is used.
        input_edge_shape (list): Shape of edge features. If shape is (None,) embedding layer is used.
        input_embedd (dict): Dictionary of embedding parameters used if input shape is None. Default is
            {'input_node_vocab': 95, 'input_edge_vocab': 5, 'input_state_vocab': 100,
            'input_node_embedd': 64, 'input_edge_embedd': 64, 'input_state_embedd': 64,
            'input_type': 'ragged'}.
        output_embedd (dict): Dictionary of embedding parameters of the graph network. Default is
            {"output_mode": 'graph', "output_type": 'padded'}.
        output_mlp (dict): Dictionary of arguments for final MLP regression or classifcation layer. Default is
            {"use_bias": [True, True, False], "units": [25, 10, 1],
            "activation": ['relu', 'relu', 'sigmoid']}.
        depth (int): Number of convolution layers. Default is 3.
        use_edge_features (bool): Whether to concatenate edges with nodes in aggregate. Default is False.
        node_mlp_args (dict): Dictionary of arguments for MLP for node update. Default is
            {"units": [100, 50], "use_bias": True, "activation": ['relu', "linear"]}
        edge_mlp_args (dict): Dictionary of arguments for MLP for interaction update. Default is
            {"units": [100, 100, 100, 100, 50],
            "activation": ['relu', 'relu', 'relu', 'relu', "linear"]}
        pooling_args (dict): Dictionary for message pooling arguments. Default is
            {'is_sorted': False, 'has_unconnected': True, 'pooling_method': "segment_mean"}

    Returns:
        model (tf.keras.model): Interaction model.

    """
    # default values
    model_default = {
        'input_embedd': {
            'input_node_vocab': 95,
            'input_edge_vocab': 5,
            'input_state_vocab': 100,
            'input_node_embedd': 64,
            'input_edge_embedd': 64,
            'input_state_embedd': 64,
            'input_tensor_type': 'ragged'
        },
        'output_embedd': {
            "output_mode": 'graph',
            "output_tensor_type": 'padded'
        },
        'output_mlp': {
            "use_bias": [True, True, False],
            "units": [25, 10, 1],
            "activation": ['relu', 'relu', 'sigmoid']
        },
        'node_mlp_args': {
            "units": [100, 50],
            "use_bias": True,
            "activation": ['relu', "linear"]
        },
        'edge_mlp_args': {
            "units": [100, 50],
            "use_bias": True,
            "activation": ['relu', "linear"]
        },
        'pooling_args': {
            'is_sorted': False,
            'has_unconnected': True,
            'pooling_method': "segment_mean"
        }
    }

    # Update default values
    input_embedd = update_model_args(model_default['input_embedd'],
                                     input_embedd)
    output_embedd = update_model_args(model_default['output_embedd'],
                                      output_embedd)
    output_mlp = update_model_args(model_default['output_mlp'], output_mlp)
    node_mlp_args = update_model_args(model_default['node_mlp_args'],
                                      node_mlp_args)
    edge_mlp_args = update_model_args(model_default['edge_mlp_args'],
                                      edge_mlp_args)
    pooling_args = update_model_args(model_default['pooling_args'],
                                     pooling_args)
    pooling_nodes_args = {
        "input_tensor_type": 'ragged',
        "node_indexing": 'sample',
        'pooling_method': "mean"
    }
    gather_args = {"node_indexing": 'sample'}
    concat_args = {"axis": -1, "input_tensor_type": 'ragged'}

    # Make input embedding, if no feature dimension
    node_input, n, edge_input, ed, edge_index_input, _, _ = generate_standard_graph_input(
        input_node_shape, input_edge_shape, None, **input_embedd)
    edi = edge_index_input

    for i in range(0, depth):
        # upd = GatherNodes()([n,edi])
        eu = GatherNodesOutgoing(**gather_args)([n, edi])
        if use_edge_features:
            eu = Concatenate(**concat_args)([eu, ed])

        eu = MLP(**edge_mlp_args)(eu)
        # Pool message
        if pooling_args['pooling_method'] in ["LSTM", "lstm"]:
            nu = PoolingLocalEdgesLSTM(**pooling_args)([n, eu, edi])
        else:
            nu = PoolingLocalMessages(**pooling_args)(
                [n, eu, edi])  # Summing for each node connection

        nu = Concatenate(**concat_args)(
            [n, nu])  # Concatenate node features with new edge updates

        n = MLP(**node_mlp_args)(nu)
        n = LayerNormalization(axis=-1)(n)  # Normalize

    # Regression layer on output
    if output_embedd["output_mode"] == 'graph':
        out = PoolingNodes(**pooling_nodes_args)(n)
        output_mlp.update({"input_tensor_type": "tensor"})
        out = MLP(**output_mlp)(out)
        main_output = ks.layers.Flatten()(out)  # will be tensor
    else:  # node embedding
        out = MLP(**output_mlp)(n)
        main_output = ChangeTensorType(input_tensor_type='ragged',
                                       output_tensor_type="tensor")(out)

    model = tf.keras.models.Model(
        inputs=[node_input, edge_input, edge_index_input], outputs=main_output)

    return model