コード例 #1
0
    def _call_single(self, X, A):
        # Reshape kernels for efficient message-passing
        kernel = tf.reshape(self.kernel, (-1, self.attn_heads * self.channels))
        attn_kernel_self = ops.transpose(self.attn_kernel_self, (2, 1, 0))
        attn_kernel_neighs = ops.transpose(self.attn_kernel_neighs, (2, 1, 0))

        # Prepare message-passing
        indices = A.indices
        N = tf.shape(X, out_type=indices.dtype)[0]
        indices = ops.sparse_add_self_loops(indices, N)
        targets, sources = indices[:, -2], indices[:, -1]

        # Update node features
        X = ops.dot(X, kernel)
        X = tf.reshape(X, (-1, self.attn_heads, self.channels))

        # Compute attention
        attn_for_self = tf.reduce_sum(X * attn_kernel_self, -1)
        attn_for_self = tf.gather(attn_for_self, targets)
        attn_for_neighs = tf.reduce_sum(X * attn_kernel_neighs, -1)
        attn_for_neighs = tf.gather(attn_for_neighs, sources)

        attn_coef = attn_for_self + attn_for_neighs
        attn_coef = tf.nn.leaky_relu(attn_coef, alpha=0.2)
        attn_coef = ops.unsorted_segment_softmax(attn_coef, targets, N)
        attn_coef = self.dropout(attn_coef)
        attn_coef = attn_coef[..., None]

        # Update representation
        output = attn_coef * tf.gather(X, sources)
        output = ops.scatter_sum(output, targets, N)

        return output, attn_coef
コード例 #2
0
    def call(self, inputs):
        X = inputs[0]  # (batch_size, N, F)
        A = inputs[1]  # (batch_size, N, N)
        E = inputs[2]  # (n_edges, S) or (batch_size, N, N, S)

        mode = ops.autodetect_mode(A, X)
        if mode == modes.SINGLE:
            return self._call_single(inputs)

        # Parameters
        N = K.shape(X)[-2]
        F = K.int_shape(X)[-1]
        F_ = self.channels

        # Filter network
        kernel_network = E
        for l in self.kernel_network_layers:
            kernel_network = l(kernel_network)

        # Convolution
        target_shape = (-1, N, N, F_, F) if mode == modes.BATCH else (N, N, F_,
                                                                      F)
        kernel = K.reshape(kernel_network, target_shape)
        output = kernel * A[..., None, None]
        output = tf.einsum('abicf,aif->abc', output, X)

        if self.root:
            output += ops.dot(X, self.root_kernel)
        if self.use_bias:
            output = K.bias_add(output, self.bias)
        if self.activation is not None:
            output = self.activation(output)

        return output
コード例 #3
0
ファイル: ecc_conv.py プロジェクト: zdqf/spektral
    def call(self, inputs):
        x, a, e = inputs

        mode = ops.autodetect_mode(a, x)
        if mode == modes.SINGLE:
            return self._call_single(inputs)

        # Parameters
        N = K.shape(x)[-2]
        F = K.int_shape(x)[-1]
        F_ = self.channels

        # Filter network
        kernel_network = e
        for layer in self.kernel_network_layers:
            kernel_network = layer(kernel_network)

        # Convolution
        target_shape = (-1, N, N, F_, F) if mode == modes.BATCH else (N, N, F_,
                                                                      F)
        kernel = K.reshape(kernel_network, target_shape)
        output = kernel * a[..., None, None]
        output = tf.einsum('abicf,aif->abc', output, x)

        if self.root:
            output += ops.dot(x, self.root_kernel)
        if self.use_bias:
            output = K.bias_add(output, self.bias)
        if self.activation is not None:
            output = self.activation(output)

        return output
コード例 #4
0
ファイル: gcn_conv.py プロジェクト: zdqf/spektral
    def call(self, inputs):
        x, a = inputs

        output = ops.dot(x, self.kernel)
        output = ops.filter_dot(a, output)

        if self.use_bias:
            output = K.bias_add(output, self.bias)
        output = self.activation(output)

        return output
