コード例 #1
0
    def _call_single(self, X, A):
        # Reshape kernels for efficient message-passing
        kernel = tf.reshape(self.kernel, (-1, self.attn_heads * self.channels))
        attn_kernel_self = ops.transpose(self.attn_kernel_self, (2, 1, 0))
        attn_kernel_neighs = ops.transpose(self.attn_kernel_neighs, (2, 1, 0))

        # Prepare message-passing
        indices = A.indices
        N = tf.shape(X, out_type=indices.dtype)[0]
        indices = ops.sparse_add_self_loops(indices, N)
        targets, sources = indices[:, -2], indices[:, -1]

        # Update node features
        X = ops.dot(X, kernel)
        X = tf.reshape(X, (-1, self.attn_heads, self.channels))

        # Compute attention
        attn_for_self = tf.reduce_sum(X * attn_kernel_self, -1)
        attn_for_self = tf.gather(attn_for_self, targets)
        attn_for_neighs = tf.reduce_sum(X * attn_kernel_neighs, -1)
        attn_for_neighs = tf.gather(attn_for_neighs, sources)

        attn_coef = attn_for_self + attn_for_neighs
        attn_coef = tf.nn.leaky_relu(attn_coef, alpha=0.2)
        attn_coef = ops.unsorted_segment_softmax(attn_coef, targets, N)
        attn_coef = self.dropout(attn_coef)
        attn_coef = attn_coef[..., None]

        # Update representation
        output = attn_coef * tf.gather(X, sources)
        output = ops.scatter_sum(output, targets, N)

        return output, attn_coef
コード例 #2
0
ファイル: graphsage_conv.py プロジェクト: wxiangs/spektral
    def call(self, inputs):
        features = inputs[0]
        fltr = inputs[1]

        # Enforce sparse representation
        if not K.is_sparse(fltr):
            fltr = ops.dense_to_sparse(fltr)

        # Propagation
        indices = fltr.indices
        N = tf.shape(features, out_type=indices.dtype)[0]
        indices = ops.sparse_add_self_loops(indices, N)
        targets, sources = indices[:, -2], indices[:, -1]
        messages = tf.gather(features, sources)
        aggregated = self.aggregate_op(messages, targets, N)
        output = K.concatenate([features, aggregated])
        output = ops.dot(output, self.kernel)

        if self.use_bias:
            output = K.bias_add(output, self.bias)
        output = K.l2_normalize(output, axis=-1)
        if self.activation is not None:
            output = self.activation(output)
        return output