def getGrid(self, im1, im2, H, scaling_steps=3):
        h1, w1 = int(im1.shape[0] / (2**scaling_steps)), int(
            im1.shape[1] / (2**scaling_steps))
        device = torch.device("cpu")

        fmap_pos1 = grid_positions(h1, w1, device)
        pos1 = upscale_positions(
            fmap_pos1, scaling_steps=scaling_steps).data.cpu().numpy()

        pos1[[0, 1]] = pos1[[1, 0]]

        ones = np.ones((1, pos1.shape[1]))
        pos1Homo = np.vstack((pos1, ones))
        pos2Homo = np.dot(H, pos1Homo)
        pos2Homo = pos2Homo / pos2Homo[2, :]
        pos2 = pos2Homo[0:2, :]

        pos1[[0, 1]] = pos1[[1, 0]]
        pos2[[0, 1]] = pos2[[1, 0]]
        pos1 = pos1.astype(np.float32)
        pos2 = pos2.astype(np.float32)

        ids = []
        for i in range(pos2.shape[1]):
            x, y = pos2[:, i]

            if (2 < x < (im1.shape[0] - 2) and 2 < y < (im1.shape[1] - 2)):
                ids.append(i)
        pos1 = pos1[:, ids]
        pos2 = pos2[:, ids]

        return pos1, pos2
예제 #2
0
    def forward(self, batch):
        batch = F.relu(batch)

        depth_wise_max = torch.max(batch, dim=1)[0]
        is_depth_wise_max = (batch == depth_wise_max)

        local_max = F.max_pool2d(batch, 3, stride=1, padding=1)
        is_local_max = (batch == local_max)

        detection = torch.min(is_depth_wise_max, is_local_max)
        grid_keypoints = torch.nonzero(detection)[:, 2:]
        keypoints = upscale_positions(grid_keypoints, scaling_steps=3)

        return grid_keypoints, keypoints
예제 #3
0
def process_multiscale(image, model, scales=[.5, 1, 2]):
    b, _, h_init, w_init = image.size()
    device = image.device
    assert (b == 1)

    all_keypoints = torch.zeros([3, 0])
    all_descriptors = torch.zeros(
        [model.dense_feature_extraction.num_channels, 0])
    all_scores = torch.zeros(0)

    previous_dense_features = None
    banned = None
    for idx, scale in enumerate(scales):
        current_image = F.interpolate(image,
                                      scale_factor=scale,
                                      mode='bilinear',
                                      align_corners=True)
        _, _, h_level, w_level = current_image.size()

        dense_features = model.dense_feature_extraction(current_image)
        del current_image

        _, _, h, w = dense_features.size()

        # Sum the feature maps.
        if previous_dense_features is not None:
            dense_features += F.interpolate(previous_dense_features,
                                            size=[h, w],
                                            mode='bilinear',
                                            align_corners=True)
            del previous_dense_features

        # Recover detections.
        detections = model.detection(dense_features)
        if banned is not None:
            banned = F.interpolate(banned.float(), size=[h, w]).bool()
            detections = torch.min(detections, ~banned)
            banned = torch.max(
                torch.max(detections, dim=1)[0].unsqueeze(1), banned)
        else:
            banned = torch.max(detections, dim=1)[0].unsqueeze(1)
        fmap_pos = torch.nonzero(detections[0].cpu()).t()
        del detections

        # Recover displacements.
        displacements = model.localization(dense_features)[0].cpu()
        displacements_i = displacements[0, fmap_pos[0, :], fmap_pos[1, :],
                                        fmap_pos[2, :]]
        displacements_j = displacements[1, fmap_pos[0, :], fmap_pos[1, :],
                                        fmap_pos[2, :]]
        del displacements

        mask = torch.min(
            torch.abs(displacements_i) < 0.5,
            torch.abs(displacements_j) < 0.5)
        fmap_pos = fmap_pos[:, mask]
        valid_displacements = torch.stack(
            [displacements_i[mask], displacements_j[mask]], dim=0)
        del mask, displacements_i, displacements_j

        fmap_keypoints = fmap_pos[1:, :].float() + valid_displacements
        del valid_displacements

        try:
            raw_descriptors, _, ids = interpolate_dense_features(
                fmap_keypoints.to(device), dense_features[0])
        except EmptyTensorError:
            continue
        fmap_pos = fmap_pos[:, ids]
        fmap_keypoints = fmap_keypoints[:, ids]
        del ids

        keypoints = upscale_positions(fmap_keypoints, scaling_steps=2)
        del fmap_keypoints

        descriptors = F.normalize(raw_descriptors, dim=0).cpu()
        del raw_descriptors

        keypoints[0, :] *= h_init / h_level
        keypoints[1, :] *= w_init / w_level

        fmap_pos = fmap_pos.cpu()
        keypoints = keypoints.cpu()

        keypoints = torch.cat([
            keypoints,
            torch.ones([1, keypoints.size(1)]) * 1 / scale,
        ],
                              dim=0)

        scores = dense_features[0, fmap_pos[0, :], fmap_pos[1, :],
                                fmap_pos[2, :]].cpu() / (idx + 1)
        del fmap_pos

        all_keypoints = torch.cat([all_keypoints, keypoints], dim=1)
        all_descriptors = torch.cat([all_descriptors, descriptors], dim=1)
        all_scores = torch.cat([all_scores, scores], dim=0)
        del keypoints, descriptors

        previous_dense_features = dense_features
        del dense_features
    del previous_dense_features, banned

    keypoints = all_keypoints.t().numpy()
    del all_keypoints
    scores = all_scores.numpy()
    del all_scores
    descriptors = all_descriptors.t().numpy()
    del all_descriptors
    return keypoints, scores, descriptors
