Ejemplo n.º 1
0
    def batch_extract_features(self, desc, heatmap_nms_batch, residual):
        # extract pts, residuals for pts, descriptors
        """
        return: -- type: tensorFloat
          pts: tensor [batch, N, 2] (no grad)  (x, y)
          pts_offset: tensor [batch, N, 2] (grad) (x, y)
          pts_desc: tensor [batch, N, 256] (grad)
        """
        batch_size = heatmap_nms_batch.shape[0]

        pts_int, pts_offset, pts_desc = [], [], []
        pts_idx = heatmap_nms_batch[...].nonzero()  # [N, 4(batch, 0, y, x)]
        for i in range(batch_size):
            mask_b = (pts_idx[:, 0] == i)  # first column == batch
            pts_int_b = pts_idx[mask_b][:, 2:].float()  # default floatTensor
            pts_int_b = pts_int_b[:, [1, 0]]  # tensor [N, 2(x,y)]
            res_b = residual[mask_b]
            # print("res_b: ", res_b.shape)
            # print("pts_int_b: ", pts_int_b.shape)
            pts_b = pts_int_b + res_b  # .no_grad()
            # extract desc
            pts_desc_b = self.sample_desc_from_points(desc[i].unsqueeze(0),
                                                      pts_b).squeeze(0)
            # print("pts_desc_b: ", pts_desc_b.shape)
            # get random shuffle
            from utils.utils import crop_or_pad_choice
            choice = crop_or_pad_choice(pts_int_b.shape[0],
                                        out_num_points=self.out_num_points,
                                        shuffle=True)
            choice = torch.tensor(choice)
            pts_int.append(pts_int_b[choice])
            pts_offset.append(res_b[choice])
            pts_desc.append(pts_desc_b[choice])

        pts_int = torch.stack((pts_int), dim=0)
        pts_offset = torch.stack((pts_offset), dim=0)
        pts_desc = torch.stack((pts_desc), dim=0)
        return {
            'pts_int': pts_int,
            'pts_offset': pts_offset,
            'pts_desc': pts_desc
        }
