コード例 #1
0
    def _kgcnn_map_output_ragged(self, inputs, input_partition_type, output_tensor_type=None):
        x = inputs
        ragged_keys = ["ragged", "RaggedTensor"]
        value_partition_keys = ["disjoint", "values_partition"]

        if output_tensor_type is None:
            output_tensor_type = self.output_tensor_type
        elif isinstance(output_tensor_type, int):
            output_tensor_type = self._tensor_input_type_found[output_tensor_type]

        if isinstance(x, tf.RaggedTensor):
            if output_tensor_type in ragged_keys:
                return x
            elif output_tensor_type in value_partition_keys:
                return kgcnn_ops_dyn_cast(x, input_tensor_type="ragged",
                                          output_tensor_type=output_tensor_type, partition_type=self.partition_type)
            else:
                raise TypeError("Error:", self.name, "output type for ragged-like input is not supported for", x)

        elif isinstance(x, list):
            if len(x) != 2:
                print("Warning:", self.name, "output does not match rank=1 partition scheme for batch dimension.")
            if output_tensor_type in ragged_keys:
                return kgcnn_ops_dyn_cast(x, input_tensor_type="values_partition",
                                          output_tensor_type=output_tensor_type, partition_type=input_partition_type)
            elif output_tensor_type in value_partition_keys:
                tens_part = kgcnn_ops_change_partition_type(x[1], input_partition_type, self.partition_type)
                return [x[0], tens_part]
            else:
                raise TypeError("Error:", self.name, "output type for ragged-like input is not supported for", x)
        else:
            raise TypeError("Error:", self.name, "input type for ragged-like input is not supported for", x)
コード例 #2
0
ファイル: sparse.py プロジェクト: thegodone/gcnn_keras
    def call(self, inputs, **kwargs):
        """Forward pass.

        Args:
            inputs: [nodes, adjacency]

            - nodes: Node features of shape (batch, [N], F)
            - adjacency (tf.sparse): SparseTensor of the adjacency matrix of shape (batch*None,batch*None)

        Returns:
            features (tf.tensor): Pooled node features of shape (batch,F)
        """
        adj = inputs[1]
        found_node_type = kgcnn_ops_get_tensor_type(
            inputs[0],
            input_tensor_type=self.input_tensor_type,
            node_indexing=self.node_indexing)
        node, node_part = kgcnn_ops_dyn_cast(
            inputs[0],
            input_tensor_type=found_node_type,
            output_tensor_type="values_partition",
            partition_type=self.partition_type)

        out = tf.sparse.sparse_dense_matmul(adj, node)

        return kgcnn_ops_dyn_cast([out, node_part],
                                  input_tensor_type="values_partition",
                                  output_tensor_type=found_node_type,
                                  partition_type=self.partition_type)
コード例 #3
0
ファイル: pooling.py プロジェクト: thegodone/gcnn_keras
    def call(self, inputs, **kwargs):
        """Forward pass.

        The tensor representation can be tf.RaggedTensor, tf.Tensor or a list of (values, partition).
        The RaggedTensor has shape (batch, None, F) or in case of equal sized graphs (batch, N, F).
        For disjoint representation (values, partition), the node embeddings are given by
        a flatten value tensor of shape (batch*None, F) and a partition tensor of either "row_length",
        "row_splits" or "value_rowids" that matches the tf.RaggedTensor partition information. In this case
        the partition_type and node_indexing scheme, i.e. "batch", must be known by the layer.
        For edge indices, the last dimension holds indices from outgoing to ingoing node (i,j) as a directed edge.

        Args:
            inputs (list): of [node, edges, edge_index]

            - nodes: Node features of shape (batch, [N], F)
            - edges: Edge or message features of shape (batch, [N], F)
            - edge_index: Edge indices of shape (batch, [N], 2)
    
        Returns:
            features: Pooled feature tensor of pooled edge features for each node.
        """
        found_node_type = kgcnn_ops_get_tensor_type(inputs[0], input_tensor_type=self.input_tensor_type,
                                                    node_indexing=self.node_indexing)
        found_edge_type = kgcnn_ops_get_tensor_type(inputs[1], input_tensor_type=self.input_tensor_type,
                                                    node_indexing=self.node_indexing)
        found_index_type = kgcnn_ops_get_tensor_type(inputs[2], input_tensor_type=self.input_tensor_type,
                                                     node_indexing=self.node_indexing)

        nod, node_part = kgcnn_ops_dyn_cast(inputs[0], input_tensor_type=found_node_type,
                                            output_tensor_type="values_partition",
                                            partition_type=self.partition_type)
        edge, _ = kgcnn_ops_dyn_cast(inputs[1], input_tensor_type=found_edge_type,
                                     output_tensor_type="values_partition",
                                     partition_type=self.partition_type)
        edgeind, edge_part = kgcnn_ops_dyn_cast(inputs[2], input_tensor_type=found_index_type,
                                                output_tensor_type="values_partition",
                                                partition_type=self.partition_type)

        shiftind = kgcnn_ops_change_edge_tensor_indexing_by_row_partition(edgeind, node_part, edge_part,
                                                                          partition_type_node=self.partition_type,
                                                                          partition_type_edge=self.partition_type,
                                                                          to_indexing='batch',
                                                                          from_indexing=self.node_indexing)

        nodind = shiftind[:, 0]  # Pick first index eg. ingoing
        dens = edge
        if not self.is_sorted:
            # Sort edgeindices
            node_order = tf.argsort(nodind, axis=0, direction='ASCENDING', stable=True)
            nodind = tf.gather(nodind, node_order, axis=0)
            dens = tf.gather(dens, node_order, axis=0)
        # Pooling via e.g. segment_sum
        out = kgcnn_ops_segment_operation_by_name(self.pooling_method, dens, nodind)
        if self.has_unconnected:
            out = kgcnn_ops_scatter_segment_tensor_nd(out, nodind, tf.shape(nod))

        return kgcnn_ops_dyn_cast([out, node_part], input_tensor_type="values_partition",
                                  output_tensor_type=found_node_type, partition_type=self.partition_type)