コード例 #5
0
    def call(self, inputs):
        x, a = inputs

        T_0 = x
        output = ops.dot(T_0, self.kernel[0])

        if self.K > 1:
            T_1 = ops.filter_dot(a, x)
            output += ops.dot(T_1, self.kernel[1])

        for k in range(2, self.K):
            T_2 = 2 * ops.filter_dot(a, T_1) - T_0
            output += ops.dot(T_2, self.kernel[k])
            T_0, T_1 = T_1, T_2

        if self.use_bias:
            output = K.bias_add(output, self.bias)
        output = self.activation(output)

        return output
コード例 #6
0
    def call(self, inputs):
        features = inputs

        # Convolution
        output = ops.dot(features, self.kernel)
        output = ops.mixed_mode_dot(self.fltr, output)

        if self.use_bias:
            output = K.bias_add(output, self.bias)
        if self.activation is not None:
            output = self.activation(output)
        return output
コード例 #7
0
    def _call_single(self, inputs):
        X = inputs[0]  # (N, F)
        A = inputs[1]  # (N, N)
        E = inputs[2]  # (n_edges, S)
        assert K.ndim(
            E) == 2, 'In single mode, E must have shape (n_edges, S).'

        # Enforce sparse representation
        if not K.is_sparse(A):
            A = ops.dense_to_sparse(A)

        # Parameters
        N = tf.shape(X)[-2]
        F = K.int_shape(X)[-1]
        F_ = self.channels

        # Filter network
        kernel_network = E
        for l in self.kernel_network_layers:
            kernel_network = l(kernel_network)  # (n_edges, F * F_)
        target_shape = (-1, F, F_)
        kernel = tf.reshape(kernel_network, target_shape)

        # Propagation
        index_i = A.indices[:, -2]
        index_j = A.indices[:, -1]
        messages = tf.gather(X, index_j)
        messages = ops.dot(messages[:, None, :], kernel)[:, 0, :]
        aggregated = ops.scatter_sum(messages, index_i, N)

        # Update
        output = aggregated
        if self.root:
            output += ops.dot(X, self.root_kernel)
        if self.use_bias:
            output = K.bias_add(output, self.bias)
        if self.activation is not None:
            output = self.activation(output)

        return output
コード例 #8
0
ファイル: graph_conv.py プロジェクト: yaniv256/spektral
    def call(self, inputs):
        features = inputs[0]
        fltr = inputs[1]

        # Convolution
        output = ops.dot(features, self.kernel)
        output = ops.filter_dot(fltr, output)

        if self.use_bias:
            output = K.bias_add(output, self.bias)
        if self.activation is not None:
            output = self.activation(output)
        return output
コード例 #9
0
    def call(self, inputs):
        features = inputs[0]
        laplacian = inputs[1]

        # Convolution
        T_0 = features
        output = ops.dot(T_0, self.kernel[0])

        if self.K > 1:
            T_1 = ops.filter_dot(laplacian, features)
            output += ops.dot(T_1, self.kernel[1])

        for k in range(2, self.K):
            T_2 = 2 * ops.filter_dot(laplacian, T_1) - T_0
            output += ops.dot(T_2, self.kernel[k])
            T_0, T_1 = T_1, T_2

        if self.use_bias:
            output = K.bias_add(output, self.bias)
        if self.activation is not None:
            output = self.activation(output)
        return output
コード例 #10
0
ファイル: ecc_conv.py プロジェクト: zdqf/spektral
    def _call_single(self, inputs):
        x, a, e = inputs
        if K.ndim(e) != 2:
            raise ValueError('In single mode, E must have shape '
                             '(n_edges, n_edge_features).')

        # Enforce sparse representation
        if not K.is_sparse(a):
            a = ops.dense_to_sparse(a)

        # Parameters
        N = tf.shape(x)[-2]
        F = K.int_shape(x)[-1]
        F_ = self.channels

        # Filter network
        kernel_network = e
        for layer in self.kernel_network_layers:
            kernel_network = layer(kernel_network)  # (n_edges, F * F_)
        target_shape = (-1, F, F_)
        kernel = tf.reshape(kernel_network, target_shape)

        # Propagation
        index_i = a.indices[:, -2]
        index_j = a.indices[:, -1]
        messages = tf.gather(x, index_j)
        messages = ops.dot(messages[:, None, :], kernel)[:, 0, :]
        aggregated = ops.scatter_sum(messages, index_i, N)

        # Update
        output = aggregated
        if self.root:
            output += ops.dot(x, self.root_kernel)
        if self.use_bias:
            output = K.bias_add(output, self.bias)
        if self.activation is not None:
            output = self.activation(output)

        return output
