def display(self, img1, img2):
        plt.figure()

        im1 = imshow_image(img1[0].cpu().numpy(), preprocessing='caffe')

        im2 = imshow_image(img2[0].cpu().numpy(), preprocessing='caffe')

        plt.subplot(1, 2, 1)
        plt.imshow(im1)
        plt.axis('off')

        plt.subplot(1, 2, 2)
        plt.imshow(im2)
        plt.axis('off')

        plt.show()

        exit(1)
Example #2
0
def draw_matches(img1, cv_kpts1, img2, cv_kpts2, match_ids, match_color=(0, 255, 0), pt_color=(0, 0, 255)):
    print(f'matches:{len(match_ids[0])},{len(match_ids[1])}')
    img1 = imshow_image(img1.cpu().numpy(), 'torch')
    cv_kpts1 = cv_kpts1.cpu().numpy().T
    img2 = imshow_image(img2.cpu().numpy(), 'torch')
    cv_kpts2 = cv_kpts2.cpu().numpy().T
    good_matches = []
    for id1, id2 in zip(*match_ids):
        match = cv2.DMatch()
        match.queryIdx = id1
        match.trainIdx = id2
        good_matches.append(match)
    mask = np.ones((len(good_matches), ))
    """Draw matches."""
    if type(cv_kpts1) is np.ndarray and type(cv_kpts2) is np.ndarray:
        cv_kpts1 = [cv2.KeyPoint(cv_kpts1[i][0], cv_kpts1[i][1], 1) for i in range(cv_kpts1.shape[0])]
        cv_kpts2 = [cv2.KeyPoint(cv_kpts2[i][0], cv_kpts2[i][1], 1) for i in range(cv_kpts2.shape[0])]
    display = cv2.drawMatches(img1, cv_kpts1, img2, cv_kpts2, good_matches, None, matchColor=match_color, singlePointColor=pt_color, matchesMask=mask.ravel().tolist(), flags=4)
    cv2.imwrite('match_vis_tmp_' + str(time.time()) + '.png', display)
    return display