def descriptor_loss_sparse(descriptors,
                           descriptors_warped,
                           homographies,
                           mask_valid=None,
                           cell_size=8,
                           device='cpu',
                           descriptor_dist=4,
                           lamda_d=250,
                           num_matching_attempts=1000,
                           num_masked_non_matches_per_match=10,
                           dist='cos',
                           method='1d',
                           **config):
    """
    consider batches of descriptors
    :param descriptors:
        Output from descriptor head
        tensor [descriptors, Hc, Wc]
    :param descriptors_warped:
        Output from descriptor head of warped image
        tensor [descriptors, Hc, Wc]
    """
    def uv_to_tuple(uv):
        return (uv[:, 0], uv[:, 1])

    def tuple_to_uv(uv_tuple):
        return torch.stack([uv_tuple[0], uv_tuple[1]])

    def tuple_to_1d(uv_tuple, W, uv=True):
        if uv:
            return uv_tuple[0] + uv_tuple[1] * W
        else:
            return uv_tuple[0] * W + uv_tuple[1]

    def uv_to_1d(points, W, uv=True):
        # assert points.dim == 2
        #     print("points: ", points[0])
        #     print("H: ", H)
        if uv:
            return points[..., 0] + points[..., 1] * W
        else:
            return points[..., 0] * W + points[..., 1]

    ## calculate matches loss
    def get_match_loss(image_a_pred,
                       image_b_pred,
                       matches_a,
                       matches_b,
                       dist='cos',
                       method='1d'):
        match_loss, matches_a_descriptors, matches_b_descriptors = \
            PixelwiseContrastiveLoss.match_loss(image_a_pred, image_b_pred,
                matches_a, matches_b, dist=dist, method=method)
        return match_loss

    def get_non_matches_corr(img_b_shape,
                             uv_a,
                             uv_b_matches,
                             num_masked_non_matches_per_match=10,
                             device='cpu'):
        ## sample non matches
        uv_b_matches = uv_b_matches.squeeze()
        uv_b_matches_tuple = uv_to_tuple(uv_b_matches)
        uv_b_non_matches_tuple = correspondence_finder.create_non_correspondences(
            uv_b_matches_tuple,
            img_b_shape,
            num_non_matches_per_match=num_masked_non_matches_per_match,
            img_b_mask=None)

        ## create_non_correspondences
        #     print("img_b_shape ", img_b_shape)
        #     print("uv_b_matches ", uv_b_matches.shape)
        # print("uv_a: ", uv_to_tuple(uv_a))
        # print("uv_b_non_matches: ", uv_b_non_matches)
        #     print("uv_b_non_matches: ", tensorUv2tuple(uv_b_non_matches))
        uv_a_tuple, uv_b_non_matches_tuple = \
            create_non_matches(uv_to_tuple(uv_a), uv_b_non_matches_tuple, num_masked_non_matches_per_match)
        return uv_a_tuple, uv_b_non_matches_tuple

    def get_non_match_loss(image_a_pred,
                           image_b_pred,
                           non_matches_a,
                           non_matches_b,
                           dist='cos'):
        ## non matches loss
        non_match_loss, num_hard_negatives, non_matches_a_descriptors, non_matches_b_descriptors = \
            PixelwiseContrastiveLoss.non_match_descriptor_loss(image_a_pred, image_b_pred,
                                                               non_matches_a.long().squeeze(),
                                                               non_matches_b.long().squeeze(),
                                                               M=0.2, invert=True, dist=dist)
        non_match_loss = non_match_loss.sum() / (num_hard_negatives + 1)
        return non_match_loss

    from utils.utils import filter_points
    from utils.utils import crop_or_pad_choice
    from utils.utils import normPts
    # ##### print configs
    # print("num_masked_non_matches_per_match: ", num_masked_non_matches_per_match)
    # print("num_matching_attempts: ", num_matching_attempts)
    # dist = 'cos'
    # print("method: ", method)

    Hc, Wc = descriptors.shape[1], descriptors.shape[2]
    img_shape = (Hc, Wc)

    # print("img_shape: ", img_shape)
    # img_shape_cpu = (Hc.to('cpu'), Wc.to('cpu'))

    # image_a_pred = descriptors.view(1, -1, Hc * Wc).transpose(1, 2)  # torch [batch_size, H*W, D]
    def descriptor_reshape(descriptors):
        descriptors = descriptors.view(-1, Hc * Wc).transpose(
            0, 1)  # torch [D, H, W] --> [H*W, d]
        descriptors = descriptors.unsqueeze(0)  # torch [1, H*W, D]
        return descriptors

    image_a_pred = descriptor_reshape(descriptors)  # torch [1, H*W, D]
    # print("image_a_pred: ", image_a_pred.shape)
    image_b_pred = descriptor_reshape(
        descriptors_warped)  # torch [batch_size, H*W, D]

    # matches
    uv_a = get_coor_cells(Hc, Wc, cell_size, uv=True, device='cpu')
    # print("uv_a: ", uv_a[0])

    homographies_H = scale_homography_torch(homographies,
                                            img_shape,
                                            shift=(-1, -1))

    # print("experiment inverse homographies")
    # homographies_H = torch.stack([torch.inverse(H) for H in homographies_H])
    # print("homographies_H: ", homographies_H.shape)
    # homographies_H = torch.inverse(homographies_H)

    uv_b_matches = warp_coor_cells_with_homographies(uv_a,
                                                     homographies_H.to('cpu'),
                                                     uv=True,
                                                     device='cpu')
    #
    # print("uv_b_matches before round: ", uv_b_matches[0])

    uv_b_matches.round_()
    # print("uv_b_matches after round: ", uv_b_matches[0])
    uv_b_matches = uv_b_matches.squeeze(0)

    # filtering out of range points
    # choice = crop_or_pad_choice(x_all.shape[0], self.sift_num, shuffle=True)

    uv_b_matches, mask = filter_points(uv_b_matches,
                                       torch.tensor([Wc, Hc]).to(device='cpu'),
                                       return_mask=True)
    # print ("pos mask sum: ", mask.sum())
    uv_a = uv_a[mask]

    # crop to the same length
    shuffle = True
    if not shuffle: print("shuffle: ", shuffle)
    choice = crop_or_pad_choice(uv_b_matches.shape[0],
                                num_matching_attempts,
                                shuffle=shuffle)
    choice = torch.tensor(choice)
    uv_a = uv_a[choice]
    uv_b_matches = uv_b_matches[choice]

    if method == '2d':
        matches_a = normPts(uv_a, torch.tensor([Wc, Hc]).float())  # [u, v]
        matches_b = normPts(uv_b_matches, torch.tensor([Wc, Hc]).float())
    else:
        matches_a = uv_to_1d(uv_a, Wc)
        matches_b = uv_to_1d(uv_b_matches, Wc)

    # print("matches_a: ", matches_a.shape)
    # print("matches_b: ", matches_b.shape)
    # print("matches_b max: ", matches_b.max())

    if method == '2d':
        match_loss = get_match_loss(descriptors,
                                    descriptors_warped,
                                    matches_a.to(device),
                                    matches_b.to(device),
                                    dist=dist,
                                    method='2d')
    else:
        match_loss = get_match_loss(image_a_pred,
                                    image_b_pred,
                                    matches_a.long().to(device),
                                    matches_b.long().to(device),
                                    dist=dist)

    # non matches

    # get non matches correspondence
    uv_a_tuple, uv_b_non_matches_tuple = get_non_matches_corr(
        img_shape,
        uv_a,
        uv_b_matches,
        num_masked_non_matches_per_match=num_masked_non_matches_per_match)

    non_matches_a = tuple_to_1d(uv_a_tuple, Wc)
    non_matches_b = tuple_to_1d(uv_b_non_matches_tuple, Wc)

    # print("non_matches_a: ", non_matches_a)
    # print("non_matches_b: ", non_matches_b)

    non_match_loss = get_non_match_loss(image_a_pred,
                                        image_b_pred,
                                        non_matches_a.to(device),
                                        non_matches_b.to(device),
                                        dist=dist)
    # non_match_loss = non_match_loss.mean()

    loss = lamda_d * match_loss + non_match_loss
    return loss, lamda_d * match_loss, non_match_loss
    pass
