Ejemplo n.º 1
0
    def __init__(self, **kwargs):
        """
        Constructor for the class and returns a class that can bind to
        the flann libraries.  Any keyword arguments passed to __init__
        override the global defaults given.
        """

        self.__rn_gen.seed()

        self.__curindex = None
        self.__curindex_data = None
        self.__curindex_type = None

        self.__flann_parameters = FLANNParameters()
        self.__flann_parameters.update(kwargs)
Ejemplo n.º 2
0
    def __init__(self, **kwargs):
        """
        Constructor for the class and returns a class that can bind to
        the flann libraries.  Any keyword arguments passed to __init__
        override the global defaults given.
        """

        self.__rn_gen.seed()

        self.__curindex = None
        self.__curindex_data = None  # pointer to keep the numpy data alive
        self.__added_data = []  # contained to keep any added numpy data alive
        self.__removed_ids = []  # contains the point ids that have been removed
        self.__curindex_type = None

        self.__flann_parameters = FLANNParameters()
        self.__flann_parameters.update(kwargs)
Ejemplo n.º 3
0
    def __init__(self, **kwargs):
        """
        Constructor for the class and returns a class that can bind to
        the flann libraries.  Any keyword arguments passed to __init__
        override the global defaults given.
        """

        self.__rn_gen.seed()

        self.__curindex = None
        self.__curindex_data = None
        self.__curindex_type = None

        self.__flann_parameters = FLANNParameters()
        self.__flann_parameters.update(kwargs)