コード例 #4
0
    def call(self, inputs, **kwargs):
        """Forward pass.

        The tensor representation can be tf.RaggedTensor, tf.Tensor or a list of (values, partition).
        The RaggedTensor has shape (batch, None, F) or in case of equal sized graphs (batch, N, F).
        For disjoint representation (values, partition), the node embeddings are given by
        a flatten value tensor of shape (batch*None, F) and a partition tensor of either "row_length",
        "row_splits" or "value_rowids" that matches the tf.RaggedTensor partition information. In this case
        the partition_type and node_indexing scheme, i.e. "batch", must be known by the layer.
        For edge indices, the last dimension holds indices from outgoing to ingoing node (i,j) as a directed edge.

        Args:
            inputs (list): [nodes, edge_index]

            - nodes: Node embeddings of shape (batch, [N], F)
            - edge_index: Edge indices of shape (batch, [N], 2)

        Returns:
            embeddings: Gathered node embeddings that match the number of edges.
        """
        found_node_type = kgcnn_ops_get_tensor_type(
            inputs[0],
            input_tensor_type=self.input_tensor_type,
            node_indexing=self.node_indexing)
        found_edge_type = kgcnn_ops_get_tensor_type(
            inputs[1],
            input_tensor_type=self.input_tensor_type,
            node_indexing=self.node_indexing)

        # We cast to values here
        node, node_part = kgcnn_ops_dyn_cast(
            inputs[0],
            input_tensor_type=found_node_type,
            output_tensor_type="values_partition",
            partition_type=self.partition_type)
        edge_index, edge_part = kgcnn_ops_dyn_cast(
            inputs[1],
            input_tensor_type=found_edge_type,
            output_tensor_type="values_partition",
            partition_type=self.partition_type)

        indexlist = kgcnn_ops_change_edge_tensor_indexing_by_row_partition(
            edge_index,
            node_part,
            edge_part,
            partition_type_node=self.partition_type,
            partition_type_edge=self.partition_type,
            to_indexing='batch',
            from_indexing=self.node_indexing)
        out = tf.gather(node, indexlist[:, 0], axis=0)

        # For ragged tensor we can now also try:
        # out = tf.gather(nod, edge_index[:, :, 0], batch_dims=1)

        return kgcnn_ops_dyn_cast([out, edge_part],
                                  input_tensor_type="values_partition",
                                  output_tensor_type=found_edge_type,
                                  partition_type=self.partition_type)
コード例 #5
0
    def call(self, inputs, **kwargs):
        """Forward pass.

        The tensor representation can be tf.RaggedTensor, tf.Tensor or a list of (values, partition).
        The RaggedTensor has shape (batch, None, F) or in case of equal sized graphs (batch, N, F).
        For disjoint representation (values, partition), the node embeddings are given by
        a flatten value tensor of shape (batch*None, F) and a partition tensor of either "row_length",
        "row_splits" or "value_rowids" that matches the tf.RaggedTensor partition information. In this case
        the partition_type and node_indexing scheme, i.e. "batch", must be known by the layer.
        For edge indices, the last dimension holds indices from outgoing to ingoing node (i,j) as a directed edge.

        Args:
            inputs: [state, target]

            - state: Graph specific embedding tensor. This is usually tensor of shape (batch, F.)
            - target: Target to collect state for. This can be node or edge embeddings of shape (batch, [N], F)

        Returns:
            state: Graph embedding with repeated single state for each graph.
        """
        found_state_type = kgcnn_ops_get_tensor_type(
            inputs[0],
            input_tensor_type="tensor",
            node_indexing=self.node_indexing)
        found_ref_type = kgcnn_ops_get_tensor_type(
            inputs[1],
            input_tensor_type=self.input_tensor_type,
            node_indexing=self.node_indexing)
        # We cast to values here
        env = kgcnn_ops_dyn_cast(inputs[0],
                                 input_tensor_type=found_state_type,
                                 output_tensor_type="tensor",
                                 partition_type=self.partition_type)
        _, target_part = kgcnn_ops_dyn_cast(
            inputs[1],
            input_tensor_type=found_ref_type,
            output_tensor_type="values_partition",
            partition_type=self.partition_type)

        target_len = kgcnn_ops_change_partition_type(target_part,
                                                     self.partition_type,
                                                     "row_length")

        out = tf.repeat(env, target_len, axis=0)

        return kgcnn_ops_dyn_cast([out, target_part],
                                  input_tensor_type="values_partition",
                                  output_tensor_type=found_ref_type,
                                  partition_type=self.partition_type)
コード例 #6
0
ファイル: pooling.py プロジェクト: thegodone/gcnn_keras
    def call(self, inputs, **kwargs):
        """Forward pass.

        The tensor representation can be tf.RaggedTensor, tf.Tensor or a list of (values, partition).
        The RaggedTensor has shape (batch, None, F) or in case of equal sized graphs (batch, N, F).
        For disjoint representation (values, partition), the node embeddings are given by
        a flatten value tensor of shape (batch*None, F) and a partition tensor of either "row_length",
        "row_splits" or "value_rowids" that matches the tf.RaggedTensor partition information. In this case
        the partition_type and node_indexing scheme, i.e. "batch", must be known by the layer.
        For edge indices, the last dimension holds indices from outgoing to ingoing node (i,j) as a directed edge.

        Args:
            inputs: Edge features or message embeddings of shape (batch, [N], F)
    
        Returns:
            tf.tensor: Pooled edges feature list of shape (batch,F).
        """
        found_edge_type = kgcnn_ops_get_tensor_type(inputs, input_tensor_type=self.input_tensor_type,
                                                    node_indexing=self.node_indexing)
        edge, edge_part = kgcnn_ops_dyn_cast(inputs, input_tensor_type=found_edge_type,
                                             output_tensor_type="values_partition",
                                             partition_type=self.partition_type)

        batchi = kgcnn_ops_change_partition_type(edge_part, self.partition_type, "value_rowids")

        out = kgcnn_ops_segment_operation_by_name(self.pooling_method, edge, batchi)
        # Output already has correct shape and type
        return out
