def _call_single(self, X, A): # Reshape kernels for efficient message-passing kernel = tf.reshape(self.kernel, (-1, self.attn_heads * self.channels)) attn_kernel_self = ops.transpose(self.attn_kernel_self, (2, 1, 0)) attn_kernel_neighs = ops.transpose(self.attn_kernel_neighs, (2, 1, 0)) # Prepare message-passing indices = A.indices N = tf.shape(X, out_type=indices.dtype)[0] indices = ops.sparse_add_self_loops(indices, N) targets, sources = indices[:, -2], indices[:, -1] # Update node features X = ops.dot(X, kernel) X = tf.reshape(X, (-1, self.attn_heads, self.channels)) # Compute attention attn_for_self = tf.reduce_sum(X * attn_kernel_self, -1) attn_for_self = tf.gather(attn_for_self, targets) attn_for_neighs = tf.reduce_sum(X * attn_kernel_neighs, -1) attn_for_neighs = tf.gather(attn_for_neighs, sources) attn_coef = attn_for_self + attn_for_neighs attn_coef = tf.nn.leaky_relu(attn_coef, alpha=0.2) attn_coef = ops.unsorted_segment_softmax(attn_coef, targets, N) attn_coef = self.dropout(attn_coef) attn_coef = attn_coef[..., None] # Update representation output = attn_coef * tf.gather(X, sources) output = ops.scatter_sum(output, targets, N) return output, attn_coef
def call(self, inputs): features = inputs[0] fltr = inputs[1] # Enforce sparse representation if not K.is_sparse(fltr): fltr = ops.dense_to_sparse(fltr) # Propagation indices = fltr.indices N = tf.shape(features, out_type=indices.dtype)[0] indices = ops.sparse_add_self_loops(indices, N) targets, sources = indices[:, -2], indices[:, -1] messages = tf.gather(features, sources) aggregated = self.aggregate_op(messages, targets, N) output = K.concatenate([features, aggregated]) output = ops.dot(output, self.kernel) if self.use_bias: output = K.bias_add(output, self.bias) output = K.l2_normalize(output, axis=-1) if self.activation is not None: output = self.activation(output) return output