Пример #1
0
def warpLabels(pnts, H, W, homography, bilinear=False):
    from utils.utils import homography_scaling_torch as homography_scaling
    from utils.utils import filter_points
    from utils.utils import warp_points
    if isinstance(pnts, torch.Tensor):
        pnts = pnts.long()
    else:
        pnts = torch.tensor(pnts).long()
    warped_pnts = warp_points(torch.stack((pnts[:, 0], pnts[:, 1]), dim=1),
                              homography_scaling(homography, H,
                                                 W))  # check the (x, y)
    outs = {}
    # warped_pnts
    # print("extrapolate_points!!")

    # ext_points = True
    if bilinear == True:
        warped_labels_bi = get_labels_bi(warped_pnts, H, W)
        outs['labels_bi'] = warped_labels_bi

    warped_pnts = filter_points(warped_pnts, torch.tensor([W, H]))
    warped_labels = scatter_points(warped_pnts, H, W, res_ext=1)

    warped_labels_res = torch.zeros(H, W, 2)
    warped_labels_res[
        quan(warped_pnts)[:, 1],
        quan(warped_pnts)[:, 0], :] = warped_pnts - warped_pnts.round()
    # print("res sum: ", (warped_pnts - warped_pnts.round()).sum())
    outs.update({
        'labels': warped_labels,
        'res': warped_labels_res,
        'warped_pnts': warped_pnts
    })
    return outs
Пример #2
0
    def __getitem__(self, index):
        """
        :param index:
        :return:
            labels_2D: tensor(1, H, W)
            image: tensor(1, H, W)
        """

        def imgPhotometric(img):
            """
            :param img:
                numpy (H, W)
            :return:
            """
            augmentation = self.ImgAugTransform(**self.config["augmentation"])
            img = img[:, :, np.newaxis]
            img = augmentation(img)
            cusAug = self.customizedTransform()
            img = cusAug(img, **self.config["augmentation"])
            return img


        from datasets.data_tools import np_to_tensor
        from utils.utils import filter_points
        from utils.var_dim import squeezeToNumpy

        sample = {}
        H, W = self.config["generation"]["image_size"]
        imgs, pnts, evts = synthetic_dataset.generate_random_shape((H, W), 5, None)
        idx = np.random.randint(1000)


        # Only take the last set of points
        pnts = torch.tensor(pnts[-1]).float()
        pnts = torch.stack((pnts[:, 1], pnts[:, 0]), dim=1)  # (x, y)
        pnts = filter_points(pnts, torch.tensor([H, W]))
        pnts_long = pnts.round().long()


        labels = torch.zeros(H, W)
        labels[pnts_long[:, 0], pnts_long[:, 1]] = 1
        valid_mask = self.compute_valid_mask(torch.tensor([H, W]), inv_homography=torch.eye(3))


        for i, evt in enumerate(evts):
            evts[i, 0] = imgPhotometric(evt[0, :, :]).squeeze()
            evts[i, 1] = imgPhotometric(evt[1, :, :]).squeeze()
        evts = torch.from_numpy(evts.astype(np.float32))


        # sample.update({"images": imgs})
        sample.update({"valid_mask": valid_mask})
        sample.update({"labels_2D": labels.unsqueeze(0)})
        # sample.update({"points": pnts})
        sample.update({"events": evts})


        return sample
Пример #3
0
def get_labels_bi(warped_pnts, H, W):
    from utils.utils import filter_points
    pnts_ext, res_ext = extrapolate_points(warped_pnts)
    # quan = lambda x: x.long()
    pnts_ext, mask = filter_points(pnts_ext,
                                   torch.tensor([W, H]),
                                   return_mask=True)
    res_ext = res_ext[mask]
    warped_labels_bi = scatter_points(pnts_ext, H, W, res_ext=res_ext)
    return warped_labels_bi
Пример #4
0
def warpLabels(pnts, homography, H, W):
    import torch
    """
    input:
        pnts: numpy
        homography: numpy
    output:
        warped_pnts: numpy
    """
    from utils.utils import warp_points
    from utils.utils import filter_points
    pnts = torch.tensor(pnts).long()
    homography = torch.tensor(homography, dtype=torch.float32)
    warped_pnts = warp_points(torch.stack((pnts[:, 0], pnts[:, 1]), dim=1),
                              homography)  # check the (x, y)
    warped_pnts = filter_points(warped_pnts, torch.tensor([W, H])).round().long()
    return warped_pnts.numpy()