コード例 #7
0
    def call(self, inputs, **kwargs):
        """Forward pass.

        This can be either a tuple of (values, partition) tensors of shape (batch*None,F)
        and a partition tensor of the type "row_length", "row_splits" or "value_rowids". This usually uses
        disjoint indexing defined by 'node_indexing'. Or a tuple of (values, mask) tensors of shape
        (batch, N, F) and mask (batch, N) or a single RaggedTensor of shape (batch,None,F)
        or a singe tensor for equally sized graphs (batch,N,F).

        Args:
            inputs: Embeddings to be encoded of shape (batch, [N], F)
        
        Returns:
            q_star (tf.tensor): Pooled tensor of shape (batch,1,2*channels)
        """
        found_node_type = kgcnn_ops_get_tensor_type(
            inputs,
            input_tensor_type=self.input_tensor_type,
            node_indexing=self.node_indexing)
        x, batch_part = kgcnn_ops_dyn_cast(
            inputs,
            input_tensor_type=found_node_type,
            output_tensor_type="values_partition",
            partition_type=self.partition_type)
        batch_num = kgcnn_ops_change_partition_type(batch_part,
                                                    self.partition_type,
                                                    "row_length")
        batch_index = kgcnn_ops_change_partition_type(batch_part,
                                                      self.partition_type,
                                                      "value_rowids")

        # Reading to memory removed here, is to be done by seperately
        m = x  # (batch*None,feat)

        # Initialize q0 and r0
        qstar = self.qstar0(m, batch_index, batch_num)

        # start loop
        for i in range(0, self.T):
            q = self.lay_lstm(qstar)  # (batch,feat)
            qt = tf.repeat(q, batch_num, axis=0)  # (batch*num,feat)
            et = self.f_et(m, qt)  # (batch*num,)
            # get at = exp(et)/sum(et) with sum(et)
            at = ksb.exp(et - self.get_scale_per_sample(
                et, batch_index, batch_num))  # (batch*num,)
            norm = self.get_norm(at, batch_index, batch_num)  # (batch*num,)
            at = norm * at  # (batch*num,) x (batch*num,)
            # calculate rt
            at = ksb.expand_dims(at, axis=1)
            rt = m * at  # (batch*num,feat) x (batch*num,1)
            rt = tf.math.segment_sum(rt, batch_index)  # (batch,feat)
            # qstar = [q,r]
            qstar = ksb.concatenate([q, rt], axis=1)  # (batch,2*feat)
            qstar = ksb.expand_dims(qstar, axis=1)  # (batch,1,2*feat)

        return qstar
コード例 #8
0
ファイル: casting.py プロジェクト: thegodone/gcnn_keras
    def call(self, inputs, **kwargs):
        """Forward pass.

        Args:
            inputs: Graph tensor-information.

        Returns:
            outputs: Changed tensor-information.
        """
        return kgcnn_ops_dyn_cast(inputs,
                                  input_tensor_type=self.input_tensor_type,
                                  output_tensor_type=self.output_tensor_type,
                                  partition_type=self.partition_type)
コード例 #9
0
ファイル: update.py プロジェクト: thegodone/gcnn_keras
    def call(self, inputs, **kwargs):
        """Forward pass.

        The tensor representation can be tf.RaggedTensor, tf.Tensor or a list of (values, partition).
        The RaggedTensor has shape (batch, None, F) or in case of equal sized graphs (batch, N, F).
        For disjoint representation (values, partition), the node embeddings are given by
        a flatten value tensor of shape (batch*None, F) and a partition tensor of either "row_length",
        "row_splits" or "value_rowids" that matches the tf.RaggedTensor partition information. In this case
        the partition_type and node_indexing scheme, i.e. "batch", must be known by the layer.
        For edge indices, the last dimension holds indices from outgoing to ingoing node (i,j) as a directed edge,
        i.e. (batch, None, 2)

        Args:
            inputs (list): of [trafo, edges]

            - trafo: Transformation by matrix multiplication for each message. Must be reshaped to (batch, [N], FxF).
            - edges: Edge embeddings or messages (batch, [N], F)
            
        Returns:
            node_updates: Transformation of messages by matrix multiplication of shape (batch, [N], F)
        """
        found_trafo_type = kgcnn_ops_get_tensor_type(inputs[0], input_tensor_type=self.input_tensor_type,
                                                     node_indexing=self.node_indexing)
        found_edge_type = kgcnn_ops_get_tensor_type(inputs[1], input_tensor_type=self.input_tensor_type,
                                                    node_indexing=self.node_indexing)
        dens_trafo, trafo_part = kgcnn_ops_dyn_cast(inputs[0], input_tensor_type=found_trafo_type,
                                                    output_tensor_type="values_partition",
                                                    partition_type=self.partition_type)
        dens_e, epart = kgcnn_ops_dyn_cast(inputs[1], input_tensor_type=found_edge_type,
                                           output_tensor_type="values_partition",
                                           partition_type=self.partition_type)

        dens_m = tf.reshape(dens_trafo,
                            (ks.backend.shape(dens_trafo)[0], self.target_shape, ks.backend.shape(dens_e)[-1]))
        out = tf.keras.backend.batch_dot(dens_m, dens_e)

        return kgcnn_ops_dyn_cast([out, epart], input_tensor_type="values_partition",
                                  output_tensor_type=found_edge_type, partition_type=self.partition_type)
コード例 #10
0
ファイル: pooling.py プロジェクト: thegodone/gcnn_keras
    def call(self, inputs, **kwargs):
        """Forward pass.

        The tensor representation can be tf.RaggedTensor, tf.Tensor or a list of (values, partition).
        The RaggedTensor has shape (batch, None, F) or in case of equal sized graphs (batch, N, F).
        For disjoint representation (values, partition), the node embeddings are given by
        a flatten value tensor of shape (batch*None, F) and a partition tensor of either "row_length",
        "row_splits" or "value_rowids" that matches the tf.RaggedTensor partition information. In this case
        the partition_type and node_indexing scheme, i.e. "batch", must be known by the layer.
        For edge indices, the last dimension holds indices from outgoing to ingoing node (i,j) as a directed edge.

        Args:
            inputs (list): of [node, weights]

            - nodes: Node features of shape (batch, [N], F)
            - weights: Edge or message weights. Most broadcast to nodes.

        Returns:
            nodes (tf.tensor): Pooled node features of shape (batch,F)
        """
        found_node_type = kgcnn_ops_get_tensor_type(inputs[0], input_tensor_type=self.input_tensor_type,
                                                    node_indexing=self.node_indexing)
        found_weight_type = kgcnn_ops_get_tensor_type(inputs[1], input_tensor_type=self.input_tensor_type,
                                                      node_indexing=self.node_indexing)
        nod, node_part = kgcnn_ops_dyn_cast(inputs[0], input_tensor_type=found_node_type,
                                            output_tensor_type="values_partition",
                                            partition_type=self.partition_type)
        weights, _ = kgcnn_ops_dyn_cast(inputs[1], input_tensor_type=found_weight_type,
                                        output_tensor_type="values_partition",
                                        partition_type=self.partition_type)

        batchi = kgcnn_ops_change_partition_type(node_part, self.partition_type, "value_rowids")

        nod = tf.math.multiply(nod, weights)
        out = kgcnn_ops_segment_operation_by_name(self.pooling_method, nod, batchi)
        # Output should have correct shape
        return out
