def sparse_bool_mask(x, mask, axis=0): # Only necessary if indices may have non-unique elements indices = tf.boolean_mask(tf.range(tf.shape(x)[axis]), mask) n_indices = tf.size(indices) # Get indices for the axis idx = x.indices[:, axis] # Find where indices match the selection eq = tf.equal(tf.expand_dims(idx, 1), tf.cast(indices, tf.int64)) # TODO this has quadratic cost # Mask for selected values sel = tf.reduce_any(eq, axis=1) # Selected values values_new = tf.boolean_mask(x.values, sel, axis=0) # New index value for selected elements n_indices = tf.cast(n_indices, tf.int64) idx_new = tf.reduce_sum(tf.cast(eq, tf.int64) * tf.range(n_indices), axis=1) idx_new = tf.boolean_mask(idx_new, sel, axis=0) # New full indices tensor indices_new = tf.boolean_mask(x.indices, sel, axis=0) indices_new = tf.concat([ indices_new[:, :axis], tf.expand_dims(idx_new, 1), indices_new[:, axis + 1:] ], axis=1) # New shape shape_new = tf.concat( [x.dense_shape[:axis], [n_indices], x.dense_shape[axis + 1:]], axis=0) return tf.SparseTensor(indices_new, values_new, shape_new)
def call(self, inputs): if len(inputs) == 3: X, A, I = inputs self.data_mode = 'graph' else: X, A = inputs I = tf.zeros(tf.shape(X)[:1], dtype=tf.int32) self.data_mode = 'single' if K.ndim(I) == 2: I = I[:, 0] A_is_sparse = K.is_sparse(A) # Get mask y = K.dot(X, K.l2_normalize(self.kernel)) N = K.shape(X)[-2] indices = ops.top_k(y[:, 0], I, self.ratio, self.top_k_var) mask = tf.scatter_nd(tf.expand_dims(indices, 1), tf.ones_like(indices), (N, )) # Multiply X and y to make layer differentiable features = X * self.gating_op(y) axis = 0 if len(K.int_shape( A)) == 2 else 1 # Cannot use negative axis in tf.boolean_mask # Reduce X X_pooled = tf.boolean_mask(features, mask, axis=axis) # Compute A^2 if A_is_sparse: A_dense = tf.sparse.to_dense(A) else: A_dense = A A_squared = K.dot(A, A_dense) # Reduce A A_pooled = tf.boolean_mask(A_squared, mask, axis=axis) A_pooled = tf.boolean_mask(A_pooled, mask, axis=axis + 1) if A_is_sparse: A_pooled = tf.contrib.layers.dense_to_sparse(A_pooled) output = [X_pooled, A_pooled] # Reduce I if self.data_mode == 'graph': I_pooled = tf.boolean_mask(I[:, None], mask)[:, 0] output.append(I_pooled) if self.return_mask: output.append(mask) return output
def upsampling_from_mask(inputs): X_, A_, I_, M_ = inputs S_ = tf.eye(tf.shape(M_)[0]) S_ = tf.boolean_mask(S_, M_) S_t_ = tf.transpose(S_) X_out_ = K.dot(S_t_, X_) A_out_ = K.dot(K.transpose(K.dot(A_, S_)), S_) I_out_ = K.dot(S_t_, K.cast(I_[:, None], tf.float32))[:, 0] I_out_ = K.cast(I_out_, tf.int32) return [X_out_, A_out_, I_out_]
def top_k(scores, I, ratio, top_k_var): """ Returns indices to get the top K values in `scores` segment-wise, with segments defined by I. K is not fixed, but it is defined as a ratio of the number of elements in each segment. :param scores: a rank 1 tensor with scores; :param I: a rank 1 tensor with segment IDs; :param ratio: float, ratio of elements to keep for each segment; :param top_k_var: a tf.Variable without shape validation (e.g., `tf.Variable(0.0, validate_shape=False)`); :return: a rank 1 tensor containing the indices to get the top K values of each segment in `scores`. """ num_nodes = tf.segment_sum(tf.ones_like(I), I) # Number of nodes in each graph cumsum = tf.cumsum(num_nodes) # Cumulative number of nodes (A, A+B, A+B+C) cumsum_start = cumsum - num_nodes # Start index of each graph n_graphs = tf.shape(num_nodes)[0] # Number of graphs in batch max_n_nodes = tf.reduce_max(num_nodes) # Order of biggest graph in batch batch_n_nodes = tf.shape(I)[0] # Number of overall nodes in batch to_keep = tf.ceil(ratio * tf.cast(num_nodes, tf.float32)) to_keep = tf.cast(to_keep, tf.int32) # Nodes to keep in each graph index = tf.range(batch_n_nodes) index = (index - tf.gather(cumsum_start, I)) + (I * max_n_nodes) y_min = tf.reduce_min(scores) dense_y = tf.ones((n_graphs * max_n_nodes, )) dense_y = dense_y * tf.cast( y_min - 1, tf.float32 ) # subtract 1 to ensure that filler values do not get picked dense_y = tf.assign( top_k_var, dense_y, validate_shape=False ) # top_k_var is a variable with unknown shape defined in the elsewhere dense_y = tf.scatter_update(dense_y, index, scores) dense_y = tf.reshape(dense_y, (n_graphs, max_n_nodes)) perm = tf.argsort(dense_y, direction='DESCENDING') perm = perm + cumsum_start[:, None] perm = tf.reshape(perm, (-1, )) to_rep = tf.tile(tf.constant([1., 0.]), (n_graphs, )) rep_times = tf.reshape( tf.concat((to_keep[:, None], (max_n_nodes - to_keep)[:, None]), -1), (-1, )) mask = tf_repeat_1d(to_rep, rep_times) perm = tf.boolean_mask(perm, mask) return perm
def tf_repeat_1d(x, repeats): """ Repeats each value `x[i]` a number of times `repeats[i]`. :param x: a rank 1 tensor; :param repeats: a rank 1 tensor; :return: a rank 1 tensor, of shape `(sum(repeats), )`. """ x = tf.expand_dims(x, 1) max_repeats = tf.reduce_max(repeats) tile_repeats = [1, max_repeats] arr_tiled = tf.tile(x, tile_repeats) mask = tf.less(tf.range(max_repeats), tf.expand_dims(repeats, 1)) result = tf.reshape(tf.boolean_mask(arr_tiled, mask), [-1]) return result