Ejemplo n.º 1
0
def export_descriptor(config, output_dir, args):
    """
    # input 2 images, output keypoints and correspondence
    save prediction:
        pred:
            'image': np(320,240)
            'prob' (keypoints): np (N1, 2)
            'desc': np (N2, 256)
            'warped_image': np(320,240)
            'warped_prob' (keypoints): np (N2, 2)
            'warped_desc': np (N2, 256)
            'homography': np (3,3)
            'matches': np [N3, 4]
    """
    from utils.loader import get_save_path
    from utils.var_dim import squeezeToNumpy

    # basic settings
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    logging.info("train on device: %s", device)
    with open(os.path.join(output_dir, "config.yml"), "w") as f:
        yaml.dump(config, f, default_flow_style=False)
    writer = SummaryWriter(getWriterPath(task=args.command, date=True))
    save_path = get_save_path(output_dir)
    save_output = save_path / "../predictions"
    os.makedirs(save_output, exist_ok=True)

    ## parameters
    outputMatches = True
    subpixel = config["model"]["subpixel"]["enable"]
    patch_size = config["model"]["subpixel"]["patch_size"]

    # data loading
    from utils.loader import dataLoader_test as dataLoader
    task = config["data"]["dataset"]
    data = dataLoader(config, dataset=task)
    test_set, test_loader = data["test_set"], data["test_loader"]
    from utils.print_tool import datasize
    datasize(test_loader, config, tag="test")

    # model loading
    from utils.loader import get_module
    Val_model_heatmap = get_module("", config["front_end_model"])
    ## load pretrained
    val_agent = Val_model_heatmap(config["model"], device=device)
    val_agent.loadModel()

    ## tracker
    tracker = PointTracker(max_length=2, nn_thresh=val_agent.nn_thresh)

    ###### check!!!
    count = 0
    for i, sample in tqdm(enumerate(test_loader)):
        img_0, img_1 = sample["image"], sample["warped_image"]

        # first image, no matches
        # img = img_0
        def get_pts_desc_from_agent(val_agent, img, device="cpu"):
            """
            pts: list [numpy (3, N)]
            desc: list [numpy (256, N)]
            """
            heatmap_batch = val_agent.run(
                img.to(device)
            )  # heatmap: numpy [batch, 1, H, W]
            # heatmap to pts
            pts = val_agent.heatmap_to_pts()
            # print("pts: ", pts)
            if subpixel:
                pts = val_agent.soft_argmax_points(pts, patch_size=patch_size)
            # heatmap, pts to desc
            desc_sparse = val_agent.desc_to_sparseDesc()
            # print("pts[0]: ", pts[0].shape, ", desc_sparse[0]: ", desc_sparse[0].shape)
            # print("pts[0]: ", pts[0].shape)
            outs = {"pts": pts[0], "desc": desc_sparse[0]}
            return outs

        def transpose_np_dict(outs):
            for entry in list(outs):
                outs[entry] = outs[entry].transpose()

        outs = get_pts_desc_from_agent(val_agent, img_0, device=device)
        pts, desc = outs["pts"], outs["desc"]  # pts: np [3, N]

        if outputMatches == True:
            tracker.update(pts, desc)

        # save keypoints
        pred = {"image": squeezeToNumpy(img_0)}
        pred.update({"prob": pts.transpose(), "desc": desc.transpose()})

        # second image, output matches
        outs = get_pts_desc_from_agent(val_agent, img_1, device=device)
        pts, desc = outs["pts"], outs["desc"]

        if outputMatches == True:
            tracker.update(pts, desc)

        pred.update({"warped_image": squeezeToNumpy(img_1)})
        # print("total points: ", pts.shape)
        pred.update(
            {
                "warped_prob": pts.transpose(),
                "warped_desc": desc.transpose(),
                "homography": squeezeToNumpy(sample["homography"]),
            }
        )

        if outputMatches == True:
            matches = tracker.get_matches()
            print("matches: ", matches.transpose().shape)
            pred.update({"matches": matches.transpose()})
        print("pts: ", pts.shape, ", desc: ", desc.shape)

        # clean last descriptor
        tracker.clear_desc()

        filename = str(count)
        path = Path(save_output, "{}.npz".format(filename))
        np.savez_compressed(path, **pred)
        # print("save: ", path)
        count += 1
    print("output pairs: ", count)
