def call(self, inputs, **kwargs): """Forward pass. The tensor representation can be tf.RaggedTensor, tf.Tensor or a list of (values, partition). The RaggedTensor has shape (batch, None, F) or in case of equal sized graphs (batch, N, F). For disjoint representation (values, partition), the node embeddings are given by a flatten value tensor of shape (batch*None, F) and a partition tensor of either "row_length", "row_splits" or "value_rowids" that matches the tf.RaggedTensor partition information. In this case the partition_type and node_indexing scheme, i.e. "batch", must be known by the layer. For edge indices, the last dimension holds indices from outgoing to ingoing node (i,j) as a directed edge. Args: inputs (list): [nodes, edge_index] - nodes: Node features of shape (batch, [N], F) - edge_index: Edge indices of shape (batch, [N], 2). Returns: edge_index: Corrected edge indices. """ if self.input_tensor_type == "values_partition": [_, part_node], [edge_index, part_edge] = inputs indexlist = kgcnn_ops_change_edge_tensor_indexing_by_row_partition( edge_index, part_node, part_edge, partition_type_node=self.partition_type, partition_type_edge=self.partition_type, from_indexing=self.from_indexing, to_indexing=self.to_indexing) return [indexlist, part_edge] elif self.input_tensor_type == "ragged": nod, edge_index = inputs indexlist = kgcnn_ops_change_edge_tensor_indexing_by_row_partition( edge_index.values, nod.row_splits, edge_index.value_rowids(), partition_type_node="row_splits", partition_type_edge="value_rowids", from_indexing=self.from_indexing, to_indexing=self.to_indexing) out = tf.RaggedTensor.from_row_splits( indexlist, edge_index.row_splits, validate=self.ragged_validate) return out
def call(self, inputs, **kwargs): """Forward pass. Args: inputs (list): [nodes, edge_index] - nodes (tf.ragged): Node embeddings of shape (batch, [N], F) - edge_index (tf.ragged): Node indices for edges of shape (batch, [M], 2) Returns: embeddings: Gathered node embeddings that match the number of edges. """ dyn_inputs = self._kgcnn_map_input_ragged(inputs, 2) # We cast to values here node, node_part = dyn_inputs[0].values, dyn_inputs[0].row_splits edge_index, edge_part = dyn_inputs[1].values, dyn_inputs[ 1].row_lengths() indexlist = kgcnn_ops_change_edge_tensor_indexing_by_row_partition( edge_index, node_part, edge_part, partition_type_node="row_splits", partition_type_edge="row_length", to_indexing='batch', from_indexing=self.node_indexing) # For ragged tensor we can now also try: # out = tf.gather(nod, edge_index[:, :, 1], batch_dims=1) out = tf.gather(node, indexlist[:, 1], axis=0) out = self._kgcnn_map_output_ragged([out, edge_part], "row_length", 1) return out
def call(self, inputs, **kwargs): """Forward pass. Args: inputs (list): [position, edge_index] - position (tf.ragged): Node positions of shape (batch, [N], 3) - edge_index (tf.ragged): Edge indices of shape (batch, [M], 2) Returns: distances: Gathered node distances as edges that match the number of indices of shape (batch, [M], 1) """ dyn_inputs = self._kgcnn_map_input_ragged(inputs, 2) # We cast to values here node, node_part = dyn_inputs[0].values, dyn_inputs[0].row_splits edge_index, edge_part = dyn_inputs[1].values, dyn_inputs[1].row_lengths() indexlist = kgcnn_ops_change_edge_tensor_indexing_by_row_partition(edge_index, node_part, edge_part, partition_type_node="row_splits", partition_type_edge="row_length", to_indexing='batch', from_indexing=self.node_indexing) # For ragged tensor we can now also try: # out = tf.gather(nod, edge_index[:, :, 0], batch_dims=1) xi = tf.gather(node, indexlist[:, 0], axis=0) xj = tf.gather(node, indexlist[:, 1], axis=0) out = tf.expand_dims(tf.sqrt(tf.nn.relu(tf.reduce_sum(tf.math.square(xi - xj), axis=-1))), axis=-1) out = self._kgcnn_map_output_ragged([out, edge_part], "row_length", 1) return out
def call(self, inputs, **kwargs): """Forward pass. Args: inputs (list): of [node, edges, edge_index, weights] - nodes (tf.ragged): Node features of shape (batch, [N], F) - edges (tf.ragged): Edge or message features of shape (batch, [M], F) - edge_index (tf.ragged): Edge indices of shape (batch, [M], 2) - weights (tf.ragged): Edge or message weights. Must broadcast to edges or messages, e.g. (batch, [M], 1) Returns: features: Pooled feature tensor of pooled edge features for each node of shape (batch, [N], F) """ dyn_inputs = self._kgcnn_map_input_ragged(inputs, 4) # We cast to values here nod, node_part = dyn_inputs[0].values, dyn_inputs[0].row_splits edge, _ = dyn_inputs[1].values, dyn_inputs[1].row_lengths() edgeind, edge_part = dyn_inputs[2].values, dyn_inputs[2].row_lengths() weights, _ = dyn_inputs[3].values, dyn_inputs[3].row_lengths() shiftind = kgcnn_ops_change_edge_tensor_indexing_by_row_partition( edgeind, node_part, edge_part, partition_type_node="row_splits", partition_type_edge="row_length", to_indexing='batch', from_indexing=self.node_indexing) wval = weights dens = edge * wval nodind = shiftind[:, 0] if not self.is_sorted: # Sort edgeindices node_order = tf.argsort(nodind, axis=0, direction='ASCENDING', stable=True) nodind = tf.gather(nodind, node_order, axis=0) dens = tf.gather(dens, node_order, axis=0) wval = tf.gather(wval, node_order, axis=0) # Pooling via e.g. segment_sum get = kgcnn_ops_segment_operation_by_name(self.pooling_method, dens, nodind) if self.normalize_by_weights: get = tf.math.divide_no_nan(get, tf.math.segment_sum(wval, nodind)) # +tf.eps if self.has_unconnected: get = kgcnn_ops_scatter_segment_tensor_nd(get, nodind, tf.shape(nod)) out = self._kgcnn_map_output_ragged([get, node_part], "row_splits", 0) return out
def call(self, inputs, **kwargs): """Forward pass. The tensor representation can be tf.RaggedTensor, tf.Tensor or a list of (values, partition). The RaggedTensor has shape (batch, None, F) or in case of equal sized graphs (batch, N, F). For disjoint representation (values, partition), the node embeddings are given by a flatten value tensor of shape (batch*None, F) and a partition tensor of either "row_length", "row_splits" or "value_rowids" that matches the tf.RaggedTensor partition information. In this case the partition_type and node_indexing scheme, i.e. "batch", must be known by the layer. For edge indices, the last dimension holds indices from outgoing to ingoing node (i,j) as a directed edge. Args: inputs (list): [nodes, edge_index] - nodes: Node embeddings of shape (batch, [N], F) - edge_index: Edge indices of shape (batch, [N], 2) Returns: embeddings: Gathered node embeddings that match the number of edges. """ found_node_type = kgcnn_ops_get_tensor_type( inputs[0], input_tensor_type=self.input_tensor_type, node_indexing=self.node_indexing) found_edge_type = kgcnn_ops_get_tensor_type( inputs[1], input_tensor_type=self.input_tensor_type, node_indexing=self.node_indexing) # We cast to values here node, node_part = kgcnn_ops_dyn_cast( inputs[0], input_tensor_type=found_node_type, output_tensor_type="values_partition", partition_type=self.partition_type) edge_index, edge_part = kgcnn_ops_dyn_cast( inputs[1], input_tensor_type=found_edge_type, output_tensor_type="values_partition", partition_type=self.partition_type) indexlist = kgcnn_ops_change_edge_tensor_indexing_by_row_partition( edge_index, node_part, edge_part, partition_type_node=self.partition_type, partition_type_edge=self.partition_type, to_indexing='batch', from_indexing=self.node_indexing) out = tf.gather(node, indexlist[:, 0], axis=0) # For ragged tensor we can now also try: # out = tf.gather(nod, edge_index[:, :, 0], batch_dims=1) return kgcnn_ops_dyn_cast([out, edge_part], input_tensor_type="values_partition", output_tensor_type=found_edge_type, partition_type=self.partition_type)
def call(self, inputs, **kwargs): """Forward pass. The tensor representation can be tf.RaggedTensor, tf.Tensor or a list of (values, partition). The RaggedTensor has shape (batch, None, F) or in case of equal sized graphs (batch, N, F). For disjoint representation (values, partition), the node embeddings are given by a flatten value tensor of shape (batch*None, F) and a partition tensor of either "row_length", "row_splits" or "value_rowids" that matches the tf.RaggedTensor partition information. In this case the partition_type and node_indexing scheme, i.e. "batch", must be known by the layer. For edge indices, the last dimension holds indices from outgoing to ingoing node (i,j) as a directed edge. Args: inputs (list): of [node, edges, edge_index] - nodes: Node features of shape (batch, [N], F) - edges: Edge or message features of shape (batch, [N], F) - edge_index: Edge indices of shape (batch, [N], 2) Returns: features: Pooled feature tensor of pooled edge features for each node. """ found_node_type = kgcnn_ops_get_tensor_type(inputs[0], input_tensor_type=self.input_tensor_type, node_indexing=self.node_indexing) found_edge_type = kgcnn_ops_get_tensor_type(inputs[1], input_tensor_type=self.input_tensor_type, node_indexing=self.node_indexing) found_index_type = kgcnn_ops_get_tensor_type(inputs[2], input_tensor_type=self.input_tensor_type, node_indexing=self.node_indexing) nod, node_part = kgcnn_ops_dyn_cast(inputs[0], input_tensor_type=found_node_type, output_tensor_type="values_partition", partition_type=self.partition_type) edge, _ = kgcnn_ops_dyn_cast(inputs[1], input_tensor_type=found_edge_type, output_tensor_type="values_partition", partition_type=self.partition_type) edgeind, edge_part = kgcnn_ops_dyn_cast(inputs[2], input_tensor_type=found_index_type, output_tensor_type="values_partition", partition_type=self.partition_type) shiftind = kgcnn_ops_change_edge_tensor_indexing_by_row_partition(edgeind, node_part, edge_part, partition_type_node=self.partition_type, partition_type_edge=self.partition_type, to_indexing='batch', from_indexing=self.node_indexing) nodind = shiftind[:, 0] # Pick first index eg. ingoing dens = edge if not self.is_sorted: # Sort edgeindices node_order = tf.argsort(nodind, axis=0, direction='ASCENDING', stable=True) nodind = tf.gather(nodind, node_order, axis=0) dens = tf.gather(dens, node_order, axis=0) # Pooling via e.g. segment_sum out = kgcnn_ops_segment_operation_by_name(self.pooling_method, dens, nodind) if self.has_unconnected: out = kgcnn_ops_scatter_segment_tensor_nd(out, nodind, tf.shape(nod)) return kgcnn_ops_dyn_cast([out, node_part], input_tensor_type="values_partition", output_tensor_type=found_node_type, partition_type=self.partition_type)
def call(self, inputs, **kwargs): """Forward pass. Args: inputs (list): [position, edge_index] - position (tf.ragged): Node positions of shape (batch, [N], 3) - edge_index (tf.ragged): Node indices of shape (batch, [M], 2) referring to nodes. - angle_index (tf.ragged): Edge indices of shape (batch, [K], 2) referring to edges. Returns: angles: Gathered edge angles between edges that match the indices. Shape is (batch, [K], 1) """ dyn_inputs = self._kgcnn_map_input_ragged(inputs, 3) node, node_part = dyn_inputs[0].values, dyn_inputs[0].row_splits edge_index, edge_part = dyn_inputs[1].values, dyn_inputs[1].row_lengths() angle_index, angle_part = dyn_inputs[2].values, dyn_inputs[2].row_lengths() indexlist = kgcnn_ops_change_edge_tensor_indexing_by_row_partition(edge_index, node_part, edge_part, partition_type_node="row_splits", partition_type_edge="row_length", to_indexing='batch', from_indexing=self.node_indexing) indexlist2 = kgcnn_ops_change_edge_tensor_indexing_by_row_partition(angle_index, edge_part, angle_part, partition_type_node="row_splits", partition_type_edge="row_length", to_indexing='batch', from_indexing=self.node_indexing) # For ragged tensor we can now also try: # out = tf.gather(nod, edge_index[:, :, 0], batch_dims=1) xi = tf.gather(node, indexlist[:, 0], axis=0) xj = tf.gather(node, indexlist[:, 1], axis=0) vs = xj - xi v1 = tf.gather(vs, indexlist2[:, 0], axis=0) v2 = tf.gather(vs, indexlist2[:, 1], axis=0) x = tf.reduce_sum(v1 * v2, axis=-1) y = tf.linalg.cross(v1, v2) y = tf.norm(y, axis=-1) angle = tf.math.atan2(y, x) angle = tf.expand_dims(angle, axis=-1) out = self._kgcnn_map_output_ragged([angle, angle_part], "row_length", 2) return out
def call(self, inputs, **kwargs): """Forward pass. Args: inputs: [node, edges, attention, edge_indices] - nodes (tf.ragged): Node features of shape (batch, [N], F) - edges (tf.ragged): Edge or message features of shape (batch, [M], F) - attention (tf.ragged): Attention coefficients of shape (batch, [M], 1) - edge_index (tf.ragged): Edge indices of shape (batch, [M], F) Returns: embeddings: Feature tensor of pooled edge attentions for each node of shape (batch, [N], F) """ dyn_inputs = self._kgcnn_map_input_ragged(inputs, 4) # We cast to values here nod, node_part = dyn_inputs[0].values, dyn_inputs[0].row_lengths() edge = dyn_inputs[1].values attention = dyn_inputs[2].values edgeind, edge_part = dyn_inputs[3].values, dyn_inputs[3].row_lengths() shiftind = kgcnn_ops_change_edge_tensor_indexing_by_row_partition( edgeind, node_part, edge_part, partition_type_node="row_length", partition_type_edge="row_length", to_indexing='batch', from_indexing=self.node_indexing) nodind = shiftind[:, 0] # Pick first index eg. ingoing dens = edge ats = attention if not self.is_sorted: # Sort edgeindices node_order = tf.argsort(nodind, axis=0, direction='ASCENDING', stable=True) nodind = tf.gather(nodind, node_order, axis=0) dens = tf.gather(dens, node_order, axis=0) ats = tf.gather(ats, node_order, axis=0) # Apply segmented softmax ats = segment_softmax(ats, nodind) get = dens * ats get = tf.math.segment_sum(get, nodind) if self.has_unconnected: # Need to fill tensor since the maximum node may not be also in pooled # Does not happen if all nodes are also connected get = kgcnn_ops_scatter_segment_tensor_nd(get, nodind, tf.shape(nod)) out = self._kgcnn_map_output_ragged([get, node_part], "row_length", 0) return out
def call(self, inputs, **kwargs): """Forward pass. Args: Inputs list of [nodes, edges, edge_index] - nodes (tf.ragged): Node feature tensor of shape (batch, [N], F) - edges (tf.ragged): Edge feature ragged tensor of shape (batch, [M], 1) - edge_index (tf.ragged): Ragged edge_indices of shape (batch, [M], 2) Returns: tf.sparse: Sparse disjoint matrix of shape (batch*None,batch*None) """ dyn_inputs = self._kgcnn_map_input_ragged(inputs, 3) # We cast to values here nod, node_len = dyn_inputs[0].values, dyn_inputs[0].row_lengths() edge, _ = dyn_inputs[1].values, dyn_inputs[1].row_lengths() edge_index, edge_len = dyn_inputs[2].values, dyn_inputs[2].row_lengths( ) # batch-wise indexing edge_index = kgcnn_ops_change_edge_tensor_indexing_by_row_partition( edge_index, node_len, edge_len, partition_type_node="row_length", partition_type_edge="row_length", from_indexing=self.node_indexing, to_indexing="batch") indexlist = edge_index valuelist = edge if not self.is_sorted: # Sort per outgoing batch_order = tf.argsort(indexlist[:, 1], axis=0, direction='ASCENDING') indexlist = tf.gather(indexlist, batch_order, axis=0) valuelist = tf.gather(valuelist, batch_order, axis=0) # Sort per ingoing node node_order = tf.argsort(indexlist[:, 0], axis=0, direction='ASCENDING', stable=True) indexlist = tf.gather(indexlist, node_order, axis=0) valuelist = tf.gather(valuelist, node_order, axis=0) indexlist = tf.cast(indexlist, dtype=tf.int64) dense_shape = tf.concat( [tf.shape(nod)[0:1], tf.shape(nod)[0:1]], axis=0) dense_shape = tf.cast(dense_shape, dtype=tf.int64) out = tf.sparse.SparseTensor(indexlist, valuelist[:, 0], dense_shape) return out
def call(self, inputs, **kwargs): """Forward pass. Args: inputs (list): [nodes, edges, edge_index] - nodes (tf.ragged): Node features of shape (batch, [N], F) - edges (tf.ragged): Edge or message features of shape (batch, [M], F) - edge_index (tf.ragged): Edge indices of shape (batch, [M], 2) Returns: features: Feature tensor of pooled edge features for each node of shape (batch, [N], F) """ dyn_inputs = self._kgcnn_map_input_ragged(inputs, 3) # We cast to values here nod, node_part = dyn_inputs[0].values, dyn_inputs[0].row_splits edge, _ = dyn_inputs[1].values, dyn_inputs[1].row_lengths() edgeind, edge_part = dyn_inputs[2].values, dyn_inputs[2].row_lengths() shiftind = kgcnn_ops_change_edge_tensor_indexing_by_row_partition( edgeind, node_part, edge_part, partition_type_node="row_splits", partition_type_edge="row_length", to_indexing='batch', from_indexing=self.node_indexing) nodind = shiftind[:, 0] # Pick first index eg. ingoing dens = edge if not self.is_sorted: # Sort edgeindices node_order = tf.argsort(nodind, axis=0, direction='ASCENDING', stable=True) nodind = tf.gather(nodind, node_order, axis=0) dens = tf.gather(dens, node_order, axis=0) # Pooling via LSTM # we make a ragged input ragged_lstm_input = tf.RaggedTensor.from_value_rowids(dens, nodind) get = self.lstm_unit(ragged_lstm_input) if self.has_unconnected: # Need to fill tensor since the maximum node may not be also in pooled # Does not happen if all nodes are also connected get = kgcnn_ops_scatter_segment_tensor_nd(get, nodind, tf.shape(nod)) out = self._kgcnn_map_output_ragged([get, node_part], "row_splits", 0) return out
def call(self, inputs, **kwargs): """Forward pass. Args: inputs (list): of [node, edges, edge_index] - nodes (tf.ragged): Node features of shape (batch, [N], F) - edges (tf.ragged): Edge or message features of shape (batch, [M], F) - edge_index (tf.ragged): Edge indices of shape (batch, [M], 2) Returns: features: Pooled feature tensor of pooled edge features for each node. """ dyn_inputs = self._kgcnn_map_input_ragged(inputs, 3) # We cast to values here nod, node_part = dyn_inputs[0].values, dyn_inputs[0].row_splits edge, _ = dyn_inputs[1].values, dyn_inputs[1].row_lengths() edgeind, edge_part = dyn_inputs[2].values, dyn_inputs[2].row_lengths() shiftind = kgcnn_ops_change_edge_tensor_indexing_by_row_partition( edgeind, node_part, edge_part, partition_type_node="row_splits", partition_type_edge="row_length", to_indexing='batch', from_indexing=self.node_indexing) nodind = shiftind[:, 0] # Pick first index eg. ingoing dens = edge if not self.is_sorted: # Sort edgeindices node_order = tf.argsort(nodind, axis=0, direction='ASCENDING', stable=True) nodind = tf.gather(nodind, node_order, axis=0) dens = tf.gather(dens, node_order, axis=0) # Pooling via e.g. segment_sum out = kgcnn_ops_segment_operation_by_name(self.pooling_method, dens, nodind) if self.has_unconnected: out = kgcnn_ops_scatter_segment_tensor_nd(out, nodind, tf.shape(nod)) out = self._kgcnn_map_output_ragged([out, node_part], "row_splits", 0) return out
def call(self, inputs, **kwargs): """Forward pass. Args: inputs: [distance, angles, angle_index] - distance (tf.ragged): Edge distance of shape (batch, [M], 1) - angles (tf.ragged): Angle list of shape (batch, [K], 1) - angle_index (tf.ragged): Indices referring to edges of shape (batch, [K], 2) Returns: angles: Expanded angle/distance basis. Shape is (batch, [K], #Radial * #Spherical) """ dyn_inputs = self._kgcnn_map_input_ragged(inputs, 3) edge, edge_part = dyn_inputs[0].values, dyn_inputs[0].row_splits angles, angle_part = dyn_inputs[1].values, dyn_inputs[1].row_splits angle_index, angle_index_part = dyn_inputs[2].values, dyn_inputs[2].row_lengths() indexlist = kgcnn_ops_change_edge_tensor_indexing_by_row_partition(angle_index, edge_part, angle_index_part, partition_type_node="row_splits", partition_type_edge="row_length", to_indexing='batch', from_indexing=self.node_indexing) d = edge id_expand_kj = indexlist d_scaled = d[:, 0] * self.inv_cutoff rbf = [] for n in range(self.num_spherical): for k in range(self.num_radial): rbf += [self.bessel_norm[n, k] * tf_spherical_bessel_jn(d_scaled * self.bessel_n_zeros[n][k], n)] rbf = tf.stack(rbf, axis=1) d_cutoff = self.envelope(d_scaled) rbf_env = d_cutoff[:, None] * rbf rbf_env = tf.gather(rbf_env, id_expand_kj[:, 1]) cbf = [tf_spherical_harmonics_yl(angles[:, 0], n) for n in range(self.num_spherical)] cbf = tf.stack(cbf, axis=1) cbf = tf.repeat(cbf, self.num_radial, axis=1) out = rbf_env * cbf out = self._kgcnn_map_output_ragged([out, angle_part], "row_splits", 0) return out
def call(self, inputs, **kwargs): """Forward path. The tensor representation can be tf.RaggedTensor, tf.Tensor or a list of (values, partition). The RaggedTensor has shape (batch, None, F) or in case of equal sized graphs (batch, N, F). For disjoint representation (values, partition), the node embeddings are given by a flatten value tensor of shape (batch*None, F) and a partition tensor of either "row_length", "row_splits" or "value_rowids" that matches the tf.RaggedTensor partition information. In this case the partition_type and node_indexing scheme, i.e. "batch", must be known by the layer. For edge indices, the last dimension holds indices from outgoing to ingoing node (i,j) as a directed edge. Args: inputs (list): [nodes, edges, edge_indices] - nodes: Node emebeddings of shape (batch, [N], F) - edges: Adjacency entries of shape (batch, [N], 1) - edge_indices: Index list of shape (batch, [N], 2) Returns: list: [edges, edge_indices] - edges: Adjacency entries of shape (batch, [N], 1) - edge_indices: Flatten index list of shape (batch, [N], 2) """ found_node_type = kgcnn_ops_get_tensor_type(inputs[0], input_tensor_type=self.input_tensor_type, node_indexing=self.node_indexing) found_edge_type = kgcnn_ops_get_tensor_type(inputs[1], input_tensor_type=self.input_tensor_type, node_indexing=self.node_indexing) found_index_type = kgcnn_ops_get_tensor_type(inputs[2], input_tensor_type=self.input_tensor_type, node_indexing=self.node_indexing) nod, node_part = kgcnn_ops_dyn_cast(inputs[0], input_tensor_type=found_node_type, output_tensor_type="values_partition", partition_type=self.partition_type) edge, _ = kgcnn_ops_dyn_cast(inputs[1], input_tensor_type=found_edge_type, output_tensor_type="values_partition", partition_type=self.partition_type) edge_index, edge_part = kgcnn_ops_dyn_cast(inputs[2], input_tensor_type=found_index_type, output_tensor_type="values_partition", partition_type=self.partition_type) # Cast to length tensor node_len = kgcnn_ops_change_partition_type(node_part, self.partition_type, "row_length") edge_len = kgcnn_ops_change_partition_type(edge_part, self.partition_type, "row_length") # batch-wise indexing edge_index = kgcnn_ops_change_edge_tensor_indexing_by_row_partition(edge_index, node_len, edge_len, partition_type_node="row_length", partition_type_edge="row_length", from_indexing=self.node_indexing, to_indexing="sample") ind_batch = tf.cast(tf.expand_dims(tf.repeat(tf.range(tf.shape(edge_len)[0]), edge_len), axis=-1), dtype=edge_index.dtype) ind_all = tf.concat([ind_batch, edge_index], axis=-1) ind_all = tf.cast(ind_all, dtype=tf.int64) max_index = tf.reduce_max(node_len) dense_shape = tf.stack([tf.cast(tf.shape(node_len)[0], dtype=max_index.dtype), max_index, max_index]) adj = tf.zeros(dense_shape, dtype=edge.dtype) ind_flat = tf.range(tf.cast(tf.shape(node_len)[0], dtype=max_index.dtype) * max_index * max_index) adj = tf.expand_dims(adj, axis=-1) adj = tf.tensor_scatter_nd_update(adj, ind_all, edge[:, 0:1]) adj = tf.squeeze(adj, axis=-1) out0 = adj out = adj for i in range(self.n - 1): out = tf.linalg.matmul(out, out0) # debug_result = out # sparsify mask = out > tf.keras.backend.epsilon() mask = tf.reshape(mask, (-1,)) out = tf.reshape(out, (-1,)) new_edge = out[mask] new_edge = tf.expand_dims(new_edge, axis=-1) new_indices = tf.unravel_index(ind_flat[mask], dims=dense_shape) new_egde_ids = new_indices[0] new_edge_index = tf.concat([tf.expand_dims(new_indices[1], axis=-1), tf.expand_dims(new_indices[2], axis=-1)], axis=-1) new_edge_len = tf.tensor_scatter_nd_add(tf.zeros_like(node_len), tf.expand_dims(new_egde_ids, axis=-1), tf.ones_like(new_egde_ids)) # Outpartition new_edge_part = kgcnn_ops_change_partition_type(new_edge_len, "row_length", self.partition_type) # batchwise indexing new_edge_index = kgcnn_ops_change_edge_tensor_indexing_by_row_partition(new_edge_index, node_len, new_edge_len, partition_type_node="row_length", partition_type_edge="row_length", from_indexing="sample", to_indexing=self.node_indexing) outlist = [kgcnn_ops_dyn_cast([new_edge, new_edge_part], input_tensor_type="values_partition", output_tensor_type=found_edge_type, partition_type=self.partition_type), kgcnn_ops_dyn_cast([new_edge_index, new_edge_part], input_tensor_type="values_partition", output_tensor_type=found_index_type, partition_type=self.partition_type) ] return outlist
def call(self, inputs, **kwargs): """Forward pass. The tensor representation can be tf.RaggedTensor, tf.Tensor or a list of (values, partition). The RaggedTensor has shape (batch, None, F) or in case of equal sized graphs (batch, N, F). For disjoint representation (values, partition), the node embeddings are given by a flatten value tensor of shape (batch*None, F) and a partition tensor of either "row_length", "row_splits" or "value_rowids" that matches the tf.RaggedTensor partition information. In this case the partition_type and node_indexing scheme, i.e. "batch", must be known by the layer. For edge indices, the last dimension holds indices from outgoing to ingoing node (i,j) as a directed edge, i.e. (batch, None, 2) Args: inputs (list): [node, edge, edge_indices, map_node, map_edge, node_pool, edge_pool, edge_indices_pool] - node: Original node tensor - edge: Original edge feature tensor - edge_indices: Original index tensor - map_node: Index map between original and pooled nodes - map_edge: Index map between original and pooled edges - node_pool: Pooled node tensor - edge_pool: Pooled edge feature tensor - edge_indices: Pooled index tensor Returns: List: [nodes, edges, edge_indices] - nodes: Unpooled node feature tensor - edges: Unpooled edge feature list - edge_indices: Unpooled edge index """ found_input_type = [ kgcnn_ops_get_tensor_type(inputs[i], input_tensor_type=self.input_tensor_type, node_indexing=self.node_indexing) for i in range(8) ] node_old, nodepart_old = kgcnn_ops_dyn_cast( inputs[0], input_tensor_type=found_input_type[0], output_tensor_type="values_partition", partition_type=self.partition_type) edge_old, edgepart_old = kgcnn_ops_dyn_cast( inputs[1], input_tensor_type=found_input_type[1], output_tensor_type="values_partition", partition_type=self.partition_type) edgeind_old, _ = kgcnn_ops_dyn_cast( inputs[2], input_tensor_type=found_input_type[2], output_tensor_type="values_partition", partition_type=self.partition_type) map_node, _ = kgcnn_ops_dyn_cast(inputs[3], input_tensor_type=found_input_type[3], output_tensor_type="values_partition", partition_type=self.partition_type) map_edge, _ = kgcnn_ops_dyn_cast(inputs[4], input_tensor_type=found_input_type[4], output_tensor_type="values_partition", partition_type=self.partition_type) node_new, nodpart_new = kgcnn_ops_dyn_cast( inputs[5], input_tensor_type=found_input_type[5], output_tensor_type="values_partition", partition_type=self.partition_type) edge_new, edgepart_new = kgcnn_ops_dyn_cast( inputs[6], input_tensor_type=found_input_type[6], output_tensor_type="values_partition", partition_type=self.partition_type) edgeind_new, _ = kgcnn_ops_dyn_cast( inputs[7], input_tensor_type=found_input_type[7], output_tensor_type="values_partition", partition_type=self.partition_type) nrowlength = kgcnn_ops_change_partition_type(nodepart_old, self.partition_type, "row_length") erowlength = kgcnn_ops_change_partition_type(edgepart_old, self.partition_type, "row_length") pool_node_len = kgcnn_ops_change_partition_type( nodpart_new, self.partition_type, "row_length") pool_edge_id = kgcnn_ops_change_partition_type(edgepart_new, self.partition_type, "value_rowids") # Correct map index for flatten batch offset map_node = kgcnn_ops_change_edge_tensor_indexing_by_row_partition( map_node, nrowlength, pool_node_len, partition_type_node="row_length", partition_type_edge="row_length", from_indexing=self.node_indexing, to_indexing="batch", axis=0) map_edge = kgcnn_ops_change_edge_tensor_indexing_by_row_partition( map_edge, erowlength, pool_edge_id, partition_type_node="row_length", partition_type_edge="value_rowids", from_indexing=self.node_indexing, to_indexing="batch", axis=0) index_dtype = map_node.dtype node_shape = tf.stack([ tf.cast(tf.shape(node_old)[0], dtype=index_dtype), tf.cast(tf.shape(node_new)[1], dtype=index_dtype) ]) out_node = tf.scatter_nd(ks.backend.expand_dims(map_node, axis=-1), node_new, node_shape) index_dtype = map_edge.dtype edge_shape = tf.stack([ tf.cast(tf.shape(edge_old)[0], dtype=index_dtype), tf.cast(tf.shape(edge_new)[1], dtype=index_dtype) ]) out_edge = tf.scatter_nd(ks.backend.expand_dims(map_edge, axis=-1), edge_new, edge_shape) outlist = [ kgcnn_ops_dyn_cast([out_node, nodepart_old], input_tensor_type="values_partition", output_tensor_type=found_input_type[0], partition_type=self.partition_type), kgcnn_ops_dyn_cast([out_edge, edgepart_old], input_tensor_type="values_partition", output_tensor_type=found_input_type[1], partition_type=self.partition_type), kgcnn_ops_dyn_cast([edgeind_old, edgepart_old], input_tensor_type="values_partition", output_tensor_type=found_input_type[2], partition_type=self.partition_type) ] return outlist
def call(self, inputs, **kwargs): """Forward pass. The tensor representation can be tf.RaggedTensor, tf.Tensor or a list of (values, partition). The RaggedTensor has shape (batch, None, F) or in case of equal sized graphs (batch, N, F). For disjoint representation (values, partition), the node embeddings are given by a flatten value tensor of shape (batch*None, F) and a partition tensor of either "row_length", "row_splits" or "value_rowids" that matches the tf.RaggedTensor partition information. In this case the partition_type and node_indexing scheme, i.e. "batch", must be known by the layer. For edge indices, the last dimension holds indices from outgoing to ingoing node (i,j) as a directed edge. Args: Inputs list of [nodes, edges, edge_index] - nodes: Node feature tensor of shape (batch, [N], F) - edges: Edge feature ragged tensor of shape (batch, [N], 1) - edge_index: Ragged edge_indices of shape (batch, [N], 2) Returns: tf.sparse: Sparse disjoint matrix of shape (batch*None,batch*None) """ found_node_type = kgcnn_ops_get_tensor_type( inputs[0], input_tensor_type=self.input_tensor_type, node_indexing=self.node_indexing) found_edge_type = kgcnn_ops_get_tensor_type( inputs[1], input_tensor_type=self.input_tensor_type, node_indexing=self.node_indexing) found_index_type = kgcnn_ops_get_tensor_type( inputs[2], input_tensor_type=self.input_tensor_type, node_indexing=self.node_indexing) nod, node_part = kgcnn_ops_dyn_cast( inputs[0], input_tensor_type=found_node_type, output_tensor_type="values_partition", partition_type=self.partition_type) edge, _ = kgcnn_ops_dyn_cast(inputs[1], input_tensor_type=found_edge_type, output_tensor_type="values_partition", partition_type=self.partition_type) edge_index, edge_part = kgcnn_ops_dyn_cast( inputs[2], input_tensor_type=found_index_type, output_tensor_type="values_partition", partition_type=self.partition_type) # Cast to length tensor node_len = kgcnn_ops_change_partition_type(node_part, self.partition_type, "row_length") edge_len = kgcnn_ops_change_partition_type(edge_part, self.partition_type, "row_length") # batch-wise indexing edge_index = kgcnn_ops_change_edge_tensor_indexing_by_row_partition( edge_index, node_len, edge_len, partition_type_node="row_length", partition_type_edge="row_length", from_indexing=self.node_indexing, to_indexing="batch") indexlist = edge_index valuelist = edge if not self.is_sorted: # Sort per outgoing batch_order = tf.argsort(indexlist[:, 1], axis=0, direction='ASCENDING') indexlist = tf.gather(indexlist, batch_order, axis=0) valuelist = tf.gather(valuelist, batch_order, axis=0) # Sort per ingoing node node_order = tf.argsort(indexlist[:, 0], axis=0, direction='ASCENDING', stable=True) indexlist = tf.gather(indexlist, node_order, axis=0) valuelist = tf.gather(valuelist, node_order, axis=0) indexlist = tf.cast(indexlist, dtype=tf.int64) dense_shape = tf.concat( [tf.shape(nod)[0:1], tf.shape(nod)[0:1]], axis=0) dense_shape = tf.cast(dense_shape, dtype=tf.int64) out = tf.sparse.SparseTensor(indexlist, valuelist[:, 0], dense_shape) return out
def call(self, inputs, **kwargs): """Forward path. Args: inputs (list): [nodes, edges, edge_indices] - nodes (tf.ragged): Node emebeddings of shape (batch, [N], F) - edges (tf.ragged): Adjacency entries of shape (batch, [M], 1) - edge_indices (tf.ragged): Index list of shape (batch, [M], 2) Returns: list: [edges, edge_indices] - edges (tf.ragged): Adjacency entries of shape (batch, [M], 1) - edge_indices (tf.ragged): Flatten index list of shape (batch, [M], 2) """ dyn_inputs = self._kgcnn_map_input_ragged(inputs, 3) nod, node_len = dyn_inputs[0].values, dyn_inputs[0].row_lengths() edge = dyn_inputs[1].values edge_index, edge_len = dyn_inputs[2].values, dyn_inputs[2].row_lengths( ) # batch-wise indexing edge_index = kgcnn_ops_change_edge_tensor_indexing_by_row_partition( edge_index, node_len, edge_len, partition_type_node="row_length", partition_type_edge="row_length", from_indexing=self.node_indexing, to_indexing="sample") ind_batch = tf.cast(tf.expand_dims(tf.repeat( tf.range(tf.shape(edge_len)[0]), edge_len), axis=-1), dtype=edge_index.dtype) ind_all = tf.concat([ind_batch, edge_index], axis=-1) ind_all = tf.cast(ind_all, dtype=tf.int64) max_index = tf.reduce_max(node_len) dense_shape = tf.stack([ tf.cast(tf.shape(node_len)[0], dtype=max_index.dtype), max_index, max_index ]) adj = tf.zeros(dense_shape, dtype=edge.dtype) ind_flat = tf.range( tf.cast(tf.shape(node_len)[0], dtype=max_index.dtype) * max_index * max_index) adj = tf.expand_dims(adj, axis=-1) adj = tf.tensor_scatter_nd_update(adj, ind_all, edge[:, 0:1]) adj = tf.squeeze(adj, axis=-1) out0 = adj out = adj for i in range(self.n - 1): out = tf.linalg.matmul(out, out0) # debug_result = out # sparsify mask = out > tf.keras.backend.epsilon() mask = tf.reshape(mask, (-1, )) out = tf.reshape(out, (-1, )) new_edge = out[mask] new_edge = tf.expand_dims(new_edge, axis=-1) new_indices = tf.unravel_index(ind_flat[mask], dims=dense_shape) new_egde_ids = new_indices[0] new_edge_index = tf.concat([ tf.expand_dims(new_indices[1], axis=-1), tf.expand_dims(new_indices[2], axis=-1) ], axis=-1) new_edge_len = tf.tensor_scatter_nd_add( tf.zeros_like(node_len), tf.expand_dims(new_egde_ids, axis=-1), tf.ones_like(new_egde_ids)) # batchwise indexing new_edge_index = kgcnn_ops_change_edge_tensor_indexing_by_row_partition( new_edge_index, node_len, new_edge_len, partition_type_node="row_length", partition_type_edge="row_length", from_indexing="sample", to_indexing=self.node_indexing) outlist = [ self._kgcnn_map_output_ragged([new_edge, new_edge_len], "row_length", 1), self._kgcnn_map_output_ragged([new_edge_index, new_edge_len], "row_length", 2) ] return outlist
def call(self, inputs, **kwargs): """Forward pass. The tensor representation can be tf.RaggedTensor, tf.Tensor or a list of (values, partition). The RaggedTensor has shape (batch, None, F) or in case of equal sized graphs (batch, N, F). For disjoint representation (values, partition), the node embeddings are given by a flatten value tensor of shape (batch*None, F) and a partition tensor of either "row_length", "row_splits" or "value_rowids" that matches the tf.RaggedTensor partition information. In this case the partition_type and node_indexing scheme, i.e. "batch", must be known by the layer. For edge indices, the last dimension holds indices from outgoing to ingoing node (i,j) as a directed edge. Args: inputs: [node, edges, attention, edge_indices] - nodes: Node features of shape (batch, [N], F) - edges: Edge or message features of shape (batch, [N], F) - attention: Attention coefficients of shape (batch, [N], 1) - edge_index: Edge indices of shape (batch, [N], F) Returns: embeddings: Feature tensor of pooled edge attentions for each node. """ found_node_type = kgcnn_ops_get_tensor_type( inputs[0], input_tensor_type=self.input_tensor_type, node_indexing=self.node_indexing) found_edge_type = kgcnn_ops_get_tensor_type( inputs[1], input_tensor_type=self.input_tensor_type, node_indexing=self.node_indexing) found_att_type = kgcnn_ops_get_tensor_type( inputs[2], input_tensor_type=self.input_tensor_type, node_indexing=self.node_indexing) found_index_type = kgcnn_ops_get_tensor_type( inputs[3], input_tensor_type=self.input_tensor_type, node_indexing=self.node_indexing) # We cast to values here nod, node_part = kgcnn_ops_dyn_cast( inputs[0], input_tensor_type=found_node_type, output_tensor_type="values_partition", partition_type=self.partition_type) edge, _ = kgcnn_ops_dyn_cast(inputs[1], input_tensor_type=found_edge_type, output_tensor_type="values_partition", partition_type=self.partition_type) attention, _ = kgcnn_ops_dyn_cast( inputs[2], input_tensor_type=found_att_type, output_tensor_type="values_partition", partition_type=self.partition_type) edgeind, edge_part = kgcnn_ops_dyn_cast( inputs[3], input_tensor_type=found_index_type, output_tensor_type="values_partition", partition_type=self.partition_type) shiftind = kgcnn_ops_change_edge_tensor_indexing_by_row_partition( edgeind, node_part, edge_part, partition_type_node=self.partition_type, partition_type_edge=self.partition_type, to_indexing='batch', from_indexing=self.node_indexing) nodind = shiftind[:, 0] # Pick first index eg. ingoing dens = edge ats = attention if not self.is_sorted: # Sort edgeindices node_order = tf.argsort(nodind, axis=0, direction='ASCENDING', stable=True) nodind = tf.gather(nodind, node_order, axis=0) dens = tf.gather(dens, node_order, axis=0) ats = tf.gather(ats, node_order, axis=0) # Apply segmented softmax ats = segment_softmax(ats, nodind) get = dens * ats get = tf.math.segment_sum(get, nodind) if self.has_unconnected: # Need to fill tensor since the maximum node may not be also in pooled # Does not happen if all nodes are also connected get = kgcnn_ops_scatter_segment_tensor_nd(get, nodind, tf.shape(nod)) return kgcnn_ops_dyn_cast([get, node_part], input_tensor_type="values_partition", output_tensor_type=found_node_type, partition_type=self.partition_type)
def call(self, inputs, **kwargs): """Forward pass. Args: inputs (list): [node, edge, edge_indices, map_node, map_edge, node_pool, edge_pool, edge_indices_pool] - node (tf.ragged): Original node tensor - edge (tf.ragged): Original edge feature tensor - edge_indices (tf.ragged): Original index tensor - map_node (tf.ragged): Index map between original and pooled nodes - map_edge (tf.ragged): Index map between original and pooled edges - node_pool (tf.ragged): Pooled node tensor - edge_pool (tf.ragged): Pooled edge feature tensor - edge_indices (tf.ragged): Pooled index tensor Returns: List: [nodes, edges, edge_indices] - nodes (tf.ragged): Unpooled node feature tensor - edges (tf.ragged): Unpooled edge feature list - edge_indices (tf.ragged): Unpooled edge index """ dyn_inputs = self._kgcnn_map_input_ragged(inputs, 8) # We cast to values here node_old, nrowlength = dyn_inputs[0].values, dyn_inputs[0].row_lengths( ) edge_old, erowlength = dyn_inputs[1].values, dyn_inputs[1].row_lengths( ) edgeind_old, _ = dyn_inputs[2].values, dyn_inputs[2].row_lengths() map_node, _ = dyn_inputs[3].values, dyn_inputs[3].row_lengths() map_edge, _ = dyn_inputs[4].values, dyn_inputs[4].row_lengths() node_new, pool_node_len = dyn_inputs[5].values, dyn_inputs[ 5].row_lengths() edge_new, pool_edge_id = dyn_inputs[6].values, dyn_inputs[ 6].value_rowids() edgeind_new, _ = dyn_inputs[7].values, dyn_inputs[7].row_lengths() # Correct map index for flatten batch offset map_node = kgcnn_ops_change_edge_tensor_indexing_by_row_partition( map_node, nrowlength, pool_node_len, partition_type_node="row_length", partition_type_edge="row_length", from_indexing=self.node_indexing, to_indexing="batch", axis=0) map_edge = kgcnn_ops_change_edge_tensor_indexing_by_row_partition( map_edge, erowlength, pool_edge_id, partition_type_node="row_length", partition_type_edge="value_rowids", from_indexing=self.node_indexing, to_indexing="batch", axis=0) index_dtype = map_node.dtype node_shape = tf.stack([ tf.cast(tf.shape(node_old)[0], dtype=index_dtype), tf.cast(tf.shape(node_new)[1], dtype=index_dtype) ]) out_node = tf.scatter_nd(ks.backend.expand_dims(map_node, axis=-1), node_new, node_shape) index_dtype = map_edge.dtype edge_shape = tf.stack([ tf.cast(tf.shape(edge_old)[0], dtype=index_dtype), tf.cast(tf.shape(edge_new)[1], dtype=index_dtype) ]) out_edge = tf.scatter_nd(ks.backend.expand_dims(map_edge, axis=-1), edge_new, edge_shape) outlist = [ self._kgcnn_map_output_ragged([out_node, nrowlength], "row_length", 0), self._kgcnn_map_output_ragged([out_edge, erowlength], "row_length", 1), self._kgcnn_map_output_ragged([edgeind_old, erowlength], "row_length", 2) ] return outlist
def call(self, inputs, **kwargs): """Forward pass. Args: inputs (list): [nodes, node_partition, edges, edge_partition, edge_indices] - nodes (tf.ragged): Node embeddings of shape (batch, [N], F) - edges (tf.ragged): Edge embeddings of shape (batch, [N], F) - edge_indices (tf.ragged): Edge index list of shape of shape (batch, [N], 2) Returns: Tuple: [nodes, edges, edge_indices], [map_nodes, map_edges] - nodes (tf.ragged): Pooled node feature tensor - edges (tf.ragged): Pooled edge feature list - edge_indices (tf.ragged): Pooled edge index list - map_nodes (tf.ragged): Index map between original and pooled nodes - map_edges (tf.ragged): Index map between original and pooled edges """ dyn_inputs = self._kgcnn_map_input_ragged(inputs, 3) # We cast to values here node, nodelen = dyn_inputs[0].values, dyn_inputs[0].row_lengths() edgefeat, edgelen = dyn_inputs[1].values, dyn_inputs[1].row_lengths() edgeindref, _ = dyn_inputs[2].values, dyn_inputs[2].row_lengths() index_dtype = edgeindref.dtype # Get node properties nvalue = node nrowlength = tf.cast(nodelen, dtype=index_dtype) erowlength = tf.cast(edgelen, dtype=index_dtype) nids = tf.repeat(tf.range(tf.shape(nrowlength)[0], dtype=index_dtype), nrowlength) # Use kernel p to get score norm_p = ks.backend.sqrt( ks.backend.sum(ks.backend.square(self.kernel_p), axis=-1, keepdims=True)) nscore = ks.backend.sum(nvalue * self.kernel_p / norm_p, axis=-1) # Sort nodes according to score # Then sort after former node ids -> stable = True keeps previous order sort1 = tf.argsort(nscore, direction='ASCENDING', stable=False) nids_sorted1 = tf.gather(nids, sort1) sort2 = tf.argsort(nids_sorted1, direction='ASCENDING', stable=True) # Must be stable=true here sort12 = tf.gather( sort1, sort2) # index goes from 0 to batch*N, no in batch indexing nvalue_sorted = tf.gather(nvalue, sort12, axis=0) nscore_sorted = tf.gather(nscore, sort12, axis=0) # Make Mask nremove = tf.cast(tf.math.round( self.k * tf.cast(nrowlength, dtype=tf.keras.backend.floatx())), dtype=index_dtype) nkeep = nrowlength - nremove n_remove_keep = ks.backend.flatten( tf.concat([ ks.backend.expand_dims(nremove, axis=-1), ks.backend.expand_dims(nkeep, axis=-1) ], axis=-1)) mask_remove_keep = ks.backend.flatten( tf.concat([ ks.backend.expand_dims(tf.zeros_like(nremove, dtype=tf.bool), axis=-1), ks.backend.expand_dims(tf.ones_like(nkeep, tf.bool), axis=-1) ], axis=-1)) mask = tf.repeat(mask_remove_keep, n_remove_keep) # Apply Mask to remove lower score nodes pooled_n = nvalue_sorted[mask] pooled_score = nscore_sorted[mask] pooled_id = nids[mask] # nids should not have changed by final sorting pooled_len = nkeep # shape=(batch,) pooled_index = tf.cast( sort12[mask], dtype=index_dtype) # the index goes from 0 to N*batch removed_index = tf.cast( sort12[tf.math.logical_not(mask)], dtype=index_dtype) # the index goes from 0 to N*batch # Pass through gate gated_n = pooled_n * ks.backend.expand_dims( tf.keras.activations.sigmoid(pooled_score), axis=-1) # Make index map for new nodes towards old index index_new_nodes = tf.range(tf.shape(pooled_index)[0], dtype=index_dtype) old_shape = tf.cast(ks.backend.expand_dims(tf.shape(nvalue)[0]), dtype=index_dtype) map_index = tf.scatter_nd( ks.backend.expand_dims(pooled_index, axis=-1), index_new_nodes, old_shape) # Shift index if necessary edge_ids = tf.repeat(tf.range(tf.shape(edgelen)[0], dtype=index_dtype), edgelen) shiftind = kgcnn_ops_change_edge_tensor_indexing_by_row_partition( edgeindref, nrowlength, edge_ids, partition_type_node="row_length", partition_type_edge="value_rowids", from_indexing=self.node_indexing, to_indexing="batch") shiftind = tf.cast( shiftind, dtype=index_dtype) # already shifted by batch offset (sub-graphs) # Remove edges that were from filtered nodes via mask mask_edge = ks.backend.expand_dims( shiftind, axis=-1) == ks.backend.expand_dims( ks.backend.expand_dims(removed_index, axis=0), axis=0) # this creates large tensor (batch*#edges,2,remove) mask_edge = tf.math.logical_not( ks.backend.any(ks.backend.any(mask_edge, axis=-1), axis=-1)) clean_shiftind = shiftind[mask_edge] clean_edge_ids = edge_ids[mask_edge] # clean_edge_len = tf.math.segment_sum(tf.ones_like(clean_edge_ids), clean_edge_ids) clean_edge_len = tf.scatter_nd( tf.expand_dims(clean_edge_ids, axis=-1), tf.ones_like(clean_edge_ids), tf.cast(tf.shape(erowlength), dtype=index_dtype)) # Map edgeindex to new index new_edge_index = tf.concat([ ks.backend.expand_dims(tf.gather(map_index, clean_shiftind[:, 0]), axis=-1), ks.backend.expand_dims(tf.gather(map_index, clean_shiftind[:, 1]), axis=-1) ], axis=-1) batch_order = tf.argsort(new_edge_index[:, 0], axis=0, direction='ASCENDING', stable=True) new_edge_index_sorted = tf.gather(new_edge_index, batch_order, axis=0) # Remove the batch offset from edge_indices again for indexing type out_indexlist = kgcnn_ops_change_edge_tensor_indexing_by_row_partition( new_edge_index_sorted, pooled_len, clean_edge_ids, partition_type_node="row_length", partition_type_edge="value_rowids", from_indexing="batch", to_indexing=self.node_indexing) # Correct edge features the same way (remove and reorder) edge_feat = edgefeat clean_edge_feat = edge_feat[mask_edge] clean_edge_feat_sorted = tf.gather(clean_edge_feat, batch_order, axis=0) # Make edge feature map for new edge features edge_position_old = tf.range(tf.shape(edgefeat)[0], dtype=index_dtype) edge_position_new = edge_position_old[mask_edge] edge_position_new = tf.gather(edge_position_new, batch_order, axis=0) # Collect output tensors out_node = gated_n out_edge = clean_edge_feat_sorted out_edge_index = out_indexlist # Change length to partition required out_np = pooled_len # row_length out_ep = clean_edge_len # row_length # Collect reverse pooling info # Remove batch offset for old indicies -> but with new length out_pool = kgcnn_ops_change_edge_tensor_indexing_by_row_partition( pooled_index, nrowlength, pooled_len, partition_type_node="row_length", partition_type_edge="row_length", from_indexing="batch", to_indexing=self.node_indexing, axis=0) out_pool_edge = kgcnn_ops_change_edge_tensor_indexing_by_row_partition( edge_position_new, erowlength, clean_edge_ids, partition_type_node="row_length", partition_type_edge="value_rowids", from_indexing="batch", to_indexing=self.node_indexing, axis=0) out = [ self._kgcnn_map_output_ragged([out_node, out_np], "row_length", 0), self._kgcnn_map_output_ragged([out_edge, out_ep], "row_length", 1), self._kgcnn_map_output_ragged([out_edge_index, out_ep], "row_length", 2) ] out_map = [ self._kgcnn_map_output_ragged([out_pool, out_np], "row_length", 0), self._kgcnn_map_output_ragged([out_pool_edge, out_ep], "row_length", 1) ] return out, out_map