Ejemplo n.º 1
0
    def call(self, inputs):
        if len(inputs) == 3:
            X, A, I = inputs
            self.data_mode = "disjoint"
        else:
            X, A = inputs
            I = tf.zeros(tf.shape(X)[:1])
            self.data_mode = "single"
        if K.ndim(I) == 2:
            I = I[:, 0]
        I = tf.cast(I, tf.int32)

        A_is_sparse = K.is_sparse(A)

        # Get mask
        y = self.compute_scores(X, A, I)
        N = K.shape(X)[-2]
        indices = ops.segment_top_k(y[:, 0], I, self.ratio)
        indices = tf.sort(indices)  # required for ordered SparseTensors
        mask = ops.indices_to_mask(indices, N)

        # Multiply X and y to make layer differentiable
        features = X * self.gating_op(y)

        axis = (0 if len(K.int_shape(A)) == 2 else 1
                )  # Cannot use negative axis in tf.boolean_mask
        # Reduce X
        X_pooled = tf.gather(features, indices, axis=axis)

        # Reduce A
        if A_is_sparse:
            A_pooled, _ = ops.gather_sparse_square(A, indices, mask=mask)
        else:
            A_pooled = tf.gather(A, indices, axis=axis)
            A_pooled = tf.gather(A_pooled, indices, axis=axis + 1)

        output = [X_pooled, A_pooled]

        # Reduce I
        if self.data_mode == "disjoint":
            I_pooled = tf.gather(I, indices)
            output.append(I_pooled)

        if self.return_mask:
            output.append(mask)

        return output
Ejemplo n.º 2
0
def test_indices_to_mask_rank2():
    indices = [[0, 2], [1, 1], [2, 1]]
    mask = ops.indices_to_mask(indices, [3, 3])
    expected = [[0, 0, 1], [0, 1, 0], [0, 1, 0]]
    np.testing.assert_equal(mask.numpy(), expected)
Ejemplo n.º 3
0
def test_indices_to_mask_rank1():
    indices = [1, 3, 4]
    mask = ops.indices_to_mask(indices, 6)
    np.testing.assert_equal(mask.numpy(), [0, 1, 0, 1, 1, 0])