Ejemplo n.º 4
0
class FLANN(object):
    """
    This class defines a python interface to the FLANN lirary.
    """
    __rn_gen = _rn.RandomState()

    _as_parameter_ = property(lambda self: self.__curindex)

    def __init__(self, **kwargs):
        """
        Constructor for the class and returns a class that can bind to
        the flann libraries.  Any keyword arguments passed to __init__
        override the global defaults given.
        """

        self.__rn_gen.seed()

        self.__curindex = None
        self.__curindex_data = None
        self.__curindex_type = None

        self.__flann_parameters = FLANNParameters()
        self.__flann_parameters.update(kwargs)

    def __del__(self):
        self.delete_index()

    ##########################################################################
    # actual workhorse functions

    def nn(self, pts, qpts, num_neighbors=1, **kwargs):
        """
        Returns the num_neighbors nearest points in dataset for each point
        in testset.
        """

        if pts.dtype.type not in allowed_types:
            raise FLANNException('Cannot handle type: %s' % pts.dtype)

        if qpts.dtype.type not in allowed_types:
            raise FLANNException('Cannot handle type: %s' % qpts.dtype)

        if pts.dtype != qpts.dtype:
            raise FLANNException('Data and query must have the same type')

        pts = ensure_2d_array(pts, default_flags)
        qpts = ensure_2d_array(qpts, default_flags)

        npts, dim = pts.shape
        nqpts = qpts.shape[0]

        assert qpts.shape[1] == dim, 'data and query must have the same dims'
        assert npts >= num_neighbors, 'more neighbors than there are points'

        result = np.empty((nqpts, num_neighbors), dtype=index_type)
        if pts.dtype == np.float64:
            dists = np.empty((nqpts, num_neighbors), dtype=np.float64)
        else:
            dists = np.empty((nqpts, num_neighbors), dtype=np.float32)

        self.__flann_parameters.update(kwargs)

        flann.find_nearest_neighbors[
            pts.dtype.type](
            pts, npts, dim, qpts, nqpts, result, dists, num_neighbors,
            pointer(self.__flann_parameters))

        if num_neighbors == 1:
            return (result.reshape(nqpts), dists.reshape(nqpts))
        else:
            return (result, dists)

    def build_index(self, pts, **kwargs):
        """
        This builds and internally stores an index to be used for
        future nearest neighbor matchings.  It erases any previously
        stored indexes, so use multiple instances of this class to
        work with multiple stored indices.  Use nn_index(...) to find
        the nearest neighbors in this index.

        pts is a 2d numpy array or matrix. All the computation is done
        in np.float32 type, but pts may be any type that is convertable
        to np.float32.
        """

        if pts.dtype.type not in allowed_types:
            raise FLANNException('Cannot handle type: %s' % pts.dtype)

        pts = ensure_2d_array(pts, default_flags)
        npts, dim = pts.shape

        self.__ensureRandomSeed(kwargs)

        self.__flann_parameters.update(kwargs)

        if self.__curindex is not None:
            flann.free_index[self.__curindex_type](
                self.__curindex, pointer(self.__flann_parameters))
            self.__curindex = None

        speedup = c_float(0)
        self.__curindex = flann.build_index[pts.dtype.type](
            pts, npts, dim, byref(speedup), pointer(self.__flann_parameters))
        self.__curindex_data = pts
        self.__curindex_type = pts.dtype.type

        params = dict(self.__flann_parameters)
        params['speedup'] = speedup.value

        return params

    def save_index(self, filename):
        """
        This saves the index to a disk file.
        """
        if self.__curindex is not None:
            flann.save_index[self.__curindex_type](
                self.__curindex, c_char_p(to_bytes(filename)))

    def load_index(self, filename, pts):
        """
        Loads an index previously saved to disk.
        """

        if pts.dtype.type not in allowed_types:
            raise FLANNException('Cannot handle type: %s' % pts.dtype)

        pts = ensure_2d_array(pts, default_flags)
        npts, dim = pts.shape

        if self.__curindex is not None:
            flann.free_index[self.__curindex_type](
                self.__curindex, pointer(self.__flann_parameters))
            self.__curindex = None
            self.__curindex_data = None
            self.__curindex_type = None

        self.__curindex = flann.load_index[pts.dtype.type](
            c_char_p(to_bytes(filename)), pts, npts, dim)
        self.__curindex_data = pts
        self.__curindex_type = pts.dtype.type
        
        
    def used_memory(self):
        """
        Returns the number of bytes consumed by the index.
        """
        return flann.used_memory[self.__curindex_type](self.__curindex)
        
    def add_points(self, pts, rebuild_threshold=2.0):
        """
        Adds points to pre-built index.

        Params:
            pts: 2D numpy array of points.\n
            rebuild_threshold: reallocs index when it grows by factor of \
                `rebuild_threshold`. A smaller value results is more space \
                efficient but less computationally efficient. Must be greater \
                than 1.           
        """
        if not pts.dtype.type in allowed_types:
            raise FLANNException("Cannot handle type: %s"%pts.dtype)
        pts = ensure_2d_array(pts,default_flags) 
        npts, dim = pts.shape
        flann.add_points[self.__curindex_type](self.__curindex, pts, npts, dim, rebuild_threshold)
        self.__curindex_data = np.row_stack((self.__curindex_data,pts))
        
    def remove_point(self, idx):
        """
        Removes a point from a pre-built index.         
        """
        flann.remove_point[self.__curindex_type](self.__curindex, idx)
        self.__curindex_data = np.delete(self.__curindex_data,idx,axis=0)

    def nn_index(self, qpts, num_neighbors=1, **kwargs):
        """
        For each point in querypts, (which may be a single point), it
        returns the num_neighbors nearest points in the index built by
        calling build_index.
        """

        if self.__curindex is None:
            raise FLANNException(
                'build_index(...) method not called first or current index deleted.')

        if qpts.dtype.type not in allowed_types:
            raise FLANNException('Cannot handle type: %s' % qpts.dtype)

        if self.__curindex_type != qpts.dtype.type:
            raise FLANNException('Index and query must have the same type')

        qpts = ensure_2d_array(qpts, default_flags)

        npts, dim = self.__curindex_data.shape

        if qpts.size == dim:
            qpts.reshape(1, dim)

        nqpts = qpts.shape[0]

        assert qpts.shape[1] == dim, 'data and query must have the same dims'
        assert npts >= num_neighbors, 'more neighbors than there are points'

        result = np.empty((nqpts, num_neighbors), dtype=index_type)
        if self.__curindex_type == np.float64:
            dists = np.empty((nqpts, num_neighbors), dtype=np.float64)
        else:
            dists = np.empty((nqpts, num_neighbors), dtype=np.float32)

        self.__flann_parameters.update(kwargs)

        flann.find_nearest_neighbors_index[
            self.__curindex_type](
            self.__curindex, qpts, nqpts, result, dists, num_neighbors,
            pointer(self.__flann_parameters))

        if num_neighbors == 1:
            return (result.reshape(nqpts), dists.reshape(nqpts))
        else:
            return (result, dists)

    def nn_radius(self, query, radius, **kwargs):

        if self.__curindex is None:
            raise FLANNException(
                'build_index(...) method not called first or current index deleted.')

        if query.dtype.type not in allowed_types:
            raise FLANNException('Cannot handle type: %s' % query.dtype)

        if self.__curindex_type != query.dtype.type:
            raise FLANNException('Index and query must have the same type')

        npts, dim = self.__curindex_data.shape
        assert query.shape[0] == dim, 'data and query must have the same dims'

        result = np.empty(npts, dtype=index_type)
        if self.__curindex_type == np.float64:
            dists = np.empty(npts, dtype=np.float64)
        else:
            dists = np.empty(npts, dtype=np.float32)

        self.__flann_parameters.update(kwargs)

        nn = flann.radius_search[
            self.__curindex_type](
            self.__curindex, query, result, dists, npts, radius,
            pointer(self.__flann_parameters))

        return (result[0:nn], dists[0:nn])

    def delete_index(self, **kwargs):
        """
        Deletes the current index freeing all the momory it uses.
        The memory used by the dataset that was indexed is not freed.
        """

        self.__flann_parameters.update(kwargs)

        if self.__curindex is not None:
            flann.free_index[self.__curindex_type](
                self.__curindex, pointer(self.__flann_parameters))
            self.__curindex = None
            self.__curindex_data = None

    ##########################################################################
    # Clustering functions

    def kmeans(self, pts, num_clusters, max_iterations=None,
               dtype=None, **kwargs):
        """
        Runs kmeans on pts with num_clusters centroids.  Returns a
        numpy array of size num_clusters x dim.

        If max_iterations is not None, the algorithm terminates after
        the given number of iterations regardless of convergence.  The
        default is to run until convergence.

        If dtype is None (the default), the array returned is the same
        type as pts.  Otherwise, the returned array is of type dtype.

        """

        if int(num_clusters) != num_clusters or num_clusters < 1:
            raise FLANNException('num_clusters must be an integer >= 1')

        if num_clusters == 1:
            if dtype is None or dtype == pts.dtype:
                return np.mean(pts, 0).reshape(1, pts.shape[1])
            else:
                return dtype(np.mean(pts, 0).reshape(1, pts.shape[1]))

        return self.hierarchical_kmeans(pts, int(num_clusters), 1,
                                        max_iterations,
                                        dtype, **kwargs)

    def hierarchical_kmeans(self, pts, branch_size, num_branches,
                            max_iterations=None,
                            dtype=None, **kwargs):
        """
        Clusters the data by using multiple runs of kmeans to
        recursively partition the dataset.  The number of resulting
        clusters is given by (branch_size-1)*num_branches+1.

        This method can be significantly faster when the number of
        desired clusters is quite large (e.g. a hundred or more).
        Higher branch sizes are slower but may give better results.

        If dtype is None (the default), the array returned is the same
        type as pts.  Otherwise, the returned array is of type dtype.

        """

        # First verify the paremeters are sensible.

        if pts.dtype.type not in allowed_types:
            raise FLANNException('Cannot handle type: %s' % pts.dtype)

        if int(branch_size) != branch_size or branch_size < 2:
            raise FLANNException('branch_size must be an integer >= 2.')

        branch_size = int(branch_size)

        if int(num_branches) != num_branches or num_branches < 1:
            raise FLANNException('num_branches must be an integer >= 1.')

        num_branches = int(num_branches)

        if max_iterations is None:
            max_iterations = -1
        else:
            max_iterations = int(max_iterations)

        # init the arrays and starting values
        pts = ensure_2d_array(pts, default_flags)
        npts, dim = pts.shape
        num_clusters = (branch_size - 1) * num_branches + 1

        if pts.dtype.type == np.float64:
            result = np.empty((num_clusters, dim), dtype=np.float64)
        else:
            result = np.empty((num_clusters, dim), dtype=np.float32)

        # set all the parameters appropriately

        self.__ensureRandomSeed(kwargs)

        params = {'iterations': max_iterations,
                  'algorithm': 'kmeans',
                  'branching': branch_size,
                  'random_seed': kwargs['random_seed']}

        self.__flann_parameters.update(params)

        numclusters = flann.compute_cluster_centers[pts.dtype.type](
            pts, npts, dim, num_clusters, result,
            pointer(self.__flann_parameters))
        if numclusters <= 0:
            raise FLANNException('Error occured during clustering procedure.')

        if dtype is None:
            return result
        else:
            return dtype(result)

    ##########################################################################
    # internal bookkeeping functions

    def __ensureRandomSeed(self, kwargs):
        if 'random_seed' not in kwargs:
            kwargs['random_seed'] = self.__rn_gen.randint(2 ** 30)