コード例 #11
0
    def call(self, inputs, **kwargs):
        x, a, _ = self.get_inputs(inputs)

        # TODO: a = add_self_loops(a)

        x = dot(x, self.kernel)
        if self.use_bias:
            x = tf.nn.bias_add(x, self.bias)
        if self.use_batch_norm:
            x = self.batch_norm(x)
        x = self.dropout(x)
        x = self.activation(x)

        return self.propagate(x, a)
コード例 #12
0
    def call(self, inputs):
        x, a, _ = self.get_inputs(inputs)
        a = ops.add_self_loops(a)

        aggregated = self.propagate(x, a)
        output = K.concatenate([x, aggregated])
        output = ops.dot(output, self.kernel)

        if self.use_bias:
            output = K.bias_add(output, self.bias)
        output = K.l2_normalize(output, axis=-1)
        if self.activation is not None:
            output = self.activation(output)

        return output
コード例 #13
0
def k_hop_sparse_subgraph(a, node_idx, k, transformer=None):
    """
    Computes the subgraph containing all the neighbors of `node_idx` up to the k-th order.
    If `a` is not the binary adjacency matrix a  `transformer` should be passed.
    **Arguments**
    - `a`: sparse `(n_nodes, n_nodes)` graph tensor;
    - `node_idx`: center node;
    - `k`: order of neighbor;
    - `transformer`: one of the functions from the `spektral.transforms` module,
       needed to convert the binary adjacency matrix into the correct format for the model;
    """
    if a.dtype != tf.float32:
        a = tf.cast(a, tf.float32)

    if transformer:
        a = binary_adj_converter(a)

    power_a = tf.sparse.eye(a.shape[0])
    k_neighs = np.zeros(a.shape[0]).astype("float32").reshape(1, -1)
    k_neighs[0, node_idx] = 1

    for _ in range(k - 1):
        power_a = dot(power_a, a)
        temp = tf.sparse.slice(power_a,
                               start=[node_idx, 0],
                               size=[1, power_a.shape[0]])
        k_neighs += tf.sparse.to_dense(temp)

    comp_graph = tf.sparse.add(a * tf.reshape(k_neighs, (-1, 1)), a * k_neighs)
    is_nonzero = tf.not_equal(comp_graph.values, 0)
    comp_graph = tf.sparse.retain(comp_graph, is_nonzero)
    comp_graph = tf.sign(comp_graph)

    if transformer:
        comp_graph = sp_tensor_to_sp_matrix(comp_graph)
        comp_graph = transformer(comp_graph)
        return sp_matrix_to_sp_tensor(comp_graph)
    else:
        return comp_graph
コード例 #14
0
ファイル: graphsage_conv.py プロジェクト: wxiangs/spektral
    def call(self, inputs):
        features = inputs[0]
        fltr = inputs[1]

        # Enforce sparse representation
        if not K.is_sparse(fltr):
            fltr = ops.dense_to_sparse(fltr)

        # Propagation
        indices = fltr.indices
        N = tf.shape(features, out_type=indices.dtype)[0]
        indices = ops.sparse_add_self_loops(indices, N)
        targets, sources = indices[:, -2], indices[:, -1]
        messages = tf.gather(features, sources)
        aggregated = self.aggregate_op(messages, targets, N)
        output = K.concatenate([features, aggregated])
        output = ops.dot(output, self.kernel)

        if self.use_bias:
            output = K.bias_add(output, self.bias)
        output = K.l2_normalize(output, axis=-1)
        if self.activation is not None:
            output = self.activation(output)
        return output