def descriptor_loss_sparse(descriptors,
                           descriptors_warped,
                           homographies,
                           mask_valid=None,
                           cell_size=8,
                           device='cpu',
                           descriptor_dist=4,
                           lamda_d=250,
                           num_matching_attempts=1000,
                           num_masked_non_matches_per_match=10,
                           dist='cos',
                           method='1d',
                           **config):
    """
    consider batches of descriptors
    :param descriptors:
        Output from descriptor head
        tensor [descriptors, Hc, Wc]
    :param descriptors_warped:
        Output from descriptor head of warped image
        tensor [descriptors, Hc, Wc]
    """
    def uv_to_tuple(uv):
        return (uv[:, 0], uv[:, 1])

    def tuple_to_uv(uv_tuple):
        return torch.stack([uv_tuple[0], uv_tuple[1]])

    def tuple_to_1d(uv_tuple, W, uv=True):
        if uv:
            return uv_tuple[0] + uv_tuple[1] * W
        else:
            return uv_tuple[0] * W + uv_tuple[1]

    def uv_to_1d(points, W, uv=True):
        # assert points.dim == 2
        #     print("points: ", points[0])
        #     print("H: ", H)
        if uv:
            return points[..., 0] + points[..., 1] * W
        else:
            return points[..., 0] * W + points[..., 1]

    ## calculate matches loss
    def get_match_loss(image_a_pred,
                       image_b_pred,
                       matches_a,
                       matches_b,
                       dist='cos',
                       method='1d'):
        match_loss, matches_a_descriptors, matches_b_descriptors = \
            PixelwiseContrastiveLoss.match_loss(image_a_pred, image_b_pred,
                matches_a, matches_b, dist=dist, method=method)
        return match_loss

    def get_non_matches_corr(img_b_shape,
                             uv_a,
                             uv_b_matches,
                             num_masked_non_matches_per_match=10,
                             device='cpu'):
        ## sample non matches
        uv_b_matches = uv_b_matches.squeeze()
        uv_b_matches_tuple = uv_to_tuple(uv_b_matches)
        uv_b_non_matches_tuple = correspondence_finder.create_non_correspondences(
            uv_b_matches_tuple,
            img_b_shape,
            num_non_matches_per_match=num_masked_non_matches_per_match,
            img_b_mask=None)

        ## create_non_correspondences
        #     print("img_b_shape ", img_b_shape)
        #     print("uv_b_matches ", uv_b_matches.shape)
        # print("uv_a: ", uv_to_tuple(uv_a))
        # print("uv_b_non_matches: ", uv_b_non_matches)
        #     print("uv_b_non_matches: ", tensorUv2tuple(uv_b_non_matches))
        uv_a_tuple, uv_b_non_matches_tuple = \
            create_non_matches(uv_to_tuple(uv_a), uv_b_non_matches_tuple, num_masked_non_matches_per_match)
        return uv_a_tuple, uv_b_non_matches_tuple

    def get_non_match_loss(image_a_pred,
                           image_b_pred,
                           non_matches_a,
                           non_matches_b,
                           dist='cos'):
        ## non matches loss
        non_match_loss, num_hard_negatives, non_matches_a_descriptors, non_matches_b_descriptors = \
            PixelwiseContrastiveLoss.non_match_descriptor_loss(image_a_pred, image_b_pred,
                                                               non_matches_a.long().squeeze(),
                                                               non_matches_b.long().squeeze(),
                                                               M=0.2, invert=True, dist=dist)
        non_match_loss = non_match_loss.sum() / (num_hard_negatives + 1)
        return non_match_loss

    from utils.utils import filter_points
    from utils.utils import crop_or_pad_choice
    from utils.utils import normPts
    # ##### print configs
    # print("num_masked_non_matches_per_match: ", num_masked_non_matches_per_match)
    # print("num_matching_attempts: ", num_matching_attempts)
    # dist = 'cos'
    # print("method: ", method)

    Hc, Wc = descriptors.shape[1], descriptors.shape[2]
    img_shape = (Hc, Wc)

    # print("img_shape: ", img_shape)
    # img_shape_cpu = (Hc.to('cpu'), Wc.to('cpu'))

    # image_a_pred = descriptors.view(1, -1, Hc * Wc).transpose(1, 2)  # torch [batch_size, H*W, D]
    def descriptor_reshape(descriptors):
        descriptors = descriptors.view(-1, Hc * Wc).transpose(
            0, 1)  # torch [D, H, W] --> [H*W, d]
        descriptors = descriptors.unsqueeze(0)  # torch [1, H*W, D]
        return descriptors

    image_a_pred = descriptor_reshape(descriptors)  # torch [1, H*W, D]
    # print("image_a_pred: ", image_a_pred.shape)
    image_b_pred = descriptor_reshape(
        descriptors_warped)  # torch [batch_size, H*W, D]

    # matches
    uv_a = get_coor_cells(Hc, Wc, cell_size, uv=True, device='cpu')
    # print("uv_a: ", uv_a[0])

    homographies_H = scale_homography_torch(homographies,
                                            img_shape,
                                            shift=(-1, -1))

    # print("experiment inverse homographies")
    # homographies_H = torch.stack([torch.inverse(H) for H in homographies_H])
    # print("homographies_H: ", homographies_H.shape)
    # homographies_H = torch.inverse(homographies_H)

    uv_b_matches = warp_coor_cells_with_homographies(uv_a,
                                                     homographies_H.to('cpu'),
                                                     uv=True,
                                                     device='cpu')
    #
    # print("uv_b_matches before round: ", uv_b_matches[0])

    uv_b_matches.round_()
    # print("uv_b_matches after round: ", uv_b_matches[0])
    uv_b_matches = uv_b_matches.squeeze(0)

    # filtering out of range points
    # choice = crop_or_pad_choice(x_all.shape[0], self.sift_num, shuffle=True)

    uv_b_matches, mask = filter_points(uv_b_matches,
                                       torch.tensor([Wc, Hc]).to(device='cpu'),
                                       return_mask=True)
    # print ("pos mask sum: ", mask.sum())
    uv_a = uv_a[mask]

    # crop to the same length
    shuffle = True
    if not shuffle: print("shuffle: ", shuffle)
    choice = crop_or_pad_choice(uv_b_matches.shape[0],
                                num_matching_attempts,
                                shuffle=shuffle)
    choice = torch.tensor(choice)
    uv_a = uv_a[choice]
    uv_b_matches = uv_b_matches[choice]

    if method == '2d':
        matches_a = normPts(uv_a, torch.tensor([Wc, Hc]).float())  # [u, v]
        matches_b = normPts(uv_b_matches, torch.tensor([Wc, Hc]).float())
    else:
        matches_a = uv_to_1d(uv_a, Wc)
        matches_b = uv_to_1d(uv_b_matches, Wc)

    # print("matches_a: ", matches_a.shape)
    # print("matches_b: ", matches_b.shape)
    # print("matches_b max: ", matches_b.max())

    if method == '2d':
        match_loss = get_match_loss(descriptors,
                                    descriptors_warped,
                                    matches_a.to(device),
                                    matches_b.to(device),
                                    dist=dist,
                                    method='2d')
    else:
        match_loss = get_match_loss(image_a_pred,
                                    image_b_pred,
                                    matches_a.long().to(device),
                                    matches_b.long().to(device),
                                    dist=dist)

    # non matches

    # get non matches correspondence
    uv_a_tuple, uv_b_non_matches_tuple = get_non_matches_corr(
        img_shape,
        uv_a,
        uv_b_matches,
        num_masked_non_matches_per_match=num_masked_non_matches_per_match)

    non_matches_a = tuple_to_1d(uv_a_tuple, Wc)
    non_matches_b = tuple_to_1d(uv_b_non_matches_tuple, Wc)

    # print("non_matches_a: ", non_matches_a)
    # print("non_matches_b: ", non_matches_b)

    non_match_loss = get_non_match_loss(image_a_pred,
                                        image_b_pred,
                                        non_matches_a.to(device),
                                        non_matches_b.to(device),
                                        dist=dist)
    # non_match_loss = non_match_loss.mean()

    loss = lamda_d * match_loss + non_match_loss
    return loss, lamda_d * match_loss, non_match_loss
    pass
