class NNF():
    def __init__(self,
                 image,
                 target_mask,
                 source_mask,
                 patch_size=(11, 11),
                 patch_weight=None,
                 num_neighbors=1):
        im_h, im_w, im_ch = image.shape

        if patch_weight is None:
            self.patch_weight = np.ones(patch_size, dtype=_im_dtype)

        self.patch_size = patch_size
        self.num_neighb = num_neighbors

        print("Build NNF index: ", end=" ")
        start = time.time()

        if _NN_algorithm != "PatchMatch":
            self.source_ind = op.masked_indices(source_mask)
            self.target_ind = op.masked_indices(target_mask)

            # convert array indices to patch indices
            pad = patch_size[0] // 2
            ind_y, ind_x = np.divmod(self.source_ind, im_w)
            self.source_ind = (ind_x - pad) + (ind_y - pad) * (im_w - 2 * pad)

            source_point_cloud = extract_patches_2d( image, patch_size=patch_size )[self.source_ind].reshape((self.source_ind.size,-1)) \
                   * np.repeat(np.sqrt(self.patch_weight),im_ch)

            # need this because of FLANN bug (?) with memory release
            self.target_point_cloud = np.zeros(
                (self.target_ind.size, source_point_cloud.shape[-1]),
                dtype=_im_dtype)

        if _NN_algorithm == "FLANN":
            self.nn = flann.FLANN()
            self.nn.build_index(source_point_cloud,
                                algorithm="kdtree",
                                trees=1)  #, log_level = "info")
        elif _NN_algorithm == "Sklearn":
            self.nn = NearestNeighbors(
                n_neighbors=num_neighbors,
                algorithm='kd_tree',
                metric='minkowski',
                n_jobs=-1)  #,metric_params={'w':self.patch_weight})
            self.nn.fit(X=source_point_cloud)
        elif _NN_algorithm == "FAISS":
            self.nn = faiss.IndexHNSWFlat(source_point_cloud.shape[1], 50)
            self.nn.add(source_point_cloud)

        if _NN_algorithm == "PatchMatch":
            self.nn = pm.PatchMatch(target_mask,
                                    source_mask,
                                    patch_size=patch_size,
                                    lambdas=np.ones_like(image,
                                                         dtype=_im_dtype))

        print('%f sec' % (time.time() - start))

    def calculate_nnf(self, image, init_guess=None):
        im_h, im_w, im_ch = image.shape

        print("Query NNF index: ", end=" ")
        start = time.time()

        if _NN_algorithm != "PatchMatch":
            ind_nnf = np.zeros((im_h * im_w, self.num_neighb), dtype='int32')
            dist_nnf = np.zeros((im_h * im_w, self.num_neighb))

            # convert array indices to patch indices
            pad = self.patch_size[0] // 2
            ind_y, ind_x = np.divmod(self.target_ind, im_w)
            ind = (ind_x - pad) + (ind_y - pad) * (im_w - 2 * pad)

            # need this because of FLANN bug (?) with memory release
            np.copyto(self.target_point_cloud,
                      extract_patches_2d(
                          image, patch_size=self.patch_size)[ind].reshape(
                              (self.target_ind.size, -1)),
                      casting='same_kind',
                      where=True)
            self.target_point_cloud *= np.repeat(np.sqrt(self.patch_weight),
                                                 im_ch)

        # note that "ind" are patch indices, not array indices
        if _NN_algorithm == "FLANN":
            ind, dist = self.nn.nn_index(self.target_point_cloud,
                                         self.num_neighb)
        elif _NN_algorithm == "Sklearn":
            dist, ind = self.nn.kneighbors(X=self.target_point_cloud,
                                           return_distance=True)
        elif _NN_algorithm == "FAISS":
            dist, ind = self.nn.search(self.target_point_cloud,
                                       self.num_neighb)

        if _NN_algorithm != "PatchMatch":
            ind = ind.reshape((ind.shape[0], self.num_neighb))
            dist = dist.reshape((dist.shape[0], self.num_neighb))

            # convert patch indices to array indices
            ind = self.source_ind[ind.ravel()]
            ind_y, ind_x = np.divmod(ind, im_w - 2 * pad)
            ind = (ind_x + pad) + (ind_y + pad) * im_w

            ind = np.reshape(ind, (-1, self.num_neighb))
            for n in range(self.num_neighb):
                ind_nnf[self.target_ind, :] = ind  #[:,n]
                dist_nnf[self.target_ind, :] = dist  #[:,n]
        elif _NN_algorithm == "PatchMatch":
            ind_nnf, dist_nnf = self.nn.find_nnf(image, init_guess=init_guess)

        print('%f sec' % (time.time() - start))

        ind_nnf = ind_nnf.reshape((-1, self.num_neighb))
        dist_nnf = dist_nnf.reshape((-1, self.num_neighb))

        return ind_nnf, dist_nnf