def modularity_loss(self, a, s, a_pool): if K.is_sparse(a): n_edges = tf.cast(len(a.values), dtype=s.dtype) degrees = tf.sparse.reduce_sum(a, axis=-1) degrees = tf.reshape(degrees, (-1, 1)) else: n_edges = tf.cast(tf.math.count_nonzero(a, axis=(-2, -1)), dtype=s.dtype) degrees = tf.reduce_sum(a, axis=-1, keepdims=True) normalizer_left = tf.matmul(s, degrees, transpose_a=True) normalizer_right = tf.matmul(degrees, s, transpose_a=True) if K.ndim(s) == 3: normalizer = ( ops.modal_dot(normalizer_left, normalizer_right) / 2 / tf.reshape(n_edges, [tf.shape(n_edges)[0]] + [1] * 2)) else: normalizer = ops.modal_dot(normalizer_left, normalizer_right) / 2 / n_edges loss = -tf.linalg.trace(a_pool - normalizer) / 2 / n_edges return loss
def _propagate_edges(self, inputs, mask=None): """ Performs the edge feature propagation step. :param inputs: All the inputs to the layer. :param mask: The mask to use. :return: The propagated edge features. """ node_features, (_, laplacian, incidence), edge_features = inputs weighted_node_features = tf.matmul(node_features, self.node_weights) # Remove the extra 1-dimension. weighted_node_features = tf.squeeze(weighted_node_features, axis=[-1]) weighted_node_features = tf.linalg.diag(weighted_node_features) weighted_node_features = ops.modal_dot(incidence, weighted_node_features, transpose_a=True) weighted_node_features = ops.modal_dot(weighted_node_features, incidence) edge_adjacency = weighted_node_features * laplacian output = ops.modal_dot(edge_adjacency, edge_features) output = ops.modal_dot(output, self.edge_kernel) return self._bias_and_activation(output, bias_weights=self.edge_bias, mask=mask)
def call(self, inputs): X, A = inputs N = K.shape(A)[-1] # Check if the layer is operating in mixed or batch mode mode = ops.autodetect_mode(X, A) self.reduce_loss = mode in (modes.MIXED, modes.BATCH) # Get normalized adjacency if K.is_sparse(A): I_ = tf.sparse.eye(N, dtype=A.dtype) A_ = tf.sparse.add(A, I_) else: I_ = tf.eye(N, dtype=A.dtype) A_ = A + I_ fltr = ops.normalize_A(A_) # Node embeddings Z = K.dot(X, self.kernel_emb) Z = ops.modal_dot(fltr, Z) if self.activation is not None: Z = self.activation(Z) # Compute cluster assignment matrix S = K.dot(X, self.kernel_pool) S = ops.modal_dot(fltr, S) S = activations.softmax(S, axis=-1) # softmax applied row-wise # Link prediction loss S_gram = ops.modal_dot(S, S, transpose_b=True) if mode == modes.MIXED: A = tf.sparse.to_dense(A)[None, ...] if K.is_sparse(A): LP_loss = tf.sparse.add(A, -S_gram) # A/tf.norm(A) - S_gram/tf.norm(S_gram) else: LP_loss = A - S_gram LP_loss = tf.norm(LP_loss, axis=(-1, -2)) if self.reduce_loss: LP_loss = K.mean(LP_loss) self.add_loss(LP_loss) # Entropy loss entr = tf.negative( tf.reduce_sum(tf.multiply(S, K.log(S + K.epsilon())), axis=-1) ) entr_loss = K.mean(entr, axis=-1) if self.reduce_loss: entr_loss = K.mean(entr_loss) self.add_loss(entr_loss) # Pooling X_pooled = ops.modal_dot(S, Z, transpose_a=True) A_pooled = ops.matmul_at_b_a(S, A) output = [X_pooled, A_pooled] if self.return_mask: output.append(S) return output
def call(self, inputs): if len(inputs) == 3: X, A, I = inputs if K.ndim(I) == 2: I = I[:, 0] else: X, A = inputs I = None # Check if the layer is operating in batch mode (X and A have rank 3) batch_mode = K.ndim(X) == 3 # Compute cluster assignment matrix S = self.mlp(X) # MinCut regularization A_pooled = ops.matmul_at_b_a(S, A) num = tf.linalg.trace(A_pooled) D = ops.degree_matrix(A) den = tf.linalg.trace(ops.matmul_at_b_a(S, D)) + K.epsilon() cut_loss = -(num / den) if batch_mode: cut_loss = K.mean(cut_loss) self.add_loss(cut_loss) # Orthogonality regularization SS = ops.modal_dot(S, S, transpose_a=True) I_S = tf.eye(self.k, dtype=SS.dtype) ortho_loss = tf.norm( SS / tf.norm(SS, axis=(-1, -2), keepdims=True) - I_S / tf.norm(I_S), axis=(-1, -2), ) if batch_mode: ortho_loss = K.mean(ortho_loss) self.add_loss(ortho_loss) # Pooling X_pooled = ops.modal_dot(S, X, transpose_a=True) A_pooled = tf.linalg.set_diag( A_pooled, tf.zeros(K.shape(A_pooled)[:-1], dtype=A_pooled.dtype)) # Remove diagonal A_pooled = ops.normalize_A(A_pooled) output = [X_pooled, A_pooled] if I is not None: I_mean = tf.math.segment_mean(I, I) I_pooled = ops.repeat(I_mean, tf.ones_like(I_mean) * self.k) output.append(I_pooled) if self.return_mask: output.append(S) return output
def gcs(self, inputs, stack, iteration): """ Creates a graph convolutional layer with a skip connection. :param inputs: list of input Tensors, namely - input node features - input node features for the skip connection - normalized adjacency matrix; :param stack: int, current stack (used to retrieve kernels); :param iteration: int, current iteration (used to retrieve kernels); :return: output node features. """ x, x_skip, a = inputs itr = 1 if self.share_weights and iteration >= 1 else iteration kernel_1, kernel_2, bias = self.kernels[stack][itr] output = K.dot(x, kernel_1) output = ops.modal_dot(a, output) skip = K.dot(x_skip, kernel_2) skip = self.dropout(skip) output += skip if self.use_bias: output = K.bias_add(output, bias) output = self.gcn_activation(output) return output
def link_prediction_loss(a, s): s_gram = ops.modal_dot(s, s, transpose_b=True) if K.is_sparse(a): lp_loss = tf.sparse.add(a, -s_gram) else: lp_loss = a - s_gram lp_loss = tf.norm(lp_loss, axis=(-1, -2)) return lp_loss
def orthogonality_loss(self, s): ss = ops.modal_dot(s, s, transpose_a=True) i_s = tf.eye(self.k, dtype=ss.dtype) ortho_loss = tf.norm( ss / tf.norm(ss, axis=(-1, -2), keepdims=True) - i_s / tf.norm(i_s), axis=(-1, -2), ) return ortho_loss
def balance_loss(self, s): ss = ops.modal_dot(s, s, transpose_a=True) loss = -tf.linalg.trace(tf.math.sqrt(ss)) if self.normalized_loss: n = float(tf.shape(s, out_type=tf.int32)[-2]) c = float(tf.shape(s, out_type=tf.int32)[-1]) loss = loss / tf.math.sqrt(n * c) return loss
def call(self, inputs): x, a = inputs mlp_out = self.mlp(x) z = mlp_out for k in range(self.propagations): z = (1 - self.alpha) * ops.modal_dot(a, z) + self.alpha * mlp_out output = self.activation(z) return output
def call(self, inputs): x, a = inputs T_0 = x output = KB.dot(T_0, self.kernel[0]) if self.K > 1: T_1 = ops.modal_dot(a, x) output += KB.dot(T_1, self.kernel[1]) for k in range(2, self.K): T_2 = 2 * ops.modal_dot(a, T_1) - T_0 output += KB.dot(T_2, self.kernel[k]) T_0, T_1 = T_1, T_2 if self.use_bias: output = KB.bias_add(output, self.bias) output = self.activation(output) return output
def call(self, inputs): x, a = inputs output = K.dot(x, self.kernel) output = ops.modal_dot(a, output) if self.use_bias: output = K.bias_add(output, self.bias) output = self.activation(output) return output
def call(self, inputs): x, a = inputs output = K.dot(x, self.kernel_1) output = ops.modal_dot(a, output) skip = K.dot(x, self.kernel_2) output += skip if self.use_bias: output = K.bias_add(output, self.bias) if self.activation is not None: output = self.activation(output) return output
def call(self, inputs, mask=None): x, a = inputs output = K.dot(x, self.kernel) output = ops.modal_dot(a, output) if self.use_bias: output = K.bias_add(output, self.bias) if mask is not None: output *= mask[0] output = self.activation(output) return output
def call(self, inputs, mask=None): x, a = inputs mlp_out = self.mlp(x) output = mlp_out for _ in range(self.propagations): output = (1 - self.alpha) * ops.modal_dot( a, output) + self.alpha * mlp_out if mask is not None: output *= mask[0] output = self.activation(output) return output
def call(self, inputs, **kwargs): x, a, i = self.get_inputs(inputs) # Select leaders lap = laplacian(a) v = ops.modal_dot(lap, x) v = tf.norm(v, axis=-1, keepdims=1) row = a.indices[:, 0] col = a.indices[:, 1] leader_check = tf.cast(tf.gather(v, row) >= tf.gather(v, col), tf.int32) leader_mask = ops.scatter_prod(leader_check[:, 0], row, self.n_nodes) leader_mask = tf.cast(leader_mask, tf.bool) return self.pool(x, a, i, leader_mask=leader_mask)
def select(self, x, a, i, fltr=None, mask=None): s = ops.modal_dot(fltr, K.dot(x, self.kernel_pool)) s = activations.softmax(s, axis=-1) if mask is not None: s *= mask[0] # Auxiliary losses lp_loss = self.link_prediction_loss(a, s) entr_loss = self.entropy_loss(s) if K.ndim(x) == 3: lp_loss = K.mean(lp_loss) entr_loss = K.mean(entr_loss) self.add_loss(lp_loss) self.add_loss(entr_loss) return s
def call(self, inputs): x, a, i = self.get_inputs(inputs) # Graph filter for GNN if K.is_sparse(a): i_n = tf.sparse.eye(self.n_nodes, dtype=a.dtype) a_ = tf.sparse.add(a, i_n) else: i_n = tf.eye(self.n_nodes, dtype=a.dtype) a_ = a + i_n fltr = ops.normalize_A(a_) y = ops.modal_dot(fltr, K.dot(x, self.kernel)) output = self.pool(x, a, i, y=y) if self.return_score: output.append(y) return output
def reduce(self, x, s, fltr=None): z = ops.modal_dot(fltr, K.dot(x, self.kernel_emb)) z = self.activation(z) return ops.modal_dot(s, z, transpose_a=True)
def reduce(self, x, s, **kwargs): return ops.modal_dot(s, x, transpose_a=True)
def call(self, inputs, **kwargs): x, a = inputs output = self.mlp((self.one + self.eps) * x + ops.modal_dot(a, x)) return output