Example #1
0
    def call(self, inputs):
        if len(inputs) == 3:
            X, A, I = inputs
            self.data_mode = 'disjoint'
        else:
            X, A = inputs
            I = tf.zeros(tf.shape(X)[:1])
            self.data_mode = 'single'
        if K.ndim(I) == 2:
            I = I[:, 0]
        I = tf.cast(I, tf.int32)

        A_is_sparse = K.is_sparse(A)

        # Get mask
        y = self.compute_scores(X, A, I)
        N = K.shape(X)[-2]
        indices = ops.segment_top_k(y[:, 0], I, self.ratio, self.top_k_var)
        mask = tf.scatter_nd(tf.expand_dims(indices, 1), tf.ones_like(indices), (N,))

        # Multiply X and y to make layer differentiable
        features = X * self.gating_op(y)

        axis = 0 if len(K.int_shape(A)) == 2 else 1  # Cannot use negative axis in tf.boolean_mask
        # Reduce X
        X_pooled = tf.boolean_mask(features, mask, axis=axis)

        # Compute A^2
        if A_is_sparse:
            A_dense = tf.sparse.to_dense(A)
        else:
            A_dense = A
        A_squared = K.dot(A, A_dense)

        # Reduce A
        A_pooled = tf.boolean_mask(A_squared, mask, axis=axis)
        A_pooled = tf.boolean_mask(A_pooled, mask, axis=axis + 1)
        if A_is_sparse:
            A_pooled = ops.dense_to_sparse(A_pooled)

        output = [X_pooled, A_pooled]

        # Reduce I
        if self.data_mode == 'disjoint':
            I_pooled = tf.boolean_mask(I[:, None], mask)[:, 0]
            output.append(I_pooled)

        if self.return_mask:
            output.append(mask)

        return output
Example #2
0
    def call(self, inputs):
        features = inputs[0]
        fltr = inputs[1]

        # Enforce sparsity
        if not K.is_sparse(fltr):
            fltr = ops.dense_to_sparse(fltr)

        # Propagation
        features_neigh = tf.math.segment_sum(
            tf.gather(features, fltr.indices[:, -1]), fltr.indices[:, -2])
        hidden = (1.0 + self.eps) * features + features_neigh

        # MLP
        output = self.mlp(hidden)

        return output
Example #3
0
    def call(self, inputs, mask=None):
        x, a, e = inputs

        # Parameters
        N = tf.shape(x)[-2]
        F = tf.shape(x)[-1]
        F_ = self.channels

        # Filter network
        kernel_network = e
        for layer in self.kernel_network_layers:
            kernel_network = layer(kernel_network)

        # Convolution
        mode = ops.autodetect_mode(x, a)
        if mode == modes.BATCH:
            kernel = K.reshape(kernel_network, (-1, N, N, F_, F))
            output = kernel * a[..., None, None]
            output = tf.einsum("abcde,ace->abd", output, x)
        else:
            # Enforce sparse representation
            if not K.is_sparse(a):
                warnings.warn("Casting dense adjacency matrix to SparseTensor."
                              "This can be an expensive operation. ")
                a = ops.dense_to_sparse(a)

            target_shape = (-1, F, F_)
            if mode == modes.MIXED:
                target_shape = (tf.shape(x)[0], ) + target_shape
            kernel = tf.reshape(kernel_network, target_shape)
            index_i = a.indices[:, 1]
            index_j = a.indices[:, 0]
            messages = tf.gather(x, index_j, axis=-2)
            messages = tf.einsum("...ab,...abc->...ac", messages, kernel)
            output = ops.scatter_sum(messages, index_i, N)

        if self.root:
            output += K.dot(x, self.root_kernel)
        if self.use_bias:
            output = K.bias_add(output, self.bias)
        if mask is not None:
            output *= mask[0]
        output = self.activation(output)

        return output
Example #4
0
    def call(self, inputs):
        features = inputs[0]
        fltr = inputs[1]

        if not K.is_sparse(fltr):
            fltr = ops.dense_to_sparse(fltr)

        features_neigh = self.aggregate_op(
            tf.gather(features, fltr.indices[:, -1]), fltr.indices[:, -2])
        output = K.concatenate([features, features_neigh])
        output = K.dot(output, self.kernel)

        if self.use_bias:
            output = K.bias_add(output, self.bias)
        if self.activation is not None:
            output = self.activation(output)
        output = K.l2_normalize(output, axis=-1)
        return output