Ejemplo n.º 2
0
    def __getitem__(self, index):
        '''

        :param index:
        :return:
            image: tensor (H, W, channel=1)
        '''
        def _read_image(path):
            cell = 8
            input_image = cv2.imread(path)
            # print(f"path: {path}, image: {image}")
            # print(f"path: {path}, image: {input_image.shape}")
            input_image = cv2.resize(input_image,
                                     (self.sizer[1], self.sizer[0]),
                                     interpolation=cv2.INTER_AREA)
            H, W = input_image.shape[0], input_image.shape[1]
            # H = H//cell*cell
            # W = W//cell*cell
            # input_image = input_image[:H,:W,:]
            input_image = cv2.cvtColor(input_image, cv2.COLOR_RGB2GRAY)

            input_image = input_image.astype('float32') / 255.0
            return input_image

        def _preprocess(image):
            if self.transforms is not None:
                image = self.transforms(image)
            return image

        def get_labels_gaussian(pnts, subpixel=False):
            heatmaps = np.zeros((H, W))
            if subpixel:
                print("pnt: ", pnts.shape)
                for center in pnts:
                    heatmaps = self.putGaussianMaps(center, heatmaps)
            else:
                aug_par = {'photometric': {}}
                aug_par['photometric']['enable'] = True
                aug_par['photometric']['params'] = self.config[
                    'gaussian_label']['params']
                augmentation = self.ImgAugTransform(**aug_par)
                # get label_2D
                labels = points_to_2D(pnts, H, W)
                labels = labels[:, :, np.newaxis]
                heatmaps = augmentation(labels)

            # warped_labels_gaussian = torch.tensor(heatmaps).float().view(-1, H, W)
            warped_labels_gaussian = torch.tensor(heatmaps).type(
                torch.FloatTensor).view(-1, H, W)
            warped_labels_gaussian[warped_labels_gaussian > 1.] = 1.
            return warped_labels_gaussian

        from datasets.data_tools import np_to_tensor

        # def np_to_tensor(img, H, W):
        #     img = torch.tensor(img).type(torch.FloatTensor).view(-1, H, W)
        #     return img

        from datasets.data_tools import warpLabels

        def imgPhotometric(img):
            """

            :param img:
                numpy (H, W)
            :return:
            """
            augmentation = self.ImgAugTransform(**self.config['augmentation'])
            img = img[:, :, np.newaxis]
            img = augmentation(img)
            cusAug = self.customizedTransform()
            img = cusAug(img, **self.config['augmentation'])
            return img

        def points_to_2D(pnts, H, W):
            labels = np.zeros((H, W))
            pnts = pnts.astype(int)
            labels[pnts[:, 1], pnts[:, 0]] = 1
            return labels

        to_floatTensor = lambda x: torch.tensor(x).type(torch.FloatTensor)

        from numpy.linalg import inv
        sample = self.samples[index]
        sample = self.format_sample(sample)
        input = {}
        input.update(sample)
        # image
        # img_o = _read_image(self.get_img_from_sample(sample))
        img_o = _read_image(sample['image'])
        H, W = img_o.shape[0], img_o.shape[1]
        # print(f"image: {image.shape}")
        img_aug = img_o.copy()
        if (self.enable_photo_train == True
                and self.action == 'train') or (self.enable_photo_val
                                                and self.action == 'val'):
            img_aug = imgPhotometric(img_o)  # numpy array (H, W, 1)

        # img_aug = _preprocess(img_aug[:,:,np.newaxis])
        img_aug = torch.tensor(img_aug, dtype=torch.float32).view(-1, H, W)

        valid_mask = self.compute_valid_mask(torch.tensor([H, W]),
                                             inv_homography=torch.eye(3))
        input.update({'image': img_aug})
        input.update({'valid_mask': valid_mask})

        if self.config['homography_adaptation']['enable']:
            # img_aug = torch.tensor(img_aug)
            homoAdapt_iter = self.config['homography_adaptation']['num']
            homographies = np.stack([
                self.sample_homography(np.array([2, 2]),
                                       shift=-1,
                                       **self.config['homography_adaptation']
                                       ['homographies']['params'])
                for i in range(homoAdapt_iter)
            ])
            ##### use inverse from the sample homography
            homographies = np.stack(
                [inv(homography) for homography in homographies])
            homographies[0, :, :] = np.identity(3)
            # homographies_id = np.stack([homographies_id, homographies])[:-1,...]

            ######

            homographies = torch.tensor(homographies, dtype=torch.float32)
            inv_homographies = torch.stack([
                torch.inverse(homographies[i, :, :])
                for i in range(homoAdapt_iter)
            ])

            # images
            warped_img = self.inv_warp_image_batch(
                img_aug.squeeze().repeat(homoAdapt_iter, 1, 1, 1),
                inv_homographies,
                mode='bilinear').unsqueeze(0)
            warped_img = warped_img.squeeze()
            # masks
            valid_mask = self.compute_valid_mask(
                torch.tensor([H, W]),
                inv_homography=inv_homographies,
                erosion_radius=self.config['augmentation']['homographic']
                ['valid_border_margin'])
            input.update({
                'image': warped_img,
                'valid_mask': valid_mask,
                'image_2D': img_aug
            })
            input.update({
                'homographies': homographies,
                'inv_homographies': inv_homographies
            })

        # laebls
        if self.labels:
            pnts = np.load(sample['points'])['pts']
            # pnts = pnts.astype(int)
            # labels = np.zeros_like(img_o)
            # labels[pnts[:, 1], pnts[:, 0]] = 1
            labels = points_to_2D(pnts, H, W)
            labels_2D = to_floatTensor(labels[np.newaxis, :, :])
            input.update({'labels_2D': labels_2D})

            ## residual
            labels_res = torch.zeros((2, H, W)).type(torch.FloatTensor)
            input.update({'labels_res': labels_res})

            if (self.enable_homo_train == True
                    and self.action == 'train') or (self.enable_homo_val
                                                    and self.action == 'val'):
                homography = self.sample_homography(
                    np.array([2, 2]),
                    shift=-1,
                    **self.config['augmentation']['homographic']['params'])

                ##### use inverse from the sample homography
                homography = inv(homography)
                ######

                inv_homography = inv(homography)
                inv_homography = torch.tensor(inv_homography).to(torch.float32)
                homography = torch.tensor(homography).to(torch.float32)
                #                 img = torch.from_numpy(img)
                warped_img = self.inv_warp_image(img_aug.squeeze(),
                                                 inv_homography,
                                                 mode='bilinear').unsqueeze(0)
                # warped_img = warped_img.squeeze().numpy()
                # warped_img = warped_img[:,:,np.newaxis]

                ##### check: add photometric #####

                # labels = torch.from_numpy(labels)
                # warped_labels = self.inv_warp_image(labels.squeeze(), inv_homography, mode='nearest').unsqueeze(0)
                ##### check #####
                warped_set = warpLabels(pnts, H, W, homography)
                warped_labels = warped_set['labels']
                # if self.transform is not None:
                # warped_img = self.transform(warped_img)
                valid_mask = self.compute_valid_mask(
                    torch.tensor([H, W]),
                    inv_homography=inv_homography,
                    erosion_radius=self.config['augmentation']['homographic']
                    ['valid_border_margin'])

                input.update({
                    'image': warped_img,
                    'labels_2D': warped_labels,
                    'valid_mask': valid_mask
                })

            if self.config['warped_pair']['enable']:
                homography = self.sample_homography(
                    np.array([2, 2]),
                    shift=-1,
                    **self.config['warped_pair']['params'])

                ##### use inverse from the sample homography
                homography = np.linalg.inv(homography)
                #####
                inv_homography = np.linalg.inv(homography)

                homography = torch.tensor(homography).type(torch.FloatTensor)
                inv_homography = torch.tensor(inv_homography).type(
                    torch.FloatTensor)

                # photometric augmentation from original image

                # warp original image
                warped_img = torch.tensor(img_o, dtype=torch.float32)
                warped_img = self.inv_warp_image(warped_img.squeeze(),
                                                 inv_homography,
                                                 mode='bilinear').unsqueeze(0)
                if (self.enable_photo_train == True and self.action
                        == 'train') or (self.enable_photo_val
                                        and self.action == 'val'):
                    warped_img = imgPhotometric(
                        warped_img.numpy().squeeze())  # numpy array (H, W, 1)
                    warped_img = torch.tensor(warped_img, dtype=torch.float32)
                    pass
                warped_img = warped_img.view(-1, H, W)

                # warped_labels = warpLabels(pnts, H, W, homography)
                warped_set = warpLabels(pnts, H, W, homography, bilinear=True)
                warped_labels = warped_set['labels']
                warped_res = warped_set['res']
                warped_res = warped_res.transpose(1, 2).transpose(0, 1)
                # print("warped_res: ", warped_res.shape)
                if self.gaussian_label:
                    # print("do gaussian labels!")
                    # warped_labels_gaussian = get_labels_gaussian(warped_set['warped_pnts'].numpy())
                    from utils.var_dim import squeezeToNumpy
                    # warped_labels_bi = self.inv_warp_image(labels_2D.squeeze(), inv_homography, mode='nearest').unsqueeze(0) # bilinear, nearest
                    warped_labels_bi = warped_set['labels_bi']
                    warped_labels_gaussian = self.gaussian_blur(
                        squeezeToNumpy(warped_labels_bi))
                    warped_labels_gaussian = np_to_tensor(
                        warped_labels_gaussian, H, W)
                    input['warped_labels_gaussian'] = warped_labels_gaussian
                    input.update({'warped_labels_bi': warped_labels_bi})

                input.update({
                    'warped_img': warped_img,
                    'warped_labels': warped_labels,
                    'warped_res': warped_res
                })

                # print('erosion_radius', self.config['warped_pair']['valid_border_margin'])
                valid_mask = self.compute_valid_mask(
                    torch.tensor([H, W]),
                    inv_homography=inv_homography,
                    erosion_radius=self.config['warped_pair']
                    ['valid_border_margin'])  # can set to other value
                input.update({'warped_valid_mask': valid_mask})
                input.update({
                    'homographies': homography,
                    'inv_homographies': inv_homography
                })

            # labels = self.labels2Dto3D(self.cell_size, labels)
            # labels = torch.from_numpy(labels[np.newaxis,:,:])
            # input.update({'labels': labels})

            if self.gaussian_label:
                # warped_labels_gaussian = get_labels_gaussian(pnts)
                labels_gaussian = self.gaussian_blur(squeezeToNumpy(labels_2D))
                labels_gaussian = np_to_tensor(labels_gaussian, H, W)
                input['labels_2D_gaussian'] = labels_gaussian

        name = sample['name']
        to_numpy = False
        if to_numpy:
            image = np.array(img)

        input.update({'name': name, 'scene_name': "./"})  # dummy scene name
        return input
    def __getitem__(self, index):
        """
        :param index:
        :return:
            labels_2D: tensor(1, H, W)
            image: tensor(1, H, W)
        """
        def checkSat(img, name=""):
            if img.max() > 1:
                print(name, img.max())
            elif img.min() < 0:
                print(name, img.min())

        def imgPhotometric(img):
            """

            :param img:
                numpy (H, W)
            :return:
            """
            augmentation = self.ImgAugTransform(**self.config["augmentation"])
            img = img[:, :, np.newaxis]
            img = augmentation(img)
            cusAug = self.customizedTransform()
            img = cusAug(img, **self.config["augmentation"])
            return img

        def get_labels(pnts, H, W):
            labels = torch.zeros(H, W)
            # print('--2', pnts, pnts.size())
            # pnts_int = torch.min(pnts.round().long(), torch.tensor([[H-1, W-1]]).long())
            pnts_int = torch.min(pnts.round().long(),
                                 torch.tensor([[W - 1, H - 1]]).long())
            # print('--3', pnts_int, pnts_int.size())
            labels[pnts_int[:, 1], pnts_int[:, 0]] = 1
            return labels

        def get_label_res(H, W, pnts):
            quan = lambda x: x.round().long()
            labels_res = torch.zeros(H, W, 2)
            # pnts_int = torch.min(pnts.round().long(), torch.tensor([[H-1, W-1]]).long())

            labels_res[quan(pnts)[:, 1],
                       quan(pnts)[:, 0], :] = pnts - pnts.round()
            # print("pnts max: ", quan(pnts).max(dim=0))
            # print("labels_res: ", labels_res.shape)
            labels_res = labels_res.transpose(1, 2).transpose(0, 1)
            return labels_res

        from datasets.data_tools import np_to_tensor
        from utils.utils import filter_points
        from utils.var_dim import squeezeToNumpy

        sample = self.samples[index]
        img = load_as_float(sample["image"])
        H, W = img.shape[0], img.shape[1]
        self.H = H
        self.W = W
        pnts = np.load(sample["points"])  # (y, x)
        pnts = torch.tensor(pnts).float()
        pnts = torch.stack((pnts[:, 1], pnts[:, 0]), dim=1)  # (x, y)
        pnts = filter_points(pnts, torch.tensor([W, H]))
        sample = {}

        # print('pnts: ', pnts[:5])
        # print('--1', pnts)
        labels_2D = get_labels(pnts, H, W)
        sample.update({"labels_2D": labels_2D.unsqueeze(0)})

        # assert Hc == round(Hc) and Wc == round(Wc), "Input image size not fit in the block size"
        if (self.config["augmentation"]["photometric"]["enable_train"]
                and self.action == "training") or (
                    self.config["augmentation"]["photometric"]["enable_val"]
                    and self.action == "validation"):
            # print('>>> Photometric aug enabled for %s.'%self.action)
            # augmentation = self.ImgAugTransform(**self.config["augmentation"])
            img = imgPhotometric(img)
        else:
            # print('>>> Photometric aug disabled for %s.'%self.action)
            pass

        if not ((self.config["augmentation"]["homographic"]["enable_train"]
                 and self.action == "training") or
                (self.config["augmentation"]["homographic"]["enable_val"]
                 and self.action == "validation")):
            # print('<<< Homograpy aug disabled for %s.'%self.action)
            img = img[:, :, np.newaxis]
            # labels = labels.view(-1,H,W)
            if self.transform is not None:
                img = self.transform(img)
            sample["image"] = img
            # sample = {'image': img, 'labels_2D': labels}
            valid_mask = self.compute_valid_mask(torch.tensor([H, W]),
                                                 inv_homography=torch.eye(3))
            sample.update({"valid_mask": valid_mask})
            labels_res = get_label_res(H, W, pnts)
            pnts_post = pnts
            # pnts_for_gaussian = pnts
        else:
            # print('>>> Homograpy aug enabled for %s.'%self.action)
            # img_warp = img
            from utils.utils import homography_scaling_torch as homography_scaling
            from numpy.linalg import inv

            homography = self.sample_homography(
                np.array([2, 2]),
                shift=-1,
                **self.config["augmentation"]["homographic"]["params"],
            )

            ##### use inverse from the sample homography
            homography = inv(homography)
            ######

            homography = torch.tensor(homography).float()
            inv_homography = homography.inverse()
            img = torch.from_numpy(img)
            warped_img = self.inv_warp_image(img.squeeze(),
                                             inv_homography,
                                             mode="bilinear")
            warped_img = warped_img.squeeze().numpy()
            warped_img = warped_img[:, :, np.newaxis]

            # labels = torch.from_numpy(labels)
            # warped_labels = self.inv_warp_image(labels.squeeze(), inv_homography, mode='nearest').unsqueeze(0)
            warped_pnts = self.warp_points(
                pnts, homography_scaling(homography, H, W))
            warped_pnts = filter_points(warped_pnts, torch.tensor([W, H]))
            # pnts = warped_pnts[:, [1, 0]]
            # pnts_for_gaussian = warped_pnts
            # warped_labels = torch.zeros(H, W)
            # warped_labels[warped_pnts[:, 1], warped_pnts[:, 0]] = 1
            # warped_labels = warped_labels.view(-1, H, W)

            if self.transform is not None:
                warped_img = self.transform(warped_img)
            # sample = {'image': warped_img, 'labels_2D': warped_labels}
            sample["image"] = warped_img

            valid_mask = self.compute_valid_mask(
                torch.tensor([H, W]),
                inv_homography=inv_homography,
                erosion_radius=self.config["augmentation"]["homographic"]
                ["valid_border_margin"],
            )  # can set to other value
            sample.update({"valid_mask": valid_mask})

            labels_2D = get_labels(warped_pnts, H, W)
            sample.update({"labels_2D": labels_2D.unsqueeze(0)})

            labels_res = get_label_res(H, W, warped_pnts)
            pnts_post = warped_pnts

        if self.gaussian_label:
            # warped_labels_gaussian = get_labels_gaussian(pnts)
            from datasets.data_tools import get_labels_bi

            labels_2D_bi = get_labels_bi(pnts_post, H, W)

            labels_gaussian = self.gaussian_blur(squeezeToNumpy(labels_2D_bi))
            labels_gaussian = np_to_tensor(labels_gaussian, H, W)
            sample["labels_2D_gaussian"] = labels_gaussian

            # add residua

        sample.update({"labels_res": labels_res})

        ### code for warped image
        if self.config["warped_pair"]["enable"]:
            from datasets.data_tools import warpLabels

            homography = self.sample_homography(
                np.array([2, 2]),
                shift=-1,
                **self.config["warped_pair"]["params"])

            ##### use inverse from the sample homography
            homography = np.linalg.inv(homography)
            #####
            inv_homography = np.linalg.inv(homography)

            homography = torch.tensor(homography).type(torch.FloatTensor)
            inv_homography = torch.tensor(inv_homography).type(
                torch.FloatTensor)

            # photometric augmentation from original image

            # warp original image
            warped_img = img.type(torch.FloatTensor)
            warped_img = self.inv_warp_image(warped_img.squeeze(),
                                             inv_homography,
                                             mode="bilinear").unsqueeze(0)
            if (self.enable_photo_train == True
                    and self.action == "train") or (self.enable_photo_val
                                                    and self.action == "val"):
                warped_img = imgPhotometric(
                    warped_img.numpy().squeeze())  # numpy array (H, W, 1)
                warped_img = torch.tensor(warped_img, dtype=torch.float32)
                pass
            warped_img = warped_img.view(-1, H, W)

            # warped_labels = warpLabels(pnts, H, W, homography)
            warped_set = warpLabels(pnts, H, W, homography, bilinear=True)
            warped_labels = warped_set["labels"]
            warped_res = warped_set["res"]
            warped_res = warped_res.transpose(1, 2).transpose(0, 1)
            # print("warped_res: ", warped_res.shape)
            if self.gaussian_label:
                # print("do gaussian labels!")
                # warped_labels_gaussian = get_labels_gaussian(warped_set['warped_pnts'].numpy())
                # warped_labels_bi = self.inv_warp_image(labels_2D.squeeze(), inv_homography, mode='nearest').unsqueeze(0) # bilinear, nearest
                warped_labels_bi = warped_set["labels_bi"]
                warped_labels_gaussian = self.gaussian_blur(
                    squeezeToNumpy(warped_labels_bi))
                warped_labels_gaussian = np_to_tensor(warped_labels_gaussian,
                                                      H, W)
                sample["warped_labels_gaussian"] = warped_labels_gaussian
                sample.update({"warped_labels_bi": warped_labels_bi})

            sample.update({
                "warped_img": warped_img,
                "warped_labels": warped_labels,
                "warped_res": warped_res,
            })

            # print('erosion_radius', self.config['warped_pair']['valid_border_margin'])
            valid_mask = self.compute_valid_mask(
                torch.tensor([H, W]),
                inv_homography=inv_homography,
                erosion_radius=self.config["warped_pair"]
                ["valid_border_margin"],
            )  # can set to other value
            sample.update({"warped_valid_mask": valid_mask})
            sample.update({
                "homographies": homography,
                "inv_homographies": inv_homography
            })

        # labels = self.labels2Dto3D(self.cell_size, labels)
        # labels = torch.from_numpy(labels[np.newaxis,:,:])
        # input.update({'labels': labels})

        ### code for warped image

        # if self.config['gaussian_label']['enable']:
        #     heatmaps = np.zeros((H, W))
        #     # for center in pnts_int.numpy():
        #     for center in pnts[:, [1, 0]].numpy():
        #         # print("put points: ", center)
        #         heatmaps = self.putGaussianMaps(center, heatmaps)
        #     # import matplotlib.pyplot as plt
        #     # plt.figure(figsize=(5, 10))
        #     # plt.subplot(211)
        #     # plt.imshow(heatmaps)
        #     # plt.colorbar()
        #     # plt.subplot(212)
        #     # plt.imshow(np.squeeze(warped_labels.numpy()))
        #     # plt.show()
        #     # import time
        #     # time.sleep(500)
        #     # results = self.pool.map(self.putGaussianMaps_par, warped_pnts.numpy())

        #     warped_labels_gaussian = torch.from_numpy(heatmaps).view(-1, H, W)
        #     warped_labels_gaussian[warped_labels_gaussian>1.] = 1.

        #     sample['labels_2D_gaussian'] = warped_labels_gaussian

        if self.getPts:
            sample.update({"pts": pnts})

        return sample