Esempio n. 1
0
def multi_graph_conv_v1(
    x: Union[tf.Tensor, tf.SparseTensor],
    adjacencies: Sequence[tf.SparseTensor],
    kernel: Union[tf.Tensor, tf.Variable],
    sparse_impl: SparseImplementation = SparseImplementation.COO,
    transform_first: Optional[bool] = None,
):
    SparseImplementation.validate(sparse_impl)
    _validate_shapes(x, adjacencies, kernel)
    kernel = tf.convert_to_tensor(kernel)
    transform_first = _transform_first(transform_first, kernel.shape)

    if not is_dense_tensor(x):
        x = to_sparse_impl(x, sparse_impl)

    filters_in, num_adj, filters_out = kernel.shape
    if transform_first:
        kernel = tf.reshape(kernel, (filters_in, num_adj * filters_out))
        x = matmul(x, kernel)
        x = tf.reshape(x, (-1, num_adj, filters_out))
        xs = [
            matmul(to_sparse_impl(adj, sparse_impl), x)
            for adj, x in zip(adjacencies, tf.unstack(x, axis=1))
        ]
        return tf.add_n(xs)

    # transform second
    xs = [matmul(to_sparse_impl(adj, sparse_impl), x) for adj in adjacencies]
    x = tf.reshape(tf.stack(xs, axis=-1), (-1, filters_in * num_adj))
    kernel = tf.reshape(kernel, (filters_in * num_adj, filters_out))
    return matmul(x, kernel)
Esempio n. 2
0
def graph_conv(
    x: Union[tf.Tensor, tf.SparseTensor],
    adjacency: tf.SparseTensor,
    kernel: Union[tf.Tensor, tf.Variable],
    sparse_impl: SparseImplementation = SparseImplementation.COO,
    transform_first: Optional[bool] = None,
):
    SparseImplementation.validate(sparse_impl)
    kernel = tf.convert_to_tensor(kernel)
    transform_first = _transform_first(transform_first, kernel.shape)
    if not is_dense_tensor(x):
        x = to_sparse_impl(x, sparse_impl)

    adjacency = to_sparse_impl(adjacency, sparse_impl)
    if transform_first:
        return matmul(adjacency, matmul(x, kernel))
    return matmul(matmul(adjacency, x), kernel)
Esempio n. 3
0
def multi_graph_conv_v2(
    x: Union[tf.Tensor, tf.SparseTensor],
    adjacencies: Sequence[tf.SparseTensor],
    kernel: Union[tf.Tensor, tf.Variable],
    sparse_impl: SparseImplementation = SparseImplementation.COO,
    transform_first: Optional[bool] = None,
):
    SparseImplementation.validate(sparse_impl)
    _validate_shapes(x, adjacencies, kernel)
    kernel = tf.convert_to_tensor(kernel)
    transform_first = _transform_first(transform_first, kernel.shape)
    nodes_out, nodes_in = tf.unstack(adjacencies[0].dense_shape)
    filters_in, num_adj, filters_out = kernel.shape

    if not is_dense_tensor(x):
        x = to_sparse_impl(x, sparse_impl)

    if transform_first:
        kernel = tf.reshape(kernel, (filters_in, num_adj * filters_out))
        x = matmul(x, kernel)
        x = tf.reshape(x, (nodes_in * num_adj, filters_out))
        adjacency = tf.sparse.reshape(  # pylint: disable=no-value-for-parameter
            sparse_stack(adjacencies, axis=-1), (nodes_out, nodes_in * num_adj)
        )
        return matmul(adjacency, x)

    # transform second
    adjacency = tf.sparse.concat(sp_inputs=adjacencies, axis=1)  # no, a*ni
    adjacency = tf.sparse.reshape(  # pylint: disable=no-value-for-parameter
        adjacency, (nodes_out * num_adj, nodes_in)
    )
    adjacency = to_sparse_impl(adjacency, sparse_impl)
    x = matmul(adjacency, x)
    x = tf.reshape(x, (nodes_out, num_adj * filters_in))
    kernel = tf.transpose(kernel, (1, 0, 2))  # this transpose annoys me greatly
    kernel = tf.reshape(kernel, (num_adj * filters_in, filters_out))
    return tf.matmul(x, kernel)
Esempio n. 4
0
def graph_conv(
    adj: tp.Union[tf.Tensor, tf.SparseTensor],
    features: tf.Tensor,
    features0: tf.Tensor,
    kernel: tf.Tensor,
    alpha: float,
    beta: float,
    variant: bool,
) -> tf.Tensor:
    hi = ops.matmul(adj, features)
    if variant:
        support = tf.concat(  # pylint: disable=unexpected-keyword-arg,no-value-for-parameter
            [hi, features0],
            axis=1)
        r = (1 - alpha) * hi + alpha * features
    else:
        support = (1 - alpha) * hi + alpha * features0
        r = support
    output = beta * support @ kernel + (1 - beta) * r
    return output
Esempio n. 5
0
 def call(self, inp: PropagationInput):
     T, X = inp
     out = ops.matmul(T, X)
     out.set_shape(X.shape)
     return out