예제 #1
0
def im_show_1(image=None, title='', rows=1, cols=1, index=1):
    """ Show image (transfer tensor to numpy first) """
    image = normalize_image(image, forward=False)
    image = image.permute(1, 2, 0).cpu().numpy()
    ax = plt.subplot(rows, cols, index)
    ax.set_title(title)
    ax.imshow(image.clip(0, 1))

    return ax
예제 #2
0
def vis_feature(vis, model, dataloader, use_cuda=True):
    # Visualize feature map of watch image
    h = 40
    num = 1024
    id = 4
    for batch_idx, batch in enumerate(dataloader):
        if use_cuda:
            batch = batch_cuda(batch)
        theta, feature_A, feature_B, correlation = model(batch)
        if batch_idx == id:
            break
    watch_feature_A = F.interpolate(feature_A, size=(h, h), mode='bilinear', align_corners=True).transpose(0, 1)[0:num, :, :, :]
    watch_feature_B = F.interpolate(feature_B, size=(h, h), mode='bilinear', align_corners=True).transpose(0, 1)[0:num, :, :, :]

    opts = dict(jpgquality=100, title='source image')
    image_A = normalize_image(batch['source_image'][0], forward=False) * 255.0
    vis.image(image_A, opts=opts)

    nrow = 32
    padding = 3
    opts = dict(jpgquality=100, title='feature map A')
    # vis.images(watch_feature_A * 255.0, nrow=nrow, padding=padding, opts=opts)
    vis.image(torchvision.utils.make_grid(watch_feature_A * 255.0, nrow=nrow, padding=padding), opts=opts)
    # vis.image(watch_feature_A[0], opts=opts)

    opts = dict(jpgquality=100, title='target image')
    image_B = normalize_image(batch['target_image'][0], forward=False) * 255.0
    vis.image(image_B, opts=opts)

    opts = dict(jpgquality=100, title='feature map B')
    vis.image(torchvision.utils.make_grid(watch_feature_B * 255.0, nrow=nrow, padding=padding), opts=opts)
    # vis.images(watch_feature_B * 255.0, nrow=nrow, padding=padding, opts=opts)
    # vis.image(watch_feature_B[0], opts=opts)

    # opts = dict(title='correlation')
    # vis.heatmap(correlation[0, 0, :, :], opts=opts)
예제 #3
0
    def draw_image(images, names, flows):
        images[:, :, 0:240, :] = normalize_image(images[:, :, 0:240, :],
                                                 forward=False)
        images *= 255.0
        images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
        for i in range(images.shape[0]):
            pos_name = (80, 255)
            if (i + 1) % group_size == 1 or (i + 1) % group_size == 0:
                pos_flow = (0, 0)
            else:
                pos_flow = (70, 275)
            cv2.putText(images[i], names[i], pos_name, fnt, 0.5, (0, 0, 0), 1)
            cv2.putText(images[i], flows[i], pos_flow, fnt, 0.5, (0, 0, 0), 1)

        images = torch.Tensor(images.astype(np.float32))
        images = images.permute(0, 3, 1, 2)

        return images
예제 #4
0
def vis_pf(vis,
           dataloader,
           theta,
           theta_weak,
           theta_inver,
           theta_weak_inver,
           results,
           results_weak,
           dataset_name,
           use_cuda=True):
    # Visualize watch images
    affTnf = GeometricTnf(geometric_model='affine', use_cuda=use_cuda)
    tpsTnf = GeometricTnf(geometric_model='tps', use_cuda=use_cuda)
    pt = PointTnf(use_cuda=use_cuda)

    watch_images = torch.ones(len(dataloader) * 6, 3, 280, 240)
    watch_keypoints = -torch.ones(len(dataloader) * 6, 2, 20)
    if use_cuda:
        watch_images = watch_images.cuda()
        watch_keypoints = watch_keypoints.cuda()
    num_points = np.ones(len(dataloader) * 6).astype(np.int8)
    correct_index = list()
    image_names = list()
    metrics = list()

    # Colors for keypoints
    cmap = plt.get_cmap('tab20')
    colors = list()
    for c in range(20):
        r = cmap(c)[0] * 255
        g = cmap(c)[1] * 255
        b = cmap(c)[2] * 255
        colors.append((b, g, r))
    fnt = cv2.FONT_HERSHEY_COMPLEX

    # means for normalize of caffe resnet and vgg
    # pixel_means = torch.Tensor(np.array([[[[102.9801, 115.9465, 122.7717]]]]).astype(np.float32))
    for batch_idx, batch in enumerate(dataloader):
        if use_cuda:
            batch = batch_cuda(batch)

        # Theta and theta_inver
        theta_aff = theta['aff'][batch_idx].unsqueeze(0)
        theta_aff_tps = theta['aff_tps'][batch_idx].unsqueeze(0)
        theta_weak_aff = theta_weak['aff'][batch_idx].unsqueeze(0)
        theta_weak_aff_tps = theta_weak['aff_tps'][batch_idx].unsqueeze(0)

        theta_aff_inver = theta_inver['aff'][batch_idx].unsqueeze(0)
        theta_aff_tps_inver = theta_inver['aff_tps'][batch_idx].unsqueeze(0)
        theta_weak_aff_inver = theta_weak_inver['aff'][batch_idx].unsqueeze(0)
        theta_weak_aff_tps_inver = theta_weak_inver['aff_tps'][
            batch_idx].unsqueeze(0)

        # Warped image
        warped_aff = affTnf(batch['source_image'], theta_aff)
        warped_aff_tps = tpsTnf(warped_aff, theta_aff_tps)
        warped_weak_aff = affTnf(batch['source_image'], theta_weak_aff)
        warped_weak_aff_tps = tpsTnf(warped_weak_aff, theta_weak_aff_tps)

        watch_images[batch_idx * 6, :, 0:240, :] = batch['source_image']
        watch_images[batch_idx * 6 + 1, :, 0:240, :] = warped_aff
        watch_images[batch_idx * 6 + 2, :, 0:240, :] = warped_aff_tps
        watch_images[batch_idx * 6 + 3, :, 0:240, :] = batch['target_image']
        watch_images[batch_idx * 6 + 4, :, 0:240, :] = warped_weak_aff
        watch_images[batch_idx * 6 + 5, :, 0:240, :] = warped_weak_aff_tps

        # Warped keypoints
        source_im_size = batch['source_im_info'][:, 0:3]
        target_im_size = batch['target_im_info'][:, 0:3]

        source_points = batch['source_points']
        target_points = batch['target_points']

        source_points_norm = PointsToUnitCoords(P=source_points,
                                                im_size=source_im_size)
        target_points_norm = PointsToUnitCoords(P=target_points,
                                                im_size=target_im_size)

        warped_points_aff_norm = pt.affPointTnf(theta=theta_aff_inver,
                                                points=source_points_norm)
        warped_points_aff = PointsToPixelCoords(P=warped_points_aff_norm,
                                                im_size=target_im_size)
        pck_aff, index_aff, N_pts = pck(target_points, warped_points_aff,
                                        dataset_name)
        warped_points_aff = relocate(warped_points_aff, target_im_size)

        warped_points_aff_tps_norm = pt.tpsPointTnf(theta=theta_aff_tps_inver,
                                                    points=source_points_norm)
        warped_points_aff_tps_norm = pt.affPointTnf(
            theta=theta_aff_inver, points=warped_points_aff_tps_norm)
        warped_points_aff_tps = PointsToPixelCoords(
            P=warped_points_aff_tps_norm, im_size=target_im_size)
        pck_aff_tps, index_aff_tps, _ = pck(target_points,
                                            warped_points_aff_tps,
                                            dataset_name)
        warped_points_aff_tps = relocate(warped_points_aff_tps, target_im_size)

        warped_points_weak_aff_norm = pt.affPointTnf(
            theta=theta_weak_aff_inver, points=source_points_norm)
        warped_points_weak_aff = PointsToPixelCoords(
            P=warped_points_weak_aff_norm, im_size=target_im_size)
        pck_weak_aff, index_weak_aff, _ = pck(target_points,
                                              warped_points_weak_aff,
                                              dataset_name)
        warped_points_weak_aff = relocate(warped_points_weak_aff,
                                          target_im_size)

        warped_points_weak_aff_tps_norm = pt.tpsPointTnf(
            theta=theta_weak_aff_tps_inver, points=source_points_norm)
        warped_points_weak_aff_tps_norm = pt.affPointTnf(
            theta=theta_weak_aff_inver, points=warped_points_weak_aff_tps_norm)
        warped_points_weak_aff_tps = PointsToPixelCoords(
            P=warped_points_weak_aff_tps_norm, im_size=target_im_size)
        pck_weak_aff_tps, index_weak_aff_tps, _ = pck(
            target_points, warped_points_weak_aff_tps, dataset_name)
        warped_points_weak_aff_tps = relocate(warped_points_weak_aff_tps,
                                              target_im_size)

        watch_keypoints[batch_idx * 6, :, :N_pts] = relocate(
            batch['source_points'], source_im_size)[:, :, :N_pts]
        watch_keypoints[batch_idx * 6 +
                        1, :, :N_pts] = warped_points_aff[:, :, :N_pts]
        watch_keypoints[batch_idx * 6 +
                        2, :, :N_pts] = warped_points_aff_tps[:, :, :N_pts]
        watch_keypoints[batch_idx * 6 + 3, :, :N_pts] = relocate(
            batch['target_points'], target_im_size)[:, :, :N_pts]
        watch_keypoints[batch_idx * 6 +
                        4, :, :N_pts] = warped_points_weak_aff[:, :, :N_pts]
        watch_keypoints[
            batch_idx * 6 +
            5, :, :N_pts] = warped_points_weak_aff_tps[:, :, :N_pts]

        num_points[batch_idx * 6:batch_idx * 6 + 6] = N_pts

        correct_index.append(np.arange(N_pts))
        correct_index.append(index_aff)
        correct_index.append(index_aff_tps)
        correct_index.append(np.arange(N_pts))
        correct_index.append(index_weak_aff)
        correct_index.append(index_weak_aff_tps)

        image_names.append('Source')
        image_names.append('Aff')
        image_names.append('Aff_tps')
        image_names.append('Target')
        image_names.append('Rocco_aff')
        image_names.append('Rocco_aff_tps')

        metrics.append('')
        metrics.append('PCK: {:.2%}'.format(pck_aff))
        metrics.append('PCK: {:.2%}'.format(pck_aff_tps))
        metrics.append('')
        metrics.append('PCK: {:.2%}'.format(pck_weak_aff))
        metrics.append('PCK: {:.2%}'.format(pck_weak_aff_tps))

    opts = dict(jpgquality=100, title=dataset_name)
    # Un-normalize for caffe resnet and vgg
    # watch_images = watch_images.permute(0, 2, 3, 1) + pixel_means
    # watch_images = watch_images[:, :, :, [2, 1, 0]].permute(0, 3, 1, 2)
    # watch_images = normalize_image(watch_images, forward=False) * 255.0
    watch_images[:, :, 0:240, :] = normalize_image(watch_images[:, :,
                                                                0:240, :],
                                                   forward=False)
    watch_images *= 255.0
    watch_images = watch_images.permute(0, 2, 3,
                                        1).cpu().numpy().astype(np.uint8)
    watch_keypoints = watch_keypoints.cpu().numpy()

    for i in range(watch_images.shape[0]):
        pos_name = (80, 255)
        if (i + 1) % 6 == 1 or (i + 1) % 6 == 4:
            pos_pck = (0, 0)
        else:
            pos_pck = (70, 275)
        cv2.putText(watch_images[i], image_names[i], pos_name, fnt, 0.5,
                    (0, 0, 0), 1)
        cv2.putText(watch_images[i], metrics[i], pos_pck, fnt, 0.5, (0, 0, 0),
                    1)
        if (i + 1) % 6 == 4:
            for j in range(num_points[i]):
                cv2.drawMarker(
                    watch_images[i],
                    (watch_keypoints[i, 0, j], watch_keypoints[i, 1, j]),
                    colors[j], cv2.MARKER_DIAMOND, 12, 2, cv2.LINE_AA)
        else:
            for j in correct_index[i]:
                cv2.drawMarker(
                    watch_images[i],
                    (watch_keypoints[i, 0, j], watch_keypoints[i, 1, j]),
                    colors[j], cv2.MARKER_CROSS, 12, 2, cv2.LINE_AA)
                cv2.drawMarker(watch_images[i],
                               (watch_keypoints[i + 3 - (i % 6), 0, j],
                                watch_keypoints[i + 3 - (i % 6), 1, j]),
                               colors[j], cv2.MARKER_DIAMOND, 12, 2,
                               cv2.LINE_AA)

    watch_images = torch.Tensor(watch_images.astype(np.float32))
    watch_images = watch_images.permute(0, 3, 1, 2)
    vis.image(torchvision.utils.make_grid(watch_images, nrow=3, padding=3),
              opts=opts)
