예제 #1
0
def create_kitti_submission(model,
                            iters=24,
                            output_path='kitti_submission',
                            write_png=False):
    """ Create submission for the Sintel leaderboard """
    model.eval()
    test_dataset = datasets.KITTI(split='testing', aug_params=None)

    if not os.path.exists(output_path):
        os.makedirs(output_path)

    if write_png:
        out_path_png = output_path + '_png'
        if not os.path.exists(out_path_png):
            os.makedirs(out_path_png)

    for test_id in range(len(test_dataset)):
        image1, image2, (frame_id, ) = test_dataset[test_id]
        padder = InputPadder(image1.shape, mode='kitti')
        image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())

        _, flow_pr = model(image1, image2, iters=iters, test_mode=True)
        flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy()

        if write_png:
            output_filename_png = os.path.join(out_path_png, frame_id + '.png')
            cv2.imwrite(output_filename_png, flow_viz.flow_to_image(flow))

        output_filename = os.path.join(output_path, frame_id)
        frame_utils.writeFlowKITTI(output_filename, flow)
예제 #2
0
def validate_kitti(args, model, iters=32):
    """ Evaluate trained model on KITTI (train) """

    model.eval()
    val_dataset = datasets.KITTI(args, do_augument=False, is_val=True, do_pad=True)

    with torch.no_grad():
        epe_list, out_list = [], []
        for i in range(len(val_dataset)):
            image1, image2, flow_gt, valid_gt = val_dataset[i]
            image1 = image1[None].cuda()
            image2 = image2[None].cuda()
            flow_gt = flow_gt.cuda()
            valid_gt = valid_gt.cuda()

            flow_predictions = model.module(image1, image2, iters=iters)
            flow_pr = flow_predictions[-1][0]

            epe = torch.sum((flow_pr - flow_gt)**2, dim=0).sqrt()
            mag = torch.sum(flow_gt**2, dim=0).sqrt()
            
            epe = epe.view(-1)
            mag = mag.view(-1)
            val = valid_gt.view(-1) >= 0.5

            out = ((epe > 3.0) & ((epe/mag) > 0.05)).float()
            epe_list.append(epe[val].mean().item())
            out_list.append(out[val].cpu().numpy())

    epe_list = np.array(epe_list)
    out_list = np.concatenate(out_list)


    print("Validation KITTI: %f, %f" % (np.mean(epe_list), 100*np.mean(out_list)))
예제 #3
0
파일: train.py 프로젝트: Pavelrst/RAFT
def fetch_dataloader(args):
    """ Create the data loader for the corresponding trainign set """

    if args.dataset == 'chairs':
        train_dataset = datasets.FlyingChairs(args,
                                              root=args.dataset_root,
                                              image_size=args.image_size)

    elif args.dataset == 'things':
        clean_dataset = datasets.SceneFlow(args,
                                           root=args.dataset_root,
                                           image_size=args.image_size,
                                           dstype='frames_cleanpass')
        final_dataset = datasets.SceneFlow(args,
                                           root=args.dataset_root,
                                           image_size=args.image_size,
                                           dstype='frames_finalpass')
        train_dataset = clean_dataset + final_dataset

    elif args.dataset == 'sintel':
        clean_dataset = datasets.MpiSintel_Train(args,
                                                 root=args.dataset_root,
                                                 image_size=args.image_size,
                                                 dstype='clean')
        final_dataset = datasets.MpiSintel_Train(args,
                                                 root=args.dataset_root,
                                                 image_size=args.image_size,
                                                 dstype='final')
        assert len(clean_dataset) == 908 and len(final_dataset) == 908
        train_dataset = clean_dataset + final_dataset

    elif args.dataset == 'kitti':
        train_dataset = datasets.KITTI(args,
                                       root=args.dataset_root,
                                       image_size=args.image_size,
                                       is_val=False)

    else:
        raise NotImplementedError

    gpuargs = {'num_workers': args.num_of_workers, 'drop_last': True}
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              pin_memory=True,
                              shuffle=True,
                              **gpuargs)

    print('Training with %d image pairs' % len(train_dataset))
    return train_loader
