def call(self, inputs): if len(inputs) == 3: X, A, I = inputs if K.ndim(I) == 2: I = I[:, 0] else: X, A = inputs I = None # Check if the layer is operating in batch mode (X and A have rank 3) batch_mode = K.ndim(X) == 3 # Compute cluster assignment matrix S = self.mlp(X) # MinCut regularization A_pooled = ops.matmul_at_b_a(S, A) num = tf.linalg.trace(A_pooled) D = ops.degree_matrix(A) den = tf.linalg.trace(ops.matmul_at_b_a(S, D)) + K.epsilon() cut_loss = -(num / den) if batch_mode: cut_loss = K.mean(cut_loss) self.add_loss(cut_loss) # Orthogonality regularization SS = ops.modal_dot(S, S, transpose_a=True) I_S = tf.eye(self.k, dtype=SS.dtype) ortho_loss = tf.norm( SS / tf.norm(SS, axis=(-1, -2), keepdims=True) - I_S / tf.norm(I_S), axis=(-1, -2), ) if batch_mode: ortho_loss = K.mean(ortho_loss) self.add_loss(ortho_loss) # Pooling X_pooled = ops.modal_dot(S, X, transpose_a=True) A_pooled = tf.linalg.set_diag( A_pooled, tf.zeros(K.shape(A_pooled)[:-1], dtype=A_pooled.dtype)) # Remove diagonal A_pooled = ops.normalize_A(A_pooled) output = [X_pooled, A_pooled] if I is not None: I_mean = tf.math.segment_mean(I, I) I_pooled = ops.repeat(I_mean, tf.ones_like(I_mean) * self.k) output.append(I_pooled) if self.return_mask: output.append(S) return output
def call(self, inputs): X, A = inputs N = K.shape(A)[-1] # Check if the layer is operating in mixed or batch mode mode = ops.autodetect_mode(X, A) self.reduce_loss = mode in (modes.MIXED, modes.BATCH) # Get normalized adjacency if K.is_sparse(A): I_ = tf.sparse.eye(N, dtype=A.dtype) A_ = tf.sparse.add(A, I_) else: I_ = tf.eye(N, dtype=A.dtype) A_ = A + I_ fltr = ops.normalize_A(A_) # Node embeddings Z = K.dot(X, self.kernel_emb) Z = ops.modal_dot(fltr, Z) if self.activation is not None: Z = self.activation(Z) # Compute cluster assignment matrix S = K.dot(X, self.kernel_pool) S = ops.modal_dot(fltr, S) S = activations.softmax(S, axis=-1) # softmax applied row-wise # Link prediction loss S_gram = ops.modal_dot(S, S, transpose_b=True) if mode == modes.MIXED: A = tf.sparse.to_dense(A)[None, ...] if K.is_sparse(A): LP_loss = tf.sparse.add(A, -S_gram) # A/tf.norm(A) - S_gram/tf.norm(S_gram) else: LP_loss = A - S_gram LP_loss = tf.norm(LP_loss, axis=(-1, -2)) if self.reduce_loss: LP_loss = K.mean(LP_loss) self.add_loss(LP_loss) # Entropy loss entr = tf.negative( tf.reduce_sum(tf.multiply(S, K.log(S + K.epsilon())), axis=-1) ) entr_loss = K.mean(entr, axis=-1) if self.reduce_loss: entr_loss = K.mean(entr_loss) self.add_loss(entr_loss) # Pooling X_pooled = ops.modal_dot(S, Z, transpose_a=True) A_pooled = ops.matmul_at_b_a(S, A) output = [X_pooled, A_pooled] if self.return_mask: output.append(S) return output
def connect(self, a, s, **kwargs): a_pool = ops.matmul_at_b_a(s, a) # Post-processing of A a_pool = tf.linalg.set_diag( a_pool, tf.zeros(K.shape(a_pool)[:-1], dtype=a_pool.dtype)) a_pool = ops.normalize_A(a_pool) return a_pool
def connect(self, a, s, **kwargs): a_pool = ops.matmul_at_b_a(s, a) # Modularity loss mod_loss = self.modularity_loss(a, s, a_pool) if K.ndim(a) == 3: mod_loss = K.mean(mod_loss) self.add_loss(mod_loss) return a_pool
def connect(self, a, s, **kwargs): a_pool = ops.matmul_at_b_a(s, a) # MinCut loss cut_loss = self.mincut_loss(a, s, a_pool) if K.ndim(a) == 3: cut_loss = K.mean(cut_loss) self.add_loss(cut_loss) # Post-processing of A a_pool = tf.linalg.set_diag( a_pool, tf.zeros(K.shape(a_pool)[:-1], dtype=a_pool.dtype)) a_pool = ops.normalize_A(a_pool) return a_pool
def connect(self, a, s, **kwargs): return ops.matmul_at_b_a(s, a)
def mincut_loss(a, s, a_pool): num = tf.linalg.trace(a_pool) d = ops.degree_matrix(a) den = tf.linalg.trace(ops.matmul_at_b_a(s, d)) cut_loss = -(num / den) return cut_loss