예제 #1
0
 def dropout(x):
     if dropout_rate:
         if is_sparse_tensor(x):
             x = SparseDropout(dropout_rate)(x)
         else:
             x = tf.keras.layers.Dropout(dropout_rate)(x)
     return x
예제 #2
0
파일: models.py 프로젝트: jackd/graph-tf
    def propagate(x, adj, filters, activation=None, kernel_regularizer=None):
        if dropout_rate:
            if is_sparse_tensor(x):
                x = SparseDropout(dropout_rate)(x)
            else:
                x = tf.keras.layers.Dropout(dropout_rate)(x)

        kwargs = dict(kernel_regularizer=kernel_regularizer, use_bias=False)
        if linear_skip_connections:
            skip = tf.keras.layers.Dense(filters, **kwargs)
            x = graph_conv_factory(filters, **kwargs)([x, adj])
            x = x + skip
            if activation is not None:
                x = activation(x)
        else:
            x = graph_conv_factory(filters, activation=activation,
                                   **kwargs)([x, adj])
        return x
예제 #3
0
    def __init__(
        self,
        channels,
        dropout_rate=0.5,
        activation=None,
        use_bias=True,
        kernel_initializer="glorot_uniform",
        bias_initializer="zeros",
        attn_kernel_initializer="glorot_uniform",
        kernel_regularizer=None,
        bias_regularizer=None,
        attn_kernel_regularizer=None,
        attn_bias_regularizer=None,
        activity_regularizer=None,
        kernel_constraint=None,
        bias_constraint=None,
        attn_kernel_constraint=None,
        **kwargs,
    ):
        super().__init__(activity_regularizer=activity_regularizer, **kwargs)
        self.channels = channels
        self.dropout_rate = dropout_rate
        self.attn_kernel_initializer = initializers.get(
            attn_kernel_initializer)
        self.attn_kernel_regularizer = regularizers.get(
            attn_kernel_regularizer)
        self.attn_kernel_constraint = constraints.get(attn_kernel_constraint)
        self.attn_bias_regularizer = regularizers.get(attn_bias_regularizer)
        self.activation = activations.get(activation)

        kwargs = dict(
            kernel_regularizer=attn_kernel_regularizer,
            kernel_constraint=attn_kernel_constraint,
            kernel_initializer=attn_kernel_initializer,
            bias_regularizer=attn_bias_regularizer,
        )
        self.key_dense = tf.keras.layers.Dense(1,
                                               use_bias=False,
                                               name="key-dense",
                                               **kwargs)
        self.query_dense = tf.keras.layers.Dense(1,
                                                 use_bias=False,
                                                 name="query-dense",
                                                 **kwargs)
        # self.attn_dense = tf.keras.layers.Dense(
        #     2, use_bias=False, name="key-dense", **kwargs
        # )
        self.values_dense = tf.keras.layers.Dense(
            channels,
            kernel_regularizer=kernel_regularizer,
            kernel_initializer=kernel_initializer,
            kernel_constraint=kernel_constraint,
            use_bias=False,
            name="values-dense",
        )
        self.input_dropout = tf.keras.layers.Dropout(dropout_rate)
        self.attn_dropout = SparseDropout(dropout_rate)
        self.values_dropout = tf.keras.layers.Dropout(dropout_rate)

        self.use_bias = use_bias
        self.bias_constraint = constraints.get(bias_constraint)
        self.bias_initializer = initializers.get(bias_initializer)
        self.bias_regularizer = regularizers.get(bias_regularizer)
        self.bias = None