예제 #5
0
def vis_caltech(vis,
                dataloader,
                theta,
                theta_weak,
                results,
                results_weak,
                title,
                use_cuda=True):
    # Visualize watch images
    affTnf = GeometricTnf(geometric_model='affine', use_cuda=use_cuda)
    tpsTnf = GeometricTnf(geometric_model='tps', use_cuda=use_cuda)
    watch_images = torch.ones(len(dataloader) * 6, 3, 280, 240)
    if use_cuda:
        watch_images = watch_images.cuda()
    image_names = list()
    lt_acc = list()
    iou = list()
    fnt = cv2.FONT_HERSHEY_COMPLEX
    # means for normalize of caffe resnet and vgg
    # pixel_means = torch.Tensor(np.array([[[[102.9801, 115.9465, 122.7717]]]]).astype(np.float32))
    for batch_idx, batch in enumerate(dataloader):
        if use_cuda:
            batch = batch_cuda(batch)

        theta_aff = theta['aff'][batch_idx].unsqueeze(0)
        theta_aff_tps = theta['aff_tps'][batch_idx].unsqueeze(0)
        theta_weak_aff = theta_weak['aff'][batch_idx].unsqueeze(0)
        theta_weak_aff_tps = theta_weak['aff_tps'][batch_idx].unsqueeze(0)

        # Warped image
        warped_aff = affTnf(batch['source_image'], theta_aff)
        warped_aff_tps = tpsTnf(warped_aff, theta_aff_tps)
        warped_weak_aff = affTnf(batch['source_image'], theta_weak_aff)
        warped_weak_aff_tps = tpsTnf(warped_weak_aff, theta_weak_aff_tps)

        watch_images[batch_idx * 6, :, 0:240, :] = batch['source_image']
        watch_images[batch_idx * 6 + 1, :, 0:240, :] = warped_aff
        watch_images[batch_idx * 6 + 2, :, 0:240, :] = warped_aff_tps
        watch_images[batch_idx * 6 + 3, :, 0:240, :] = batch['target_image']
        watch_images[batch_idx * 6 + 4, :, 0:240, :] = warped_weak_aff
        watch_images[batch_idx * 6 + 5, :, 0:240, :] = warped_weak_aff_tps

        image_names.append('Source')
        image_names.append('Aff')
        image_names.append('Aff_tps')
        image_names.append('Target')
        image_names.append('Rocco_aff')
        image_names.append('Rocco_aff_tps')

        lt_acc.append('')
        lt_acc.append('LT-ACC: {:.2f}'.format(
            float(results['aff']['label_transfer_accuracy'][batch_idx])))
        lt_acc.append('LT-ACC: {:.2f}'.format(
            float(results['aff_tps']['label_transfer_accuracy'][batch_idx])))
        lt_acc.append('')
        lt_acc.append('LT-ACC: {:.2f}'.format(
            float(results_weak['aff']['label_transfer_accuracy'][batch_idx])))
        lt_acc.append('LT-ACC: {:.2f}'.format(
            float(results_weak['aff_tps']['label_transfer_accuracy']
                  [batch_idx])))

        iou.append('')
        iou.append('IoU: {:.2f}'.format(
            float(results['aff']['intersection_over_union'][batch_idx])))
        iou.append('IoU: {:.2f}'.format(
            float(results['aff_tps']['intersection_over_union'][batch_idx])))
        iou.append('')
        iou.append('IoU: {:.2f}'.format(
            float(results_weak['aff']['intersection_over_union'][batch_idx])))
        iou.append('IoU: {:.2f}'.format(
            float(results_weak['aff_tps']['intersection_over_union']
                  [batch_idx])))

    opts = dict(jpgquality=100, title=title)
    # Un-normalize for caffe resnet and vgg
    # watch_images = watch_images.permute(0, 2, 3, 1) + pixel_means
    # watch_images = watch_images[:, :, :, [2, 1, 0]].permute(0, 3, 1, 2)
    # watch_images = normalize_image(watch_images, forward=False) * 255.0
    watch_images[:, :, 0:240, :] = normalize_image(watch_images[:, :,
                                                                0:240, :],
                                                   forward=False)
    watch_images *= 255.0
    watch_images = watch_images.permute(0, 2, 3,
                                        1).cpu().numpy().astype(np.uint8)
    for i in range(watch_images.shape[0]):
        pos_name = (80, 255)
        if (i + 1) % 6 == 1 or (i + 1) % 6 == 4:
            pos_lt_ac = (0, 0)
            pos_iou = (0, 0)
        else:
            pos_lt_ac = (10, 275)
            pos_iou = (140, 275)
        cv2.putText(watch_images[i], image_names[i], pos_name, fnt, 0.5,
                    (0, 0, 0), 1)
        cv2.putText(watch_images[i], lt_acc[i], pos_lt_ac, fnt, 0.5, (0, 0, 0),
                    1)
        cv2.putText(watch_images[i], iou[i], pos_iou, fnt, 0.5, (0, 0, 0), 1)

    watch_images = torch.Tensor(watch_images.astype(np.float32))
    watch_images = watch_images.permute(0, 3, 1, 2)
    vis.image(torchvision.utils.make_grid(watch_images, nrow=3, padding=5),
              opts=opts)
예제 #6
0
def vis_fn(vis, train_loss, val_pck, train_lr, epoch, num_epochs, dataloader, theta, thetai, results,
           geometric_model='tps', use_cuda=True):
    # Visualize watch images
    if geometric_model == 'tps':
        geoTnf = GeometricTnf(geometric_model='tps', use_cuda=use_cuda)
    elif geometric_model == 'affine':
        geoTnf = GeometricTnf(geometric_model='affine', use_cuda=use_cuda)

    pt = PointTPS(use_cuda=use_cuda)

    group_size = 3
    watch_images = torch.ones(len(dataloader) * group_size, 3, 280, 240)
    watch_keypoints = -torch.ones(len(dataloader) * group_size, 2, 20)
    if use_cuda:
        watch_images = watch_images.cuda()
        watch_keypoints = watch_keypoints.cuda()
    num_points = np.ones(len(dataloader) * group_size).astype(np.int8)
    correct_index = list()
    image_names = list()
    metrics = list()

    # Colors for keypoints
    cmap = plt.get_cmap('tab20')
    colors = list()
    for c in range(20):
        r = cmap(c)[0] * 255
        g = cmap(c)[1] * 255
        b = cmap(c)[2] * 255
        colors.append((b, g, r))
    fnt = cv2.FONT_HERSHEY_COMPLEX

    theta, thetai = swap(theta, thetai)
    # means for normalize of caffe resnet and vgg
    # pixel_means = torch.Tensor(np.array([[[[102.9801, 115.9465, 122.7717]]]]).astype(np.float32))
    for batch_idx, batch in enumerate(dataloader):
        if use_cuda:
            batch = batch_cuda(batch)

        batch['source_image'], batch['target_image'] = swap(batch['source_image'], batch['target_image'])
        batch['source_im_info'], batch['target_im_info'] = swap(batch['source_im_info'], batch['target_im_info'])
        batch['source_points'], batch['target_points'] = swap(batch['source_points'], batch['target_points'])

        # Theta and thetai
        if geometric_model == 'tps':
            theta_batch = theta['tps'][batch_idx].unsqueeze(0)
            theta_batch_inver = thetai['tps'][batch_idx].unsqueeze(0)
        elif geometric_model == 'affine':
            theta_batch = theta['aff'][batch_idx].unsqueeze(0)
            theta_batch_inver = thetai['aff'][batch_idx].unsqueeze(0)

        # Warped image
        warped_image = geoTnf(batch['source_image'], theta_batch)

        watch_images[batch_idx * group_size, :, 0:240, :] = batch['source_image']
        watch_images[batch_idx * group_size + 1, :, 0:240, :] = warped_image
        watch_images[batch_idx * group_size + 2, :, 0:240, :] = batch['target_image']

        # Warped keypoints
        source_im_size = batch['source_im_info'][:, 0:3]
        target_im_size = batch['target_im_info'][:, 0:3]

        source_points = batch['source_points']
        target_points = batch['target_points']

        source_points_norm = PointsToUnitCoords(P=source_points, im_size=source_im_size)
        target_points_norm = PointsToUnitCoords(P=target_points, im_size=target_im_size)

        if geometric_model == 'tps':
            warped_points_norm = pt.tpsPointTnf(theta=theta_batch_inver, points=source_points_norm)
        elif geometric_model == 'affine':
            warped_points_norm = pt.affPointTnf(theta=theta_batch_inver, points=source_points_norm)

        warped_points = PointsToPixelCoords(P=warped_points_norm, im_size=target_im_size)
        _, index_correct, N_pts = pck(target_points, warped_points)

        watch_keypoints[batch_idx * group_size, :, :N_pts] = relocate(batch['source_points'], source_im_size)[:, :, :N_pts]
        watch_keypoints[batch_idx * group_size + 1, :, :N_pts] = relocate(warped_points, target_im_size)[:, :, :N_pts]
        watch_keypoints[batch_idx * group_size + 2, :, :N_pts] = relocate(batch['target_points'], target_im_size)[:, :, :N_pts]

        num_points[batch_idx * group_size:batch_idx * group_size + group_size] = N_pts

        correct_index.append(np.arange(N_pts))
        correct_index.append(index_correct)
        correct_index.append(np.arange(N_pts))

        image_names.append('Source')
        if geometric_model == 'tps':
            image_names.append('TPS')
        elif geometric_model == 'affine':
            image_names.append('Affine')
        image_names.append('Target')

        metrics.append('')
        if geometric_model == 'tps':
            metrics.append('PCK: {:.2%}'.format(float(results['tps']['pck'][batch_idx])))
        elif geometric_model == 'affine':
            metrics.append('PCK: {:.2%}'.format(float(results['aff']['pck'][batch_idx])))
        metrics.append('')

    opts = dict(jpgquality=100, title='Epoch ' + str(epoch) + ' source warped target')
    # Un-normalize for caffe resnet and vgg
    # watch_images = watch_images.permute(0, 2, 3, 1) + pixel_means
    # watch_images = watch_images[:, :, :, [2, 1, 0]].permute(0, 3, 1, 2)
    # watch_images = normalize_image(watch_images, forward=False) * 255.0
    watch_images[:, :, 0:240, :] = normalize_image(watch_images[:, :, 0:240, :], forward=False)
    watch_images *= 255.0
    watch_images = watch_images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
    watch_keypoints = watch_keypoints.cpu().numpy()

    for i in range(watch_images.shape[0]):
        pos_name = (80, 255)
        if (i + 1) % group_size == 1 or (i + 1) % group_size == 0:
            pos_pck = (0, 0)
        else:
            pos_pck = (70, 275)
        cv2.putText(watch_images[i], image_names[i], pos_name, fnt, 0.5, (0, 0, 0), 1)
        cv2.putText(watch_images[i], metrics[i], pos_pck, fnt, 0.5, (0, 0, 0), 1)
        if (i + 1) % group_size == 0:
            for j in range(num_points[i]):
                cv2.drawMarker(watch_images[i], (watch_keypoints[i, 0, j], watch_keypoints[i, 1, j]), colors[j],
                               cv2.MARKER_CROSS, 12, 2, cv2.LINE_AA)
        else:
            for j in correct_index[i]:
                cv2.drawMarker(watch_images[i], (watch_keypoints[i, 0, j], watch_keypoints[i, 1, j]), colors[j],
                               cv2.MARKER_DIAMOND, 12, 2, cv2.LINE_AA)
                cv2.drawMarker(watch_images[i],
                               (watch_keypoints[i + (group_size - 1) - (i % group_size), 0, j], watch_keypoints[i + (group_size - 1) - (i % group_size), 1, j]),
                               colors[j], cv2.MARKER_CROSS, 12, 2, cv2.LINE_AA)

    watch_images = torch.Tensor(watch_images.astype(np.float32))
    watch_images = watch_images.permute(0, 3, 1, 2)
    vis.image(torchvision.utils.make_grid(watch_images, nrow=6, padding=3), opts=opts)

    if epoch == num_epochs:
        if geometric_model == 'affine':
            sub_str = 'Affine'
        elif geometric_model == 'tps':
            sub_str = 'TPS'
        epochs = np.arange(1, num_epochs+1)
        # Visualize train loss
        opts_loss = dict(xlabel='Epoch',
                    ylabel='Loss',
                    title='GM ResNet101 ' + sub_str + ' Training Loss',
                    legend=['Loss'],
                    width=2000)
        vis.line(train_loss, epochs, opts=opts_loss)

        # Visualize val pck
        opts_pck = dict(xlabel='Epoch',
                    ylabel='Val PCK',
                    title='GM ResNet101 ' + sub_str + ' Val PCK',
                    legend=['PCK'],
                    width=2000)
        vis.line(val_pck, epochs, opts=opts_pck)

        # Visualize train lr
        opts_lr = dict(xlabel='Epoch',
                       ylabel='Learning Rate',
                       title='GM ResNet101 ' + sub_str + ' Training Learning Rate',
                       legend=['LR'],
                       width=2000)
        vis.line(train_lr, epochs, opts=opts_lr)