コード例 #11
0
ファイル: update.py プロジェクト: thegodone/gcnn_keras
    def call(self, inputs, **kwargs):
        """Forward pass.

        The tensor representation can be tf.RaggedTensor, tf.Tensor or a list of (values, partition).
        The RaggedTensor has shape (batch, None, F) or in case of equal sized graphs (batch, N, F).
        For disjoint representation (values, partition), the node embeddings are given by
        a flatten value tensor of shape (batch*None, F) and a partition tensor of either "row_length",
        "row_splits" or "value_rowids" that matches the tf.RaggedTensor partition information. In this case
        the partition_type and node_indexing scheme, i.e. "batch", must be known by the layer.
        For edge indices, the last dimension holds indices from outgoing to ingoing node (i,j) as a directed edge,
        i.e. (batch, None, 2)

        Args:
            inputs (list): of [nodes, updates]

            - nodes (tf.tensor): Node embeddings of shape (batch, [N], F)
            - updates (tf.tensor): Matching node updates of shape (batch, [N], F

        Returns:
            updated_nodes (tf.tensor): Updated nodes of shape (batch*None,F)
        """
        found_node_type = kgcnn_ops_get_tensor_type(inputs[0], input_tensor_type=self.input_tensor_type,
                                                    node_indexing=self.node_indexing)
        found_updates_type = kgcnn_ops_get_tensor_type(inputs[1], input_tensor_type=self.input_tensor_type,
                                                       node_indexing=self.node_indexing)
        n, npart = kgcnn_ops_dyn_cast(inputs[0], input_tensor_type=found_node_type,
                                      output_tensor_type="values_partition",
                                      partition_type=self.partition_type)
        eu, _ = kgcnn_ops_dyn_cast(inputs[1], input_tensor_type=found_updates_type,
                                   output_tensor_type="values_partition",
                                   partition_type=self.partition_type)

        out, _ = self.gru_cell(eu, [n], **kwargs)

        return kgcnn_ops_dyn_cast([out, npart], input_tensor_type="values_partition",
                                  output_tensor_type=found_node_type, partition_type=self.partition_type)
コード例 #12
0
ファイル: sparse.py プロジェクト: thegodone/gcnn_keras
    def call(self, inputs, **kwargs):
        """Forward pass.

        The tensor representation can be tf.RaggedTensor, tf.Tensor or a list of (values, partition).
        The RaggedTensor has shape (batch, None, F) or in case of equal sized graphs (batch, N, F).
        For disjoint representation (values, partition), the node embeddings are given by
        a flatten value tensor of shape (batch*None, F) and a partition tensor of either "row_length",
        "row_splits" or "value_rowids" that matches the tf.RaggedTensor partition information. In this case
        the partition_type and node_indexing scheme, i.e. "batch", must be known by the layer.
        For edge indices, the last dimension holds indices from outgoing to ingoing node (i,j) as a directed edge.

        Args:
            Inputs list of [nodes, edges, edge_index]

            - nodes: Node feature tensor of shape (batch, [N], F)
            - edges: Edge feature ragged tensor of shape (batch, [N], 1)
            - edge_index: Ragged edge_indices of shape (batch, [N], 2)

        Returns:
            tf.sparse: Sparse disjoint matrix of shape (batch*None,batch*None)
        """
        found_node_type = kgcnn_ops_get_tensor_type(
            inputs[0],
            input_tensor_type=self.input_tensor_type,
            node_indexing=self.node_indexing)
        found_edge_type = kgcnn_ops_get_tensor_type(
            inputs[1],
            input_tensor_type=self.input_tensor_type,
            node_indexing=self.node_indexing)
        found_index_type = kgcnn_ops_get_tensor_type(
            inputs[2],
            input_tensor_type=self.input_tensor_type,
            node_indexing=self.node_indexing)

        nod, node_part = kgcnn_ops_dyn_cast(
            inputs[0],
            input_tensor_type=found_node_type,
            output_tensor_type="values_partition",
            partition_type=self.partition_type)
        edge, _ = kgcnn_ops_dyn_cast(inputs[1],
                                     input_tensor_type=found_edge_type,
                                     output_tensor_type="values_partition",
                                     partition_type=self.partition_type)
        edge_index, edge_part = kgcnn_ops_dyn_cast(
            inputs[2],
            input_tensor_type=found_index_type,
            output_tensor_type="values_partition",
            partition_type=self.partition_type)

        # Cast to length tensor
        node_len = kgcnn_ops_change_partition_type(node_part,
                                                   self.partition_type,
                                                   "row_length")
        edge_len = kgcnn_ops_change_partition_type(edge_part,
                                                   self.partition_type,
                                                   "row_length")

        # batch-wise indexing
        edge_index = kgcnn_ops_change_edge_tensor_indexing_by_row_partition(
            edge_index,
            node_len,
            edge_len,
            partition_type_node="row_length",
            partition_type_edge="row_length",
            from_indexing=self.node_indexing,
            to_indexing="batch")
        indexlist = edge_index
        valuelist = edge

        if not self.is_sorted:
            # Sort per outgoing
            batch_order = tf.argsort(indexlist[:, 1],
                                     axis=0,
                                     direction='ASCENDING')
            indexlist = tf.gather(indexlist, batch_order, axis=0)
            valuelist = tf.gather(valuelist, batch_order, axis=0)
            # Sort per ingoing node
            node_order = tf.argsort(indexlist[:, 0],
                                    axis=0,
                                    direction='ASCENDING',
                                    stable=True)
            indexlist = tf.gather(indexlist, node_order, axis=0)
            valuelist = tf.gather(valuelist, node_order, axis=0)

        indexlist = tf.cast(indexlist, dtype=tf.int64)
        dense_shape = tf.concat(
            [tf.shape(nod)[0:1], tf.shape(nod)[0:1]], axis=0)
        dense_shape = tf.cast(dense_shape, dtype=tf.int64)
        out = tf.sparse.SparseTensor(indexlist, valuelist[:, 0], dense_shape)

        return out