예제 #4
0
파일: loss.py 프로젝트: yangsuhui/d2-net
def loss_function(model,
                  batch,
                  device,
                  margin=1,
                  safe_radius=4,
                  scaling_steps=3):
    output = model({
        'image1': batch['image1'].to(device),
        'image2': batch['image2'].to(device)
    })

    loss = torch.tensor(np.array([0], dtype=np.float32), device=device)
    has_grad = False

    n_valid_samples = 0
    for idx_in_batch in range(batch['image1'].size(0)):
        # Annotations
        depth1 = batch['depth1'][idx_in_batch, :, :].to(device)  # [h1, w1]
        intrinsics1 = batch['intrinsics1'][idx_in_batch, :].to(device)  # [9]
        pose1 = batch['pose1'][idx_in_batch, :].view(4, 4).to(device)  # [4, 4]
        bbox1 = batch['bbox1'][idx_in_batch, :].to(device)  # [2]

        depth2 = batch['depth2'][idx_in_batch, :, :].to(device)
        intrinsics2 = batch['intrinsics2'][idx_in_batch, :].to(device)
        pose2 = batch['pose2'][idx_in_batch, :].view(4, 4).to(device)
        bbox2 = batch['bbox2'][idx_in_batch, :].to(device)

        # Network output
        dense_features1 = output['dense_features1'][idx_in_batch, :, :, :]
        c, h1, w1 = dense_features1.size()
        scores1 = output['scores1'][idx_in_batch, :, :].view(-1)

        dense_features2 = output['dense_features2'][idx_in_batch, :, :, :]
        _, h2, w2 = dense_features2.size()
        scores2 = output['scores2'][idx_in_batch, :, :]

        all_descriptors1 = F.normalize(dense_features1.view(c, -1), dim=0)
        descriptors1 = all_descriptors1

        all_descriptors2 = F.normalize(dense_features2.view(c, -1), dim=0)

        # Warp the positions from image 1 to image 2
        fmap_pos1 = grid_positions(h1, w1, device)
        pos1 = upscale_positions(fmap_pos1, scaling_steps=scaling_steps)
        try:
            pos1, pos2, ids = warp(pos1, depth1, intrinsics1, pose1, bbox1,
                                   depth2, intrinsics2, pose2, bbox2)  # [2, _]
        except EmptyTensorError:
            continue
        fmap_pos1 = fmap_pos1[:, ids]
        descriptors1 = descriptors1[:, ids]
        scores1 = scores1[ids]

        # Skip the pair if not enough GT correspondences are available
        if ids.size(0) < 128:
            continue

        # Descriptors at the corresponding positions
        fmap_pos2 = torch.round(
            downscale_positions(pos2, scaling_steps=scaling_steps)).long()
        descriptors2 = F.normalize(dense_features2[:, fmap_pos2[0, :],
                                                   fmap_pos2[1, :]],
                                   dim=0)
        positive_distance = 2 - 2 * (descriptors1.t().unsqueeze(
            1) @ descriptors2.t().unsqueeze(2)).squeeze()

        all_fmap_pos2 = grid_positions(h2, w2, device)
        position_distance = torch.max(torch.abs(
            fmap_pos2.unsqueeze(2).float() - all_fmap_pos2.unsqueeze(1)),
                                      dim=0)[0]
        is_out_of_safe_radius = position_distance > safe_radius
        distance_matrix = 2 - 2 * (descriptors1.t() @ all_descriptors2)
        negative_distance2 = torch.min(
            distance_matrix + (1 - is_out_of_safe_radius.float()) * 10.,
            dim=1)[0]

        all_fmap_pos1 = grid_positions(h1, w1, device)
        position_distance = torch.max(torch.abs(
            fmap_pos1.unsqueeze(2).float() - all_fmap_pos1.unsqueeze(1)),
                                      dim=0)[0]
        is_out_of_safe_radius = position_distance > safe_radius
        distance_matrix = 2 - 2 * (descriptors2.t() @ all_descriptors1)
        negative_distance1 = torch.min(
            distance_matrix + (1 - is_out_of_safe_radius.float()) * 10.,
            dim=1)[0]

        diff = positive_distance - torch.min(negative_distance1,
                                             negative_distance2)

        scores2 = scores2[fmap_pos2[0, :], fmap_pos2[1, :]]

        loss = loss + torch.sum(scores1 * scores2 * F.relu(margin + diff)
                                ) / torch.sum(scores1 * scores2)
        has_grad = True
        n_valid_samples += 1

        if batch['batch_idx'] % batch['log_interval'] == 0:
            pos1_aux = pos1.cpu().numpy()
            pos2_aux = pos2.cpu().numpy()
            k = pos1_aux.shape[1]
            col = np.random.rand(k, 3)
            n_sp = 4
            plt.figure()
            plt.subplot(1, n_sp, 1)
            im1 = imshow_image(
                batch['image1'][idx_in_batch, :, :, :].cpu().numpy(),
                preprocessing=batch['preprocessing'])
            plt.imshow(im1)
            plt.scatter(pos1_aux[1, :],
                        pos1_aux[0, :],
                        s=0.25**2,
                        c=col,
                        marker=',',
                        alpha=0.5)
            plt.axis('off')
            plt.subplot(1, n_sp, 2)
            plt.imshow(
                output['scores1'][idx_in_batch, :, :].data.cpu().numpy(),
                cmap='Reds')
            plt.axis('off')
            plt.subplot(1, n_sp, 3)
            im2 = imshow_image(
                batch['image2'][idx_in_batch, :, :, :].cpu().numpy(),
                preprocessing=batch['preprocessing'])
            plt.imshow(im2)
            plt.scatter(pos2_aux[1, :],
                        pos2_aux[0, :],
                        s=0.25**2,
                        c=col,
                        marker=',',
                        alpha=0.5)
            plt.axis('off')
            plt.subplot(1, n_sp, 4)
            plt.imshow(
                output['scores2'][idx_in_batch, :, :].data.cpu().numpy(),
                cmap='Reds')
            plt.axis('off')
            savefig(
                'train_vis/%s/%02d.%02d.%d.png' %
                ('train' if batch['train'] else 'valid', batch['epoch_idx'],
                 batch['batch_idx'] // batch['log_interval'], idx_in_batch),
                dpi=300)
            plt.close()

    if not has_grad:
        raise NoGradientError

    loss = loss / n_valid_samples

    return loss
예제 #5
0
    def getGrid(self,
                img1,
                img2,
                minCorr=128,
                scaling_steps=3,
                matcher="FLANN"):
        im1 = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB)
        im2 = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB)

        surf = cv2.xfeatures2d.SURF_create(100)
        # surf = cv2.xfeatures2d.SIFT_create()

        kp1, des1 = surf.detectAndCompute(im1, None)
        kp2, des2 = surf.detectAndCompute(im2, None)

        if (len(kp1) < minCorr or len(kp2) < minCorr):
            # print("Less correspondences {} {}".format(len(kp1), len(kp2)))
            return [], []

        if (matcher == "BF"):

            bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=True)
            matches = bf.match(des1, des2)
            matches = sorted(matches, key=lambda x: x.distance)

        elif (matcher == "FLANN"):

            FLANN_INDEX_KDTREE = 0
            index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5)
            search_params = dict(checks=50)
            flann = cv2.FlannBasedMatcher(index_params, search_params)
            matches = flann.knnMatch(des1, des2, k=2)
            good = []
            for m, n in matches:
                if m.distance < 0.7 * n.distance:
                    good.append(m)
            matches = good

        if (len(matches) > 800):
            matches = matches[0:800]
        elif (len(matches) < minCorr):
            return [], []

        # im4 = cv2.drawMatches(im1, kp1, im2, kp2, matches, None, flags=2)
        # cv2.imshow('Image4', im4)
        # cv2.waitKey(0)

        src_pts = np.float32([kp1[m.queryIdx].pt
                              for m in matches]).reshape(-1, 1, 2)
        dst_pts = np.float32([kp2[m.trainIdx].pt
                              for m in matches]).reshape(-1, 1, 2)
        H, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0)
        if H is None:
            return [], []
        # h1, w1 = int(cropSize/(2**scaling_steps)), int(cropSize/(2**scaling_steps))
        h1, w1 = int(im1.shape[0] / (2**scaling_steps)), int(
            im1.shape[1] / (2**scaling_steps))
        device = torch.device("cpu")

        fmap_pos1 = grid_positions(h1, w1, device)
        pos1 = upscale_positions(
            fmap_pos1, scaling_steps=scaling_steps).data.cpu().numpy()

        pos1[[0, 1]] = pos1[[1, 0]]

        ones = np.ones((1, pos1.shape[1]))
        pos1Homo = np.vstack((pos1, ones))
        pos2Homo = np.dot(H, pos1Homo)
        pos2Homo = pos2Homo / pos2Homo[2, :]
        pos2 = pos2Homo[0:2, :]

        pos1[[0, 1]] = pos1[[1, 0]]
        pos2[[0, 1]] = pos2[[1, 0]]
        pos1 = pos1.astype(np.float32)
        pos2 = pos2.astype(np.float32)

        ids = []
        for i in range(pos2.shape[1]):
            x, y = pos2[:, i]
            # if(2 < x < (cropSize-2) and 2 < y < (cropSize-2)):
            # if(20 < x < (im1.shape[0]-20) and 20 < y < (im1.shape[1]-20)):
            if (2 < x < (im1.shape[0] - 2) and 2 < y < (im1.shape[1] - 2)):
                ids.append(i)
        pos1 = pos1[:, ids]
        pos2 = pos2[:, ids]

        # for i in range(0, pos1.shape[1], 20):
        # 	im1 = cv2.circle(im1, (pos1[1, i], pos1[0, i]), 1, (0, 0, 255), 2)
        # for i in range(0, pos2.shape[1], 20):
        # 	im2 = cv2.circle(im2, (pos2[1, i], pos2[0, i]), 1, (0, 0, 255), 2)

        # im3 = cv2.hconcat([im1, im2])

        # for i in range(0, pos1.shape[1], 20):
        # 	im3 = cv2.line(im3, (int(pos1[1, i]), int(pos1[0, i])), (int(pos2[1, i]) +  im1.shape[1], int(pos2[0, i])), (0, 255, 0), 1)

        # cv2.imshow('Image', im1)
        # cv2.imshow('Image2', im2)
        # cv2.imshow('Image3', im3)
        # cv2.waitKey(0)

        return pos1, pos2
