def set_abstraction( features, neighborhood, edge_mlp, reduction=tf.reduce_max, node_mlp=None, coords_as_features=True, ): """ Reimplementation of original `pointnet_sa_module` function. Args: features: [b, n_i?, filters_in] float32 tensor of flattend batched point features. neighborhood: `deepcloud.neigh.Neighborhood` instance with `n_i` inputs and `n_o` output points. edge_mlp: callable acting on each edge features. reduction: operation to reduce neighborhoods to point features. node_mpl: callacble acting on each point after reduction. coords_as_features: if True, use relative coords in neighborhood as well as features in edges. Returns: features: [b, n_o?, filters_o] float32 array, where filters_o is the number of output features of `edge_mlp` if `node_mlp` is None else the number of output features of `node_mlp`. """ def flat_rel_coords(): return b.as_batched_model_input( neighborhood.rel_coords.flat_values).flat_values if features is None: features = flat_rel_coords() else: features = layer_utils.flatten_leading_dims(features) offset_batched_neighbors = neighborhood.offset_batched_neighbors if coords_as_features: features = tf.gather(features, offset_batched_neighbors.flat_values) features = layers.Lambda(tf.concat, arguments=dict(axis=-1))( [features, flat_rel_coords()]) else: # more efficient than original implementation features = edge_mlp(features) features = tf.gather(features, offset_batched_neighbors.flat_values) # features is not flat, [B, f] features = tf.RaggedTensor.from_nested_row_splits( features, offset_batched_neighbors.nested_row_splits) # features is now [b, n_o?, k?, E] features = ragged_lambda(reduction, arguments=dict(axis=-2))(features) # features is now [b, n_o?, E] if node_mlp is not None: if isinstance(features, tf.RaggedTensor): features = tf.ragged.map_flat_values(node_mlp, features) else: features = node_mlp(features) return features
def query_pairs(coords, radius, name=None, max_neighbors=None): if isinstance(radius, tf.Tensor): kwargs = {} args = coords, radius else: kwargs = dict(radius=radius) args = coords kwargs['max_neighbors'] = max_neighbors neighbors = ragged_lambda(_query_pairs, arguments=kwargs)(args) return neighbors
def _batched_ragged(self, rt): if rt in self._batched_inputs_dict: return self._batched_inputs_dict[rt] size = self._batched_fixed_tensor( ragged_layers.ragged_lambda(lambda x: x.nrows())(rt)) nested_row_lengths = ragged_layers.nested_row_lengths(rt) nested_row_lengths = [ self._batched_tensor(rl) for rl in nested_row_lengths ] nested_row_lengths = [ layer_utils.flatten_leading_dims(rl, 2) for rl in nested_row_lengths ] values = layer_utils.flatten_leading_dims( self._batched_tensor(rt.flat_values), 2) out = ragged_layers.ragged_from_nested_row_lengths( values, [size] + nested_row_lengths) self._batched_inputs_dict[rt] = out return out
def get_relative_coords(in_coords, out_coords, indices, name=None): return ragged_lambda(_get_relative_coords, name=name)([in_coords, out_coords, indices])
def row_lengths(row_splits): if isinstance(row_splits, tf.RaggedTensor): return ragged_lambda(_row_lengths)(row_splits) else: return diff(row_splits)
def reverse_query_pairs(neighbors, size): return ragged_lambda(_reverse_query_pairs)([neighbors, size])