class NNEncode():
    ''' Encode points using NN search and Gaussian kernel '''
    def __init__(self, NN, sigma, km_filepath='', cc=-1):
        if (check_value(cc, -1)):
            self.cc = np.load(km_filepath)
        else:
            self.cc = cc
        self.K = self.cc.shape[0]
        self.NN = int(NN)
        self.sigma = sigma
        self.nbrs = NearestNeighbors(n_neighbors=NN,
                                     algorithm='ball_tree').fit(self.cc)
        self.closest_neighbor = NearestNeighbors(n_neighbors=1,
                                                 algorithm='ball_tree').fit(
                                                     self.cc)

        self.alreadyUsed = False

    def encode_points_mtx_nd(self,
                             pts_nd,
                             axis=1,
                             return_sparse=False,
                             sameBlock=True):
        pts_flt = flatten_nd_array(pts_nd, axis=axis)
        P = pts_flt.shape[0]
        if (sameBlock and self.alreadyUsed):
            self.pts_enc_flt[...] = 0  # already pre-allocated
        else:
            self.alreadyUsed = True
            self.pts_enc_flt = np.zeros((P, self.K))
            self.p_inds = np.arange(0, P, dtype='int')[:, na()]

        P = pts_flt.shape[0]

        if return_sparse:
            (dists, inds) = self.nbrs.closest_neighbor(pts_flt)
        else:
            (dists, inds) = self.nbrs.kneighbors(pts_flt)

        wts = np.exp(-dists**2 / (2 * self.sigma**2))
        wts = wts / np.sum(wts, axis=1)[:, na()]

        self.pts_enc_flt[self.p_inds, inds] = wts
        pts_enc_nd = unflatten_2d_array(self.pts_enc_flt, pts_nd, axis=axis)

        return pts_enc_nd

    def decode_points_mtx_nd(self, pts_enc_nd, axis=1):
        pts_enc_flt = flatten_nd_array(pts_enc_nd, axis=axis)
        pts_dec_flt = np.dot(pts_enc_flt, self.cc)
        pts_dec_nd = unflatten_2d_array(pts_dec_flt, pts_enc_nd, axis=axis)
        return pts_dec_nd