def train_fn(epoch,
             model,
             loss_fn,
             loss_cycle_fn,
             loss_jitter_fn,
             lambda_c,
             lambda_j,
             optimizer,
             dataloader,
             triple_generation,
             geometric_model='tps',
             use_cuda=True,
             log_interval=100,
             vis=None):
    """
        Train the model with synthetically training triple:
        {source image, target image, refer image (warped source image), theta_GT} from PF-PASCAL.
        1. Train the transformation parameters theta_st from source image to target image;
        2. Train the transformation parameters theta_tr from target image to refer image;
        3. Combine theta_st and theta_st to obtain theta from source image to refer image, and compute loss between
        theta and theta_GT.
    """

    geoTnf = GeometricTnf(geometric_model=geometric_model, use_cuda=use_cuda)
    epoch_loss = 0
    if (epoch % 5 == 0 or epoch == 1) and vis is not None:
        stride_images = len(dataloader) / 3
        group_size = 9
        watch_images = torch.ones(group_size * 4, 3, 340, 340).cuda()
        watch_theta = torch.zeros(8, 36).cuda()
        fnt = cv2.FONT_HERSHEY_COMPLEX
        stride_loss = len(dataloader) / 105
        iter_loss = np.zeros(106)
    begin = time.time()
    for batch_idx, batch in enumerate(dataloader):
        ''' Move input batch to gpu '''
        # batch['source_image'].shape & batch['target_image'].shape: (batch_size, 3, 240, 240)
        # batch['theta'].shape-tps: (batch_size, 18)-random or (batch_size, 18, 1, 1)-(pre-set from csv)
        if use_cuda:
            batch = batch_cuda(batch)
        ''' Get the training triple {source image, target image, refer image (warped source image), theta_GT}'''
        batch_triple = triple_generation(batch)
        ''' Train the model '''
        optimizer.zero_grad()
        loss = 0
        # Predict tps parameters between images
        # theta.shape: (batch_size, 18) for tps, theta.shape: (batch_size, 6) for affine
        theta_st, theta_ts, theta_tr, theta_rt = model(
            batch_triple)  # from source image to target image
        loss_match = loss_fn(theta_st=theta_st,
                             theta_tr=theta_tr,
                             theta_GT=batch_triple['theta_GT'])
        loss_cycle_st = loss_cycle_fn(theta_AB=theta_st, theta_BA=theta_ts)
        loss_cycle_ts = loss_cycle_fn(theta_AB=theta_ts, theta_BA=theta_st)
        loss_cycle_tr = loss_cycle_fn(theta_AB=theta_tr, theta_BA=theta_rt)
        loss_cycle_rt = loss_cycle_fn(theta_AB=theta_rt, theta_BA=theta_tr)
        loss_jitter = loss_jitter_fn(theta_st=theta_st, theta_tr=theta_tr)
        loss = loss_match + lambda_c * (
            loss_cycle_st + loss_cycle_ts + loss_cycle_tr +
            loss_cycle_rt) / 4 + lambda_j * loss_jitter
        # loss = loss_match + lambda_c * (loss_cycle_st + loss_cycle_ts + loss_cycle_tr + loss_cycle_rt) / 4
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        if (batch_idx + 1) % log_interval == 0:
            end = time.time()
            print(
                'Train epoch: {} [{}/{} ({:.0%})]\t\tCurrent batch loss: {:.6f}\t\tTime cost ({} batches): {:.4f} s'
                .format(epoch, batch_idx + 1,
                        len(dataloader), (batch_idx + 1) / len(dataloader),
                        loss.item(), batch_idx + 1, end - begin))

        if (epoch % 5 == 0 or epoch == 1) and vis is not None:
            if (batch_idx + 1) % stride_images == 0 or batch_idx == 0:
                watch_images, watch_theta = add_watch(
                    watch_images, watch_theta, batch_triple, geoTnf, theta_st,
                    theta_tr,
                    int((batch_idx + 1) / stride_images) * group_size,
                    int((batch_idx + 1) / stride_images))

            # if batch_idx <= 19:
            #     watch_images, image_names = add_watch(watch_images, watch_theta, batch_st, batch_tr, geoTnf, theta_st, theta_tr, batch_idx * group_size, batch_idx)
            #     if batch_idx == 19:
            #         opts = dict(jpgquality=100, title='Epoch ' + str(epoch) + ' source warped_sr target warped_tr refer warped_sr')
            #         watch_images[:, :, 50:290, 50:290] = normalize_image(watch_images[:, :, 50:290, 50:290], forward=False)
            #         watch_images *= 255.0
            #         watch_images = watch_images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
            #         for i in range(watch_images.shape[0]):
            #             if i % group_size == 0:
            #                 cp_norm = watch_theta[int(i / group_size), :18].view(1, 2, -1)
            #                 watch_images[i] = draw_grid(watch_images[i], cp_norm)
            #
            #                 cp_norm = watch_theta[int(i / group_size), 18:].view(1, 2, -1)
            #                 watch_images[i + 1] = draw_grid(watch_images[i + 1], cp_norm)
            #
            #                 cp_norm = watch_theta[int(i / group_size) + 1, :18].view(1, 2, -1)
            #                 watch_images[i + 3] = draw_grid(watch_images[i + 3], cp_norm)
            #
            #                 cp_norm = watch_theta[int(i / group_size) + 1, 18:].view(1, 2, -1)
            #                 watch_images[i + 4] = draw_grid(watch_images[i + 4], cp_norm)
            #
            #         watch_images = torch.Tensor(watch_images.astype(np.float32))
            #         watch_images = watch_images.permute(0, 3, 1, 2)
            #         vis.image(torchvision.utils.make_grid(watch_images, nrow=3, padding=3), opts=opts)

            if (batch_idx + 1) % stride_loss == 0 or batch_idx == 0:
                iter_loss[int((batch_idx + 1) /
                              stride_loss)] = epoch_loss / (batch_idx + 1)

        # watch_images = normalize_image(batch_tr['target_image'], forward=False) * 255.0
        # vis.images(watch_images, nrow=8, padding=3)

    end = time.time()

    # Visualize watch images & train loss
    if (epoch % 5 == 0 or epoch == 1) and vis is not None:
        opts = dict(jpgquality=100,
                    title='Epoch ' + str(epoch) +
                    ' source warped_sr target warped_tr refer warped_sr')
        watch_images[:, :, 50:290,
                     50:290] = normalize_image(watch_images[:, :, 50:290,
                                                            50:290],
                                               forward=False)
        watch_images *= 255.0
        watch_images = watch_images.permute(0, 2, 3,
                                            1).cpu().numpy().astype(np.uint8)
        for i in range(watch_images.shape[0]):
            if i % group_size == 0:
                cp_norm = watch_theta[int(i / group_size), :18].view(1, 2, -1)
                watch_images[i] = draw_grid(watch_images[i], cp_norm)

                cp_norm = watch_theta[int(i / group_size), 18:].view(1, 2, -1)
                watch_images[i + 1] = draw_grid(watch_images[i + 1], cp_norm)

                cp_norm = watch_theta[int(i / group_size) + 1, :18].view(
                    1, 2, -1)
                watch_images[i + 3] = draw_grid(watch_images[i + 3], cp_norm)

                cp_norm = watch_theta[int(i / group_size) + 1,
                                      18:].view(1, 2, -1)
                watch_images[i + 4] = draw_grid(watch_images[i + 4], cp_norm)

        watch_images = torch.Tensor(watch_images.astype(np.float32))
        watch_images = watch_images.permute(0, 3, 1, 2)
        vis.image(torchvision.utils.make_grid(watch_images, nrow=3, padding=3),
                  opts=opts)

        opts_loss = dict(xlabel='Iterations (' + str(stride_loss) + ')',
                         ylabel='Loss',
                         title='GM ResNet101 ' + geometric_model +
                         ' Training Loss in Epoch ' + str(epoch),
                         legend=['Loss'],
                         width=2000)
        vis.line(iter_loss, np.arange(106), opts=opts_loss)

    epoch_loss /= len(dataloader)
    print('Train set -- Average loss: {:.6f}\t\tTime cost: {:.4f}'.format(
        epoch_loss, end - begin))
    return epoch_loss, end - begin
예제 #8
0
def train_fn(epoch,
             model,
             loss_fn,
             optimizer,
             dataloader,
             use_cuda=True,
             log_interval=100,
             vis=None):
    """
        Train cosegmentation model:
        {source image, target image} from PF-PASCAL.
        1. Train the co-object masks on source image and target image;
        2. Compute loss.
    """

    epoch_loss = 0
    # if (epoch % 5 == 0 or epoch == 1) and vis is not None:
    if vis is not None:
        stride_images = len(dataloader) / 4
        group_size = 6
        watch_images = torch.ones(group_size * 5, 3, 260, 240).cuda()
        image_names = list()
        fnt = cv2.FONT_HERSHEY_COMPLEX
        # means for normalize of caffe resnet and vgg
        # pixel_means = torch.Tensor(np.array([[[[102.9801, 115.9465, 122.7717]]]]).astype(np.float32)).cuda()
        stride_loss = len(dataloader) / 92
        iter_loss = np.zeros(93)
    begin = time.time()
    for batch_idx, batch in enumerate(dataloader):
        ''' Move input batch to gpu '''
        if use_cuda:
            batch = batch_cuda(batch)
        ''' Train the model '''
        optimizer.zero_grad()
        mask_A, mask_B = model(batch)
        loss = loss_fn(mask_A, mask_B, batch['source_image'],
                       batch['target_image'])
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        if (batch_idx + 1) % log_interval == 0:
            end = time.time()
            print(
                'Train epoch: {} [{}/{} ({:.0%})]\t\tCurrent batch loss: {:.6f}\t\tTime cost ({} batches): {:.4f} s'
                .format(epoch, batch_idx + 1,
                        len(dataloader), (batch_idx + 1) / len(dataloader),
                        loss.item(), batch_idx + 1, end - begin))

        # if (epoch % 5 == 0 or epoch == 1) and vis is not None:
        if vis is not None:
            if (batch_idx + 1) % stride_images == 0 or batch_idx == 0:
                watch_images, image_names = add_watch(
                    watch_images, image_names, batch, mask_A, mask_B,
                    int((batch_idx + 1) / stride_images) * group_size)

            # if batch_idx <= 4:
            #     watch_images, image_names = add_watch(watch_images, image_names, batch, mask_A, mask_B, batch_idx * group_size)
            #     if batch_idx == 4:
            #         opts = dict(jpgquality=100, title='Epoch ' + str(epoch) + ' image mask')
            #         watch_images[:, :, 0:240, :] = normalize_image(watch_images[:, :, 0:240, :], forward=False)
            #         watch_images *= 255.0
            #         watch_images = watch_images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
            #         for i in range(len(image_names)):
            #             cv2.putText(watch_images[i], image_names[i], (80, 255), fnt, 0.5, (0, 0, 0), 1)
            #         watch_images = torch.Tensor(watch_images.astype(np.float32))
            #         watch_images = watch_images.permute(0, 3, 1, 2)
            #         vis.image(torchvision.utils.make_grid(watch_images, nrow=group_size, padding=3), opts=opts)

            if (batch_idx + 1) % stride_loss == 0 or batch_idx == 0:
                iter_loss[int((batch_idx + 1) /
                              stride_loss)] = epoch_loss / (batch_idx + 1)

    end = time.time()

    # Visualize watch images & train loss
    # if (epoch % 5 == 0 or epoch == 1) and vis is not None:
    if vis is not None:
        opts = dict(jpgquality=100,
                    title='Epoch ' + str(epoch) + ' image mask')
        watch_images[:, :, 0:240, :] = normalize_image(watch_images[:, :,
                                                                    0:240, :],
                                                       forward=False)
        watch_images *= 255.0
        watch_images = watch_images.permute(0, 2, 3,
                                            1).cpu().numpy().astype(np.uint8)
        for i in range(watch_images.shape[0]):
            cv2.putText(watch_images[i], image_names[i], (80, 255), fnt, 0.5,
                        (0, 0, 0), 1)
        watch_images = torch.Tensor(watch_images.astype(np.float32))
        watch_images = watch_images.permute(0, 3, 1, 2)
        vis.image(torchvision.utils.make_grid(watch_images,
                                              nrow=group_size,
                                              padding=3),
                  opts=opts)
        # Un-normalize for caffe resnet and vgg
        # watch_images = watch_images.permute(0, 2, 3, 1) + pixel_means
        # watch_images = watch_images[:, :, :, [2, 1, 0]].permute(0, 3, 1, 2)

        opts_loss = dict(xlabel='Iterations (' + str(stride_loss) + ')',
                         ylabel='Loss',
                         title='CoSegmentation Training Loss in Epoch ' +
                         str(epoch),
                         legend=['Loss'],
                         width=2000)
        vis.line(iter_loss, np.arange(93), opts=opts_loss)

    epoch_loss /= len(dataloader)
    print('Train set -- Average loss: {:.6f}\t\tTime cost: {:.4f}'.format(
        epoch_loss, end - begin))
    return epoch_loss, end - begin
