def test_segment_top_k(): x = np.array([0.2, 0.5, 0.3, -0.1, -0.2, -0.1], dtype=np.float32) I = np.array([0, 0, 0, 0, 1, 1], dtype=np.int64) ratio = 0.5 topk = ops.segment_top_k(x, I, ratio) actual = topk.numpy() expected = [1, 2, 5] np.testing.assert_equal(actual, expected)
def call(self, inputs): if len(inputs) == 3: X, A, I = inputs self.data_mode = 'graph' else: X, A = inputs I = tf.zeros(tf.shape(X)[:1], dtype=tf.int32) self.data_mode = 'single' if K.ndim(I) == 2: I = I[:, 0] A_is_sparse = K.is_sparse(A) # Get mask y = K.dot(X, self.kernel) y = filter_dot(A, y) N = K.shape(X)[-2] indices = ops.segment_top_k(y[:, 0], I, self.ratio, self.top_k_var) mask = tf.scatter_nd(tf.expand_dims(indices, 1), tf.ones_like(indices), (N, )) # Multiply X and y to make layer differentiable features = X * self.gating_op(y) axis = 0 if len(K.int_shape( A)) == 2 else 1 # Cannot use negative axis in tf.boolean_mask # Reduce X X_pooled = tf.boolean_mask(features, mask, axis=axis) # Compute A^2 if A_is_sparse: A_dense = tf.sparse.to_dense(A) else: A_dense = A A_squared = K.dot(A, A_dense) # Reduce A A_pooled = tf.boolean_mask(A_squared, mask, axis=axis) A_pooled = tf.boolean_mask(A_pooled, mask, axis=axis + 1) if A_is_sparse: A_pooled = tf.contrib.layers.dense_to_sparse(A_pooled) output = [X_pooled, A_pooled] # Reduce I if self.data_mode == 'graph': I_pooled = tf.boolean_mask(I[:, None], mask)[:, 0] output.append(I_pooled) if self.return_mask: output.append(mask) return output
def call(self, inputs): if len(inputs) == 3: X, A, I = inputs self.data_mode = "disjoint" else: X, A = inputs I = tf.zeros(tf.shape(X)[:1]) self.data_mode = "single" if K.ndim(I) == 2: I = I[:, 0] I = tf.cast(I, tf.int32) A_is_sparse = K.is_sparse(A) # Get mask y = self.compute_scores(X, A, I) N = K.shape(X)[-2] indices = ops.segment_top_k(y[:, 0], I, self.ratio) indices = tf.sort(indices) # required for ordered SparseTensors mask = ops.indices_to_mask(indices, N) # Multiply X and y to make layer differentiable features = X * self.gating_op(y) axis = (0 if len(K.int_shape(A)) == 2 else 1 ) # Cannot use negative axis in tf.boolean_mask # Reduce X X_pooled = tf.gather(features, indices, axis=axis) # Reduce A if A_is_sparse: A_pooled, _ = ops.gather_sparse_square(A, indices, mask=mask) else: A_pooled = tf.gather(A, indices, axis=axis) A_pooled = tf.gather(A_pooled, indices, axis=axis + 1) output = [X_pooled, A_pooled] # Reduce I if self.data_mode == "disjoint": I_pooled = tf.gather(I, indices) output.append(I_pooled) if self.return_mask: output.append(mask) return output