def call(self, points, queries, radii, points_row_splits=None, queries_row_splits=None): """This function computes the neighbors within a radius for each query point. Arguments: points: The 3D positions of the input points. *This argument must be given as a positional argument!* queries: The 3D positions of the query points. radii: A radius for each query point. points_row_splits: Optional 1D vector with the row splits information if points is batched. This vector is [0, num_points] if there is only 1 batch item. queries_row_splits: Optional 1D vector with the row splits information if queries is batched. This vector is [0, num_queries] if there is only 1 batch item. Returns: 3 Tensors in the following order neighbors_index The compact list of indices of the neighbors. The corresponding query point can be inferred from the 'neighbor_count_row_splits' vector. neighbors_row_splits The exclusive prefix sum of the neighbor count for the query points including the total neighbor count as the last element. The size of this array is the number of queries + 1. neighbors_distance Stores the distance to each neighbor if 'return_distances' is True. Note that the distances are squared if metric is L2. This is a zero length Tensor if 'return_distances' is False. """ if points_row_splits is None: points_row_splits = tf.cast(tf.stack([0, tf.shape(points)[0]]), dtype=tf.int64) if queries_row_splits is None: queries_row_splits = tf.cast(tf.stack([0, tf.shape(queries)[0]]), dtype=tf.int64) result = ops.radius_search( ignore_query_point=self.ignore_query_point, return_distances=self.return_distances, normalize_distances=self.normalize_distances, metric=self.metric, points=points, queries=queries, radii=radii, points_row_splits=points_row_splits, queries_row_splits=queries_row_splits) return result
def call(self, inp_features, inp_positions, out_positions, extents, inp_importance=None, fixed_radius_search_hash_table=None, user_neighbors_index=None, user_neighbors_row_splits=None, user_neighbors_importance=None): """This function computes the output features. Arguments: inp_features: A 2D tensor which stores a feature vector for each input point. inp_positions: A 2D tensor with the 3D point positions of each input point. The coordinates for each point is a vector with format [x,y,z]. out_positions: A 2D tensor with the 3D point positions of each output point. The coordinates for each point is a vector with format [x,y,z]. extents: The extent defines the spatial size of the filter for each output point. For 'ball to cube' coordinate mappings the extent defines the bounding box of the ball. The shape of the tensor is either [1] or [num output points]. inp_importance: Optional scalar importance value for each input point. fixed_radius_search_hash_table: A precomputed hash table generated with build_spatial_hash_table(). This input can be used to explicitly force the reuse of a hash table in special cases and is usually not needed. Note that the hash table must have been generated with the same 'points' array. Note that this parameter is only used if 'extents' is a scalar. user_neighbors_index: This parameter together with 'user_neighbors_row_splits' and 'user_neighbors_importance' allows to override the automatic neighbor search. This is the list of neighbor indices for each output point. This is a nested list for which the start and end of each sublist is defined by 'user_neighbors_row_splits'. user_neighbors_row_splits: Defines the start and end of each neighbors list in 'user_neighbors_index'. user_neighbors_importance: Defines a scalar importance value for each element in 'user_neighbors_index'. Returns: A tensor of shape [num output points, filters] with the output features. """ offset = self.offset if inp_importance is None: inp_importance = tf.ones((0,), dtype=tf.float32) extents = tf.convert_to_tensor(extents) return_distances = not self.window_function is None if not user_neighbors_index is None and not user_neighbors_row_splits is None: if user_neighbors_importance is None: neighbors_importance = tf.ones((0,), dtype=tf.float32) else: neighbors_importance = user_neighbors_importance neighbors_index = user_neighbors_index neighbors_row_splits = user_neighbors_row_splits else: if extents.shape.rank == 0: radius = 0.5 * extents self.nns = self.fixed_radius_search( inp_positions, queries=out_positions, radius=radius, hash_table=fixed_radius_search_hash_table) if return_distances: if self.radius_search_metric == 'L2': neighbors_distance_normalized = self.nns.neighbors_distance / ( radius * radius) else: # L1 neighbors_distance_normalized = self.nns.neighbors_distance / radius elif extents.shape.rank == 1: radii = 0.5 * extents self.nns = ops.radius_search( ignore_query_point=self.radius_search_ignore_query_points, return_distances=return_distances, normalize_distances=return_distances, metric=self.radius_search_metric, points=inp_positions, queries=out_positions, radii=radii) else: raise Exception("extents rank must be 0 or 1") if self.window_function is None: neighbors_importance = tf.ones((0,), dtype=tf.float32) else: neighbors_importance = self.window_function( neighbors_distance_normalized) neighbors_index = self.nns.neighbors_index neighbors_row_splits = self.nns.neighbors_row_splits # for stats and debugging num_pairs = tf.shape(neighbors_index)[0] self._avg_neighbors = tf.dtypes.cast( num_pairs, tf.float32) / tf.dtypes.cast( tf.shape(out_positions)[0], tf.float32) extents_rank2 = extents while extents_rank2.shape.rank < 2: extents_rank2 = tf.expand_dims(extents_rank2, axis=-1) self._conv_values = { 'filters': self.kernel, 'out_positions': out_positions, 'extents': extents_rank2, 'offset': offset, 'inp_positions': inp_positions, 'inp_features': inp_features, 'inp_importance': inp_importance, 'neighbors_index': neighbors_index, 'neighbors_row_splits': neighbors_row_splits, 'neighbors_importance': neighbors_importance, 'align_corners': self.align_corners, 'coordinate_mapping': self.coordinate_mapping, 'interpolation': self.interpolation, 'normalize': self.normalize, } out_features = ops.continuous_conv(**self._conv_values) self._conv_output = out_features if self.use_dense_layer_for_center: self._dense_output = self.dense(inp_features) out_features = out_features + self._dense_output if self.use_bias: out_features += self.bias out_features = self.activation(out_features) return out_features