def __init__(self, pooling_method="sum", node_indexing="sample", partition_type="row_length", input_tensor_type="ragged", ragged_validate=False, is_sorted=False, has_unconnected=True, **kwargs): """Initialize layer.""" super(PoolingAdjacencyMatmul, self).__init__(**kwargs) self.pooling_method = pooling_method self.ragged_validate = ragged_validate self.is_sorted = is_sorted self.node_indexing = node_indexing self.partition_type = partition_type self.input_tensor_type = input_tensor_type self.has_unconnected = has_unconnected self._tensor_input_type_implemented = [ "ragged", "values_partition", "disjoint", "tensor", "RaggedTensor" ] self._supports_ragged_inputs = True self._test_tensor_input = kgcnn_ops_static_test_tensor_input_type( self.input_tensor_type, self._tensor_input_type_implemented, self.node_indexing)
def __init__(self, concat_nodes=True, node_indexing='sample', partition_type="row_length", input_tensor_type="ragged", ragged_validate=False, is_sorted=False, has_unconnected=True, **kwargs): """Initialize layer.""" super(GatherNodes, self).__init__(**kwargs) self.node_indexing = node_indexing self.partition_type = partition_type self.concat_nodes = concat_nodes self.input_tensor_type = input_tensor_type self.ragged_validate = ragged_validate self.is_sorted = is_sorted self.has_unconnected = has_unconnected self._tensor_input_type_implemented = [ "ragged", "values_partition", "disjoint", "tensor", "RaggedTensor" ] self._supports_ragged_inputs = True self._test_tensor_input = kgcnn_ops_static_test_tensor_input_type( self.input_tensor_type, self._tensor_input_type_implemented, self.node_indexing)
def __init__(self, k=0.1, kernel_initializer='glorot_uniform', kernel_regularizer=None, kernel_constraint=None, partition_type="row_length", node_indexing="sample", input_tensor_type="ragged", ragged_validate=False, **kwargs): """Initialize Layer.""" super(PoolingTopK, self).__init__(**kwargs) self.k = k self.kernel_initializer = ks.initializers.get(kernel_initializer) self.kernel_regularizer = ks.regularizers.get(kernel_regularizer) self.kernel_constraint = ks.constraints.get(kernel_constraint) self.partition_type = partition_type self.node_indexing = node_indexing self.input_tensor_type = input_tensor_type self.ragged_validate = ragged_validate self._tensor_input_type_implemented = [ "ragged", "values_partition", "disjoint", "tensor", "RaggedTensor" ] self._supports_ragged_inputs = True self._test_tensor_input = kgcnn_ops_static_test_tensor_input_type( self.input_tensor_type, self._tensor_input_type_implemented, self.node_indexing) self.units_p = None self.kernel_p = None
def __init__(self, units, pooling_method="LSTM", node_indexing="sample", is_sorted=False, has_unconnected=True, partition_type="row_length", input_tensor_type="ragged", ragged_validate=False, activation='tanh', recurrent_activation='sigmoid', use_bias=True, kernel_initializer='glorot_uniform', recurrent_initializer='orthogonal', bias_initializer='zeros', unit_forget_bias=True, kernel_regularizer=None, recurrent_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, recurrent_constraint=None, bias_constraint=None, dropout=0.0, recurrent_dropout=0.0, return_sequences=False, return_state=False, go_backwards=False, stateful=False, time_major=False, unroll=False, **kwargs): """Initialize layer.""" super(PoolingLocalEdgesLSTM, self).__init__(**kwargs) self.pooling_method = pooling_method self.is_sorted = is_sorted self.has_unconnected = has_unconnected self.node_indexing = node_indexing self.partition_type = partition_type self.input_tensor_type = input_tensor_type self.ragged_validate = ragged_validate self._supports_ragged_inputs = True self._tensor_input_type_implemented = ["ragged", "values_partition", "disjoint", "tensor", "RaggedTensor"] self._test_tensor_input = kgcnn_ops_static_test_tensor_input_type(self.input_tensor_type, self._tensor_input_type_implemented, self.node_indexing) self.lstm_unit = ks.layers.LSTM(units=units, activation=activation, recurrent_activation=recurrent_activation, use_bias=use_bias, kernel_initializer=kernel_initializer, recurrent_initializer=recurrent_initializer, bias_initializer=bias_initializer, unit_forget_bias=unit_forget_bias, kernel_regularizer=kernel_regularizer, recurrent_regularizer=recurrent_regularizer, bias_regularizer=bias_regularizer, activity_regularizer=activity_regularizer, kernel_constraint=kernel_constraint, recurrent_constraint=recurrent_constraint, bias_constraint=bias_constraint, dropout=dropout, recurrent_dropout=recurrent_dropout, return_sequences=return_sequences, return_state=return_state, go_backwards=go_backwards, stateful=stateful, time_major=time_major, unroll=unroll) if self.pooling_method not in ["LSTM", "lstm"]: print("Warning: Pooling method does not match with layer, expected 'LSTM' but got", self.pooling_method)
def __init__(self, node_indexing="sample", partition_type="row_length", input_tensor_type="ragged", output_tensor_type=None, ragged_validate=False, is_sorted=False, has_unconnected=True, is_directed=True, **kwargs): """Initialize layer.""" super(GraphBaseLayer, self).__init__(**kwargs) self.is_directed = is_directed self.node_indexing = node_indexing self.partition_type = partition_type self.input_tensor_type = input_tensor_type if output_tensor_type is None: self.output_tensor_type = input_tensor_type else: self.output_tensor_type = output_tensor_type self.ragged_validate = ragged_validate self.is_sorted = is_sorted self.has_unconnected = has_unconnected self._supports_ragged_inputs = True self._tensor_input_type_implemented = [ "ragged", "values_partition", "disjoint", "tensor", "RaggedTensor", "Tensor" ] self._tensor_input_type_found = [] self._test_tensor_input = kgcnn_ops_static_test_tensor_input_type( self.input_tensor_type, self._tensor_input_type_implemented, self.node_indexing) self._all_kgcnn_info = { "node_indexing": self.node_indexing, "partition_type": self.partition_type, "input_tensor_type": self.input_tensor_type, "ragged_validate": self.ragged_validate, "is_sorted": self.is_sorted, "has_unconnected": self.has_unconnected, "output_tensor_type": self.output_tensor_type, "is_directed": self.is_directed }
def __init__(self, units, activation='tanh', recurrent_activation='sigmoid', use_bias=True, kernel_initializer='glorot_uniform', recurrent_initializer='orthogonal', bias_initializer='zeros', kernel_regularizer=None, recurrent_regularizer=None, bias_regularizer=None, kernel_constraint=None, recurrent_constraint=None, bias_constraint=None, dropout=0.0, recurrent_dropout=0.0, reset_after=True, node_indexing="sample", is_sorted=False, has_unconnected=True, partition_type="row_length", input_tensor_type="ragged", ragged_validate=False, **kwargs): """Initialize layer.""" super(GRUupdate, self).__init__(**kwargs) self.units = units self.is_sorted = is_sorted self.has_unconnected = has_unconnected self.node_indexing = node_indexing self.partition_type = partition_type self.input_tensor_type = input_tensor_type self.ragged_validate = ragged_validate self._supports_ragged_inputs = True self._tensor_input_type_implemented = ["ragged", "values_partition", "disjoint", "tensor", "RaggedTensor"] self._test_tensor_input = kgcnn_ops_static_test_tensor_input_type(self.input_tensor_type, self._tensor_input_type_implemented, self.node_indexing) self.gru_cell = tf.keras.layers.GRUCell(units=units, activation=activation, recurrent_activation=recurrent_activation, use_bias=use_bias, kernel_initializer=kernel_initializer, recurrent_initializer=recurrent_initializer, bias_initializer=bias_initializer, kernel_regularizer=kernel_regularizer, recurrent_regularizer=recurrent_regularizer, bias_regularizer=bias_regularizer, kernel_constraint=kernel_constraint, recurrent_constraint=recurrent_constraint, bias_constraint=bias_constraint, dropout=dropout, recurrent_dropout=recurrent_dropout, reset_after=reset_after)
def __init__(self, node_indexing="sample", partition_type="row_length", input_tensor_type="ragged", ragged_validate=False, **kwargs): """Initialize Layer.""" super(UnPoolingTopK, self).__init__(**kwargs) self.partition_type = partition_type self.node_indexing = node_indexing self.input_tensor_type = input_tensor_type self.ragged_validate = ragged_validate self._tensor_input_type_implemented = [ "ragged", "values_partition", "disjoint", "tensor", "RaggedTensor" ] self._supports_ragged_inputs = True self._test_tensor_input = kgcnn_ops_static_test_tensor_input_type( self.input_tensor_type, self._tensor_input_type_implemented, self.node_indexing)
def __init__(self, units, activation=None, use_bias=True, kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None, kernel_initializer='glorot_uniform', bias_initializer='zeros', use_edge_features=False, node_indexing="sample", is_sorted=False, has_unconnected=True, partition_type="row_length", input_tensor_type="ragged", ragged_validate=False, **kwargs): """Initialize layer.""" super(AttentionHeadGAT, self).__init__(**kwargs) # graph args self.is_sorted = is_sorted self.use_edge_features = use_edge_features self.has_unconnected = has_unconnected self.node_indexing = node_indexing self.partition_type = partition_type self.input_tensor_type = input_tensor_type self.ragged_validate = ragged_validate self._supports_ragged_inputs = True self._tensor_input_type_implemented = [ "ragged", "values_partition", "disjoint", "tensor", "RaggedTensor" ] self._test_tensor_input = kgcnn_ops_static_test_tensor_input_type( self.input_tensor_type, self._tensor_input_type_implemented, self.node_indexing) # dense args self.units = int(units) if activation is None and "leaky_relu" in kgcnn_custom_act: activation = {"class_name": "leaky_relu", "config": {"alpha": 0.2}} elif activation is None: activation = "relu" self.use_bias = use_bias self.ath_activation = tf.keras.activations.get(activation) self.ath_kernel_initializer = tf.keras.initializers.get( kernel_initializer) self.ath_bias_initializer = tf.keras.initializers.get(bias_initializer) self.ath_kernel_regularizer = tf.keras.regularizers.get( kernel_regularizer) self.ath_bias_regularizer = tf.keras.regularizers.get(bias_regularizer) self.ath_activity_regularizer = tf.keras.regularizers.get( activity_regularizer) self.ath_kernel_constraint = tf.keras.constraints.get( kernel_constraint) self.ath_bias_constraint = tf.keras.constraints.get(bias_constraint) kernel_args = { "use_bias": use_bias, "kernel_regularizer": kernel_regularizer, "activity_regularizer": activity_regularizer, "bias_regularizer": bias_regularizer, "kernel_constraint": kernel_constraint, "bias_constraint": bias_constraint, "kernel_initializer": kernel_initializer, "bias_initializer": bias_initializer } dens_args = { "ragged_validate": self.ragged_validate, "input_tensor_type": self.input_tensor_type } dens_args.update(kernel_args) gather_args = { "input_tensor_type": self.input_tensor_type, "node_indexing": self.node_indexing } pooling_args = { "node_indexing": node_indexing, "partition_type": partition_type, "has_unconnected": has_unconnected, "is_sorted": is_sorted, "ragged_validate": self.ragged_validate, "input_tensor_type": self.input_tensor_type } self.lay_linear_trafo = Dense(units, activation="linear", **dens_args) self.lay_alpha = Dense(1, activation=activation, **dens_args) self.lay_gather_in = GatherNodesIngoing(**gather_args) self.lay_gather_out = GatherNodesOutgoing(**gather_args) self.lay_concat = Concatenate(axis=-1, input_tensor_type=self.input_tensor_type) self.lay_pool_attention = PoolingLocalEdgesAttention(**pooling_args) self.lay_final_activ = Activation( activation=activation, input_tensor_type=self.input_tensor_type)
def __init__(self, node_embed=None, edge_embed=None, env_embed=None, use_bias=True, activation=None, kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None, kernel_initializer='glorot_uniform', bias_initializer='zeros', pooling_method="mean", is_sorted=False, has_unconnected=True, partition_type="row_length", node_indexing='sample', input_tensor_type="ragged", ragged_validate=False, **kwargs): """Initialize layer.""" super(MEGnetBlock, self).__init__(**kwargs) self.pooling_method = pooling_method self.is_sorted = is_sorted self.has_unconnected = has_unconnected self.partition_type = partition_type self.node_indexing = node_indexing self.input_tensor_type = input_tensor_type self.ragged_validate = ragged_validate self._tensor_input_type_implemented = [ "ragged", "values_partition", "disjoint", "tensor", "RaggedTensor" ] self._supports_ragged_inputs = True self._test_tensor_input = kgcnn_ops_static_test_tensor_input_type( self.input_tensor_type, self._tensor_input_type_implemented, self.node_indexing) if node_embed is None: node_embed = [16, 16, 16] if env_embed is None: env_embed = [16, 16, 16] if edge_embed is None: edge_embed = [16, 16, 16] self.node_embed = node_embed self.edge_embed = edge_embed self.env_embed = env_embed self.use_bias = use_bias if activation is None and 'softplus2' in kgcnn_custom_act: activation = 'softplus2' elif activation is None: activation = "selu" self.megnet_activation = tf.keras.activations.get(activation) self.megnet_kernel_initializer = tf.keras.initializers.get( kernel_initializer) self.megnet_bias_initializer = tf.keras.initializers.get( bias_initializer) self.megnet_kernel_regularizer = tf.keras.regularizers.get( kernel_regularizer) self.megnet_bias_regularizer = tf.keras.regularizers.get( bias_regularizer) self.megnet_activity_regularizer = tf.keras.regularizers.get( activity_regularizer) self.megnet_kernel_constraint = tf.keras.constraints.get( kernel_constraint) self.megnet_bias_constraint = tf.keras.constraints.get(bias_constraint) kernel_args = { "kernel_regularizer": kernel_regularizer, "activity_regularizer": activity_regularizer, "bias_regularizer": bias_regularizer, "kernel_constraint": kernel_constraint, "bias_constraint": bias_constraint, "kernel_initializer": kernel_initializer, "bias_initializer": bias_initializer } mlp_args = { "input_tensor_type": self.input_tensor_type, "ragged_validate": self.ragged_validate } mlp_args.update(kernel_args) pool_args = { "pooling_method": self.pooling_method, "is_sorted": self.is_sorted, "has_unconnected": self.has_unconnected, "input_tensor_type": self.input_tensor_type, "ragged_validate": self.ragged_validate, "partition_type": self.partition_type, "node_indexing": self.node_indexing } gather_args = { "is_sorted": self.is_sorted, "has_unconnected": self.has_unconnected, "input_tensor_type": self.input_tensor_type, "ragged_validate": self.ragged_validate, "partition_type": self.partition_type, "node_indexing": self.node_indexing } # Node self.lay_phi_n = Dense(units=self.node_embed[0], activation=activation, use_bias=self.use_bias, **mlp_args) self.lay_phi_n_1 = Dense(units=self.node_embed[1], activation=activation, use_bias=self.use_bias, **mlp_args) self.lay_phi_n_2 = Dense(units=self.node_embed[2], activation='linear', use_bias=self.use_bias, **mlp_args) self.lay_esum = PoolingLocalEdges(**pool_args) self.lay_gather_un = GatherState(**gather_args) self.lay_conc_nu = Concatenate( axis=-1, input_tensor_type=self.input_tensor_type) # Edge self.lay_phi_e = Dense(units=self.edge_embed[0], activation=activation, use_bias=self.use_bias, **mlp_args) self.lay_phi_e_1 = Dense(units=self.edge_embed[1], activation=activation, use_bias=self.use_bias, **mlp_args) self.lay_phi_e_2 = Dense(units=self.edge_embed[2], activation='linear', use_bias=self.use_bias, **mlp_args) self.lay_gather_n = GatherNodes(**gather_args) self.lay_gather_ue = GatherState(**gather_args) self.lay_conc_enu = Concatenate( axis=-1, input_tensor_type=self.input_tensor_type) # Environment self.lay_usum_e = PoolingGlobalEdges(**pool_args) self.lay_usum_n = PoolingNodes(**pool_args) self.lay_conc_u = Concatenate(axis=-1, input_tensor_type="tensor") self.lay_phi_u = ks.layers.Dense(units=self.env_embed[0], activation=activation, use_bias=self.use_bias, **kernel_args) self.lay_phi_u_1 = ks.layers.Dense(units=self.env_embed[1], activation=activation, use_bias=self.use_bias, **kernel_args) self.lay_phi_u_2 = ks.layers.Dense(units=self.env_embed[2], activation='linear', use_bias=self.use_bias, **kernel_args)
def __init__(self, units, activation=None, use_bias=True, kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None, kernel_initializer='glorot_uniform', bias_initializer='zeros', node_indexing='sample', pooling_method='sum', is_sorted=False, has_unconnected=True, normalize_by_weights=False, partition_type="row_length", input_tensor_type="ragged", ragged_validate=False, **kwargs): """Initialize layer.""" super(GCN, self).__init__(**kwargs) self.node_indexing = node_indexing self.normalize_by_weights = normalize_by_weights self.partition_type = partition_type self.pooling_method = pooling_method self.has_unconnected = has_unconnected self.is_sorted = is_sorted self.input_tensor_type = input_tensor_type self.ragged_validate = ragged_validate self._tensor_input_type_implemented = [ "ragged", "values_partition", "disjoint", "tensor", "RaggedTensor" ] self._supports_ragged_inputs = True self._test_tensor_input = kgcnn_ops_static_test_tensor_input_type( self.input_tensor_type, self._tensor_input_type_implemented, self.node_indexing) if activation is None and 'leaky_relu' in kgcnn_custom_act: activation = {"class_name": "leaky_relu", "config": {"alpha": 0.2}} elif activation is None: activation = "relu" self.units = units self.use_bias = use_bias self.gcn_activation = tf.keras.activations.get(activation) self.gcn_kernel_initializer = tf.keras.initializers.get( kernel_initializer) self.gcn_bias_initializer = tf.keras.initializers.get(bias_initializer) self.gcn_kernel_regularizer = tf.keras.regularizers.get( kernel_regularizer) self.gcn_bias_regularizer = tf.keras.regularizers.get(bias_regularizer) self.gcn_activity_regularizer = tf.keras.regularizers.get( activity_regularizer) self.gcn_kernel_constraint = tf.keras.constraints.get( kernel_constraint) self.gcn_bias_constraint = tf.keras.constraints.get(bias_constraint) kernel_args = { "kernel_regularizer": kernel_regularizer, "activity_regularizer": activity_regularizer, "bias_regularizer": bias_regularizer, "kernel_constraint": kernel_constraint, "bias_constraint": bias_constraint, "kernel_initializer": kernel_initializer, "bias_initializer": bias_initializer } # Layers self.lay_gather = GatherNodesOutgoing( node_indexing=self.node_indexing, partition_type=self.partition_type, is_sorted=self.is_sorted, has_unconnected=self.has_unconnected, ragged_validate=self.ragged_validate, input_tensor_type=self.input_tensor_type) self.lay_dense = Dense(units=self.units, use_bias=self.use_bias, activation='linear', input_tensor_type=self.input_tensor_type, ragged_validate=self.ragged_validate, **kernel_args) self.lay_pool = PoolingWeightedLocalEdges( pooling_method=self.pooling_method, is_sorted=self.is_sorted, has_unconnected=self.has_unconnected, node_indexing=self.node_indexing, normalize_by_weights=self.normalize_by_weights, partition_type=self.partition_type, ragged_validate=self.ragged_validate, input_tensor_type=self.input_tensor_type) self.lay_act = Activation(activation, ragged_validate=self.ragged_validate, input_tensor_type=self.input_tensor_type)
def __init__(self, units=128, use_bias=True, activation=None, kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None, kernel_initializer='glorot_uniform', bias_initializer='zeros', cfconv_pool='sum', is_sorted=False, has_unconnected=True, partition_type="row_length", node_indexing='sample', input_tensor_type="ragged", ragged_validate=False, **kwargs): """Initialize Layer.""" super(SchNetInteraction, self).__init__(**kwargs) self.is_sorted = is_sorted self.has_unconnected = has_unconnected self.partition_type = partition_type self.node_indexing = node_indexing self.input_tensor_type = input_tensor_type self.ragged_validate = ragged_validate self._tensor_input_type_implemented = [ "ragged", "values_partition", "disjoint", "tensor", "RaggedTensor" ] self._supports_ragged_inputs = True self._test_tensor_input = kgcnn_ops_static_test_tensor_input_type( self.input_tensor_type, self._tensor_input_type_implemented, self.node_indexing) self.cfconv_pool = cfconv_pool self.use_bias = use_bias self.units = units if activation is None and 'shifted_softplus' in kgcnn_custom_act: activation = 'shifted_softplus' elif activation is None: activation = "selu" self.schnet_activation = tf.keras.activations.get(activation) self.schnet_kernel_initializer = tf.keras.initializers.get( kernel_initializer) self.schnet_bias_initializer = tf.keras.initializers.get( bias_initializer) self.schnet_kernel_regularizer = tf.keras.regularizers.get( kernel_regularizer) self.schnet_bias_regularizer = tf.keras.regularizers.get( bias_regularizer) self.schnet_activity_regularizer = tf.keras.regularizers.get( activity_regularizer) self.schnet_kernel_constraint = tf.keras.constraints.get( kernel_constraint) self.schnet_bias_constraint = tf.keras.constraints.get(bias_constraint) kernel_args = { "kernel_regularizer": kernel_regularizer, "activity_regularizer": activity_regularizer, "bias_regularizer": bias_regularizer, "kernel_constraint": kernel_constraint, "bias_constraint": bias_constraint, "kernel_initializer": kernel_initializer, "bias_initializer": bias_initializer } # Layers self.lay_cfconv = SchNetCFconv( units=self.units, activation=activation, use_bias=self.use_bias, cfconv_pool=self.cfconv_pool, has_unconnected=self.has_unconnected, is_sorted=self.is_sorted, partition_type=self.partition_type, input_tensor_type=self.input_tensor_type, ragged_validate=self.ragged_validate, node_indexing=self.node_indexing, **kernel_args) self.lay_dense1 = Dense(units=self.units, activation='linear', use_bias=False, input_tensor_type=self.input_tensor_type, ragged_validate=self.ragged_validate, **kernel_args) self.lay_dense2 = Dense(units=self.units, activation=activation, use_bias=self.use_bias, input_tensor_type=self.input_tensor_type, ragged_validate=self.ragged_validate, **kernel_args) self.lay_dense3 = Dense(units=self.units, activation='linear', use_bias=self.use_bias, input_tensor_type=self.input_tensor_type, ragged_validate=self.ragged_validate, **kernel_args) self.lay_add = Add(input_tensor_type=self.input_tensor_type, ragged_validate=self.ragged_validate)
def __init__( self, # Arg channels, T=3, pooling_method='mean', init_qstar='mean', node_indexing="sample", is_sorted=False, has_unconnected=True, partition_type="row_length", input_tensor_type="ragged", ragged_validate=False, # Args for LSTM activation="tanh", recurrent_activation="sigmoid", use_bias=True, kernel_initializer="glorot_uniform", recurrent_initializer="orthogonal", bias_initializer="zeros", unit_forget_bias=True, kernel_regularizer=None, recurrent_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, recurrent_constraint=None, bias_constraint=None, dropout=0.0, recurrent_dropout=0.0, implementation=2, return_sequences=False, # Should not be changed here return_state=False, # Should not be changed here go_backwards=False, # Should not be changed here stateful=False, time_major=False, unroll=False, **kwargs): """Init layer.""" super(Set2Set, self).__init__(**kwargs) # Number of Channels to use in LSTM self.channels = channels self.T = T # Number of Iterations to work on memory self.pooling_method = pooling_method self.init_qstar = init_qstar self.partition_type = partition_type self.is_sorted = is_sorted self.has_unconnected = has_unconnected self.node_indexing = node_indexing self.input_tensor_type = input_tensor_type self.ragged_validate = ragged_validate self._supports_ragged_inputs = True self._tensor_input_type_implemented = [ "ragged", "values_partition", "disjoint", "tensor", "RaggedTensor" ] self._test_tensor_input = kgcnn_ops_static_test_tensor_input_type( self.input_tensor_type, self._tensor_input_type_implemented, self.node_indexing) if self.pooling_method == 'mean': self._pool = ksb.mean elif self.pooling_method == 'sum': self._pool = ksb.sum else: raise TypeError("Unknown pooling, choose: 'mean', 'sum', ...") if self.init_qstar == 'mean': self.qstar0 = self.init_qstar_mean else: self.qstar0 = self.init_qstar_0 # ... # LSTM Layer to run on m self.lay_lstm = ks.layers.LSTM( channels, activation=activation, recurrent_activation=recurrent_activation, use_bias=use_bias, kernel_initializer=kernel_initializer, recurrent_initializer=recurrent_initializer, bias_initializer=bias_initializer, unit_forget_bias=unit_forget_bias, kernel_regularizer=kernel_regularizer, recurrent_regularizer=recurrent_regularizer, bias_regularizer=bias_regularizer, activity_regularizer=activity_regularizer, kernel_constraint=kernel_constraint, recurrent_constraint=recurrent_constraint, bias_constraint=bias_constraint, dropout=dropout, recurrent_dropout=recurrent_dropout, implementation=implementation, return_sequences=return_sequences, return_state=return_state, go_backwards=go_backwards, stateful=stateful, time_major=time_major, unroll=unroll)