Ejemplo n.º 5
0
class FLANN(object):
    """
    This class defines a python interface to the FLANN lirary.
    """
    __rn_gen = _rn.RandomState()

    _as_parameter_ = property(lambda self: self.__curindex)

    def __init__(self, **kwargs):
        """
        Constructor for the class and returns a class that can bind to
        the flann libraries.  Any keyword arguments passed to __init__
        override the global defaults given.
        """

        self.__rn_gen.seed()

        self.__curindex = None
        self.__curindex_data = None  # pointer to keep the numpy data alive
        self.__added_data = []  # contained to keep any added numpy data alive
        self.__removed_ids = []  # contains the point ids that have been removed
        self.__curindex_type = None

        self.__flann_parameters = FLANNParameters()
        self.__flann_parameters.update(kwargs)

    def __del__(self):
        #print('FLANN OBJECT IS DELETED')
        self.delete_index()

    @property
    def shape(self):
        return self.get_indexed_shape()

    @property
    def __len__(self):
        return self.shape[0]

    def get_indexed_shape(self):
        """ returns the shape of the data being indexed """
        npts, dim = self.__curindex_data.shape
        for _extra in self.__added_data:
            npts += _extra.shape[0]
        npts -= len(self.__removed_ids)
        return npts, dim

    def get_indexed_data(self):
        """
        returns all the data indexed by the FLANN object

        (this returns points that have been removed but still exist in memory)
        """
        return self.__curindex_data, self.__added_data

    def used_memory_dataset(self):
        """
        Returns the amount of memory used by the dataset
        """
        if self.__curindex_data is None:
            return 0
        num_bytes = self.__curindex_data.nbytes
        for _extra in self.__added_data:
            num_bytes += _extra.nbytes
        return num_bytes

    def used_memory(self):
        """
        Returns the amount of memory used by the index

        Returns: int
        """
        if self.__curindex is None:
            return 0
        return flann.used_memory[self.__curindex_type](self.__curindex)

    ##########################################################################
    # actual workhorse functions

    def nn(self, pts, qpts, num_neighbors=1, **kwargs):
        """
        Returns the num_neighbors nearest points in dataset for each point
        in testset.
        """

        if pts.dtype.type not in allowed_types:
            raise FLANNException('Cannot handle type: %s' % pts.dtype)

        if qpts.dtype.type not in allowed_types:
            raise FLANNException('Cannot handle type: %s' % pts.dtype)

        if pts.dtype != qpts.dtype:
            raise FLANNException('Data and query must have the same type')

        pts = ensure_2d_array(pts, default_flags)
        qpts = ensure_2d_array(qpts, default_flags)

        npts, dim = pts.shape
        nqpts = qpts.shape[0]

        assert qpts.shape[1] == dim, 'data and query must have the same dims'
        assert npts >= num_neighbors, 'more neighbors than there are points'

        result = np.empty((nqpts, num_neighbors), dtype=index_type)
        if pts.dtype == np.float64:
            dists = np.empty((nqpts, num_neighbors), dtype=np.float64)
        else:
            dists = np.empty((nqpts, num_neighbors), dtype=np.float32)

        self.__flann_parameters.update(kwargs)

        flann.find_nearest_neighbors[
            pts.dtype.type](
            pts, npts, dim, qpts, nqpts, result, dists, num_neighbors,
            pointer(self.__flann_parameters))

        if num_neighbors == 1:
            return (result.reshape(nqpts), dists.reshape(nqpts))
        else:
            return (result, dists)

    def build_index(self, pts, **kwargs):
        """
        This builds and internally stores an index to be used for
        future nearest neighbor matchings.  It erases any previously
        stored indexes, so use multiple instances of this class to
        work with multiple stored indices.  Use nn_index(...) to find
        the nearest neighbors in this index.

        pts is a 2d numpy array or matrix. All the computation is done
        in np.float32 type, but pts may be any type that is convertable
        to np.float32.
        """

        if pts.dtype.type not in allowed_types:
            raise FLANNException('Cannot handle type: %s' % pts.dtype)

        pts = ensure_2d_array(pts, default_flags)
        npts, dim = pts.shape

        self.__ensureRandomSeed(kwargs)

        self.__flann_parameters.update(kwargs)

        if self.__curindex is not None:
            flann.free_index[self.__curindex_type](
                self.__curindex, pointer(self.__flann_parameters))
            self.__curindex = None

        speedup = c_float(0)
        self.__curindex = flann.build_index[pts.dtype.type](
            pts, npts, dim, byref(speedup), pointer(self.__flann_parameters))
        self.__curindex_data = pts
        self.__curindex_type = pts.dtype.type

        params = dict(self.__flann_parameters)
        params['speedup'] = speedup.value

        return params

    def add_points(self, new_pts, rebuild_threshold=2):
        """
        Adds pts to the current index. If the number of added points is more
        than a factor of rebuild_threshold larger than the original number of
        points, the index is rebuilt.
        """
        if new_pts.dtype.type not in allowed_types:
            raise FLANNException('Cannot handle type: %s' % new_pts.dtype)
        if new_pts.dtype != self.__curindex_type:
            raise FLANNException('New points must have the same type')
        new_pts = ensure_2d_array(new_pts, default_flags)
        rows = new_pts.shape[0]
        flann.add_points[self.__curindex_type](self.__curindex, new_pts, rows, rebuild_threshold)
        self.__added_data.append(new_pts)

    def remove_point(self, id_):
        """
        Removes a point from the index

        Params:
            id = point id to be removed

        Returns: void
        """
        flann.remove_point[self.__curindex_type](self.__curindex, id_)
        self.__removed_ids.append(id_)

    def remove_points(self, id_list):
        """
        Removes multiple points from the index

        Params:
            id_list = point ids to be removed

        Returns: void
        """
        for id_ in id_list:
            flann.remove_point[self.__curindex_type](self.__curindex, id_)

    def save_index(self, filename):
        """
        This saves the index to a disk file.
        """
        if self.__curindex is not None:
            flann.save_index[self.__curindex_type](
                self.__curindex, c_char_p(to_bytes(filename)))

    def load_index(self, filename, pts):
        """
        Loads an index previously saved to disk.
        """

        if pts.dtype.type not in allowed_types:
            raise FLANNException('Cannot handle type: %s' % pts.dtype)

        pts = ensure_2d_array(pts, default_flags)
        npts, dim = pts.shape

        if self.__curindex is not None:
            flann.free_index[self.__curindex_type](
                self.__curindex, pointer(self.__flann_parameters))
            self.__curindex = None
            self.__curindex_data = None
            self.__added_data = []
            self.__curindex_type = None

        self.__curindex = flann.load_index[pts.dtype.type](
            c_char_p(to_bytes(filename)), pts, npts, dim)

        if self.__curindex is None:
            raise FLANNException(
                ('Error loading the FLANN index with filename=%r.'
                 ' C++ may have thrown more detailed errors') % (filename,))

        self.__curindex_data = pts
        self.__added_data = []
        self.__removed_ids = []
        self.__curindex_type = pts.dtype.type

    def nn_index(self, qpts, num_neighbors=1, **kwargs):
        """
        For each point in querypts, (which may be a single point), it
        returns the num_neighbors nearest points in the index built by
        calling build_index.
        """

        if self.__curindex is None:
            raise FLANNException(
                'build_index(...) method not called first or current index deleted.')

        if qpts.dtype.type not in allowed_types:
            raise FLANNException('Cannot handle type: %s' % qpts.dtype)

        if self.__curindex_type != qpts.dtype.type:
            raise FLANNException('Index and query must have the same type')

        qpts = ensure_2d_array(qpts, default_flags)

        npts, dim = self.get_indexed_shape()

        if qpts.size == dim:
            qpts.reshape(1, dim)

        nqpts = qpts.shape[0]

        assert qpts.shape[1] == dim, 'data and query must have the same dims'
        assert npts >= num_neighbors, 'more neighbors than there are points'

        result = np.empty((nqpts, num_neighbors), dtype=index_type)
        if self.__curindex_type == np.float64:
            dists = np.empty((nqpts, num_neighbors), dtype=np.float64)
        else:
            dists = np.empty((nqpts, num_neighbors), dtype=np.float32)

        self.__flann_parameters.update(kwargs)

        flann.find_nearest_neighbors_index[
            self.__curindex_type](
            self.__curindex, qpts, nqpts, result, dists, num_neighbors,
            pointer(self.__flann_parameters))

        if num_neighbors == 1:
            return (result.reshape(nqpts), dists.reshape(nqpts))
        else:
            return (result, dists)

    def nn_radius(self, query, radius, **kwargs):

        if self.__curindex is None:
            raise FLANNException(
                'build_index(...) method not called first or current index deleted.')

        if query.dtype.type not in allowed_types:
            raise FLANNException('Cannot handle type: %s' % query.dtype)

        if self.__curindex_type != query.dtype.type:
            raise FLANNException('Index and query must have the same type')

        npts, dim = self.get_indexed_shape()
        assert(query.shape[0] == dim), 'data and query must have the same dims'

        result = np.empty(npts, dtype=index_type)
        if self.__curindex_type == np.float64:
            dists = np.empty(npts, dtype=np.float64)
        else:
            dists = np.empty(npts, dtype=np.float32)

        self.__flann_parameters.update(kwargs)

        nn = flann.radius_search[
            self.__curindex_type](
            self.__curindex, query, result, dists, npts, radius,
            pointer(self.__flann_parameters))

        return (result[0:nn], dists[0:nn])

    def delete_index(self, **kwargs):
        """
        Deletes the current index freeing all the momory it uses.
        The memory used by the dataset that was indexed is not freed
        unless there are no other references to those numpy arrays.
        """

        self.__flann_parameters.update(kwargs)

        if self.__curindex is not None and flann is not None:
            flann.free_index[self.__curindex_type](
                self.__curindex, pointer(self.__flann_parameters))
            self.__curindex = None
            self.__curindex_data = None
            self.__added_data = []
            self.__removed_ids = []

    ##########################################################################
    # Clustering functions

    def kmeans(self, pts, num_clusters, max_iterations=None,
               dtype=None, **kwargs):
        """
        Runs kmeans on pts with num_clusters centroids.  Returns a
        numpy array of size num_clusters x dim.

        If max_iterations is not None, the algorithm terminates after
        the given number of iterations regardless of convergence.  The
        default is to run until convergence.

        If dtype is None (the default), the array returned is the same
        type as pts.  Otherwise, the returned array is of type dtype.

        """

        if int(num_clusters) != num_clusters or num_clusters < 1:
            raise FLANNException('num_clusters must be an integer >= 1')

        if num_clusters == 1:
            if dtype is None or dtype == pts.dtype:
                return np.mean(pts, 0).reshape(1, pts.shape[1])
            else:
                return dtype(np.mean(pts, 0).reshape(1, pts.shape[1]))

        return self.hierarchical_kmeans(pts, int(num_clusters), 1,
                                        max_iterations,
                                        dtype, **kwargs)

    def hierarchical_kmeans(self, pts, branch_size, num_branches,
                            max_iterations=None,
                            dtype=None, **kwargs):
        """
        Clusters the data by using multiple runs of kmeans to
        recursively partition the dataset.  The number of resulting
        clusters is given by (branch_size-1)*num_branches+1.

        This method can be significantly faster when the number of
        desired clusters is quite large (e.g. a hundred or more).
        Higher branch sizes are slower but may give better results.

        If dtype is None (the default), the array returned is the same
        type as pts.  Otherwise, the returned array is of type dtype.

        """

        # First verify the paremeters are sensible.

        if pts.dtype.type not in allowed_types:
            raise FLANNException('Cannot handle type: %s' % pts.dtype)

        if int(branch_size) != branch_size or branch_size < 2:
            raise FLANNException('branch_size must be an integer >= 2.')

        branch_size = int(branch_size)

        if int(num_branches) != num_branches or num_branches < 1:
            raise FLANNException('num_branches must be an integer >= 1.')

        num_branches = int(num_branches)

        if max_iterations is None:
            max_iterations = -1
        else:
            max_iterations = int(max_iterations)

        # init the arrays and starting values
        pts = ensure_2d_array(pts, default_flags)
        npts, dim = pts.shape
        num_clusters = (branch_size - 1) * num_branches + 1

        if pts.dtype.type == np.float64:
            result = np.empty((num_clusters, dim), dtype=np.float64)
        else:
            result = np.empty((num_clusters, dim), dtype=np.float32)

        # set all the parameters appropriately

        self.__ensureRandomSeed(kwargs)

        params = {'iterations': max_iterations,
                  'algorithm': 'kmeans',
                  'branching': branch_size,
                  'random_seed': kwargs['random_seed']}

        self.__flann_parameters.update(params)

        numclusters = flann.compute_cluster_centers[pts.dtype.type](
            pts, npts, dim, num_clusters, result,
            pointer(self.__flann_parameters))
        if numclusters <= 0:
            raise FLANNException('Error occured during clustering procedure.')

        if dtype is None:
            return result
        else:
            return dtype(result)

    ##########################################################################
    # internal bookkeeping functions

    def __ensureRandomSeed(self, kwargs):
        if 'random_seed' not in kwargs:
            kwargs['random_seed'] = self.__rn_gen.randint(2 ** 30)