예제 #9
0
def vis_fn(vis,
           train_loss,
           val_iou,
           train_lr,
           epoch,
           num_epochs,
           dataloader,
           results,
           masks_A,
           masks_B,
           use_cuda=True):
    # Visualize watch images
    group_size = 6
    watch_images = torch.ones(len(dataloader) * group_size, 3, 280, 240)
    if use_cuda:
        watch_images = watch_images.cuda()
    image_names = list()
    metrics = list()

    # Colors for keypoints
    cmap = plt.get_cmap('tab20')
    colors = list()
    for c in range(20):
        r = cmap(c)[0] * 255
        g = cmap(c)[1] * 255
        b = cmap(c)[2] * 255
        colors.append((b, g, r))
    fnt = cv2.FONT_HERSHEY_COMPLEX

    for batch_idx, batch in enumerate(dataloader):
        if use_cuda:
            batch = batch_cuda(batch)

        # Theta and theta_inver
        watch_images[batch_idx * group_size, :,
                     0:240, :] = batch['source_image']
        watch_images[batch_idx * group_size + 1, :,
                     0:240, :] = torch.mul(batch['source_image'],
                                           batch['source_mask'])
        # watch_images[batch_idx * group_size + 2, :, 0:240, :] = torch.mul(batch['source_image'], masks_A[batch_idx])
        mask_A = masks_A[batch_idx].gt(0.5).float()
        watch_images[batch_idx * group_size + 2, :,
                     0:240, :] = torch.mul(batch['source_image'], mask_A)
        watch_images[batch_idx * group_size + 3, :,
                     0:240, :] = batch['target_image']
        watch_images[batch_idx * group_size + 4, :,
                     0:240, :] = torch.mul(batch['target_image'],
                                           batch['target_mask'])
        # watch_images[batch_idx * group_size + 5, :, 0:240, :] = torch.mul(batch['target_image'], masks_B[batch_idx])
        mask_B = masks_B[batch_idx].gt(0.5).float()
        watch_images[batch_idx * group_size + 5, :,
                     0:240, :] = torch.mul(batch['target_image'], mask_B)

        image_names.append('Source')
        image_names.append('Mask_GT')
        image_names.append('Mask')
        image_names.append('Target')
        image_names.append('Mask_GT')
        image_names.append('Mask')

        metrics.append('')
        metrics.append('')
        metrics.append('IoU: {:.2%}'.format(float(results[batch_idx, 0])))
        metrics.append('')
        metrics.append('')
        metrics.append('IoU: {:.2%}'.format(float(results[batch_idx, 1])))

    opts = dict(jpgquality=100,
                title='Epoch ' + str(epoch) + ' image mask_gt mask')
    # Un-normalize for caffe resnet and vgg
    # watch_images = watch_images.permute(0, 2, 3, 1) + pixel_means
    # watch_images = watch_images[:, :, :, [2, 1, 0]].permute(0, 3, 1, 2)
    # watch_images = normalize_image(watch_images, forward=False) * 255.0
    watch_images[:, :, 0:240, :] = normalize_image(watch_images[:, :,
                                                                0:240, :],
                                                   forward=False)
    watch_images *= 255.0
    watch_images = watch_images.permute(0, 2, 3,
                                        1).cpu().numpy().astype(np.uint8)

    for i in range(watch_images.shape[0]):
        pos_name = (80, 255)
        if (i + 1) % group_size == 3 or (i + 1) % group_size == 0:
            pos_iou = (70, 275)
        else:
            pos_iou = (0, 0)
        cv2.putText(watch_images[i], image_names[i], pos_name, fnt, 0.5,
                    (0, 0, 0), 1)
        cv2.putText(watch_images[i], metrics[i], pos_iou, fnt, 0.5, (0, 0, 0),
                    1)

    watch_images = torch.Tensor(watch_images.astype(np.float32))
    watch_images = watch_images.permute(0, 3, 1, 2)
    vis.image(torchvision.utils.make_grid(watch_images, nrow=6, padding=3),
              opts=opts)

    if epoch == num_epochs:
        epochs = np.arange(1, num_epochs + 1)
        # Visualize train loss
        opts_loss = dict(xlabel='Epoch',
                         ylabel='Loss',
                         title='CoSegmentation Training Loss',
                         legend=['Loss'],
                         width=2000)
        vis.line(train_loss, epochs, opts=opts_loss)

        # Visualize val pck
        opts_pck = dict(xlabel='Epoch',
                        ylabel='Val PCK',
                        title='CoSegmentation Val IoU',
                        legend=['PCK'],
                        width=2000)
        vis.line(val_iou, epochs, opts=opts_pck)

        # Visualize train lr
        opts_lr = dict(xlabel='Epoch',
                       ylabel='Learning Rate',
                       title='CoSegmentation Training Learning Rate',
                       legend=['LR'],
                       width=2000)
        vis.line(train_lr, epochs, opts=opts_lr)
def vis_fn(vis,
           train_loss,
           val_pck,
           train_lr,
           epoch,
           num_epochs,
           dataloader,
           theta,
           thetai,
           results,
           geometric_model='tps',
           use_cuda=True):
    geoTnf = GeometricTnf(geometric_model='tps', use_cuda=use_cuda)

    group_size = 3
    watch_images = torch.ones(len(dataloader) * group_size, 3, 340, 340)
    if use_cuda:
        watch_images = watch_images.cuda()

    # means for normalize of caffe resnet and vgg
    # pixel_means = torch.Tensor(np.array([[[[102.9801, 115.9465, 122.7717]]]]).astype(np.float32))
    for batch_idx, batch in enumerate(dataloader):
        if use_cuda:
            batch = batch_cuda(batch)

        # Theta and thetai
        theta_batch = theta['tps'][batch_idx].unsqueeze(0)

        # Warped image
        warped_image = geoTnf(batch['source_image'], theta_batch)

        watch_images[batch_idx * group_size, :, 50:290,
                     50:290] = batch['source_image']
        watch_images[batch_idx * group_size + 1, :, 50:290,
                     50:290] = warped_image
        watch_images[batch_idx * group_size + 2, :, 50:290,
                     50:290] = batch['target_image']

    opts = dict(jpgquality=100,
                title='Epoch ' + str(epoch) + ' source warped target')
    # Un-normalize for caffe resnet and vgg
    # watch_images = watch_images.permute(0, 2, 3, 1) + pixel_means
    # watch_images = watch_images[:, :, :, [2, 1, 0]].permute(0, 3, 1, 2)
    # watch_images = normalize_image(watch_images, forward=False) * 255.0
    watch_images[:, :, 50:290,
                 50:290] = normalize_image(watch_images[:, :, 50:290, 50:290],
                                           forward=False)
    watch_images *= 255.0
    watch_images = watch_images.permute(0, 2, 3,
                                        1).cpu().numpy().astype(np.uint8)

    for i in range(watch_images.shape[0]):
        if i % group_size == 0:
            cp_norm = theta['tps'][int(i / group_size)][:18].view(1, 2, -1)
            watch_images[i] = draw_grid(watch_images[i], cp_norm)

            cp_norm = theta['tps'][int(i / group_size)][18:].view(1, 2, -1)
            watch_images[i + 1] = draw_grid(watch_images[i + 1], cp_norm)

    watch_images = torch.Tensor(watch_images.astype(np.float32))
    watch_images = watch_images.permute(0, 3, 1, 2)
    vis.image(torchvision.utils.make_grid(watch_images, nrow=3, padding=3),
              opts=opts)

    if epoch == num_epochs:
        if geometric_model == 'affine':
            sub_str = 'Affine'
        elif geometric_model == 'tps':
            sub_str = 'TPS'
        epochs = np.arange(1, num_epochs + 1)
        # Visualize train loss
        opts_loss = dict(xlabel='Epoch',
                         ylabel='Loss',
                         title='GM ResNet101 ' + sub_str + ' Training Loss',
                         legend=['Loss'],
                         width=2000)
        vis.line(train_loss, epochs, opts=opts_loss)

        # Visualize val pck
        opts_pck = dict(xlabel='Epoch',
                        ylabel='Val PCK',
                        title='GM ResNet101 ' + sub_str + ' Val PCK',
                        legend=['PCK'],
                        width=2000)
        vis.line(val_pck, epochs, opts=opts_pck)

        # Visualize train lr
        opts_lr = dict(xlabel='Epoch',
                       ylabel='Learning Rate',
                       title='GM ResNet101 ' + sub_str +
                       ' Training Learning Rate',
                       legend=['LR'],
                       width=2000)
        vis.line(train_lr, epochs, opts=opts_lr)
