Esempio n. 1
0
 def visit_ragged_components(self, obj):
     from more_keras.ragged.np_impl import RaggedArray
     if len(obj.nested_row_splits) > 1:
         RaggedArray.from_nested_row_splits(
             self.visit(obj.flat_values), self.visit(obj.nested_row_splits))
     return RaggedArray.from_row_splits(
         self.visit(obj.flat_values), self.visit(obj.nested_row_splits[0]))
Esempio n. 2
0
    def query_ball_point(self, x, r, max_neighbors=None, approx_neighbors=None):
        """
        Find points in the tree within `r` of `x` using only `self.query`.

        Note scipy and sklearn both implement their own versions of
        these which ignore max_neighbors and approx_neighbors arguments.

        Args:
            x: [n2, m] float.
            r: float, radius of ball search.
            max_neighbors: int, maximum number of neighbors to consider. If
                `None`, uses `approx_neighbors` and recursive strategy.
            approx_neighbors: int, approximate number of neighbors to consider
                in recursive strategy. Ignored if `max_neighbors` is given.

        Returns:
            [n2, k?] RaggedArray of indices into data.
        """
        if max_neighbors is None:
            if approx_neighbors is None:
                raise ValueError(
                    '`max_neighbors` or `approx_neighbors` must be provided for'
                    '{}.query_ball_point'.format(self.__class__.__name__))
            return self.query_ball_point_recursive(x, r, approx_neighbors)
        else:
            indices = self.query(x,
                                 max_neighbors,
                                 distance_upper_bound=r,
                                 return_distance=False)

            return RaggedArray.from_mask(indices, self.valid(indices))
Esempio n. 3
0
def truncate(neighbors, limit):
    """Take only the first `limit` entries of each row of `neighbors`."""
    row_lengths = neighbors.row_lengths
    true_counts = np.minimum(row_lengths, limit)
    truncated = np.maximum(row_lengths - limit, 0)
    values = np.array([[True, False]], dtype=np.bool)
    repeats = np.stack((true_counts, truncated), axis=1)
    mask = np.repeat(
        np.tile(values, (len(row_lengths), 1)).flatten(), repeats.flatten())
    flat_values = neighbors.flat_values[mask]
    return RaggedArray.from_row_lengths(flat_values, true_counts)
Esempio n. 4
0
    def query_ball_point_recursive(self, x, r, approx_neighbors):
        """
        Query ball point using only self.query.

        Performs query using k=approx_neighbors, then repeats with double
        the number of neighbors until at least one returns an invalid flag.

        Returns:
            RaggedArray of indices
        """
        indices, valid = self._query_ball_point_recursive(
            x, r, approx_neighbors)
        return RaggedArray.from_mask(indices, valid)
Esempio n. 5
0
    def test_truncate(self):
        r = np.random.RandomState(123)  # pylint: disable=no-member
        upper = 10
        row_lengths = (r.uniform(size=(10, )) * upper).astype(np.int64)
        data = np.zeros((np.sum(row_lengths), ), dtype=np.bool)
        ragged = RaggedArray.from_row_lengths(data, row_lengths)

        def slow_truncate(ragged, limit):
            return RaggedArray.from_ragged_lists(
                tuple(rl[:limit] for rl in ragged.ragged_lists))

        limit = upper // 2
        actual = core.truncate(ragged, limit)
        expected = slow_truncate(ragged, limit)
        np.testing.assert_equal(actual.flat_values, expected.flat_values)
        np.testing.assert_equal(actual.row_splits, expected.row_splits)
Esempio n. 6
0
 def slow_truncate(ragged, limit):
     return RaggedArray.from_ragged_lists(
         tuple(rl[:limit] for rl in ragged.ragged_lists))
Esempio n. 7
0
 def rejection_sample(flat_indices, row_splits):
     from more_keras.ragged.np_impl import RaggedArray
     ra = RaggedArray.from_row_splits(flat_indices, row_splits)
     return np.array(core.rejection_sample_precomputed(ra),
                     dtype=np.int64)
Esempio n. 8
0
def reverse_query_ball(ragged_array, size=None, data=None):
    """
    Get `query_ball_tree` for reversed in/out trees.

    Also returns data associated with the reverse, or relevant indices.

    Example usage:
    ```python
    radius = 0.1
    na = 50
    nb = 40
    r = np.random.RandomState(123)

    in_coords = r.uniform(size=(na, 3))
    out_coords = r.uniform(size=(nb, 3))

    in_tree = tree_utils.KDTree(in_coords)
    out_tree = tree_utils.KDTree(out_coords)
    arr = tree_utils.query_ball_tree(in_tree, out_tree, radius)

    rel_coords = np.repeat(out_coords, arr.row_lengths, axis=0) - \
        in_coords[arr.flat_values]
    rel_dists = np.linalg.norm(rel_coords, axis=-1)

    rev_arr, rev_indices = tree_utils.reverse_query_ball(arr, na)
    rel_coords_inv = rel_coords[rev_indices]

    arr_rev, rel_dists_inv = tree_utils.reverse_query_ball(
        arr, na, rel_dists)
    np.testing.assert_allclose(
        np.linalg.norm(rel_coords_inv, axis=-1), rel_dists)

    naive_arr_rev = tree_utils.query_ball_tree(out_tree, in_tree, radius)

    np.testing.assert_equal(
        naive_arr_rev.flat_values, rev_arr.flat_values)
    np.testing.assert_equal(
        naive_arr_rev.row_splits, rev_arr.row_splits)

    naive_rel_coords_inv = np.repeat(
        in_coords, naive_arr_rev.row_lengths, axis=0) -\
        out_coords[naive_arr_rev.flat_values]
    naive_rel_dists_inv = np.linalg.norm(naive_rel_coords_inv, axis=-1)

    np.testing.assert_allclose(rel_coords_inv, -naive_rel_coords_inv)
    np.testing.assert_allclose(rel_dists_inv, naive_rel_dists_inv)
    ```

    Args:
        ragged_array: RaggedArray instance, presumably from `query_ball_tree`
            or `query_pairs`. Note if you used `query_pairs`, the returned
            `ragged_out` will be the same as `ragged_array` input (though the
            indices may still be useful)

    Returns:
        ragged_out: RaggedArray corresponding to the opposite tree search for
            which `ragged_array` used.
        data: can be used to transform data calculated using input
            ragged_array. See above example
    """
    if data is None:
        data = np.arange(ragged_array.size, dtype=np.int64)
    # take advantage of fast scipy.sparse implementations
    mat = sp.csr_matrix(
        (data, ragged_array.flat_values, ragged_array.row_splits))
    trans = mat.transpose().tocsr()
    trans.sort_indices()
    row_splits = trans.indptr
    if size is not None:
        diff = size - row_splits.size + 1
        if diff != 0:
            tail = row_splits[-1] * np.ones((diff,), dtype=row_splits.dtype)
            row_splits = np.concatenate([row_splits, tail], axis=0)
    ragged_out = RaggedArray.from_row_splits(trans.indices, row_splits)
    return ragged_out, trans.data
Esempio n. 9
0
def _maybe_clip(ragged_lists, max_neighbors, default_max_neighbors):
    if max_neighbors is None:
        max_neighbors = default_max_neighbors
    if max_neighbors is not None:
        ragged_lists = [rl[:max_neighbors] for rl in ragged_lists]
    return RaggedArray.from_ragged_lists(ragged_lists, dtype=np.int64)
Esempio n. 10
0
def sort_ragged(ra):
    rl = ra.ragged_lists
    for r in rl:
        r.sort()
    return RaggedArray.from_ragged_lists(rl, dtype=ra.dtype)