コード例 #13
0
ファイル: attention.py プロジェクト: thegodone/gcnn_keras
    def call(self, inputs, **kwargs):
        """Forward pass.

        The tensor representation can be tf.RaggedTensor, tf.Tensor or a list of (values, partition).
        The RaggedTensor has shape (batch, None, F) or in case of equal sized graphs (batch, N, F).
        For disjoint representation (values, partition), the node embeddings are given by
        a flatten value tensor of shape (batch*None, F) and a partition tensor of either "row_length",
        "row_splits" or "value_rowids" that matches the tf.RaggedTensor partition information. In this case
        the partition_type and node_indexing scheme, i.e. "batch", must be known by the layer.
        For edge indices, the last dimension holds indices from outgoing to ingoing node (i,j) as a directed edge.

        Args:
            inputs: [node, edges, attention, edge_indices]

            - nodes: Node features of shape (batch, [N], F)
            - edges: Edge or message features of shape (batch, [N], F)
            - attention: Attention coefficients of shape (batch, [N], 1)
            - edge_index: Edge indices of shape (batch, [N], F)

        Returns:
            embeddings: Feature tensor of pooled edge attentions for each node.
        """
        found_node_type = kgcnn_ops_get_tensor_type(
            inputs[0],
            input_tensor_type=self.input_tensor_type,
            node_indexing=self.node_indexing)
        found_edge_type = kgcnn_ops_get_tensor_type(
            inputs[1],
            input_tensor_type=self.input_tensor_type,
            node_indexing=self.node_indexing)
        found_att_type = kgcnn_ops_get_tensor_type(
            inputs[2],
            input_tensor_type=self.input_tensor_type,
            node_indexing=self.node_indexing)
        found_index_type = kgcnn_ops_get_tensor_type(
            inputs[3],
            input_tensor_type=self.input_tensor_type,
            node_indexing=self.node_indexing)

        # We cast to values here
        nod, node_part = kgcnn_ops_dyn_cast(
            inputs[0],
            input_tensor_type=found_node_type,
            output_tensor_type="values_partition",
            partition_type=self.partition_type)
        edge, _ = kgcnn_ops_dyn_cast(inputs[1],
                                     input_tensor_type=found_edge_type,
                                     output_tensor_type="values_partition",
                                     partition_type=self.partition_type)
        attention, _ = kgcnn_ops_dyn_cast(
            inputs[2],
            input_tensor_type=found_att_type,
            output_tensor_type="values_partition",
            partition_type=self.partition_type)
        edgeind, edge_part = kgcnn_ops_dyn_cast(
            inputs[3],
            input_tensor_type=found_index_type,
            output_tensor_type="values_partition",
            partition_type=self.partition_type)

        shiftind = kgcnn_ops_change_edge_tensor_indexing_by_row_partition(
            edgeind,
            node_part,
            edge_part,
            partition_type_node=self.partition_type,
            partition_type_edge=self.partition_type,
            to_indexing='batch',
            from_indexing=self.node_indexing)

        nodind = shiftind[:, 0]  # Pick first index eg. ingoing
        dens = edge
        ats = attention
        if not self.is_sorted:
            # Sort edgeindices
            node_order = tf.argsort(nodind,
                                    axis=0,
                                    direction='ASCENDING',
                                    stable=True)
            nodind = tf.gather(nodind, node_order, axis=0)
            dens = tf.gather(dens, node_order, axis=0)
            ats = tf.gather(ats, node_order, axis=0)

        # Apply segmented softmax
        ats = segment_softmax(ats, nodind)
        get = dens * ats
        get = tf.math.segment_sum(get, nodind)

        if self.has_unconnected:
            # Need to fill tensor since the maximum node may not be also in pooled
            # Does not happen if all nodes are also connected
            get = kgcnn_ops_scatter_segment_tensor_nd(get, nodind,
                                                      tf.shape(nod))

        return kgcnn_ops_dyn_cast([get, node_part],
                                  input_tensor_type="values_partition",
                                  output_tensor_type=found_node_type,
                                  partition_type=self.partition_type)
