def _call_single(self, x, a): # Reshape kernels for efficient message-passing kernel = tf.reshape(self.kernel, (-1, self.attn_heads * self.channels)) attn_kernel_self = ops.transpose(self.attn_kernel_self, (2, 1, 0)) attn_kernel_neighs = ops.transpose(self.attn_kernel_neighs, (2, 1, 0)) # Prepare message-passing indices = a.indices N = tf.shape(x, out_type=indices.dtype)[-2] indices = ops.add_self_loops_indices(indices, N) targets, sources = indices[:, 1], indices[:, 0] # Update node features x = K.dot(x, kernel) x = tf.reshape(x, (-1, self.attn_heads, self.channels)) # Compute attention attn_for_self = tf.reduce_sum(x * attn_kernel_self, -1) attn_for_self = tf.gather(attn_for_self, targets) attn_for_neighs = tf.reduce_sum(x * attn_kernel_neighs, -1) attn_for_neighs = tf.gather(attn_for_neighs, sources) attn_coef = attn_for_self + attn_for_neighs attn_coef = tf.nn.leaky_relu(attn_coef, alpha=0.2) attn_coef = ops.unsorted_segment_softmax(attn_coef, targets, N) attn_coef = self.dropout(attn_coef) attn_coef = attn_coef[..., None] # Update representation output = attn_coef * tf.gather(x, sources) output = tf.math.unsorted_segment_sum(output, targets, N) return output, attn_coef
def _call_single(self, X, A): # Reshape kernels for efficient message-passing kernel = tf.reshape(self.kernel, (-1, self.attn_heads * self.channels)) attn_kernel_self = ops.transpose(self.attn_kernel_self, (2, 1, 0)) attn_kernel_neighs = ops.transpose(self.attn_kernel_neighs, (2, 1, 0)) # Prepare message-passing indices = A.indices N = tf.shape(X, out_type=indices.dtype)[0] indices = ops.sparse_add_self_loops(indices, N) targets, sources = indices[:, -2], indices[:, -1] # Update node features X = ops.dot(X, kernel) X = tf.reshape(X, (-1, self.attn_heads, self.channels)) # Compute attention attn_for_self = tf.reduce_sum(X * attn_kernel_self, -1) attn_for_self = tf.gather(attn_for_self, targets) attn_for_neighs = tf.reduce_sum(X * attn_kernel_neighs, -1) attn_for_neighs = tf.gather(attn_for_neighs, sources) attn_coef = attn_for_self + attn_for_neighs attn_coef = tf.nn.leaky_relu(attn_coef, alpha=0.2) attn_coef = ops.unsorted_segment_softmax(attn_coef, targets, N) attn_coef = self.dropout(attn_coef) attn_coef = attn_coef[..., None] # Update representation output = attn_coef * tf.gather(X, sources) output = ops.scatter_sum(targets, output, N) return output, attn_coef
def message(self, X, X_norm=None): X_j = self.get_j(X) X_norm_i = self.get_i(X_norm) X_norm_j = self.get_j(X_norm) alpha = self.beta * tf.reduce_sum(X_norm_i * X_norm_j, axis=-1) alpha = ops.unsorted_segment_softmax(alpha, self.index_i, self.N) alpha = alpha[:, None] return alpha * X_j
def message(self, x, x_norm=None): x_j = self.get_j(x) x_norm_i = self.get_i(x_norm) x_norm_j = self.get_j(x_norm) alpha = self.beta * tf.reduce_sum(x_norm_i * x_norm_j, axis=-1) alpha = ops.unsorted_segment_softmax(alpha, self.index_i, self.n_nodes) alpha = alpha[:, None] return alpha * x_j
def message(self, x, x_norm=None): x_j = self.get_j(x) x_norm_i = self.get_i(x_norm) x_norm_j = self.get_j(x_norm) alpha = self.beta * tf.reduce_sum(x_norm_i * x_norm_j, axis=-1) if len(alpha.shape) == 2: alpha = tf.transpose(alpha) # For mixed mode alpha = ops.unsorted_segment_softmax(alpha, self.index_i, self.n_nodes) if len(alpha.shape) == 2: alpha = tf.transpose(alpha) # For mixed mode alpha = alpha[..., None] return alpha * x_j
def call(self, inputs): if self.data_mode == "disjoint": X, I = inputs if K.ndim(I) == 2: I = I[:, 0] else: X = inputs attn_coeff = K.dot(X, self.attn_kernel) attn_coeff = K.squeeze(attn_coeff, -1) if self.data_mode == "single": attn_coeff = K.softmax(attn_coeff) output = K.dot(attn_coeff[None, ...], X) elif self.data_mode == "batch": attn_coeff = K.softmax(attn_coeff) output = K.batch_dot(attn_coeff, X) else: attn_coeff = ops.unsorted_segment_softmax(attn_coeff, I, K.shape(X)[0]) output = attn_coeff[:, None] * X output = tf.math.segment_sum(output, I) return output