def _call_single(self, X, A): # Reshape kernels for efficient message-passing kernel = tf.reshape(self.kernel, (-1, self.attn_heads * self.channels)) attn_kernel_self = ops.transpose(self.attn_kernel_self, (2, 1, 0)) attn_kernel_neighs = ops.transpose(self.attn_kernel_neighs, (2, 1, 0)) # Prepare message-passing indices = A.indices N = tf.shape(X, out_type=indices.dtype)[0] indices = ops.sparse_add_self_loops(indices, N) targets, sources = indices[:, -2], indices[:, -1] # Update node features X = ops.dot(X, kernel) X = tf.reshape(X, (-1, self.attn_heads, self.channels)) # Compute attention attn_for_self = tf.reduce_sum(X * attn_kernel_self, -1) attn_for_self = tf.gather(attn_for_self, targets) attn_for_neighs = tf.reduce_sum(X * attn_kernel_neighs, -1) attn_for_neighs = tf.gather(attn_for_neighs, sources) attn_coef = attn_for_self + attn_for_neighs attn_coef = tf.nn.leaky_relu(attn_coef, alpha=0.2) attn_coef = ops.unsorted_segment_softmax(attn_coef, targets, N) attn_coef = self.dropout(attn_coef) attn_coef = attn_coef[..., None] # Update representation output = attn_coef * tf.gather(X, sources) output = ops.scatter_sum(output, targets, N) return output, attn_coef
def call(self, inputs): X = inputs[0] # (batch_size, N, F) A = inputs[1] # (batch_size, N, N) E = inputs[2] # (n_edges, S) or (batch_size, N, N, S) mode = ops.autodetect_mode(A, X) if mode == modes.SINGLE: return self._call_single(inputs) # Parameters N = K.shape(X)[-2] F = K.int_shape(X)[-1] F_ = self.channels # Filter network kernel_network = E for l in self.kernel_network_layers: kernel_network = l(kernel_network) # Convolution target_shape = (-1, N, N, F_, F) if mode == modes.BATCH else (N, N, F_, F) kernel = K.reshape(kernel_network, target_shape) output = kernel * A[..., None, None] output = tf.einsum('abicf,aif->abc', output, X) if self.root: output += ops.dot(X, self.root_kernel) if self.use_bias: output = K.bias_add(output, self.bias) if self.activation is not None: output = self.activation(output) return output
def call(self, inputs): x, a, e = inputs mode = ops.autodetect_mode(a, x) if mode == modes.SINGLE: return self._call_single(inputs) # Parameters N = K.shape(x)[-2] F = K.int_shape(x)[-1] F_ = self.channels # Filter network kernel_network = e for layer in self.kernel_network_layers: kernel_network = layer(kernel_network) # Convolution target_shape = (-1, N, N, F_, F) if mode == modes.BATCH else (N, N, F_, F) kernel = K.reshape(kernel_network, target_shape) output = kernel * a[..., None, None] output = tf.einsum('abicf,aif->abc', output, x) if self.root: output += ops.dot(x, self.root_kernel) if self.use_bias: output = K.bias_add(output, self.bias) if self.activation is not None: output = self.activation(output) return output
def call(self, inputs): x, a = inputs output = ops.dot(x, self.kernel) output = ops.filter_dot(a, output) if self.use_bias: output = K.bias_add(output, self.bias) output = self.activation(output) return output
def call(self, inputs): x, a = inputs T_0 = x output = ops.dot(T_0, self.kernel[0]) if self.K > 1: T_1 = ops.filter_dot(a, x) output += ops.dot(T_1, self.kernel[1]) for k in range(2, self.K): T_2 = 2 * ops.filter_dot(a, T_1) - T_0 output += ops.dot(T_2, self.kernel[k]) T_0, T_1 = T_1, T_2 if self.use_bias: output = K.bias_add(output, self.bias) output = self.activation(output) return output
def call(self, inputs): features = inputs # Convolution output = ops.dot(features, self.kernel) output = ops.mixed_mode_dot(self.fltr, output) if self.use_bias: output = K.bias_add(output, self.bias) if self.activation is not None: output = self.activation(output) return output
def _call_single(self, inputs): X = inputs[0] # (N, F) A = inputs[1] # (N, N) E = inputs[2] # (n_edges, S) assert K.ndim( E) == 2, 'In single mode, E must have shape (n_edges, S).' # Enforce sparse representation if not K.is_sparse(A): A = ops.dense_to_sparse(A) # Parameters N = tf.shape(X)[-2] F = K.int_shape(X)[-1] F_ = self.channels # Filter network kernel_network = E for l in self.kernel_network_layers: kernel_network = l(kernel_network) # (n_edges, F * F_) target_shape = (-1, F, F_) kernel = tf.reshape(kernel_network, target_shape) # Propagation index_i = A.indices[:, -2] index_j = A.indices[:, -1] messages = tf.gather(X, index_j) messages = ops.dot(messages[:, None, :], kernel)[:, 0, :] aggregated = ops.scatter_sum(messages, index_i, N) # Update output = aggregated if self.root: output += ops.dot(X, self.root_kernel) if self.use_bias: output = K.bias_add(output, self.bias) if self.activation is not None: output = self.activation(output) return output
def call(self, inputs): features = inputs[0] fltr = inputs[1] # Convolution output = ops.dot(features, self.kernel) output = ops.filter_dot(fltr, output) if self.use_bias: output = K.bias_add(output, self.bias) if self.activation is not None: output = self.activation(output) return output
def call(self, inputs): features = inputs[0] laplacian = inputs[1] # Convolution T_0 = features output = ops.dot(T_0, self.kernel[0]) if self.K > 1: T_1 = ops.filter_dot(laplacian, features) output += ops.dot(T_1, self.kernel[1]) for k in range(2, self.K): T_2 = 2 * ops.filter_dot(laplacian, T_1) - T_0 output += ops.dot(T_2, self.kernel[k]) T_0, T_1 = T_1, T_2 if self.use_bias: output = K.bias_add(output, self.bias) if self.activation is not None: output = self.activation(output) return output
def _call_single(self, inputs): x, a, e = inputs if K.ndim(e) != 2: raise ValueError('In single mode, E must have shape ' '(n_edges, n_edge_features).') # Enforce sparse representation if not K.is_sparse(a): a = ops.dense_to_sparse(a) # Parameters N = tf.shape(x)[-2] F = K.int_shape(x)[-1] F_ = self.channels # Filter network kernel_network = e for layer in self.kernel_network_layers: kernel_network = layer(kernel_network) # (n_edges, F * F_) target_shape = (-1, F, F_) kernel = tf.reshape(kernel_network, target_shape) # Propagation index_i = a.indices[:, -2] index_j = a.indices[:, -1] messages = tf.gather(x, index_j) messages = ops.dot(messages[:, None, :], kernel)[:, 0, :] aggregated = ops.scatter_sum(messages, index_i, N) # Update output = aggregated if self.root: output += ops.dot(x, self.root_kernel) if self.use_bias: output = K.bias_add(output, self.bias) if self.activation is not None: output = self.activation(output) return output
def call(self, inputs, **kwargs): x, a, _ = self.get_inputs(inputs) # TODO: a = add_self_loops(a) x = dot(x, self.kernel) if self.use_bias: x = tf.nn.bias_add(x, self.bias) if self.use_batch_norm: x = self.batch_norm(x) x = self.dropout(x) x = self.activation(x) return self.propagate(x, a)
def call(self, inputs): x, a, _ = self.get_inputs(inputs) a = ops.add_self_loops(a) aggregated = self.propagate(x, a) output = K.concatenate([x, aggregated]) output = ops.dot(output, self.kernel) if self.use_bias: output = K.bias_add(output, self.bias) output = K.l2_normalize(output, axis=-1) if self.activation is not None: output = self.activation(output) return output
def k_hop_sparse_subgraph(a, node_idx, k, transformer=None): """ Computes the subgraph containing all the neighbors of `node_idx` up to the k-th order. If `a` is not the binary adjacency matrix a `transformer` should be passed. **Arguments** - `a`: sparse `(n_nodes, n_nodes)` graph tensor; - `node_idx`: center node; - `k`: order of neighbor; - `transformer`: one of the functions from the `spektral.transforms` module, needed to convert the binary adjacency matrix into the correct format for the model; """ if a.dtype != tf.float32: a = tf.cast(a, tf.float32) if transformer: a = binary_adj_converter(a) power_a = tf.sparse.eye(a.shape[0]) k_neighs = np.zeros(a.shape[0]).astype("float32").reshape(1, -1) k_neighs[0, node_idx] = 1 for _ in range(k - 1): power_a = dot(power_a, a) temp = tf.sparse.slice(power_a, start=[node_idx, 0], size=[1, power_a.shape[0]]) k_neighs += tf.sparse.to_dense(temp) comp_graph = tf.sparse.add(a * tf.reshape(k_neighs, (-1, 1)), a * k_neighs) is_nonzero = tf.not_equal(comp_graph.values, 0) comp_graph = tf.sparse.retain(comp_graph, is_nonzero) comp_graph = tf.sign(comp_graph) if transformer: comp_graph = sp_tensor_to_sp_matrix(comp_graph) comp_graph = transformer(comp_graph) return sp_matrix_to_sp_tensor(comp_graph) else: return comp_graph
def call(self, inputs): features = inputs[0] fltr = inputs[1] # Enforce sparse representation if not K.is_sparse(fltr): fltr = ops.dense_to_sparse(fltr) # Propagation indices = fltr.indices N = tf.shape(features, out_type=indices.dtype)[0] indices = ops.sparse_add_self_loops(indices, N) targets, sources = indices[:, -2], indices[:, -1] messages = tf.gather(features, sources) aggregated = self.aggregate_op(messages, targets, N) output = K.concatenate([features, aggregated]) output = ops.dot(output, self.kernel) if self.use_bias: output = K.bias_add(output, self.bias) output = K.l2_normalize(output, axis=-1) if self.activation is not None: output = self.activation(output) return output