def call(self, inputs): if len(inputs) == 3: X, A, I = inputs self.data_mode = 'disjoint' else: X, A = inputs I = tf.zeros(tf.shape(X)[:1]) self.data_mode = 'single' if K.ndim(I) == 2: I = I[:, 0] I = tf.cast(I, tf.int32) A_is_sparse = K.is_sparse(A) # Get mask y = self.compute_scores(X, A, I) N = K.shape(X)[-2] indices = ops.segment_top_k(y[:, 0], I, self.ratio, self.top_k_var) mask = tf.scatter_nd(tf.expand_dims(indices, 1), tf.ones_like(indices), (N,)) # Multiply X and y to make layer differentiable features = X * self.gating_op(y) axis = 0 if len(K.int_shape(A)) == 2 else 1 # Cannot use negative axis in tf.boolean_mask # Reduce X X_pooled = tf.boolean_mask(features, mask, axis=axis) # Compute A^2 if A_is_sparse: A_dense = tf.sparse.to_dense(A) else: A_dense = A A_squared = K.dot(A, A_dense) # Reduce A A_pooled = tf.boolean_mask(A_squared, mask, axis=axis) A_pooled = tf.boolean_mask(A_pooled, mask, axis=axis + 1) if A_is_sparse: A_pooled = ops.dense_to_sparse(A_pooled) output = [X_pooled, A_pooled] # Reduce I if self.data_mode == 'disjoint': I_pooled = tf.boolean_mask(I[:, None], mask)[:, 0] output.append(I_pooled) if self.return_mask: output.append(mask) return output
def call(self, inputs): features = inputs[0] fltr = inputs[1] # Enforce sparsity if not K.is_sparse(fltr): fltr = ops.dense_to_sparse(fltr) # Propagation features_neigh = tf.math.segment_sum( tf.gather(features, fltr.indices[:, -1]), fltr.indices[:, -2]) hidden = (1.0 + self.eps) * features + features_neigh # MLP output = self.mlp(hidden) return output
def call(self, inputs, mask=None): x, a, e = inputs # Parameters N = tf.shape(x)[-2] F = tf.shape(x)[-1] F_ = self.channels # Filter network kernel_network = e for layer in self.kernel_network_layers: kernel_network = layer(kernel_network) # Convolution mode = ops.autodetect_mode(x, a) if mode == modes.BATCH: kernel = K.reshape(kernel_network, (-1, N, N, F_, F)) output = kernel * a[..., None, None] output = tf.einsum("abcde,ace->abd", output, x) else: # Enforce sparse representation if not K.is_sparse(a): warnings.warn("Casting dense adjacency matrix to SparseTensor." "This can be an expensive operation. ") a = ops.dense_to_sparse(a) target_shape = (-1, F, F_) if mode == modes.MIXED: target_shape = (tf.shape(x)[0], ) + target_shape kernel = tf.reshape(kernel_network, target_shape) index_i = a.indices[:, 1] index_j = a.indices[:, 0] messages = tf.gather(x, index_j, axis=-2) messages = tf.einsum("...ab,...abc->...ac", messages, kernel) output = ops.scatter_sum(messages, index_i, N) if self.root: output += K.dot(x, self.root_kernel) if self.use_bias: output = K.bias_add(output, self.bias) if mask is not None: output *= mask[0] output = self.activation(output) return output
def call(self, inputs): features = inputs[0] fltr = inputs[1] if not K.is_sparse(fltr): fltr = ops.dense_to_sparse(fltr) features_neigh = self.aggregate_op( tf.gather(features, fltr.indices[:, -1]), fltr.indices[:, -2]) output = K.concatenate([features, features_neigh]) output = K.dot(output, self.kernel) if self.use_bias: output = K.bias_add(output, self.bias) if self.activation is not None: output = self.activation(output) output = K.l2_normalize(output, axis=-1) return output
def call(self, inputs): features = inputs[0] fltr = inputs[1] # Enforce sparse representation if not K.is_sparse(fltr): fltr = ops.dense_to_sparse(fltr) # Propagation targets = fltr.indices[:, -2] sources = fltr.indices[:, -1] messages = tf.gather(features, sources) aggregated = ops.scatter_sum(targets, messages, N=tf.shape(features)[0]) hidden = (1.0 + self.eps) * features + aggregated # MLP output = self.mlp(hidden) return output
def _call_single(self, inputs): X = inputs[0] # (N, F) A = inputs[1] # (N, N) E = inputs[2] # (n_edges, S) assert K.ndim( E) == 2, 'In single mode, E must have shape (n_edges, S).' # Enforce sparse representation if not K.is_sparse(A): A = ops.dense_to_sparse(A) # Parameters N = tf.shape(X)[-2] F = K.int_shape(X)[-1] F_ = self.channels # Filter network kernel_network = E for l in self.kernel_network_layers: kernel_network = l(kernel_network) # (n_edges, F * F_) target_shape = (-1, F, F_) kernel = tf.reshape(kernel_network, target_shape) # Propagation index_i = A.indices[:, -2] index_j = A.indices[:, -1] messages = tf.gather(X, index_j) messages = ops.dot(messages[:, None, :], kernel)[:, 0, :] aggregated = ops.scatter_sum(messages, index_i, N) # Update output = aggregated if self.root: output += ops.dot(X, self.root_kernel) if self.use_bias: output = K.bias_add(output, self.bias) if self.activation is not None: output = self.activation(output) return output
def _call_single(self, inputs): x, a, e = inputs if K.ndim(e) != 2: raise ValueError('In single mode, E must have shape ' '(n_edges, n_edge_features).') # Enforce sparse representation if not K.is_sparse(a): a = ops.dense_to_sparse(a) # Parameters N = tf.shape(x)[-2] F = K.int_shape(x)[-1] F_ = self.channels # Filter network kernel_network = e for layer in self.kernel_network_layers: kernel_network = layer(kernel_network) # (n_edges, F * F_) target_shape = (-1, F, F_) kernel = tf.reshape(kernel_network, target_shape) # Propagation index_i = a.indices[:, -2] index_j = a.indices[:, -1] messages = tf.gather(x, index_j) messages = ops.dot(messages[:, None, :], kernel)[:, 0, :] aggregated = ops.scatter_sum(messages, index_i, N) # Update output = aggregated if self.root: output += ops.dot(x, self.root_kernel) if self.use_bias: output = K.bias_add(output, self.bias) if self.activation is not None: output = self.activation(output) return output
def call(self, inputs): features = inputs[0] fltr = inputs[1] # Enforce sparse representation if not K.is_sparse(fltr): fltr = ops.dense_to_sparse(fltr) # Propagation indices = fltr.indices N = tf.shape(features, out_type=indices.dtype)[0] indices = ops.sparse_add_self_loops(indices, N) targets, sources = indices[:, -2], indices[:, -1] messages = tf.gather(features, sources) aggregated = self.aggregate_op(messages, targets, N) output = K.concatenate([features, aggregated]) output = ops.dot(output, self.kernel) if self.use_bias: output = K.bias_add(output, self.bias) output = K.l2_normalize(output, axis=-1) if self.activation is not None: output = self.activation(output) return output