Ejemplo n.º 6
0
class FLANN(object):
    """
    This class defines a python interface to the FLANN lirary.

    Example:
        >>> from pyflann import FLANN
        >>> import numpy as np
        >>> dvecs = np.random.rand(1000, 128)
        >>> qvecs = np.random.rand(10, 128)
        >>> flann = FLANN()
        >>> params = flann.build_index(dvecs)
        >>> # xdoctest: +IGNORE_WANT
        >>> qx_to_dx, qx_to_dist = flann.nn_index(qvecs, num_neighbors=5)
        >>> print('qx_to_dist = {!r}'.format(qx_to_dist))
        >>> print('qx_to_dx = {!r}'.format(qx_to_dx))
        qx_to_dist = array([[16.6, 16.74, 18.055, 18.12, 18.4],
                            [16.2, 16.96, 17.204, 18.50, 18.6],
                            [14.2, 15.91, 15.961, 16.76, 16.8],
                            [18.0, 18.09, 18.389, 18.62, 18.8],
                            [16.0, 17.32, 19.177, 19.17, 19.3],
                            [17.1, 17.83, 17.941, 18.32, 18.5],
                            [17.0, 17.80, 17.888, 18.32, 18.5],
                            [17.6, 17.72, 17.852, 17.93, 18.1],
                            [15.6, 15.66, 16.117, 16.49, 17.0],
                            [14.3, 15.88, 16.186, 16.74, 17.2]])
        qx_to_dx = array([[113, 150, 122, 956, 444],
                          [469, 308, 363, 229, 613],
                          [591, 564, 153,   0, 396],
                          [954, 800, 515, 460, 114],
                          [107, 266, 698, 708, 451],
                          [ 38,  86, 404, 579, 517],
                          [942, 631,  10, 561, 230],
                          [152, 691, 582, 153, 840],
                          [700, 543, 582, 971, 826],
                          [  8, 843, 670, 853,  79]], dtype=int32)
        >>> print('flann.shape = {!r}'.format(flann.shape))
        flann.shape = (1000, 128)
        >>> flann.remove_points(np.unique(qx_to_dx.ravel()))
        >>> qx_to_dx2, qx_to_dist2 = flann.nn_index(qvecs, num_neighbors=5)
        >>> assert len(set(qx_to_dx2.ravel()) & set(qx_to_dx.ravel())) == 0
        >>> flann.add_points(dvecs + 1)

    """

    __rn_gen = _rn.RandomState()

    _as_parameter_ = property(lambda self: self.__curindex)

    def __init__(self, **kwargs):
        """
        Constructor for the class and returns a class that can bind to
        the flann libraries.  Any keyword arguments passed to __init__
        override the global defaults given.
        """

        self.__rn_gen.seed()

        self.__curindex = None
        self.__curindex_data = None  # pointer to keep the numpy data alive
        self.__added_data = []  # contained to keep any added numpy data alive
        self.__removed_ids = [
        ]  # contains the point ids that have been removed
        self.__curindex_type = None
        self.__added_data = []  # contained to keep any added numpy data alive
        self.__removed_ids = [
        ]  # contains the point ids that have been removed

        self.__flann_parameters = FLANNParameters()
        self.__flann_parameters.update(kwargs)

    def __del__(self):
        # print('FLANN OBJECT IS DELETED')
        self.delete_index()

    @property
    def shape(self):
        return self.get_indexed_shape()

    @property
    def __len__(self):
        return self.shape[0]

    def get_indexed_shape(self):
        """ returns the shape of the data being indexed """
        npts, dim = self.__curindex_data.shape
        for _extra in self.__added_data:
            npts += _extra.shape[0]
        npts -= len(self.__removed_ids)
        return npts, dim

    def get_indexed_data(self):
        """
        returns all the data indexed by the FLANN object

        (this returns points that have been removed but still exist in memory)
        """
        return self.__curindex_data, self.__added_data

    def used_memory_dataset(self):
        """
        Returns the amount of memory used by the dataset
        """
        if self.__curindex_data is None:
            return 0
        num_bytes = self.__curindex_data.nbytes
        for _extra in self.__added_data:
            num_bytes += _extra.nbytes
        return num_bytes

    def used_memory(self):
        """
        Returns the amount of memory used by the index

        Returns: int
        """
        if self.__curindex is None:
            return 0
        return flann.used_memory[self.__curindex_type](self.__curindex)

    ##########################################################################
    # actual workhorse functions

    def nn(self, pts, qpts, num_neighbors=1, **kwargs):
        """
        Returns the num_neighbors nearest points in dataset for each point
        in testset.
        """

        if pts.dtype.type not in allowed_types:
            raise FLANNException('Cannot handle type: %s' % pts.dtype)

        if qpts.dtype.type not in allowed_types:
            raise FLANNException('Cannot handle type: %s' % pts.dtype)

        if pts.dtype != qpts.dtype:
            raise FLANNException('Data and query must have the same type')

        pts = ensure_2d_array(pts, default_flags)
        qpts = ensure_2d_array(qpts, default_flags)

        npts, dim = pts.shape
        nqpts = qpts.shape[0]

        assert qpts.shape[1] == dim, 'data and query must have the same dims'
        assert npts >= num_neighbors, 'more neighbors than there are points'

        result = np.empty((nqpts, num_neighbors), dtype=index_type)
        if pts.dtype == np.float64:
            dists = np.empty((nqpts, num_neighbors), dtype=np.float64)
        else:
            dists = np.empty((nqpts, num_neighbors), dtype=np.float32)

        self.__flann_parameters.update(kwargs)

        flann.find_nearest_neighbors[pts.dtype.type](
            pts,
            npts,
            dim,
            qpts,
            nqpts,
            result,
            dists,
            num_neighbors,
            pointer(self.__flann_parameters),
        )

        if num_neighbors == 1:
            return (result.reshape(nqpts), dists.reshape(nqpts))
        else:
            return (result, dists)

    def build_index(self, pts, **kwargs):
        """
        This builds and internally stores an index to be used for
        future nearest neighbor matchings.  It erases any previously
        stored indexes, so use multiple instances of this class to
        work with multiple stored indices.  Use nn_index(...) to find
        the nearest neighbors in this index.

        pts is a 2d numpy array or matrix. All the computation is done
        in np.float32 type, but pts may be any type that is convertable
        to np.float32.
        """

        if pts.dtype.type not in allowed_types:
            raise FLANNException('Cannot handle type: %s' % pts.dtype)

        pts = ensure_2d_array(pts, default_flags)
        npts, dim = pts.shape

        self.__ensureRandomSeed(kwargs)

        self.__flann_parameters.update(kwargs)

        if self.__curindex is not None:
            flann.free_index[self.__curindex_type](
                self.__curindex, pointer(self.__flann_parameters))
            self.__curindex = None

        speedup = c_float(0)
        self.__curindex = flann.build_index[pts.dtype.type](
            pts, npts, dim, byref(speedup), pointer(self.__flann_parameters))
        self.__curindex_data = pts
        self.__curindex_type = pts.dtype.type

        params = dict(self.__flann_parameters)
        params['speedup'] = speedup.value

        return params

    def add_points(self, pts, rebuild_threshold=2):
        """
        Adds points to pre-built index.

        Params:
            pts: 2D numpy array of points.
            rebuild_threshold: reallocs index when it grows by factor of
                `rebuild_threshold`. A smaller value results is more space
                efficient but less computationally efficient. Must be greater
                than 1.
        """
        if pts.dtype.type not in allowed_types:
            raise FLANNException('Cannot handle type: %s' % pts.dtype)
        if pts.dtype.type != self.__curindex_type:
            raise FLANNException('New points must have the same type')
        pts = ensure_2d_array(pts, default_flags)

        npts = pts.shape[0]
        flann.add_points[self.__curindex_type](self.__curindex, pts, npts,
                                               rebuild_threshold)
        self.__curindex_data = np.row_stack((self.__curindex_data, pts))
        self.__added_data.append(pts)

    def remove_point(self, idx):
        """
        Removes a point from a pre-built index.
        """
        flann.remove_point[self.__curindex_type](self.__curindex, idx)
        # Not sure if this is ok:
        self.__curindex_data = np.delete(self.__curindex_data, idx, axis=0)
        self.__removed_ids.append(idx)

    def remove_points(self, id_list):
        """
        Removes multiple points from the index

        Params:
            id_list = point ids to be removed

        Returns: void
        """
        for id_ in id_list:
            flann.remove_point[self.__curindex_type](self.__curindex, id_)

    def save_index(self, filename):
        """
        This saves the index to a disk file.
        """
        if self.__curindex is not None:
            flann.save_index[self.__curindex_type](self.__curindex,
                                                   c_char_p(
                                                       to_bytes(filename)))

    def load_index(self, filename, pts):
        """
        Loads an index previously saved to disk.
        """

        if pts.dtype.type not in allowed_types:
            raise FLANNException('Cannot handle type: %s' % pts.dtype)

        pts = ensure_2d_array(pts, default_flags)
        npts, dim = pts.shape

        if self.__curindex is not None:
            flann.free_index[self.__curindex_type](
                self.__curindex, pointer(self.__flann_parameters))
            self.__curindex = None
            self.__curindex_data = None
            self.__added_data = []
            self.__curindex_type = None
            self.__added_data = []
            self.__removed_ids = []

        self.__curindex = flann.load_index[pts.dtype.type](c_char_p(
            to_bytes(filename)), pts, npts, dim)

        if self.__curindex is None:
            raise FLANNException(
                ('Error loading the FLANN index with filename=%r.'
                 ' C++ may have thrown more detailed errors') % (filename, ))

        self.__curindex_data = pts
        self.__added_data = []
        self.__removed_ids = []
        self.__curindex_type = pts.dtype.type
        self.__added_data = []
        self.__removed_ids = []
        self.__removed_ids = []

    def nn_index(self, qpts, num_neighbors=1, **kwargs):
        """
        For each point in querypts, (which may be a single point), it
        returns the num_neighbors nearest points in the index built by
        calling build_index.
        """

        if self.__curindex is None:
            raise FLANNException(
                'build_index(...) method not called first or current index deleted.'
            )

        if qpts.dtype.type not in allowed_types:
            raise FLANNException('Cannot handle type: %s' % qpts.dtype)

        if self.__curindex_type != qpts.dtype.type:
            raise FLANNException('Index and query must have the same type')

        qpts = ensure_2d_array(qpts, default_flags)

        npts, dim = self.get_indexed_shape()

        if qpts.size == dim:
            qpts.reshape(1, dim)

        nqpts = qpts.shape[0]

        assert qpts.shape[1] == dim, 'data and query must have the same dims'
        assert npts >= num_neighbors, 'more neighbors than there are points'

        result = np.empty((nqpts, num_neighbors), dtype=index_type)
        if self.__curindex_type == np.float64:
            dists = np.empty((nqpts, num_neighbors), dtype=np.float64)
        else:
            dists = np.empty((nqpts, num_neighbors), dtype=np.float32)

        self.__flann_parameters.update(kwargs)

        flann.find_nearest_neighbors_index[self.__curindex_type](
            self.__curindex,
            qpts,
            nqpts,
            result,
            dists,
            num_neighbors,
            pointer(self.__flann_parameters),
        )

        if num_neighbors == 1:
            return (result.reshape(nqpts), dists.reshape(nqpts))
        else:
            return (result, dists)

    def nn_radius(self, query, radius, **kwargs):

        if self.__curindex is None:
            raise FLANNException(
                'build_index(...) method not called first or current index deleted.'
            )

        if query.dtype.type not in allowed_types:
            raise FLANNException('Cannot handle type: %s' % query.dtype)

        if self.__curindex_type != query.dtype.type:
            raise FLANNException('Index and query must have the same type')

        npts, dim = self.get_indexed_shape()
        assert query.shape[0] == dim, 'data and query must have the same dims'

        result = np.empty(npts, dtype=index_type)
        if self.__curindex_type == np.float64:
            dists = np.empty(npts, dtype=np.float64)
        else:
            dists = np.empty(npts, dtype=np.float32)

        self.__flann_parameters.update(kwargs)

        nn = flann.radius_search[self.__curindex_type](
            self.__curindex,
            query,
            result,
            dists,
            npts,
            radius,
            pointer(self.__flann_parameters),
        )

        return (result[0:nn], dists[0:nn])

    def delete_index(self, **kwargs):
        """
        Deletes the current index freeing all the momory it uses.
        The memory used by the dataset that was indexed is not freed
        unless there are no other references to those numpy arrays.
        """

        self.__flann_parameters.update(kwargs)

        if self.__curindex is not None and flann is not None:
            flann.free_index[self.__curindex_type](
                self.__curindex, pointer(self.__flann_parameters))
            self.__curindex = None
            self.__curindex_data = None
            self.__added_data = []
            self.__removed_ids = []

    ##########################################################################
    # Clustering functions

    def kmeans(self,
               pts,
               num_clusters,
               max_iterations=None,
               dtype=None,
               **kwargs):
        """
        Runs kmeans on pts with num_clusters centroids.  Returns a
        numpy array of size num_clusters x dim.

        If max_iterations is not None, the algorithm terminates after
        the given number of iterations regardless of convergence.  The
        default is to run until convergence.

        If dtype is None (the default), the array returned is the same
        type as pts.  Otherwise, the returned array is of type dtype.

        """

        if int(num_clusters) != num_clusters or num_clusters < 1:
            raise FLANNException('num_clusters must be an integer >= 1')

        if num_clusters == 1:
            if dtype is None or dtype == pts.dtype:
                return np.mean(pts, 0).reshape(1, pts.shape[1])
            else:
                return dtype(np.mean(pts, 0).reshape(1, pts.shape[1]))

        return self.hierarchical_kmeans(pts, int(num_clusters), 1,
                                        max_iterations, dtype, **kwargs)

    def hierarchical_kmeans(self,
                            pts,
                            branch_size,
                            num_branches,
                            max_iterations=None,
                            dtype=None,
                            **kwargs):
        """
        Clusters the data by using multiple runs of kmeans to
        recursively partition the dataset.  The number of resulting
        clusters is given by (branch_size-1)*num_branches+1.

        This method can be significantly faster when the number of
        desired clusters is quite large (e.g. a hundred or more).
        Higher branch sizes are slower but may give better results.

        If dtype is None (the default), the array returned is the same
        type as pts.  Otherwise, the returned array is of type dtype.

        """

        # First verify the paremeters are sensible.

        if pts.dtype.type not in allowed_types:
            raise FLANNException('Cannot handle type: %s' % pts.dtype)

        if int(branch_size) != branch_size or branch_size < 2:
            raise FLANNException('branch_size must be an integer >= 2.')

        branch_size = int(branch_size)

        if int(num_branches) != num_branches or num_branches < 1:
            raise FLANNException('num_branches must be an integer >= 1.')

        num_branches = int(num_branches)

        if max_iterations is None:
            max_iterations = -1
        else:
            max_iterations = int(max_iterations)

        # init the arrays and starting values
        pts = ensure_2d_array(pts, default_flags)
        npts, dim = pts.shape
        num_clusters = (branch_size - 1) * num_branches + 1

        if pts.dtype.type == np.float64:
            result = np.empty((num_clusters, dim), dtype=np.float64)
        else:
            result = np.empty((num_clusters, dim), dtype=np.float32)

        # set all the parameters appropriately

        self.__ensureRandomSeed(kwargs)

        params = {
            'iterations': max_iterations,
            'algorithm': 'kmeans',
            'branching': branch_size,
            'random_seed': kwargs['random_seed'],
        }

        self.__flann_parameters.update(params)

        numclusters = flann.compute_cluster_centers[pts.dtype.type](
            pts,
            npts,
            dim,
            num_clusters,
            result,
            pointer(self.__flann_parameters),
        )
        if numclusters <= 0:
            raise FLANNException('Error occured during clustering procedure.')

        if dtype is None:
            return result
        else:
            return dtype(result)

    ##########################################################################
    # internal bookkeeping functions

    def __ensureRandomSeed(self, kwargs):
        if 'random_seed' not in kwargs:
            kwargs['random_seed'] = self.__rn_gen.randint(2**30)