예제 #11
0
def train_fn_detect(epoch,
                    model,
                    faster_rcnn,
                    aff_theta,
                    loss_fn,
                    optimizer,
                    dataloader,
                    triple_generation,
                    use_cuda=True,
                    log_interval=100,
                    vis=None,
                    show=False):
    """
        Train the model with synthetically training triple:
        {source image, target image, refer image (warped source image), theta_GT} from PF-PASCAL.
        1. Train the transformation parameters theta_st from source image to target image;
        2. Train the transformation parameters theta_tr from target image to refer image;
        3. Combine theta_st and theta_st to obtain theta from source image to refer image, and compute loss between
        theta and theta_GT.
    """

    tpsTnf = GeometricTnf(geometric_model='tps', use_cuda=use_cuda)
    epoch_loss = 0
    if (epoch % 5 == 0 or epoch == 1) and vis is not None:
        stride_images = len(dataloader) / 3
        watch_images = torch.Tensor(24, 3, 240, 240)
        # means for normalize of caffe resnet and vgg
        # pixel_means = torch.Tensor(np.array([[[[102.9801, 115.9465, 122.7717]]]]).astype(np.float32)).cuda()
        stride_loss = len(dataloader) / 35
        iter_loss = np.zeros(36)
    begin = time.time()
    for batch_idx, batch in enumerate(dataloader):
        ''' Move input batch to gpu '''
        # batch['source_image'].shape & batch['target_image'].shape: (batch_size, 3, 240, 240)
        # batch['theta'].shape-tps: (batch_size, 18)-random or (batch_size, 18, 1, 1)-(pre-set from csv)
        if use_cuda:
            batch = batch_cuda(batch)
        ''' Get the training triple {source image, target image, refer image (warped source image), theta_GT}'''
        batch_triple = triple_generation(batch)
        ''' Train the model '''
        optimizer.zero_grad()
        # Predict tps parameters between images
        box_info_s = faster_rcnn(
            im_data=batch_triple['source_im'],
            im_info=batch_triple['source_im_info'][:, 3:],
            gt_boxes=batch_triple['source_gt_boxes'],
            num_boxes=batch_triple['source_num_boxes'])[0:3]
        box_info_t = faster_rcnn(
            im_data=batch_triple['target_im'],
            im_info=batch_triple['target_im_info'][:, 3:],
            gt_boxes=batch_triple['target_gt_boxes'],
            num_boxes=batch_triple['target_num_boxes'])[0:3]
        box_info_r = faster_rcnn(
            im_data=batch_triple['refer_im'],
            im_info=batch_triple['refer_im_info'][:, 3:],
            gt_boxes=batch_triple['refer_gt_boxes'],
            num_boxes=batch_triple['refer_num_boxes'])[0:3]

        all_box_s = select_boxes(rois=box_info_s[0],
                                 cls_prob=box_info_s[1],
                                 bbox_pred=box_info_s[2],
                                 im_infos=batch_triple['source_im_info'][:,
                                                                         3:])
        all_box_t = select_boxes(rois=box_info_t[0],
                                 cls_prob=box_info_t[1],
                                 bbox_pred=box_info_t[2],
                                 im_infos=batch_triple['target_im_info'][:,
                                                                         3:])
        all_box_r = select_boxes(rois=box_info_r[0],
                                 cls_prob=box_info_r[1],
                                 bbox_pred=box_info_r[2],
                                 im_infos=batch_triple['source_im_info'][:,
                                                                         3:])

        box_s, box_t, box_r = select_box(all_box_s, all_box_t, all_box_r)
        theta_st = aff_theta(boxes_s=box_s, boxes_t=box_t)
        theta_tr = aff_theta(boxes_s=box_t, boxes_t=box_r)

        # theta.shape: (batch_size, 18) for tps
        batch_st = {
            'source_image': batch_triple['source_image'],
            'target_image': batch_triple['target_image']
        }
        batch_tr = {
            'source_image': batch_triple['target_image'],
            'target_image': batch_triple['refer_image']
        }
        theta_aff_tps_st, theta_aff_st = model(
            batch_st, theta_st)  # from source image to target image
        theta_aff_tps_tr, theta_aff_tr = model(
            batch_tr, theta_tr)  # from target image to refer image

        # show_images(batch_st=batch_st, batch_tr=batch_tr, box_s=box_s, box_t=box_t, box_r=box_r)
        loss = loss_fn(theta_st=theta_aff_tps_st,
                       theta_tr=theta_aff_tps_tr,
                       theta_GT=batch_triple['theta_GT'])
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        if (batch_idx + 1) % log_interval == 0:
            end = time.time()
            print(
                'Train epoch: {} [{}/{} ({:.0%})]\t\tCurrent batch loss: {:.6f}\t\tTime cost ({} batches): {:.4f} s'
                .format(epoch, batch_idx + 1,
                        len(dataloader), (batch_idx + 1) / len(dataloader),
                        loss.item(), batch_idx + 1, end - begin))

        if (epoch % 5 == 0 or epoch == 1) and vis is not None:
            if (batch_idx + 1) % stride_images == 0 or batch_idx == 0:
                watch_images = add_watch(
                    watch_images, batch_st, batch_tr, tpsTnf, theta_aff_tps_st,
                    theta_aff_tps_tr,
                    int((batch_idx + 1) / stride_images) * 6)
            if (batch_idx + 1) % stride_loss == 0 or batch_idx == 0:
                iter_loss[int((batch_idx + 1) /
                              stride_loss)] = epoch_loss / (batch_idx + 1)

        # tmp_images = batch_triple['target_im'].permute(0, 2, 3, 1) + pixel_means
        # tmp_images = tmp_images[:, :, :, [2, 1, 0]].permute(0, 3, 1, 2)
        # vis.image(torchvision.utils.make_grid(tmp_images, nrow=1, padding=3))

        if show:
            # if dual:
            #     warped_image_aff = affTnf(batch_st['source_image'], theta_aff_st_1)
            #     warped_image_aff_2 = affTnf(batch_tr['source_image'], theta_aff_tr_1)
            #     warped_image_aff_3 = affTnf(warped_image_aff, theta_aff_tr_1)
            #     show_images(batch_st, batch_tr, warped_image_aff.detach(), warped_image_aff_2.detach(), warped_image_aff_3.detach())
            #
            #     warped_image_aff = affTnf(batch_st['source_image'], theta_aff_st)
            #     warped_image_aff_2 = affTnf(batch_tr['source_image'], theta_aff_tr)
            #     warped_image_aff_3 = affTnf(warped_image_aff, theta_aff_tr)
            #     show_images(batch_st, batch_tr, warped_image_aff.detach(), warped_image_aff_2.detach(), warped_image_aff_3.detach())
            #
            #     warped_image_aff_tps = tpsTnf(batch_st['source_image'], theta_aff_tps_st)
            #     warped_image_aff_tps_2 = tpsTnf(batch_tr['source_image'], theta_aff_tps_tr)
            #     warped_image_aff_tps_3 = tpsTnf(warped_image_aff_tps, theta_aff_tps_tr)
            #     show_images(batch_st, batch_tr, warped_image_aff_tps.detach(), warped_image_aff_tps_2.detach(), warped_image_aff_tps_3.detach())
            #
            #     warped_image = affTnf(batch_st['source_image'], theta_aff_st_1)
            #     warped_image = affTnf(warped_image, theta_aff_st)
            #     warped_image = tpsTnf(warped_image, theta_aff_tps_st)
            #     warped_image_2 = affTnf(batch_tr['source_image'], theta_aff_tr_1)
            #     warped_image_2 = affTnf(warped_image_2, theta_aff_tr)
            #     warped_image_2 = tpsTnf(warped_image_2, theta_aff_tps_tr)
            #     warped_image_3 = affTnf(warped_image, theta_aff_tr_1)
            #     warped_image_3 = affTnf(warped_image_3, theta_aff_tr)
            #     warped_image_3 = tpsTnf(warped_image_3, theta_aff_tps_tr)
            #     show_images(batch_st, batch_tr, warped_image.detach(), warped_image_2.detach(), warped_image_3.detach())
            # else:
            #     warped_image_aff = affTnf(batch_st['source_image'], theta_st)
            #     warped_image_aff_2 = affTnf(batch_tr['source_image'], theta_tr)
            #     warped_image_aff_3 = affTnf(warped_image_aff, theta_tr)
            #     show_images(batch_st, batch_tr, warped_image_aff.detach(), warped_image_aff_2.detach(), warped_image_aff_3.detach())

            warped_image_tps = tpsTnf(batch_st['source_image'],
                                      theta_aff_tps_st)
            warped_image_tps_2 = tpsTnf(batch_tr['source_image'],
                                        theta_aff_tps_tr)
            warped_image_tps_3 = tpsTnf(warped_image_tps, theta_aff_tps_tr)
            show_images(batch_triple, warped_image_tps.detach(),
                        warped_image_tps_2.detach(),
                        warped_image_tps_3.detach(), box_s, box_t, box_r)

    end = time.time()

    # Visualize watch images & train loss
    if (epoch % 5 == 0 or epoch == 1) and vis is not None:
        opts = dict(jpgquality=100,
                    title='Epoch ' + str(epoch) +
                    ' source warped_sr target warped_tr refer warped_sr')
        # Un-normalize for caffe resnet and vgg
        # watch_images = watch_images.permute(0, 2, 3, 1) + pixel_means
        # watch_images = watch_images[:, :, :, [2, 1, 0]].permute(0, 3, 1, 2)
        watch_images = normalize_image(watch_images, forward=False) * 255.0
        vis.image(torchvision.utils.make_grid(watch_images, nrow=6, padding=3),
                  opts=opts)
        # vis.images(watch_images, nrow=6, padding=3, opts=opts)

        opts_loss = dict(
            xlabel='Iterations (' + str(stride_loss) + ')',
            ylabel='Loss',
            title='GM ResNet101 Detect&Affine&TPS Training Loss in Epoch ' +
            str(epoch),
            legend=['Loss'],
            width=2000)
        vis.line(iter_loss, np.arange(36), opts=opts_loss)

    epoch_loss /= len(dataloader)
    print('Train set -- Average loss: {:.6f}\t\tTime cost: {:.4f}'.format(
        epoch_loss, end - begin))
    return epoch_loss, end - begin
예제 #12
0
def vis_control2(vis,
                 dataloader,
                 theta_1,
                 theta_2,
                 dataset_name,
                 use_cuda=True):
    # Visualize watch images
    tpsTnf = GeometricTnf2(geometric_model='tps', use_cuda=use_cuda)

    group_size = 5
    watch_images = torch.ones(len(dataloader) * group_size, 3, 340, 340)
    if use_cuda:
        watch_images = watch_images.cuda()

    # means for normalize of caffe resnet and vgg
    # pixel_means = torch.Tensor(np.array([[[[102.9801, 115.9465, 122.7717]]]]).astype(np.float32))
    for batch_idx, batch in enumerate(dataloader):
        if use_cuda:
            batch = batch_cuda(batch)

        # Theta and theta_inver
        theta_tps_1 = theta_1['tps'][batch_idx].unsqueeze(0)
        theta_tps_2 = theta_2['tps'][batch_idx].unsqueeze(0)

        # Warped image
        warped_tps_1 = tpsTnf(batch['source_image'], theta_tps_1)
        warped_tps_2 = tpsTnf(batch['source_image'], theta_tps_2)

        watch_images[batch_idx * group_size, :, 50:290,
                     50:290] = batch['source_image']
        watch_images[batch_idx * group_size + 1, :, 50:290,
                     50:290] = warped_tps_1
        watch_images[batch_idx * group_size + 2, :, 50:290,
                     50:290] = batch['source_image']
        watch_images[batch_idx * group_size + 3, :, 50:290,
                     50:290] = warped_tps_2
        watch_images[batch_idx * group_size + 4, :, 50:290,
                     50:290] = batch['target_image']

    opts = dict(jpgquality=100, title=dataset_name)
    watch_images[:, :, 50:290,
                 50:290] = normalize_image(watch_images[:, :, 50:290, 50:290],
                                           forward=False)
    watch_images *= 255.0
    watch_images = watch_images.permute(0, 2, 3,
                                        1).cpu().numpy().astype(np.uint8)

    im_size = torch.Tensor([[240, 240]]).cuda()
    for i in range(watch_images.shape[0]):
        if i % group_size < 4:
            if i % group_size == 0:
                cp_norm = theta_1['tps'][int(i / group_size)][:18].view(
                    1, 2, -1)

            if i % group_size == 1:
                cp_norm = theta_1['tps'][int(i / group_size)][18:].view(
                    1, 2, -1)

            if i % group_size == 2:
                cp_norm = theta_2['tps'][int(i / group_size)][:18].view(
                    1, 2, -1)

            if i % group_size == 3:
                cp_norm = theta_2['tps'][int(i / group_size)][18:].view(
                    1, 2, -1)
            watch_images[i] = draw_grid(watch_images[i], cp_norm)

    watch_images = torch.Tensor(watch_images.astype(np.float32))
    watch_images = watch_images.permute(0, 3, 1, 2)
    vis.image(torchvision.utils.make_grid(watch_images, nrow=5, padding=5),
              opts=opts)
