class NNF(): def __init__(self, image, target_mask, source_mask, patch_size=(11, 11), patch_weight=None, num_neighbors=1): im_h, im_w, im_ch = image.shape if patch_weight is None: self.patch_weight = np.ones(patch_size, dtype=_im_dtype) self.patch_size = patch_size self.num_neighb = num_neighbors print("Build NNF index: ", end=" ") start = time.time() if _NN_algorithm != "PatchMatch": self.source_ind = op.masked_indices(source_mask) self.target_ind = op.masked_indices(target_mask) # convert array indices to patch indices pad = patch_size[0] // 2 ind_y, ind_x = np.divmod(self.source_ind, im_w) self.source_ind = (ind_x - pad) + (ind_y - pad) * (im_w - 2 * pad) source_point_cloud = extract_patches_2d( image, patch_size=patch_size )[self.source_ind].reshape((self.source_ind.size,-1)) \ * np.repeat(np.sqrt(self.patch_weight),im_ch) # need this because of FLANN bug (?) with memory release self.target_point_cloud = np.zeros( (self.target_ind.size, source_point_cloud.shape[-1]), dtype=_im_dtype) if _NN_algorithm == "FLANN": self.nn = flann.FLANN() self.nn.build_index(source_point_cloud, algorithm="kdtree", trees=1) #, log_level = "info") elif _NN_algorithm == "Sklearn": self.nn = NearestNeighbors( n_neighbors=num_neighbors, algorithm='kd_tree', metric='minkowski', n_jobs=-1) #,metric_params={'w':self.patch_weight}) self.nn.fit(X=source_point_cloud) elif _NN_algorithm == "FAISS": self.nn = faiss.IndexHNSWFlat(source_point_cloud.shape[1], 50) self.nn.add(source_point_cloud) if _NN_algorithm == "PatchMatch": self.nn = pm.PatchMatch(target_mask, source_mask, patch_size=patch_size, lambdas=np.ones_like(image, dtype=_im_dtype)) print('%f sec' % (time.time() - start)) def calculate_nnf(self, image, init_guess=None): im_h, im_w, im_ch = image.shape print("Query NNF index: ", end=" ") start = time.time() if _NN_algorithm != "PatchMatch": ind_nnf = np.zeros((im_h * im_w, self.num_neighb), dtype='int32') dist_nnf = np.zeros((im_h * im_w, self.num_neighb)) # convert array indices to patch indices pad = self.patch_size[0] // 2 ind_y, ind_x = np.divmod(self.target_ind, im_w) ind = (ind_x - pad) + (ind_y - pad) * (im_w - 2 * pad) # need this because of FLANN bug (?) with memory release np.copyto(self.target_point_cloud, extract_patches_2d( image, patch_size=self.patch_size)[ind].reshape( (self.target_ind.size, -1)), casting='same_kind', where=True) self.target_point_cloud *= np.repeat(np.sqrt(self.patch_weight), im_ch) # note that "ind" are patch indices, not array indices if _NN_algorithm == "FLANN": ind, dist = self.nn.nn_index(self.target_point_cloud, self.num_neighb) elif _NN_algorithm == "Sklearn": dist, ind = self.nn.kneighbors(X=self.target_point_cloud, return_distance=True) elif _NN_algorithm == "FAISS": dist, ind = self.nn.search(self.target_point_cloud, self.num_neighb) if _NN_algorithm != "PatchMatch": ind = ind.reshape((ind.shape[0], self.num_neighb)) dist = dist.reshape((dist.shape[0], self.num_neighb)) # convert patch indices to array indices ind = self.source_ind[ind.ravel()] ind_y, ind_x = np.divmod(ind, im_w - 2 * pad) ind = (ind_x + pad) + (ind_y + pad) * im_w ind = np.reshape(ind, (-1, self.num_neighb)) for n in range(self.num_neighb): ind_nnf[self.target_ind, :] = ind #[:,n] dist_nnf[self.target_ind, :] = dist #[:,n] elif _NN_algorithm == "PatchMatch": ind_nnf, dist_nnf = self.nn.find_nnf(image, init_guess=init_guess) print('%f sec' % (time.time() - start)) ind_nnf = ind_nnf.reshape((-1, self.num_neighb)) dist_nnf = dist_nnf.reshape((-1, self.num_neighb)) return ind_nnf, dist_nnf