def call(self, inputs): if len(inputs) == 3: X, A, I = inputs self.data_mode = "disjoint" else: X, A = inputs I = tf.zeros(tf.shape(X)[:1]) self.data_mode = "single" if K.ndim(I) == 2: I = I[:, 0] I = tf.cast(I, tf.int32) A_is_sparse = K.is_sparse(A) # Get mask y = self.compute_scores(X, A, I) N = K.shape(X)[-2] indices = ops.segment_top_k(y[:, 0], I, self.ratio) indices = tf.sort(indices) # required for ordered SparseTensors mask = ops.indices_to_mask(indices, N) # Multiply X and y to make layer differentiable features = X * self.gating_op(y) axis = (0 if len(K.int_shape(A)) == 2 else 1 ) # Cannot use negative axis in tf.boolean_mask # Reduce X X_pooled = tf.gather(features, indices, axis=axis) # Reduce A if A_is_sparse: A_pooled, _ = ops.gather_sparse_square(A, indices, mask=mask) else: A_pooled = tf.gather(A, indices, axis=axis) A_pooled = tf.gather(A_pooled, indices, axis=axis + 1) output = [X_pooled, A_pooled] # Reduce I if self.data_mode == "disjoint": I_pooled = tf.gather(I, indices) output.append(I_pooled) if self.return_mask: output.append(mask) return output
def test_indices_to_mask_rank2(): indices = [[0, 2], [1, 1], [2, 1]] mask = ops.indices_to_mask(indices, [3, 3]) expected = [[0, 0, 1], [0, 1, 0], [0, 1, 0]] np.testing.assert_equal(mask.numpy(), expected)
def test_indices_to_mask_rank1(): indices = [1, 3, 4] mask = ops.indices_to_mask(indices, 6) np.testing.assert_equal(mask.numpy(), [0, 1, 0, 1, 1, 0])