コード例 #14
0
    def call(self, inputs, **kwargs):
        """Forward pass.

        The tensor representation can be tf.RaggedTensor, tf.Tensor or a list of (values, partition).
        The RaggedTensor has shape (batch, None, F) or in case of equal sized graphs (batch, N, F).
        For disjoint representation (values, partition), the node embeddings are given by
        a flatten value tensor of shape (batch*None, F) and a partition tensor of either "row_length",
        "row_splits" or "value_rowids" that matches the tf.RaggedTensor partition information. In this case
        the partition_type and node_indexing scheme, i.e. "batch", must be known by the layer.
        For edge indices, the last dimension holds indices from outgoing to ingoing node (i,j) as a directed edge.

        Args:
            inputs (list): of [nodes, node_partition, edges, edge_partition, edge_indices]

            - nodes: Node embeddings of shape (batch, [N], F)
            - edges: Edge embeddings of shape (batch, [N], F)
            - edge_indices: Edge index list of shape of shape (batch, [N], 2)
        
        Returns:
            Tuple: [nodes, edges, edge_indices], [map_nodes, map_edges]
            
            - nodes: Pooled node feature tensor
            - edges: Pooled edge feature list
            - edge_indices (tf.tensor): Pooled edge index list
            - map_nodes (tf.tensor): Index map between original and pooled nodes
            - map_edges (tf.tensor): Index map between original and pooled edges
        """
        found_node_type = kgcnn_ops_get_tensor_type(
            inputs[0],
            input_tensor_type=self.input_tensor_type,
            node_indexing=self.node_indexing)
        found_edge_type = kgcnn_ops_get_tensor_type(
            inputs[1],
            input_tensor_type=self.input_tensor_type,
            node_indexing=self.node_indexing)
        found_index_type = kgcnn_ops_get_tensor_type(
            inputs[2],
            input_tensor_type=self.input_tensor_type,
            node_indexing=self.node_indexing)

        node, node_part = kgcnn_ops_dyn_cast(
            inputs[0],
            input_tensor_type=found_node_type,
            output_tensor_type="values_partition",
            partition_type=self.partition_type)
        edgefeat, edge_part = kgcnn_ops_dyn_cast(
            inputs[1],
            input_tensor_type=found_edge_type,
            output_tensor_type="values_partition",
            partition_type=self.partition_type)
        edgeindref, _ = kgcnn_ops_dyn_cast(
            inputs[2],
            input_tensor_type=found_index_type,
            output_tensor_type="values_partition",
            partition_type=self.partition_type)

        edgelen = kgcnn_ops_change_partition_type(edge_part,
                                                  self.partition_type,
                                                  "row_length")
        nodelen = kgcnn_ops_change_partition_type(node_part,
                                                  self.partition_type,
                                                  "row_length")

        index_dtype = edgeindref.dtype
        # Get node properties
        nvalue = node
        nrowlength = tf.cast(nodelen, dtype=index_dtype)
        erowlength = tf.cast(edgelen, dtype=index_dtype)
        nids = tf.repeat(tf.range(tf.shape(nrowlength)[0], dtype=index_dtype),
                         nrowlength)

        # Use kernel p to get score
        norm_p = ks.backend.sqrt(
            ks.backend.sum(ks.backend.square(self.kernel_p),
                           axis=-1,
                           keepdims=True))
        nscore = ks.backend.sum(nvalue * self.kernel_p / norm_p, axis=-1)

        # Sort nodes according to score
        # Then sort after former node ids -> stable = True keeps previous order
        sort1 = tf.argsort(nscore, direction='ASCENDING', stable=False)
        nids_sorted1 = tf.gather(nids, sort1)
        sort2 = tf.argsort(nids_sorted1, direction='ASCENDING',
                           stable=True)  # Must be stable=true here
        sort12 = tf.gather(
            sort1, sort2)  # index goes from 0 to batch*N, no in batch indexing
        nvalue_sorted = tf.gather(nvalue, sort12, axis=0)
        nscore_sorted = tf.gather(nscore, sort12, axis=0)

        # Make Mask
        nremove = tf.cast(tf.math.round(
            self.k * tf.cast(nrowlength, dtype=tf.keras.backend.floatx())),
                          dtype=index_dtype)
        nkeep = nrowlength - nremove
        n_remove_keep = ks.backend.flatten(
            tf.concat([
                ks.backend.expand_dims(nremove, axis=-1),
                ks.backend.expand_dims(nkeep, axis=-1)
            ],
                      axis=-1))
        mask_remove_keep = ks.backend.flatten(
            tf.concat([
                ks.backend.expand_dims(tf.zeros_like(nremove, dtype=tf.bool),
                                       axis=-1),
                ks.backend.expand_dims(tf.ones_like(nkeep, tf.bool), axis=-1)
            ],
                      axis=-1))
        mask = tf.repeat(mask_remove_keep, n_remove_keep)

        # Apply Mask to remove lower score nodes
        pooled_n = nvalue_sorted[mask]
        pooled_score = nscore_sorted[mask]
        pooled_id = nids[mask]  # nids should not have changed by final sorting
        pooled_len = nkeep  # shape=(batch,)
        pooled_index = tf.cast(
            sort12[mask],
            dtype=index_dtype)  # the index goes from 0 to N*batch
        removed_index = tf.cast(
            sort12[tf.math.logical_not(mask)],
            dtype=index_dtype)  # the index goes from 0 to N*batch

        # Pass through gate
        gated_n = pooled_n * ks.backend.expand_dims(
            tf.keras.activations.sigmoid(pooled_score), axis=-1)

        # Make index map for new nodes towards old index
        index_new_nodes = tf.range(tf.shape(pooled_index)[0],
                                   dtype=index_dtype)
        old_shape = tf.cast(ks.backend.expand_dims(tf.shape(nvalue)[0]),
                            dtype=index_dtype)
        map_index = tf.scatter_nd(
            ks.backend.expand_dims(pooled_index, axis=-1), index_new_nodes,
            old_shape)

        # Shift index if necessary
        edge_ids = tf.repeat(tf.range(tf.shape(edgelen)[0], dtype=index_dtype),
                             edgelen)

        shiftind = kgcnn_ops_change_edge_tensor_indexing_by_row_partition(
            edgeindref,
            nrowlength,
            edge_ids,
            partition_type_node="row_length",
            partition_type_edge="value_rowids",
            from_indexing=self.node_indexing,
            to_indexing="batch")

        shiftind = tf.cast(
            shiftind,
            dtype=index_dtype)  # already shifted by batch offset (subgraphs)

        # Remove edges that were from filtered nodes via mask
        mask_edge = ks.backend.expand_dims(
            shiftind, axis=-1) == ks.backend.expand_dims(
                ks.backend.expand_dims(removed_index, axis=0),
                axis=0)  # this creates large tensor (batch*#edges,2,remove)
        mask_edge = tf.math.logical_not(
            ks.backend.any(ks.backend.any(mask_edge, axis=-1), axis=-1))
        clean_shiftind = shiftind[mask_edge]
        clean_edge_ids = edge_ids[mask_edge]
        # clean_edge_len = tf.math.segment_sum(tf.ones_like(clean_edge_ids), clean_edge_ids)
        clean_edge_len = tf.scatter_nd(
            tf.expand_dims(clean_edge_ids, axis=-1),
            tf.ones_like(clean_edge_ids),
            tf.cast(tf.shape(erowlength), dtype=index_dtype))

        # Map edgeindex to new index
        new_edge_index = tf.concat([
            ks.backend.expand_dims(tf.gather(map_index, clean_shiftind[:, 0]),
                                   axis=-1),
            ks.backend.expand_dims(tf.gather(map_index, clean_shiftind[:, 1]),
                                   axis=-1)
        ],
                                   axis=-1)
        batch_order = tf.argsort(new_edge_index[:, 0],
                                 axis=0,
                                 direction='ASCENDING',
                                 stable=True)
        new_edge_index_sorted = tf.gather(new_edge_index, batch_order, axis=0)

        # Remove the batch offset from edge_indices again for indexing type
        out_indexlist = kgcnn_ops_change_edge_tensor_indexing_by_row_partition(
            new_edge_index_sorted,
            pooled_len,
            clean_edge_ids,
            partition_type_node="row_length",
            partition_type_edge="value_rowids",
            from_indexing="batch",
            to_indexing=self.node_indexing)

        # Correct edge features the same way (remove and reorder)
        edge_feat = edgefeat
        clean_edge_feat = edge_feat[mask_edge]
        clean_edge_feat_sorted = tf.gather(clean_edge_feat,
                                           batch_order,
                                           axis=0)

        # Make edge feature map for new edge features
        edge_position_old = tf.range(tf.shape(edgefeat)[0], dtype=index_dtype)
        edge_position_new = edge_position_old[mask_edge]
        edge_position_new = tf.gather(edge_position_new, batch_order, axis=0)

        # Collect output tensors
        out_node = gated_n
        out_edge = clean_edge_feat_sorted
        out_edge_index = out_indexlist

        # Change length to partition required
        out_np = kgcnn_ops_change_partition_type(pooled_len, "row_length",
                                                 self.partition_type)
        out_ep = kgcnn_ops_change_partition_type(clean_edge_len, "row_length",
                                                 self.partition_type)

        # Collect reverse pooling info
        # Remove batch offset for old indicies -> but with new length
        out_pool = kgcnn_ops_change_edge_tensor_indexing_by_row_partition(
            pooled_index,
            nrowlength,
            pooled_len,
            partition_type_node="row_length",
            partition_type_edge="row_length",
            from_indexing="batch",
            to_indexing=self.node_indexing,
            axis=0)
        out_pool_edge = kgcnn_ops_change_edge_tensor_indexing_by_row_partition(
            edge_position_new,
            erowlength,
            clean_edge_ids,
            partition_type_node="row_length",
            partition_type_edge="value_rowids",
            from_indexing="batch",
            to_indexing=self.node_indexing,
            axis=0)

        out = [
            kgcnn_ops_dyn_cast([out_node, out_np],
                               input_tensor_type="values_partition",
                               output_tensor_type=found_node_type,
                               partition_type=self.partition_type),
            kgcnn_ops_dyn_cast([out_edge, out_ep],
                               input_tensor_type="values_partition",
                               output_tensor_type=found_edge_type,
                               partition_type=self.partition_type),
            kgcnn_ops_dyn_cast([out_edge_index, out_ep],
                               input_tensor_type="values_partition",
                               output_tensor_type=found_index_type,
                               partition_type=self.partition_type)
        ]

        out_map = [
            kgcnn_ops_dyn_cast([out_pool, out_np],
                               input_tensor_type="values_partition",
                               output_tensor_type=found_node_type,
                               partition_type=self.partition_type),
            kgcnn_ops_dyn_cast([out_pool_edge, out_ep],
                               input_tensor_type="values_partition",
                               output_tensor_type=found_edge_type,
                               partition_type=self.partition_type)
        ]

        return out, out_map