Пример #6
0
def descriptor_loss_sparse_reliability(descriptors,
                                       descriptors_warped,
                                       reliability,
                                       reliability_warp,
                                       homographies,
                                       aplosser,
                                       mask_valid=None,
                                       cell_size=8,
                                       device='cpu',
                                       descriptor_dist=4,
                                       lamda_d=250,
                                       lamda_r=1,
                                       num_matching_attempts=1000,
                                       num_masked_non_matches_per_match=10,
                                       dist='cos',
                                       method='1d',
                                       sos=True,
                                       reli_base=0.5,
                                       **config):
    """
    consider batches of descriptors
    :param descriptors:
        Output from descriptor head
        tensor [descriptors, Hc, Wc]
    :param descriptors_warped:
        Output from descriptor head of warped image
        tensor [descriptors, Hc, Wc]
    """
    def uv_to_tuple(uv):
        return (uv[:, 0], uv[:, 1])

    def tuple_to_uv(uv_tuple):
        return torch.stack([uv_tuple[0], uv_tuple[1]])

    def tuple_to_1d(uv_tuple, W, uv=True):
        if uv:
            return uv_tuple[0] + uv_tuple[1] * W
        else:
            return uv_tuple[0] * W + uv_tuple[1]

    def uv_to_1d(points, W, uv=True):
        # assert points.dim == 2
        #     print("points: ", points[0])
        #     print("H: ", H)
        if uv:
            return points[..., 0] + points[..., 1] * W
        else:
            return points[..., 0] * W + points[..., 1]

    ## calculate matches loss
    def get_match_loss(image_a_pred,
                       image_b_pred,
                       matches_a,
                       matches_b,
                       dist='cos',
                       method='1d',
                       sos=True):
        match_loss, matches_a_descriptors, matches_b_descriptors = \
            PixelwiseContrastiveLoss.match_loss(image_a_pred, image_b_pred,
                matches_a, matches_b, dist=dist, method=method, sos=sos)

        # Add SOS Loss
        return match_loss, matches_a_descriptors, matches_b_descriptors

    def get_non_matches_corr(img_b_shape,
                             uv_a,
                             uv_b_matches,
                             num_masked_non_matches_per_match=10,
                             device='cpu'):
        ## sample non matches
        uv_b_matches = uv_b_matches.squeeze()
        uv_b_matches_tuple = uv_to_tuple(uv_b_matches)
        uv_b_non_matches_tuple = correspondence_finder.create_non_correspondences(
            uv_b_matches_tuple,
            img_b_shape,
            num_non_matches_per_match=num_masked_non_matches_per_match,
            img_b_mask=None)

        ## create_non_correspondences
        #     print("img_b_shape ", img_b_shape)
        #     print("uv_b_matches ", uv_b_matches.shape)
        # print("uv_a: ", uv_to_tuple(uv_a))
        # print("uv_b_non_matches: ", uv_b_non_matches)
        #     print("uv_b_non_matches: ", tensorUv2tuple(uv_b_non_matches))
        uv_a_tuple, uv_b_non_matches_tuple = \
            create_non_matches(uv_to_tuple(uv_a), uv_b_non_matches_tuple, num_masked_non_matches_per_match)
        return uv_a_tuple, uv_b_non_matches_tuple

    def get_non_match_loss(image_a_pred,
                           image_b_pred,
                           non_matches_a,
                           non_matches_b,
                           dist='cos'):
        ## non matches loss
        non_match_loss, num_hard_negatives, non_matches_a_descriptors, non_matches_b_descriptors = \
            PixelwiseContrastiveLoss.non_match_descriptor_loss(image_a_pred, image_b_pred,
                                                               non_matches_a.long().squeeze(),
                                                               non_matches_b.long().squeeze(),
                                                               M=0.2, invert=True, dist=dist)
        non_match_loss = non_match_loss.sum() / (num_hard_negatives + 1)
        return non_match_loss

    # def reliability_loss(descriptors_a, descriptors_b, gt, uv_a, uv_b):
    #     dist = descriptors_a.mm(descriptors_b.T).pow(2)
    #     dist = dist / dist.max()
    #     acc = (dist * gt).sum(-1) # to be tested
    #     # max = dist.max()
    #     # min = dist.min()
    #     # dist_diagonal = dist.diagonal().pow(2)
    #     reli_a = reliability[:, uv_a[:, 1].long(), uv_a[:, 0].long()].transpose(1, 0)
    #     reli_b = reliability_warp[:, uv_b[:, 1].long(), uv_b[:, 0].long()]
    #     # reli = reli_a.mm(reli_b).pow(0.5)
    #     reli = (reli_a + reli_b) / 2
    #     ap_loss = 1 - acc * reli - (1 - reli) * reli_base
    #     # ap_loss = 1 - dist_diagonal * reli
    #     ap_loss = ap_loss.mean()
    #     return ap_loss

    def reliability_loss(descriptors_a,
                         descriptors_b,
                         image_b_pred,
                         reliability_a,
                         reliability_b,
                         uv_a,
                         uv_b,
                         img_b_shape,
                         aplosser,
                         method='1d',
                         device='cpu',
                         reli_base=0.5):
        uv_b = uv_b.squeeze()
        uv_b_matches_tuple = uv_to_tuple(uv_b)
        uv_b_non_matches_tuple = correspondence_finder.create_non_correspondences(
            uv_b_matches_tuple,
            img_b_shape,
            num_non_matches_per_match=1,
            img_b_mask=None)
        _, uv_b_non_matches_tuple = create_non_matches(uv_to_tuple(uv_a),
                                                       uv_b_non_matches_tuple,
                                                       1)
        uv_b_non = tuple_to_uv(uv_b_non_matches_tuple).squeeze().transpose(
            1, 0)  # negative sample coordinates

        # generate negative descriptors
        if method == '2d':

            def sampleDescriptors(image_a_pred, matches_a, mode, norm=False):
                image_a_pred = image_a_pred.unsqueeze(0)  # torch [1, D, H, W]
                matches_a.unsqueeze_(0).unsqueeze_(2)
                matches_a_descriptors = torch.nn.functional.grid_sample(
                    image_a_pred, matches_a, mode=mode, align_corners=True)
                matches_a_descriptors = matches_a_descriptors.squeeze(
                ).transpose(0, 1)

                # print("image_a_pred: ", image_a_pred.shape)
                # print("matches_a: ", matches_a.shape)
                # print("matches_a: ", matches_a)
                # print("matches_a_descriptors: ", matches_a_descriptors)
                if norm:
                    dn = torch.norm(matches_a_descriptors, p=2,
                                    dim=1)  # Compute the norm of b_descriptors
                    matches_a_descriptors = matches_a_descriptors.div(
                        torch.unsqueeze(dn, 1))  # Divide by norm to normalize.
                return matches_a_descriptors

            matches_b_non = normPts(
                uv_b_non,
                torch.tensor([img_b_shape[1], img_b_shape[0]]).float())
            descriptors_b_non = sampleDescriptors(image_b_pred,
                                                  matches_b_non.to(device),
                                                  mode='bilinear',
                                                  norm=False)
        else:
            matches_b_non = uv_to_1d(uv_b_matches, img_b_shape[1])
            descriptors_b_non = torch.index_select(
                image_b_pred, 1,
                matches_b_non.long().to(device))

        qconf = reliability_a[:, uv_a[:, 1].long(), uv_a[:, 0].long()] + \
                reliability_b[:, uv_b[:, 1].long(), uv_b[:, 0].long()]
        qconf /= 2

        pscores = (descriptors_a * descriptors_b).sum(-1)[:, None]
        nscores = (descriptors_a * descriptors_b_non).sum(-1)[:, None]
        scores = torch.cat((pscores, nscores), dim=1)

        gt = torch.zeros_like(scores, dtype=torch.uint8)
        gt[:, :pscores.shape[1]] = 1

        ap = aplosser(scores, gt)
        ap_loss = 1 - ap * qconf - (1 - qconf) * reli_base

        return ap_loss.mean()

    from utils.utils import filter_points
    from utils.utils import crop_or_pad_choice
    from utils.utils import normPts
    # ##### print configs
    # print("num_masked_non_matches_per_match: ", num_masked_non_matches_per_match)
    # print("num_matching_attempts: ", num_matching_attempts)
    # dist = 'cos'
    # print("method: ", method)

    Hc, Wc = descriptors.shape[1], descriptors.shape[2]
    img_shape = (Hc, Wc)

    # print("img_shape: ", img_shape)
    # img_shape_cpu = (Hc.to('cpu'), Wc.to('cpu'))

    # image_a_pred = descriptors.view(1, -1, Hc * Wc).transpose(1, 2)  # torch [batch_size, H*W, D]
    def descriptor_reshape(descriptors):
        descriptors = descriptors.view(-1, Hc * Wc).transpose(
            0, 1)  # torch [D, H, W] --> [H*W, d]
        descriptors = descriptors.unsqueeze(0)  # torch [1, H*W, D]
        return descriptors

    image_a_pred = descriptor_reshape(descriptors)  # torch [1, H*W, D]
    # print("image_a_pred: ", image_a_pred.shape)
    image_b_pred = descriptor_reshape(
        descriptors_warped)  # torch [batch_size, H*W, D]

    # matches
    uv_a = get_coor_cells(Hc, Wc, cell_size, uv=True, device='cpu')
    # print("uv_a: ", uv_a[0])

    homographies_H = scale_homography_torch(
        homographies, img_shape, shift=(-1, -1)
    )  # original scale is for image size, now downscale it to descriptor tensor size

    # print("experiment inverse homographies")
    # homographies_H = torch.stack([torch.inverse(H) for H in homographies_H])
    # print("homographies_H: ", homographies_H.shape)
    # homographies_H = torch.inverse(homographies_H)

    uv_b_matches = warp_coor_cells_with_homographies(uv_a,
                                                     homographies_H.to('cpu'),
                                                     uv=True,
                                                     device='cpu')
    #
    # print("uv_b_matches before round: ", uv_b_matches[0])

    uv_b_matches.round_()
    # print("uv_b_matches after round: ", uv_b_matches[0])
    uv_b_matches = uv_b_matches.squeeze(0)

    # filtering out of range points
    # choice = crop_or_pad_choice(x_all.shape[0], self.sift_num, shuffle=True)

    uv_b_matches, mask = filter_points(uv_b_matches,
                                       torch.tensor([Wc, Hc]).to(device='cpu'),
                                       return_mask=True)
    # print ("pos mask sum: ", mask.sum())
    uv_a = uv_a[mask]  # uv_a, uv_b_matches are the gt

    # crop to the same length
    shuffle = True
    if not shuffle: print("shuffle: ", shuffle)
    choice = crop_or_pad_choice(uv_b_matches.shape[0],
                                num_matching_attempts,
                                shuffle=shuffle)
    choice = torch.tensor(choice)
    uv_a = uv_a[choice]
    uv_b_matches = uv_b_matches[choice]

    ## add reliability,

    if method == '2d':
        matches_a = normPts(uv_a, torch.tensor([Wc, Hc]).float())  # [u, v]
        matches_b = normPts(uv_b_matches, torch.tensor([Wc, Hc]).float())
    else:
        matches_a = uv_to_1d(uv_a, Wc)
        matches_b = uv_to_1d(uv_b_matches, Wc)

    # print("matches_a: ", matches_a.shape)
    # print("matches_b: ", matches_b.shape)
    # print("matches_b max: ", matches_b.max())

    if method == '2d':
        match_loss, matches_a_descriptors, matches_b_descriptors = get_match_loss(
            descriptors,
            descriptors_warped,
            matches_a.to(device),
            matches_b.to(device),
            dist=dist,
            method='2d',
            sos=sos)
    else:
        match_loss, matches_a_descriptors, matches_b_descriptors = get_match_loss(
            image_a_pred,
            image_b_pred,
            matches_a.long().to(device),
            matches_b.long().to(device),
            dist=dist,
            sos=sos)

    # reliability loss
    matches_a_descriptors, matches_b_descriptors = matches_a_descriptors.squeeze(
    ), matches_b_descriptors.squeeze()
    # gt = torch.eye(matches_a_descriptors.shape[0]).to(device)
    if method == '2d':
        ap_loss = reliability_loss(matches_a_descriptors,
                                   matches_b_descriptors,
                                   descriptors_warped,
                                   reliability,
                                   reliability_warp,
                                   uv_a,
                                   uv_b_matches,
                                   img_shape,
                                   aplosser,
                                   method=method,
                                   device=device,
                                   reli_base=reli_base)
    else:
        ap_loss = reliability_loss(matches_a_descriptors,
                                   matches_b_descriptors,
                                   image_b_pred,
                                   reliability,
                                   reliability_warp,
                                   uv_a,
                                   uv_b_matches,
                                   img_shape,
                                   aplosser,
                                   method=method,
                                   device=device,
                                   reli_base=reli_base)

    # non matches

    # get non matches correspondence
    uv_a_tuple, uv_b_non_matches_tuple = get_non_matches_corr(
        img_shape,
        uv_a,
        uv_b_matches,
        num_masked_non_matches_per_match=num_masked_non_matches_per_match)

    non_matches_a = tuple_to_1d(uv_a_tuple, Wc)
    non_matches_b = tuple_to_1d(uv_b_non_matches_tuple, Wc)

    # print("non_matches_a: ", non_matches_a)
    # print("non_matches_b: ", non_matches_b)

    non_match_loss = get_non_match_loss(image_a_pred,
                                        image_b_pred,
                                        non_matches_a.to(device),
                                        non_matches_b.to(device),
                                        dist=dist)
    # non_match_loss = non_match_loss.mean()

    loss = lamda_d * match_loss + lamda_r * ap_loss + non_match_loss
    return loss, lamda_d * match_loss, non_match_loss, lamda_r * ap_loss
    pass