Ejemplo n.º 3
0
def descriptor_loss_sparse_reliability(descriptors,
                                       descriptors_warped,
                                       reliability,
                                       reliability_warp,
                                       homographies,
                                       aplosser,
                                       mask_valid=None,
                                       cell_size=8,
                                       device='cpu',
                                       descriptor_dist=4,
                                       lamda_d=250,
                                       lamda_r=1,
                                       num_matching_attempts=1000,
                                       num_masked_non_matches_per_match=10,
                                       dist='cos',
                                       method='1d',
                                       sos=True,
                                       reli_base=0.5,
                                       **config):
    """
    consider batches of descriptors
    :param descriptors:
        Output from descriptor head
        tensor [descriptors, Hc, Wc]
    :param descriptors_warped:
        Output from descriptor head of warped image
        tensor [descriptors, Hc, Wc]
    """
    def uv_to_tuple(uv):
        return (uv[:, 0], uv[:, 1])

    def tuple_to_uv(uv_tuple):
        return torch.stack([uv_tuple[0], uv_tuple[1]])

    def tuple_to_1d(uv_tuple, W, uv=True):
        if uv:
            return uv_tuple[0] + uv_tuple[1] * W
        else:
            return uv_tuple[0] * W + uv_tuple[1]

    def uv_to_1d(points, W, uv=True):
        # assert points.dim == 2
        #     print("points: ", points[0])
        #     print("H: ", H)
        if uv:
            return points[..., 0] + points[..., 1] * W
        else:
            return points[..., 0] * W + points[..., 1]

    ## calculate matches loss
    def get_match_loss(image_a_pred,
                       image_b_pred,
                       matches_a,
                       matches_b,
                       dist='cos',
                       method='1d',
                       sos=True):
        match_loss, matches_a_descriptors, matches_b_descriptors = \
            PixelwiseContrastiveLoss.match_loss(image_a_pred, image_b_pred,
                matches_a, matches_b, dist=dist, method=method, sos=sos)

        # Add SOS Loss
        return match_loss, matches_a_descriptors, matches_b_descriptors

    def get_non_matches_corr(img_b_shape,
                             uv_a,
                             uv_b_matches,
                             num_masked_non_matches_per_match=10,
                             device='cpu'):
        ## sample non matches
        uv_b_matches = uv_b_matches.squeeze()
        uv_b_matches_tuple = uv_to_tuple(uv_b_matches)
        uv_b_non_matches_tuple = correspondence_finder.create_non_correspondences(
            uv_b_matches_tuple,
            img_b_shape,
            num_non_matches_per_match=num_masked_non_matches_per_match,
            img_b_mask=None)

        ## create_non_correspondences
        #     print("img_b_shape ", img_b_shape)
        #     print("uv_b_matches ", uv_b_matches.shape)
        # print("uv_a: ", uv_to_tuple(uv_a))
        # print("uv_b_non_matches: ", uv_b_non_matches)
        #     print("uv_b_non_matches: ", tensorUv2tuple(uv_b_non_matches))
        uv_a_tuple, uv_b_non_matches_tuple = \
            create_non_matches(uv_to_tuple(uv_a), uv_b_non_matches_tuple, num_masked_non_matches_per_match)
        return uv_a_tuple, uv_b_non_matches_tuple

    def get_non_match_loss(image_a_pred,
                           image_b_pred,
                           non_matches_a,
                           non_matches_b,
                           dist='cos'):
        ## non matches loss
        non_match_loss, num_hard_negatives, non_matches_a_descriptors, non_matches_b_descriptors = \
            PixelwiseContrastiveLoss.non_match_descriptor_loss(image_a_pred, image_b_pred,
                                                               non_matches_a.long().squeeze(),
                                                               non_matches_b.long().squeeze(),
                                                               M=0.2, invert=True, dist=dist)
        non_match_loss = non_match_loss.sum() / (num_hard_negatives + 1)
        return non_match_loss

    # def reliability_loss(descriptors_a, descriptors_b, gt, uv_a, uv_b):
    #     dist = descriptors_a.mm(descriptors_b.T).pow(2)
    #     dist = dist / dist.max()
    #     acc = (dist * gt).sum(-1) # to be tested
    #     # max = dist.max()
    #     # min = dist.min()
    #     # dist_diagonal = dist.diagonal().pow(2)
    #     reli_a = reliability[:, uv_a[:, 1].long(), uv_a[:, 0].long()].transpose(1, 0)
    #     reli_b = reliability_warp[:, uv_b[:, 1].long(), uv_b[:, 0].long()]
    #     # reli = reli_a.mm(reli_b).pow(0.5)
    #     reli = (reli_a + reli_b) / 2
    #     ap_loss = 1 - acc * reli - (1 - reli) * reli_base
    #     # ap_loss = 1 - dist_diagonal * reli
    #     ap_loss = ap_loss.mean()
    #     return ap_loss

    def reliability_loss(descriptors_a,
                         descriptors_b,
                         image_b_pred,
                         reliability_a,
                         reliability_b,
                         uv_a,
                         uv_b,
                         img_b_shape,
                         aplosser,
                         method='1d',
                         device='cpu',
                         reli_base=0.5):
        uv_b = uv_b.squeeze()
        uv_b_matches_tuple = uv_to_tuple(uv_b)
        uv_b_non_matches_tuple = correspondence_finder.create_non_correspondences(
            uv_b_matches_tuple,
            img_b_shape,
            num_non_matches_per_match=1,
            img_b_mask=None)
        _, uv_b_non_matches_tuple = create_non_matches(uv_to_tuple(uv_a),
                                                       uv_b_non_matches_tuple,
                                                       1)
        uv_b_non = tuple_to_uv(uv_b_non_matches_tuple).squeeze().transpose(
            1, 0)  # negative sample coordinates

        # generate negative descriptors
        if method == '2d':

            def sampleDescriptors(image_a_pred, matches_a, mode, norm=False):
                image_a_pred = image_a_pred.unsqueeze(0)  # torch [1, D, H, W]
                matches_a.unsqueeze_(0).unsqueeze_(2)
                matches_a_descriptors = torch.nn.functional.grid_sample(
                    image_a_pred, matches_a, mode=mode, align_corners=True)
                matches_a_descriptors = matches_a_descriptors.squeeze(
                ).transpose(0, 1)

                # print("image_a_pred: ", image_a_pred.shape)
                # print("matches_a: ", matches_a.shape)
                # print("matches_a: ", matches_a)
                # print("matches_a_descriptors: ", matches_a_descriptors)
                if norm:
                    dn = torch.norm(matches_a_descriptors, p=2,
                                    dim=1)  # Compute the norm of b_descriptors
                    matches_a_descriptors = matches_a_descriptors.div(
                        torch.unsqueeze(dn, 1))  # Divide by norm to normalize.
                return matches_a_descriptors

            matches_b_non = normPts(
                uv_b_non,
                torch.tensor([img_b_shape[1], img_b_shape[0]]).float())
            descriptors_b_non = sampleDescriptors(image_b_pred,
                                                  matches_b_non.to(device),
                                                  mode='bilinear',
                                                  norm=False)
        else:
            matches_b_non = uv_to_1d(uv_b_matches, img_b_shape[1])
            descriptors_b_non = torch.index_select(
                image_b_pred, 1,
                matches_b_non.long().to(device))

        qconf = reliability_a[:, uv_a[:, 1].long(), uv_a[:, 0].long()] + \
                reliability_b[:, uv_b[:, 1].long(), uv_b[:, 0].long()]
        qconf /= 2

        pscores = (descriptors_a * descriptors_b).sum(-1)[:, None]
        nscores = (descriptors_a * descriptors_b_non).sum(-1)[:, None]
        scores = torch.cat((pscores, nscores), dim=1)

        gt = torch.zeros_like(scores, dtype=torch.uint8)
        gt[:, :pscores.shape[1]] = 1

        ap = aplosser(scores, gt)
        ap_loss = 1 - ap * qconf - (1 - qconf) * reli_base

        return ap_loss.mean()

    from utils.utils import filter_points
    from utils.utils import crop_or_pad_choice
    from utils.utils import normPts
    # ##### print configs
    # print("num_masked_non_matches_per_match: ", num_masked_non_matches_per_match)
    # print("num_matching_attempts: ", num_matching_attempts)
    # dist = 'cos'
    # print("method: ", method)

    Hc, Wc = descriptors.shape[1], descriptors.shape[2]
    img_shape = (Hc, Wc)

    # print("img_shape: ", img_shape)
    # img_shape_cpu = (Hc.to('cpu'), Wc.to('cpu'))

    # image_a_pred = descriptors.view(1, -1, Hc * Wc).transpose(1, 2)  # torch [batch_size, H*W, D]
    def descriptor_reshape(descriptors):
        descriptors = descriptors.view(-1, Hc * Wc).transpose(
            0, 1)  # torch [D, H, W] --> [H*W, d]
        descriptors = descriptors.unsqueeze(0)  # torch [1, H*W, D]
        return descriptors

    image_a_pred = descriptor_reshape(descriptors)  # torch [1, H*W, D]
    # print("image_a_pred: ", image_a_pred.shape)
    image_b_pred = descriptor_reshape(
        descriptors_warped)  # torch [batch_size, H*W, D]

    # matches
    uv_a = get_coor_cells(Hc, Wc, cell_size, uv=True, device='cpu')
    # print("uv_a: ", uv_a[0])

    homographies_H = scale_homography_torch(
        homographies, img_shape, shift=(-1, -1)
    )  # original scale is for image size, now downscale it to descriptor tensor size

    # print("experiment inverse homographies")
    # homographies_H = torch.stack([torch.inverse(H) for H in homographies_H])
    # print("homographies_H: ", homographies_H.shape)
    # homographies_H = torch.inverse(homographies_H)

    uv_b_matches = warp_coor_cells_with_homographies(uv_a,
                                                     homographies_H.to('cpu'),
                                                     uv=True,
                                                     device='cpu')
    #
    # print("uv_b_matches before round: ", uv_b_matches[0])

    uv_b_matches.round_()
    # print("uv_b_matches after round: ", uv_b_matches[0])
    uv_b_matches = uv_b_matches.squeeze(0)

    # filtering out of range points
    # choice = crop_or_pad_choice(x_all.shape[0], self.sift_num, shuffle=True)

    uv_b_matches, mask = filter_points(uv_b_matches,
                                       torch.tensor([Wc, Hc]).to(device='cpu'),
                                       return_mask=True)
    # print ("pos mask sum: ", mask.sum())
    uv_a = uv_a[mask]  # uv_a, uv_b_matches are the gt

    # crop to the same length
    shuffle = True
    if not shuffle: print("shuffle: ", shuffle)
    choice = crop_or_pad_choice(uv_b_matches.shape[0],
                                num_matching_attempts,
                                shuffle=shuffle)
    choice = torch.tensor(choice)
    uv_a = uv_a[choice]
    uv_b_matches = uv_b_matches[choice]

    ## add reliability,

    if method == '2d':
        matches_a = normPts(uv_a, torch.tensor([Wc, Hc]).float())  # [u, v]
        matches_b = normPts(uv_b_matches, torch.tensor([Wc, Hc]).float())
    else:
        matches_a = uv_to_1d(uv_a, Wc)
        matches_b = uv_to_1d(uv_b_matches, Wc)

    # print("matches_a: ", matches_a.shape)
    # print("matches_b: ", matches_b.shape)
    # print("matches_b max: ", matches_b.max())

    if method == '2d':
        match_loss, matches_a_descriptors, matches_b_descriptors = get_match_loss(
            descriptors,
            descriptors_warped,
            matches_a.to(device),
            matches_b.to(device),
            dist=dist,
            method='2d',
            sos=sos)
    else:
        match_loss, matches_a_descriptors, matches_b_descriptors = get_match_loss(
            image_a_pred,
            image_b_pred,
            matches_a.long().to(device),
            matches_b.long().to(device),
            dist=dist,
            sos=sos)

    # reliability loss
    matches_a_descriptors, matches_b_descriptors = matches_a_descriptors.squeeze(
    ), matches_b_descriptors.squeeze()
    # gt = torch.eye(matches_a_descriptors.shape[0]).to(device)
    if method == '2d':
        ap_loss = reliability_loss(matches_a_descriptors,
                                   matches_b_descriptors,
                                   descriptors_warped,
                                   reliability,
                                   reliability_warp,
                                   uv_a,
                                   uv_b_matches,
                                   img_shape,
                                   aplosser,
                                   method=method,
                                   device=device,
                                   reli_base=reli_base)
    else:
        ap_loss = reliability_loss(matches_a_descriptors,
                                   matches_b_descriptors,
                                   image_b_pred,
                                   reliability,
                                   reliability_warp,
                                   uv_a,
                                   uv_b_matches,
                                   img_shape,
                                   aplosser,
                                   method=method,
                                   device=device,
                                   reli_base=reli_base)

    # non matches

    # get non matches correspondence
    uv_a_tuple, uv_b_non_matches_tuple = get_non_matches_corr(
        img_shape,
        uv_a,
        uv_b_matches,
        num_masked_non_matches_per_match=num_masked_non_matches_per_match)

    non_matches_a = tuple_to_1d(uv_a_tuple, Wc)
    non_matches_b = tuple_to_1d(uv_b_non_matches_tuple, Wc)

    # print("non_matches_a: ", non_matches_a)
    # print("non_matches_b: ", non_matches_b)

    non_match_loss = get_non_match_loss(image_a_pred,
                                        image_b_pred,
                                        non_matches_a.to(device),
                                        non_matches_b.to(device),
                                        dist=dist)
    # non_match_loss = non_match_loss.mean()

    loss = lamda_d * match_loss + lamda_r * ap_loss + non_match_loss
    return loss, lamda_d * match_loss, non_match_loss, lamda_r * ap_loss
    pass