Ejemplo n.º 7
0
class FLANN(object):
    """
    This class defines a python interface to the FLANN lirary.
    """

    __rn_gen = _rn.RandomState()

    _as_parameter_ = property(lambda self: self.__curindex)

    def __init__(self, **kwargs):
        """
        Constructor for the class and returns a class that can bind to
        the flann libraries.  Any keyword arguments passed to __init__
        override the global defaults given.
        """

        self.__rn_gen.seed()

        self.__curindex = None
        self.__curindex_data = None
        self.__curindex_type = None

        self.__flann_parameters = FLANNParameters()
        self.__flann_parameters.update(kwargs)

    def __del__(self):
        self.delete_index()

    ##########################################################################
    # actual workhorse functions

    def nn(self, pts, qpts, num_neighbors=1, **kwargs):
        """
        Returns the num_neighbors nearest points in dataset for each point
        in testset.
        """

        if pts.dtype.type not in allowed_types:
            raise FLANNException("Cannot handle type: %s" % pts.dtype)

        if qpts.dtype.type not in allowed_types:
            raise FLANNException("Cannot handle type: %s" % pts.dtype)

        if pts.dtype != qpts.dtype:
            raise FLANNException("Data and query must have the same type")

        pts = ensure_2d_array(pts, default_flags)
        qpts = ensure_2d_array(qpts, default_flags)

        npts, dim = pts.shape
        nqpts = qpts.shape[0]

        assert qpts.shape[1] == dim, "data and query must have the same dims"
        if num_neighbors > npts:
            num_neighbors = npts
        assert npts >= num_neighbors, "more neighbors than there are points"

        result = np.empty((nqpts, num_neighbors), dtype=index_type)
        if pts.dtype == np.float64:
            dists = np.empty((nqpts, num_neighbors), dtype=np.float64)
        else:
            dists = np.empty((nqpts, num_neighbors), dtype=np.float32)

        self.__flann_parameters.update(kwargs)

        flann.find_nearest_neighbors[pts.dtype.type](
            pts, npts, dim, qpts, nqpts, result, dists, num_neighbors, pointer(self.__flann_parameters)
        )

        if num_neighbors == 1:
            return (result.reshape(nqpts), dists.reshape(nqpts))
        else:
            return (result, dists)

    def build_index(self, pts, **kwargs):
        """
        This builds and internally stores an index to be used for
        future nearest neighbor matchings.  It erases any previously
        stored indexes, so use multiple instances of this class to
        work with multiple stored indices.  Use nn_index(...) to find
        the nearest neighbors in this index.

        pts is a 2d numpy array or matrix. All the computation is done
        in np.float32 type, but pts may be any type that is convertable
        to np.float32.
        """

        if pts.dtype.type not in allowed_types:
            raise FLANNException("Cannot handle type: %s" % pts.dtype)

        pts = ensure_2d_array(pts, default_flags)
        npts, dim = pts.shape

        self.__ensureRandomSeed(kwargs)

        self.__flann_parameters.update(kwargs)

        if self.__curindex is not None:
            flann.free_index[self.__curindex_type](self.__curindex, pointer(self.__flann_parameters))
            self.__curindex = None

        speedup = c_float(0)
        self.__curindex = flann.build_index[pts.dtype.type](
            pts, npts, dim, byref(speedup), pointer(self.__flann_parameters)
        )
        self.__curindex_data = pts
        self.__curindex_type = pts.dtype.type

        params = dict(self.__flann_parameters)
        params["speedup"] = speedup.value

        return params

    def save_index(self, filename):
        """
        This saves the index to a disk file.
        """
        if self.__curindex is not None:
            flann.save_index[self.__curindex_type](self.__curindex, c_char_p(to_bytes(filename)))

    def load_index(self, filename, pts):
        """
        Loads an index previously saved to disk.
        """

        if pts.dtype.type not in allowed_types:
            raise FLANNException("Cannot handle type: %s" % pts.dtype)

        pts = ensure_2d_array(pts, default_flags)
        npts, dim = pts.shape

        if self.__curindex is not None:
            flann.free_index[self.__curindex_type](self.__curindex, pointer(self.__flann_parameters))
            self.__curindex = None
            self.__curindex_data = None
            self.__curindex_type = None

        self.__curindex = flann.load_index[pts.dtype.type](c_char_p(to_bytes(filename)), pts, npts, dim)
        self.__curindex_data = pts
        self.__curindex_type = pts.dtype.type

    def nn_index(self, qpts, num_neighbors=1, **kwargs):
        """
        For each point in querypts, (which may be a single point), it
        returns the num_neighbors nearest points in the index built by
        calling build_index.
        """

        if self.__curindex is None:
            raise FLANNException("build_index(...) method not called first or current index deleted.")

        if qpts.dtype.type not in allowed_types:
            raise FLANNException("Cannot handle type: %s" % qpts.dtype)

        if self.__curindex_type != qpts.dtype.type:
            raise FLANNException("Index and query must have the same type")

        qpts = ensure_2d_array(qpts, default_flags)

        npts, dim = self.__curindex_data.shape

        if qpts.size == dim:
            qpts.reshape(1, dim)

        nqpts = qpts.shape[0]

        assert qpts.shape[1] == dim, "data and query must have the same dims"
        if num_neighbors > npts:
            num_neighbors = npts
        assert npts >= num_neighbors, "more neighbors than there are points"

        result = np.empty((nqpts, num_neighbors), dtype=index_type)
        if self.__curindex_type == np.float64:
            dists = np.empty((nqpts, num_neighbors), dtype=np.float64)
        else:
            dists = np.empty((nqpts, num_neighbors), dtype=np.float32)

        self.__flann_parameters.update(kwargs)

        flann.find_nearest_neighbors_index[self.__curindex_type](
            self.__curindex, qpts, nqpts, result, dists, num_neighbors, pointer(self.__flann_parameters)
        )

        if num_neighbors == 1:
            return (result.reshape(nqpts), dists.reshape(nqpts))
        else:
            return (result, dists)

    def nn_radius(self, query, radius, **kwargs):

        if self.__curindex is None:
            raise FLANNException("build_index(...) method not called first or current index deleted.")

        if query.dtype.type not in allowed_types:
            raise FLANNException("Cannot handle type: %s" % query.dtype)

        if self.__curindex_type != query.dtype.type:
            raise FLANNException("Index and query must have the same type")

        npts, dim = self.__curindex_data.shape
        assert query.shape[0] == dim, "data and query must have the same dims"

        result = np.empty(npts, dtype=index_type)
        if self.__curindex_type == np.float64:
            dists = np.empty(npts, dtype=np.float64)
        else:
            dists = np.empty(npts, dtype=np.float32)

        self.__flann_parameters.update(kwargs)

        nn = flann.radius_search[self.__curindex_type](
            self.__curindex, query, result, dists, npts, radius, pointer(self.__flann_parameters)
        )

        return (result[0:nn], dists[0:nn])

    def delete_index(self, **kwargs):
        """
        Deletes the current index freeing all the momory it uses.
        The memory used by the dataset that was indexed is not freed.
        """

        self.__flann_parameters.update(kwargs)

        if self.__curindex is not None:
            flann.free_index[self.__curindex_type](self.__curindex, pointer(self.__flann_parameters))
            self.__curindex = None
            self.__curindex_data = None

    ##########################################################################
    # Clustering functions

    def kmeans(self, pts, num_clusters, max_iterations=None, dtype=None, **kwargs):
        """
        Runs kmeans on pts with num_clusters centroids.  Returns a
        numpy array of size num_clusters x dim.

        If max_iterations is not None, the algorithm terminates after
        the given number of iterations regardless of convergence.  The
        default is to run until convergence.

        If dtype is None (the default), the array returned is the same
        type as pts.  Otherwise, the returned array is of type dtype.

        """

        if int(num_clusters) != num_clusters or num_clusters < 1:
            raise FLANNException("num_clusters must be an integer >= 1")

        if num_clusters == 1:
            if dtype is None or dtype == pts.dtype:
                return np.mean(pts, 0).reshape(1, pts.shape[1])
            else:
                return dtype(np.mean(pts, 0).reshape(1, pts.shape[1]))

        return self.hierarchical_kmeans(pts, int(num_clusters), 1, max_iterations, dtype, **kwargs)

    def hierarchical_kmeans(self, pts, branch_size, num_branches, max_iterations=None, dtype=None, **kwargs):
        """
        Clusters the data by using multiple runs of kmeans to
        recursively partition the dataset.  The number of resulting
        clusters is given by (branch_size-1)*num_branches+1.

        This method can be significantly faster when the number of
        desired clusters is quite large (e.g. a hundred or more).
        Higher branch sizes are slower but may give better results.

        If dtype is None (the default), the array returned is the same
        type as pts.  Otherwise, the returned array is of type dtype.

        """

        # First verify the paremeters are sensible.

        if pts.dtype.type not in allowed_types:
            raise FLANNException("Cannot handle type: %s" % pts.dtype)

        if int(branch_size) != branch_size or branch_size < 2:
            raise FLANNException("branch_size must be an integer >= 2.")

        branch_size = int(branch_size)

        if int(num_branches) != num_branches or num_branches < 1:
            raise FLANNException("num_branches must be an integer >= 1.")

        num_branches = int(num_branches)

        if max_iterations is None:
            max_iterations = -1
        else:
            max_iterations = int(max_iterations)

        # init the arrays and starting values
        pts = ensure_2d_array(pts, default_flags)
        npts, dim = pts.shape
        num_clusters = (branch_size - 1) * num_branches + 1

        if pts.dtype.type == np.float64:
            result = np.empty((num_clusters, dim), dtype=np.float64)
        else:
            result = np.empty((num_clusters, dim), dtype=np.float32)

        # set all the parameters appropriately

        self.__ensureRandomSeed(kwargs)

        params = {
            "iterations": max_iterations,
            "algorithm": "kmeans",
            "branching": branch_size,
            "random_seed": kwargs["random_seed"],
        }

        self.__flann_parameters.update(params)

        numclusters = flann.compute_cluster_centers[pts.dtype.type](
            pts, npts, dim, num_clusters, result, pointer(self.__flann_parameters)
        )
        if numclusters <= 0:
            raise FLANNException("Error occured during clustering procedure.")

        if dtype is None:
            return result
        else:
            return dtype(result)

    ##########################################################################
    # internal bookkeeping functions

    def __ensureRandomSeed(self, kwargs):
        if "random_seed" not in kwargs:
            kwargs["random_seed"] = self.__rn_gen.randint(2 ** 30)