コード例 #1
0
def create_kitti_submission(model, iters=24, output_path='kitti_submission'):
    """ Create submission for the Sintel leaderboard """
    model.eval()
    test_dataset = datasets.KITTI(split='testing', aug_params=None)

    if not os.path.exists(output_path):
        os.makedirs(output_path)

    for test_id in range(len(test_dataset)):
        image1, image2, (frame_id, ) = test_dataset[test_id]
        padder = InputPadder(image1.shape, mode='kitti')
        image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())

        _, flow_pr = model(image1, image2, iters=iters, test_mode=True)
        flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy()

        output_filename = os.path.join(output_path, frame_id)
        frame_utils.writeFlowKITTI(output_filename, flow)
コード例 #2
0
ファイル: train.py プロジェクト: kmbriedis/RAFT
def fetch_dataloader(args):
    """ Create the data loader for the corresponding training set """

    if args.dataset == 'chairs':
        train_dataset = datasets.FlyingChairs(args, image_size=args.image_size)

    elif args.dataset == 'things':
        clean_dataset = datasets.SceneFlow(args,
                                           image_size=args.image_size,
                                           dstype='frames_cleanpass')
        final_dataset = datasets.SceneFlow(args,
                                           image_size=args.image_size,
                                           dstype='frames_finalpass')
        train_dataset = clean_dataset + final_dataset

    elif args.dataset == 'sintel':
        clean_dataset = datasets.MpiSintel(args,
                                           image_size=args.image_size,
                                           dstype='clean')
        final_dataset = datasets.MpiSintel(args,
                                           image_size=args.image_size,
                                           dstype='final')
        train_dataset = clean_dataset + final_dataset

    elif args.dataset == 'kitti':
        train_dataset = datasets.KITTI(args,
                                       image_size=args.image_size,
                                       is_val=False)

    gpuargs = {'num_workers': 4, 'drop_last': True}
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              pin_memory=True,
                              shuffle=True,
                              **gpuargs)

    print('Training with %d image pairs' % len(train_dataset))
    return train_loader
コード例 #3
0
ファイル: evaluate.py プロジェクト: kmbriedis/RAFT
def validate_kitti(args, model, iters=32):
    """ Evaluate trained model on KITTI (train) """

    model.eval()
    val_dataset = datasets.KITTI(args,
                                 do_augument=False,
                                 is_val=True,
                                 do_pad=True)

    with torch.no_grad():
        epe_list, out_list = [], []
        for i in range(len(val_dataset)):
            image1, image2, flow_gt, valid_gt = val_dataset[i]
            image1 = image1[None].cuda()
            image2 = image2[None].cuda()
            flow_gt = flow_gt.cuda()
            valid_gt = valid_gt.cuda()

            flow_predictions = model.module(image1, image2, iters=iters)
            flow_pr = flow_predictions[-1][0]

            epe = torch.sum((flow_pr - flow_gt)**2, dim=0).sqrt()
            mag = torch.sum(flow_gt**2, dim=0).sqrt()

            epe = epe.view(-1)
            mag = mag.view(-1)
            val = valid_gt.view(-1) >= 0.5

            out = ((epe > 3.0) & ((epe / mag) > 0.05)).float()
            epe_list.append(epe[val].mean().item())
            out_list.append(out[val].cpu().numpy())

    epe_list = np.array(epe_list)
    out_list = np.concatenate(out_list)

    print("Validation KITTI: %f, %f" %
          (np.mean(epe_list), 100 * np.mean(out_list)))
コード例 #4
0
def validate_kitti(model, args, iters=24):
    """ Peform validation using the KITTI-2015 (train) split """
    model.eval()
    val_dataset = datasets.KITTI(split='training', root=args.dataset)

    from tqdm import tqdm
    out_list, epe_list = [], []
    for _, val_id in enumerate(tqdm(list(range(len(val_dataset))))):
        image1, image2, flow_gt, valid_gt = val_dataset[val_id]
        image1 = image1[None].cuda()
        image2 = image2[None].cuda()

        padder = InputPadder(image1.shape, mode='kitti')
        image1, image2 = padder.pad(image1, image2)

        flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True)
        flow = padder.unpad(flow_pr[0]).cpu()

        epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt()
        mag = torch.sum(flow_gt**2, dim=0).sqrt()

        epe = epe.view(-1)
        mag = mag.view(-1)
        val = valid_gt.view(-1) >= 0.5

        out = ((epe > 3.0) & ((epe / mag) > 0.05)).float()
        epe_list.append(epe[val].mean().item())
        out_list.append(out[val].cpu().numpy())

    epe_list = np.array(epe_list)
    out_list = np.concatenate(out_list)

    epe = np.mean(epe_list)
    f1 = 100 * np.mean(out_list)

    print("Validation KITTI: %f, %f" % (epe, f1))
    return {'kitti-epe': epe, 'kitti-f1': f1}