예제 #6
0
파일: loss2.py 프로젝트: kinalmehta/d2-net
def loss_function(model,
                  batch,
                  device,
                  margin=1,
                  safe_radius=4,
                  scaling_steps=3,
                  plot=False):
    output = model({
        'image1': batch['image1'].to(device),
        'image2': batch['image2'].to(device)
    })

    loss = torch.tensor(np.array([0], dtype=np.float32), device=device)
    has_grad = False

    n_valid_samples = 0
    for idx_in_batch in range(batch['image1'].size(0)):
        # Annotations
        depth1 = batch['depth1'][idx_in_batch].to(device)  # [h1, w1]
        intrinsics1 = batch['intrinsics1'][idx_in_batch].to(device)  # [3, 3]
        pose1 = batch['pose1'][idx_in_batch].view(4, 4).to(device)  # [4, 4]
        bbox1 = batch['bbox1'][idx_in_batch].to(device)  # [2]

        depth2 = batch['depth2'][idx_in_batch].to(device)
        intrinsics2 = batch['intrinsics2'][idx_in_batch].to(device)
        pose2 = batch['pose2'][idx_in_batch].view(4, 4).to(device)
        bbox2 = batch['bbox2'][idx_in_batch].to(device)

        # Network output
        dense_features1 = output['dense_features1'][idx_in_batch]
        c, h1, w1 = dense_features1.size()
        scores1 = output['scores1'][idx_in_batch].view(-1)

        dense_features2 = output['dense_features2'][idx_in_batch]
        _, h2, w2 = dense_features2.size()
        scores2 = output['scores2'][idx_in_batch]

        all_descriptors1 = F.normalize(dense_features1.view(c, -1), dim=0)
        descriptors1 = all_descriptors1

        all_descriptors2 = F.normalize(dense_features2.view(c, -1), dim=0)

        # Warp the positions from image 1 to image 2
        fmap_pos1 = grid_positions(h1, w1, device)

        hOrig, wOrig = int(batch['image1'].shape[2] / 8), int(
            batch['image1'].shape[3] / 8)
        fmap_pos1Orig = grid_positions(hOrig, wOrig, device)
        pos1 = upscale_positions(fmap_pos1Orig, scaling_steps=scaling_steps)

        try:
            pos1, pos2, ids = warp(pos1, depth1, intrinsics1, pose1, bbox1,
                                   depth2, intrinsics2, pose2, bbox2)
        except EmptyTensorError:
            continue

        # exit(1)

        # H1 = output['H1'][idx_in_batch]
        # H2 = output['H2'][idx_in_batch]

        # try:
        # 	pos1, pos2 = homoAlign(pos1, pos2, H1, H2, device)
        # except IndexError:
        # 	continue

        # ids = idsAlign(pos1, device, h1, w1)

        # img_warp1 = tgm.warp_perspective(batch['image1'].to(device), H1, dsize=(400, 400))
        # img_warp2 = tgm.warp_perspective(batch['image2'].to(device), H2, dsize=(400, 400))

        # drawTraining(img_warp1, img_warp2, pos1, pos2, batch, idx_in_batch, output)

        fmap_pos1 = fmap_pos1[:, ids]
        descriptors1 = descriptors1[:, ids]
        scores1 = scores1[ids]

        # Skip the pair if not enough GT correspondences are available
        if ids.size(0) < 128:
            continue

        # Descriptors at the corresponding positions
        fmap_pos2 = torch.round(
            downscale_positions(pos2, scaling_steps=scaling_steps)).long()

        descriptors2 = F.normalize(dense_features2[:, fmap_pos2[0, :],
                                                   fmap_pos2[1, :]],
                                   dim=0)

        positive_distance = 2 - 2 * (descriptors1.t().unsqueeze(
            1) @ descriptors2.t().unsqueeze(2)).squeeze()

        # positive_distance = getPositiveDistance(descriptors1, descriptors2)

        all_fmap_pos2 = grid_positions(h2, w2, device)
        position_distance = torch.max(torch.abs(
            fmap_pos2.unsqueeze(2).float() - all_fmap_pos2.unsqueeze(1)),
                                      dim=0)[0]
        is_out_of_safe_radius = position_distance > safe_radius

        distance_matrix = 2 - 2 * (descriptors1.t() @ all_descriptors2)
        # distance_matrix = getDistanceMatrix(descriptors1, all_descriptors2)

        negative_distance2 = torch.min(
            distance_matrix + (1 - is_out_of_safe_radius.float()) * 10.,
            dim=1)[0]

        # negative_distance2 = semiHardMine(distance_matrix, is_out_of_safe_radius, positive_distance, margin)

        all_fmap_pos1 = grid_positions(h1, w1, device)
        position_distance = torch.max(torch.abs(
            fmap_pos1.unsqueeze(2).float() - all_fmap_pos1.unsqueeze(1)),
                                      dim=0)[0]
        is_out_of_safe_radius = position_distance > safe_radius

        distance_matrix = 2 - 2 * (descriptors2.t() @ all_descriptors1)
        # distance_matrix = getDistanceMatrix(descriptors2, all_descriptors1)

        negative_distance1 = torch.min(
            distance_matrix + (1 - is_out_of_safe_radius.float()) * 10.,
            dim=1)[0]

        # negative_distance1 = semiHardMine(distance_matrix, is_out_of_safe_radius, positive_distance, margin)

        diff = positive_distance - torch.min(negative_distance1,
                                             negative_distance2)

        scores2 = scores2[fmap_pos2[0, :], fmap_pos2[1, :]]

        loss = loss + (torch.sum(scores1 * scores2 * F.relu(margin + diff)) /
                       (torch.sum(scores1 * scores2)))

        has_grad = True
        n_valid_samples += 1

        if plot and batch['batch_idx'] % batch['log_interval'] == 0:
            drawTraining(batch['image1'],
                         batch['image2'],
                         pos1,
                         pos2,
                         batch,
                         idx_in_batch,
                         output,
                         save=True)
            # drawTraining(img_warp1, img_warp2, pos1, pos2, batch, idx_in_batch, output, save=True)

    if not has_grad:
        raise NoGradientError

    loss = loss / (n_valid_samples)

    return loss