예제 #13
0
def vis_pf(vis,
           dataloader,
           theta_1,
           theta_2,
           theta_inver_1,
           theta_inver_2,
           results_1,
           results_2,
           dataset_name,
           use_cuda=True):
    # Visualize watch images
    tpsTnf_1 = GeometricTnf(geometric_model='tps', use_cuda=use_cuda)
    tpsTnf_2 = GeometricTnf2(geometric_model='tps', use_cuda=use_cuda)
    pt_1 = PointTnf(use_cuda=use_cuda)
    pt_2 = PointTPS(use_cuda=use_cuda)

    group_size = 4
    watch_images = torch.ones(len(dataloader) * group_size, 3, 280, 240)
    watch_keypoints = -torch.ones(len(dataloader) * group_size, 2, 20)
    if use_cuda:
        watch_images = watch_images.cuda()
        watch_keypoints = watch_keypoints.cuda()
    num_points = np.ones(len(dataloader) * 6).astype(np.int8)
    correct_index = list()
    image_names = list()
    metrics = list()

    # Colors for keypoints
    cmap = plt.get_cmap('tab20')
    colors = list()
    for c in range(20):
        r = cmap(c)[0] * 255
        g = cmap(c)[1] * 255
        b = cmap(c)[2] * 255
        colors.append((b, g, r))
    fnt = cv2.FONT_HERSHEY_COMPLEX

    # means for normalize of caffe resnet and vgg
    # pixel_means = torch.Tensor(np.array([[[[102.9801, 115.9465, 122.7717]]]]).astype(np.float32))
    for batch_idx, batch in enumerate(dataloader):
        if use_cuda:
            batch = batch_cuda(batch)

        # Theta and theta_inver
        theta_tps_1 = theta_1['tps'][batch_idx].unsqueeze(0)
        theta_tps_2 = theta_2['tps'][batch_idx].unsqueeze(0)

        thetai_tps_1 = theta_inver_1['tps'][batch_idx].unsqueeze(0)
        thetai_tps_2 = theta_inver_2['tps'][batch_idx].unsqueeze(0)

        # Warped image
        warped_tps_1 = tpsTnf_1(batch['source_image'], theta_tps_1)
        warped_tps_2 = tpsTnf_2(batch['source_image'], theta_tps_2)

        watch_images[batch_idx * group_size, :,
                     0:240, :] = batch['source_image']
        watch_images[batch_idx * group_size + 1, :, 0:240, :] = warped_tps_1
        watch_images[batch_idx * group_size + 2, :, 0:240, :] = warped_tps_2
        watch_images[batch_idx * group_size + 3, :,
                     0:240, :] = batch['target_image']

        # Warped keypoints
        source_im_size = batch['source_im_info'][:, 0:3]
        target_im_size = batch['target_im_info'][:, 0:3]

        source_points = batch['source_points']
        target_points = batch['target_points']

        source_points_norm = PointsToUnitCoords(P=source_points,
                                                im_size=source_im_size)
        target_points_norm = PointsToUnitCoords(P=target_points,
                                                im_size=target_im_size)

        warped_points_tps_norm_1 = pt_1.tpsPointTnf(theta=thetai_tps_1,
                                                    points=source_points_norm)
        warped_points_tps_1 = PointsToPixelCoords(P=warped_points_tps_norm_1,
                                                  im_size=target_im_size)
        pck_tps_1, index_tps_1, N_pts = pck(target_points, warped_points_tps_1,
                                            dataset_name)
        warped_points_tps_1 = relocate(warped_points_tps_1, target_im_size)

        warped_points_tps_norm_2 = pt_2.tpsPointTnf(theta=thetai_tps_2,
                                                    points=source_points_norm)
        warped_points_tps_2 = PointsToPixelCoords(P=warped_points_tps_norm_2,
                                                  im_size=target_im_size)
        pck_tps_2, index_tps_2, _ = pck(target_points, warped_points_tps_2,
                                        dataset_name)
        warped_points_tps_2 = relocate(warped_points_tps_2, target_im_size)

        watch_keypoints[batch_idx * group_size, :, :N_pts] = relocate(
            batch['source_points'], source_im_size)[:, :, :N_pts]
        watch_keypoints[batch_idx * group_size +
                        1, :, :N_pts] = warped_points_tps_1[:, :, :N_pts]
        watch_keypoints[batch_idx * group_size +
                        2, :, :N_pts] = warped_points_tps_2[:, :, :N_pts]
        watch_keypoints[batch_idx * group_size + 3, :, :N_pts] = relocate(
            batch['target_points'], target_im_size)[:, :, :N_pts]

        num_points[batch_idx * group_size:batch_idx * group_size +
                   group_size] = N_pts

        correct_index.append(np.arange(N_pts))
        correct_index.append(index_tps_1)
        correct_index.append(index_tps_2)
        correct_index.append(np.arange(N_pts))

        image_names.append('Source')
        image_names.append('TPS')
        image_names.append('TPS_Jitter')
        image_names.append('Target')

        metrics.append('')
        metrics.append('PCK: {:.2%}'.format(pck_tps_1))
        metrics.append('PCK: {:.2%}'.format(pck_tps_2))
        metrics.append('')

    opts = dict(jpgquality=100, title=dataset_name)
    # Un-normalize for caffe resnet and vgg
    # watch_images = watch_images.permute(0, 2, 3, 1) + pixel_means
    # watch_images = watch_images[:, :, :, [2, 1, 0]].permute(0, 3, 1, 2)
    # watch_images = normalize_image(watch_images, forward=False) * 255.0
    watch_images[:, :, 0:240, :] = normalize_image(watch_images[:, :,
                                                                0:240, :],
                                                   forward=False)
    watch_images *= 255.0
    watch_images = watch_images.permute(0, 2, 3,
                                        1).cpu().numpy().astype(np.uint8)
    watch_keypoints = watch_keypoints.cpu().numpy()

    for i in range(watch_images.shape[0]):
        pos_name = (80, 255)
        if (i + 1) % group_size == 1 or (i + 1) % group_size == 0:
            pos_pck = (0, 0)
        else:
            pos_pck = (70, 275)
        cv2.putText(watch_images[i], image_names[i], pos_name, fnt, 0.5,
                    (0, 0, 0), 1)
        cv2.putText(watch_images[i], metrics[i], pos_pck, fnt, 0.5, (0, 0, 0),
                    1)
        if (i + 1) % group_size == 0:
            for j in range(num_points[i]):
                cv2.drawMarker(
                    watch_images[i],
                    (watch_keypoints[i, 0, j], watch_keypoints[i, 1, j]),
                    colors[j], cv2.MARKER_DIAMOND, 12, 2, cv2.LINE_AA)
        else:
            for j in correct_index[i]:
                cv2.drawMarker(
                    watch_images[i],
                    (watch_keypoints[i, 0, j], watch_keypoints[i, 1, j]),
                    colors[j], cv2.MARKER_CROSS, 12, 2, cv2.LINE_AA)
                cv2.drawMarker(watch_images[i],
                               (watch_keypoints[i + (group_size - 1) -
                                                (i % group_size), 0, j],
                                watch_keypoints[i + (group_size - 1) -
                                                (i % group_size), 1, j]),
                               colors[j], cv2.MARKER_DIAMOND, 12, 2,
                               cv2.LINE_AA)

    watch_images = torch.Tensor(watch_images.astype(np.float32))
    watch_images = watch_images.permute(0, 3, 1, 2)
    vis.image(torchvision.utils.make_grid(watch_images, nrow=4, padding=5),
              opts=opts)
예제 #14
0
def vis_control(vis,
                dataloader,
                theta_1,
                theta_2,
                dataset_name,
                use_cuda=True):
    # Visualize watch images
    tpsTnf_1 = GeometricTnf(geometric_model='tps', use_cuda=use_cuda)
    tpsTnf_2 = GeometricTnf2(geometric_model='tps', use_cuda=use_cuda)

    group_size = 5
    watch_images = torch.ones(len(dataloader) * group_size, 3, 340, 340)
    if use_cuda:
        watch_images = watch_images.cuda()

    # Colors for keypoints
    cmap = plt.get_cmap('tab20')
    colors = list()
    for c in range(20):
        r = cmap(c)[0] * 255
        g = cmap(c)[1] * 255
        b = cmap(c)[2] * 255
        colors.append((b, g, r))
    fnt = cv2.FONT_HERSHEY_COMPLEX

    # means for normalize of caffe resnet and vgg
    # pixel_means = torch.Tensor(np.array([[[[102.9801, 115.9465, 122.7717]]]]).astype(np.float32))
    for batch_idx, batch in enumerate(dataloader):
        if use_cuda:
            batch = batch_cuda(batch)

        # Theta and theta_inver
        theta_tps_1 = theta_1['tps'][batch_idx].unsqueeze(0)
        theta_tps_2 = theta_2['tps'][batch_idx].unsqueeze(0)

        # Warped image
        warped_tps_1 = tpsTnf_1(batch['source_image'], theta_tps_1)
        warped_tps_2 = tpsTnf_2(batch['source_image'], theta_tps_2)

        watch_images[batch_idx * group_size, :, 50:290,
                     50:290] = batch['source_image']
        watch_images[batch_idx * group_size + 1, :, 50:290,
                     50:290] = warped_tps_1
        watch_images[batch_idx * group_size + 2, :, 50:290,
                     50:290] = batch['source_image']
        watch_images[batch_idx * group_size + 3, :, 50:290,
                     50:290] = warped_tps_2
        watch_images[batch_idx * group_size + 4, :, 50:290,
                     50:290] = batch['target_image']

    opts = dict(jpgquality=100, title=dataset_name)
    watch_images[:, :, 50:290,
                 50:290] = normalize_image(watch_images[:, :, 50:290, 50:290],
                                           forward=False)
    watch_images *= 255.0
    watch_images = watch_images.permute(0, 2, 3,
                                        1).cpu().numpy().astype(np.uint8)

    im_size = torch.Tensor([[240, 240]]).cuda()
    for i in range(watch_images.shape[0]):
        if i % group_size == 0:
            cp_norm = theta_1['tps'][int(i / group_size)].view(1, 2, -1)
            cp = PointsToPixelCoords(P=cp_norm, im_size=im_size)
            cp = cp.squeeze().cpu().numpy() + 50
            for j in range(9):
                cv2.drawMarker(watch_images[i], (cp[0, j], cp[1, j]),
                               (0, 0, 255), cv2.MARKER_TILTED_CROSS, 12, 2,
                               cv2.LINE_AA)

            for j in range(2):
                for k in range(3):
                    # vertical grid
                    cv2.line(watch_images[i],
                             (cp[0, j + k * 3], cp[1, j + k * 3]),
                             (cp[0, j + k * 3 + 1], cp[1, j + k * 3 + 1]),
                             (0, 0, 255), 2, cv2.LINE_AA)
                    # horizontal grid
                    cv2.line(watch_images[i],
                             (cp[0, j * 3 + k], cp[1, j * 3 + k]),
                             (cp[0, j * 3 + k + 3], cp[1, j * 3 + k + 3]),
                             (0, 0, 255), 2, cv2.LINE_AA)

        if i % group_size == 1:
            cp_norm = torch.Tensor(
                [-1, -1, -1, 0, 0, 0, 1, 1, 1, -1, 0, 1, -1, 0, 1, -1, 0,
                 1]).cuda().view(1, 2, -1)
            cp = PointsToPixelCoords(P=cp_norm, im_size=im_size)
            cp = cp.squeeze().cpu().numpy() + 50
            for j in range(9):
                cv2.drawMarker(watch_images[i], (cp[0, j], cp[1, j]),
                               (0, 0, 255), cv2.MARKER_TILTED_CROSS, 12, 2,
                               cv2.LINE_AA)

            for j in range(1):
                for k in range(3):
                    # vertical grid
                    cv2.line(watch_images[i],
                             (cp[0, j + k * 3], cp[1, j + k * 3]),
                             (cp[0, j + k * 3 + 1], cp[1, j + k * 3 + 1]),
                             (0, 0, 255), 2, cv2.LINE_AA)
                    # horizontal grid
                    cv2.line(watch_images[i],
                             (cp[0, j * 3 + k], cp[1, j * 3 + k]),
                             (cp[0, j * 3 + k + 3], cp[1, j * 3 + k + 3]),
                             (0, 0, 255), 2, cv2.LINE_AA)

        if i % group_size == 2:
            cp_norm = theta_2['tps'][int(i / group_size)][:18].view(1, 2, -1)
            cp = PointsToPixelCoords(P=cp_norm, im_size=im_size)
            cp = cp.squeeze().cpu().numpy() + 50
            for j in range(9):
                cv2.drawMarker(watch_images[i], (cp[0, j], cp[1, j]),
                               (0, 0, 255), cv2.MARKER_TILTED_CROSS, 12, 2,
                               cv2.LINE_AA)

            for j in range(2):
                for k in range(3):
                    # vertical grid
                    cv2.line(watch_images[i],
                             (cp[0, j + k * 3], cp[1, j + k * 3]),
                             (cp[0, j + k * 3 + 1], cp[1, j + k * 3 + 1]),
                             (0, 0, 255), 2, cv2.LINE_AA)
                    # horizontal grid
                    cv2.line(watch_images[i],
                             (cp[0, j * 3 + k], cp[1, j * 3 + k]),
                             (cp[0, j * 3 + k + 3], cp[1, j * 3 + k + 3]),
                             (0, 0, 255), 2, cv2.LINE_AA)

        if i % group_size == 3:
            cp_norm = theta_2['tps'][int(i / group_size)][18:].view(1, 2, -1)
            cp = PointsToPixelCoords(P=cp_norm, im_size=im_size)
            cp = cp.squeeze().cpu().numpy() + 50
            for j in range(9):
                cv2.drawMarker(watch_images[i], (cp[0, j], cp[1, j]),
                               (0, 0, 255), cv2.MARKER_TILTED_CROSS, 12, 2,
                               cv2.LINE_AA)

            for j in range(2):
                for k in range(3):
                    # vertical grid
                    cv2.line(watch_images[i],
                             (cp[0, j + k * 3], cp[1, j + k * 3]),
                             (cp[0, j + k * 3 + 1], cp[1, j + k * 3 + 1]),
                             (0, 0, 255), 2, cv2.LINE_AA)
                    # horizontal grid
                    cv2.line(watch_images[i],
                             (cp[0, j * 3 + k], cp[1, j * 3 + k]),
                             (cp[0, j * 3 + k + 3], cp[1, j * 3 + k + 3]),
                             (0, 0, 255), 2, cv2.LINE_AA)

    watch_images = torch.Tensor(watch_images.astype(np.float32))
    watch_images = watch_images.permute(0, 3, 1, 2)
    vis.image(torchvision.utils.make_grid(watch_images, nrow=5, padding=5),
              opts=opts)