Example #5
0
    def call(self, inputs):
        features = inputs[0]
        fltr = inputs[1]

        # Enforce sparse representation
        if not K.is_sparse(fltr):
            fltr = ops.dense_to_sparse(fltr)

        # Propagation
        targets = fltr.indices[:, -2]
        sources = fltr.indices[:, -1]
        messages = tf.gather(features, sources)
        aggregated = ops.scatter_sum(targets, messages, N=tf.shape(features)[0])
        hidden = (1.0 + self.eps) * features + aggregated

        # MLP
        output = self.mlp(hidden)

        return output
Example #6
0
    def _call_single(self, inputs):
        X = inputs[0]  # (N, F)
        A = inputs[1]  # (N, N)
        E = inputs[2]  # (n_edges, S)
        assert K.ndim(
            E) == 2, 'In single mode, E must have shape (n_edges, S).'

        # Enforce sparse representation
        if not K.is_sparse(A):
            A = ops.dense_to_sparse(A)

        # Parameters
        N = tf.shape(X)[-2]
        F = K.int_shape(X)[-1]
        F_ = self.channels

        # Filter network
        kernel_network = E
        for l in self.kernel_network_layers:
            kernel_network = l(kernel_network)  # (n_edges, F * F_)
        target_shape = (-1, F, F_)
        kernel = tf.reshape(kernel_network, target_shape)

        # Propagation
        index_i = A.indices[:, -2]
        index_j = A.indices[:, -1]
        messages = tf.gather(X, index_j)
        messages = ops.dot(messages[:, None, :], kernel)[:, 0, :]
        aggregated = ops.scatter_sum(messages, index_i, N)

        # Update
        output = aggregated
        if self.root:
            output += ops.dot(X, self.root_kernel)
        if self.use_bias:
            output = K.bias_add(output, self.bias)
        if self.activation is not None:
            output = self.activation(output)

        return output
Example #7
0
    def _call_single(self, inputs):
        x, a, e = inputs
        if K.ndim(e) != 2:
            raise ValueError('In single mode, E must have shape '
                             '(n_edges, n_edge_features).')

        # Enforce sparse representation
        if not K.is_sparse(a):
            a = ops.dense_to_sparse(a)

        # Parameters
        N = tf.shape(x)[-2]
        F = K.int_shape(x)[-1]
        F_ = self.channels

        # Filter network
        kernel_network = e
        for layer in self.kernel_network_layers:
            kernel_network = layer(kernel_network)  # (n_edges, F * F_)
        target_shape = (-1, F, F_)
        kernel = tf.reshape(kernel_network, target_shape)

        # Propagation
        index_i = a.indices[:, -2]
        index_j = a.indices[:, -1]
        messages = tf.gather(x, index_j)
        messages = ops.dot(messages[:, None, :], kernel)[:, 0, :]
        aggregated = ops.scatter_sum(messages, index_i, N)

        # Update
        output = aggregated
        if self.root:
            output += ops.dot(x, self.root_kernel)
        if self.use_bias:
            output = K.bias_add(output, self.bias)
        if self.activation is not None:
            output = self.activation(output)

        return output
Example #8
0
    def call(self, inputs):
        features = inputs[0]
        fltr = inputs[1]

        # Enforce sparse representation
        if not K.is_sparse(fltr):
            fltr = ops.dense_to_sparse(fltr)

        # Propagation
        indices = fltr.indices
        N = tf.shape(features, out_type=indices.dtype)[0]
        indices = ops.sparse_add_self_loops(indices, N)
        targets, sources = indices[:, -2], indices[:, -1]
        messages = tf.gather(features, sources)
        aggregated = self.aggregate_op(messages, targets, N)
        output = K.concatenate([features, aggregated])
        output = ops.dot(output, self.kernel)

        if self.use_bias:
            output = K.bias_add(output, self.bias)
        output = K.l2_normalize(output, axis=-1)
        if self.activation is not None:
            output = self.activation(output)
        return output