예제 #7
0
    def forward(self, sample, output, plot=False):

        c, h1, w1 = output['dense_features1'].shape
        s1 = output['scores1'].view(-1)

        _, h2, w2 = output['dense_features1'].shape
        s2 = output['scores2'].squeeze(0)

        all_des1 = F.normalize(output['dense_features1'].view(c, -1), dim=0)
        all_des2 = F.normalize(output['dense_features2'], dim=0)

        fmap_pos1 = grid_positions(h1, w1, self.device)
        pos1 = upscale_positions(fmap_pos1, scaling_steps=self.scaling_steps)
        des1 = all_des1

        pos1, ids = self.check_boundary(pos1, sample['image1'].shape)
        if ids.shape[0] < pos1.shape[1]:
            fmap_pos1 = fmap_pos1[:, ids]
            des1 = des1[:, ids]
            s1 = s1[ids]

        pos1_homo = torch.cat(
            [pos1, torch.ones(1, pos1.shape[1], device=self.device)], dim=0)
        pos2_homo = torch.matmul(sample['homo12'].squeeze(0), pos1_homo)
        pos2 = pos2_homo[:2, :] / (pos2_homo[2, :] + 1e-8)
        try:
            pos2, ids = self.check_boundary(pos2, sample['image1'].shape)
        except EmptyTensorError:
            raise NoGradientError

        fmap_pos1 = fmap_pos1[:, ids]
        des1 = des1[:, ids]
        s1 = s1[ids]

        fmap_pos2 = torch.round(
            downscale_positions(pos2,
                                scaling_steps=self.scaling_steps)).long()
        des2 = all_des2[:, fmap_pos2[0, :], fmap_pos2[1, :]].view(c, -1)
        all_des2 = all_des2.view(c, -1)
        s2 = s2[fmap_pos2[0, :],
                fmap_pos2[1, :]].view(-1)  # (173) # 取出 score 2

        # important
        # (173, 1, 512) @ (173, 512, 1) => (173, 1, 1) => squeeze => (173)
        positive_dist = 2 - 2 * (
            des1.t().unsqueeze(1) @ des2.t().unsqueeze(2)).squeeze()

        all_fmap_pos2 = grid_positions(h2, w2, self.device)
        pos_dist = torch.max(torch.abs(
            fmap_pos2.unsqueeze(2).float() - all_fmap_pos2.unsqueeze(1)),
                             dim=0)[0]
        is_out_of_safe_radius = pos_dist > self.safe_radius

        # (173, 1024) # 计算图像1点和所有图2点的特征距离
        dist_mat = 2 - 2 * (des1.t() @ all_des2)
        negative_dist2 = torch.min(  # 剔除掉近邻的点, 也就是在4个像素距离内(包含4个像素)的点, 然后找 descriptor 1 的 hardest negative sample
            dist_mat + (1 - is_out_of_safe_radius.float()) * 10.,
            dim=1)[0]  # (173) 找到 hardest negative sample

        # 将刚才针对 图像1 的计算再在 图像2 上进行一遍来找 descriptor 2 的 hardest negative sample
        all_fmap_pos1 = grid_positions(h1, w1, self.device)
        pos_dist = torch.max(torch.abs(
            fmap_pos1.unsqueeze(2).float() - all_fmap_pos1.unsqueeze(1)),
                             dim=0)[0]
        is_out_of_safe_radius = pos_dist > self.safe_radius

        dist_mat = 2 - 2 * (des2.t() @ all_des1)
        negative_dist1 = torch.min(dist_mat +
                                   (1 - is_out_of_safe_radius.float()) * 10.,
                                   dim=1)[0]
        # hard loss
        diff = positive_dist - torch.min(negative_dist1, negative_dist2)
        s = s1 * s2
        # 这里用 F.relu 取代了 hard_loss 中的 max() 函数
        loss = torch.sum(
            s * F.relu(self.margin + diff)) / (torch.sum(s) + 1e-8)
        # 这里希望 F.relu(margin + diff) 越小的话 scores1 * scores2 越大越好

        loss = loss * self.scale

        return loss