def descriptor_loss_sparse(descriptors,
                           descriptors_warped,
                           homographies,
                           mask_valid=None,
                           cell_size=8,
                           device='cpu',
                           descriptor_dist=4,
                           lamda_d=250,
                           num_matching_attempts=1000,
                           num_masked_non_matches_per_match=10,
                           **config):
    """
    consider batches of descriptors
    :param descriptors:
        Output from descriptor head
        tensor [descriptors, Hc, Wc]
    :param descriptors_warped:
        Output from descriptor head of warped image
        tensor [descriptors, Hc, Wc]
    """
    def uv_to_tuple(uv):
        return (uv[:, 0], uv[:, 1])

    def tuple_to_uv(uv_tuple):
        return torch.stack([uv_tuple[:, 0], uv_tuple[:, 1]])

    def tuple_to_1d(uv_tuple, H):
        return uv_tuple[0] * H + uv_tuple[1]

    def uv_to_1d(points, H):
        # assert points.dim == 2
        #     print("points: ", points[0])
        #     print("H: ", H)
        return points[..., 0] * H + points[..., 1]

    ## calculate matches loss
    def get_match_loss(image_a_pred, image_b_pred, matches_a, matches_b):
        match_loss, matches_a_descriptors, matches_b_descriptors = \
            PixelwiseContrastiveLoss.match_loss(image_a_pred, image_b_pred, matches_a.long(), matches_b.long())
        return match_loss

    def get_non_matches_corr(img_b_shape,
                             uv_a,
                             uv_b_matches,
                             num_masked_non_matches_per_match=10):
        ## sample non matches
        uv_b_matches = uv_b_matches.squeeze()
        uv_b_matches_tuple = uv_to_tuple(uv_b_matches)
        uv_b_non_matches_tuple = correspondence_finder.create_non_correspondences(
            uv_b_matches_tuple,
            img_b_shape,
            num_non_matches_per_match=num_masked_non_matches_per_match,
            img_b_mask=None)

        ## create_non_correspondences
        #     print("img_b_shape ", img_b_shape)
        #     print("uv_b_matches ", uv_b_matches.shape)
        # print("uv_a: ", uv_to_tuple(uv_a))
        # print("uv_b_non_matches: ", uv_b_non_matches)
        #     print("uv_b_non_matches: ", tensorUv2tuple(uv_b_non_matches))
        uv_a_tuple, uv_b_non_matches_tuple = \
            create_non_matches(uv_to_tuple(uv_a), uv_b_non_matches_tuple, num_masked_non_matches_per_match)
        return uv_a_tuple, uv_b_non_matches_tuple

    def get_non_match_loss(image_a_pred, image_b_pred, non_matches_a,
                           non_matches_b):
        ## non matches loss
        non_match_loss, num_hard_negatives, non_matches_a_descriptors, non_matches_b_descriptors = \
                        PixelwiseContrastiveLoss.non_match_descriptor_loss(image_a_pred, image_b_pred,
                                                   non_matches_a.long().squeeze(), non_matches_b.long().squeeze())
        return non_match_loss

    from utils.utils import filter_points
    from utils.utils import crop_or_pad_choice

    Hc, Wc = descriptors.shape[1], descriptors.shape[2]

    image_a_pred = descriptors.view(1, -1, Hc * Wc).transpose(
        1, 2)  # torch [batch_size, H*W, D]
    # print("image_a_pred: ", image_a_pred.shape)
    image_b_pred = descriptors_warped.view(1, -1, Hc * Wc).transpose(
        1, 2)  # torch [batch_size, H*W, D]

    # matches
    uv_a = get_coor_cells(Hc, Wc, cell_size, uv=True)
    # print("uv_a: ", uv_a.shape)

    homographies_H = scale_homography_torch(homographies,
                                            image_shape,
                                            shift=(-1, -1))

    uv_b_matches = warp_coor_cells_with_homographies(uv_a,
                                                     homographies_H,
                                                     uv=True)
    uv_b_matches = uv_b_matches.squeeze(0)
    # print("uv_b_matches: ", uv_b_matches.shape)

    # filtering out of range points
    # choice = crop_or_pad_choice(x_all.shape[0], self.sift_num, shuffle=True)

    uv_b_matches, mask = filter_points(uv_b_matches,
                                       torch.tensor([Wc, Hc]),
                                       return_mask=True)
    uv_a = uv_a[mask]

    # crop to the same length
    choice = crop_or_pad_choice(uv_b_matches.shape[0],
                                num_matching_attempts,
                                shuffle=True)
    choice = torch.tensor(choice)
    uv_a = uv_a[choice]
    uv_b_matches = uv_b_matches[choice]

    matches_a = uv_to_1d(uv_a, Hc)
    matches_b = uv_to_1d(uv_b_matches, Hc)

    # print("matches_a: ", matches_a.shape)
    # print("matches_b: ", matches_b.shape)
    # print("matches_b max: ", matches_b.max())

    match_loss = get_match_loss(image_a_pred, image_b_pred, matches_a,
                                matches_b)

    # non matches
    img_b_shape = (Hc, Wc)

    # get non matches correspondence
    uv_a_tuple, uv_b_non_matches_tuple = get_non_matches_corr(
        img_b_shape,
        uv_a,
        uv_b_matches,
        num_masked_non_matches_per_match=num_masked_non_matches_per_match)

    non_matches_a = tuple_to_1d(uv_a_tuple, Hc)
    non_matches_b = tuple_to_1d(uv_b_non_matches_tuple, Hc)

    # print("non_matches_a: ", non_matches_a)
    # print("non_matches_b: ", non_matches_b)

    non_match_loss = get_non_match_loss(image_a_pred, image_b_pred,
                                        non_matches_a, non_matches_a)
    non_match_loss = non_match_loss.mean()

    loss = lamda_d * match_loss + non_match_loss
    return loss
    pass
                                 dim=1)  # (y, x) to (x, y)

    return coor_cells


