Beispiel #1
0
def test_gather_sparse_square():
    st = random_sparse((5, 5), 15, seed=0)
    dense = tf.sparse.to_dense(st)
    indices = np.array([1, 3, 4], dtype=np.int64)
    actual, _ = ops.gather_sparse_square(st, indices)
    for axis in (0, 1):
        dense = tf.gather(dense, indices, axis=axis)
    actual = tf.sparse.to_dense(actual)
    np.testing.assert_equal(actual.numpy(), dense.numpy())
Beispiel #2
0
    def call(self, inputs):
        if len(inputs) == 3:
            X, A, I = inputs
            self.data_mode = "disjoint"
        else:
            X, A = inputs
            I = tf.zeros(tf.shape(X)[:1])
            self.data_mode = "single"
        if K.ndim(I) == 2:
            I = I[:, 0]
        I = tf.cast(I, tf.int32)

        A_is_sparse = K.is_sparse(A)

        # Get mask
        y = self.compute_scores(X, A, I)
        N = K.shape(X)[-2]
        indices = ops.segment_top_k(y[:, 0], I, self.ratio)
        indices = tf.sort(indices)  # required for ordered SparseTensors
        mask = ops.indices_to_mask(indices, N)

        # Multiply X and y to make layer differentiable
        features = X * self.gating_op(y)

        axis = (0 if len(K.int_shape(A)) == 2 else 1
                )  # Cannot use negative axis in tf.boolean_mask
        # Reduce X
        X_pooled = tf.gather(features, indices, axis=axis)

        # Reduce A
        if A_is_sparse:
            A_pooled, _ = ops.gather_sparse_square(A, indices, mask=mask)
        else:
            A_pooled = tf.gather(A, indices, axis=axis)
            A_pooled = tf.gather(A_pooled, indices, axis=axis + 1)

        output = [X_pooled, A_pooled]

        # Reduce I
        if self.data_mode == "disjoint":
            I_pooled = tf.gather(I, indices)
            output.append(I_pooled)

        if self.return_mask:
            output.append(mask)

        return output