예제 #8
0
def loss_function(model,
                  batch,
                  device,
                  margin=1,
                  safe_radius=4,
                  scaling_steps=3,
                  plot=False):
    output = model(
        {  # ['dense_features1', 'scores1', 'dense_features2', 'scores2']
            'image1': batch['image1'].to(device),
            'image2': batch['image2'].to(device)
        })

    loss = torch.tensor(np.array([0], dtype=np.float32), device=device)
    has_grad = False

    n_valid_samples = 0
    for idx_in_batch in range(batch['image1'].size(0)):
        # Annotations
        depth1 = batch['depth1'][idx_in_batch].to(
            device)  # [h1, w1] (256, 256)
        intrinsics1 = batch['intrinsics1'][idx_in_batch].to(device)  # [3, 3]
        pose1 = batch['pose1'][idx_in_batch].view(4, 4).to(
            device)  # [4, 4] extrinsics
        bbox1 = batch['bbox1'][idx_in_batch].to(device)  # [2] top_left_corner

        depth2 = batch['depth2'][idx_in_batch].to(device)
        intrinsics2 = batch['intrinsics2'][idx_in_batch].to(device)
        pose2 = batch['pose2'][idx_in_batch].view(4, 4).to(device)
        bbox2 = batch['bbox2'][idx_in_batch].to(device)

        # Network output
        # (512, 32, 32)
        dense_features1 = output['dense_features1'][idx_in_batch]
        c, h1, w1 = dense_features1.size(
        )  # c=512, h1=32, w1=32 # TODO rename c to c1
        scores1 = output['scores1'][idx_in_batch].view(-1)  # 32*32=1024

        dense_features2 = output['dense_features2'][idx_in_batch]
        _, h2, w2 = dense_features2.size()  # TODO assert c2 == c1
        scores2 = output['scores2'][idx_in_batch]

        all_descriptors1 = F.normalize(dense_features1.view(c, -1),
                                       dim=0)  # (512, 32*32=1024)
        descriptors1 = all_descriptors1

        all_descriptors2 = F.normalize(dense_features2.view(c, -1), dim=0)

        # Warp the positions from image 1 to image 2
        # 制作 32*32 大小的棋盘 # (2, 1024) 坐标从 [0, 0] 到 [31, 31]
        fmap_pos1 = grid_positions(h1, w1, device)
        # 上采样棋盘坐标, 因为 VGG 提特征将 (256, 256) 的图像缩小到了 (32, 32) # 上采样会损失很多定位精度 #
        # shape 是 [2, 1024], 将 32*32 个 xy 坐标上采样, 变为 [3.5, 3.5] 到 [251.5, 251.5]
        pos1 = upscale_positions(fmap_pos1, scaling_steps=scaling_steps)
        try:  # 将图像1的点变换到图像2上, 但仅仅保留那些比较好的点, 用下标 ids 记录 # '好' 的定义是 1024 个点在变换前后对应的深度图的深度大于0, 并且没有越界, 并且估计和采样的深度误差在 0.05 范围内的点
            pos1, pos2, ids = warp(  # 数量从 1024 降到了 173
                pos1, depth1, intrinsics1, pose1, bbox1, depth2, intrinsics2,
                pose2, bbox2)
        except EmptyTensorError:
            continue
        fmap_pos1 = fmap_pos1[:, ids]  # 找到对应的 采样前棋盘坐标 fmap_pos1
        descriptors1 = descriptors1[:, ids]  # 找到对应的特征
        scores1 = scores1[ids]  # 找到对应的响应值

        # Skip the pair if not enough GT correspondences are available
        if ids.size(0) < 128:
            continue

        # Descriptors at the corresponding positions
        fmap_pos2 = torch.round(  # 将从 pos1 变换过来的 pos2 再降采样到 (32, 32) 的尺度, 并取最近整数
            downscale_positions(pos2, scaling_steps=scaling_steps)).long()
        descriptors2 = F.normalize(  # 取 fmap_pos2 对应的 特征
            dense_features2[:, fmap_pos2[0, :], fmap_pos2[1, :]],
            dim=0)
        positive_distance = 2 - 2 * (  # 计算配对 descriptor 的相似度, 值域是 [0, 2], 值越小代表越相似
            descriptors1.t().unsqueeze(1) @ descriptors2.t().unsqueeze(2)
        ).squeeze(
        )  # (173, 1, 512) @ (173, 512, 1) => (173, 1, 1) => squeeze => (173)

        all_fmap_pos2 = grid_positions(h2, w2, device)  # 制作 32*32 大小的棋盘
        position_distance = torch.max(
            torch.abs(
                fmap_pos2.unsqueeze(2).float() - all_fmap_pos2.unsqueeze(1)),
            dim=0
        )[0]  # (173, 1024) # 计算 fmap_pos2 到 棋盘每个点的距离(取像素 x 坐标距离和像素 y 坐标距离较大的值)
        # safe_radius=4 如果超出 4 个像素就视为不是该点近邻的点
        is_out_of_safe_radius = position_distance > safe_radius
        # (173, 1024) # 计算图像1点和所有图2点的特征距离
        distance_matrix = 2 - 2 * (descriptors1.t() @ all_descriptors2)
        negative_distance2 = torch.min(  # 剔除掉近邻的点, 也就是在4个像素距离内(包含4个像素)的点, 然后找 descriptor 1 的 hardest negative sample
            distance_matrix + (1 - is_out_of_safe_radius.float()) * 10.,
            dim=1)[0]  # (173) 找到 hardest negative sample
        # 将刚才针对 图像1 的计算再在 图像2 上进行一遍来找 descriptor 2 的 hardest negative sample
        all_fmap_pos1 = grid_positions(h1, w1, device)
        position_distance = torch.max(torch.abs(
            fmap_pos1.unsqueeze(2).float() - all_fmap_pos1.unsqueeze(1)),
                                      dim=0)[0]
        is_out_of_safe_radius = position_distance > safe_radius
        distance_matrix = 2 - 2 * (descriptors2.t() @ all_descriptors1)
        negative_distance1 = torch.min(
            distance_matrix + (1 - is_out_of_safe_radius.float()) * 10.,
            dim=1)[0]
        # hard loss
        diff = positive_distance - torch.min(negative_distance1,
                                             negative_distance2)

        scores2 = scores2[fmap_pos2[0, :],
                          fmap_pos2[1, :]]  # (173) # 取出 score 2

        loss = loss + (  # 这里用 F.relu 取代了 hard_loss 中的 max() 函数
            torch.sum(scores1 * scores2 * F.relu(margin + diff)) /
            torch.sum(scores1 * scores2)
        )  # 这里希望 F.relu(margin + diff) 越小的话 scores1 * scores2 越大越好

        has_grad = True  # 如果能运行到这里, 那么代表这个 训练样本 经受了考验
        n_valid_samples += 1

        if plot and batch['batch_idx'] % batch['log_interval'] == 0:
            pos1_aux = pos1.cpu().numpy()
            pos2_aux = pos2.cpu().numpy()
            k = pos1_aux.shape[1]
            col = np.random.rand(k, 3)
            n_sp = 4
            plt.figure()
            plt.subplot(1, n_sp, 1)
            im1 = imshow_image(batch['image1'][idx_in_batch].cpu().numpy(),
                               preprocessing=batch['preprocessing'])
            plt.imshow(im1)
            plt.scatter(pos1_aux[1, :],
                        pos1_aux[0, :],
                        s=0.25**2,
                        c=col,
                        marker=',',
                        alpha=0.5)
            plt.axis('off')

            plt.subplot(1, n_sp, 2)
            plt.imshow(output['scores1'][idx_in_batch].data.cpu().numpy(),
                       cmap='Reds')
            plt.axis('off')

            plt.subplot(1, n_sp, 3)
            im2 = imshow_image(batch['image2'][idx_in_batch].cpu().numpy(),
                               preprocessing=batch['preprocessing'])
            plt.imshow(im2)
            plt.scatter(pos2_aux[1, :],
                        pos2_aux[0, :],
                        s=0.25**2,
                        c=col,
                        marker=',',
                        alpha=0.5)
            plt.axis('off')

            plt.subplot(1, n_sp, 4)
            plt.imshow(output['scores2'][idx_in_batch].data.cpu().numpy(),
                       cmap='Reds')
            plt.axis('off')

            savefig(
                'train_vis/%s.%02d.%02d.%d.png' %
                ('train' if batch['train'] else 'valid', batch['epoch_idx'],
                 batch['batch_idx'] // batch['log_interval'], idx_in_batch),
                dpi=300)
            plt.close()

    if not has_grad:
        raise NoGradientError

    loss = loss / n_valid_samples

    return loss
예제 #9
0
def loss_function(model,
                  batch,
                  device,
                  margin=1,
                  safe_radius=4,
                  scaling_steps=3,
                  plot=False):
    output = model({
        'image1': batch['image1'].to(device),
        'image2': batch['image2'].to(device)
    })

    loss = torch.tensor(np.array([0], dtype=np.float32), device=device)
    has_grad = False

    n_valid_samples = 0
    for idx_in_batch in range(batch['image1'].size(0)):
        # Annotations
        depth1 = batch['depth1'][idx_in_batch].to(device)  # [h1, w1]
        intrinsics1 = batch['intrinsics1'][idx_in_batch].to(device)  # [3, 3]
        pose1 = batch['pose1'][idx_in_batch].view(4, 4).to(device)  # [4, 4]
        bbox1 = batch['bbox1'][idx_in_batch].to(device)  # [2]

        depth2 = batch['depth2'][idx_in_batch].to(device)
        intrinsics2 = batch['intrinsics2'][idx_in_batch].to(device)
        pose2 = batch['pose2'][idx_in_batch].view(4, 4).to(device)
        bbox2 = batch['bbox2'][idx_in_batch].to(device)

        # Network output
        dense_features1 = output['dense_features1'][idx_in_batch]
        c, h1, w1 = dense_features1.size()
        scores1 = output['scores1'][idx_in_batch].view(-1)

        dense_features2 = output['dense_features2'][idx_in_batch]
        _, h2, w2 = dense_features2.size()
        scores2 = output['scores2'][idx_in_batch]

        all_descriptors1 = F.normalize(dense_features1.view(c, -1), dim=0)
        descriptors1 = all_descriptors1

        all_descriptors2 = F.normalize(dense_features2.view(c, -1), dim=0)

        # Warp the positions from image 1 to image 2
        fmap_pos1 = grid_positions(h1, w1, device)

        hOrig, wOrig = int(batch['image1'].shape[2] / 8), int(
            batch['image1'].shape[3] / 8)
        fmap_pos1Orig = grid_positions(hOrig, wOrig, device)
        pos1 = upscale_positions(fmap_pos1Orig, scaling_steps=scaling_steps)

        # SIFT Feature Detection

        imgNp1 = imshow_image(batch['image1'][idx_in_batch].cpu().numpy(),
                              preprocessing=batch['preprocessing'])
        imgNp1 = cv2.cvtColor(imgNp1, cv2.COLOR_BGR2RGB)
        # surf = cv2.xfeatures2d.SIFT_create(100)
        #surf = cv2.xfeatures2d.SURF_create(80)
        orb = cv2.ORB_create(nfeatures=100, scoreType=cv2.ORB_FAST_SCORE)

        kp = orb.detect(imgNp1, None)

        keyP = [(kp[i].pt) for i in range(len(kp))]
        keyP = np.asarray(keyP).T
        keyP[[0, 1]] = keyP[[1, 0]]
        keyP = np.floor(keyP) + 0.5

        pos1 = torch.from_numpy(keyP).to(pos1.device).float()

        try:
            pos1, pos2, ids = warp(pos1, depth1, intrinsics1, pose1, bbox1,
                                   depth2, intrinsics2, pose2, bbox2)
        except EmptyTensorError:
            continue

        ids = idsAlign(pos1, device, h1, w1)

        # cv2.drawKeypoints(imgNp1, kp, imgNp1)
        # cv2.imshow('Keypoints', imgNp1)
        # cv2.waitKey(0)

        # drawTraining(batch['image1'], batch['image2'], pos1, pos2, batch, idx_in_batch, output, save=False)

        # exit(1)

        # Top view homography adjustment

        # H1 = output['H1'][idx_in_batch]
        # H2 = output['H2'][idx_in_batch]

        # try:
        # 	pos1, pos2 = homoAlign(pos1, pos2, H1, H2, device)
        # except IndexError:
        # 	continue

        # ids = idsAlign(pos1, device, h1, w1)

        # img_warp1 = tgm.warp_perspective(batch['image1'].to(device), H1, dsize=(400, 400))
        # img_warp2 = tgm.warp_perspective(batch['image2'].to(device), H2, dsize=(400, 400))

        # drawTraining(img_warp1, img_warp2, pos1, pos2, batch, idx_in_batch, output)

        fmap_pos1 = fmap_pos1[:, ids]
        descriptors1 = descriptors1[:, ids]
        scores1 = scores1[ids]

        # Skip the pair if not enough GT correspondences are available
        if ids.size(0) < 128:
            print(ids.size(0))
            continue

        # Descriptors at the corresponding positions
        fmap_pos2 = torch.round(
            downscale_positions(pos2, scaling_steps=scaling_steps)).long()

        descriptors2 = F.normalize(dense_features2[:, fmap_pos2[0, :],
                                                   fmap_pos2[1, :]],
                                   dim=0)

        positive_distance = 2 - 2 * (descriptors1.t().unsqueeze(
            1) @ descriptors2.t().unsqueeze(2)).squeeze()

        # positive_distance = getPositiveDistance(descriptors1, descriptors2)

        all_fmap_pos2 = grid_positions(h2, w2, device)
        position_distance = torch.max(torch.abs(
            fmap_pos2.unsqueeze(2).float() - all_fmap_pos2.unsqueeze(1)),
                                      dim=0)[0]
        is_out_of_safe_radius = position_distance > safe_radius

        distance_matrix = 2 - 2 * (descriptors1.t() @ all_descriptors2)
        # distance_matrix = getDistanceMatrix(descriptors1, all_descriptors2)

        negative_distance2 = torch.min(
            distance_matrix + (1 - is_out_of_safe_radius.float()) * 10.,
            dim=1)[0]

        # negative_distance2 = semiHardMine(distance_matrix, is_out_of_safe_radius, positive_distance, margin)

        all_fmap_pos1 = grid_positions(h1, w1, device)
        position_distance = torch.max(torch.abs(
            fmap_pos1.unsqueeze(2).float() - all_fmap_pos1.unsqueeze(1)),
                                      dim=0)[0]
        is_out_of_safe_radius = position_distance > safe_radius

        distance_matrix = 2 - 2 * (descriptors2.t() @ all_descriptors1)
        # distance_matrix = getDistanceMatrix(descriptors2, all_descriptors1)

        negative_distance1 = torch.min(
            distance_matrix + (1 - is_out_of_safe_radius.float()) * 10.,
            dim=1)[0]

        # negative_distance1 = semiHardMine(distance_matrix, is_out_of_safe_radius, positive_distance, margin)

        diff = positive_distance - torch.min(negative_distance1,
                                             negative_distance2)

        scores2 = scores2[fmap_pos2[0, :], fmap_pos2[1, :]]

        loss = loss + (torch.sum(scores1 * scores2 * F.relu(margin + diff)) /
                       (torch.sum(scores1 * scores2)))

        has_grad = True
        n_valid_samples += 1

        if plot and batch['batch_idx'] % batch['log_interval'] == 0:
            print("Inside plot.")
            drawTraining(batch['image1'],
                         batch['image2'],
                         pos1,
                         pos2,
                         batch,
                         idx_in_batch,
                         output,
                         save=True)
            # drawTraining(img_warp1, img_warp2, pos1, pos2, batch, idx_in_batch, output, save=True)

    if not has_grad:
        raise NoGradientError

    loss = loss / (n_valid_samples)

    return loss
예제 #10
0
def process_multiscale(image, model, scales=[.25, 0.50, 1.0]):
    b, _, h_init, w_init = image.size()
    device = image.device
    assert(b == 1)

    all_keypoints = torch.zeros([3, 0])
    all_descriptors = torch.zeros([
        model.dense_feature_extraction.num_channels, 0
    ])
    all_scores = torch.zeros(0)

    previous_dense_features = None
    banned = None
    for idx, scale in enumerate(scales):
        current_image = F.interpolate(
            image, scale_factor=scale,
            mode='bilinear', align_corners=True
        )
        _, _, h_level, w_level = current_image.size()

        dense_features = model.dense_feature_extraction(current_image)
        del current_image

        _, _, h, w = dense_features.size()

        # Sum the feature maps.
        if previous_dense_features is not None:
            dense_features += F.interpolate(
                previous_dense_features, size=[h, w],
                mode='bilinear', align_corners=True
            )
            del previous_dense_features

        # Recover detections.
        detections = model.detection(dense_features)
        if banned is not None:
            banned = F.interpolate(banned.float(), size=[h, w]).bool()
            detections = torch.min(detections, ~banned)
            banned = torch.max(
                torch.max(detections, dim=1)[0].unsqueeze(1), banned
            )
        else:
            banned = torch.max(detections, dim=1)[0].unsqueeze(1)
        fmap_pos = torch.nonzero(detections[0].cpu()).t()
        del detections
        # vis

        """
        fig = plt.figure()

        #plt.subplot(2, 1, 2)
        #plt.imshow(img_out)
        for i in range(25):
            vismap = dense_features[0,i,::,::]
            #
            vismap = vismap.cpu()

            #use sigmod to [0,1]
            vismap= 1.0/(1+np.exp(-1*vismap))

            # to [0,255]
            vismap=np.round(vismap*255)
            vismap=vismap.data.numpy()
            plt.subplot(5, 5, i+1)
            plt.axis('off')
            plt.imshow(vismap)
            filename = '/home/asky/featuremap/CH%d.jpg'% (i)

            #cv2.imwrite(filename,vismap)

        plt.tight_layout()
        fig.show()
        """
        # Recover displacements.
        displacements = model.localization(dense_features)[0].cpu()
        displacements_i = displacements[
            0, fmap_pos[0, :], fmap_pos[1, :], fmap_pos[2, :]
        ]
        displacements_j = displacements[
            1, fmap_pos[0, :], fmap_pos[1, :], fmap_pos[2, :]
        ]
        del displacements

        mask = torch.min(
            torch.abs(displacements_i) < 0.5,
            torch.abs(displacements_j) < 0.5
        )
        fmap_pos = fmap_pos[:, mask]
        valid_displacements = torch.stack([
            displacements_i[mask],
            displacements_j[mask]
        ], dim=0)
        del mask, displacements_i, displacements_j

        fmap_keypoints = fmap_pos[1 :, :].float() + valid_displacements
        del valid_displacements

        try:
            raw_descriptors, _, ids = interpolate_dense_features(
                fmap_keypoints.to(device),
                dense_features[0]
            )
        except EmptyTensorError:
            continue
        fmap_pos = fmap_pos[:, ids]
        fmap_keypoints = fmap_keypoints[:, ids]
        del ids

        keypoints = upscale_positions(fmap_keypoints, scaling_steps=2)
        del fmap_keypoints

        descriptors = F.normalize(raw_descriptors, dim=0).cpu()
        del raw_descriptors

        keypoints[0, :] *= h_init / h_level
        keypoints[1, :] *= w_init / w_level

        fmap_pos = fmap_pos.cpu()
        keypoints = keypoints.cpu()

        keypoints = torch.cat([
            keypoints,
            torch.ones([1, keypoints.size(1)]) * 1 / scale,
        ], dim=0)

        scores = dense_features[
            0, fmap_pos[0, :], fmap_pos[1, :], fmap_pos[2, :]
        ].cpu() / (idx + 1)
        del fmap_pos

        all_keypoints = torch.cat([all_keypoints, keypoints], dim=1)
        all_descriptors = torch.cat([all_descriptors, descriptors], dim=1)
        all_scores = torch.cat([all_scores, scores], dim=0)
        del keypoints, descriptors

        previous_dense_features = dense_features
        del dense_features
    del previous_dense_features, banned

    keypoints = all_keypoints.t().numpy()
    del all_keypoints
    scores = all_scores.numpy()
    del all_scores
    descriptors = all_descriptors.t().numpy()
    del all_descriptors
    return keypoints, scores, descriptors
예제 #11
0
def process_multiscale(image, model, scales=None):
    if scales is None:
        scales = [.5, 1, 2]
    b, _, h_init, w_init = image.size()  # b = 1, h_int = 256, w_int = 256
    device = image.device
    assert (b == 1)

    all_keypoints = torch.zeros([3, 0])  # 代表 cyx 坐标, c 代表深度的坐标
    all_descriptors = torch.zeros([  # (512, 0) 512 维度的向量
        model.dense_feature_extraction.num_channels, 0
    ])
    all_scores = torch.zeros(0)  # 关键点的响应

    previous_dense_features = None
    banned = None
    for idx, scale in enumerate(scales):
        current_image = F.interpolate(  # 如果需要将图像缩放到多个尺度, 就是用插值
            image,
            scale_factor=scale,
            mode='bilinear',
            align_corners=True)
        _, _, h_level, w_level = current_image.size()  # 插值之后的高和宽

        # (1, 512, 63, 63) 提取的特征维度, 在训练时, 256 是被缩小到 32.
        dense_features = model.dense_feature_extraction(current_image)
        del current_image

        _, _, h, w = dense_features.size()  # h = 63, w = 63

        # Sum the feature maps.
        if previous_dense_features is not None:
            # 如果之前预测得到了 feature, 则采样到和当前一样的大小
            # 由于多尺度默认输入是 [0.5, 1., 2.], 所以这里的插值是上采样
            # 如果反过来是 [2., 1., 0.5], 就会变成下采样了

            # 插值采样到一样的大小之后
            # 之前的 feature map 和 现在的 feature map 进行相加得到新的
            # previous_dense_features
            dense_features += F.interpolate(previous_dense_features,
                                            size=[h, w],
                                            mode='bilinear',
                                            align_corners=True)
            del previous_dense_features

        # Recover detections.
        # (1, 512, 63, 63)
        # 如果是多尺度, 通过融合以后的 feature map 进行关键点的预测.
        # 预测得到一个 (1, 512, 63, 63) 的 detection, 值为 0 或 1
        detections = model.detection(dense_features)
        if banned is not None:
            # banned 表示上一次检测到的关键点, 在多尺度下 为了避免重复检测, ban 掉它.
            # 和 feature 一样, 如果之前有 banned 就插值到和本次一样 # (1, 1, h, w)
            banned = F.interpolate(banned.float(), size=[h, w]).byte()
            # 如果上一次有在这一次的某个位置检测到关键点, 就将这个位置的值设为 0
            detections = torch.min(detections.byte(), (1 - banned))
            banned = torch.max(  # 将之前检测的结果和现在检测的结果进行融合
                torch.max(detections, dim=1)[0].unsqueeze(1), banned)
        else:
            # (1, 1, 63, 63)
            # 因为 detection 的值为 0 或 1, 所以取最大有 1 的话代表这个点有关键点
            banned = torch.max(detections, dim=1)[0].unsqueeze(1)
        # (3, 752) # 提取关键点的对应整数坐标 cyx
        fmap_pos = torch.nonzero(detections[0]).t().cpu()
        del detections

        # Recover displacements.
        # 输入 dense_features 是 (1, 512, 63, 63)
        # 输出是 (2, 512, 63, 63) 代表的每个点 yx
        # 的坐标偏移
        displacements = model.localization(dense_features)[0]
        # 从 displacements 中取出 fmap_pos y 坐标对应的偏移量
        displacements_i = displacements[0, fmap_pos[0, :], fmap_pos[1, :],
                                        fmap_pos[2, :]]  # (752)
        # 从 displacements 中取出 fmap_pos x 坐标对应的偏移量
        displacements_j = displacements[1, fmap_pos[0, :], fmap_pos[1, :],
                                        fmap_pos[2, :]]  # (752)
        del displacements

        # 将偏移量绝对值小于 0.5 的关键点的 mask 设为 1
        # 只要有一个超过或等于 0.5, 就为 0
        mask = torch.min(
            torch.abs(displacements_i) < 0.5,
            torch.abs(displacements_j) < 0.5)  # (752) mask.sum() = 718
        # 只取出偏移量小于 0.5 的关键点整数坐标
        fmap_pos = fmap_pos[:, mask]  # (3, 718)
        valid_displacements = torch.stack(
            [displacements_i[mask], displacements_j[mask]],
            dim=0).cpu()  # (2, 718)
        del mask, displacements_i, displacements_j

        # fmap_pos[1 :, :].shape => (2, 718)
        fmap_keypoints = fmap_pos[1:, :].float() + valid_displacements  # (yx)
        del valid_displacements

        try:
            # (根据坐标将特征进行插值得到最终的特征)
            # raw_descriptors => (512, 718), ids => 718
            raw_descriptors, _, ids = interpolate_dense_features(
                fmap_keypoints.to(device),  # (2, 718) - yx
                dense_features[0]  # dense_features[0].shape => (512, 63, 63)
            )
        except EmptyTensorError:
            continue
        fmap_pos = fmap_pos[:, ids]  # (3, 718) 整数坐标
        fmap_keypoints = fmap_keypoints[:, ids]  # (2, 718) 亚像素坐标
        del ids

        keypoints = upscale_positions(fmap_keypoints,
                                      scaling_steps=2)  # 为了抵消 vgg16 1/4 下采样
        del fmap_keypoints

        descriptors = F.normalize(raw_descriptors, dim=0).cpu()  # L2-norm
        del raw_descriptors

        keypoints[0, :] *= h_init / h_level  # 为了抵消开头的 resize
        keypoints[1, :] *= w_init / w_level  # 为了抵消开头的 resize

        fmap_pos = fmap_pos.cpu()
        keypoints = keypoints.cpu()

        keypoints = torch.cat(
            [  # 变为齐次坐标
                keypoints,
                torch.ones([1, keypoints.size(1)]) * 1 /
                scale,  # 考虑当前 scale 的影响
            ],
            dim=0)

        # 根据关键点整数坐标从 feature 中取得对应的 score
        # 再根据下标(idx=0, 1, 2)依次除以 (idx+1 = 1, 2, 3)
        scores = dense_features[0, fmap_pos[0, :], fmap_pos[1, :],
                                fmap_pos[2, :]].cpu() / (idx + 1)
        del fmap_pos

        all_keypoints = torch.cat([all_keypoints, keypoints],
                                  dim=1)  # (2, 718)
        all_descriptors = torch.cat([all_descriptors, descriptors],
                                    dim=1)  # (512, 718)
        all_scores = torch.cat([all_scores, scores], dim=0)  # (718)
        del keypoints, descriptors

        previous_dense_features = dense_features
        del dense_features
    del previous_dense_features, banned

    keypoints = all_keypoints.t().numpy()  # (718, 3)
    del all_keypoints
    scores = all_scores.numpy()  # (718)
    del all_scores
    descriptors = all_descriptors.t().numpy()  # (718, 512)
    del all_descriptors
    return keypoints, scores, descriptors