예제 #4
0
def validate_kitti(model, iters=24):
    """ Peform validation using the KITTI-2015 (train) split """
    model.eval()
    val_dataset = datasets.KITTI(split='training')

    out_list, epe_list = [], []
    for val_id in range(len(val_dataset)):
        image1, image2, flow_gt, valid_gt = val_dataset[val_id]
        image1 = image1[None].cuda()
        image2 = image2[None].cuda()

        padder = InputPadder(image1.shape, mode='kitti')
        image1, image2 = padder.pad(image1, image2)

        flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True)
        flow = padder.unpad(flow_pr[0]).cpu()

        epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt()
        mag = torch.sum(flow_gt**2, dim=0).sqrt()

        epe = epe.view(-1)
        mag = mag.view(-1)
        val = valid_gt.view(-1) >= 0.5

        out = ((epe > 3.0) & ((epe/mag) > 0.05)).float()
        epe_list.append(epe[val].mean().item())
        out_list.append(out[val].cpu().numpy())

    epe_list = np.array(epe_list)
    out_list = np.concatenate(out_list)

    epe = np.mean(epe_list)
    f1 = 100 * np.mean(out_list)

    print("Validation KITTI: %f, %f" % (epe, f1))
    return {'kitti-epe': epe, 'kitti-f1': f1}
