def _kgcnn_map_output_ragged(self, inputs, input_partition_type, output_tensor_type=None): x = inputs ragged_keys = ["ragged", "RaggedTensor"] value_partition_keys = ["disjoint", "values_partition"] if output_tensor_type is None: output_tensor_type = self.output_tensor_type elif isinstance(output_tensor_type, int): output_tensor_type = self._tensor_input_type_found[output_tensor_type] if isinstance(x, tf.RaggedTensor): if output_tensor_type in ragged_keys: return x elif output_tensor_type in value_partition_keys: return kgcnn_ops_dyn_cast(x, input_tensor_type="ragged", output_tensor_type=output_tensor_type, partition_type=self.partition_type) else: raise TypeError("Error:", self.name, "output type for ragged-like input is not supported for", x) elif isinstance(x, list): if len(x) != 2: print("Warning:", self.name, "output does not match rank=1 partition scheme for batch dimension.") if output_tensor_type in ragged_keys: return kgcnn_ops_dyn_cast(x, input_tensor_type="values_partition", output_tensor_type=output_tensor_type, partition_type=input_partition_type) elif output_tensor_type in value_partition_keys: tens_part = kgcnn_ops_change_partition_type(x[1], input_partition_type, self.partition_type) return [x[0], tens_part] else: raise TypeError("Error:", self.name, "output type for ragged-like input is not supported for", x) else: raise TypeError("Error:", self.name, "input type for ragged-like input is not supported for", x)
def call(self, inputs, **kwargs): """Forward pass. Args: inputs: [nodes, adjacency] - nodes: Node features of shape (batch, [N], F) - adjacency (tf.sparse): SparseTensor of the adjacency matrix of shape (batch*None,batch*None) Returns: features (tf.tensor): Pooled node features of shape (batch,F) """ adj = inputs[1] found_node_type = kgcnn_ops_get_tensor_type( inputs[0], input_tensor_type=self.input_tensor_type, node_indexing=self.node_indexing) node, node_part = kgcnn_ops_dyn_cast( inputs[0], input_tensor_type=found_node_type, output_tensor_type="values_partition", partition_type=self.partition_type) out = tf.sparse.sparse_dense_matmul(adj, node) return kgcnn_ops_dyn_cast([out, node_part], input_tensor_type="values_partition", output_tensor_type=found_node_type, partition_type=self.partition_type)
def call(self, inputs, **kwargs): """Forward pass. The tensor representation can be tf.RaggedTensor, tf.Tensor or a list of (values, partition). The RaggedTensor has shape (batch, None, F) or in case of equal sized graphs (batch, N, F). For disjoint representation (values, partition), the node embeddings are given by a flatten value tensor of shape (batch*None, F) and a partition tensor of either "row_length", "row_splits" or "value_rowids" that matches the tf.RaggedTensor partition information. In this case the partition_type and node_indexing scheme, i.e. "batch", must be known by the layer. For edge indices, the last dimension holds indices from outgoing to ingoing node (i,j) as a directed edge. Args: inputs (list): of [node, edges, edge_index] - nodes: Node features of shape (batch, [N], F) - edges: Edge or message features of shape (batch, [N], F) - edge_index: Edge indices of shape (batch, [N], 2) Returns: features: Pooled feature tensor of pooled edge features for each node. """ found_node_type = kgcnn_ops_get_tensor_type(inputs[0], input_tensor_type=self.input_tensor_type, node_indexing=self.node_indexing) found_edge_type = kgcnn_ops_get_tensor_type(inputs[1], input_tensor_type=self.input_tensor_type, node_indexing=self.node_indexing) found_index_type = kgcnn_ops_get_tensor_type(inputs[2], input_tensor_type=self.input_tensor_type, node_indexing=self.node_indexing) nod, node_part = kgcnn_ops_dyn_cast(inputs[0], input_tensor_type=found_node_type, output_tensor_type="values_partition", partition_type=self.partition_type) edge, _ = kgcnn_ops_dyn_cast(inputs[1], input_tensor_type=found_edge_type, output_tensor_type="values_partition", partition_type=self.partition_type) edgeind, edge_part = kgcnn_ops_dyn_cast(inputs[2], input_tensor_type=found_index_type, output_tensor_type="values_partition", partition_type=self.partition_type) shiftind = kgcnn_ops_change_edge_tensor_indexing_by_row_partition(edgeind, node_part, edge_part, partition_type_node=self.partition_type, partition_type_edge=self.partition_type, to_indexing='batch', from_indexing=self.node_indexing) nodind = shiftind[:, 0] # Pick first index eg. ingoing dens = edge if not self.is_sorted: # Sort edgeindices node_order = tf.argsort(nodind, axis=0, direction='ASCENDING', stable=True) nodind = tf.gather(nodind, node_order, axis=0) dens = tf.gather(dens, node_order, axis=0) # Pooling via e.g. segment_sum out = kgcnn_ops_segment_operation_by_name(self.pooling_method, dens, nodind) if self.has_unconnected: out = kgcnn_ops_scatter_segment_tensor_nd(out, nodind, tf.shape(nod)) return kgcnn_ops_dyn_cast([out, node_part], input_tensor_type="values_partition", output_tensor_type=found_node_type, partition_type=self.partition_type)
def call(self, inputs, **kwargs): """Forward pass. The tensor representation can be tf.RaggedTensor, tf.Tensor or a list of (values, partition). The RaggedTensor has shape (batch, None, F) or in case of equal sized graphs (batch, N, F). For disjoint representation (values, partition), the node embeddings are given by a flatten value tensor of shape (batch*None, F) and a partition tensor of either "row_length", "row_splits" or "value_rowids" that matches the tf.RaggedTensor partition information. In this case the partition_type and node_indexing scheme, i.e. "batch", must be known by the layer. For edge indices, the last dimension holds indices from outgoing to ingoing node (i,j) as a directed edge. Args: inputs (list): [nodes, edge_index] - nodes: Node embeddings of shape (batch, [N], F) - edge_index: Edge indices of shape (batch, [N], 2) Returns: embeddings: Gathered node embeddings that match the number of edges. """ found_node_type = kgcnn_ops_get_tensor_type( inputs[0], input_tensor_type=self.input_tensor_type, node_indexing=self.node_indexing) found_edge_type = kgcnn_ops_get_tensor_type( inputs[1], input_tensor_type=self.input_tensor_type, node_indexing=self.node_indexing) # We cast to values here node, node_part = kgcnn_ops_dyn_cast( inputs[0], input_tensor_type=found_node_type, output_tensor_type="values_partition", partition_type=self.partition_type) edge_index, edge_part = kgcnn_ops_dyn_cast( inputs[1], input_tensor_type=found_edge_type, output_tensor_type="values_partition", partition_type=self.partition_type) indexlist = kgcnn_ops_change_edge_tensor_indexing_by_row_partition( edge_index, node_part, edge_part, partition_type_node=self.partition_type, partition_type_edge=self.partition_type, to_indexing='batch', from_indexing=self.node_indexing) out = tf.gather(node, indexlist[:, 0], axis=0) # For ragged tensor we can now also try: # out = tf.gather(nod, edge_index[:, :, 0], batch_dims=1) return kgcnn_ops_dyn_cast([out, edge_part], input_tensor_type="values_partition", output_tensor_type=found_edge_type, partition_type=self.partition_type)
def call(self, inputs, **kwargs): """Forward pass. The tensor representation can be tf.RaggedTensor, tf.Tensor or a list of (values, partition). The RaggedTensor has shape (batch, None, F) or in case of equal sized graphs (batch, N, F). For disjoint representation (values, partition), the node embeddings are given by a flatten value tensor of shape (batch*None, F) and a partition tensor of either "row_length", "row_splits" or "value_rowids" that matches the tf.RaggedTensor partition information. In this case the partition_type and node_indexing scheme, i.e. "batch", must be known by the layer. For edge indices, the last dimension holds indices from outgoing to ingoing node (i,j) as a directed edge. Args: inputs: [state, target] - state: Graph specific embedding tensor. This is usually tensor of shape (batch, F.) - target: Target to collect state for. This can be node or edge embeddings of shape (batch, [N], F) Returns: state: Graph embedding with repeated single state for each graph. """ found_state_type = kgcnn_ops_get_tensor_type( inputs[0], input_tensor_type="tensor", node_indexing=self.node_indexing) found_ref_type = kgcnn_ops_get_tensor_type( inputs[1], input_tensor_type=self.input_tensor_type, node_indexing=self.node_indexing) # We cast to values here env = kgcnn_ops_dyn_cast(inputs[0], input_tensor_type=found_state_type, output_tensor_type="tensor", partition_type=self.partition_type) _, target_part = kgcnn_ops_dyn_cast( inputs[1], input_tensor_type=found_ref_type, output_tensor_type="values_partition", partition_type=self.partition_type) target_len = kgcnn_ops_change_partition_type(target_part, self.partition_type, "row_length") out = tf.repeat(env, target_len, axis=0) return kgcnn_ops_dyn_cast([out, target_part], input_tensor_type="values_partition", output_tensor_type=found_ref_type, partition_type=self.partition_type)
def call(self, inputs, **kwargs): """Forward pass. The tensor representation can be tf.RaggedTensor, tf.Tensor or a list of (values, partition). The RaggedTensor has shape (batch, None, F) or in case of equal sized graphs (batch, N, F). For disjoint representation (values, partition), the node embeddings are given by a flatten value tensor of shape (batch*None, F) and a partition tensor of either "row_length", "row_splits" or "value_rowids" that matches the tf.RaggedTensor partition information. In this case the partition_type and node_indexing scheme, i.e. "batch", must be known by the layer. For edge indices, the last dimension holds indices from outgoing to ingoing node (i,j) as a directed edge. Args: inputs: Edge features or message embeddings of shape (batch, [N], F) Returns: tf.tensor: Pooled edges feature list of shape (batch,F). """ found_edge_type = kgcnn_ops_get_tensor_type(inputs, input_tensor_type=self.input_tensor_type, node_indexing=self.node_indexing) edge, edge_part = kgcnn_ops_dyn_cast(inputs, input_tensor_type=found_edge_type, output_tensor_type="values_partition", partition_type=self.partition_type) batchi = kgcnn_ops_change_partition_type(edge_part, self.partition_type, "value_rowids") out = kgcnn_ops_segment_operation_by_name(self.pooling_method, edge, batchi) # Output already has correct shape and type return out
def call(self, inputs, **kwargs): """Forward pass. This can be either a tuple of (values, partition) tensors of shape (batch*None,F) and a partition tensor of the type "row_length", "row_splits" or "value_rowids". This usually uses disjoint indexing defined by 'node_indexing'. Or a tuple of (values, mask) tensors of shape (batch, N, F) and mask (batch, N) or a single RaggedTensor of shape (batch,None,F) or a singe tensor for equally sized graphs (batch,N,F). Args: inputs: Embeddings to be encoded of shape (batch, [N], F) Returns: q_star (tf.tensor): Pooled tensor of shape (batch,1,2*channels) """ found_node_type = kgcnn_ops_get_tensor_type( inputs, input_tensor_type=self.input_tensor_type, node_indexing=self.node_indexing) x, batch_part = kgcnn_ops_dyn_cast( inputs, input_tensor_type=found_node_type, output_tensor_type="values_partition", partition_type=self.partition_type) batch_num = kgcnn_ops_change_partition_type(batch_part, self.partition_type, "row_length") batch_index = kgcnn_ops_change_partition_type(batch_part, self.partition_type, "value_rowids") # Reading to memory removed here, is to be done by seperately m = x # (batch*None,feat) # Initialize q0 and r0 qstar = self.qstar0(m, batch_index, batch_num) # start loop for i in range(0, self.T): q = self.lay_lstm(qstar) # (batch,feat) qt = tf.repeat(q, batch_num, axis=0) # (batch*num,feat) et = self.f_et(m, qt) # (batch*num,) # get at = exp(et)/sum(et) with sum(et) at = ksb.exp(et - self.get_scale_per_sample( et, batch_index, batch_num)) # (batch*num,) norm = self.get_norm(at, batch_index, batch_num) # (batch*num,) at = norm * at # (batch*num,) x (batch*num,) # calculate rt at = ksb.expand_dims(at, axis=1) rt = m * at # (batch*num,feat) x (batch*num,1) rt = tf.math.segment_sum(rt, batch_index) # (batch,feat) # qstar = [q,r] qstar = ksb.concatenate([q, rt], axis=1) # (batch,2*feat) qstar = ksb.expand_dims(qstar, axis=1) # (batch,1,2*feat) return qstar
def call(self, inputs, **kwargs): """Forward pass. Args: inputs: Graph tensor-information. Returns: outputs: Changed tensor-information. """ return kgcnn_ops_dyn_cast(inputs, input_tensor_type=self.input_tensor_type, output_tensor_type=self.output_tensor_type, partition_type=self.partition_type)
def call(self, inputs, **kwargs): """Forward pass. The tensor representation can be tf.RaggedTensor, tf.Tensor or a list of (values, partition). The RaggedTensor has shape (batch, None, F) or in case of equal sized graphs (batch, N, F). For disjoint representation (values, partition), the node embeddings are given by a flatten value tensor of shape (batch*None, F) and a partition tensor of either "row_length", "row_splits" or "value_rowids" that matches the tf.RaggedTensor partition information. In this case the partition_type and node_indexing scheme, i.e. "batch", must be known by the layer. For edge indices, the last dimension holds indices from outgoing to ingoing node (i,j) as a directed edge, i.e. (batch, None, 2) Args: inputs (list): of [trafo, edges] - trafo: Transformation by matrix multiplication for each message. Must be reshaped to (batch, [N], FxF). - edges: Edge embeddings or messages (batch, [N], F) Returns: node_updates: Transformation of messages by matrix multiplication of shape (batch, [N], F) """ found_trafo_type = kgcnn_ops_get_tensor_type(inputs[0], input_tensor_type=self.input_tensor_type, node_indexing=self.node_indexing) found_edge_type = kgcnn_ops_get_tensor_type(inputs[1], input_tensor_type=self.input_tensor_type, node_indexing=self.node_indexing) dens_trafo, trafo_part = kgcnn_ops_dyn_cast(inputs[0], input_tensor_type=found_trafo_type, output_tensor_type="values_partition", partition_type=self.partition_type) dens_e, epart = kgcnn_ops_dyn_cast(inputs[1], input_tensor_type=found_edge_type, output_tensor_type="values_partition", partition_type=self.partition_type) dens_m = tf.reshape(dens_trafo, (ks.backend.shape(dens_trafo)[0], self.target_shape, ks.backend.shape(dens_e)[-1])) out = tf.keras.backend.batch_dot(dens_m, dens_e) return kgcnn_ops_dyn_cast([out, epart], input_tensor_type="values_partition", output_tensor_type=found_edge_type, partition_type=self.partition_type)
def call(self, inputs, **kwargs): """Forward pass. The tensor representation can be tf.RaggedTensor, tf.Tensor or a list of (values, partition). The RaggedTensor has shape (batch, None, F) or in case of equal sized graphs (batch, N, F). For disjoint representation (values, partition), the node embeddings are given by a flatten value tensor of shape (batch*None, F) and a partition tensor of either "row_length", "row_splits" or "value_rowids" that matches the tf.RaggedTensor partition information. In this case the partition_type and node_indexing scheme, i.e. "batch", must be known by the layer. For edge indices, the last dimension holds indices from outgoing to ingoing node (i,j) as a directed edge. Args: inputs (list): of [node, weights] - nodes: Node features of shape (batch, [N], F) - weights: Edge or message weights. Most broadcast to nodes. Returns: nodes (tf.tensor): Pooled node features of shape (batch,F) """ found_node_type = kgcnn_ops_get_tensor_type(inputs[0], input_tensor_type=self.input_tensor_type, node_indexing=self.node_indexing) found_weight_type = kgcnn_ops_get_tensor_type(inputs[1], input_tensor_type=self.input_tensor_type, node_indexing=self.node_indexing) nod, node_part = kgcnn_ops_dyn_cast(inputs[0], input_tensor_type=found_node_type, output_tensor_type="values_partition", partition_type=self.partition_type) weights, _ = kgcnn_ops_dyn_cast(inputs[1], input_tensor_type=found_weight_type, output_tensor_type="values_partition", partition_type=self.partition_type) batchi = kgcnn_ops_change_partition_type(node_part, self.partition_type, "value_rowids") nod = tf.math.multiply(nod, weights) out = kgcnn_ops_segment_operation_by_name(self.pooling_method, nod, batchi) # Output should have correct shape return out
def call(self, inputs, **kwargs): """Forward pass. The tensor representation can be tf.RaggedTensor, tf.Tensor or a list of (values, partition). The RaggedTensor has shape (batch, None, F) or in case of equal sized graphs (batch, N, F). For disjoint representation (values, partition), the node embeddings are given by a flatten value tensor of shape (batch*None, F) and a partition tensor of either "row_length", "row_splits" or "value_rowids" that matches the tf.RaggedTensor partition information. In this case the partition_type and node_indexing scheme, i.e. "batch", must be known by the layer. For edge indices, the last dimension holds indices from outgoing to ingoing node (i,j) as a directed edge, i.e. (batch, None, 2) Args: inputs (list): of [nodes, updates] - nodes (tf.tensor): Node embeddings of shape (batch, [N], F) - updates (tf.tensor): Matching node updates of shape (batch, [N], F Returns: updated_nodes (tf.tensor): Updated nodes of shape (batch*None,F) """ found_node_type = kgcnn_ops_get_tensor_type(inputs[0], input_tensor_type=self.input_tensor_type, node_indexing=self.node_indexing) found_updates_type = kgcnn_ops_get_tensor_type(inputs[1], input_tensor_type=self.input_tensor_type, node_indexing=self.node_indexing) n, npart = kgcnn_ops_dyn_cast(inputs[0], input_tensor_type=found_node_type, output_tensor_type="values_partition", partition_type=self.partition_type) eu, _ = kgcnn_ops_dyn_cast(inputs[1], input_tensor_type=found_updates_type, output_tensor_type="values_partition", partition_type=self.partition_type) out, _ = self.gru_cell(eu, [n], **kwargs) return kgcnn_ops_dyn_cast([out, npart], input_tensor_type="values_partition", output_tensor_type=found_node_type, partition_type=self.partition_type)
def call(self, inputs, **kwargs): """Forward pass. The tensor representation can be tf.RaggedTensor, tf.Tensor or a list of (values, partition). The RaggedTensor has shape (batch, None, F) or in case of equal sized graphs (batch, N, F). For disjoint representation (values, partition), the node embeddings are given by a flatten value tensor of shape (batch*None, F) and a partition tensor of either "row_length", "row_splits" or "value_rowids" that matches the tf.RaggedTensor partition information. In this case the partition_type and node_indexing scheme, i.e. "batch", must be known by the layer. For edge indices, the last dimension holds indices from outgoing to ingoing node (i,j) as a directed edge. Args: Inputs list of [nodes, edges, edge_index] - nodes: Node feature tensor of shape (batch, [N], F) - edges: Edge feature ragged tensor of shape (batch, [N], 1) - edge_index: Ragged edge_indices of shape (batch, [N], 2) Returns: tf.sparse: Sparse disjoint matrix of shape (batch*None,batch*None) """ found_node_type = kgcnn_ops_get_tensor_type( inputs[0], input_tensor_type=self.input_tensor_type, node_indexing=self.node_indexing) found_edge_type = kgcnn_ops_get_tensor_type( inputs[1], input_tensor_type=self.input_tensor_type, node_indexing=self.node_indexing) found_index_type = kgcnn_ops_get_tensor_type( inputs[2], input_tensor_type=self.input_tensor_type, node_indexing=self.node_indexing) nod, node_part = kgcnn_ops_dyn_cast( inputs[0], input_tensor_type=found_node_type, output_tensor_type="values_partition", partition_type=self.partition_type) edge, _ = kgcnn_ops_dyn_cast(inputs[1], input_tensor_type=found_edge_type, output_tensor_type="values_partition", partition_type=self.partition_type) edge_index, edge_part = kgcnn_ops_dyn_cast( inputs[2], input_tensor_type=found_index_type, output_tensor_type="values_partition", partition_type=self.partition_type) # Cast to length tensor node_len = kgcnn_ops_change_partition_type(node_part, self.partition_type, "row_length") edge_len = kgcnn_ops_change_partition_type(edge_part, self.partition_type, "row_length") # batch-wise indexing edge_index = kgcnn_ops_change_edge_tensor_indexing_by_row_partition( edge_index, node_len, edge_len, partition_type_node="row_length", partition_type_edge="row_length", from_indexing=self.node_indexing, to_indexing="batch") indexlist = edge_index valuelist = edge if not self.is_sorted: # Sort per outgoing batch_order = tf.argsort(indexlist[:, 1], axis=0, direction='ASCENDING') indexlist = tf.gather(indexlist, batch_order, axis=0) valuelist = tf.gather(valuelist, batch_order, axis=0) # Sort per ingoing node node_order = tf.argsort(indexlist[:, 0], axis=0, direction='ASCENDING', stable=True) indexlist = tf.gather(indexlist, node_order, axis=0) valuelist = tf.gather(valuelist, node_order, axis=0) indexlist = tf.cast(indexlist, dtype=tf.int64) dense_shape = tf.concat( [tf.shape(nod)[0:1], tf.shape(nod)[0:1]], axis=0) dense_shape = tf.cast(dense_shape, dtype=tf.int64) out = tf.sparse.SparseTensor(indexlist, valuelist[:, 0], dense_shape) return out
def call(self, inputs, **kwargs): """Forward pass. The tensor representation can be tf.RaggedTensor, tf.Tensor or a list of (values, partition). The RaggedTensor has shape (batch, None, F) or in case of equal sized graphs (batch, N, F). For disjoint representation (values, partition), the node embeddings are given by a flatten value tensor of shape (batch*None, F) and a partition tensor of either "row_length", "row_splits" or "value_rowids" that matches the tf.RaggedTensor partition information. In this case the partition_type and node_indexing scheme, i.e. "batch", must be known by the layer. For edge indices, the last dimension holds indices from outgoing to ingoing node (i,j) as a directed edge. Args: inputs: [node, edges, attention, edge_indices] - nodes: Node features of shape (batch, [N], F) - edges: Edge or message features of shape (batch, [N], F) - attention: Attention coefficients of shape (batch, [N], 1) - edge_index: Edge indices of shape (batch, [N], F) Returns: embeddings: Feature tensor of pooled edge attentions for each node. """ found_node_type = kgcnn_ops_get_tensor_type( inputs[0], input_tensor_type=self.input_tensor_type, node_indexing=self.node_indexing) found_edge_type = kgcnn_ops_get_tensor_type( inputs[1], input_tensor_type=self.input_tensor_type, node_indexing=self.node_indexing) found_att_type = kgcnn_ops_get_tensor_type( inputs[2], input_tensor_type=self.input_tensor_type, node_indexing=self.node_indexing) found_index_type = kgcnn_ops_get_tensor_type( inputs[3], input_tensor_type=self.input_tensor_type, node_indexing=self.node_indexing) # We cast to values here nod, node_part = kgcnn_ops_dyn_cast( inputs[0], input_tensor_type=found_node_type, output_tensor_type="values_partition", partition_type=self.partition_type) edge, _ = kgcnn_ops_dyn_cast(inputs[1], input_tensor_type=found_edge_type, output_tensor_type="values_partition", partition_type=self.partition_type) attention, _ = kgcnn_ops_dyn_cast( inputs[2], input_tensor_type=found_att_type, output_tensor_type="values_partition", partition_type=self.partition_type) edgeind, edge_part = kgcnn_ops_dyn_cast( inputs[3], input_tensor_type=found_index_type, output_tensor_type="values_partition", partition_type=self.partition_type) shiftind = kgcnn_ops_change_edge_tensor_indexing_by_row_partition( edgeind, node_part, edge_part, partition_type_node=self.partition_type, partition_type_edge=self.partition_type, to_indexing='batch', from_indexing=self.node_indexing) nodind = shiftind[:, 0] # Pick first index eg. ingoing dens = edge ats = attention if not self.is_sorted: # Sort edgeindices node_order = tf.argsort(nodind, axis=0, direction='ASCENDING', stable=True) nodind = tf.gather(nodind, node_order, axis=0) dens = tf.gather(dens, node_order, axis=0) ats = tf.gather(ats, node_order, axis=0) # Apply segmented softmax ats = segment_softmax(ats, nodind) get = dens * ats get = tf.math.segment_sum(get, nodind) if self.has_unconnected: # Need to fill tensor since the maximum node may not be also in pooled # Does not happen if all nodes are also connected get = kgcnn_ops_scatter_segment_tensor_nd(get, nodind, tf.shape(nod)) return kgcnn_ops_dyn_cast([get, node_part], input_tensor_type="values_partition", output_tensor_type=found_node_type, partition_type=self.partition_type)
def call(self, inputs, **kwargs): """Forward pass. The tensor representation can be tf.RaggedTensor, tf.Tensor or a list of (values, partition). The RaggedTensor has shape (batch, None, F) or in case of equal sized graphs (batch, N, F). For disjoint representation (values, partition), the node embeddings are given by a flatten value tensor of shape (batch*None, F) and a partition tensor of either "row_length", "row_splits" or "value_rowids" that matches the tf.RaggedTensor partition information. In this case the partition_type and node_indexing scheme, i.e. "batch", must be known by the layer. For edge indices, the last dimension holds indices from outgoing to ingoing node (i,j) as a directed edge. Args: inputs (list): of [nodes, node_partition, edges, edge_partition, edge_indices] - nodes: Node embeddings of shape (batch, [N], F) - edges: Edge embeddings of shape (batch, [N], F) - edge_indices: Edge index list of shape of shape (batch, [N], 2) Returns: Tuple: [nodes, edges, edge_indices], [map_nodes, map_edges] - nodes: Pooled node feature tensor - edges: Pooled edge feature list - edge_indices (tf.tensor): Pooled edge index list - map_nodes (tf.tensor): Index map between original and pooled nodes - map_edges (tf.tensor): Index map between original and pooled edges """ found_node_type = kgcnn_ops_get_tensor_type( inputs[0], input_tensor_type=self.input_tensor_type, node_indexing=self.node_indexing) found_edge_type = kgcnn_ops_get_tensor_type( inputs[1], input_tensor_type=self.input_tensor_type, node_indexing=self.node_indexing) found_index_type = kgcnn_ops_get_tensor_type( inputs[2], input_tensor_type=self.input_tensor_type, node_indexing=self.node_indexing) node, node_part = kgcnn_ops_dyn_cast( inputs[0], input_tensor_type=found_node_type, output_tensor_type="values_partition", partition_type=self.partition_type) edgefeat, edge_part = kgcnn_ops_dyn_cast( inputs[1], input_tensor_type=found_edge_type, output_tensor_type="values_partition", partition_type=self.partition_type) edgeindref, _ = kgcnn_ops_dyn_cast( inputs[2], input_tensor_type=found_index_type, output_tensor_type="values_partition", partition_type=self.partition_type) edgelen = kgcnn_ops_change_partition_type(edge_part, self.partition_type, "row_length") nodelen = kgcnn_ops_change_partition_type(node_part, self.partition_type, "row_length") index_dtype = edgeindref.dtype # Get node properties nvalue = node nrowlength = tf.cast(nodelen, dtype=index_dtype) erowlength = tf.cast(edgelen, dtype=index_dtype) nids = tf.repeat(tf.range(tf.shape(nrowlength)[0], dtype=index_dtype), nrowlength) # Use kernel p to get score norm_p = ks.backend.sqrt( ks.backend.sum(ks.backend.square(self.kernel_p), axis=-1, keepdims=True)) nscore = ks.backend.sum(nvalue * self.kernel_p / norm_p, axis=-1) # Sort nodes according to score # Then sort after former node ids -> stable = True keeps previous order sort1 = tf.argsort(nscore, direction='ASCENDING', stable=False) nids_sorted1 = tf.gather(nids, sort1) sort2 = tf.argsort(nids_sorted1, direction='ASCENDING', stable=True) # Must be stable=true here sort12 = tf.gather( sort1, sort2) # index goes from 0 to batch*N, no in batch indexing nvalue_sorted = tf.gather(nvalue, sort12, axis=0) nscore_sorted = tf.gather(nscore, sort12, axis=0) # Make Mask nremove = tf.cast(tf.math.round( self.k * tf.cast(nrowlength, dtype=tf.keras.backend.floatx())), dtype=index_dtype) nkeep = nrowlength - nremove n_remove_keep = ks.backend.flatten( tf.concat([ ks.backend.expand_dims(nremove, axis=-1), ks.backend.expand_dims(nkeep, axis=-1) ], axis=-1)) mask_remove_keep = ks.backend.flatten( tf.concat([ ks.backend.expand_dims(tf.zeros_like(nremove, dtype=tf.bool), axis=-1), ks.backend.expand_dims(tf.ones_like(nkeep, tf.bool), axis=-1) ], axis=-1)) mask = tf.repeat(mask_remove_keep, n_remove_keep) # Apply Mask to remove lower score nodes pooled_n = nvalue_sorted[mask] pooled_score = nscore_sorted[mask] pooled_id = nids[mask] # nids should not have changed by final sorting pooled_len = nkeep # shape=(batch,) pooled_index = tf.cast( sort12[mask], dtype=index_dtype) # the index goes from 0 to N*batch removed_index = tf.cast( sort12[tf.math.logical_not(mask)], dtype=index_dtype) # the index goes from 0 to N*batch # Pass through gate gated_n = pooled_n * ks.backend.expand_dims( tf.keras.activations.sigmoid(pooled_score), axis=-1) # Make index map for new nodes towards old index index_new_nodes = tf.range(tf.shape(pooled_index)[0], dtype=index_dtype) old_shape = tf.cast(ks.backend.expand_dims(tf.shape(nvalue)[0]), dtype=index_dtype) map_index = tf.scatter_nd( ks.backend.expand_dims(pooled_index, axis=-1), index_new_nodes, old_shape) # Shift index if necessary edge_ids = tf.repeat(tf.range(tf.shape(edgelen)[0], dtype=index_dtype), edgelen) shiftind = kgcnn_ops_change_edge_tensor_indexing_by_row_partition( edgeindref, nrowlength, edge_ids, partition_type_node="row_length", partition_type_edge="value_rowids", from_indexing=self.node_indexing, to_indexing="batch") shiftind = tf.cast( shiftind, dtype=index_dtype) # already shifted by batch offset (subgraphs) # Remove edges that were from filtered nodes via mask mask_edge = ks.backend.expand_dims( shiftind, axis=-1) == ks.backend.expand_dims( ks.backend.expand_dims(removed_index, axis=0), axis=0) # this creates large tensor (batch*#edges,2,remove) mask_edge = tf.math.logical_not( ks.backend.any(ks.backend.any(mask_edge, axis=-1), axis=-1)) clean_shiftind = shiftind[mask_edge] clean_edge_ids = edge_ids[mask_edge] # clean_edge_len = tf.math.segment_sum(tf.ones_like(clean_edge_ids), clean_edge_ids) clean_edge_len = tf.scatter_nd( tf.expand_dims(clean_edge_ids, axis=-1), tf.ones_like(clean_edge_ids), tf.cast(tf.shape(erowlength), dtype=index_dtype)) # Map edgeindex to new index new_edge_index = tf.concat([ ks.backend.expand_dims(tf.gather(map_index, clean_shiftind[:, 0]), axis=-1), ks.backend.expand_dims(tf.gather(map_index, clean_shiftind[:, 1]), axis=-1) ], axis=-1) batch_order = tf.argsort(new_edge_index[:, 0], axis=0, direction='ASCENDING', stable=True) new_edge_index_sorted = tf.gather(new_edge_index, batch_order, axis=0) # Remove the batch offset from edge_indices again for indexing type out_indexlist = kgcnn_ops_change_edge_tensor_indexing_by_row_partition( new_edge_index_sorted, pooled_len, clean_edge_ids, partition_type_node="row_length", partition_type_edge="value_rowids", from_indexing="batch", to_indexing=self.node_indexing) # Correct edge features the same way (remove and reorder) edge_feat = edgefeat clean_edge_feat = edge_feat[mask_edge] clean_edge_feat_sorted = tf.gather(clean_edge_feat, batch_order, axis=0) # Make edge feature map for new edge features edge_position_old = tf.range(tf.shape(edgefeat)[0], dtype=index_dtype) edge_position_new = edge_position_old[mask_edge] edge_position_new = tf.gather(edge_position_new, batch_order, axis=0) # Collect output tensors out_node = gated_n out_edge = clean_edge_feat_sorted out_edge_index = out_indexlist # Change length to partition required out_np = kgcnn_ops_change_partition_type(pooled_len, "row_length", self.partition_type) out_ep = kgcnn_ops_change_partition_type(clean_edge_len, "row_length", self.partition_type) # Collect reverse pooling info # Remove batch offset for old indicies -> but with new length out_pool = kgcnn_ops_change_edge_tensor_indexing_by_row_partition( pooled_index, nrowlength, pooled_len, partition_type_node="row_length", partition_type_edge="row_length", from_indexing="batch", to_indexing=self.node_indexing, axis=0) out_pool_edge = kgcnn_ops_change_edge_tensor_indexing_by_row_partition( edge_position_new, erowlength, clean_edge_ids, partition_type_node="row_length", partition_type_edge="value_rowids", from_indexing="batch", to_indexing=self.node_indexing, axis=0) out = [ kgcnn_ops_dyn_cast([out_node, out_np], input_tensor_type="values_partition", output_tensor_type=found_node_type, partition_type=self.partition_type), kgcnn_ops_dyn_cast([out_edge, out_ep], input_tensor_type="values_partition", output_tensor_type=found_edge_type, partition_type=self.partition_type), kgcnn_ops_dyn_cast([out_edge_index, out_ep], input_tensor_type="values_partition", output_tensor_type=found_index_type, partition_type=self.partition_type) ] out_map = [ kgcnn_ops_dyn_cast([out_pool, out_np], input_tensor_type="values_partition", output_tensor_type=found_node_type, partition_type=self.partition_type), kgcnn_ops_dyn_cast([out_pool_edge, out_ep], input_tensor_type="values_partition", output_tensor_type=found_edge_type, partition_type=self.partition_type) ] return out, out_map
def call(self, inputs, **kwargs): """Forward pass. The tensor representation can be tf.RaggedTensor, tf.Tensor or a list of (values, partition). The RaggedTensor has shape (batch, None, F) or in case of equal sized graphs (batch, N, F). For disjoint representation (values, partition), the node embeddings are given by a flatten value tensor of shape (batch*None, F) and a partition tensor of either "row_length", "row_splits" or "value_rowids" that matches the tf.RaggedTensor partition information. In this case the partition_type and node_indexing scheme, i.e. "batch", must be known by the layer. For edge indices, the last dimension holds indices from outgoing to ingoing node (i,j) as a directed edge, i.e. (batch, None, 2) Args: inputs (list): [node, edge, edge_indices, map_node, map_edge, node_pool, edge_pool, edge_indices_pool] - node: Original node tensor - edge: Original edge feature tensor - edge_indices: Original index tensor - map_node: Index map between original and pooled nodes - map_edge: Index map between original and pooled edges - node_pool: Pooled node tensor - edge_pool: Pooled edge feature tensor - edge_indices: Pooled index tensor Returns: List: [nodes, edges, edge_indices] - nodes: Unpooled node feature tensor - edges: Unpooled edge feature list - edge_indices: Unpooled edge index """ found_input_type = [ kgcnn_ops_get_tensor_type(inputs[i], input_tensor_type=self.input_tensor_type, node_indexing=self.node_indexing) for i in range(8) ] node_old, nodepart_old = kgcnn_ops_dyn_cast( inputs[0], input_tensor_type=found_input_type[0], output_tensor_type="values_partition", partition_type=self.partition_type) edge_old, edgepart_old = kgcnn_ops_dyn_cast( inputs[1], input_tensor_type=found_input_type[1], output_tensor_type="values_partition", partition_type=self.partition_type) edgeind_old, _ = kgcnn_ops_dyn_cast( inputs[2], input_tensor_type=found_input_type[2], output_tensor_type="values_partition", partition_type=self.partition_type) map_node, _ = kgcnn_ops_dyn_cast(inputs[3], input_tensor_type=found_input_type[3], output_tensor_type="values_partition", partition_type=self.partition_type) map_edge, _ = kgcnn_ops_dyn_cast(inputs[4], input_tensor_type=found_input_type[4], output_tensor_type="values_partition", partition_type=self.partition_type) node_new, nodpart_new = kgcnn_ops_dyn_cast( inputs[5], input_tensor_type=found_input_type[5], output_tensor_type="values_partition", partition_type=self.partition_type) edge_new, edgepart_new = kgcnn_ops_dyn_cast( inputs[6], input_tensor_type=found_input_type[6], output_tensor_type="values_partition", partition_type=self.partition_type) edgeind_new, _ = kgcnn_ops_dyn_cast( inputs[7], input_tensor_type=found_input_type[7], output_tensor_type="values_partition", partition_type=self.partition_type) nrowlength = kgcnn_ops_change_partition_type(nodepart_old, self.partition_type, "row_length") erowlength = kgcnn_ops_change_partition_type(edgepart_old, self.partition_type, "row_length") pool_node_len = kgcnn_ops_change_partition_type( nodpart_new, self.partition_type, "row_length") pool_edge_id = kgcnn_ops_change_partition_type(edgepart_new, self.partition_type, "value_rowids") # Correct map index for flatten batch offset map_node = kgcnn_ops_change_edge_tensor_indexing_by_row_partition( map_node, nrowlength, pool_node_len, partition_type_node="row_length", partition_type_edge="row_length", from_indexing=self.node_indexing, to_indexing="batch", axis=0) map_edge = kgcnn_ops_change_edge_tensor_indexing_by_row_partition( map_edge, erowlength, pool_edge_id, partition_type_node="row_length", partition_type_edge="value_rowids", from_indexing=self.node_indexing, to_indexing="batch", axis=0) index_dtype = map_node.dtype node_shape = tf.stack([ tf.cast(tf.shape(node_old)[0], dtype=index_dtype), tf.cast(tf.shape(node_new)[1], dtype=index_dtype) ]) out_node = tf.scatter_nd(ks.backend.expand_dims(map_node, axis=-1), node_new, node_shape) index_dtype = map_edge.dtype edge_shape = tf.stack([ tf.cast(tf.shape(edge_old)[0], dtype=index_dtype), tf.cast(tf.shape(edge_new)[1], dtype=index_dtype) ]) out_edge = tf.scatter_nd(ks.backend.expand_dims(map_edge, axis=-1), edge_new, edge_shape) outlist = [ kgcnn_ops_dyn_cast([out_node, nodepart_old], input_tensor_type="values_partition", output_tensor_type=found_input_type[0], partition_type=self.partition_type), kgcnn_ops_dyn_cast([out_edge, edgepart_old], input_tensor_type="values_partition", output_tensor_type=found_input_type[1], partition_type=self.partition_type), kgcnn_ops_dyn_cast([edgeind_old, edgepart_old], input_tensor_type="values_partition", output_tensor_type=found_input_type[2], partition_type=self.partition_type) ] return outlist
def call(self, inputs, **kwargs): """Forward path. The tensor representation can be tf.RaggedTensor, tf.Tensor or a list of (values, partition). The RaggedTensor has shape (batch, None, F) or in case of equal sized graphs (batch, N, F). For disjoint representation (values, partition), the node embeddings are given by a flatten value tensor of shape (batch*None, F) and a partition tensor of either "row_length", "row_splits" or "value_rowids" that matches the tf.RaggedTensor partition information. In this case the partition_type and node_indexing scheme, i.e. "batch", must be known by the layer. For edge indices, the last dimension holds indices from outgoing to ingoing node (i,j) as a directed edge. Args: inputs (list): [nodes, edges, edge_indices] - nodes: Node emebeddings of shape (batch, [N], F) - edges: Adjacency entries of shape (batch, [N], 1) - edge_indices: Index list of shape (batch, [N], 2) Returns: list: [edges, edge_indices] - edges: Adjacency entries of shape (batch, [N], 1) - edge_indices: Flatten index list of shape (batch, [N], 2) """ found_node_type = kgcnn_ops_get_tensor_type(inputs[0], input_tensor_type=self.input_tensor_type, node_indexing=self.node_indexing) found_edge_type = kgcnn_ops_get_tensor_type(inputs[1], input_tensor_type=self.input_tensor_type, node_indexing=self.node_indexing) found_index_type = kgcnn_ops_get_tensor_type(inputs[2], input_tensor_type=self.input_tensor_type, node_indexing=self.node_indexing) nod, node_part = kgcnn_ops_dyn_cast(inputs[0], input_tensor_type=found_node_type, output_tensor_type="values_partition", partition_type=self.partition_type) edge, _ = kgcnn_ops_dyn_cast(inputs[1], input_tensor_type=found_edge_type, output_tensor_type="values_partition", partition_type=self.partition_type) edge_index, edge_part = kgcnn_ops_dyn_cast(inputs[2], input_tensor_type=found_index_type, output_tensor_type="values_partition", partition_type=self.partition_type) # Cast to length tensor node_len = kgcnn_ops_change_partition_type(node_part, self.partition_type, "row_length") edge_len = kgcnn_ops_change_partition_type(edge_part, self.partition_type, "row_length") # batch-wise indexing edge_index = kgcnn_ops_change_edge_tensor_indexing_by_row_partition(edge_index, node_len, edge_len, partition_type_node="row_length", partition_type_edge="row_length", from_indexing=self.node_indexing, to_indexing="sample") ind_batch = tf.cast(tf.expand_dims(tf.repeat(tf.range(tf.shape(edge_len)[0]), edge_len), axis=-1), dtype=edge_index.dtype) ind_all = tf.concat([ind_batch, edge_index], axis=-1) ind_all = tf.cast(ind_all, dtype=tf.int64) max_index = tf.reduce_max(node_len) dense_shape = tf.stack([tf.cast(tf.shape(node_len)[0], dtype=max_index.dtype), max_index, max_index]) adj = tf.zeros(dense_shape, dtype=edge.dtype) ind_flat = tf.range(tf.cast(tf.shape(node_len)[0], dtype=max_index.dtype) * max_index * max_index) adj = tf.expand_dims(adj, axis=-1) adj = tf.tensor_scatter_nd_update(adj, ind_all, edge[:, 0:1]) adj = tf.squeeze(adj, axis=-1) out0 = adj out = adj for i in range(self.n - 1): out = tf.linalg.matmul(out, out0) # debug_result = out # sparsify mask = out > tf.keras.backend.epsilon() mask = tf.reshape(mask, (-1,)) out = tf.reshape(out, (-1,)) new_edge = out[mask] new_edge = tf.expand_dims(new_edge, axis=-1) new_indices = tf.unravel_index(ind_flat[mask], dims=dense_shape) new_egde_ids = new_indices[0] new_edge_index = tf.concat([tf.expand_dims(new_indices[1], axis=-1), tf.expand_dims(new_indices[2], axis=-1)], axis=-1) new_edge_len = tf.tensor_scatter_nd_add(tf.zeros_like(node_len), tf.expand_dims(new_egde_ids, axis=-1), tf.ones_like(new_egde_ids)) # Outpartition new_edge_part = kgcnn_ops_change_partition_type(new_edge_len, "row_length", self.partition_type) # batchwise indexing new_edge_index = kgcnn_ops_change_edge_tensor_indexing_by_row_partition(new_edge_index, node_len, new_edge_len, partition_type_node="row_length", partition_type_edge="row_length", from_indexing="sample", to_indexing=self.node_indexing) outlist = [kgcnn_ops_dyn_cast([new_edge, new_edge_part], input_tensor_type="values_partition", output_tensor_type=found_edge_type, partition_type=self.partition_type), kgcnn_ops_dyn_cast([new_edge_index, new_edge_part], input_tensor_type="values_partition", output_tensor_type=found_index_type, partition_type=self.partition_type) ] return outlist