def call(self, inputs): if len(inputs) == 3: X, A, I = inputs if K.ndim(I) == 2: I = I[:, 0] else: X, A = inputs I = None # Check if the layer is operating in batch mode (X and A have rank 3) batch_mode = K.ndim(X) == 3 # Compute cluster assignment matrix S = self.mlp(X) # MinCut regularization A_pooled = ops.matmul_at_b_a(S, A) num = tf.linalg.trace(A_pooled) D = ops.degree_matrix(A) den = tf.linalg.trace(ops.matmul_at_b_a(S, D)) + K.epsilon() cut_loss = -(num / den) if batch_mode: cut_loss = K.mean(cut_loss) self.add_loss(cut_loss) # Orthogonality regularization SS = ops.modal_dot(S, S, transpose_a=True) I_S = tf.eye(self.k, dtype=SS.dtype) ortho_loss = tf.norm( SS / tf.norm(SS, axis=(-1, -2), keepdims=True) - I_S / tf.norm(I_S), axis=(-1, -2), ) if batch_mode: ortho_loss = K.mean(ortho_loss) self.add_loss(ortho_loss) # Pooling X_pooled = ops.modal_dot(S, X, transpose_a=True) A_pooled = tf.linalg.set_diag( A_pooled, tf.zeros(K.shape(A_pooled)[:-1], dtype=A_pooled.dtype)) # Remove diagonal A_pooled = ops.normalize_A(A_pooled) output = [X_pooled, A_pooled] if I is not None: I_mean = tf.math.segment_mean(I, I) I_pooled = ops.repeat(I_mean, tf.ones_like(I_mean) * self.k) output.append(I_pooled) if self.return_mask: output.append(S) return output
def segment_top_k(x, i, ratio): """ Returns indices to get the top K values in x segment-wise, according to the segments defined in I. K is not fixed, but it is defined as a ratio of the number of elements in each segment. :param x: a rank 1 Tensor; :param i: a rank 1 Tensor with segment IDs for x; :param ratio: float, ratio of elements to keep for each segment; :return: a rank 1 Tensor containing the indices to get the top K values of each segment in x. """ i = tf.cast(i, tf.int32) n = tf.shape(i)[0] n_nodes = tf.math.segment_sum(tf.ones_like(i), i) batch_size = tf.shape(n_nodes)[0] n_nodes_max = tf.reduce_max(n_nodes) cumulative_n_nodes = tf.concat( (tf.zeros(1, dtype=n_nodes.dtype), tf.cumsum(n_nodes)[:-1]), 0 ) index = tf.range(n) index = (index - tf.gather(cumulative_n_nodes, i)) + (i * n_nodes_max) dense_x = tf.zeros(batch_size * n_nodes_max, dtype=x.dtype) - 1e20 dense_x = tf.tensor_scatter_nd_update(dense_x, index[:, None], x) dense_x = tf.reshape(dense_x, (batch_size, n_nodes_max)) perm = tf.argsort(dense_x, direction="DESCENDING") perm = perm + cumulative_n_nodes[:, None] perm = tf.reshape(perm, (-1,)) k = tf.cast(tf.math.ceil(ratio * tf.cast(n_nodes, tf.float32)), i.dtype) # This costs more memory # to_rep = tf.tile(tf.constant([1., 0.]), (batch_size,)) # rep_times = tf.reshape(tf.concat((k[:, None], (n_nodes_max - k)[:, None]), -1), (-1,)) # mask = ops.repeat(to_rep, rep_times) # perm = tf.boolean_mask(perm, mask) # This is slower r_range = tf.ragged.range(k).flat_values r_delta = ops.repeat(tf.range(batch_size) * n_nodes_max, k) mask = r_range + r_delta perm = tf.gather(perm, mask) return perm
def call(self, inputs): # Note that I is useless, because thee layer cannot be used in graph # batch mode. if len(inputs) == 3: X, A, I = inputs else: X, A = inputs I = None N = K.shape(A)[-1] # Check if the layer is operating in batch mode (X and A have rank 3) mode = ops.autodetect_mode(A, X) self.reduce_loss = mode in (ops._modes['M'], ops._modes['B']) # Get normalized adjacency if K.is_sparse(A): I_ = tf.sparse.eye(N, dtype=A.dtype) A_ = tf.sparse.add(A, I_) else: I_ = tf.eye(N, dtype=A.dtype) A_ = A + I_ fltr = ops.normalize_A(A_) # Node embeddings Z = K.dot(X, self.kernel_emb) Z = ops.filter_dot(fltr, Z) if self.activation is not None: Z = self.activation(Z) # Compute cluster assignment matrix S = K.dot(X, self.kernel_pool) S = ops.filter_dot(fltr, S) S = activations.softmax(S, axis=-1) # softmax applied row-wise # Link prediction loss S_gram = ops.matmul_A_BT(S, S) if K.is_sparse(A): LP_loss = tf.sparse.add(A, -S_gram) # A/tf.norm(A) - S_gram/tf.norm(S_gram) else: LP_loss = A - S_gram LP_loss = tf.norm(LP_loss, axis=(-1, -2)) if self.reduce_loss: LP_loss = K.mean(LP_loss) self.add_loss(LP_loss) # Entropy loss entr = tf.negative(tf.reduce_sum(tf.multiply(S, K.log(S + K.epsilon())), axis=-1)) entr_loss = K.mean(entr, axis=-1) if self.reduce_loss: entr_loss = K.mean(entr_loss) self.add_loss(entr_loss) # Pooling X_pooled = ops.matmul_AT_B(S, Z) A_pooled = ops.matmul_AT_B_A(S, A) if K.ndim(A_pooled) == 3: self.mixed_mode = True output = [X_pooled, A_pooled] if I is not None: I_mean = tf.segment_mean(I, I) I_pooled = ops.repeat(I_mean, tf.ones_like(I_mean) * self.k) output.append(I_pooled) if self.return_mask: output.append(S) return output
def call(self, inputs): # Note that I is useless, because thee layer cannot be used in graph # batch mode. if len(inputs) == 3: X, A, I = inputs else: X, A = inputs I = None # Check if the layer is operating in batch mode (X and A have rank 3) batch_mode = K.ndim(A) == 3 # Optionally compute hidden layer if self.h is None: Hid = X else: Hid = K.dot(X, self.kernel_in) if self.use_bias: Hid = K.bias_add(Hid, self.bias_in) if self.activation is not None: Hid = self.activation(Hid) # Compute cluster assignment matrix S = K.dot(Hid, self.kernel_out) if self.use_bias: S = K.bias_add(S, self.bias_out) S = activations.softmax(S, axis=-1) # Apply softmax to get cluster assignments # MinCut regularization A_pooled = ops.matmul_AT_B_A(S, A) num = tf.trace(A_pooled) D = ops.degree_matrix(A) den = tf.trace(ops.matmul_AT_B_A(S, D)) cut_loss = -(num / den) if batch_mode: cut_loss = K.mean(cut_loss) self.add_loss(cut_loss) # Orthogonality regularization SS = ops.matmul_AT_B(S, S) I_S = tf.eye(self.k) ortho_loss = tf.norm( SS / tf.norm(SS, axis=(-1, -2)) - I_S / tf.norm(I_S), axis=(-1, -2) ) if batch_mode: ortho_loss = K.mean(cut_loss) self.add_loss(ortho_loss) # Pooling X_pooled = ops.matmul_AT_B(S, X) A_pooled = tf.linalg.set_diag(A_pooled, tf.zeros(K.shape(A_pooled)[:-1])) # Remove diagonal A_pooled = ops.normalize_A(A_pooled) output = [X_pooled, A_pooled] if I is not None: I_mean = tf.segment_mean(I, I) I_pooled = ops.repeat(I_mean, tf.ones_like(I_mean) * self.k) output.append(I_pooled) if self.return_mask: output.append(S) return output
def reduce_index(self, i, s, **kwargs): i_mean = tf.math.segment_mean(i, i) i_pool = ops.repeat(i_mean, tf.ones_like(i_mean) * self.k) return i_pool