예제 #15
0
def train_fn_dual(epoch, model, loss_fn, optimizer, dataloader, triple_generation, use_cuda=True, log_interval=100, vis=None, show=False):
    """
        Train the model with synthetically training triple:
        {source image, target image, refer image (warped source image), theta_GT} from PF-PASCAL.
        1. Train the transformation parameters theta_st from source image to target image;
        2. Train the transformation parameters theta_tr from target image to refer image;
        3. Combine theta_st and theta_st to obtain theta from source image to refer image, and compute loss between
        theta and theta_GT.
    """

    geoTnf = ComposedGeometricTnf(use_cuda=use_cuda)
    affTnf = GeometricTnf(geometric_model='affine', use_cuda=use_cuda)
    tpsTnf = GeometricTnf(geometric_model='tps', use_cuda=use_cuda)
    epoch_loss = 0
    if (epoch % 5 == 0 or epoch == 1) and vis is not None:
        stride_images = len(dataloader) / 3
        group_size = 6
        watch_images = torch.ones(group_size * 4, 3, 260, 240).cuda()
        # watch_images = torch.ones(group_size * 20, 3, 260, 240).cuda()
        image_names = list()
        fnt = cv2.FONT_HERSHEY_COMPLEX
        # means for normalize of caffe resnet and vgg
        # pixel_means = torch.Tensor(np.array([[[[102.9801, 115.9465, 122.7717]]]]).astype(np.float32)).cuda()
        stride_loss = len(dataloader) / 105
        iter_loss = np.zeros(106)
    begin = time.time()
    for batch_idx, batch in enumerate(dataloader):
        ''' Move input batch to gpu '''
        # batch['source_image'].shape & batch['target_image'].shape: (batch_size, 3, 240, 240)
        # batch['theta'].shape-tps: (batch_size, 18)-random or (batch_size, 18, 1, 1)-(pre-set from csv)
        if use_cuda:
            batch = batch_cuda(batch)

        ''' Get the training triple {source image, target image, refer image (warped source image), theta_GT}'''
        batch_triple = triple_generation(batch)

        ''' Train the model '''
        optimizer.zero_grad()
        # Predict tps parameters between images
        # theta.shape: (batch_size, 18) for tps
        batch_st = {'source_image': batch_triple['source_image'], 'target_image': batch_triple['target_image']}
        batch_tr = {'source_image': batch_triple['target_image'], 'target_image': batch_triple['refer_image']}
        theta_aff_tps_st, theta_aff_st = model(batch_st)  # from source image to target image
        theta_aff_tps_tr, theta_aff_tr = model(batch_tr)  # from target image to refer image
        # show_images(batch_st=batch_st, batch_tr=batch_tr, box_s=box_s, box_t=box_t, box_r=box_r)
        # loss = loss_fn(theta_st=theta_aff_tps_st, theta_tr=theta_aff_tps_tr, theta_GT=batch_triple['theta_GT'])
        loss = loss_fn(theta_aff_tps_st=theta_aff_tps_st, theta_aff_st=theta_aff_st, theta_aff_tps_tr=theta_aff_tps_tr,
                       theta_aff_tr=theta_aff_tr, theta_aff_GT=batch_triple['theta_aff_GT'], theta_tps_GT=batch_triple['theta_tps_GT'])
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        if (batch_idx+1) % log_interval == 0:
            end = time.time()
            print('Train epoch: {} [{}/{} ({:.0%})]\t\tCurrent batch loss: {:.6f}\t\tTime cost ({} batches): {:.4f} s'
                  .format(epoch, batch_idx+1, len(dataloader), (batch_idx+1) / len(dataloader), loss.item(), batch_idx + 1, end - begin))

        if (epoch % 5 == 0 or epoch == 1) and vis is not None:
            if (batch_idx + 1) % stride_images == 0 or batch_idx == 0:
                watch_images, image_names = add_watch(watch_images, image_names, batch_st, batch_tr, affTnf, tpsTnf, geoTnf, theta_aff_tps_st, theta_aff_st, theta_aff_tps_tr, theta_aff_tr, int((batch_idx + 1) / stride_images) * group_size)

            # if batch_idx <= 19:
            #     watch_images, image_names = add_watch(watch_images, image_names, batch_st, batch_tr, affTnf, tpsTnf, geoTnf, theta_aff_tps_st, theta_aff_st, theta_aff_tps_tr, theta_aff_tr, batch_idx * group_size, batch_triple)
            #     if batch_idx == 19:
            #         opts = dict(jpgquality=100, title='Epoch ' + str(epoch) + ' source warped_st target warped_tr refer warped_sr')
            #         watch_images[:, :, 0:240, :] = normalize_image(watch_images[:, :, 0:240, :], forward=False)
            #         watch_images *= 255.0
            #         watch_images = watch_images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
            #         for i in range(len(image_names)):
            #             cv2.putText(watch_images[i], image_names[i], (80, 255), fnt, 0.5, (0, 0, 0), 1)
            #         watch_images = torch.Tensor(watch_images.astype(np.float32))
            #         watch_images = watch_images.permute(0, 3, 1, 2)
            #         vis.image(torchvision.utils.make_grid(watch_images, nrow=group_size, padding=3), opts=opts)

            if (batch_idx + 1) % stride_loss == 0 or batch_idx == 0:
                iter_loss[int((batch_idx + 1) / stride_loss)] = epoch_loss / (batch_idx + 1)

        # tmp_images = normalize_image(batch_tr['target_image'], forward=False) * 255.0
        # vis.images(tmp_images, nrow=8, padding=3)

        # if show:
        #     if dual:
        #         warped_image_aff = affTnf(batch_st['source_image'], theta_aff_st_1)
        #         warped_image_aff_2 = affTnf(batch_tr['source_image'], theta_aff_tr_1)
        #         warped_image_aff_3 = affTnf(warped_image_aff, theta_aff_tr_1)
        #         show_images(batch_st, batch_tr, warped_image_aff.detach(), warped_image_aff_2.detach(), warped_image_aff_3.detach())
        #
        #         warped_image_aff = affTnf(batch_st['source_image'], theta_aff_st)
        #         warped_image_aff_2 = affTnf(batch_tr['source_image'], theta_aff_tr)
        #         warped_image_aff_3 = affTnf(warped_image_aff, theta_aff_tr)
        #         show_images(batch_st, batch_tr, warped_image_aff.detach(), warped_image_aff_2.detach(), warped_image_aff_3.detach())
        #
        #         warped_image_aff_tps = tpsTnf(batch_st['source_image'], theta_aff_tps_st)
        #         warped_image_aff_tps_2 = tpsTnf(batch_tr['source_image'], theta_aff_tps_tr)
        #         warped_image_aff_tps_3 = tpsTnf(warped_image_aff_tps, theta_aff_tps_tr)
        #         show_images(batch_st, batch_tr, warped_image_aff_tps.detach(), warped_image_aff_tps_2.detach(), warped_image_aff_tps_3.detach())
        #
        #         warped_image = affTnf(batch_st['source_image'], theta_aff_st_1)
        #         warped_image = affTnf(warped_image, theta_aff_st)
        #         warped_image = tpsTnf(warped_image, theta_aff_tps_st)
        #         warped_image_2 = affTnf(batch_tr['source_image'], theta_aff_tr_1)
        #         warped_image_2 = affTnf(warped_image_2, theta_aff_tr)
        #         warped_image_2 = tpsTnf(warped_image_2, theta_aff_tps_tr)
        #         warped_image_3 = affTnf(warped_image, theta_aff_tr_1)
        #         warped_image_3 = affTnf(warped_image_3, theta_aff_tr)
        #         warped_image_3 = tpsTnf(warped_image_3, theta_aff_tps_tr)
        #         show_images(batch_st, batch_tr, warped_image.detach(), warped_image_2.detach(), warped_image_3.detach())
        #     else:
        #         warped_image_aff = affTnf(batch_st['source_image'], theta_st)
        #         warped_image_aff_2 = affTnf(batch_tr['source_image'], theta_tr)
        #         warped_image_aff_3 = affTnf(warped_image_aff, theta_tr)
        #         show_images(batch_st, batch_tr, warped_image_aff.detach(), warped_image_aff_2.detach(), warped_image_aff_3.detach())

                # warped_image_tps = tpsTnf(batch_st['source_image'], theta_st)
                # warped_image_tps_2 = tpsTnf(batch_tr['source_image'], theta_tr)
                # warped_image_tps_3 = tpsTnf(warped_image_tps, theta_tr)
                # show_images(batch_st, batch_tr, warped_image_tps.detach(), warped_image_tps_2.detach(), warped_image_tps_3.detach())

    end = time.time()

    # Visualize watch images & train loss
    if (epoch % 5 == 0 or epoch == 1) and vis is not None:
        opts = dict(jpgquality=100, title='Epoch ' + str(epoch) + ' source warped_st target warped_tr refer warped_sr')
        watch_images[:, :, 0:240, :] = normalize_image(watch_images[:, :, 0:240, :], forward=False)
        watch_images *= 255.0
        watch_images = watch_images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
        for i in range(watch_images.shape[0]):
            cv2.putText(watch_images[i], image_names[i], (80, 255), fnt, 0.5, (0, 0, 0), 1)
        watch_images = torch.Tensor(watch_images.astype(np.float32))
        watch_images = watch_images.permute(0, 3, 1, 2)
        vis.image(torchvision.utils.make_grid(watch_images, nrow=group_size, padding=3), opts=opts)
        # Un-normalize for caffe resnet and vgg
        # watch_images = watch_images.permute(0, 2, 3, 1) + pixel_means
        # watch_images = watch_images[:, :, :, [2, 1, 0]].permute(0, 3, 1, 2)

        opts_loss = dict(xlabel='Iterations (' + str(stride_loss) + ')',
                         ylabel='Loss',
                         title='GM ResNet101 AffTPS Training Loss in Epoch ' + str(epoch),
                         legend=['Loss'],
                         width=2000)
        vis.line(iter_loss, np.arange(106), opts=opts_loss)

    epoch_loss /= len(dataloader)
    print('Train set -- Average loss: {:.6f}\t\tTime cost: {:.4f}'.format(epoch_loss, end - begin))
    return epoch_loss, end - begin
