def __init__(self, units, pooling_method='sum', normalize_by_weights=False, activation='kgcnn>leaky_relu', use_bias=True, kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None, kernel_initializer='glorot_uniform', bias_initializer='zeros', **kwargs): """Initialize layer.""" super(GCN, self).__init__(**kwargs) self.normalize_by_weights = normalize_by_weights self.pooling_method = pooling_method self.units = units kernel_args = {"kernel_regularizer": kernel_regularizer, "activity_regularizer": activity_regularizer, "bias_regularizer": bias_regularizer, "kernel_constraint": kernel_constraint, "bias_constraint": bias_constraint, "kernel_initializer": kernel_initializer, "bias_initializer": bias_initializer, "use_bias": use_bias} pool_args = {"pooling_method": pooling_method, "normalize_by_weights": normalize_by_weights} # Layers self.lay_gather = GatherNodesOutgoing(**self._kgcnn_info) self.lay_dense = Dense(units=self.units, activation='linear', input_tensor_type=self.input_tensor_type, ragged_validate=self.ragged_validate, **kernel_args) self.lay_pool = PoolingWeightedLocalEdges(**pool_args, **self._kgcnn_info) self.lay_act = Activation(activation, ragged_validate=self.ragged_validate, input_tensor_type=self.input_tensor_type)
def __init__(self, units, cfconv_pool='segment_sum', use_bias=True, activation='kgcnn>shifted_softplus', kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None, kernel_initializer='glorot_uniform', bias_initializer='zeros', **kwargs): """Initialize Layer.""" super(SchNetCFconv, self).__init__(**kwargs) self.cfconv_pool = cfconv_pool self.units = units self.use_bias = use_bias kernel_args = {"kernel_regularizer": kernel_regularizer, "activity_regularizer": activity_regularizer, "bias_regularizer": bias_regularizer, "kernel_constraint": kernel_constraint, "bias_constraint": bias_constraint, "kernel_initializer": kernel_initializer, "bias_initializer": bias_initializer} # Layer self.lay_dense1 = Dense(units=self.units, activation=activation, use_bias=self.use_bias, **kernel_args, **self._kgcnn_info) self.lay_dense2 = Dense(units=self.units, activation='linear', use_bias=self.use_bias, **kernel_args, **self._kgcnn_info) self.lay_sum = PoolingLocalEdges(pooling_method=cfconv_pool, **self._kgcnn_info) self.gather_n = GatherNodesOutgoing(**self._kgcnn_info) self.lay_mult = Multiply(**self._kgcnn_info)
def __init__(self, units, pooling_method='sum', normalize_by_weights=False, activation=None, use_bias=True, kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None, kernel_initializer='glorot_uniform', bias_initializer='zeros', **kwargs): """Initialize layer.""" super(GCN, self).__init__(**kwargs) self.normalize_by_weights = normalize_by_weights self.pooling_method = pooling_method self.units = units if activation is None and 'leaky_relu' in kgcnn_custom_act: activation = {"class_name": "leaky_relu", "config": {"alpha": 0.2}} elif activation is None: activation = "relu" kernel_args = {"kernel_regularizer": kernel_regularizer, "activity_regularizer": activity_regularizer, "bias_regularizer": bias_regularizer, "kernel_constraint": kernel_constraint, "bias_constraint": bias_constraint, "kernel_initializer": kernel_initializer, "bias_initializer": bias_initializer, "use_bias": use_bias} gather_args = self._all_kgcnn_info pool_args = {"pooling_method": pooling_method, "normalize_by_weights": normalize_by_weights} pool_args.update(self._all_kgcnn_info) # Layers self.lay_gather = GatherNodesOutgoing(**gather_args) self.lay_dense = Dense(units=self.units, activation='linear', input_tensor_type=self.input_tensor_type, ragged_validate=self.ragged_validate, **kernel_args) self.lay_pool = PoolingWeightedLocalEdges(**pool_args) self.lay_act = Activation(activation, ragged_validate=self.ragged_validate, input_tensor_type=self.input_tensor_type)
def __init__(self, units, cfconv_pool='segment_sum', use_bias=True, activation=None, kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None, kernel_initializer='glorot_uniform', bias_initializer='zeros', **kwargs): """Initialize Layer.""" super(SchNetCFconv, self).__init__(**kwargs) self.cfconv_pool = cfconv_pool self.units = units self.use_bias = use_bias if activation is None and 'shifted_softplus' in kgcnn_custom_act: activation = 'shifted_softplus' elif activation is None: activation = "selu" kernel_args = {"kernel_regularizer": kernel_regularizer, "activity_regularizer": activity_regularizer, "bias_regularizer": bias_regularizer, "kernel_constraint": kernel_constraint, "bias_constraint": bias_constraint, "kernel_initializer": kernel_initializer, "bias_initializer": bias_initializer} pooling_args = {"pooling_method": cfconv_pool} pooling_args.update(self._all_kgcnn_info) # Layer self.lay_dense1 = Dense(units=self.units, activation=activation, use_bias=self.use_bias, input_tensor_type=self.input_tensor_type, ragged_validate=self.ragged_validate, **kernel_args) self.lay_dense2 = Dense(units=self.units, activation='linear', use_bias=self.use_bias, input_tensor_type=self.input_tensor_type, ragged_validate=self.ragged_validate, **kernel_args) self.lay_sum = PoolingLocalEdges(**pooling_args) self.gather_n = GatherNodesOutgoing(**self._all_kgcnn_info) self.lay_mult = Multiply(input_tensor_type=self.input_tensor_type, ragged_validate=self.ragged_validate)
def __init__(self, emb_size, int_emb_size, basis_emb_size, num_before_skip, num_after_skip, use_bias=True, pooling_method="sum", activation=None, kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None, kernel_initializer='orthogonal', bias_initializer='zeros', **kwargs): super(DimNetInteractionPPBlock, self).__init__(**kwargs) self.use_bias = use_bias self.pooling_method = pooling_method self.emb_size = emb_size self.int_emb_size = int_emb_size self.basis_emb_size = basis_emb_size self.num_before_skip = num_before_skip self.num_after_skip = num_after_skip if activation is None and 'swish' in kgcnn_custom_act: activation = 'swish' elif activation is None: activation = "selu" kernel_args = { "kernel_regularizer": kernel_regularizer, "activity_regularizer": activity_regularizer, "bias_regularizer": bias_regularizer, "kernel_constraint": kernel_constraint, "bias_constraint": bias_constraint, "kernel_initializer": kernel_initializer, "bias_initializer": bias_initializer } pool_args = {"pooling_method": pooling_method} pool_args.update(self._all_kgcnn_info) gather_args = self._all_kgcnn_info # Transformations of Bessel and spherical basis representations self.dense_rbf1 = Dense(basis_emb_size, use_bias=False, **kernel_args) self.dense_rbf2 = Dense(emb_size, use_bias=False, **kernel_args) self.dense_sbf1 = Dense(basis_emb_size, use_bias=False, **kernel_args) self.dense_sbf2 = Dense(int_emb_size, use_bias=False, **kernel_args) # Dense transformations of input messages self.dense_ji = Dense(emb_size, activation=activation, use_bias=True, **kernel_args) self.dense_kj = Dense(emb_size, activation=activation, use_bias=True, **kernel_args) # Embedding projections for interaction triplets self.down_projection = Dense(int_emb_size, activation=activation, use_bias=False, **kernel_args) self.up_projection = Dense(emb_size, activation=activation, use_bias=False, **kernel_args) # Residual layers before skip connection self.layers_before_skip = [] for i in range(num_before_skip): self.layers_before_skip.append( ResidualLayer(emb_size, activation=activation, use_bias=True, **kernel_args)) self.final_before_skip = Dense(emb_size, activation=activation, use_bias=True, **kernel_args) # Residual layers after skip connection self.layers_after_skip = [] for i in range(num_after_skip): self.layers_after_skip.append( ResidualLayer(emb_size, activation=activation, use_bias=True, **kernel_args)) self.lay_add1 = Add() self.lay_add2 = Add() self.lay_mult1 = Multiply() self.lay_mult2 = Multiply() self.lay_gather = GatherNodesOutgoing(**gather_args) # Are edges here self.lay_pool = PoolingLocalEdges(**pool_args)
def __init__(self, units, activation=None, use_bias=True, kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None, kernel_initializer='glorot_uniform', bias_initializer='zeros', use_edge_features=False, node_indexing="sample", is_sorted=False, has_unconnected=True, partition_type="row_length", input_tensor_type="ragged", ragged_validate=False, **kwargs): """Initialize layer.""" super(AttentionHeadGAT, self).__init__(**kwargs) # graph args self.is_sorted = is_sorted self.use_edge_features = use_edge_features self.has_unconnected = has_unconnected self.node_indexing = node_indexing self.partition_type = partition_type self.input_tensor_type = input_tensor_type self.ragged_validate = ragged_validate self._supports_ragged_inputs = True self._tensor_input_type_implemented = [ "ragged", "values_partition", "disjoint", "tensor", "RaggedTensor" ] self._test_tensor_input = kgcnn_ops_static_test_tensor_input_type( self.input_tensor_type, self._tensor_input_type_implemented, self.node_indexing) # dense args self.units = int(units) if activation is None and "leaky_relu" in kgcnn_custom_act: activation = {"class_name": "leaky_relu", "config": {"alpha": 0.2}} elif activation is None: activation = "relu" self.use_bias = use_bias self.ath_activation = tf.keras.activations.get(activation) self.ath_kernel_initializer = tf.keras.initializers.get( kernel_initializer) self.ath_bias_initializer = tf.keras.initializers.get(bias_initializer) self.ath_kernel_regularizer = tf.keras.regularizers.get( kernel_regularizer) self.ath_bias_regularizer = tf.keras.regularizers.get(bias_regularizer) self.ath_activity_regularizer = tf.keras.regularizers.get( activity_regularizer) self.ath_kernel_constraint = tf.keras.constraints.get( kernel_constraint) self.ath_bias_constraint = tf.keras.constraints.get(bias_constraint) kernel_args = { "use_bias": use_bias, "kernel_regularizer": kernel_regularizer, "activity_regularizer": activity_regularizer, "bias_regularizer": bias_regularizer, "kernel_constraint": kernel_constraint, "bias_constraint": bias_constraint, "kernel_initializer": kernel_initializer, "bias_initializer": bias_initializer } dens_args = { "ragged_validate": self.ragged_validate, "input_tensor_type": self.input_tensor_type } dens_args.update(kernel_args) gather_args = { "input_tensor_type": self.input_tensor_type, "node_indexing": self.node_indexing } pooling_args = { "node_indexing": node_indexing, "partition_type": partition_type, "has_unconnected": has_unconnected, "is_sorted": is_sorted, "ragged_validate": self.ragged_validate, "input_tensor_type": self.input_tensor_type } self.lay_linear_trafo = Dense(units, activation="linear", **dens_args) self.lay_alpha = Dense(1, activation=activation, **dens_args) self.lay_gather_in = GatherNodesIngoing(**gather_args) self.lay_gather_out = GatherNodesOutgoing(**gather_args) self.lay_concat = Concatenate(axis=-1, input_tensor_type=self.input_tensor_type) self.lay_pool_attention = PoolingLocalEdgesAttention(**pooling_args) self.lay_final_activ = Activation( activation=activation, input_tensor_type=self.input_tensor_type)
def make_nmpn( # Input input_node_shape, input_edge_shape, input_embedd: dict = None, # Output output_embedd: dict = None, output_mlp: dict = None, # Model specific depth=3, node_dim=128, edge_dense: dict = None, use_set2set=True, set2set_args: dict = None, pooling_args: dict = None): """ Get Message passing model. Args: input_node_shape (list): Shape of node features. If shape is (None,) embedding layer is used. input_edge_shape (list): Shape of edge features. If shape is (None,) embedding layer is used. input_embedd (dict): Dictionary of embedding parameters used if input shape is None. Default is {'input_node_vocab': 95, 'input_edge_vocab': 5, 'input_state_vocab': 100, 'input_node_embedd': 64, 'input_edge_embedd': 64, 'input_state_embedd': 64, 'input_type': 'ragged'} output_embedd (str): Dictionary of embedding parameters of the graph network. Default is {"output_mode": 'graph', "output_type": 'padded'} output_mlp (dict): Dictionary of MLP arguments for output regression or classifcation. Default is {"use_bias": [True, True, False], "units": [25, 10, 1], "output_activation": ['selu', 'selu', 'sigmoid']} depth (int, optional): Depth. Defaults to 3. node_dim (int, optional): Dimension for hidden node representation. Defaults to 128. edge_dense (dict): Dictionary of arguments for NN to make edge matrix. Default is {'use_bias' : True, 'activation' : 'selu'} use_set2set (bool, optional): Use set2set layer. Defaults to True. set2set_args (dict): Dictionary of Set2Set Layer Arguments. Default is {'channels': 32, 'T': 3, "pooling_method": "sum", "init_qstar": "0"} pooling_args (dict): Dictionary for message pooling arguments. Default is {'is_sorted': False, 'has_unconnected': True, 'pooling_method': "segment_mean"} Returns: model (ks.models.Model): Message Passing model. """ # Make default parameter model_default = { 'input_embedd': { 'input_node_vocab': 95, 'input_edge_vocab': 5, 'input_state_vocab': 100, 'input_node_embedd': 64, 'input_edge_embedd': 64, 'input_state_embedd': 64, 'input_tensor_type': 'ragged' }, 'output_embedd': { "output_mode": 'graph', "output_type": 'padded' }, 'output_mlp': { "use_bias": [True, True, False], "units": [25, 10, 1], "activation": ['selu', 'selu', 'sigmoid'] }, 'set2set_args': { 'channels': 32, 'T': 3, "pooling_method": "sum", "init_qstar": "0" }, 'pooling_args': { 'is_sorted': False, 'has_unconnected': True, 'pooling_method': "segment_mean" }, 'edge_dense': { 'use_bias': True, 'activation': 'selu' } } # Update model args input_embedd = update_model_args(model_default['input_embedd'], input_embedd) output_embedd = update_model_args(model_default['output_embedd'], output_embedd) output_mlp = update_model_args(model_default['output_mlp'], output_mlp) set2set_args = update_model_args(model_default['set2set_args'], set2set_args) pooling_args = update_model_args(model_default['pooling_args'], pooling_args) edge_dense = update_model_args(model_default['edge_dense'], edge_dense) # Make input embedding, if no feature dimension node_input, n, edge_input, ed, edge_index_input, _, _ = generate_standard_graph_input( input_node_shape, input_edge_shape, None, **input_embedd) tens_type = "values_partition" node_indexing = "batch" n = ChangeTensorType(input_tensor_type="ragged", output_tensor_type=tens_type)(n) ed = ChangeTensorType(input_tensor_type="ragged", output_tensor_type=tens_type)(ed) edi = ChangeTensorType(input_tensor_type="ragged", output_tensor_type=tens_type)(edge_index_input) edi = ChangeIndexing(input_tensor_type=tens_type, to_indexing=node_indexing)([n, edi]) set2set_args.update({"input_tensor_type": tens_type}) output_mlp.update({"input_tensor_type": tens_type}) edge_dense.update({"input_tensor_type": tens_type}) pooling_args.update({ "input_tensor_type": tens_type, "node_indexing": node_indexing }) n = Dense(node_dim, activation="linear", input_tensor_type=tens_type)(n) edge_net = Dense(node_dim * node_dim, **edge_dense)(ed) gru = GRUupdate(node_dim, input_tensor_type=tens_type, node_indexing=node_indexing) for i in range(0, depth): eu = GatherNodesOutgoing(input_tensor_type=tens_type, node_indexing=node_indexing)([n, edi]) eu = TrafoMatMulMessages(node_dim, input_tensor_type=tens_type, node_indexing=node_indexing)([edge_net, eu]) eu = PoolingLocalEdges(**pooling_args)( [n, eu, edi]) # Summing for each node connections n = gru([n, eu]) if output_embedd["output_mode"] == 'graph': if use_set2set: # output outss = Dense(set2set_args['channels'], activation="linear", input_tensor_type=tens_type)(n) out = Set2Set(**set2set_args)(outss) else: out = PoolingNodes(**pooling_args)(n) # final dense layers output_mlp.update({"input_tensor_type": "tensor"}) main_output = MLP(**output_mlp)(out) else: # Node labeling out = n main_output = MLP(**output_mlp)(out) main_output = ChangeTensorType( input_tensor_type=tens_type, output_tensor_type="tensor")(main_output) # no ragged for distribution supported atm model = ks.models.Model(inputs=[node_input, edge_input, edge_index_input], outputs=main_output) return model
def __init__(self, units, activation=None, use_bias=True, kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None, kernel_initializer='glorot_uniform', bias_initializer='zeros', node_indexing='sample', pooling_method='sum', is_sorted=False, has_unconnected=True, normalize_by_weights=False, partition_type="row_length", input_tensor_type="ragged", ragged_validate=False, **kwargs): """Initialize layer.""" super(GCN, self).__init__(**kwargs) self.node_indexing = node_indexing self.normalize_by_weights = normalize_by_weights self.partition_type = partition_type self.pooling_method = pooling_method self.has_unconnected = has_unconnected self.is_sorted = is_sorted self.input_tensor_type = input_tensor_type self.ragged_validate = ragged_validate self._tensor_input_type_implemented = [ "ragged", "values_partition", "disjoint", "tensor", "RaggedTensor" ] self._supports_ragged_inputs = True self._test_tensor_input = kgcnn_ops_static_test_tensor_input_type( self.input_tensor_type, self._tensor_input_type_implemented, self.node_indexing) if activation is None and 'leaky_relu' in kgcnn_custom_act: activation = {"class_name": "leaky_relu", "config": {"alpha": 0.2}} elif activation is None: activation = "relu" self.units = units self.use_bias = use_bias self.gcn_activation = tf.keras.activations.get(activation) self.gcn_kernel_initializer = tf.keras.initializers.get( kernel_initializer) self.gcn_bias_initializer = tf.keras.initializers.get(bias_initializer) self.gcn_kernel_regularizer = tf.keras.regularizers.get( kernel_regularizer) self.gcn_bias_regularizer = tf.keras.regularizers.get(bias_regularizer) self.gcn_activity_regularizer = tf.keras.regularizers.get( activity_regularizer) self.gcn_kernel_constraint = tf.keras.constraints.get( kernel_constraint) self.gcn_bias_constraint = tf.keras.constraints.get(bias_constraint) kernel_args = { "kernel_regularizer": kernel_regularizer, "activity_regularizer": activity_regularizer, "bias_regularizer": bias_regularizer, "kernel_constraint": kernel_constraint, "bias_constraint": bias_constraint, "kernel_initializer": kernel_initializer, "bias_initializer": bias_initializer } # Layers self.lay_gather = GatherNodesOutgoing( node_indexing=self.node_indexing, partition_type=self.partition_type, is_sorted=self.is_sorted, has_unconnected=self.has_unconnected, ragged_validate=self.ragged_validate, input_tensor_type=self.input_tensor_type) self.lay_dense = Dense(units=self.units, use_bias=self.use_bias, activation='linear', input_tensor_type=self.input_tensor_type, ragged_validate=self.ragged_validate, **kernel_args) self.lay_pool = PoolingWeightedLocalEdges( pooling_method=self.pooling_method, is_sorted=self.is_sorted, has_unconnected=self.has_unconnected, node_indexing=self.node_indexing, normalize_by_weights=self.normalize_by_weights, partition_type=self.partition_type, ragged_validate=self.ragged_validate, input_tensor_type=self.input_tensor_type) self.lay_act = Activation(activation, ragged_validate=self.ragged_validate, input_tensor_type=self.input_tensor_type)
def __init__(self, units, use_bias=True, activation=None, kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None, kernel_initializer='glorot_uniform', bias_initializer='zeros', cfconv_pool='segment_sum', is_sorted=False, has_unconnected=True, partition_type="row_length", node_indexing='sample', input_tensor_type="ragged", ragged_validate=False, **kwargs): """Initialize Layer.""" super(SchNetCFconv, self).__init__(**kwargs) self.cfconv_pool = cfconv_pool self.is_sorted = is_sorted self.partition_type = partition_type self.has_unconnected = has_unconnected self.node_indexing = node_indexing self.input_tensor_type = input_tensor_type self.ragged_validate = ragged_validate self._tensor_input_type_implemented = [ "ragged", "values_partition", "disjoint", "tensor", "RaggedTensor" ] self._supports_ragged_inputs = True self._test_tensor_input = kgcnn_ops_static_test_tensor_input_type( self.input_tensor_type, self._tensor_input_type_implemented, self.node_indexing) self.units = units self.use_bias = use_bias if activation is None and 'shifted_softplus' in kgcnn_custom_act: activation = 'shifted_softplus' elif activation is None: activation = "selu" self.cfc_activation = tf.keras.activations.get(activation) self.cfc_kernel_initializer = tf.keras.initializers.get( kernel_initializer) self.cfc_bias_initializer = tf.keras.initializers.get(bias_initializer) self.cfc_kernel_regularizer = tf.keras.regularizers.get( kernel_regularizer) self.cfc_bias_regularizer = tf.keras.regularizers.get(bias_regularizer) self.cfc_activity_regularizer = tf.keras.regularizers.get( activity_regularizer) self.cfc_kernel_constraint = tf.keras.constraints.get( kernel_constraint) self.cfc_bias_constraint = tf.keras.constraints.get(bias_constraint) kernel_args = { "kernel_regularizer": kernel_regularizer, "activity_regularizer": activity_regularizer, "bias_regularizer": bias_regularizer, "kernel_constraint": kernel_constraint, "bias_constraint": bias_constraint, "kernel_initializer": kernel_initializer, "bias_initializer": bias_initializer } # Layer self.lay_dense1 = Dense(units=self.units, activation=activation, use_bias=self.use_bias, input_tensor_type=self.input_tensor_type, ragged_validate=self.ragged_validate, **kernel_args) self.lay_dense2 = Dense(units=self.units, activation='linear', use_bias=self.use_bias, input_tensor_type=self.input_tensor_type, ragged_validate=self.ragged_validate, **kernel_args) self.lay_sum = PoolingLocalEdges( pooling_method=self.cfconv_pool, is_sorted=self.is_sorted, has_unconnected=self.has_unconnected, partition_type=self.partition_type, node_indexing=self.node_indexing, input_tensor_type=self.input_tensor_type, ragged_validate=self.ragged_validate) self.gather_n = GatherNodesOutgoing( node_indexing=self.node_indexing, partition_type=self.partition_type, input_tensor_type=self.input_tensor_type, is_sorted=self.is_sorted, ragged_validate=self.ragged_validate, has_unconnected=self.has_unconnected) self.lay_mult = Multiply(input_tensor_type=self.input_tensor_type, ragged_validate=self.ragged_validate)
def make_graph_sage( # Input input_node_shape, input_edge_shape, input_embedd: dict = None, # Output output_embedd: dict = None, output_mlp: dict = None, # Model specific parameter depth=3, use_edge_features=False, node_mlp_args: dict = None, edge_mlp_args: dict = None, pooling_args: dict = None): """ Generate Interaction network. Args: input_node_shape (list): Shape of node features. If shape is (None,) embedding layer is used. input_edge_shape (list): Shape of edge features. If shape is (None,) embedding layer is used. input_embedd (dict): Dictionary of embedding parameters used if input shape is None. Default is {'input_node_vocab': 95, 'input_edge_vocab': 5, 'input_state_vocab': 100, 'input_node_embedd': 64, 'input_edge_embedd': 64, 'input_state_embedd': 64, 'input_type': 'ragged'}. output_embedd (dict): Dictionary of embedding parameters of the graph network. Default is {"output_mode": 'graph', "output_type": 'padded'}. output_mlp (dict): Dictionary of arguments for final MLP regression or classifcation layer. Default is {"use_bias": [True, True, False], "units": [25, 10, 1], "activation": ['relu', 'relu', 'sigmoid']}. depth (int): Number of convolution layers. Default is 3. use_edge_features (bool): Whether to concatenate edges with nodes in aggregate. Default is False. node_mlp_args (dict): Dictionary of arguments for MLP for node update. Default is {"units": [100, 50], "use_bias": True, "activation": ['relu', "linear"]} edge_mlp_args (dict): Dictionary of arguments for MLP for interaction update. Default is {"units": [100, 100, 100, 100, 50], "activation": ['relu', 'relu', 'relu', 'relu', "linear"]} pooling_args (dict): Dictionary for message pooling arguments. Default is {'is_sorted': False, 'has_unconnected': True, 'pooling_method': "segment_mean"} Returns: model (tf.keras.model): Interaction model. """ # default values model_default = { 'input_embedd': { 'input_node_vocab': 95, 'input_edge_vocab': 5, 'input_state_vocab': 100, 'input_node_embedd': 64, 'input_edge_embedd': 64, 'input_state_embedd': 64, 'input_tensor_type': 'ragged' }, 'output_embedd': { "output_mode": 'graph', "output_tensor_type": 'padded' }, 'output_mlp': { "use_bias": [True, True, False], "units": [25, 10, 1], "activation": ['relu', 'relu', 'sigmoid'] }, 'node_mlp_args': { "units": [100, 50], "use_bias": True, "activation": ['relu', "linear"] }, 'edge_mlp_args': { "units": [100, 50], "use_bias": True, "activation": ['relu', "linear"] }, 'pooling_args': { 'is_sorted': False, 'has_unconnected': True, 'pooling_method': "segment_mean" } } # Update default values input_embedd = update_model_args(model_default['input_embedd'], input_embedd) output_embedd = update_model_args(model_default['output_embedd'], output_embedd) output_mlp = update_model_args(model_default['output_mlp'], output_mlp) node_mlp_args = update_model_args(model_default['node_mlp_args'], node_mlp_args) edge_mlp_args = update_model_args(model_default['edge_mlp_args'], edge_mlp_args) pooling_args = update_model_args(model_default['pooling_args'], pooling_args) pooling_nodes_args = { "input_tensor_type": 'ragged', "node_indexing": 'sample', 'pooling_method': "mean" } gather_args = {"node_indexing": 'sample'} concat_args = {"axis": -1, "input_tensor_type": 'ragged'} # Make input embedding, if no feature dimension node_input, n, edge_input, ed, edge_index_input, _, _ = generate_standard_graph_input( input_node_shape, input_edge_shape, None, **input_embedd) edi = edge_index_input for i in range(0, depth): # upd = GatherNodes()([n,edi]) eu = GatherNodesOutgoing(**gather_args)([n, edi]) if use_edge_features: eu = Concatenate(**concat_args)([eu, ed]) eu = MLP(**edge_mlp_args)(eu) # Pool message if pooling_args['pooling_method'] in ["LSTM", "lstm"]: nu = PoolingLocalEdgesLSTM(**pooling_args)([n, eu, edi]) else: nu = PoolingLocalMessages(**pooling_args)( [n, eu, edi]) # Summing for each node connection nu = Concatenate(**concat_args)( [n, nu]) # Concatenate node features with new edge updates n = MLP(**node_mlp_args)(nu) n = LayerNormalization(axis=-1)(n) # Normalize # Regression layer on output if output_embedd["output_mode"] == 'graph': out = PoolingNodes(**pooling_nodes_args)(n) output_mlp.update({"input_tensor_type": "tensor"}) out = MLP(**output_mlp)(out) main_output = ks.layers.Flatten()(out) # will be tensor else: # node embedding out = MLP(**output_mlp)(n) main_output = ChangeTensorType(input_tensor_type='ragged', output_tensor_type="tensor")(out) model = tf.keras.models.Model( inputs=[node_input, edge_input, edge_index_input], outputs=main_output) return model