Ejemplo n.º 4
0
def descriptor_loss_sparse(descriptors,
                           descriptors_warped,
                           homographies,
                           mask_valid=None,
                           cell_size=8,
                           device='cpu',
                           descriptor_dist=4,
                           lamda_d=250,
                           num_matching_attempts=1000,
                           num_masked_non_matches_per_match=10,
                           **config):
    """
    consider batches of descriptors
    :param descriptors:
        Output from descriptor head
        tensor [descriptors, Hc, Wc]
    :param descriptors_warped:
        Output from descriptor head of warped image
        tensor [descriptors, Hc, Wc]
    """
    def uv_to_tuple(uv):
        return (uv[:, 0], uv[:, 1])

    def tuple_to_uv(uv_tuple):
        return torch.stack([uv_tuple[:, 0], uv_tuple[:, 1]])

    def tuple_to_1d(uv_tuple, H):
        return uv_tuple[0] * H + uv_tuple[1]

    def uv_to_1d(points, H):
        # assert points.dim == 2
        #     print("points: ", points[0])
        #     print("H: ", H)
        return points[..., 0] * H + points[..., 1]

    ## calculate matches loss
    def get_match_loss(image_a_pred, image_b_pred, matches_a, matches_b):
        match_loss, matches_a_descriptors, matches_b_descriptors = \
            PixelwiseContrastiveLoss.match_loss(image_a_pred, image_b_pred, matches_a.long(), matches_b.long())
        return match_loss

    def get_non_matches_corr(img_b_shape,
                             uv_a,
                             uv_b_matches,
                             num_masked_non_matches_per_match=10):
        ## sample non matches
        uv_b_matches = uv_b_matches.squeeze()
        uv_b_matches_tuple = uv_to_tuple(uv_b_matches)
        uv_b_non_matches_tuple = correspondence_finder.create_non_correspondences(
            uv_b_matches_tuple,
            img_b_shape,
            num_non_matches_per_match=num_masked_non_matches_per_match,
            img_b_mask=None)

        ## create_non_correspondences
        #     print("img_b_shape ", img_b_shape)
        #     print("uv_b_matches ", uv_b_matches.shape)
        # print("uv_a: ", uv_to_tuple(uv_a))
        # print("uv_b_non_matches: ", uv_b_non_matches)
        #     print("uv_b_non_matches: ", tensorUv2tuple(uv_b_non_matches))
        uv_a_tuple, uv_b_non_matches_tuple = \
            create_non_matches(uv_to_tuple(uv_a), uv_b_non_matches_tuple, num_masked_non_matches_per_match)
        return uv_a_tuple, uv_b_non_matches_tuple

    def get_non_match_loss(image_a_pred, image_b_pred, non_matches_a,
                           non_matches_b):
        ## non matches loss
        non_match_loss, num_hard_negatives, non_matches_a_descriptors, non_matches_b_descriptors = \
                        PixelwiseContrastiveLoss.non_match_descriptor_loss(image_a_pred, image_b_pred,
                                                   non_matches_a.long().squeeze(), non_matches_b.long().squeeze())
        return non_match_loss

    from utils.utils import filter_points
    from utils.utils import crop_or_pad_choice

    Hc, Wc = descriptors.shape[1], descriptors.shape[2]

    image_a_pred = descriptors.view(1, -1, Hc * Wc).transpose(
        1, 2)  # torch [batch_size, H*W, D]
    # print("image_a_pred: ", image_a_pred.shape)
    image_b_pred = descriptors_warped.view(1, -1, Hc * Wc).transpose(
        1, 2)  # torch [batch_size, H*W, D]

    # matches
    uv_a = get_coor_cells(Hc, Wc, cell_size, uv=True)
    # print("uv_a: ", uv_a.shape)

    homographies_H = scale_homography_torch(homographies,
                                            image_shape,
                                            shift=(-1, -1))

    uv_b_matches = warp_coor_cells_with_homographies(uv_a,
                                                     homographies_H,
                                                     uv=True)
    uv_b_matches = uv_b_matches.squeeze(0)
    # print("uv_b_matches: ", uv_b_matches.shape)

    # filtering out of range points
    # choice = crop_or_pad_choice(x_all.shape[0], self.sift_num, shuffle=True)

    uv_b_matches, mask = filter_points(uv_b_matches,
                                       torch.tensor([Wc, Hc]),
                                       return_mask=True)
    uv_a = uv_a[mask]

    # crop to the same length
    choice = crop_or_pad_choice(uv_b_matches.shape[0],
                                num_matching_attempts,
                                shuffle=True)
    choice = torch.tensor(choice)
    uv_a = uv_a[choice]
    uv_b_matches = uv_b_matches[choice]

    matches_a = uv_to_1d(uv_a, Hc)
    matches_b = uv_to_1d(uv_b_matches, Hc)

    # print("matches_a: ", matches_a.shape)
    # print("matches_b: ", matches_b.shape)
    # print("matches_b max: ", matches_b.max())

    match_loss = get_match_loss(image_a_pred, image_b_pred, matches_a,
                                matches_b)

    # non matches
    img_b_shape = (Hc, Wc)

    # get non matches correspondence
    uv_a_tuple, uv_b_non_matches_tuple = get_non_matches_corr(
        img_b_shape,
        uv_a,
        uv_b_matches,
        num_masked_non_matches_per_match=num_masked_non_matches_per_match)

    non_matches_a = tuple_to_1d(uv_a_tuple, Hc)
    non_matches_b = tuple_to_1d(uv_b_non_matches_tuple, Hc)

    # print("non_matches_a: ", non_matches_a)
    # print("non_matches_b: ", non_matches_b)

    non_match_loss = get_non_match_loss(image_a_pred, image_b_pred,
                                        non_matches_a, non_matches_a)
    non_match_loss = non_match_loss.mean()

    loss = lamda_d * match_loss + non_match_loss
    return loss
    pass