コード例 #15
0
    def call(self, inputs, **kwargs):
        """Forward pass.

        The tensor representation can be tf.RaggedTensor, tf.Tensor or a list of (values, partition).
        The RaggedTensor has shape (batch, None, F) or in case of equal sized graphs (batch, N, F).
        For disjoint representation (values, partition), the node embeddings are given by
        a flatten value tensor of shape (batch*None, F) and a partition tensor of either "row_length",
        "row_splits" or "value_rowids" that matches the tf.RaggedTensor partition information. In this case
        the partition_type and node_indexing scheme, i.e. "batch", must be known by the layer.
        For edge indices, the last dimension holds indices from outgoing to ingoing node (i,j) as a directed edge,
        i.e. (batch, None, 2)

        Args:
            inputs (list): [node, edge, edge_indices, map_node, map_edge, node_pool, edge_pool, edge_indices_pool]

            - node: Original node tensor
            - edge: Original edge feature tensor
            - edge_indices: Original index tensor
            - map_node: Index map between original and pooled nodes
            - map_edge: Index map between original and pooled edges
            - node_pool: Pooled node tensor
            - edge_pool: Pooled edge feature tensor
            - edge_indices: Pooled index tensor
        
        Returns:
            List: [nodes, edges, edge_indices]
            
            - nodes: Unpooled node feature tensor
            - edges: Unpooled edge feature list
            - edge_indices: Unpooled edge index
        """
        found_input_type = [
            kgcnn_ops_get_tensor_type(inputs[i],
                                      input_tensor_type=self.input_tensor_type,
                                      node_indexing=self.node_indexing)
            for i in range(8)
        ]

        node_old, nodepart_old = kgcnn_ops_dyn_cast(
            inputs[0],
            input_tensor_type=found_input_type[0],
            output_tensor_type="values_partition",
            partition_type=self.partition_type)
        edge_old, edgepart_old = kgcnn_ops_dyn_cast(
            inputs[1],
            input_tensor_type=found_input_type[1],
            output_tensor_type="values_partition",
            partition_type=self.partition_type)
        edgeind_old, _ = kgcnn_ops_dyn_cast(
            inputs[2],
            input_tensor_type=found_input_type[2],
            output_tensor_type="values_partition",
            partition_type=self.partition_type)
        map_node, _ = kgcnn_ops_dyn_cast(inputs[3],
                                         input_tensor_type=found_input_type[3],
                                         output_tensor_type="values_partition",
                                         partition_type=self.partition_type)
        map_edge, _ = kgcnn_ops_dyn_cast(inputs[4],
                                         input_tensor_type=found_input_type[4],
                                         output_tensor_type="values_partition",
                                         partition_type=self.partition_type)
        node_new, nodpart_new = kgcnn_ops_dyn_cast(
            inputs[5],
            input_tensor_type=found_input_type[5],
            output_tensor_type="values_partition",
            partition_type=self.partition_type)
        edge_new, edgepart_new = kgcnn_ops_dyn_cast(
            inputs[6],
            input_tensor_type=found_input_type[6],
            output_tensor_type="values_partition",
            partition_type=self.partition_type)
        edgeind_new, _ = kgcnn_ops_dyn_cast(
            inputs[7],
            input_tensor_type=found_input_type[7],
            output_tensor_type="values_partition",
            partition_type=self.partition_type)

        nrowlength = kgcnn_ops_change_partition_type(nodepart_old,
                                                     self.partition_type,
                                                     "row_length")
        erowlength = kgcnn_ops_change_partition_type(edgepart_old,
                                                     self.partition_type,
                                                     "row_length")
        pool_node_len = kgcnn_ops_change_partition_type(
            nodpart_new, self.partition_type, "row_length")
        pool_edge_id = kgcnn_ops_change_partition_type(edgepart_new,
                                                       self.partition_type,
                                                       "value_rowids")

        # Correct map index for flatten batch offset
        map_node = kgcnn_ops_change_edge_tensor_indexing_by_row_partition(
            map_node,
            nrowlength,
            pool_node_len,
            partition_type_node="row_length",
            partition_type_edge="row_length",
            from_indexing=self.node_indexing,
            to_indexing="batch",
            axis=0)
        map_edge = kgcnn_ops_change_edge_tensor_indexing_by_row_partition(
            map_edge,
            erowlength,
            pool_edge_id,
            partition_type_node="row_length",
            partition_type_edge="value_rowids",
            from_indexing=self.node_indexing,
            to_indexing="batch",
            axis=0)

        index_dtype = map_node.dtype
        node_shape = tf.stack([
            tf.cast(tf.shape(node_old)[0], dtype=index_dtype),
            tf.cast(tf.shape(node_new)[1], dtype=index_dtype)
        ])
        out_node = tf.scatter_nd(ks.backend.expand_dims(map_node, axis=-1),
                                 node_new, node_shape)

        index_dtype = map_edge.dtype
        edge_shape = tf.stack([
            tf.cast(tf.shape(edge_old)[0], dtype=index_dtype),
            tf.cast(tf.shape(edge_new)[1], dtype=index_dtype)
        ])
        out_edge = tf.scatter_nd(ks.backend.expand_dims(map_edge, axis=-1),
                                 edge_new, edge_shape)

        outlist = [
            kgcnn_ops_dyn_cast([out_node, nodepart_old],
                               input_tensor_type="values_partition",
                               output_tensor_type=found_input_type[0],
                               partition_type=self.partition_type),
            kgcnn_ops_dyn_cast([out_edge, edgepart_old],
                               input_tensor_type="values_partition",
                               output_tensor_type=found_input_type[1],
                               partition_type=self.partition_type),
            kgcnn_ops_dyn_cast([edgeind_old, edgepart_old],
                               input_tensor_type="values_partition",
                               output_tensor_type=found_input_type[2],
                               partition_type=self.partition_type)
        ]
        return outlist