예제 #16
0
def vis_fn_detect(vis,
                  model,
                  faster_rcnn,
                  aff_theta,
                  train_loss,
                  val_pck,
                  train_lr,
                  epoch,
                  num_epochs,
                  dataloader,
                  use_cuda=True):
    # Visualize watch images
    affTnf = GeometricTnf(geometric_model='affine', use_cuda=use_cuda)
    tpsTnf = GeometricTnf(geometric_model='tps', use_cuda=use_cuda)
    watch_images = torch.Tensor(len(dataloader) * 5, 3, 240, 240)
    # means for normalize of caffe resnet and vgg
    # pixel_means = torch.Tensor(np.array([[[[102.9801, 115.9465, 122.7717]]]]).astype(np.float32))
    for batch_idx, batch in enumerate(dataloader):
        if use_cuda:
            batch = batch_cuda(batch)

        box_info_s = faster_rcnn(im_data=batch['source_im'],
                                 im_info=batch['source_im_info'][:, 3:],
                                 gt_boxes=batch['source_gt_boxes'],
                                 num_boxes=batch['source_num_boxes'])[0:3]
        box_info_t = faster_rcnn(im_data=batch['target_im'],
                                 im_info=batch['target_im_info'][:, 3:],
                                 gt_boxes=batch['target_gt_boxes'],
                                 num_boxes=batch['target_num_boxes'])[0:3]
        all_box_s = select_boxes(rois=box_info_s[0],
                                 cls_prob=box_info_s[1],
                                 bbox_pred=box_info_s[2],
                                 im_infos=batch['source_im_info'][:, 3:])
        all_box_t = select_boxes(rois=box_info_t[0],
                                 cls_prob=box_info_t[1],
                                 bbox_pred=box_info_t[2],
                                 im_infos=batch['target_im_info'][:, 3:])
        box_s, box_t = select_box_st(all_box_s, all_box_t)
        theta_det = aff_theta(boxes_s=box_s, boxes_t=box_t)
        theta_aff_tps, theta_aff = model(batch, theta_det)

        warped_image_1 = affTnf(batch['source_image'], theta_det)
        warped_image_2 = affTnf(warped_image_1, theta_aff)
        warped_image_3 = tpsTnf(warped_image_2, theta_aff_tps)
        watch_images[batch_idx * 5] = batch['source_image'][0]
        watch_images[batch_idx * 5 + 1] = warped_image_1[0]
        watch_images[batch_idx * 5 + 2] = warped_image_2[0]
        watch_images[batch_idx * 5 + 3] = warped_image_3[0]
        watch_images[batch_idx * 5 + 4] = batch['target_image'][0]

    opts = dict(jpgquality=100,
                title='Epoch ' + str(epoch) + ' source warped target')
    # Un-normalize for caffe resnet and vgg
    # watch_images = watch_images.permute(0, 2, 3, 1) + pixel_means
    # watch_images = watch_images[:, :, :, [2, 1, 0]].permute(0, 3, 1, 2)
    watch_images = normalize_image(watch_images, forward=False) * 255.0
    vis.image(torchvision.utils.make_grid(watch_images, nrow=5, padding=5),
              opts=opts)
    # vis.images(watch_images, nrow=5, padding=3, opts=opts)

    if epoch == num_epochs:
        epochs = np.arange(1, num_epochs + 1)
        # Visualize train loss
        opts_loss = dict(xlabel='Epoch',
                         ylabel='Loss',
                         title='GM ResNet101 Detect&Affine&TPS Training Loss',
                         legend=['Loss'],
                         width=2000)
        vis.line(train_loss, epochs, opts=opts_loss)

        # Visualize val pck
        opts_pck = dict(xlabel='Epoch',
                        ylabel='Val PCK',
                        title='GM ResNet101 Detect&Affine&TPS Val PCK',
                        legend=['PCK'],
                        width=2000)
        vis.line(val_pck, epochs, opts=opts_pck)

        # Visualize train lr
        opts_lr = dict(
            xlabel='Epoch',
            ylabel='Learning Rate',
            title='GM ResNet101 Detect&Affine&TPS Training Learning Rate',
            legend=['LR'],
            width=2000)
        vis.line(train_lr, epochs, opts=opts_lr)
    def __call__(self, batch=None):
        """
            Generate a synthetically training triple:
            1. Use the given image pair as the source image and target image;
            2. Padding the source image, and wrap the padding image with the given transformation to generate the refer image;
            3. The training triple consists of {source image, target image, refer image, theta_GT}
            4. The input for fasterRCNN consists of {source_im, target_im, refer_im, source_im_info, target_im_info,
             refer_im_info, source_gt_boxes, target_gt_boxes, refer_gt_boxes, source_num_boxes, target_num_boxes,
             refer_num_boxes}.
        """

        # image_batch.shape: (batch_size, 3, H, W)
        # theta_batch.shape-tps: (batch_size, 18)-random or (batch_size, 18, 1, 1)-(pre-set from csv)
        # boxes.shape: (batch_size, 4), 4: (x_min, y_min, x_max, y_max)
        img_A_batch = batch['source_image']
        img_B_batch = batch['target_image']
        theta_aff = batch['theta_GT'][:, :6].contiguous()
        theta_tps = batch['theta_GT'][:, 6:]

        # Generate symmetrically padded image for bigger sampling region to warp the source image
        padded_image_batch = self.symmetricImagePad(
            image_batch=img_A_batch, padding_factor=self.padding_factor)

        img_A_batch.requires_grad = False
        img_B_batch.requires_grad = False
        padded_image_batch.requires_grad = False
        theta_aff.requires_grad = False
        theta_tps.requires_grad = False

        # Get the refer image by warping the padded image with the given transformation
        # warped_image_batch = self.affTnf(image_batch=padded_image_batch, theta_batch=theta_aff, padding_factor=0.5, crop_factor=self.crop_factor)
        # warped_image_batch = self.symmetricImagePad(image_batch=warped_image_batch, padding_factor=self.padding_factor)
        # warped_image_batch = self.tpsTnf(image_batch=warped_image_batch, theta_batch=theta_tps, padding_factor=0.5, crop_factor=self.crop_factor)
        # warped_image_aff = self.affTnf(image_batch=padded_image_batch, theta_batch=theta_aff, padding_factor=self.padding_factor, crop_factor=self.crop_factor)
        # warped_image_tps = self.tpsTnf(image_batch=padded_image_batch, theta_batch=theta_tps, padding_factor=self.padding_factor, crop_factor=self.crop_factor)

        # warped_image_batch = self.affTpsTnf(source_image=padded_image_batch, theta_aff=theta_aff, theta_aff_tps=theta_tps, use_cuda=self.use_cuda)
        warped_image_batch = self.geometricTnf(image_batch=padded_image_batch,
                                               theta_aff=theta_aff,
                                               theta_aff_tps=theta_tps)

        # Get the refer im for extracting rois
        tmp_image_batch = normalize_image(image=warped_image_batch,
                                          forward=False) * 255.0
        tmp_image_batch = tmp_image_batch.cpu().numpy().transpose((0, 2, 3, 1))
        tmp_image_batch = tmp_image_batch[:, :, :, ::-1]  # RGB -> BGR
        warped_im_batch = torch.zeros_like(warped_image_batch,
                                           dtype=torch.float32)
        for i in range(tmp_image_batch.shape[0]):
            warped_im_batch[i] = roi_data(tmp_image_batch[i])[0]
        warped_im_batch = warped_im_batch.cuda()
        warped_im_batch.requires_grad = False

        # img_A_batch.shape, img_B_batch.shape, warped_image_batch.shape: (batch_size, 3, 240, 240)
        # theta_batch.shape-tps: (batch_size, 18)-random or (batch_size, 18, 1, 1)-(pre-set from csv)
        # theta_batch.shape-affine: (batch_size, 2, 3)
        # return {'source_image': img_A_batch, 'target_image': img_B_batch, 'refer_image': warped_image_batch,
        #         'source_im': batch['source_im'], 'target_im': batch['target_im'], 'refer_im': warped_im_batch,
        #         'source_im_info': batch['source_im_info'], 'target_im_info': batch['target_im_info'], 'refer_im_info': batch['source_im_info'],
        #         'source_gt_boxes': batch['source_gt_boxes'], 'target_gt_boxes': batch['target_gt_boxes'], 'refer_gt_boxes': batch['source_gt_boxes'],
        #         'source_num_boxes': batch['source_num_boxes'], 'target_num_boxes': batch['target_num_boxes'], 'refer_num_boxes': batch['source_num_boxes'],
        #         'theta_aff_GT': theta_aff, 'theta_tps_GT': theta_tps}

        return {
            'source_image': img_A_batch,
            'target_image': img_B_batch,
            'refer_image': warped_image_batch,
            'source_im_info': batch['source_im_info'],
            'target_im_info': batch['target_im_info'],
            'refer_im_info': batch['source_im_info'],
            'theta_aff_GT': theta_aff,
            'theta_tps_GT': theta_tps
        }
예제 #18
0
def vis_tss(vis,
            dataloader,
            theta,
            theta_weak,
            csv_file,
            title,
            use_cuda=True):
    # Visualize watch images
    dataframe = pd.read_csv(csv_file)
    scores_det = dataframe.iloc[:, 5]
    scores_det_aff = dataframe.iloc[:, 6]
    scores_det_aff_tps = dataframe.iloc[:, 7]
    scores_aff = dataframe.iloc[:, 8]
    scores_aff_tps = dataframe.iloc[:, 9]
    affTnf = GeometricTnf(geometric_model='affine', use_cuda=use_cuda)
    tpsTnf = GeometricTnf(geometric_model='tps', use_cuda=use_cuda)
    watch_images = torch.ones(len(dataloader) * 8, 3, 280, 240)
    if use_cuda:
        watch_images = watch_images.cuda()
    image_names = list()
    flow = list()
    fnt = cv2.FONT_HERSHEY_COMPLEX
    # means for normalize of caffe resnet and vgg
    # pixel_means = torch.Tensor(np.array([[[[102.9801, 115.9465, 122.7717]]]]).astype(np.float32))
    for batch_idx, batch in enumerate(dataloader):
        if use_cuda:
            batch = batch_cuda(batch)

        theta_det = theta['det'][batch_idx].unsqueeze(0)
        theta_det_aff = theta['det_aff'][batch_idx].unsqueeze(0)
        theta_det_aff_tps = theta['det_aff_tps'][batch_idx].unsqueeze(0)
        theta_aff = theta_weak['aff'][batch_idx].unsqueeze(0)
        theta_aff_tps = theta_weak['aff_tps'][batch_idx].unsqueeze(0)

        # Warped image
        warped_det = affTnf(batch['source_image'], theta_det)
        warped_det_aff = affTnf(warped_det, theta_det_aff)
        warped_det_aff_tps = tpsTnf(warped_det_aff, theta_det_aff_tps)
        warped_aff = affTnf(batch['source_image'], theta_aff)
        warped_aff_tps = tpsTnf(warped_aff, theta_aff_tps)

        watch_images[batch_idx * 8, :, 0:240, :] = batch['source_image']
        watch_images[batch_idx * 8 + 1, :, 0:240, :] = warped_det
        watch_images[batch_idx * 8 + 2, :, 0:240, :] = warped_det_aff
        watch_images[batch_idx * 8 + 3, :, 0:240, :] = warped_det_aff_tps
        watch_images[batch_idx * 8 + 4, :, 0:240, :] = batch['target_image']
        watch_images[batch_idx * 8 + 6, :, 0:240, :] = warped_aff
        watch_images[batch_idx * 8 + 7, :, 0:240, :] = warped_aff_tps

        image_names.append('Source')
        image_names.append('Det')
        image_names.append('Det_aff')
        image_names.append('Det_aff_tps')
        image_names.append('Target')
        image_names.append('')
        image_names.append('Rocco_aff')
        image_names.append('Rocco_aff_tps')

        flow.append('')
        flow.append('Flow: {:.3f}'.format(scores_det[batch_idx]))
        flow.append('Flow: {:.3f}'.format(scores_det_aff[batch_idx]))
        flow.append('Flow: {:.3f}'.format(scores_det_aff_tps[batch_idx]))
        flow.append('')
        flow.append('')
        flow.append('Flow: {:.3f}'.format(scores_aff[batch_idx]))
        flow.append('Flow: {:.3f}'.format(scores_aff_tps[batch_idx]))

    opts = dict(jpgquality=100, title=title)
    # Un-normalize for caffe resnet and vgg
    # watch_images = watch_images.permute(0, 2, 3, 1) + pixel_means
    # watch_images = watch_images[:, :, :, [2, 1, 0]].permute(0, 3, 1, 2)
    # watch_images = normalize_image(watch_images, forward=False) * 255.0
    watch_images[:, :, 0:240, :] = normalize_image(watch_images[:, :,
                                                                0:240, :],
                                                   forward=False)
    watch_images *= 255.0
    watch_images = watch_images.permute(0, 2, 3,
                                        1).cpu().numpy().astype(np.uint8)
    for i in range(watch_images.shape[0]):
        if (i + 1) % 8 != 6:
            pos_name = (80, 255)
            if (i + 1) % 8 == 1 or (i + 1) % 8 == 5:
                pos_lt_ac = (0, 0)
                pos_flow = (0, 0)
            else:
                pos_flow = (70, 275)
            cv2.putText(watch_images[i], image_names[i], pos_name, fnt, 0.5,
                        (0, 0, 0), 1)
            cv2.putText(watch_images[i], flow[i], pos_flow, fnt, 0.5,
                        (0, 0, 0), 1)
        else:
            watch_images[i] = 255

    watch_images = torch.Tensor(watch_images.astype(np.float32))
    watch_images = watch_images.permute(0, 3, 1, 2)
    vis.image(torchvision.utils.make_grid(watch_images, nrow=4, padding=5),
              opts=opts)