def upsample_transposed_convolution(data, pool_map, sizes, kernel_size, transposed_convolution_op, name=None): # pyformat: disable r"""Graph upsampling by transposed convolution. Upsamples a graph using a transposed convolution op. The map from input vertices to the upsampled graph is specified by the reverse of pool_map. The inputs `pool_map` and `sizes` are the same as used for pooling: >>> pooled = pool(data, pool_map, sizes) >>> upsampled = upsample_transposed_convolution(pooled, pool_map, sizes, ...) The shorthands used below are `V1`: The number of vertices in the inputs. `V2`: The number of vertices in the upsampled output. `C`: The number of channels in the inputs. Note: In the following, A1 to A3 are optional batch dimensions. Only up to three batch dimensions are supported due to limitations with TensorFlow's dense-sparse multiplication. Please see the documentation for `graph_pooling.pool` for a detailed interpretation of the inputs `pool_map` and `sizes`. Args: data: A `float` tensor with shape `[A1, ..., A3, V1, C]`. pool_map: A `SparseTensor` with the same type as `data` and with shape `[A1, ..., A3, V1, V2]`. `pool_map` will be interpreted in the same way as the `pool_map` argument of `graph_pooling.pool`, namely `v_i_map = [..., v_i, :]` are the upsampled vertices corresponding to vertex `v_i`. Additionally, for transposed convolution a fixed number of entries in each `v_i_map` (equal to `kernel_size`) are expected: `|v_i_map| = kernel_size`. When this is not the case, the map is either truncated or the last element repeated. Furthermore, upsampled vertex indices should not be repeated across maps otherwise the output is nondeterministic. Specifically, to avoid nondeterminism we must have `intersect([a1, ..., an, v_i, :],[a1, ..., a3, v_j, :]) = {}, i != j`. sizes: An `int` tensor of shape `[A1, ..., A3, 2]` indicating the true input sizes in case of padding (`sizes=None` indicates no padding): `sizes[A1, ..., A3, 0] <= V1` and `sizes[A1, ..., A3, 1] <= V2`. kernel_size: The kernel size for transposed convolution. transposed_convolution_op: A callable transposed convolution op with the form `y = transposed_convolution_op(x)`, where `x` has shape `[1, 1, D1, C]` and `y` must have shape `[1, 1, kernel_size * D1, C]`. `transposed_convolution_op` maps each row of `x` to `kernel_size` rows in `y`. An example: `transposed_convolution_op = tf.keras.layers.Conv2DTranspose( filters=C, kernel_size=(1, kernel_size), strides=(1, kernel_size), padding='valid', ...) name: A name for this op. Defaults to 'graph_pooling_upsample_transposed_convolution'. Returns: Tensor with shape `[A1, ..., A3, V2, C]`. Raises: TypeError: if the input types are invalid. TypeError: if `transposed_convolution_op` is not a callable. ValueError: if the input dimensions are invalid. """ # pyformat: enable with tf.compat.v1.name_scope( name, 'graph_pooling_upsample_transposed_convolution', [data, pool_map, sizes]): data = tf.convert_to_tensor(value=data) pool_map = tf.compat.v1.convert_to_tensor_or_sparse_tensor(value=pool_map) if sizes is not None: sizes = tf.convert_to_tensor(value=sizes) utils.check_valid_graph_unpooling_input(data, pool_map, sizes) if not callable(transposed_convolution_op): raise TypeError("'transposed_convolution_op' must be callable.") if sizes is not None: sizes_input, sizes_output = tf.split(sizes, 2, axis=-1) sizes_input = tf.squeeze(sizes_input, axis=-1) sizes_output = tf.squeeze(sizes_output, axis=-1) else: sizes_input = None sizes_output = None num_features = tf.compat.v1.dimension_value(data.shape[-1]) batched = data.shape.ndims > 2 if batched: x_flat, _ = utils.flatten_batch_to_2d(data, sizes_input) pool_map_block_diagonal = utils.convert_to_block_diag_2d(pool_map, sizes) else: x_flat = data pool_map_block_diagonal = pool_map x_flat = tf.expand_dims(tf.expand_dims(x_flat, 0), 0) x_upsample = transposed_convolution_op(x_flat) # Map each upsampled vertex into its correct position based on pool_map. # Select 'kernel_size' neighbors for each input vertex. Truncate or repeat # as necessary. ragged = tf.RaggedTensor.from_value_rowids( pool_map_block_diagonal.indices[:, 1], pool_map_block_diagonal.indices[:, 0]) # Take up to the first 'kernel_size' entries. ragged_k = ragged[:, :kernel_size] # Fill rows with less than 'kernel_size' entries by repeating the last # entry. last = ragged_k[:, -1:].flat_values num_repeat = kernel_size - ragged_k.row_lengths() sum_num_repeat = tf.reduce_sum(input_tensor=num_repeat) ones_ragged = tf.RaggedTensor.from_row_lengths( tf.ones((sum_num_repeat,), dtype=last.dtype), num_repeat) repeat = ones_ragged * tf.expand_dims(last, -1) padded = tf.concat([ragged_k, repeat], axis=1) pool_map_dense = tf.reshape(padded.flat_values, (-1, kernel_size)) # Map rows of 'x_upsample' to positions indicated by the # indices 'pool_map_dense'. up_scatter_indices = tf.expand_dims(tf.reshape(pool_map_dense, (-1,)), -1) up_row = tf.reshape(tf.cast(up_scatter_indices, tf.int64), (-1,)) up_column = tf.range(tf.shape(input=up_row, out_type=tf.dtypes.int64)[0]) scatter_indices = tf.concat( (tf.expand_dims(up_row, -1), tf.expand_dims(up_column, -1)), axis=1) scatter_values = tf.ones_like(up_row, dtype=x_upsample.dtype) scatter_shape = tf.reduce_max(input_tensor=scatter_indices, axis=0) + 1 scatter = tf.SparseTensor(scatter_indices, scatter_values, scatter_shape) scatter = tf.sparse.reorder(scatter) row_sum = tf.sparse.reduce_sum(tf.abs(scatter), keepdims=True, axis=-1) row_sum = tf.compat.v1.where(tf.equal(row_sum, 0.), row_sum, 1.0 / row_sum) scatter = row_sum * scatter x_upsample = tf.sparse.sparse_dense_matmul(scatter, x_upsample[0, 0, :, :]) if batched: if sizes_output is not None: x_upsample = utils.unflatten_2d_to_batch(x_upsample, sizes_output) else: output_shape = tf.concat((tf.shape(input=pool_map)[:-2], tf.shape(input=pool_map)[-1:], (num_features,)), axis=0) x_upsample = tf.reshape(x_upsample, output_shape) return x_upsample
def edge_convolution_template(data, neighbors, sizes, edge_function, reduction, edge_function_kwargs, name=None): # pyformat: disable r"""A template for edge convolutions. This function implements a general edge convolution for graphs of the form \\(y_i = \sum_{j \in \mathcal{N}(i)} w_{ij} f(x_i, x_j)\\), where \\(\mathcal{N}(i)\\) is the set of vertices in the neighborhood of vertex \\(i\\), \\(x_i \in \mathbb{R}^C\\) are the features at vertex \\(i\\), \\(w_{ij} \in \mathbb{R}\\) is the weight for the edge between vertex \\(i\\) and vertex \\(j\\), and finally \\(f(x_i, x_j): \mathbb{R}^{C} \times \mathbb{R}^{C} \to \mathbb{R}^{D}\\) is a user-supplied function. This template also implements the same general edge convolution described above with a max-reduction instead of a weighted sum. An example of how this template can be used is for Laplacian smoothing, which is defined as $$y_i = \frac{1}{|\mathcal{N(i)}|} \sum_{j \in \mathcal{N(i)}} x_j$$. `edge_convolution_template` can be used to perform Laplacian smoothing by setting $$w_{ij} = \frac{1}{|\mathcal{N(i)}|}$$, `edge_function=lambda x, y: y`, and reduction='weighted'. The shorthands used below are `V`: The number of vertices. `C`: The number of channels in the input data. Note: In the following, A1 to An are optional batch dimensions. Args: data: A `float` tensor with shape `[A1, ..., An, V, C]`. neighbors: A `SparseTensor` with the same type as `data` and with shape `[A1, ..., An, V, V]` representing vertex neighborhoods. The neighborhood of a vertex defines the support region for convolution. The value at `neighbors[A1, ..., An, i, j]` corresponds to the weight \\(w_{ij}\\) above. Each vertex must have at least one neighbor. sizes: An `int` tensor of shape `[A1, ..., An]` indicating the true input sizes in case of padding (`sizes=None` indicates no padding). Note that `sizes[A1, ..., An] <= V`. If `data` and `neighbors` are 2-D, `sizes` will be ignored. As an example, consider an input consisting of three graphs G0, G1, and G2 with V0, V1, and V2 vertices respectively. The padded input would have the shapes `[3, V, C]`, and `[3, V, V]` for `data` and `neighbors` respectively, where `V = max([V0, V1, V2])`. The true sizes of each graph will be specified by `sizes=[V0, V1, V2]` and `data[i, :Vi, :]` and `neighbors[i, :Vi, :Vi]` will be the vertex and neighborhood data of graph Gi. The `SparseTensor` `neighbors` should have no nonzero entries in the padded regions. edge_function: A callable that takes at least two arguments of vertex features and returns a tensor of vertex features. `Y = f(X1, X2, **kwargs)`, where `X1` and `X2` have shape `[V3, C]` and `Y` must have shape `[V3, D], D >= 1`. reduction: Either 'weighted' or 'max'. Specifies the reduction over the neighborhood. For 'weighted', the reduction is a weighted sum as shown in the equation above. For 'max' the reduction is a max over features in which case the weights $$w_{ij}$$ are ignored. edge_function_kwargs: A dict containing any additional keyword arguments to be passed to `edge_function`. name: A name for this op. Defaults to `graph_convolution_edge_convolution_template`. Returns: Tensor with shape `[A1, ..., An, V, D]`. Raises: TypeError: if the input types are invalid. ValueError: if the input dimensions are invalid. """ # pyformat: enable with tf.compat.v1.name_scope(name, "graph_convolution_edge_convolution_template", [data, neighbors, sizes]): data = tf.convert_to_tensor(value=data) neighbors = tf.compat.v1.convert_to_tensor_or_sparse_tensor(value=neighbors) if sizes is not None: sizes = tf.convert_to_tensor(value=sizes) data_ndims = data.shape.ndims utils.check_valid_graph_convolution_input(data, neighbors, sizes) # Flatten the batch dimensions and remove any vertex padding. if data_ndims > 2: if sizes is not None: sizes_square = tf.stack((sizes, sizes), axis=-1) else: sizes_square = None x_flat, unflatten = utils.flatten_batch_to_2d(data, sizes) adjacency = utils.convert_to_block_diag_2d(neighbors, sizes_square) else: x_flat = data adjacency = neighbors adjacency_ind_0 = adjacency.indices[:, 0] adjacency_ind_1 = adjacency.indices[:, 1] vertex_features = tf.gather(x_flat, adjacency_ind_0) neighbor_features = tf.gather(x_flat, adjacency_ind_1) edge_features = edge_function(vertex_features, neighbor_features, **edge_function_kwargs) if reduction == "weighted": features = utils.partition_sums_2d(edge_features, adjacency_ind_0, adjacency.values) elif reduction == "max": features = tf.math.segment_max(data=edge_features, segment_ids=adjacency_ind_0) features.set_shape(features.shape.merge_with( (tf.compat.v1.dimension_value(x_flat.shape[0]), tf.compat.v1.dimension_value(edge_features.shape[-1])))) else: raise ValueError("The reduction method must be 'weighted' or 'max'") if data_ndims > 2: features = unflatten(features) return features
def feature_steered_convolution(data, neighbors, sizes, var_u, var_v, var_c, var_w, var_b, name=None): # pyformat: disable """Implements the Feature Steered graph convolution. FeaStNet: Feature-Steered Graph Convolutions for 3D Shape Analysis Nitika Verma, Edmond Boyer, Jakob Verbeek CVPR 2018 https://arxiv.org/abs/1706.05206 The shorthands used below are `V`: The number of vertices. `C`: The number of channels in the input data. `D`: The number of channels in the output after convolution. `W`: The number of weight matrices used in the convolution. The input variables (`var_u`, `var_v`, `var_c`, `var_w`, `var_b`) correspond to the variables with the same names in the paper cited above. Note: In the following, A1 to An are optional batch dimensions. Args: data: A `float` tensor with shape `[A1, ..., An, V, C]`. neighbors: A `SparseTensor` with the same type as `data` and with shape `[A1, ..., An, V, V]` representing vertex neighborhoods. The neighborhood of a vertex defines the support region for convolution. For a mesh, a common choice for the neighborhood of vertex i would be the vertices in the K-ring of i (including i itself). Each vertex must have at least one neighbor. For a faithful implementation of the FeaStNet convolution, neighbors should be a row-normalized weight matrix corresponding to the graph adjacency matrix with self-edges: `neighbors[A1, ..., An, i, j] > 0` if vertex j is a neighbor of i, and `neighbors[A1, ..., An, i, i] > 0` for all i, and `sum(neighbors, axis=-1)[A1, ..., An, i] == 1.0 for all i`. These requirements are relaxed in this implementation. sizes: An `int` tensor of shape `[A1, ..., An]` indicating the true input sizes in case of padding (`sizes=None` indicates no padding).Note that `sizes[A1, ..., An] <= V`. If `data` and `neighbors` are 2-D, `sizes` will be ignored. An example usage of `sizes`: consider an input consisting of three graphs G0, G1, and G2 with V0, V1, and V2 vertices respectively. The padded input would have the following shapes: `data.shape = [3, V, C]` and `neighbors.shape = [3, V, V]`, where `V = max([V0, V1, V2])`. The true sizes of each graph will be specified by `sizes=[V0, V1, V2]` `data[i, :Vi, :]` and `neighbors[i, :Vi, :Vi]` will be the vertex and neighborhood data of graph Gi. The `SparseTensor` `neighbors` should have no nonzero entries in the padded regions. var_u: A 2-D tensor with shape `[C, W]`. var_v: A 2-D tensor with shape `[C, W]`. var_c: A 1-D tensor with shape `[W]`. var_w: A 3-D tensor with shape `[W, C, D]`. var_b: A 1-D tensor with shape `[D]`. name: A name for this op. Defaults to `graph_convolution_feature_steered_convolution`. Returns: Tensor with shape `[A1, ..., An, V, D]`. Raises: TypeError: if the input types are invalid. ValueError: if the input dimensions are invalid. """ # pyformat: enable with tf.compat.v1.name_scope( name, "graph_convolution_feature_steered_convolution", [data, neighbors, sizes, var_u, var_v, var_c, var_w, var_b]): data = tf.convert_to_tensor(value=data) neighbors = tf.compat.v1.convert_to_tensor_or_sparse_tensor(value=neighbors) if sizes is not None: sizes = tf.convert_to_tensor(value=sizes) var_u = tf.convert_to_tensor(value=var_u) var_v = tf.convert_to_tensor(value=var_v) var_c = tf.convert_to_tensor(value=var_c) var_w = tf.convert_to_tensor(value=var_w) var_b = tf.convert_to_tensor(value=var_b) data_ndims = data.shape.ndims utils.check_valid_graph_convolution_input(data, neighbors, sizes) shape.compare_dimensions( tensors=(data, var_u, var_v, var_w), tensor_names=("data", "var_u", "var_v", "var_w"), axes=(-1, 0, 0, 1)) shape.compare_dimensions( tensors=(var_u, var_v, var_c, var_w), tensor_names=("var_u", "var_v", "var_c", "var_w"), axes=(1, 1, 0, 0)) shape.compare_dimensions( tensors=(var_w, var_b), tensor_names=("var_w", "var_b"), axes=-1) # Flatten the batch dimensions and remove any vertex padding. if data_ndims > 2: if sizes is not None: sizes_square = tf.stack((sizes, sizes), axis=-1) else: sizes_square = None x_flat, unflatten = utils.flatten_batch_to_2d(data, sizes) adjacency = utils.convert_to_block_diag_2d(neighbors, sizes_square) else: x_flat = data adjacency = neighbors x_u = tf.matmul(x_flat, var_u) x_v = tf.matmul(x_flat, var_v) adjacency_ind_0 = adjacency.indices[:, 0] adjacency_ind_1 = adjacency.indices[:, 1] x_u_rep = tf.gather(x_u, adjacency_ind_0) x_v_sep = tf.gather(x_v, adjacency_ind_1) weights_q = tf.exp(x_u_rep + x_v_sep + tf.reshape(var_c, (1, -1))) weights_q_sum = tf.reduce_sum( input_tensor=weights_q, axis=-1, keepdims=True) weights_q = weights_q / weights_q_sum y_i_m = [] x_sep = tf.gather(x_flat, adjacency_ind_1) q_m_list = tf.unstack(weights_q, axis=-1) w_m_list = tf.unstack(var_w, axis=0) for q_m, w_m in zip(q_m_list, w_m_list): # Compute `y_i_m = sum_{j in neighborhood(i)} q_m(x_i, x_j) * w_m * x_j`. q_m = tf.expand_dims(q_m, axis=-1) p_sum = utils.partition_sums_2d(q_m * x_sep, adjacency_ind_0, adjacency.values) y_i_m.append(tf.matmul(p_sum, w_m)) y_out = tf.add_n(inputs=y_i_m) + tf.reshape(var_b, [1, -1]) if data_ndims > 2: y_out = unflatten(y_out) return y_out
def unflatten_2d_to_batch(flat): _, unflatten = utils.flatten_batch_to_2d(data_init, sizes=sizes) return unflatten(flat)
def flatten_batch_to_2d(data): flattened, _ = utils.flatten_batch_to_2d(data, sizes=sizes) return flattened
def test_flatten_batch_to_2d_exception_raised_types(self): """Check the exception when input is not an integer.""" with self.assertRaisesRegexp(TypeError, "'sizes' must have an integer type."): utils.flatten_batch_to_2d(np.ones((3, 4, 3)), np.ones((3, )))
def pool(data, pool_map, sizes, algorithm='max', name=None): # pyformat: disable """Implements graph pooling. The features at each output vertex are computed by pooling over a subset of vertices in the input graph. This pooling window is specified by the input `pool_map`. The shorthands used below are `V1`: The number of vertices in the input data. `V2`: The number of vertices in the pooled output data. `C`: The number of channels in the data. Note: In the following, A1 to An are optional batch dimensions. Args: data: A `float` tensor with shape `[A1, ..., An, V1, C]`. pool_map: A `SparseTensor` with the same type as `data` and with shape `[A1, ..., An, V2, V1]`. The features for an output vertex `v2` will be computed by pooling over the corresponding input vertices specified by the entries in `pool_map[A1, ..., An, v2, :]`. sizes: An `int` tensor of shape `[A1, ..., An, 2]` indicating the true input sizes in case of padding (`sizes=None` indicates no padding). `sizes[A1, ..., An, 0] <= V2` specifies the padding in the (pooled) output, and `sizes[A1, ..., An, 1] <= V1` specifies the padding in the input. algorithm: The pooling function, must be either 'max' or 'weighted'. Default is 'max'. For 'max' pooling, the output features are the maximum over the input vertices (in this case only the indices of the `SparseTensor` `pool_map` are used, the values are ignored). For 'weighted', the output features are a weighted sum of the input vertices, the weights specified by the values of `pool_map`. name: A name for this op. Defaults to 'graph_pooling_pool'. Returns: Tensor with shape `[A1, ..., An, V2, C]`. Raises: TypeError: if the input types are invalid. ValueError: if the input dimensions are invalid. ValueError: if `algorithm` is invalid. """ # pyformat: enable with tf.compat.v1.name_scope(name, 'graph_pooling_pool', [data, pool_map, sizes]): data = tf.convert_to_tensor(value=data) pool_map = tf.compat.v1.convert_to_tensor_or_sparse_tensor( value=pool_map) if sizes is not None: sizes = tf.convert_to_tensor(value=sizes) utils.check_valid_graph_pooling_input(data, pool_map, sizes) if sizes is not None: sizes_output, sizes_input = tf.split(sizes, 2, axis=-1) sizes_output = tf.squeeze(sizes_output, axis=-1) sizes_input = tf.squeeze(sizes_input, axis=-1) else: sizes_output = None sizes_input = None batched = data.shape.ndims > 2 if batched: x_flat, _ = utils.flatten_batch_to_2d(data, sizes_input) pool_map_block_diagonal = utils.convert_to_block_diag_2d( pool_map, sizes) else: x_flat = data pool_map_block_diagonal = pool_map if algorithm == 'weighted': pooled = tf.sparse.sparse_dense_matmul(pool_map_block_diagonal, x_flat) elif algorithm == 'max': pool_groups = tf.gather(x_flat, pool_map_block_diagonal.indices[:, 1]) pooled = tf.math.segment_max( data=pool_groups, segment_ids=pool_map_block_diagonal.indices[:, 0]) else: raise ValueError('The pooling method must be "weighted" or "max"') if batched: if sizes_output is not None: pooled = utils.unflatten_2d_to_batch(pooled, sizes_output) else: output_shape = tf.concat( (tf.shape(input=pool_map)[:-1], (-1, )), axis=0) pooled = tf.reshape(pooled, output_shape) return pooled
def edge_convolution_template(data, neighbors, sizes, edge_function, edge_function_kwargs, name=None): # pyformat: disable r"""A template for edge convolutions. This function implements a general edge convolution for graphs of the form $$ y_i = \sum_{j \in \mathcal{N}(i)} w_{ij} f(x_i, x_j) $$ Where $$\mathcal{N}(i)$$ is the set of vertices in the neighborhood of vertex $$i$$, $$x_i \in \mathbb{R}^C$$ are the features at vertex $$i$$, $$w_{ij} \in \mathbb{R}$$ is the weight for the edge between vertex $$i$$ and vertex $$j$$, and finally $$f(x_i, x_j): \mathbb{R}^{C} \times \mathbb{R}^{C} \to \mathbb{R}^{D}$$ is a user-supplied function. The shorthands used below are `V`: The number of vertices. `C`: The number of channels in the input data. Note: In the following, A1 to An are optional batch dimensions. Args: data: A `float` tensor with shape `[A1, ..., An, V, C]`. neighbors: A `SparseTensor` with the same type as `data` and with shape `[A1, ..., An, V, V]` representing vertex neighborhoods. The neighborhood of a vertex defines the support region for convolution. The value at `neighbors[A1, ..., An, i, j]` corresponds to the weight $$w_{ij}$$ above. Each vertex must have at least one neighbor. sizes: An `int` tensor of shape `[A1, ..., An]` indicating the true input sizes in case of padding (`sizes=None` indicates no padding). Note that `sizes[A1, ..., An] <= V`. If `data` and `neighbors` are 2-D, `sizes` will be ignored. As an example, consider an input consisting of three graphs G0, G1, and G2 with V0, V1, and V2 vertices respectively. The padded input would have the shapes `[3, V, C]`, and `[3, V, V]` for `data` and `neighbors` respectively, where `V = max([V0, V1, V2])`. The true sizes of each graph will be specified by `sizes=[V0, V1, V2]` and `data[i, :Vi, :]` and `neighbors[i, :Vi, :Vi]` will be the vertex and neighborhood data of graph Gi. The `SparseTensor` `neighbors` should have no nonzero entries in the padded regions. edge_function: A callable that takes at least two arguments of vertex features and returns a tensor of vertex features. `Y = f(X1, X2, **kwargs)`, where `X1` and `X2` have shape `[V3, C]` and `Y` must have shape `[V3, D], D >= 1`. edge_function_kwargs: A dict containing any additional keyword arguments to be passed to `edge_function`. name: A name for this op. Defaults to `graph_convolution_edge_convolution_template`. Returns: Tensor with shape `[A1, ..., An, V, D]`. Raises: TypeError: if the input types are invalid. ValueError: if the input dimensions are invalid. """ # pyformat: enable with tf.compat.v1.name_scope( name, "graph_convolution_edge_convolution_template", [data, neighbors, sizes]): data = tf.convert_to_tensor(value=data) neighbors = tf.compat.v1.convert_to_tensor_or_sparse_tensor( value=neighbors) if sizes is not None: sizes = tf.convert_to_tensor(value=sizes) data_ndims = data.shape.ndims utils.check_valid_graph_convolution_input(data, neighbors, sizes) # Flatten the batch dimensions and remove any vertex padding. if data_ndims > 2: if sizes is not None: sizes_square = tf.stack((sizes, sizes), axis=-1) else: sizes_square = None x_flat, unflatten = utils.flatten_batch_to_2d(data, sizes) adjacency = utils.convert_to_block_diag_2d(neighbors, sizes_square) else: x_flat = data adjacency = neighbors adjacency_ind_0 = adjacency.indices[:, 0] adjacency_ind_1 = adjacency.indices[:, 1] vertex_features = tf.gather(x_flat, adjacency_ind_0) neighbor_features = tf.gather(x_flat, adjacency_ind_1) edge_features = edge_function(vertex_features, neighbor_features, **edge_function_kwargs) features = utils.partition_sums_2d(edge_features, adjacency_ind_0, adjacency.values) if data_ndims > 2: features = unflatten(features) return features