コード例 #5
0
def validate_kitti_colorjitter(model, args, iters=24):
    """ Peform validation using the KITTI-2015 (train) split """
    from torchvision.transforms import ColorJitter
    from tqdm import tqdm
    model.eval()
    val_dataset = datasets.KITTI(split='training', root=args.dataset)

    jitterparam = 0.86
    photo_aug = ColorJitter(brightness=jitterparam,
                            contrast=jitterparam,
                            saturation=jitterparam,
                            hue=jitterparam / 3.14)

    def color_transform(img1, img2, photo_aug):
        torch.manual_seed(1234)
        np.random.seed(1234)
        img1 = img1.permute([1, 2, 0]).numpy().astype(np.uint8)
        img2 = img2.permute([1, 2, 0]).numpy().astype(np.uint8)
        image_stack = np.concatenate([img1, img2], axis=0)
        image_stack = np.array(photo_aug(Image.fromarray(image_stack)),
                               dtype=np.uint8)
        img1, img2 = np.split(image_stack, 2, axis=0)
        img1 = torch.from_numpy(img1).permute([2, 0, 1]).float()
        img2 = torch.from_numpy(img2).permute([2, 0, 1]).float()
        return img1, img2

    out_list, epe_list = [], []
    for _, val_id in enumerate(tqdm(list(range(len(val_dataset))))):
        image1, image2, flow_gt, valid_gt = val_dataset[val_id]
        image1, image2 = color_transform(image1, image2, photo_aug)

        image1 = image1[None].cuda()
        image2 = image2[None].cuda()

        padder = InputPadder(image1.shape, mode='kitti')
        image1, image2 = padder.pad(image1, image2)

        flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True)
        flow = padder.unpad(flow_pr[0]).cpu()

        epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt()
        mag = torch.sum(flow_gt**2, dim=0).sqrt()

        epe = epe.view(-1)
        mag = mag.view(-1)
        val = valid_gt.view(-1) >= 0.5

        print("Index: %d, valnum: %d" % (val_id, torch.sum(valid_gt).item()))

        out = ((epe > 3.0) & ((epe / mag) > 0.05)).float()
        epe_list.append(epe[val].mean().item())
        out_list.append(out[val].cpu().numpy())

    epe_list = np.array(epe_list)
    out_list = np.concatenate(out_list)

    epe = np.mean(epe_list)
    f1 = 100 * np.mean(out_list)

    print("jitterparam:%f, Validation KITTI: %f, %f" % (jitterparam, epe, f1))
    return {'kitti-epe': epe, 'kitti-f1': f1}
コード例 #6
0
def validate_kitti_customized(model, iters=24):
    """ Peform validation using the KITTI-2015 (train) split """
    model.eval()
    val_dataset = datasets.KITTI(
        split='training',
        root='/home/shengjie/Documents/Data/Kitti/kitti_stereo/stereo15')

    out_list, epe_list = [], []
    for val_id in range(len(val_dataset)):
        image1, image2, flow_gt, valid_gt = val_dataset[val_id]
        image1 = image1[None].cuda()
        image2 = image2[None].cuda()

        padder = InputPadder(image1.shape, mode='kitti')
        image1, image2 = padder.pad(image1, image2)

        flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True)
        flow = padder.unpad(flow_pr[0]).cpu()

        flowT = flow.cpu()
        flownp = flowT.numpy()

        image1_vls = padder.unpad(image1[0]).cpu()
        image2_vls = padder.unpad(image2[0]).cpu()

        image1_vlsnp = image1_vls.permute([1, 2,
                                           0]).cpu().numpy().astype(np.uint8)
        image2_vlsnp = image2_vls.permute([1, 2,
                                           0]).cpu().numpy().astype(np.uint8)
        flow_gt_vls_np = flow_gt.cpu().numpy()
        valid_gt_vls_np = valid_gt.cpu().numpy()

        _, h, w = flowT.shape
        xx, yy = np.meshgrid(range(w), range(h), indexing='xy')
        resampledxx = xx + flowT[0].cpu().numpy()
        resampledyy = yy + flowT[1].cpu().numpy()

        epipole_vote(xx, yy, flownp, image1_vlsnp, image2_vlsnp,
                     flow_gt_vls_np, valid_gt_vls_np)

        resampledxx = ((resampledxx / (w - 1)) - 0.5) * 2
        resampledyy = ((resampledyy / (h - 1)) - 0.5) * 2
        resamplegrid = torch.stack(
            [torch.from_numpy(resampledxx),
             torch.from_numpy(resampledyy)],
            dim=2).unsqueeze(0).float()
        image1_recon_vls = torch.nn.functional.grid_sample(
            input=image2_vls.unsqueeze(0),
            grid=resamplegrid,
            mode='bilinear',
            padding_mode='reflection')

        # rndx = np.random.randint(0, w)
        # rndy = np.random.randint(0, h)
        rndx = 215
        rndy = 278
        tarx = rndx + flownp[0, int(rndy), int(rndx)]
        tary = rndy + flownp[1, int(rndy), int(rndx)]

        plt.figure()
        plt.imshow(image1.squeeze().permute([1, 2, 0
                                             ]).cpu().numpy().astype(np.uint8))
        plt.scatter(rndx, rndy, 1, 'r')

        plt.figure()
        plt.imshow(image2.squeeze().permute([1, 2, 0
                                             ]).cpu().numpy().astype(np.uint8))
        plt.scatter(tarx, tary, 1, 'r')

        plt.figure()
        plt.imshow(image1_recon_vls.squeeze().permute(
            [1, 2, 0]).cpu().numpy().astype(np.uint8))

        import PIL.Image as Image
        from core.utils.flow_viz import flow_to_image
        flowimg = flow_to_image(flow.permute([1, 2, 0]).cpu().numpy())
        Image.fromarray(flowimg).show()

        epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt()
        mag = torch.sum(flow_gt**2, dim=0).sqrt()

        epe = epe.view(-1)
        mag = mag.view(-1)
        val = valid_gt.view(-1) >= 0.5

        out = ((epe > 3.0) & ((epe / mag) > 0.05)).float()
        epe_list.append(epe[val].mean().item())
        out_list.append(out[val].cpu().numpy())

    epe_list = np.array(epe_list)
    out_list = np.concatenate(out_list)

    epe = np.mean(epe_list)
    f1 = 100 * np.mean(out_list)

    print("Validation KITTI: %f, %f" % (epe, f1))
    return {'kitti-epe': epe, 'kitti-f1': f1}