예제 #5
0
def main():
    global args, best_EPE, save_path
    args = parser.parse_args()

    # Load config file
    if args.cfg is not None:
        cfg_from_file(args.cfg)
        assert cfg.TAG == os.path.splitext(os.path.basename(
            args.cfg))[0], 'TAG name should be file name'

    # Build save_path, which can be specified by out_dir and exp_dir
    save_path = '{},{}epochs{},b{},lr{}'.format(
        'dicl_wrapper', args.epochs,
        ',epochSize' + str(args.epoch_size) if args.epoch_size > 0 else '',
        args.batch_size, args.lr)

    save_path = os.path.join(args.exp_dir, save_path)
    if args.out_dir is not None:
        outpath = os.path.join(args.out_dir, args.dataset)
    else:
        outpath = args.dataset
    save_path = os.path.join(outpath, save_path)

    if not os.path.exists(outpath): os.makedirs(outpath)
    if not os.path.exists(save_path): os.makedirs(save_path)

    # Create logger
    log_file = os.path.join(save_path, 'log.txt')
    logger = create_logger(log_file)
    logger.info('**********************Start logging**********************')
    logger.info('=> will save everything to {}'.format(save_path))

    # Print settings
    for _, key in enumerate(args.__dict__):
        logger.info(args.__dict__[key])
    save_config_to_file(cfg, logger=logger)
    logger.info(args.pretrained)

    # Set random seed
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)

    train_writer = SummaryWriter(os.path.join(save_path, 'train'))
    eval_writer = SummaryWriter(os.path.join(save_path, 'eval'))

    logger.info("=> fetching img pairs in '{}'".format(args.data))

    ########################## DATALOADER ##########################
    if args.dataset == 'flying_chairs':
        if cfg.SIMPLE_AUG:
            train_dataset = datasets.FlyingChairs_SimpleAug(args,
                                                            root=args.data)
            test_dataset = datasets.FlyingChairs_SimpleAug(args,
                                                           root=args.data,
                                                           mode='val')
        else:
            train_dataset = datasets.FlyingChairs(args,
                                                  image_size=cfg.CROP_SIZE,
                                                  root=args.data)
            test_dataset = datasets.FlyingChairs(args,
                                                 root=args.data,
                                                 mode='val',
                                                 do_augument=False)
    elif args.dataset == 'flying_things':
        train_dataset = datasets.SceneFlow(args,
                                           image_size=cfg.CROP_SIZE,
                                           root=args.data,
                                           dstype='frames_cleanpass',
                                           mode='train')
        test_dataset = datasets.SceneFlow(args,
                                          image_size=cfg.CROP_SIZE,
                                          root=args.data,
                                          dstype='frames_cleanpass',
                                          mode='val',
                                          do_augument=False)
    elif args.dataset == 'mpi_sintel_clean' or args.dataset == 'mpi_sintel_final':
        clean_dataset = datasets.MpiSintel(args,
                                           image_size=cfg.CROP_SIZE,
                                           root=args.data,
                                           dstype='clean')
        final_dataset = datasets.MpiSintel(args,
                                           image_size=cfg.CROP_SIZE,
                                           root=args.data,
                                           dstype='final')
        train_dataset = torch.utils.data.ConcatDataset([clean_dataset] +
                                                       [final_dataset])
        if args.dataset == 'mpi_sintel_final':
            test_dataset = datasets.MpiSintel(args,
                                              do_augument=False,
                                              image_size=None,
                                              root=args.data,
                                              dstype='final')
        else:
            test_dataset = datasets.MpiSintel(args,
                                              do_augument=False,
                                              image_size=None,
                                              root=args.data,
                                              dstype='clean')
    elif args.dataset == 'KITTI':
        train_dataset = datasets.KITTI(args,
                                       image_size=cfg.CROP_SIZE,
                                       root=args.data,
                                       is_val=False,
                                       logger=logger)
        if args.data_kitti12 is not None:
            train_dataset12 = datasets.KITTI12(args,
                                               image_size=cfg.CROP_SIZE,
                                               root=args.data_kitti12,
                                               is_val=False,
                                               logger=logger)
            train_dataset = torch.utils.data.ConcatDataset([train_dataset] +
                                                           [train_dataset12])
        test_dataset = datasets.KITTI(args,
                                      root=args.data,
                                      do_augument=False,
                                      is_val=True,
                                      do_pad=False)
    else:
        raise NotImplementedError

    logger.info('Training with %d image pairs' % len(train_dataset))

    logger.info('Testing with %d image pairs' % len(test_dataset))

    gpuargs = {'num_workers': args.workers, 'drop_last': cfg.DROP_LAST}
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               pin_memory=True,
                                               shuffle=True,
                                               **gpuargs)

    if 'KITTI' in args.dataset:
        # We set batch size to 1 since KITTI images have different sizes
        val_loader = torch.utils.data.DataLoader(test_dataset,
                                                 batch_size=1,
                                                 num_workers=args.workers,
                                                 pin_memory=True,
                                                 shuffle=False)
    else:
        val_loader = torch.utils.data.DataLoader(test_dataset,
                                                 batch_size=args.batch_size,
                                                 num_workers=args.workers,
                                                 pin_memory=True,
                                                 shuffle=False)

    # create model
    if args.pretrained:
        logger.info("=> using pre-trained model '{}'".format(args.pretrained))
        pretrained_dict = torch.load(args.pretrained)

        if 'state_dict' in pretrained_dict.keys():
            pretrained_dict['state_dict'] = {
                k: v
                for k, v in pretrained_dict['state_dict'].items()
            }

    model = models.__dict__['dicl_wrapper'](None)

    assert (args.solver in ['adam', 'sgd'])
    logger.info('=> setting {} solver'.format(args.solver))

    if args.solver == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.lr,
                                     weight_decay=cfg.WEIGHT_DECAY,
                                     betas=(cfg.MOMENTUM, cfg.BETA))
    elif args.solver == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    weight_decay=cfg.WEIGHT_DECAY,
                                    momentum=cfg.MOMENTUM)

    if args.pretrained:
        if 'state_dict' in pretrained_dict.keys():
            model.load_state_dict(pretrained_dict['state_dict'], strict=False)
        else:
            model.load_state_dict(pretrained_dict, strict=False)

        if args.reuse_optim:
            try:
                optimizer.load_state_dict(pretrained_dict['optimizer_state'])
            except:
                logger.info('do not have optimizer state')
        del pretrained_dict
        torch.cuda.empty_cache()

    model = torch.nn.DataParallel(model)

    if torch.cuda.is_available():
        model = model.cuda()

    # Evaluation
    if args.evaluate:
        with torch.no_grad():
            best_EPE = validate(val_loader,
                                model,
                                0,
                                None,
                                eval_writer,
                                logger=logger)
        return

    # Learning rate schedule
    milestones = []
    for num in range(len(args.milestones)):
        milestones.append(int(args.milestones[num]))

    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=milestones,
                                                     gamma=0.5)

    ###################################### Training  ######################################
    for epoch in range(args.start_epoch, args.epochs):

        # train for one epoch
        train_loss = train(train_loader,
                           model,
                           optimizer,
                           epoch,
                           train_writer,
                           logger=logger)
        scheduler.step()

        train_writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)
        train_writer.add_scalar('avg_loss', train_loss, epoch)

        if epoch % args.eval_freq == 0 and not args.no_eval:
            with torch.no_grad():
                EPE = validate(val_loader,
                               model,
                               epoch,
                               output_writers,
                               eval_writer,
                               logger=logger)
            eval_writer.add_scalar('mean_EPE', EPE, epoch)

            if best_EPE < 0:
                best_EPE = EPE

            if EPE < best_EPE:
                best_EPE = EPE
                ckpt_best_file = 'checkpoint_best.pth.tar'
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'arch': 'dicl_wrapper',
                        'state_dict': model.module.state_dict(),
                        'optimizer_state': optimizer.state_dict(),
                        'best_EPE': EPE
                    },
                    False,
                    filename=ckpt_best_file)
            logger.info('Epoch: [{0}] Best EPE: {1}'.format(epoch, best_EPE))

        # Skip at least 5 epochs to save memory
        save_freq = max(args.eval_freq, 5)
        if epoch % save_freq == 0:
            ckpt_file = 'checkpoint_' + str(epoch) + '.pth.tar'
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': 'dicl_wrapper',
                    'state_dict': model.module.state_dict(),
                    'optimizer_state': optimizer.state_dict(),
                    'best_EPE': best_EPE
                },
                False,
                filename=ckpt_file)