coor_cells = get_coor_cells(Hc,
                            Wc,
                            cell_size=cell_size,
                            device=device,
                            uv=True)
print("coor_cells: ", coor_cells)
print("coor_cells: ", coor_cells.shape)

from utils.utils import filter_points
filtered_points, mask = filter_points(coor_cells,
                                      torch.tensor([Wc, Hc]),
                                      return_mask=True)


def warp_coor_cells_with_homographies(coor_cells, homographies, uv=False):
    from utils.utils import warp_points
    # warped_coor_cells = warp_points(coor_cells.view([-1, 2]), homographies, device)
    #     warped_coor_cells = normPts(coor_cells.view([-1, 2]), shape)
    warped_coor_cells = coor_cells
    if uv == False:
        warped_coor_cells = torch.stack(
            (warped_coor_cells[:, 1], warped_coor_cells[:, 0]),
            dim=1)  # (y, x) to (x, y)

    warped_coor_cells = warp_points(warped_coor_cells, homographies, device)
Пример #9
0
def quadruplet_descriptor_loss_sparse(descriptors,
                                      descriptors_warped,
                                      homographies,
                                      mask_valid=None,
                                      cell_size=8,
                                      device='cpu',
                                      descriptor_dist=4,
                                      lamda_d=250,
                                      num_matching_attempts=1000,
                                      num_masked_non_matches_per_match=10,
                                      dist='cos',
                                      method='1d',
                                      **config):
    """
    consider batches of descriptors
    :param descriptors:
        Output from descriptor head
        tensor [descriptors, Hc, Wc]
    :param descriptors_warped:
        Output from descriptor head of warped image
        tensor [descriptors, Hc, Wc]
    """

    from utils.utils import filter_points
    from utils.utils import crop_or_pad_choice
    from utils.utils import normPts
    # ##### print configs
    # print("num_masked_non_matches_per_match: ", num_masked_non_matches_per_match)
    # print("num_matching_attempts: ", num_matching_attempts)
    # dist = 'cos'
    # print("method: ", method)

    Hc, Wc, Dim = descriptors.shape[1], descriptors.shape[
        2], descriptors.shape[0]
    img_shape = (Hc, Wc)
    # print("img_shape: ", img_shape)
    # img_shape_cpu = (Hc.to('cpu'), Wc.to('cpu'))

    # image_a_pred = descriptors.view(1, -1, Hc * Wc).transpose(1, 2)  # torch [batch_size, H*W, D]

    image_a_pred = descriptor_reshape(descriptors, Hc, Wc)  # torch [1, H*W, D]
    # print("image_a_pred: ", image_a_pred.shape)
    image_b_pred = descriptor_reshape(descriptors_warped, Hc,
                                      Wc)  # torch [batch_size, H*W, D]

    # matches
    uv_a = get_coor_cells(Hc, Wc, cell_size, uv=True, device='cpu')
    # print("uv_a: ", uv_a[0])

    homographies_H = scale_homography_torch(homographies,
                                            img_shape,
                                            shift=(-1, -1))

    # print("experiment inverse homographies")
    # homographies_H = torch.stack([torch.inverse(H) for H in homographies_H])
    # print("homographies_H: ", homographies_H.shape)
    # homographies_H = torch.inverse(homographies_H)

    uv_b_matches = warp_coor_cells_with_homographies(uv_a,
                                                     homographies_H.to('cpu'),
                                                     uv=True,
                                                     device='cpu')
    #
    # print("uv_b_matches before round: ", uv_b_matches[0])

    uv_b_matches.round_()
    # print("uv_b_matches after round: ", uv_b_matches[0])
    uv_b_matches = uv_b_matches.squeeze(0)

    # filtering out of range points
    # choice = crop_or_pad_choice(x_all.shape[0], self.sift_num, shuffle=True)

    uv_b_matches, mask = filter_points(uv_b_matches,
                                       torch.tensor([Wc, Hc]).to(device='cpu'),
                                       return_mask=True)
    # print ("pos mask sum: ", mask.sum())
    uv_a = uv_a[mask]

    # crop to the same length
    shuffle = True
    if not shuffle: print("shuffle: ", shuffle)
    choice = crop_or_pad_choice(uv_b_matches.shape[0],
                                num_matching_attempts,
                                shuffle=shuffle)
    choice = torch.tensor(choice)
    uv_a = uv_a[choice]
    uv_b_matches = uv_b_matches[choice]

    matches_a = normPts(uv_a, torch.tensor([Wc, Hc]).float())  # [u, v]
    matches_b = normPts(uv_b_matches, torch.tensor([Wc, Hc]).float())

    # print("matches_a: ", matches_a.shape)
    # print("matches_b: ", matches_b.shape)
    # print("matches_b max: ", matches_b.max())

    # non matches
    # get non matches correspondence
    uv_a_tuple, uv_b_matches_tuple, uv_b_non_matches_tuple = get_first_term_corr(
        img_shape,
        uv_a,
        uv_b_matches,
        num_masked_non_matches_per_match=num_masked_non_matches_per_match)

    non_matches_a = tuple_to_1d(uv_a_tuple, Wc)
    long_matches_b = tuple_to_1d(uv_b_matches_tuple, Wc)
    non_matches_b = tuple_to_1d(uv_b_non_matches_tuple, Wc)

    non_matches_a_descriptors = torch.index_select(
        image_a_pred, 1,
        non_matches_a.to(device).long().squeeze()).squeeze()
    non_matches_b_descriptors = torch.index_select(
        image_b_pred, 1,
        non_matches_b.to(device).long().squeeze()).squeeze()
    long_matches_b_descriptors = torch.index_select(
        image_b_pred, 1,
        long_matches_b.to(device).long().squeeze()).squeeze()

    all_neg_pairs = (non_matches_a_descriptors -
                     non_matches_b_descriptors).pow(2).sum(dim=-1)
    all_pos_pairs = (non_matches_a_descriptors -
                     long_matches_b_descriptors).pow(2).sum(dim=-1)

    alpha = (all_neg_pairs.sum() - all_pos_pairs.sum()) / (
        num_matching_attempts * num_masked_non_matches_per_match)

    first_term = torch.clamp(all_pos_pairs - all_neg_pairs + alpha,
                             min=0).sum() / (num_matching_attempts *
                                             num_masked_non_matches_per_match)

    pos_matches_a = non_matches_a_descriptors.reshape(
        (num_matching_attempts, num_masked_non_matches_per_match,
         Dim)).repeat(1, num_masked_non_matches_per_match, 1).reshape(
             (-1, Dim))
    pos_matches_b = long_matches_b_descriptors.reshape(
        (num_matching_attempts, num_masked_non_matches_per_match,
         Dim)).repeat(1, num_masked_non_matches_per_match, 1).reshape(
             (-1, Dim))

    neg_matches_b = non_matches_b_descriptors.reshape(
        (num_matching_attempts, num_masked_non_matches_per_match, 1,
         Dim)).repeat(1, 1, num_masked_non_matches_per_match, 1).reshape(
             (-1, Dim))
    neg_matches_b_permute = non_matches_b_descriptors.reshape(
        (num_matching_attempts, 1, num_masked_non_matches_per_match,
         Dim)).repeat(1, num_masked_non_matches_per_match, 1, 1).reshape(
             (-1, Dim))

    pos_matches = (pos_matches_a - pos_matches_b).pow(2).sum(dim=-1)
    neg_matches = (neg_matches_b - neg_matches_b_permute).pow(2).sum(dim=-1)

    match_mask = (torch.ones(num_masked_non_matches_per_match,
                             num_masked_non_matches_per_match) -
                  torch.eye(num_masked_non_matches_per_match)).repeat(
                      num_matching_attempts, 1, 1).flatten().to(device)

    second_term = (torch.clamp(
        (pos_matches - neg_matches) * match_mask + 0.5 * alpha, min=0) /
                   (num_masked_non_matches_per_match *
                    (num_masked_non_matches_per_match - 1) *
                    num_matching_attempts)).sum()

    # non_match_loss = non_match_loss.mean()

    loss = lamda_d * first_term + second_term
    return loss, lamda_d * first_term, second_term
    def __getitem__(self, index):
        """
        :param index:
        :return:
            labels_2D: tensor(1, H, W)
            image: tensor(1, H, W)
        """
        def checkSat(img, name=""):
            if img.max() > 1:
                print(name, img.max())
            elif img.min() < 0:
                print(name, img.min())

        def imgPhotometric(img):
            """

            :param img:
                numpy (H, W)
            :return:
            """
            augmentation = self.ImgAugTransform(**self.config["augmentation"])
            img = img[:, :, np.newaxis]
            img = augmentation(img)
            cusAug = self.customizedTransform()
            img = cusAug(img, **self.config["augmentation"])
            return img

        def get_labels(pnts, H, W):
            labels = torch.zeros(H, W)
            # print('--2', pnts, pnts.size())
            # pnts_int = torch.min(pnts.round().long(), torch.tensor([[H-1, W-1]]).long())
            pnts_int = torch.min(pnts.round().long(),
                                 torch.tensor([[W - 1, H - 1]]).long())
            # print('--3', pnts_int, pnts_int.size())
            labels[pnts_int[:, 1], pnts_int[:, 0]] = 1
            return labels

        def get_label_res(H, W, pnts):
            quan = lambda x: x.round().long()
            labels_res = torch.zeros(H, W, 2)
            # pnts_int = torch.min(pnts.round().long(), torch.tensor([[H-1, W-1]]).long())

            labels_res[quan(pnts)[:, 1],
                       quan(pnts)[:, 0], :] = pnts - pnts.round()
            # print("pnts max: ", quan(pnts).max(dim=0))
            # print("labels_res: ", labels_res.shape)
            labels_res = labels_res.transpose(1, 2).transpose(0, 1)
            return labels_res

        from datasets.data_tools import np_to_tensor
        from utils.utils import filter_points
        from utils.var_dim import squeezeToNumpy

        sample = self.samples[index]
        img = load_as_float(sample["image"])
        H, W = img.shape[0], img.shape[1]
        self.H = H
        self.W = W
        pnts = np.load(sample["points"])  # (y, x)
        pnts = torch.tensor(pnts).float()
        pnts = torch.stack((pnts[:, 1], pnts[:, 0]), dim=1)  # (x, y)
        pnts = filter_points(pnts, torch.tensor([W, H]))
        sample = {}

        # print('pnts: ', pnts[:5])
        # print('--1', pnts)
        labels_2D = get_labels(pnts, H, W)
        sample.update({"labels_2D": labels_2D.unsqueeze(0)})

        # assert Hc == round(Hc) and Wc == round(Wc), "Input image size not fit in the block size"
        if (self.config["augmentation"]["photometric"]["enable_train"]
                and self.action == "training") or (
                    self.config["augmentation"]["photometric"]["enable_val"]
                    and self.action == "validation"):
            # print('>>> Photometric aug enabled for %s.'%self.action)
            # augmentation = self.ImgAugTransform(**self.config["augmentation"])
            img = imgPhotometric(img)
        else:
            # print('>>> Photometric aug disabled for %s.'%self.action)
            pass

        if not ((self.config["augmentation"]["homographic"]["enable_train"]
                 and self.action == "training") or
                (self.config["augmentation"]["homographic"]["enable_val"]
                 and self.action == "validation")):
            # print('<<< Homograpy aug disabled for %s.'%self.action)
            img = img[:, :, np.newaxis]
            # labels = labels.view(-1,H,W)
            if self.transform is not None:
                img = self.transform(img)
            sample["image"] = img
            # sample = {'image': img, 'labels_2D': labels}
            valid_mask = self.compute_valid_mask(torch.tensor([H, W]),
                                                 inv_homography=torch.eye(3))
            sample.update({"valid_mask": valid_mask})
            labels_res = get_label_res(H, W, pnts)
            pnts_post = pnts
            # pnts_for_gaussian = pnts
        else:
            # print('>>> Homograpy aug enabled for %s.'%self.action)
            # img_warp = img
            from utils.utils import homography_scaling_torch as homography_scaling
            from numpy.linalg import inv

            homography = self.sample_homography(
                np.array([2, 2]),
                shift=-1,
                **self.config["augmentation"]["homographic"]["params"],
            )

            ##### use inverse from the sample homography
            homography = inv(homography)
            ######

            homography = torch.tensor(homography).float()
            inv_homography = homography.inverse()
            img = torch.from_numpy(img)
            warped_img = self.inv_warp_image(img.squeeze(),
                                             inv_homography,
                                             mode="bilinear")
            warped_img = warped_img.squeeze().numpy()
            warped_img = warped_img[:, :, np.newaxis]

            # labels = torch.from_numpy(labels)
            # warped_labels = self.inv_warp_image(labels.squeeze(), inv_homography, mode='nearest').unsqueeze(0)
            warped_pnts = self.warp_points(
                pnts, homography_scaling(homography, H, W))
            warped_pnts = filter_points(warped_pnts, torch.tensor([W, H]))
            # pnts = warped_pnts[:, [1, 0]]
            # pnts_for_gaussian = warped_pnts
            # warped_labels = torch.zeros(H, W)
            # warped_labels[warped_pnts[:, 1], warped_pnts[:, 0]] = 1
            # warped_labels = warped_labels.view(-1, H, W)

            if self.transform is not None:
                warped_img = self.transform(warped_img)
            # sample = {'image': warped_img, 'labels_2D': warped_labels}
            sample["image"] = warped_img

            valid_mask = self.compute_valid_mask(
                torch.tensor([H, W]),
                inv_homography=inv_homography,
                erosion_radius=self.config["augmentation"]["homographic"]
                ["valid_border_margin"],
            )  # can set to other value
            sample.update({"valid_mask": valid_mask})

            labels_2D = get_labels(warped_pnts, H, W)
            sample.update({"labels_2D": labels_2D.unsqueeze(0)})

            labels_res = get_label_res(H, W, warped_pnts)
            pnts_post = warped_pnts

        if self.gaussian_label:
            # warped_labels_gaussian = get_labels_gaussian(pnts)
            from datasets.data_tools import get_labels_bi

            labels_2D_bi = get_labels_bi(pnts_post, H, W)

            labels_gaussian = self.gaussian_blur(squeezeToNumpy(labels_2D_bi))
            labels_gaussian = np_to_tensor(labels_gaussian, H, W)
            sample["labels_2D_gaussian"] = labels_gaussian

            # add residua

        sample.update({"labels_res": labels_res})

        ### code for warped image
        if self.config["warped_pair"]["enable"]:
            from datasets.data_tools import warpLabels

            homography = self.sample_homography(
                np.array([2, 2]),
                shift=-1,
                **self.config["warped_pair"]["params"])

            ##### use inverse from the sample homography
            homography = np.linalg.inv(homography)
            #####
            inv_homography = np.linalg.inv(homography)

            homography = torch.tensor(homography).type(torch.FloatTensor)
            inv_homography = torch.tensor(inv_homography).type(
                torch.FloatTensor)

            # photometric augmentation from original image

            # warp original image
            warped_img = img.type(torch.FloatTensor)
            warped_img = self.inv_warp_image(warped_img.squeeze(),
                                             inv_homography,
                                             mode="bilinear").unsqueeze(0)
            if (self.enable_photo_train == True
                    and self.action == "train") or (self.enable_photo_val
                                                    and self.action == "val"):
                warped_img = imgPhotometric(
                    warped_img.numpy().squeeze())  # numpy array (H, W, 1)
                warped_img = torch.tensor(warped_img, dtype=torch.float32)
                pass
            warped_img = warped_img.view(-1, H, W)

            # warped_labels = warpLabels(pnts, H, W, homography)
            warped_set = warpLabels(pnts, H, W, homography, bilinear=True)
            warped_labels = warped_set["labels"]
            warped_res = warped_set["res"]
            warped_res = warped_res.transpose(1, 2).transpose(0, 1)
            # print("warped_res: ", warped_res.shape)
            if self.gaussian_label:
                # print("do gaussian labels!")
                # warped_labels_gaussian = get_labels_gaussian(warped_set['warped_pnts'].numpy())
                # warped_labels_bi = self.inv_warp_image(labels_2D.squeeze(), inv_homography, mode='nearest').unsqueeze(0) # bilinear, nearest
                warped_labels_bi = warped_set["labels_bi"]
                warped_labels_gaussian = self.gaussian_blur(
                    squeezeToNumpy(warped_labels_bi))
                warped_labels_gaussian = np_to_tensor(warped_labels_gaussian,
                                                      H, W)
                sample["warped_labels_gaussian"] = warped_labels_gaussian
                sample.update({"warped_labels_bi": warped_labels_bi})

            sample.update({
                "warped_img": warped_img,
                "warped_labels": warped_labels,
                "warped_res": warped_res,
            })

            # print('erosion_radius', self.config['warped_pair']['valid_border_margin'])
            valid_mask = self.compute_valid_mask(
                torch.tensor([H, W]),
                inv_homography=inv_homography,
                erosion_radius=self.config["warped_pair"]
                ["valid_border_margin"],
            )  # can set to other value
            sample.update({"warped_valid_mask": valid_mask})
            sample.update({
                "homographies": homography,
                "inv_homographies": inv_homography
            })

        # labels = self.labels2Dto3D(self.cell_size, labels)
        # labels = torch.from_numpy(labels[np.newaxis,:,:])
        # input.update({'labels': labels})

        ### code for warped image

        # if self.config['gaussian_label']['enable']:
        #     heatmaps = np.zeros((H, W))
        #     # for center in pnts_int.numpy():
        #     for center in pnts[:, [1, 0]].numpy():
        #         # print("put points: ", center)
        #         heatmaps = self.putGaussianMaps(center, heatmaps)
        #     # import matplotlib.pyplot as plt
        #     # plt.figure(figsize=(5, 10))
        #     # plt.subplot(211)
        #     # plt.imshow(heatmaps)
        #     # plt.colorbar()
        #     # plt.subplot(212)
        #     # plt.imshow(np.squeeze(warped_labels.numpy()))
        #     # plt.show()
        #     # import time
        #     # time.sleep(500)
        #     # results = self.pool.map(self.putGaussianMaps_par, warped_pnts.numpy())

        #     warped_labels_gaussian = torch.from_numpy(heatmaps).view(-1, H, W)
        #     warped_labels_gaussian[warped_labels_gaussian>1.] = 1.

        #     sample['labels_2D_gaussian'] = warped_labels_gaussian

        if self.getPts:
            sample.update({"pts": pnts})

        return sample