コード例 #16
0
ファイル: connect.py プロジェクト: thegodone/gcnn_keras
    def call(self, inputs, **kwargs):
        """Forward path.

        The tensor representation can be tf.RaggedTensor, tf.Tensor or a list of (values, partition).
        The RaggedTensor has shape (batch, None, F) or in case of equal sized graphs (batch, N, F).
        For disjoint representation (values, partition), the node embeddings are given by
        a flatten value tensor of shape (batch*None, F) and a partition tensor of either "row_length",
        "row_splits" or "value_rowids" that matches the tf.RaggedTensor partition information. In this case
        the partition_type and node_indexing scheme, i.e. "batch", must be known by the layer.
        For edge indices, the last dimension holds indices from outgoing to ingoing node (i,j) as a directed edge.

        Args:
            inputs (list): [nodes, edges, edge_indices]

            - nodes: Node emebeddings of shape (batch, [N], F)
            - edges: Adjacency entries of shape (batch, [N], 1)
            - edge_indices: Index list of shape (batch, [N], 2)
            
        Returns:
            list: [edges, edge_indices]

            - edges: Adjacency entries of shape  (batch, [N], 1)
            - edge_indices: Flatten index list of shape (batch, [N], 2)
        """
        found_node_type = kgcnn_ops_get_tensor_type(inputs[0], input_tensor_type=self.input_tensor_type,
                                                    node_indexing=self.node_indexing)
        found_edge_type = kgcnn_ops_get_tensor_type(inputs[1], input_tensor_type=self.input_tensor_type,
                                                    node_indexing=self.node_indexing)
        found_index_type = kgcnn_ops_get_tensor_type(inputs[2], input_tensor_type=self.input_tensor_type,
                                                     node_indexing=self.node_indexing)

        nod, node_part = kgcnn_ops_dyn_cast(inputs[0], input_tensor_type=found_node_type,
                                            output_tensor_type="values_partition",
                                            partition_type=self.partition_type)
        edge, _ = kgcnn_ops_dyn_cast(inputs[1], input_tensor_type=found_edge_type,
                                     output_tensor_type="values_partition",
                                     partition_type=self.partition_type)
        edge_index, edge_part = kgcnn_ops_dyn_cast(inputs[2], input_tensor_type=found_index_type,
                                                   output_tensor_type="values_partition",
                                                   partition_type=self.partition_type)

        # Cast to length tensor
        node_len = kgcnn_ops_change_partition_type(node_part, self.partition_type, "row_length")
        edge_len = kgcnn_ops_change_partition_type(edge_part, self.partition_type, "row_length")

        # batch-wise indexing
        edge_index = kgcnn_ops_change_edge_tensor_indexing_by_row_partition(edge_index,
                                                                            node_len, edge_len,
                                                                            partition_type_node="row_length",
                                                                            partition_type_edge="row_length",
                                                                            from_indexing=self.node_indexing,
                                                                            to_indexing="sample")

        ind_batch = tf.cast(tf.expand_dims(tf.repeat(tf.range(tf.shape(edge_len)[0]), edge_len), axis=-1),
                            dtype=edge_index.dtype)
        ind_all = tf.concat([ind_batch, edge_index], axis=-1)
        ind_all = tf.cast(ind_all, dtype=tf.int64)

        max_index = tf.reduce_max(node_len)
        dense_shape = tf.stack([tf.cast(tf.shape(node_len)[0], dtype=max_index.dtype), max_index, max_index])
        adj = tf.zeros(dense_shape, dtype=edge.dtype)
        ind_flat = tf.range(tf.cast(tf.shape(node_len)[0], dtype=max_index.dtype) * max_index * max_index)

        adj = tf.expand_dims(adj, axis=-1)
        adj = tf.tensor_scatter_nd_update(adj, ind_all, edge[:, 0:1])
        adj = tf.squeeze(adj, axis=-1)

        out0 = adj
        out = adj
        for i in range(self.n - 1):
            out = tf.linalg.matmul(out, out0)

        # debug_result = out

        # sparsify
        mask = out > tf.keras.backend.epsilon()
        mask = tf.reshape(mask, (-1,))
        out = tf.reshape(out, (-1,))

        new_edge = out[mask]
        new_edge = tf.expand_dims(new_edge, axis=-1)
        new_indices = tf.unravel_index(ind_flat[mask], dims=dense_shape)
        new_egde_ids = new_indices[0]
        new_edge_index = tf.concat([tf.expand_dims(new_indices[1], axis=-1), tf.expand_dims(new_indices[2], axis=-1)],
                                   axis=-1)
        new_edge_len = tf.tensor_scatter_nd_add(tf.zeros_like(node_len), tf.expand_dims(new_egde_ids, axis=-1),
                                                tf.ones_like(new_egde_ids))

        # Outpartition
        new_edge_part = kgcnn_ops_change_partition_type(new_edge_len, "row_length", self.partition_type)

        # batchwise indexing
        new_edge_index = kgcnn_ops_change_edge_tensor_indexing_by_row_partition(new_edge_index,
                                                                                node_len, new_edge_len,
                                                                                partition_type_node="row_length",
                                                                                partition_type_edge="row_length",
                                                                                from_indexing="sample",
                                                                                to_indexing=self.node_indexing)

        outlist = [kgcnn_ops_dyn_cast([new_edge, new_edge_part], input_tensor_type="values_partition",
                                      output_tensor_type=found_edge_type, partition_type=self.partition_type),
                   kgcnn_ops_dyn_cast([new_edge_index, new_edge_part], input_tensor_type="values_partition",
                                      output_tensor_type=found_index_type, partition_type=self.partition_type)
                   ]

        return outlist