def dropout(x): if dropout_rate: if is_sparse_tensor(x): x = SparseDropout(dropout_rate)(x) else: x = tf.keras.layers.Dropout(dropout_rate)(x) return x
def propagate(x, adj, filters, activation=None, kernel_regularizer=None): if dropout_rate: if is_sparse_tensor(x): x = SparseDropout(dropout_rate)(x) else: x = tf.keras.layers.Dropout(dropout_rate)(x) kwargs = dict(kernel_regularizer=kernel_regularizer, use_bias=False) if linear_skip_connections: skip = tf.keras.layers.Dense(filters, **kwargs) x = graph_conv_factory(filters, **kwargs)([x, adj]) x = x + skip if activation is not None: x = activation(x) else: x = graph_conv_factory(filters, activation=activation, **kwargs)([x, adj]) return x
def __init__( self, channels, dropout_rate=0.5, activation=None, use_bias=True, kernel_initializer="glorot_uniform", bias_initializer="zeros", attn_kernel_initializer="glorot_uniform", kernel_regularizer=None, bias_regularizer=None, attn_kernel_regularizer=None, attn_bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None, attn_kernel_constraint=None, **kwargs, ): super().__init__(activity_regularizer=activity_regularizer, **kwargs) self.channels = channels self.dropout_rate = dropout_rate self.attn_kernel_initializer = initializers.get( attn_kernel_initializer) self.attn_kernel_regularizer = regularizers.get( attn_kernel_regularizer) self.attn_kernel_constraint = constraints.get(attn_kernel_constraint) self.attn_bias_regularizer = regularizers.get(attn_bias_regularizer) self.activation = activations.get(activation) kwargs = dict( kernel_regularizer=attn_kernel_regularizer, kernel_constraint=attn_kernel_constraint, kernel_initializer=attn_kernel_initializer, bias_regularizer=attn_bias_regularizer, ) self.key_dense = tf.keras.layers.Dense(1, use_bias=False, name="key-dense", **kwargs) self.query_dense = tf.keras.layers.Dense(1, use_bias=False, name="query-dense", **kwargs) # self.attn_dense = tf.keras.layers.Dense( # 2, use_bias=False, name="key-dense", **kwargs # ) self.values_dense = tf.keras.layers.Dense( channels, kernel_regularizer=kernel_regularizer, kernel_initializer=kernel_initializer, kernel_constraint=kernel_constraint, use_bias=False, name="values-dense", ) self.input_dropout = tf.keras.layers.Dropout(dropout_rate) self.attn_dropout = SparseDropout(dropout_rate) self.values_dropout = tf.keras.layers.Dropout(dropout_rate) self.use_bias = use_bias self.bias_constraint = constraints.get(bias_constraint) self.bias_initializer = initializers.get(bias_initializer) self.bias_regularizer = regularizers.get(bias_regularizer) self.bias = None