예제 #1
0
class KeypointNetwithIOLoss(torch.nn.Module):
    """
    Model class encapsulating the KeypointNet and the IONet.

    Parameters
    ----------
    keypoint_loss_weight: float
        Keypoint loss weight.
    descriptor_loss_weight: float
        Descriptor loss weight.
    score_loss_weight: float
        Score loss weight.
    keypoint_net_learning_rate: float
        Keypoint net learning rate.
    with_io:
        Use IONet.
    use_color : bool
        Use color or grayscale images.
    do_upsample: bool
        Upsample desnse descriptor map.
    do_cross: bool
        Predict keypoints outside cell borders.
    with_drop : bool
        Use dropout.
    descriptor_loss: bool
        Use descriptor loss.
    kwargs : dict
        Extra parameters
    """
    def __init__(
        self, keypoint_loss_weight=1.0, descriptor_loss_weight=2.0, score_loss_weight=1.0, 
        keypoint_net_learning_rate=0.001, with_io=True, use_color=True, do_upsample=True, 
        do_cross=True, descriptor_loss=True, with_drop=True, keypoint_net_type='KeypointNet', **kwargs):

        super().__init__()

        self.keypoint_loss_weight = keypoint_loss_weight
        self.descriptor_loss_weight = descriptor_loss_weight
        self.score_loss_weight = score_loss_weight
        self.keypoint_net_learning_rate = keypoint_net_learning_rate
        self.optim_params = []

        self.cell = 8 # Size of each output cell. Keep this fixed.
        self.border_remove = 4 # Remove points this close to the border.
        self.top_k2 = 300
        self.relax_field = 4

        self.use_color = use_color
        self.descriptor_loss = descriptor_loss

        # Initialize KeypointNet
        if keypoint_net_type == 'KeypointNet':
            self.keypoint_net = KeypointNet(use_color=use_color, do_upsample=do_upsample, with_drop=with_drop, do_cross=do_cross)
        elif keypoint_net_type == 'KeypointResnet':
            self.keypoint_net = KeypointResnet(with_drop=with_drop)
        else:
            raise NotImplemented('Keypoint net type not supported {}'.format(keypoint_net_type))
        self.keypoint_net = self.keypoint_net.cuda()
        self.add_optimizer_params('KeypointNet', self.keypoint_net.parameters(), keypoint_net_learning_rate)

        self.with_io = with_io
        self.io_net = None
        if self.with_io:
            self.io_net = InlierNet(blocks=4)
            self.io_net = self.io_net.cuda()
            self.add_optimizer_params('InlierNet', self.io_net.parameters(),  keypoint_net_learning_rate)

        self.train_metrics = {}
        self.vis = {}
        if torch.cuda.current_device() == 0:
            print('KeypointNetwithIOLoss:: with io {} with descriptor loss {}'.format(self.with_io, self.descriptor_loss))

    def add_optimizer_params(self, name, params, lr):
        self.optim_params.append(
            {'name': name, 'lr': lr, 'original_lr': lr,
             'params': filter(lambda p: p.requires_grad, params)})

    def forward(self, data, debug=False):
        """
        Processes a batch.

        Parameters
        ----------
        batch : dict
            Input batch.
        debug : bool
            True if to compute debug data (stored in self.vis).

        Returns
        -------
        output : dict
            Dictionary containing the output of depth and pose networks
        """

        loss_2d = 0

        if self.training:

            B, _, H, W = data['image'].shape
            device = data['image'].device

            recall_2d = 0
            inlier_cnt = 0

            input_img = data['image']
            input_img_aug = data['image_aug']
            homography = data['homography']

            input_img = to_color_normalized(input_img.clone())
            input_img_aug = to_color_normalized(input_img_aug.clone())

            # Get network outputs
            source_score, source_uv_pred, source_feat = self.keypoint_net(input_img_aug)
            target_score, target_uv_pred, target_feat = self.keypoint_net(input_img)
            _, _, Hc, Wc = target_score.shape

            # Normalize uv coordinates
            # TODO: Have a function for the norm and de-norm of 2d coordinate.
            target_uv_norm = target_uv_pred.clone()
            target_uv_norm[:,0] = (target_uv_norm[:,0] / (float(W-1)/2.)) - 1.
            target_uv_norm[:,1] = (target_uv_norm[:,1] / (float(H-1)/2.)) - 1.
            target_uv_norm = target_uv_norm.permute(0, 2, 3, 1)

            source_uv_norm = source_uv_pred.clone()
            source_uv_norm[:,0] = (source_uv_norm[:,0] / (float(W-1)/2.)) - 1.
            source_uv_norm[:,1] = (source_uv_norm[:,1] / (float(H-1)/2.)) - 1.
            source_uv_norm = source_uv_norm.permute(0, 2, 3, 1)

            source_uv_warped_norm = warp_homography_batch(source_uv_norm, homography)
            source_uv_warped = source_uv_warped_norm.clone()

            source_uv_warped[:,:,:,0] = (source_uv_warped[:,:,:,0] + 1) * (float(W-1)/2.)
            source_uv_warped[:,:,:,1] = (source_uv_warped[:,:,:,1] + 1) * (float(H-1)/2.)
            source_uv_warped = source_uv_warped.permute(0, 3, 1, 2)

            target_uv_resampled = torch.nn.functional.grid_sample(target_uv_pred, source_uv_warped_norm, mode='nearest', align_corners=True)

            target_uv_resampled_norm = target_uv_resampled.clone()
            target_uv_resampled_norm[:,0] = (target_uv_resampled_norm[:,0] / (float(W-1)/2.)) - 1.
            target_uv_resampled_norm[:,1] = (target_uv_resampled_norm[:,1] / (float(H-1)/2.)) - 1.
            target_uv_resampled_norm = target_uv_resampled_norm.permute(0, 2, 3, 1)

            # Border mask
            border_mask_ori = torch.ones(B,Hc,Wc)
            border_mask_ori[:,0] = 0
            border_mask_ori[:,Hc-1] = 0
            border_mask_ori[:,:,0] = 0
            border_mask_ori[:,:,Wc-1] = 0
            border_mask_ori = border_mask_ori.gt(1e-3).to(device)

            # Out-of-bourder(OOB) mask. Not nessesary in our case, since it's prevented at HA procedure already. Kept here for future usage.
            oob_mask2 = source_uv_warped_norm[:,:,:,0].lt(1) & source_uv_warped_norm[:,:,:,0].gt(-1) & source_uv_warped_norm[:,:,:,1].lt(1) & source_uv_warped_norm[:,:,:,1].gt(-1)
            border_mask = border_mask_ori & oob_mask2

            d_uv_mat_abs = torch.abs(source_uv_warped.view(B,2,-1).unsqueeze(3) - target_uv_pred.view(B,2,-1).unsqueeze(2))
            d_uv_l2_mat = torch.norm(d_uv_mat_abs, p=2, dim=1)
            d_uv_l2_min, d_uv_l2_min_index = d_uv_l2_mat.min(dim=2)

            dist_norm_valid_mask = d_uv_l2_min.lt(4) & border_mask.view(B,Hc*Wc)

            # Keypoint loss
            loc_loss = d_uv_l2_min[dist_norm_valid_mask].mean()
            loss_2d += self.keypoint_loss_weight * loc_loss.mean()

            #Desc Head Loss, per-pixel level triplet loss from https://arxiv.org/pdf/1902.11046.pdf.
            if self.descriptor_loss:
                metric_loss, recall_2d = build_descriptor_loss(source_feat, target_feat, source_uv_norm.detach(), source_uv_warped_norm.detach(), source_uv_warped, keypoint_mask=border_mask, relax_field=self.relax_field)
                loss_2d += self.descriptor_loss_weight * metric_loss * 2
            else:
                _, recall_2d = build_descriptor_loss(source_feat, target_feat, source_uv_norm, source_uv_warped_norm, source_uv_warped, keypoint_mask=border_mask, relax_field=self.relax_field, eval_only=True)

            #Score Head Loss
            target_score_associated = target_score.view(B,Hc*Wc).gather(1, d_uv_l2_min_index).view(B,Hc,Wc).unsqueeze(1)
            dist_norm_valid_mask = dist_norm_valid_mask.view(B,Hc,Wc).unsqueeze(1) & border_mask.unsqueeze(1)
            d_uv_l2_min = d_uv_l2_min.view(B,Hc,Wc).unsqueeze(1)
            loc_err = d_uv_l2_min[dist_norm_valid_mask]

            usp_loss = (target_score_associated[dist_norm_valid_mask] + source_score[dist_norm_valid_mask]) * (loc_err - loc_err.mean())
            loss_2d += self.score_loss_weight * usp_loss.mean()

            target_score_resampled = torch.nn.functional.grid_sample(target_score, source_uv_warped_norm.detach(), mode='bilinear', align_corners=True)

            loss_2d += self.score_loss_weight * torch.nn.functional.mse_loss(target_score_resampled[border_mask.unsqueeze(1)],
                                                                                source_score[border_mask.unsqueeze(1)]).mean() * 2
            if self.with_io:
                # Compute IO loss
                top_k_score1, top_k_indice1 = source_score.view(B,Hc*Wc).topk(self.top_k2, dim=1, largest=False)
                top_k_mask1 = torch.zeros(B, Hc * Wc).to(device)
                top_k_mask1.scatter_(1, top_k_indice1, value=1)
                top_k_mask1 = top_k_mask1.gt(1e-3).view(B,Hc,Wc)

                top_k_score2, top_k_indice2 = target_score.view(B,Hc*Wc).topk(self.top_k2, dim=1, largest=False)
                top_k_mask2 = torch.zeros(B, Hc * Wc).to(device)
                top_k_mask2.scatter_(1, top_k_indice2, value=1)
                top_k_mask2 = top_k_mask2.gt(1e-3).view(B,Hc,Wc)

                source_uv_norm_topk = source_uv_norm[top_k_mask1].view(B, self.top_k2, 2)
                target_uv_norm_topk = target_uv_norm[top_k_mask2].view(B, self.top_k2, 2)
                source_uv_warped_norm_topk = source_uv_warped_norm[top_k_mask1].view(B, self.top_k2, 2)

                source_feat_topk = torch.nn.functional.grid_sample(source_feat, source_uv_norm_topk.unsqueeze(1), align_corners=True).squeeze()
                target_feat_topk = torch.nn.functional.grid_sample(target_feat, target_uv_norm_topk.unsqueeze(1), align_corners=True).squeeze()

                source_feat_topk = source_feat_topk.div(torch.norm(source_feat_topk, p=2, dim=1).unsqueeze(1))
                target_feat_topk = target_feat_topk.div(torch.norm(target_feat_topk, p=2, dim=1).unsqueeze(1))

                dmat = torch.bmm(source_feat_topk.permute(0,2,1), target_feat_topk)
                dmat = torch.sqrt(2 - 2 * torch.clamp(dmat, min=-1, max=1))
                dmat_soft_min = torch.sum(dmat* dmat.mul(-1).softmax(dim=2), dim=2)
                dmat_min, dmat_min_indice = torch.min(dmat, dim=2)

                target_uv_norm_topk_associated = target_uv_norm_topk.gather(1, dmat_min_indice.unsqueeze(2).repeat(1,1,2))
                point_pair = torch.cat([source_uv_norm_topk, target_uv_norm_topk_associated, dmat_min.unsqueeze(2)], 2)

                inlier_pred = self.io_net(point_pair.permute(0,2,1).unsqueeze(3)).squeeze()

                target_uv_norm_topk_associated_raw = target_uv_norm_topk_associated.clone()
                target_uv_norm_topk_associated_raw[:,:,0] = (target_uv_norm_topk_associated_raw[:,:,0] + 1) * (float(W-1)/2.)
                target_uv_norm_topk_associated_raw[:,:,1] = (target_uv_norm_topk_associated_raw[:,:,1] + 1) * (float(H-1)/2.)

                source_uv_warped_norm_topk_raw = source_uv_warped_norm_topk.clone()
                source_uv_warped_norm_topk_raw[:,:,0] = (source_uv_warped_norm_topk_raw[:,:,0] + 1) * (float(W-1)/2.)
                source_uv_warped_norm_topk_raw[:,:,1] = (source_uv_warped_norm_topk_raw[:,:,1] + 1) * (float(H-1)/2.)


                matching_score = torch.norm(target_uv_norm_topk_associated_raw - source_uv_warped_norm_topk_raw, p=2, dim=2)
                inlier_mask = matching_score.lt(4)
                inlier_gt = 2 * inlier_mask.float() - 1

                if inlier_mask.sum() > 10:

                    io_loss = torch.nn.functional.mse_loss(inlier_pred, inlier_gt)
                    loss_2d += self.keypoint_loss_weight * io_loss


            if debug and torch.cuda.current_device() == 0:
                # Generate visualization data
                vis_ori = (input_img[0].permute(1, 2, 0).detach().cpu().clone().squeeze() )
                vis_ori -= vis_ori.min()
                vis_ori /= vis_ori.max()
                vis_ori = (vis_ori* 255).numpy().astype(np.uint8)

                if self.use_color is False:
                    vis_ori = cv2.cvtColor(vis_ori, cv2.COLOR_GRAY2BGR)

                _, top_k = target_score.view(B,-1).topk(self.top_k2, dim=1) #JT: Target frame keypoints
                vis_ori = draw_keypoints(vis_ori, target_uv_pred.view(B,2,-1)[:,:,top_k[0].squeeze()],(0,0,255))

                _, top_k = source_score.view(B,-1).topk(self.top_k2, dim=1) #JT: Warped Source frame keypoints
                vis_ori = draw_keypoints(vis_ori, source_uv_warped.view(B,2,-1)[:,:,top_k[0].squeeze()],(255,0,255))

                cm = get_cmap('plasma')
                heatmap = target_score[0].detach().cpu().clone().numpy().squeeze()
                heatmap -= heatmap.min()
                heatmap /= heatmap.max()
                heatmap = cv2.resize(heatmap, (W, H))
                heatmap = cm(heatmap)[:, :, :3]

                self.vis['img_ori'] = np.clip(vis_ori, 0, 255) / 255.
                self.vis['heatmap'] = np.clip(heatmap * 255, 0, 255) / 255.

        return loss_2d, recall_2d
