Exemplo n.º 1
0
    def _call_single(self, x, a):
        # Reshape kernels for efficient message-passing
        kernel = tf.reshape(self.kernel, (-1, self.attn_heads * self.channels))
        attn_kernel_self = ops.transpose(self.attn_kernel_self, (2, 1, 0))
        attn_kernel_neighs = ops.transpose(self.attn_kernel_neighs, (2, 1, 0))

        # Prepare message-passing
        indices = a.indices
        N = tf.shape(x, out_type=indices.dtype)[-2]
        indices = ops.add_self_loops_indices(indices, N)
        targets, sources = indices[:, 1], indices[:, 0]

        # Update node features
        x = K.dot(x, kernel)
        x = tf.reshape(x, (-1, self.attn_heads, self.channels))

        # Compute attention
        attn_for_self = tf.reduce_sum(x * attn_kernel_self, -1)
        attn_for_self = tf.gather(attn_for_self, targets)
        attn_for_neighs = tf.reduce_sum(x * attn_kernel_neighs, -1)
        attn_for_neighs = tf.gather(attn_for_neighs, sources)

        attn_coef = attn_for_self + attn_for_neighs
        attn_coef = tf.nn.leaky_relu(attn_coef, alpha=0.2)
        attn_coef = ops.unsorted_segment_softmax(attn_coef, targets, N)
        attn_coef = self.dropout(attn_coef)
        attn_coef = attn_coef[..., None]

        # Update representation
        output = attn_coef * tf.gather(x, sources)
        output = tf.math.unsorted_segment_sum(output, targets, N)

        return output, attn_coef
Exemplo n.º 2
0
    def _call_single(self, X, A):
        # Reshape kernels for efficient message-passing
        kernel = tf.reshape(self.kernel, (-1, self.attn_heads * self.channels))
        attn_kernel_self = ops.transpose(self.attn_kernel_self, (2, 1, 0))
        attn_kernel_neighs = ops.transpose(self.attn_kernel_neighs, (2, 1, 0))

        # Prepare message-passing
        indices = A.indices
        N = tf.shape(X, out_type=indices.dtype)[0]
        indices = ops.sparse_add_self_loops(indices, N)
        targets, sources = indices[:, -2], indices[:, -1]

        # Update node features
        X = ops.dot(X, kernel)
        X = tf.reshape(X, (-1, self.attn_heads, self.channels))

        # Compute attention
        attn_for_self = tf.reduce_sum(X * attn_kernel_self, -1)
        attn_for_self = tf.gather(attn_for_self, targets)
        attn_for_neighs = tf.reduce_sum(X * attn_kernel_neighs, -1)
        attn_for_neighs = tf.gather(attn_for_neighs, sources)

        attn_coef = attn_for_self + attn_for_neighs
        attn_coef = tf.nn.leaky_relu(attn_coef, alpha=0.2)
        attn_coef = ops.unsorted_segment_softmax(attn_coef, targets, N)
        attn_coef = self.dropout(attn_coef)
        attn_coef = attn_coef[..., None]

        # Update representation
        output = attn_coef * tf.gather(X, sources)
        output = ops.scatter_sum(targets, output, N)

        return output, attn_coef
Exemplo n.º 3
0
    def message(self, X, X_norm=None):
        X_j = self.get_j(X)
        X_norm_i = self.get_i(X_norm)
        X_norm_j = self.get_j(X_norm)
        alpha = self.beta * tf.reduce_sum(X_norm_i * X_norm_j, axis=-1)
        alpha = ops.unsorted_segment_softmax(alpha, self.index_i, self.N)
        alpha = alpha[:, None]

        return alpha * X_j
Exemplo n.º 4
0
    def message(self, x, x_norm=None):
        x_j = self.get_j(x)
        x_norm_i = self.get_i(x_norm)
        x_norm_j = self.get_j(x_norm)
        alpha = self.beta * tf.reduce_sum(x_norm_i * x_norm_j, axis=-1)
        alpha = ops.unsorted_segment_softmax(alpha, self.index_i, self.n_nodes)
        alpha = alpha[:, None]

        return alpha * x_j
Exemplo n.º 5
0
    def message(self, x, x_norm=None):
        x_j = self.get_j(x)
        x_norm_i = self.get_i(x_norm)
        x_norm_j = self.get_j(x_norm)
        alpha = self.beta * tf.reduce_sum(x_norm_i * x_norm_j, axis=-1)

        if len(alpha.shape) == 2:
            alpha = tf.transpose(alpha)  # For mixed mode
        alpha = ops.unsorted_segment_softmax(alpha, self.index_i, self.n_nodes)
        if len(alpha.shape) == 2:
            alpha = tf.transpose(alpha)  # For mixed mode
        alpha = alpha[..., None]

        return alpha * x_j
Exemplo n.º 6
0
    def call(self, inputs):
        if self.data_mode == "disjoint":
            X, I = inputs
            if K.ndim(I) == 2:
                I = I[:, 0]
        else:
            X = inputs
        attn_coeff = K.dot(X, self.attn_kernel)
        attn_coeff = K.squeeze(attn_coeff, -1)
        if self.data_mode == "single":
            attn_coeff = K.softmax(attn_coeff)
            output = K.dot(attn_coeff[None, ...], X)
        elif self.data_mode == "batch":
            attn_coeff = K.softmax(attn_coeff)
            output = K.batch_dot(attn_coeff, X)
        else:
            attn_coeff = ops.unsorted_segment_softmax(attn_coeff, I,
                                                      K.shape(X)[0])
            output = attn_coeff[:, None] * X
            output = tf.math.segment_sum(output, I)

        return output