Example #3
0
def loss_function(model,
                  batch,
                  device,
                  margin=1,
                  safe_radius=4,
                  scaling_steps=3):
    output = model({
        'image1': batch['image1'].to(device),
        'image2': batch['image2'].to(device)
    })

    loss = torch.tensor(np.array([0], dtype=np.float32), device=device)
    has_grad = False

    n_valid_samples = 0
    for idx_in_batch in range(batch['image1'].size(0)):
        # Annotations
        depth1 = batch['depth1'][idx_in_batch, :, :].to(device)  # [h1, w1]
        intrinsics1 = batch['intrinsics1'][idx_in_batch, :].to(device)  # [9]
        pose1 = batch['pose1'][idx_in_batch, :].view(4, 4).to(device)  # [4, 4]
        bbox1 = batch['bbox1'][idx_in_batch, :].to(device)  # [2]

        depth2 = batch['depth2'][idx_in_batch, :, :].to(device)
        intrinsics2 = batch['intrinsics2'][idx_in_batch, :].to(device)
        pose2 = batch['pose2'][idx_in_batch, :].view(4, 4).to(device)
        bbox2 = batch['bbox2'][idx_in_batch, :].to(device)

        # Network output
        dense_features1 = output['dense_features1'][idx_in_batch, :, :, :]
        c, h1, w1 = dense_features1.size()
        scores1 = output['scores1'][idx_in_batch, :, :].view(-1)

        dense_features2 = output['dense_features2'][idx_in_batch, :, :, :]
        _, h2, w2 = dense_features2.size()
        scores2 = output['scores2'][idx_in_batch, :, :]

        all_descriptors1 = F.normalize(dense_features1.view(c, -1), dim=0)
        descriptors1 = all_descriptors1

        all_descriptors2 = F.normalize(dense_features2.view(c, -1), dim=0)

        # Warp the positions from image 1 to image 2
        fmap_pos1 = grid_positions(h1, w1, device)
        pos1 = upscale_positions(fmap_pos1, scaling_steps=scaling_steps)
        try:
            pos1, pos2, ids = warp(pos1, depth1, intrinsics1, pose1, bbox1,
                                   depth2, intrinsics2, pose2, bbox2)  # [2, _]
        except EmptyTensorError:
            continue
        fmap_pos1 = fmap_pos1[:, ids]
        descriptors1 = descriptors1[:, ids]
        scores1 = scores1[ids]

        # Skip the pair if not enough GT correspondences are available
        if ids.size(0) < 128:
            continue

        # Descriptors at the corresponding positions
        fmap_pos2 = torch.round(
            downscale_positions(pos2, scaling_steps=scaling_steps)).long()
        descriptors2 = F.normalize(dense_features2[:, fmap_pos2[0, :],
                                                   fmap_pos2[1, :]],
                                   dim=0)
        positive_distance = 2 - 2 * (descriptors1.t().unsqueeze(
            1) @ descriptors2.t().unsqueeze(2)).squeeze()

        all_fmap_pos2 = grid_positions(h2, w2, device)
        position_distance = torch.max(torch.abs(
            fmap_pos2.unsqueeze(2).float() - all_fmap_pos2.unsqueeze(1)),
                                      dim=0)[0]
        is_out_of_safe_radius = position_distance > safe_radius
        distance_matrix = 2 - 2 * (descriptors1.t() @ all_descriptors2)
        negative_distance2 = torch.min(
            distance_matrix + (1 - is_out_of_safe_radius.float()) * 10.,
            dim=1)[0]

        all_fmap_pos1 = grid_positions(h1, w1, device)
        position_distance = torch.max(torch.abs(
            fmap_pos1.unsqueeze(2).float() - all_fmap_pos1.unsqueeze(1)),
                                      dim=0)[0]
        is_out_of_safe_radius = position_distance > safe_radius
        distance_matrix = 2 - 2 * (descriptors2.t() @ all_descriptors1)
        negative_distance1 = torch.min(
            distance_matrix + (1 - is_out_of_safe_radius.float()) * 10.,
            dim=1)[0]

        diff = positive_distance - torch.min(negative_distance1,
                                             negative_distance2)

        scores2 = scores2[fmap_pos2[0, :], fmap_pos2[1, :]]

        loss = loss + torch.sum(scores1 * scores2 * F.relu(margin + diff)
                                ) / torch.sum(scores1 * scores2)
        has_grad = True
        n_valid_samples += 1

        if batch['batch_idx'] % batch['log_interval'] == 0:
            pos1_aux = pos1.cpu().numpy()
            pos2_aux = pos2.cpu().numpy()
            k = pos1_aux.shape[1]
            col = np.random.rand(k, 3)
            n_sp = 4
            plt.figure()
            plt.subplot(1, n_sp, 1)
            im1 = imshow_image(
                batch['image1'][idx_in_batch, :, :, :].cpu().numpy(),
                preprocessing=batch['preprocessing'])
            plt.imshow(im1)
            plt.scatter(pos1_aux[1, :],
                        pos1_aux[0, :],
                        s=0.25**2,
                        c=col,
                        marker=',',
                        alpha=0.5)
            plt.axis('off')
            plt.subplot(1, n_sp, 2)
            plt.imshow(
                output['scores1'][idx_in_batch, :, :].data.cpu().numpy(),
                cmap='Reds')
            plt.axis('off')
            plt.subplot(1, n_sp, 3)
            im2 = imshow_image(
                batch['image2'][idx_in_batch, :, :, :].cpu().numpy(),
                preprocessing=batch['preprocessing'])
            plt.imshow(im2)
            plt.scatter(pos2_aux[1, :],
                        pos2_aux[0, :],
                        s=0.25**2,
                        c=col,
                        marker=',',
                        alpha=0.5)
            plt.axis('off')
            plt.subplot(1, n_sp, 4)
            plt.imshow(
                output['scores2'][idx_in_batch, :, :].data.cpu().numpy(),
                cmap='Reds')
            plt.axis('off')
            savefig(
                'train_vis/%s/%02d.%02d.%d.png' %
                ('train' if batch['train'] else 'valid', batch['epoch_idx'],
                 batch['batch_idx'] // batch['log_interval'], idx_in_batch),
                dpi=300)
            plt.close()

    if not has_grad:
        raise NoGradientError

    loss = loss / n_valid_samples

    return loss
Example #4
0
def drawTraining(image1,
                 image2,
                 pos1,
                 pos2,
                 batch,
                 idx_in_batch,
                 output,
                 save=False):
    pos1_aux = pos1.cpu().numpy()
    pos2_aux = pos2.cpu().numpy()

    k = pos1_aux.shape[1]
    col = np.random.rand(k, 3)
    n_sp = 4
    plt.figure()
    plt.subplot(1, n_sp, 1)
    im1 = imshow_image(image1[0].cpu().numpy(),
                       preprocessing=batch['preprocessing'])
    plt.imshow(im1)
    plt.scatter(pos1_aux[1, :],
                pos1_aux[0, :],
                s=0.25**2,
                c=col,
                marker=',',
                alpha=0.5)
    plt.axis('off')
    plt.subplot(1, n_sp, 2)
    plt.imshow(output['scores1'][idx_in_batch].data.cpu().numpy(), cmap='Reds')
    plt.axis('off')
    plt.subplot(1, n_sp, 3)
    im2 = imshow_image(image2[0].cpu().numpy(),
                       preprocessing=batch['preprocessing'])
    plt.imshow(im2)
    plt.scatter(pos2_aux[1, :],
                pos2_aux[0, :],
                s=0.25**2,
                c=col,
                marker=',',
                alpha=0.5)
    plt.axis('off')
    plt.subplot(1, n_sp, 4)
    plt.imshow(output['scores2'][idx_in_batch].data.cpu().numpy(), cmap='Reds')
    plt.axis('off')

    if (save == True):
        savefig('train_vis/%s.%02d.%02d.%d.png' %
                ('train' if batch['train'] else 'valid', batch['epoch_idx'],
                 batch['batch_idx'] // batch['log_interval'], idx_in_batch),
                dpi=300)
    else:
        plt.show()

    plt.close()

    im1 = cv2.cvtColor(im1, cv2.COLOR_BGR2RGB)
    im2 = cv2.cvtColor(im2, cv2.COLOR_BGR2RGB)

    for i in range(0, pos1_aux.shape[1], 5):
        im1 = cv2.circle(im1, (pos1_aux[1, i], pos1_aux[0, i]), 1, (0, 0, 255),
                         2)
    for i in range(0, pos2_aux.shape[1], 5):
        im2 = cv2.circle(im2, (pos2_aux[1, i], pos2_aux[0, i]), 1, (0, 0, 255),
                         2)

    im3 = cv2.hconcat([im1, im2])

    for i in range(0, pos1_aux.shape[1], 5):
        im3 = cv2.line(
            im3, (int(pos1_aux[1, i]), int(pos1_aux[0, i])),
            (int(pos2_aux[1, i]) + im1.shape[1], int(pos2_aux[0, i])),
            (0, 255, 0), 1)

    if (save == True):
        cv2.imwrite(
            'train_vis/%s.%02d.%02d.%d.png' %
            ('train_corr' if batch['train'] else 'valid', batch['epoch_idx'],
             batch['batch_idx'] // batch['log_interval'], idx_in_batch), im3)
    else:
        cv2.imshow('Image', im3)
        cv2.waitKey(0)
Example #5
0
def loss_function_PT(model,
                     batch,
                     device,
                     margin=1,
                     safe_radius=4,
                     scaling_steps=3,
                     plot=False):
    output = model({
        'image1': batch['image1'].to(device),
        'image2': batch['image2'].to(device)
    })

    loss = torch.tensor(np.array([0], dtype=np.float32), device=device)
    has_grad = False

    n_valid_samples = 0
    for idx_in_batch in range(batch['image1'].size(0)):
        # Network output
        dense_features1 = output['dense_features1'][idx_in_batch]
        c, h1, w1 = dense_features1.size()
        scores1 = output['scores1'][idx_in_batch].view(-1)

        dense_features2 = output['dense_features2'][idx_in_batch]
        _, h2, w2 = dense_features2.size()
        scores2 = output['scores2'][idx_in_batch]

        all_descriptors1 = F.normalize(dense_features1.view(c, -1), dim=0)
        descriptors1 = all_descriptors1

        all_descriptors2 = F.normalize(dense_features2.view(c, -1), dim=0)

        fmap_pos1 = grid_positions(h1, w1, device)

        # hOrig, wOrig = int(batch['image1'].shape[2]/8), int(batch['image1'].shape[3]/8)
        # fmap_pos1Orig = grid_positions(hOrig, wOrig, device)
        #pos1 = upscale_positions(fmap_pos1Orig, scaling_steps=scaling_steps)

        # get correspondences
        img1 = imshow_image(batch['image1'][idx_in_batch].cpu().numpy(),
                            preprocessing=batch['preprocessing'])
        img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB)
        img2 = imshow_image(batch['image2'][idx_in_batch].cpu().numpy(),
                            preprocessing=batch['preprocessing'])
        img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB)

        pos1, pos2 = getCorr(img1, img2)
        if (len(pos1) == 0 or len(pos2) == 0):
            continue

        pos1 = torch.from_numpy(pos1.astype(np.float32)).to(device)
        pos2 = torch.from_numpy(pos2.astype(np.float32)).to(device)
        img1 = torch.from_numpy(img1.astype(np.float32)).to(device)
        img2 = torch.from_numpy(img2.astype(np.float32)).to(device)

        # print('p1', pos1.size())
        # print('p2',pos2.size())
        ids = idsAlign(pos1, device, h1, w1)
        #print('ids', ids)

        fmap_pos1 = fmap_pos1[:, ids]
        descriptors1 = descriptors1[:, ids]
        scores1 = scores1[ids]

        # Skip the pair if not enough GT correspondences are available

        if ids.size(0) < 128:
            print('hi', ids.size(0))
            continue

        # Descriptors at the corresponding positions
        fmap_pos2 = torch.round(
            downscale_positions(pos2, scaling_steps=scaling_steps)).long()

        descriptors2 = F.normalize(dense_features2[:, fmap_pos2[0, :],
                                                   fmap_pos2[1, :]],
                                   dim=0)

        positive_distance = 2 - 2 * (descriptors1.t().unsqueeze(
            1) @ descriptors2.t().unsqueeze(2)).squeeze()

        #positive_distance = getPositiveDistance(descriptors1, descriptors2)

        all_fmap_pos2 = grid_positions(h2, w2, device)
        position_distance = torch.max(torch.abs(
            fmap_pos2.unsqueeze(2).float() - all_fmap_pos2.unsqueeze(1)),
                                      dim=0)[0]
        is_out_of_safe_radius = position_distance > safe_radius

        distance_matrix = 2 - 2 * (descriptors1.t() @ all_descriptors2)
        #distance_matrix = getDistanceMatrix(descriptors1, all_descriptors2)

        negative_distance2 = torch.min(
            distance_matrix + (1 - is_out_of_safe_radius.float()) * 10.,
            dim=1)[0]

        #negative_distance2 = semiHardMine(distance_matrix, is_out_of_safe_radius, positive_distance, margin)

        all_fmap_pos1 = grid_positions(h1, w1, device)
        position_distance = torch.max(torch.abs(
            fmap_pos1.unsqueeze(2).float() - all_fmap_pos1.unsqueeze(1)),
                                      dim=0)[0]
        is_out_of_safe_radius = position_distance > safe_radius

        distance_matrix = 2 - 2 * (descriptors2.t() @ all_descriptors1)
        #distance_matrix = getDistanceMatrix(descriptors2, all_descriptors1)

        negative_distance1 = torch.min(
            distance_matrix + (1 - is_out_of_safe_radius.float()) * 10.,
            dim=1)[0]

        #negative_distance1 = semiHardMine(distance_matrix, is_out_of_safe_radius, positive_distance, margin)

        diff = positive_distance - torch.min(negative_distance1,
                                             negative_distance2)
        print('positive_distance_min', torch.min(positive_distance))
        print('negative_distance1_min', torch.min(negative_distance1))
        print('positive_distance_max', torch.max(positive_distance))
        print('negative_distance1_max', torch.max(negative_distance1))
        print('positive_distance_mean', torch.mean(positive_distance))
        print('negative_distance1_mean', torch.mean(negative_distance1))

        scores2 = scores2[fmap_pos2[0, :], fmap_pos2[1, :]]

        loss = loss + (torch.sum(scores1 * scores2 * F.relu(margin + diff)) /
                       (torch.sum(scores1 * scores2)))

        print('scores1_min', torch.min(scores1))
        print('scores1_max', torch.max(scores1))
        print('scores1_mean', torch.mean(scores1))

        has_grad = True
        n_valid_samples += 1

        if plot and batch['batch_idx'] % batch['log_interval'] == 0:
            drawTraining(batch['image1'],
                         batch['image2'],
                         pos1,
                         pos2,
                         batch,
                         idx_in_batch,
                         output,
                         save=True)
            #drawTraining(img_warp1, img_warp2, pos1, pos2, batch, idx_in_batch, output, save=True)

    if not has_grad:
        raise NoGradientError
    print('scores1', scores1)
    print('scores2', scores2)
    loss = loss / (n_valid_samples)

    return loss
Example #6
0
def loss_function(model,
                  batch,
                  device,
                  margin=1,
                  safe_radius=4,
                  scaling_steps=3,
                  plot=False):
    output = model(
        {  # ['dense_features1', 'scores1', 'dense_features2', 'scores2']
            'image1': batch['image1'].to(device),
            'image2': batch['image2'].to(device)
        })

    loss = torch.tensor(np.array([0], dtype=np.float32), device=device)
    has_grad = False

    n_valid_samples = 0
    for idx_in_batch in range(batch['image1'].size(0)):
        # Annotations
        depth1 = batch['depth1'][idx_in_batch].to(
            device)  # [h1, w1] (256, 256)
        intrinsics1 = batch['intrinsics1'][idx_in_batch].to(device)  # [3, 3]
        pose1 = batch['pose1'][idx_in_batch].view(4, 4).to(
            device)  # [4, 4] extrinsics
        bbox1 = batch['bbox1'][idx_in_batch].to(device)  # [2] top_left_corner

        depth2 = batch['depth2'][idx_in_batch].to(device)
        intrinsics2 = batch['intrinsics2'][idx_in_batch].to(device)
        pose2 = batch['pose2'][idx_in_batch].view(4, 4).to(device)
        bbox2 = batch['bbox2'][idx_in_batch].to(device)

        # Network output
        # (512, 32, 32)
        dense_features1 = output['dense_features1'][idx_in_batch]
        c, h1, w1 = dense_features1.size(
        )  # c=512, h1=32, w1=32 # TODO rename c to c1
        scores1 = output['scores1'][idx_in_batch].view(-1)  # 32*32=1024

        dense_features2 = output['dense_features2'][idx_in_batch]
        _, h2, w2 = dense_features2.size()  # TODO assert c2 == c1
        scores2 = output['scores2'][idx_in_batch]

        all_descriptors1 = F.normalize(dense_features1.view(c, -1),
                                       dim=0)  # (512, 32*32=1024)
        descriptors1 = all_descriptors1

        all_descriptors2 = F.normalize(dense_features2.view(c, -1), dim=0)

        # Warp the positions from image 1 to image 2
        # 制作 32*32 大小的棋盘 # (2, 1024) 坐标从 [0, 0] 到 [31, 31]
        fmap_pos1 = grid_positions(h1, w1, device)
        # 上采样棋盘坐标, 因为 VGG 提特征将 (256, 256) 的图像缩小到了 (32, 32) # 上采样会损失很多定位精度 #
        # shape 是 [2, 1024], 将 32*32 个 xy 坐标上采样, 变为 [3.5, 3.5] 到 [251.5, 251.5]
        pos1 = upscale_positions(fmap_pos1, scaling_steps=scaling_steps)
        try:  # 将图像1的点变换到图像2上, 但仅仅保留那些比较好的点, 用下标 ids 记录 # '好' 的定义是 1024 个点在变换前后对应的深度图的深度大于0, 并且没有越界, 并且估计和采样的深度误差在 0.05 范围内的点
            pos1, pos2, ids = warp(  # 数量从 1024 降到了 173
                pos1, depth1, intrinsics1, pose1, bbox1, depth2, intrinsics2,
                pose2, bbox2)
        except EmptyTensorError:
            continue
        fmap_pos1 = fmap_pos1[:, ids]  # 找到对应的 采样前棋盘坐标 fmap_pos1
        descriptors1 = descriptors1[:, ids]  # 找到对应的特征
        scores1 = scores1[ids]  # 找到对应的响应值

        # Skip the pair if not enough GT correspondences are available
        if ids.size(0) < 128:
            continue

        # Descriptors at the corresponding positions
        fmap_pos2 = torch.round(  # 将从 pos1 变换过来的 pos2 再降采样到 (32, 32) 的尺度, 并取最近整数
            downscale_positions(pos2, scaling_steps=scaling_steps)).long()
        descriptors2 = F.normalize(  # 取 fmap_pos2 对应的 特征
            dense_features2[:, fmap_pos2[0, :], fmap_pos2[1, :]],
            dim=0)
        positive_distance = 2 - 2 * (  # 计算配对 descriptor 的相似度, 值域是 [0, 2], 值越小代表越相似
            descriptors1.t().unsqueeze(1) @ descriptors2.t().unsqueeze(2)
        ).squeeze(
        )  # (173, 1, 512) @ (173, 512, 1) => (173, 1, 1) => squeeze => (173)

        all_fmap_pos2 = grid_positions(h2, w2, device)  # 制作 32*32 大小的棋盘
        position_distance = torch.max(
            torch.abs(
                fmap_pos2.unsqueeze(2).float() - all_fmap_pos2.unsqueeze(1)),
            dim=0
        )[0]  # (173, 1024) # 计算 fmap_pos2 到 棋盘每个点的距离(取像素 x 坐标距离和像素 y 坐标距离较大的值)
        # safe_radius=4 如果超出 4 个像素就视为不是该点近邻的点
        is_out_of_safe_radius = position_distance > safe_radius
        # (173, 1024) # 计算图像1点和所有图2点的特征距离
        distance_matrix = 2 - 2 * (descriptors1.t() @ all_descriptors2)
        negative_distance2 = torch.min(  # 剔除掉近邻的点, 也就是在4个像素距离内(包含4个像素)的点, 然后找 descriptor 1 的 hardest negative sample
            distance_matrix + (1 - is_out_of_safe_radius.float()) * 10.,
            dim=1)[0]  # (173) 找到 hardest negative sample
        # 将刚才针对 图像1 的计算再在 图像2 上进行一遍来找 descriptor 2 的 hardest negative sample
        all_fmap_pos1 = grid_positions(h1, w1, device)
        position_distance = torch.max(torch.abs(
            fmap_pos1.unsqueeze(2).float() - all_fmap_pos1.unsqueeze(1)),
                                      dim=0)[0]
        is_out_of_safe_radius = position_distance > safe_radius
        distance_matrix = 2 - 2 * (descriptors2.t() @ all_descriptors1)
        negative_distance1 = torch.min(
            distance_matrix + (1 - is_out_of_safe_radius.float()) * 10.,
            dim=1)[0]
        # hard loss
        diff = positive_distance - torch.min(negative_distance1,
                                             negative_distance2)

        scores2 = scores2[fmap_pos2[0, :],
                          fmap_pos2[1, :]]  # (173) # 取出 score 2

        loss = loss + (  # 这里用 F.relu 取代了 hard_loss 中的 max() 函数
            torch.sum(scores1 * scores2 * F.relu(margin + diff)) /
            torch.sum(scores1 * scores2)
        )  # 这里希望 F.relu(margin + diff) 越小的话 scores1 * scores2 越大越好

        has_grad = True  # 如果能运行到这里, 那么代表这个 训练样本 经受了考验
        n_valid_samples += 1

        if plot and batch['batch_idx'] % batch['log_interval'] == 0:
            pos1_aux = pos1.cpu().numpy()
            pos2_aux = pos2.cpu().numpy()
            k = pos1_aux.shape[1]
            col = np.random.rand(k, 3)
            n_sp = 4
            plt.figure()
            plt.subplot(1, n_sp, 1)
            im1 = imshow_image(batch['image1'][idx_in_batch].cpu().numpy(),
                               preprocessing=batch['preprocessing'])
            plt.imshow(im1)
            plt.scatter(pos1_aux[1, :],
                        pos1_aux[0, :],
                        s=0.25**2,
                        c=col,
                        marker=',',
                        alpha=0.5)
            plt.axis('off')

            plt.subplot(1, n_sp, 2)
            plt.imshow(output['scores1'][idx_in_batch].data.cpu().numpy(),
                       cmap='Reds')
            plt.axis('off')

            plt.subplot(1, n_sp, 3)
            im2 = imshow_image(batch['image2'][idx_in_batch].cpu().numpy(),
                               preprocessing=batch['preprocessing'])
            plt.imshow(im2)
            plt.scatter(pos2_aux[1, :],
                        pos2_aux[0, :],
                        s=0.25**2,
                        c=col,
                        marker=',',
                        alpha=0.5)
            plt.axis('off')

            plt.subplot(1, n_sp, 4)
            plt.imshow(output['scores2'][idx_in_batch].data.cpu().numpy(),
                       cmap='Reds')
            plt.axis('off')

            savefig(
                'train_vis/%s.%02d.%02d.%d.png' %
                ('train' if batch['train'] else 'valid', batch['epoch_idx'],
                 batch['batch_idx'] // batch['log_interval'], idx_in_batch),
                dpi=300)
            plt.close()

    if not has_grad:
        raise NoGradientError

    loss = loss / n_valid_samples

    return loss
Example #7
0
def loss_function(model,
                  batch,
                  device,
                  margin=1,
                  safe_radius=4,
                  scaling_steps=3,
                  plot=False):
    output = model({
        'image1': batch['image1'].to(device),
        'image2': batch['image2'].to(device)
    })

    loss = torch.tensor(np.array([0], dtype=np.float32), device=device)
    has_grad = False

    n_valid_samples = 0
    for idx_in_batch in range(batch['image1'].size(0)):
        # Annotations
        depth1 = batch['depth1'][idx_in_batch].to(device)  # [h1, w1]
        intrinsics1 = batch['intrinsics1'][idx_in_batch].to(device)  # [3, 3]
        pose1 = batch['pose1'][idx_in_batch].view(4, 4).to(device)  # [4, 4]
        bbox1 = batch['bbox1'][idx_in_batch].to(device)  # [2]

        depth2 = batch['depth2'][idx_in_batch].to(device)
        intrinsics2 = batch['intrinsics2'][idx_in_batch].to(device)
        pose2 = batch['pose2'][idx_in_batch].view(4, 4).to(device)
        bbox2 = batch['bbox2'][idx_in_batch].to(device)

        # Network output
        dense_features1 = output['dense_features1'][idx_in_batch]
        c, h1, w1 = dense_features1.size()
        scores1 = output['scores1'][idx_in_batch].view(-1)

        dense_features2 = output['dense_features2'][idx_in_batch]
        _, h2, w2 = dense_features2.size()
        scores2 = output['scores2'][idx_in_batch]

        all_descriptors1 = F.normalize(dense_features1.view(c, -1), dim=0)
        descriptors1 = all_descriptors1

        all_descriptors2 = F.normalize(dense_features2.view(c, -1), dim=0)

        # Warp the positions from image 1 to image 2
        fmap_pos1 = grid_positions(h1, w1, device)

        hOrig, wOrig = int(batch['image1'].shape[2] / 8), int(
            batch['image1'].shape[3] / 8)
        fmap_pos1Orig = grid_positions(hOrig, wOrig, device)
        pos1 = upscale_positions(fmap_pos1Orig, scaling_steps=scaling_steps)

        # SIFT Feature Detection

        imgNp1 = imshow_image(batch['image1'][idx_in_batch].cpu().numpy(),
                              preprocessing=batch['preprocessing'])
        imgNp1 = cv2.cvtColor(imgNp1, cv2.COLOR_BGR2RGB)
        # surf = cv2.xfeatures2d.SIFT_create(100)
        #surf = cv2.xfeatures2d.SURF_create(80)
        orb = cv2.ORB_create(nfeatures=100, scoreType=cv2.ORB_FAST_SCORE)

        kp = orb.detect(imgNp1, None)

        keyP = [(kp[i].pt) for i in range(len(kp))]
        keyP = np.asarray(keyP).T
        keyP[[0, 1]] = keyP[[1, 0]]
        keyP = np.floor(keyP) + 0.5

        pos1 = torch.from_numpy(keyP).to(pos1.device).float()

        try:
            pos1, pos2, ids = warp(pos1, depth1, intrinsics1, pose1, bbox1,
                                   depth2, intrinsics2, pose2, bbox2)
        except EmptyTensorError:
            continue

        ids = idsAlign(pos1, device, h1, w1)

        # cv2.drawKeypoints(imgNp1, kp, imgNp1)
        # cv2.imshow('Keypoints', imgNp1)
        # cv2.waitKey(0)

        # drawTraining(batch['image1'], batch['image2'], pos1, pos2, batch, idx_in_batch, output, save=False)

        # exit(1)

        # Top view homography adjustment

        # H1 = output['H1'][idx_in_batch]
        # H2 = output['H2'][idx_in_batch]

        # try:
        # 	pos1, pos2 = homoAlign(pos1, pos2, H1, H2, device)
        # except IndexError:
        # 	continue

        # ids = idsAlign(pos1, device, h1, w1)

        # img_warp1 = tgm.warp_perspective(batch['image1'].to(device), H1, dsize=(400, 400))
        # img_warp2 = tgm.warp_perspective(batch['image2'].to(device), H2, dsize=(400, 400))

        # drawTraining(img_warp1, img_warp2, pos1, pos2, batch, idx_in_batch, output)

        fmap_pos1 = fmap_pos1[:, ids]
        descriptors1 = descriptors1[:, ids]
        scores1 = scores1[ids]

        # Skip the pair if not enough GT correspondences are available
        if ids.size(0) < 128:
            print(ids.size(0))
            continue

        # Descriptors at the corresponding positions
        fmap_pos2 = torch.round(
            downscale_positions(pos2, scaling_steps=scaling_steps)).long()

        descriptors2 = F.normalize(dense_features2[:, fmap_pos2[0, :],
                                                   fmap_pos2[1, :]],
                                   dim=0)

        positive_distance = 2 - 2 * (descriptors1.t().unsqueeze(
            1) @ descriptors2.t().unsqueeze(2)).squeeze()

        # positive_distance = getPositiveDistance(descriptors1, descriptors2)

        all_fmap_pos2 = grid_positions(h2, w2, device)
        position_distance = torch.max(torch.abs(
            fmap_pos2.unsqueeze(2).float() - all_fmap_pos2.unsqueeze(1)),
                                      dim=0)[0]
        is_out_of_safe_radius = position_distance > safe_radius

        distance_matrix = 2 - 2 * (descriptors1.t() @ all_descriptors2)
        # distance_matrix = getDistanceMatrix(descriptors1, all_descriptors2)

        negative_distance2 = torch.min(
            distance_matrix + (1 - is_out_of_safe_radius.float()) * 10.,
            dim=1)[0]

        # negative_distance2 = semiHardMine(distance_matrix, is_out_of_safe_radius, positive_distance, margin)

        all_fmap_pos1 = grid_positions(h1, w1, device)
        position_distance = torch.max(torch.abs(
            fmap_pos1.unsqueeze(2).float() - all_fmap_pos1.unsqueeze(1)),
                                      dim=0)[0]
        is_out_of_safe_radius = position_distance > safe_radius

        distance_matrix = 2 - 2 * (descriptors2.t() @ all_descriptors1)
        # distance_matrix = getDistanceMatrix(descriptors2, all_descriptors1)

        negative_distance1 = torch.min(
            distance_matrix + (1 - is_out_of_safe_radius.float()) * 10.,
            dim=1)[0]

        # negative_distance1 = semiHardMine(distance_matrix, is_out_of_safe_radius, positive_distance, margin)

        diff = positive_distance - torch.min(negative_distance1,
                                             negative_distance2)

        scores2 = scores2[fmap_pos2[0, :], fmap_pos2[1, :]]

        loss = loss + (torch.sum(scores1 * scores2 * F.relu(margin + diff)) /
                       (torch.sum(scores1 * scores2)))

        has_grad = True
        n_valid_samples += 1

        if plot and batch['batch_idx'] % batch['log_interval'] == 0:
            print("Inside plot.")
            drawTraining(batch['image1'],
                         batch['image2'],
                         pos1,
                         pos2,
                         batch,
                         idx_in_batch,
                         output,
                         save=True)
            # drawTraining(img_warp1, img_warp2, pos1, pos2, batch, idx_in_batch, output, save=True)

    if not has_grad:
        raise NoGradientError

    loss = loss / (n_valid_samples)

    return loss