class NNEncode(): ''' Encode points using NN search and Gaussian kernel ''' def __init__(self, NN, sigma, km_filepath='', cc=-1): if (check_value(cc, -1)): self.cc = np.load(km_filepath) else: self.cc = cc self.K = self.cc.shape[0] self.NN = int(NN) self.sigma = sigma self.nbrs = NearestNeighbors(n_neighbors=NN, algorithm='ball_tree').fit(self.cc) self.closest_neighbor = NearestNeighbors(n_neighbors=1, algorithm='ball_tree').fit( self.cc) self.alreadyUsed = False def encode_points_mtx_nd(self, pts_nd, axis=1, return_sparse=False, sameBlock=True): pts_flt = flatten_nd_array(pts_nd, axis=axis) P = pts_flt.shape[0] if (sameBlock and self.alreadyUsed): self.pts_enc_flt[...] = 0 # already pre-allocated else: self.alreadyUsed = True self.pts_enc_flt = np.zeros((P, self.K)) self.p_inds = np.arange(0, P, dtype='int')[:, na()] P = pts_flt.shape[0] if return_sparse: (dists, inds) = self.nbrs.closest_neighbor(pts_flt) else: (dists, inds) = self.nbrs.kneighbors(pts_flt) wts = np.exp(-dists**2 / (2 * self.sigma**2)) wts = wts / np.sum(wts, axis=1)[:, na()] self.pts_enc_flt[self.p_inds, inds] = wts pts_enc_nd = unflatten_2d_array(self.pts_enc_flt, pts_nd, axis=axis) return pts_enc_nd def decode_points_mtx_nd(self, pts_enc_nd, axis=1): pts_enc_flt = flatten_nd_array(pts_enc_nd, axis=axis) pts_dec_flt = np.dot(pts_enc_flt, self.cc) pts_dec_nd = unflatten_2d_array(pts_dec_flt, pts_enc_nd, axis=axis) return pts_dec_nd