Ejemplo n.º 5
0
def quadruplet_descriptor_loss_sparse(descriptors,
                                      descriptors_warped,
                                      homographies,
                                      mask_valid=None,
                                      cell_size=8,
                                      device='cpu',
                                      descriptor_dist=4,
                                      lamda_d=250,
                                      num_matching_attempts=1000,
                                      num_masked_non_matches_per_match=10,
                                      dist='cos',
                                      method='1d',
                                      **config):
    """
    consider batches of descriptors
    :param descriptors:
        Output from descriptor head
        tensor [descriptors, Hc, Wc]
    :param descriptors_warped:
        Output from descriptor head of warped image
        tensor [descriptors, Hc, Wc]
    """

    from utils.utils import filter_points
    from utils.utils import crop_or_pad_choice
    from utils.utils import normPts
    # ##### print configs
    # print("num_masked_non_matches_per_match: ", num_masked_non_matches_per_match)
    # print("num_matching_attempts: ", num_matching_attempts)
    # dist = 'cos'
    # print("method: ", method)

    Hc, Wc, Dim = descriptors.shape[1], descriptors.shape[
        2], descriptors.shape[0]
    img_shape = (Hc, Wc)
    # print("img_shape: ", img_shape)
    # img_shape_cpu = (Hc.to('cpu'), Wc.to('cpu'))

    # image_a_pred = descriptors.view(1, -1, Hc * Wc).transpose(1, 2)  # torch [batch_size, H*W, D]

    image_a_pred = descriptor_reshape(descriptors, Hc, Wc)  # torch [1, H*W, D]
    # print("image_a_pred: ", image_a_pred.shape)
    image_b_pred = descriptor_reshape(descriptors_warped, Hc,
                                      Wc)  # torch [batch_size, H*W, D]

    # matches
    uv_a = get_coor_cells(Hc, Wc, cell_size, uv=True, device='cpu')
    # print("uv_a: ", uv_a[0])

    homographies_H = scale_homography_torch(homographies,
                                            img_shape,
                                            shift=(-1, -1))

    # print("experiment inverse homographies")
    # homographies_H = torch.stack([torch.inverse(H) for H in homographies_H])
    # print("homographies_H: ", homographies_H.shape)
    # homographies_H = torch.inverse(homographies_H)

    uv_b_matches = warp_coor_cells_with_homographies(uv_a,
                                                     homographies_H.to('cpu'),
                                                     uv=True,
                                                     device='cpu')
    #
    # print("uv_b_matches before round: ", uv_b_matches[0])

    uv_b_matches.round_()
    # print("uv_b_matches after round: ", uv_b_matches[0])
    uv_b_matches = uv_b_matches.squeeze(0)

    # filtering out of range points
    # choice = crop_or_pad_choice(x_all.shape[0], self.sift_num, shuffle=True)

    uv_b_matches, mask = filter_points(uv_b_matches,
                                       torch.tensor([Wc, Hc]).to(device='cpu'),
                                       return_mask=True)
    # print ("pos mask sum: ", mask.sum())
    uv_a = uv_a[mask]

    # crop to the same length
    shuffle = True
    if not shuffle: print("shuffle: ", shuffle)
    choice = crop_or_pad_choice(uv_b_matches.shape[0],
                                num_matching_attempts,
                                shuffle=shuffle)
    choice = torch.tensor(choice)
    uv_a = uv_a[choice]
    uv_b_matches = uv_b_matches[choice]

    matches_a = normPts(uv_a, torch.tensor([Wc, Hc]).float())  # [u, v]
    matches_b = normPts(uv_b_matches, torch.tensor([Wc, Hc]).float())

    # print("matches_a: ", matches_a.shape)
    # print("matches_b: ", matches_b.shape)
    # print("matches_b max: ", matches_b.max())

    # non matches
    # get non matches correspondence
    uv_a_tuple, uv_b_matches_tuple, uv_b_non_matches_tuple = get_first_term_corr(
        img_shape,
        uv_a,
        uv_b_matches,
        num_masked_non_matches_per_match=num_masked_non_matches_per_match)

    non_matches_a = tuple_to_1d(uv_a_tuple, Wc)
    long_matches_b = tuple_to_1d(uv_b_matches_tuple, Wc)
    non_matches_b = tuple_to_1d(uv_b_non_matches_tuple, Wc)

    non_matches_a_descriptors = torch.index_select(
        image_a_pred, 1,
        non_matches_a.to(device).long().squeeze()).squeeze()
    non_matches_b_descriptors = torch.index_select(
        image_b_pred, 1,
        non_matches_b.to(device).long().squeeze()).squeeze()
    long_matches_b_descriptors = torch.index_select(
        image_b_pred, 1,
        long_matches_b.to(device).long().squeeze()).squeeze()

    all_neg_pairs = (non_matches_a_descriptors -
                     non_matches_b_descriptors).pow(2).sum(dim=-1)
    all_pos_pairs = (non_matches_a_descriptors -
                     long_matches_b_descriptors).pow(2).sum(dim=-1)

    alpha = (all_neg_pairs.sum() - all_pos_pairs.sum()) / (
        num_matching_attempts * num_masked_non_matches_per_match)

    first_term = torch.clamp(all_pos_pairs - all_neg_pairs + alpha,
                             min=0).sum() / (num_matching_attempts *
                                             num_masked_non_matches_per_match)

    pos_matches_a = non_matches_a_descriptors.reshape(
        (num_matching_attempts, num_masked_non_matches_per_match,
         Dim)).repeat(1, num_masked_non_matches_per_match, 1).reshape(
             (-1, Dim))
    pos_matches_b = long_matches_b_descriptors.reshape(
        (num_matching_attempts, num_masked_non_matches_per_match,
         Dim)).repeat(1, num_masked_non_matches_per_match, 1).reshape(
             (-1, Dim))

    neg_matches_b = non_matches_b_descriptors.reshape(
        (num_matching_attempts, num_masked_non_matches_per_match, 1,
         Dim)).repeat(1, 1, num_masked_non_matches_per_match, 1).reshape(
             (-1, Dim))
    neg_matches_b_permute = non_matches_b_descriptors.reshape(
        (num_matching_attempts, 1, num_masked_non_matches_per_match,
         Dim)).repeat(1, num_masked_non_matches_per_match, 1, 1).reshape(
             (-1, Dim))

    pos_matches = (pos_matches_a - pos_matches_b).pow(2).sum(dim=-1)
    neg_matches = (neg_matches_b - neg_matches_b_permute).pow(2).sum(dim=-1)

    match_mask = (torch.ones(num_masked_non_matches_per_match,
                             num_masked_non_matches_per_match) -
                  torch.eye(num_masked_non_matches_per_match)).repeat(
                      num_matching_attempts, 1, 1).flatten().to(device)

    second_term = (torch.clamp(
        (pos_matches - neg_matches) * match_mask + 0.5 * alpha, min=0) /
                   (num_masked_non_matches_per_match *
                    (num_masked_non_matches_per_match - 1) *
                    num_matching_attempts)).sum()

    # non_match_loss = non_match_loss.mean()

    loss = lamda_d * first_term + second_term
    return loss, lamda_d * first_term, second_term