Example #1
0
    def call(self,
             points,
             queries,
             radii,
             points_row_splits=None,
             queries_row_splits=None):
        """This function computes the neighbors within a radius for each query point.

        Arguments:

          points: The 3D positions of the input points. *This argument must be
            given as a positional argument!*

          queries: The 3D positions of the query points.

          radii: A radius for each query point.

          points_row_splits: Optional 1D vector with the row splits information
            if points is batched.  This vector is [0, num_points] if there is
            only 1 batch item.

          queries_row_splits: Optional 1D vector with the row splits information
            if queries is batched. This vector is [0, num_queries] if there is
            only 1 batch item.

        Returns:
          3 Tensors in the following order

          neighbors_index
            The compact list of indices of the neighbors. The corresponding query point
            can be inferred from the 'neighbor_count_row_splits' vector.

          neighbors_row_splits
            The exclusive prefix sum of the neighbor count for the query points including
            the total neighbor count as the last element. The size of this array is the
            number of queries + 1.

          neighbors_distance
            Stores the distance to each neighbor if 'return_distances' is True.
            Note that the distances are squared if metric is L2.
            This is a zero length Tensor if 'return_distances' is False.
        """
        if points_row_splits is None:
            points_row_splits = tf.cast(tf.stack([0, tf.shape(points)[0]]),
                                        dtype=tf.int64)
        if queries_row_splits is None:
            queries_row_splits = tf.cast(tf.stack([0, tf.shape(queries)[0]]),
                                         dtype=tf.int64)

        result = ops.radius_search(
            ignore_query_point=self.ignore_query_point,
            return_distances=self.return_distances,
            normalize_distances=self.normalize_distances,
            metric=self.metric,
            points=points,
            queries=queries,
            radii=radii,
            points_row_splits=points_row_splits,
            queries_row_splits=queries_row_splits)
        return result
Example #2
0
    def call(self,
             inp_features,
             inp_positions,
             out_positions,
             extents,
             inp_importance=None,
             fixed_radius_search_hash_table=None,
             user_neighbors_index=None,
             user_neighbors_row_splits=None,
             user_neighbors_importance=None):
        """This function computes the output features.

        Arguments:

          inp_features: A 2D tensor which stores a feature vector for each input
            point.

          inp_positions: A 2D tensor with the 3D point positions of each input
            point. The coordinates for each point is a vector with format [x,y,z].

          out_positions: A 2D tensor with the 3D point positions of each output
            point. The coordinates for each point is a vector with format [x,y,z].

          extents: The extent defines the spatial size of the filter for each
            output point.
            For 'ball to cube' coordinate mappings the extent defines the
            bounding box of the ball.
            The shape of the tensor is either [1] or [num output points].

          inp_importance: Optional scalar importance value for each input point.

          fixed_radius_search_hash_table: A precomputed hash table generated with
            build_spatial_hash_table().
            This input can be used to explicitly force the reuse of a hash table in
            special cases and is usually not needed.
            Note that the hash table must have been generated with the same 'points'
            array. Note that this parameter is only used if 'extents' is a scalar.

          user_neighbors_index: This parameter together with 'user_neighbors_row_splits'
            and 'user_neighbors_importance' allows to override the automatic neighbor
            search. This is the list of neighbor indices for each output point.
            This is a nested list for which the start and end of each sublist is
            defined by 'user_neighbors_row_splits'.

          user_neighbors_row_splits: Defines the start and end of each neighbors
            list in 'user_neighbors_index'.

          user_neighbors_importance: Defines a scalar importance value for each
            element in 'user_neighbors_index'.


        Returns: A tensor of shape [num output points, filters] with the output
          features.
        """

        offset = self.offset

        if inp_importance is None:
            inp_importance = tf.ones((0,), dtype=tf.float32)

        extents = tf.convert_to_tensor(extents)

        return_distances = not self.window_function is None

        if not user_neighbors_index is None and not user_neighbors_row_splits is None:

            if user_neighbors_importance is None:
                neighbors_importance = tf.ones((0,), dtype=tf.float32)
            else:
                neighbors_importance = user_neighbors_importance

            neighbors_index = user_neighbors_index
            neighbors_row_splits = user_neighbors_row_splits

        else:
            if extents.shape.rank == 0:
                radius = 0.5 * extents
                self.nns = self.fixed_radius_search(
                    inp_positions,
                    queries=out_positions,
                    radius=radius,
                    hash_table=fixed_radius_search_hash_table)
                if return_distances:
                    if self.radius_search_metric == 'L2':
                        neighbors_distance_normalized = self.nns.neighbors_distance / (
                            radius * radius)
                    else:  # L1
                        neighbors_distance_normalized = self.nns.neighbors_distance / radius

            elif extents.shape.rank == 1:
                radii = 0.5 * extents
                self.nns = ops.radius_search(
                    ignore_query_point=self.radius_search_ignore_query_points,
                    return_distances=return_distances,
                    normalize_distances=return_distances,
                    metric=self.radius_search_metric,
                    points=inp_positions,
                    queries=out_positions,
                    radii=radii)

            else:
                raise Exception("extents rank must be 0 or 1")

            if self.window_function is None:
                neighbors_importance = tf.ones((0,), dtype=tf.float32)
            else:
                neighbors_importance = self.window_function(
                    neighbors_distance_normalized)

            neighbors_index = self.nns.neighbors_index
            neighbors_row_splits = self.nns.neighbors_row_splits

        # for stats and debugging
        num_pairs = tf.shape(neighbors_index)[0]
        self._avg_neighbors = tf.dtypes.cast(
            num_pairs, tf.float32) / tf.dtypes.cast(
                tf.shape(out_positions)[0], tf.float32)

        extents_rank2 = extents
        while extents_rank2.shape.rank < 2:
            extents_rank2 = tf.expand_dims(extents_rank2, axis=-1)

        self._conv_values = {
            'filters': self.kernel,
            'out_positions': out_positions,
            'extents': extents_rank2,
            'offset': offset,
            'inp_positions': inp_positions,
            'inp_features': inp_features,
            'inp_importance': inp_importance,
            'neighbors_index': neighbors_index,
            'neighbors_row_splits': neighbors_row_splits,
            'neighbors_importance': neighbors_importance,
            'align_corners': self.align_corners,
            'coordinate_mapping': self.coordinate_mapping,
            'interpolation': self.interpolation,
            'normalize': self.normalize,
        }

        out_features = ops.continuous_conv(**self._conv_values)

        self._conv_output = out_features

        if self.use_dense_layer_for_center:
            self._dense_output = self.dense(inp_features)
            out_features = out_features + self._dense_output

        if self.use_bias:
            out_features += self.bias
        out_features = self.activation(out_features)

        return out_features