예제 #2
0
class KeypointNetwithIOLoss(torch.nn.Module):
    """
    Model class encapsulating the KeypointNet and the IONet.

    Parameters
    ----------
    keypoint_loss_weight: float
        Keypoint loss weight.
    descriptor_loss_weight: float
        Descriptor loss weight.
    score_loss_weight: float
        Score loss weight.
    keypoint_net_learning_rate: float
        Keypoint net learning rate.
    with_io:
        Use IONet.
    use_color : bool
        Use color or grayscale images.
    do_upsample: bool
        Upsample desnse descriptor map.
    do_cross: bool
        Predict keypoints outside cell borders.
    with_drop : bool
        Use dropout.
    descriptor_loss: bool
        Use descriptor loss.
    kwargs : dict
        Extra parameters
    """
    def __init__(self,
                 training_mode,
                 keypoint_loss_weight=1.0,
                 descriptor_loss_weight=2.0,
                 score_loss_weight=1.0,
                 keypoint_net_learning_rate=0.001,
                 with_io=True,
                 use_color=True,
                 do_upsample=True,
                 do_cross=True,
                 descriptor_loss=True,
                 with_drop=True,
                 keypoint_net_type='KeypointNet',
                 pretrained_model=None,
                 **kwargs):

        super().__init__()

        self.keypoint_loss_weight = keypoint_loss_weight
        self.descriptor_loss_weight = descriptor_loss_weight
        self.score_loss_weight = score_loss_weight
        self.keypoint_net_learning_rate = keypoint_net_learning_rate
        self.optim_params = []

        self.cell = 8  # Size of each output cell. Keep this fixed.
        self.border_remove = 4  # Remove points this close to the border.
        self.top_k2 = 300
        self.relax_field = 4

        self.use_color = use_color
        self.descriptor_loss = descriptor_loss
        self.training_mode = training_mode

        # Initialize KeypointNet
        if pretrained_model == None:
            if keypoint_net_type == 'KeypointNet':
                self.keypoint_net = KeypointNet(use_color=use_color,
                                                do_upsample=do_upsample,
                                                with_drop=with_drop,
                                                do_cross=do_cross)
            elif keypoint_net_type == 'KeypointResnet':
                self.keypoint_net = KeypointResnet(with_drop=with_drop)
            else:
                raise NotImplemented(
                    'Keypoint net type not supported {}'.format(
                        keypoint_net_type))
        else:
            checkpoint = torch.load(pretrained_model)
            model_args = checkpoint['config']['model']['params']
            if 'keypoint_net_type' in checkpoint['config']['model']['params']:
                net_type = checkpoint['config']['model']['params']
            else:
                net_type = KeypointNet  # default when no type is specified
            if net_type is KeypointNet:
                print('keypointNet')
                self.keypoint_net = KeypointNet(
                    use_color=model_args['use_color'],
                    do_upsample=model_args['do_upsample'],
                    do_cross=model_args['do_cross'])
            else:
                print('keypointresnet')
                self.keypoint_net = KeypointResnet()
            self.keypoint_net.load_state_dict(checkpoint['state_dict'])
            print('Loaded KeypointNet from {}'.format(pretrained_model))
            print('KeypointNet params {}'.format(model_args))

        self.keypoint_net = self.keypoint_net.cuda()
        self.add_optimizer_params('KeypointNet',
                                  self.keypoint_net.parameters(),
                                  keypoint_net_learning_rate)

        self.with_io = with_io
        self.io_net = None
        if self.with_io:
            self.io_net = InlierNet(blocks=4)
            self.io_net = self.io_net.cuda()
            self.add_optimizer_params('InlierNet', self.io_net.parameters(),
                                      keypoint_net_learning_rate)

        self.train_metrics = {}
        self.vis = {}
        if torch.cuda.current_device() == 0:
            print('KeypointNetwithIOLoss:: with io {} with descriptor loss {}'.
                  format(self.with_io, self.descriptor_loss))

    def add_optimizer_params(self, name, params, lr):
        self.optim_params.append({
            'name':
            name,
            'lr':
            lr,
            'original_lr':
            lr,
            'params':
            filter(lambda p: p.requires_grad, params)
        })

    def forward(self, data, debug=False):
        """
        Processes a batch.

        Parameters
        ----------
        batch : dict
            Input batch.
        debug : bool
            True if to compute debug data (stored in self.vis).

        Returns
        -------
        output : dict
            Dictionary containing the output of depth and pose networks
        """

        loss_2d = 0

        if self.training:

            B, _, H, W = data['image'].shape
            device = data['image'].device

            reprojection = Reprojection(width=1024, height=768, verbose=False)

            recall_2d = 0
            inlier_cnt = 0

            input_img_0 = data['image']
            input_img_aug_0 = data['image_aug']
            # metainfo: [source_position_map, target_position_map, source_reflectance_map, target_R_CW, target_t_CW]
            metainfo = data['metainfo']
            source_frame = data['source_frame']
            target_frame = data['target_frame']
            scenter2tcenter = data['scenter2tcenter']

            input_img = to_color_normalized(input_img_0.clone())
            input_img_aug = to_color_normalized(input_img_aug_0.clone())

            # Get network outputs
            # score: (B, 1, H_out, W_out)
            # uv_pred: (B, 2, H_out, W_out)
            # feat: (B, 256, H_out, W_out)
            source_score, source_uv_pred, source_feat = self.keypoint_net(
                input_img_aug)
            target_score, target_uv_pred, target_feat = self.keypoint_net(
                input_img)
            _, _, Hc, Wc = target_score.shape

            # Normalize uv coordinates
            # TODO: Have a function for the norm and de-norm of 2d coordinate.
            target_uv_norm = target_uv_pred.clone()
            target_uv_norm[:, 0] = (target_uv_norm[:, 0] /
                                    (float(W - 1) / 2.)) - 1.
            target_uv_norm[:, 1] = (target_uv_norm[:, 1] /
                                    (float(H - 1) / 2.)) - 1.
            target_uv_norm = target_uv_norm.permute(0, 2, 3, 1)

            source_uv_norm = source_uv_pred.clone()
            source_uv_norm[:, 0] = (source_uv_norm[:, 0] /
                                    (float(W - 1) / 2.)) - 1.
            source_uv_norm[:, 1] = (source_uv_norm[:, 1] /
                                    (float(H - 1) / 2.)) - 1.
            source_uv_norm = source_uv_norm.permute(0, 2, 3, 1)

            if self.training_mode=='scene' or self.training_mode=='cam' or self.training_mode=='con'\
                or self.training_mode=='scene+HA' or self.training_mode=='cam+HA' or self.training_mode=='con+HA':
                # get source_uv with frame transformation, then normalize
                # source_uv_pred dim: (B, 2, H_out, W_out)
                source_uv_warped, inliers, _, _ = warp_frame2frame_batch(
                    source_uv_pred,
                    metainfo,
                    source_frame,
                    target_frame,
                    scenter2tcenter,
                    projection=reprojection)
                source_uv_warped = source_uv_warped.float()
                # normalization
                source_uv_warped_norm = source_uv_warped.clone()
                source_uv_warped_norm[:, 0] = (source_uv_warped_norm[:, 0] /
                                               (float(W - 1) / 2.)) - 1.
                source_uv_warped_norm[:, 1] = (source_uv_warped_norm[:, 1] /
                                               (float(H - 1) / 2.)) - 1.
                source_uv_warped_norm = source_uv_warped_norm.permute(
                    0, 2, 3, 1)
            if self.training_mode == 'HA' or self.training_mode == 'HA_wo_sp':
                homography = data['homography']
                source_uv_warped_norm = warp_homography_batch(
                    source_uv_norm, homography)
                source_uv_warped = source_uv_warped_norm.clone()
                source_uv_warped[:, :, :, 0] = (source_uv_warped[:, :, :, 0] +
                                                1) * (float(W - 1) / 2.)
                source_uv_warped[:, :, :,
                                 1] = (source_uv_warped[:, :, :, 1] + 1) * (
                                     float(H - 1) / 2.)  # (B,H,W,C)
                source_uv_warped = source_uv_warped.permute(0, 3, 1,
                                                            2)  # (B,C,H,W)
            if self.training_mode == 'scene+HA' or self.training_mode == 'cam+HA' or self.training_mode == 'con+HA':
                homography = data['homography']
                source_uv_warped_norm = warp_homography_batch(
                    source_uv_warped_norm, homography)
                source_uv_warped = source_uv_warped_norm.clone()
                source_uv_warped[:, :, :, 0] = (source_uv_warped[:, :, :, 0] +
                                                1) * (float(W - 1) / 2.)
                source_uv_warped[:, :, :,
                                 1] = (source_uv_warped[:, :, :, 1] + 1) * (
                                     float(H - 1) / 2.)  # (B,H,W,C)
                source_uv_warped = source_uv_warped.permute(0, 3, 1,
                                                            2)  # (B,C,H,W)

            target_uv_resampled = torch.nn.functional.grid_sample(
                target_uv_pred,
                source_uv_warped_norm.float(),
                mode='nearest',
                align_corners=True)

            target_uv_resampled_norm = target_uv_resampled.clone()
            target_uv_resampled_norm[:, 0] = (target_uv_resampled_norm[:, 0] /
                                              (float(W - 1) / 2.)) - 1.
            target_uv_resampled_norm[:, 1] = (target_uv_resampled_norm[:, 1] /
                                              (float(H - 1) / 2.)) - 1.
            target_uv_resampled_norm = target_uv_resampled_norm.permute(
                0, 2, 3, 1)

            # Border mask
            border_mask_ori = torch.ones(B, Hc, Wc)
            border_mask_ori[:, 0] = 0
            border_mask_ori[:, Hc - 1] = 0
            border_mask_ori[:, :, 0] = 0
            border_mask_ori[:, :, Wc - 1] = 0
            border_mask_ori = border_mask_ori.gt(1e-3).to(device)

            # Out-of-bourder(OOB) mask. Not nessesary in our case, since it's prevented at HA procedure already. Kept here for future usage.
            oob_mask2 = source_uv_warped_norm[:, :, :, 0].lt(
                1) & source_uv_warped_norm[:, :, :, 0].gt(
                    -1) & source_uv_warped_norm[:, :, :, 1].lt(
                        1) & source_uv_warped_norm[:, :, :, 1].gt(-1)
            border_mask = border_mask_ori & oob_mask2
            if not (self.training_mode == 'HA'
                    or self.training_mode == 'HA_wo_sp'):
                # Treat outliers from Hypersim projection as out of bournder points.
                inliers = inliers.squeeze()
                border_mask = border_mask & inliers

            d_uv_mat_abs = torch.abs(
                source_uv_warped.view(B, 2, -1).unsqueeze(3) -
                target_uv_pred.view(B, 2, -1).unsqueeze(2))
            d_uv_l2_mat = torch.norm(d_uv_mat_abs, p=2, dim=1)
            d_uv_l2_min, d_uv_l2_min_index = d_uv_l2_mat.min(dim=2)

            dist_norm_valid_mask = d_uv_l2_min.lt(4) & border_mask.view(
                B, Hc * Wc)

            # Keypoint loss
            loc_loss = d_uv_l2_min[dist_norm_valid_mask].mean()
            loss_2d += self.keypoint_loss_weight * loc_loss.mean()

            # Desc Head Loss, per-pixel level triplet loss from https://arxiv.org/pdf/1902.11046.pdf.
            if self.descriptor_loss:
                metric_loss, recall_2d = build_descriptor_loss(
                    source_feat,
                    target_feat,
                    source_uv_norm.detach(),
                    source_uv_warped_norm.detach(),
                    source_uv_warped,
                    keypoint_mask=border_mask,
                    relax_field=self.relax_field)
                loss_2d += self.descriptor_loss_weight * metric_loss * 2
            else:
                _, recall_2d = build_descriptor_loss(
                    source_feat,
                    target_feat,
                    source_uv_norm,
                    source_uv_warped_norm,
                    source_uv_warped,
                    keypoint_mask=border_mask,
                    relax_field=self.relax_field,
                    eval_only=True)

            #Score Head Loss
            target_score_associated = target_score.view(B, Hc * Wc).gather(
                1, d_uv_l2_min_index).view(B, Hc, Wc).unsqueeze(1)
            dist_norm_valid_mask = dist_norm_valid_mask.view(
                B, Hc, Wc).unsqueeze(1) & border_mask.unsqueeze(1)
            d_uv_l2_min = d_uv_l2_min.view(B, Hc, Wc).unsqueeze(1)
            loc_err = d_uv_l2_min[dist_norm_valid_mask]

            usp_loss = (target_score_associated[dist_norm_valid_mask] +
                        source_score[dist_norm_valid_mask]) * (loc_err -
                                                               loc_err.mean())
            loss_2d += self.score_loss_weight * usp_loss.mean()

            target_score_resampled = torch.nn.functional.grid_sample(
                target_score,
                source_uv_warped_norm.detach(),
                mode='bilinear',
                align_corners=True)

            loss_2d += self.score_loss_weight * torch.nn.functional.mse_loss(
                target_score_resampled[border_mask.unsqueeze(1)],
                source_score[border_mask.unsqueeze(1)]).mean() * 2
            if self.with_io:
                # Compute IO loss
                top_k_score1, top_k_indice1 = source_score.view(
                    B, Hc * Wc).topk(self.top_k2, dim=1, largest=False)
                top_k_mask1 = torch.zeros(B, Hc * Wc).to(device)
                top_k_mask1.scatter_(1, top_k_indice1, value=1)
                top_k_mask1 = top_k_mask1.gt(1e-3).view(B, Hc, Wc)

                top_k_score2, top_k_indice2 = target_score.view(
                    B, Hc * Wc).topk(self.top_k2, dim=1, largest=False)
                top_k_mask2 = torch.zeros(B, Hc * Wc).to(device)
                top_k_mask2.scatter_(1, top_k_indice2, value=1)
                top_k_mask2 = top_k_mask2.gt(1e-3).view(B, Hc, Wc)

                source_uv_norm_topk = source_uv_norm[top_k_mask1].view(
                    B, self.top_k2, 2)
                target_uv_norm_topk = target_uv_norm[top_k_mask2].view(
                    B, self.top_k2, 2)
                source_uv_warped_norm_topk = source_uv_warped_norm[
                    top_k_mask1].view(B, self.top_k2, 2)

                source_feat_topk = torch.nn.functional.grid_sample(
                    source_feat,
                    source_uv_norm_topk.unsqueeze(1),
                    align_corners=True).squeeze()
                target_feat_topk = torch.nn.functional.grid_sample(
                    target_feat,
                    target_uv_norm_topk.unsqueeze(1),
                    align_corners=True).squeeze()

                source_feat_topk = source_feat_topk.div(
                    torch.norm(source_feat_topk, p=2, dim=1).unsqueeze(1))
                target_feat_topk = target_feat_topk.div(
                    torch.norm(target_feat_topk, p=2, dim=1).unsqueeze(1))

                dmat = torch.bmm(source_feat_topk.permute(0, 2, 1),
                                 target_feat_topk)
                dmat = torch.sqrt(2 - 2 * torch.clamp(dmat, min=-1, max=1))
                dmat_soft_min = torch.sum(dmat * dmat.mul(-1).softmax(dim=2),
                                          dim=2)
                dmat_min, dmat_min_indice = torch.min(dmat, dim=2)

                target_uv_norm_topk_associated = target_uv_norm_topk.gather(
                    1,
                    dmat_min_indice.unsqueeze(2).repeat(1, 1, 2))
                point_pair = torch.cat([
                    source_uv_norm_topk, target_uv_norm_topk_associated,
                    dmat_min.unsqueeze(2)
                ], 2)

                inlier_pred = self.io_net(
                    point_pair.permute(0, 2, 1).unsqueeze(3)).squeeze()

                target_uv_norm_topk_associated_raw = target_uv_norm_topk_associated.clone(
                )
                target_uv_norm_topk_associated_raw[:, :, 0] = (
                    target_uv_norm_topk_associated_raw[:, :, 0] +
                    1) * (float(W - 1) / 2.)
                target_uv_norm_topk_associated_raw[:, :, 1] = (
                    target_uv_norm_topk_associated_raw[:, :, 1] +
                    1) * (float(H - 1) / 2.)

                source_uv_warped_norm_topk_raw = source_uv_warped_norm_topk.clone(
                )
                source_uv_warped_norm_topk_raw[:, :, 0] = (
                    source_uv_warped_norm_topk_raw[:, :, 0] +
                    1) * (float(W - 1) / 2.)
                source_uv_warped_norm_topk_raw[:, :, 1] = (
                    source_uv_warped_norm_topk_raw[:, :, 1] +
                    1) * (float(H - 1) / 2.)

                matching_score = torch.norm(
                    target_uv_norm_topk_associated_raw -
                    source_uv_warped_norm_topk_raw,
                    p=2,
                    dim=2)
                inlier_mask = matching_score.lt(4)
                inlier_gt = 2 * inlier_mask.float() - 1

                if inlier_mask.sum() > 10:

                    io_loss = torch.nn.functional.mse_loss(
                        inlier_pred, inlier_gt)
                    loss_2d += self.keypoint_loss_weight * io_loss

            if debug and torch.cuda.current_device() == 0:
                # Generate visualization data
                vis_ori = (input_img[0].permute(
                    1, 2, 0).detach().cpu().clone().squeeze())
                vis_ori -= vis_ori.min()
                vis_ori /= vis_ori.max()
                vis_ori = (vis_ori * 255).numpy().astype(np.uint8)

                vis_tar = (input_img[0].permute(
                    1, 2, 0).detach().cpu().clone().squeeze())
                vis_tar -= vis_tar.min()
                vis_tar /= vis_tar.max()
                vis_tar = (vis_tar * 255).numpy().astype(np.uint8)

                vis_src = (input_img_aug[0].permute(
                    1, 2, 0).detach().cpu().clone().squeeze())
                vis_src -= vis_src.min()
                vis_src /= vis_src.max()
                vis_src = (vis_src * 255).numpy().astype(np.uint8)

                if self.use_color is False:
                    vis_ori = cv2.cvtColor(vis_ori, cv2.COLOR_GRAY2BGR)

                _, top_k = target_score.view(B, -1).topk(
                    self.top_k2, dim=1)  #JT: Target frame keypoints
                vis_ori = draw_keypoints(
                    vis_ori,
                    target_uv_pred.view(B, 2, -1)[:, :, top_k[0].squeeze()],
                    (0, 0, 255))

                _, top_k = source_score.view(B, -1).topk(
                    self.top_k2, dim=1)  #JT: Warped Source frame keypoints
                vis_ori = draw_keypoints(
                    vis_ori,
                    source_uv_warped.view(B, 2, -1)[:, :, top_k[0].squeeze()],
                    (255, 0, 255))

                cm = get_cmap('plasma')
                heatmap = target_score[0].detach().cpu().clone().numpy(
                ).squeeze()
                heatmap -= heatmap.min()
                heatmap /= heatmap.max()
                heatmap = cv2.resize(heatmap, (W, H))
                heatmap = cm(heatmap)[:, :, :3]

                self.vis['img_ori'] = np.clip(vis_ori, 0, 255) / 255.
                self.vis['heatmap'] = np.clip(heatmap * 255, 0, 255) / 255.
                # Visualization of projection. Uncomment to activate---does not work if use frame2frame combined with HA
                # self.vis['img_src'] = np.clip(vis_src, 0, 255) / 255.
                # self.vis['img_tar'] = np.clip(vis_tar, 0, 255) / 255.

                # import itertools
                # W_a = [*range(0, 512, 4)]
                # H_a = [*range(0, 384, 4)]
                # px_source = torch.tensor(list(itertools.product(W_a, H_a))).T.float().cuda()
                # px_source = px_source.view(2, len(H_a), len(W_a))
                # px_source = torch.unsqueeze(px_source, 0)
                # # source_uv_warped, inliers = warp_frame2frame_batch(
                # #     px_source, torch.unsqueeze(metainfo[0], 0), torch.unsqueeze(source_frame[0], 0),
                # #     torch.unsqueeze(target_frame[0], 0), torch.unsqueeze(scenter2tcenter[0], 0), projection=reprojection)
                # source_uv_warped, inliers, source_name, target_name = warp_frame2frame_batch(
                #     px_source, metainfo, source_frame,
                #     target_frame, scenter2tcenter, projection=reprojection)
                # source_uv_warped = source_uv_warped
                # inliers = inliers.squeeze()
                # outliers = ~inliers
                # source_uv_inliers = px_source[:,:,inliers]
                # source_uv_outliers = px_source[:,:,outliers]
                # source_uv_warped_inliers = source_uv_warped[:,:,inliers]
                # source_uv_warped_outliers = source_uv_warped[:,:,outliers]

                # vis_src_masked = draw_keypoints(vis_src, source_uv_outliers,(255,0,255))
                # vis_src_masked = draw_keypoints(vis_src_masked, source_uv_inliers,(0,0,255))

                # vis_tar_masked = draw_keypoints(vis_tar, source_uv_warped_outliers,(255,0,255))
                # vis_tar_masked = draw_keypoints(vis_tar_masked, source_uv_warped_inliers,(0,0,255))

                # self.vis['img_src_masked'] = np.clip(vis_src_masked, 0, 255) / 255.
                # self.vis['img_tar_masked'] = np.clip(vis_tar_masked, 0, 255) / 255.
                # self.vis['source_name'] = source_name[0]
                # self.vis['target_name'] = target_name[0]

        return loss_2d, recall_2d