示例#1
0
文件: demo.py 项目: kmbriedis/RAFT
def demo(args):
    model = RAFT(args)
    model = torch.nn.DataParallel(model)
    model.load_state_dict(torch.load(args.model))

    model.to(DEVICE)
    model.eval()

    with torch.no_grad():

        # sintel images
        image1 = load_image('images/sintel_0.png')
        image2 = load_image('images/sintel_1.png')

        flow_predictions = model(image1, image2, iters=args.iters, upsample=False)
        display(image1[0], image2[0], flow_predictions[-1][0])

        # kitti images
        image1 = load_image('images/kitti_0.png')
        image2 = load_image('images/kitti_1.png')

        flow_predictions = model(image1, image2, iters=16)    
        display(image1[0], image2[0], flow_predictions[-1][0])

        # davis images
        image1 = load_image('images/davis_0.jpg')
        image2 = load_image('images/davis_1.jpg')

        flow_predictions = model(image1, image2, iters=16)    
        display(image1[0], image2[0], flow_predictions[-1][0])
示例#2
0
def train(gpu, ngpus_per_node, args):
    print("Using GPU %d for training" % gpu)
    args.gpu = gpu

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=ngpus_per_node, rank=args.gpu)

    model = RAFT(args)
    if args.distributed:
        torch.cuda.set_device(args.gpu)
        args.batch_size = int(args.batch_size / ngpus_per_node)
        model = nn.SyncBatchNorm.convert_sync_batchnorm(module=model)
        model = model.to(f'cuda:{args.gpu}')
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True, output_device=args.gpu)

        eppCbck = eppConstrainer_background(height=args.image_size[0], width=args.image_size[1], bz=args.batch_size)
        eppCbck.to(f'cuda:{args.gpu}')

        eppconcluer = eppConcluer()
        eppconcluer.to(f'cuda:{args.gpu}')
    else:
        model = torch.nn.DataParallel(model)
        model.cuda()


    if args.restore_ckpt is not None:
        print("=> loading checkpoint '{}'".format(args.restore_ckpt))
        loc = 'cuda:{}'.format(args.gpu)
        checkpoint = torch.load(args.restore_ckpt, map_location=loc)
        model.load_state_dict(checkpoint, strict=False)

    model.eval()

    if args.stage != 'chairs':
        model.module.freeze_bn()

    _, evaluation_entries = read_splits()

    eval_dataset = KITTI_eigen(split='evaluation', root=args.dataset_root, entries=evaluation_entries, semantics_root=args.semantics_root, depth_root=args.depth_root)
    eval_sampler = torch.utils.data.distributed.DistributedSampler(eval_dataset) if args.distributed else None
    eval_loader = data.DataLoader(eval_dataset, batch_size=1, pin_memory=True,
                                   shuffle=(eval_sampler is None), num_workers=4, drop_last=True,
                                   sampler=eval_sampler)

    if args.distributed:
        group = dist.new_group([i for i in range(ngpus_per_node)])

    print(validate_kitti(model.module, args, eval_loader, eppCbck, eppconcluer, group))
    return
示例#3
0
def RAFT(pretrained=False, model_name="chairs+things", device=None, **kwargs):
    """
    RAFT model (https://arxiv.org/abs/2003.12039)
    model_name (str): One of 'chairs+things', 'sintel', 'kitti' and 'small'
                      note that for 'small', the architecture is smaller
    """

    model_list = ["chairs+things", "sintel", "kitti", "small"]
    if model_name not in model_list:
        raise ValueError("Model should be one of " + str(model_list))

    model_args = argparse.Namespace(**kwargs)
    model_args.small = "small" in model_name

    model = RAFT_module(model_args)
    if device is None:
        device = torch.cuda.current_device() if torch.cuda.is_available(
        ) else "cpu"
    if device != "cpu":
        model = torch.nn.DataParallel(model, device_ids=[device])
    else:
        model = torch.nn.DataParallel(model)
        model.device_ids = None

    if pretrained:
        torch_home = _get_torch_home()
        model_dir = os.path.join(torch_home, "checkpoints", "models_RAFT")
        model_path = os.path.join(model_dir, "models", model_name + ".pth")
        if not os.path.exists(model_dir):
            os.makedirs(model_dir, exist_ok=True)
            response = urllib.request.urlopen(models_url, timeout=10)
            z = zipfile.ZipFile(io.BytesIO(response.read()))
            z.extractall(model_dir)
        else:
            time.sleep(
                10
            )  # Give the time for the models to be downloaded and unzipped

        map_location = torch.device('cpu') if device == "cpu" else None
        model.load_state_dict(torch.load(model_path,
                                         map_location=map_location))

    model = model.to(device)
    model.eval()
    return model
示例#4
0
def train(gpu, ngpus_per_node, args):
    print("Using GPU %d for training" % gpu)
    args.gpu = gpu

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=ngpus_per_node,
                                rank=args.gpu)

    model = RAFT(args)
    if args.distributed:
        torch.cuda.set_device(args.gpu)
        args.batch_size = int(args.batch_size / ngpus_per_node)
        model = nn.SyncBatchNorm.convert_sync_batchnorm(module=model)
        model = model.to(f'cuda:{args.gpu}')
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.gpu],
            find_unused_parameters=True,
            output_device=args.gpu)
    else:
        model = torch.nn.DataParallel(model)
        model.cuda()
    logroot = os.path.join(args.logroot, args.name)
    print("Parameter Count: %d, saving location: %s" %
          (count_parameters(model), logroot))

    if args.restore_ckpt is not None:
        print("=> loading checkpoint '{}'".format(args.restore_ckpt))
        loc = 'cuda:{}'.format(args.gpu)
        checkpoint = torch.load(args.restore_ckpt, map_location=loc)
        model.load_state_dict(checkpoint, strict=False)

    model.train()

    if args.stage != 'chairs':
        model.module.freeze_bn()

    train_entries, evaluation_entries = read_splits()
    aug_params = {
        'crop_size': args.image_size,
        'min_scale': -0.2,
        'max_scale': 0.4,
        'do_flip': False
    }
    train_dataset = VirtualKITTI2(aug_params,
                                  split='training',
                                  root=args.dataset_root,
                                  entries=train_entries)
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset) if args.distributed else None
    train_loader = data.DataLoader(train_dataset,
                                   batch_size=args.batch_size,
                                   pin_memory=False,
                                   shuffle=(train_sampler is None),
                                   num_workers=args.num_workers,
                                   drop_last=True,
                                   sampler=train_sampler)

    eval_dataset = VirtualKITTI2(split='evaluation',
                                 root=args.dataset_root,
                                 entries=evaluation_entries)
    eval_sampler = torch.utils.data.distributed.DistributedSampler(
        eval_dataset) if args.distributed else None
    eval_loader = data.DataLoader(eval_dataset,
                                  batch_size=args.batch_size,
                                  pin_memory=False,
                                  shuffle=(eval_sampler is None),
                                  num_workers=args.num_workers,
                                  drop_last=True,
                                  sampler=eval_sampler)

    if args.distributed:
        group = dist.new_group([i for i in range(ngpus_per_node)])

    optimizer, scheduler = fetch_optimizer(args, model)

    total_steps = 0
    scaler = GradScaler(enabled=args.mixed_precision)

    if args.gpu == 0:
        logger = Logger(model, scheduler, logroot)
        logger_evaluation = Logger(
            model, scheduler,
            os.path.join(args.logroot, 'evaluation_VRKitti', args.name))

    VAL_FREQ = 500
    add_noise = True
    epoch = 0

    should_keep_training = True
    while should_keep_training:

        for i_batch, data_blob in enumerate(train_loader):
            optimizer.zero_grad()
            image1, image2, flow, valid = data_blob

            image1 = Variable(image1, requires_grad=True)
            image1 = image1.cuda(gpu, non_blocking=True)

            image2 = Variable(image2, requires_grad=True)
            image2 = image2.cuda(gpu, non_blocking=True)

            flow = Variable(flow, requires_grad=True)
            flow = flow.cuda(gpu, non_blocking=True)

            valid = Variable(valid, requires_grad=True)
            valid = valid.cuda(gpu, non_blocking=True)

            if add_noise:
                stdv = np.random.uniform(0.0, 5.0)
                image1 = (image1 + stdv * torch.randn(*image1.shape).cuda(
                    gpu, non_blocking=True)).clamp(0.0, 255.0)
                image2 = (image2 + stdv * torch.randn(*image2.shape).cuda(
                    gpu, non_blocking=True)).clamp(0.0, 255.0)

            flow_predictions = model(image1, image2, iters=args.iters)

            loss, metrics = sequence_loss(flow_predictions, flow, valid,
                                          args.gamma)
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)

            scaler.step(optimizer)
            scheduler.step()
            scaler.update()

            if args.gpu == 0:
                logger.push(metrics, image1, image2, flow, flow_predictions,
                            valid)

            if total_steps % VAL_FREQ == VAL_FREQ - 1:

                results = validate_VRKitti2(model.module, args, eval_loader,
                                            group)

                model.train()
                if args.stage != 'chairs':
                    model.module.freeze_bn()

                if args.gpu == 0:
                    logger_evaluation.write_dict(results, total_steps)
                    PATH = os.path.join(
                        logroot, '%s.pth' % (str(total_steps + 1).zfill(3)))
                    torch.save(model.state_dict(), PATH)

            total_steps += 1

            if total_steps > args.num_steps:
                should_keep_training = False
                break
        epoch = epoch + 1

    if args.gpu == 0:
        logger.close()
        PATH = os.path.join(logroot, 'final.pth')
        torch.save(model.state_dict(), PATH)

    return PATH
示例#5
0
def train(gpu, ngpus_per_node, args):
    print("Using GPU %d for training" % gpu)
    args.gpu = gpu

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=ngpus_per_node, rank=args.gpu)

    model = RAFT(args=args)
    if args.distributed:
        torch.cuda.set_device(args.gpu)
        args.batch_size = int(args.batch_size / ngpus_per_node)
        model = nn.SyncBatchNorm.convert_sync_batchnorm(module=model)
        model = model.to(f'cuda:{args.gpu}')
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True, output_device=args.gpu)
    else:
        model = torch.nn.DataParallel(model)
        model.cuda()

    logroot = os.path.join(args.logroot, args.name)
    print("Parameter Count: %d, saving location: %s" % (count_parameters(model), logroot))

    if args.restore_ckpt is not None:
        print("=> loading checkpoint '{}'".format(args.restore_ckpt))
        loc = 'cuda:{}'.format(args.gpu)
        checkpoint = torch.load(args.restore_ckpt, map_location=loc)
        model.load_state_dict(checkpoint, strict=False)

    model.train()

    train_entries, evaluation_entries = read_splits()

    train_dataset = KITTI_eigen(root=args.dataset_root, inheight=args.inheight, inwidth=args.inwidth, entries=train_entries, maxinsnum=args.maxinsnum,
                                depth_root=args.depth_root, depthvls_root=args.depthvlsgt_root, prediction_root=args.prediction_root, ins_root=args.ins_root,
                                istrain=True, muteaug=False, banremovedup=False, isgarg=True)
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None
    train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, pin_memory=True, num_workers=int(args.num_workers / ngpus_per_node), drop_last=True, sampler=train_sampler)

    eval_dataset = KITTI_eigen(root=args.dataset_root, inheight=args.evalheight, inwidth=args.evalwidth, entries=evaluation_entries, maxinsnum=args.maxinsnum,
                               depth_root=args.depth_root, depthvls_root=args.depthvlsgt_root, prediction_root=args.prediction_root, ins_root=args.ins_root, istrain=False, isgarg=True)
    eval_sampler = torch.utils.data.distributed.DistributedSampler(eval_dataset) if args.distributed else None
    eval_loader = data.DataLoader(eval_dataset, batch_size=1, pin_memory=True, num_workers=3, drop_last=True, sampler=eval_sampler)

    print("Training splits contain %d images while test splits contain %d images" % (train_dataset.__len__(), eval_dataset.__len__()))

    if args.distributed:
        group = dist.new_group([i for i in range(ngpus_per_node)])

    optimizer, scheduler = fetch_optimizer(args, model, int(train_dataset.__len__() / 2))

    total_steps = 0

    if args.gpu == 0:
        logger = Logger(logroot)
        logger_evaluation = Logger(os.path.join(args.logroot, 'evaluation_eigen_background', args.name))
        logger.create_summarywriter()
        logger_evaluation.create_summarywriter()

    VAL_FREQ = 5000
    epoch = 0
    minout = 1

    st = time.time()
    should_keep_training = True
    while should_keep_training:
        train_sampler.set_epoch(epoch)
        for i_batch, data_blob in enumerate(train_loader):
            optimizer.zero_grad()

            image1 = data_blob['img1'].cuda(gpu) / 255.0
            image2 = data_blob['img2'].cuda(gpu) / 255.0
            flowmap = data_blob['flowmap'].cuda(gpu)

            outputs = model(image1, image2)

            selector = (flowmap[:, 0, :, :] != 0)
            flow_loss = sequence_loss(outputs, flowmap, selector, gamma=args.gamma, max_flow=MAX_FLOW)

            metrics = dict()
            metrics['flow_loss'] = flow_loss

            loss = flow_loss
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)

            optimizer.step()
            scheduler.step()

            if args.gpu == 0:
                logger.write_dict(metrics, step=total_steps)
                if total_steps % SUM_FREQ == 0:
                    dr = time.time() - st
                    resths = (args.num_steps - total_steps) * dr / (total_steps + 1) / 60 / 60
                    print("Step: %d, rest hour: %f, flowloss: %f" % (total_steps, resths, flow_loss.item()))
                    logger.write_vls(data_blob, outputs, selector.unsqueeze(1), total_steps)

            if total_steps % VAL_FREQ == 1:
                if args.gpu == 0:
                    results = validate_kitti(model.module, args, eval_loader, logger, group, total_steps)
                else:
                    results = validate_kitti(model.module, args, eval_loader, None, group, None)

                if args.gpu == 0:
                    logger_evaluation.write_dict(results, total_steps)
                    if minout > results['out']:
                        minout = results['out']
                        PATH = os.path.join(logroot, 'minout.pth')
                        torch.save(model.state_dict(), PATH)
                        print("model saved to %s" % PATH)

                model.train()

            total_steps += 1

            if total_steps > args.num_steps:
                should_keep_training = False
                break
        epoch = epoch + 1

    if args.gpu == 0:
        logger.close()
        PATH = os.path.join(logroot, 'final.pth')
        torch.save(model.state_dict(), PATH)

    return
示例#6
0
            out = ((epe > 3.0) & ((epe / mag) > 0.05)).float()
            epe_list.append(epe[val].mean().item())
            out_list.append(out[val].cpu().numpy())

    epe_list = np.array(epe_list)
    out_list = np.concatenate(out_list)

    print("Validation KITTI: %f, %f" %
          (np.mean(epe_list), 100 * np.mean(out_list)))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', help="restore checkpoint")
    parser.add_argument('--small', action='store_true', help='use small model')
    parser.add_argument('--sintel_iters', type=int, default=50)
    parser.add_argument('--kitti_iters', type=int, default=32)

    args = parser.parse_args()

    model = RAFT(args)
    model = torch.nn.DataParallel(model)
    model.load_state_dict(torch.load(args.model))

    model.to('cuda')
    model.eval()

    validate_sintel(args, model, args.sintel_iters)
    validate_kitti(args, model, args.kitti_iters)
示例#7
0
def train(gpu, ngpus_per_node, args):
    print("Using GPU %d for training" % gpu)
    args.gpu = gpu

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=ngpus_per_node, rank=args.gpu)

    model = RAFT(args)
    if args.distributed:
        torch.cuda.set_device(args.gpu)
        args.batch_size = int(args.batch_size / ngpus_per_node)
        model = nn.SyncBatchNorm.convert_sync_batchnorm(module=model)
        model = model.to(f'cuda:{args.gpu}')
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True, output_device=args.gpu)

        eppCbck = eppConstrainer_background(height=args.image_size[0], width=args.image_size[1], bz=args.batch_size)
        eppCbck.to(f'cuda:{args.gpu}')

        eppconcluer = eppConcluer()
        eppconcluer.to(f'cuda:{args.gpu}')
    else:
        model = torch.nn.DataParallel(model)
        model.cuda()

    logroot = os.path.join(args.logroot, args.name)
    print("Parameter Count: %d, saving location: %s" % (count_parameters(model), logroot))

    if args.restore_ckpt is not None:
        print("=> loading checkpoint '{}'".format(args.restore_ckpt))
        loc = 'cuda:{}'.format(args.gpu)
        checkpoint = torch.load(args.restore_ckpt, map_location=loc)
        model.load_state_dict(checkpoint, strict=False)

    model.train()

    if args.stage != 'chairs':
        model.module.freeze_bn()

    train_entries, evaluation_entries = read_splits()

    aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}
    train_dataset = KITTI_eigen(aug_params, split='training', root=args.dataset_root, entries=train_entries, semantics_root=args.semantics_root, depth_root=args.depth_root)
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None
    train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, pin_memory=True,
                                   shuffle=(train_sampler is None), num_workers=int(args.num_workers / ngpus_per_node), drop_last=True,
                                   sampler=train_sampler)

    eval_dataset = KITTI_eigen(split='evaluation', root=args.dataset_root, entries=evaluation_entries, semantics_root=args.semantics_root, depth_root=args.depth_root)
    eval_sampler = torch.utils.data.distributed.DistributedSampler(eval_dataset) if args.distributed else None
    eval_loader = data.DataLoader(eval_dataset, batch_size=1, pin_memory=True,
                                   shuffle=(eval_sampler is None), num_workers=3, drop_last=True,
                                   sampler=eval_sampler)

    if args.distributed:
        group = dist.new_group([i for i in range(ngpus_per_node)])

    optimizer, scheduler = fetch_optimizer(args, model)

    total_steps = 0

    if args.gpu == 0:
        logger = Logger(model, scheduler, logroot, args.num_steps)
        logger_evaluation = Logger(model, scheduler, os.path.join(args.logroot, 'evaluation_eigen_background', args.name), args.num_steps)

    VAL_FREQ = 5000
    add_noise = False
    epoch = 0

    should_keep_training = True
    print(validate_kitti(model.module, args, eval_loader, eppCbck, eppconcluer, group))
    while should_keep_training:

        train_sampler.set_epoch(epoch)
        for i_batch, data_blob in enumerate(train_loader):
            optimizer.zero_grad()

            image1 = data_blob['img1']
            image1 = Variable(image1, requires_grad=True)
            image1 = image1.cuda(gpu, non_blocking=True)

            image2 = data_blob['img2']
            image2 = Variable(image2, requires_grad=True)
            image2 = image2.cuda(gpu, non_blocking=True)

            flow = data_blob['flow']
            flow = Variable(flow, requires_grad=True)
            flow = flow.cuda(gpu, non_blocking=True)

            valid = data_blob['valid']
            valid = Variable(valid, requires_grad=True)
            valid = valid.cuda(gpu, non_blocking=True)

            E = data_blob['E']
            E = Variable(E, requires_grad=True)
            E = E.cuda(gpu, non_blocking=True)

            semantic_selector = data_blob['semantic_selector']
            semantic_selector = Variable(semantic_selector, requires_grad=True)
            semantic_selector = semantic_selector.cuda(gpu, non_blocking=True)

            if add_noise:
                stdv = np.random.uniform(0.0, 5.0)
                image1 = (image1 + stdv * torch.randn(*image1.shape).cuda(gpu, non_blocking=True)).clamp(0.0, 255.0)
                image2 = (image2 + stdv * torch.randn(*image2.shape).cuda(gpu, non_blocking=True)).clamp(0.0, 255.0)

            flow_predictions = model(image1, image2, iters=args.iters)

            metrics = dict()
            loss_flow, metrics_flow = sequence_flowloss(flow_predictions, flow, valid, args.gamma)
            loss_eppc, metrics_eppc = sequence_eppcloss(eppCbck, flow_predictions, semantic_selector, E, args.gamma)

            metrics.update(metrics_flow)
            metrics.update(metrics_eppc)

            loss = loss_flow + loss_eppc * args.eppcw
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)

            optimizer.step()
            scheduler.step()

            if args.gpu == 0:
                logger.push(metrics, image1, image2, flow, flow_predictions, valid, data_blob['depth'])

            if total_steps % VAL_FREQ == VAL_FREQ - 1:

                results = validate_kitti(model.module, args, eval_loader, eppCbck, eppconcluer, group)

                model.train()
                if args.stage != 'chairs':
                    model.module.freeze_bn()

                if args.gpu == 0:
                    logger_evaluation.write_dict(results, total_steps)
                    PATH = os.path.join(logroot, '%s.pth' % (str(total_steps + 1).zfill(3)))
                    torch.save(model.state_dict(), PATH)

            total_steps += 1

            if total_steps > args.num_steps:
                should_keep_training = False
                break
        epoch = epoch + 1

    if args.gpu == 0:
        logger.close()
        PATH = os.path.join(logroot, 'final.pth')
        torch.save(model.state_dict(), PATH)

    return PATH
示例#8
0
def train(gpu, ngpus_per_node, args):
    print("Using GPU %d for training" % gpu)
    args.gpu = gpu

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=ngpus_per_node,
                                rank=args.gpu)

    model = RAFT(args=args)
    if args.distributed:
        torch.cuda.set_device(args.gpu)
        args.batch_size = int(args.batch_size / ngpus_per_node)
        model = nn.SyncBatchNorm.convert_sync_batchnorm(module=model)
        model = model.to(f'cuda:{args.gpu}')
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.gpu],
            find_unused_parameters=True,
            output_device=args.gpu)
    else:
        model = torch.nn.DataParallel(model)
        model.cuda()

    if args.restore_ckpt is not None:
        print("=> loading checkpoint '{}'".format(args.restore_ckpt))
        loc = 'cuda:{}'.format(args.gpu)
        checkpoint = torch.load(args.restore_ckpt, map_location=loc)
        model.load_state_dict(checkpoint, strict=False)

    evaluation_entries = read_splits_mapping()

    eval_dataset = KITTI_eigen_stereo15(root=args.dataset_stereo15_orgned_root,
                                        inheight=args.evalheight,
                                        inwidth=args.evalwidth,
                                        entries=evaluation_entries,
                                        mdPred_root=args.mdPred_root,
                                        maxinsnum=args.maxinsnum,
                                        istrain=True,
                                        isgarg=True,
                                        deepv2dpred_root=args.deepv2dpred_root,
                                        prediction_root=args.prediction_root,
                                        flowPred_root=args.flowPred_root)
    eval_sampler = torch.utils.data.distributed.DistributedSampler(
        eval_dataset) if args.distributed else None
    eval_loader = data.DataLoader(eval_dataset,
                                  batch_size=1,
                                  pin_memory=True,
                                  num_workers=3,
                                  drop_last=True,
                                  sampler=eval_sampler)

    print("Test splits contain %d images" % (eval_dataset.__len__()))

    if args.distributed:
        group = dist.new_group([i for i in range(ngpus_per_node)])

    # validate_RAFT_flow(args, model, eval_loader, group, usestreodepth=False)
    validate_RAFT_flow_pose(args,
                            model,
                            eval_loader,
                            group,
                            usestreodepth=False,
                            scale_info_src='mdPred')
    # validate_RAFT_flow_pose(args, model, eval_loader, group, usestreodepth=True, scale_info_src='mdPred')
    # validate_RAFT_flow_pose(args, model, eval_loader, group, usestreodepth=False, scale_info_src='deepv2d_depth')
    # validate_RAFT_flow_pose(args, model, eval_loader, group, usestreodepth=True, scale_info_src='stereo_depth')
    # validate_RAFT_flow_pose(args, model, eval_loader, group, usestreodepth=False, scale_info_src='stereo_depth')
    # validate_RAFT_flow_pose(args, model, eval_loader, group, usestreodepth=True, scale_info_src='deppv2d_pose')
    # validate_RAFT_flow_pose(args, model, eval_loader, group, usestreodepth=False, scale_info_src='deppv2d_pose')
    # validate_RAFT_flow_pose(args, model, eval_loader, group, usestreodepth=False)
    return
示例#9
0
def train(gpu, ngpus_per_node, args):
    print("Using GPU %d for training" % gpu)
    args.gpu = gpu

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=ngpus_per_node,
                                rank=args.gpu)

    model = RAFT(args=args)
    if args.distributed:
        torch.cuda.set_device(args.gpu)
        args.batch_size = int(args.batch_size / ngpus_per_node)
        model = nn.SyncBatchNorm.convert_sync_batchnorm(module=model)
        model = model.to(f'cuda:{args.gpu}')
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.gpu],
            find_unused_parameters=True,
            output_device=args.gpu)
    else:
        model = torch.nn.DataParallel(model)
        model.cuda()

    logroot = os.path.join(args.logroot, args.name)
    print("Parameter Count: %d, saving location: %s" %
          (count_parameters(model), logroot))

    if args.restore_ckpt is not None:
        print("=> loading checkpoint '{}'".format(args.restore_ckpt))
        loc = 'cuda:{}'.format(args.gpu)
        checkpoint = torch.load(args.restore_ckpt, map_location=loc)
        model.load_state_dict(checkpoint, strict=False)

    model.train()

    train_entries, evaluation_entries = read_splits()

    eval_dataset = KITTI_eigen(root=args.dataset_root,
                               inheight=args.evalheight,
                               inwidth=args.evalwidth,
                               entries=evaluation_entries,
                               maxinsnum=args.maxinsnum,
                               depth_root=args.depth_root,
                               depthvls_root=args.depthvlsgt_root,
                               prediction_root=args.prediction_root,
                               ins_root=args.ins_root,
                               istrain=False,
                               isgarg=True)
    eval_sampler = torch.utils.data.distributed.DistributedSampler(
        eval_dataset) if args.distributed else None
    eval_loader = data.DataLoader(eval_dataset,
                                  batch_size=1,
                                  pin_memory=True,
                                  num_workers=3,
                                  drop_last=True,
                                  sampler=eval_sampler)

    print("Test splits contain %d images" % (eval_dataset.__len__()))

    if args.distributed:
        group = dist.new_group([i for i in range(ngpus_per_node)])

    validate_kitti(model.module, args, eval_loader, group)
    return
示例#10
0
def train(gpu, ngpus_per_node, args):
    print("Using GPU %d for training" % gpu)
    args.gpu = gpu

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=ngpus_per_node,
                                rank=args.gpu)

    model = RAFT(args)
    if args.distributed:
        torch.cuda.set_device(args.gpu)
        args.batch_size = int(args.batch_size / ngpus_per_node)
        model = nn.SyncBatchNorm.convert_sync_batchnorm(module=model)
        model = model.to(f'cuda:{args.gpu}')
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.gpu],
            find_unused_parameters=True,
            output_device=args.gpu)
    else:
        model = torch.nn.DataParallel(model)
        model.cuda()
    logroot = os.path.join(args.logroot, args.name)
    print("Parameter Count: %d, saving location: %s" %
          (count_parameters(model), logroot))

    if args.restore_ckpt is not None:
        print("=> loading checkpoint '{}'".format(args.restore_ckpt))
        loc = 'cuda:{}'.format(args.gpu)
        checkpoint = torch.load(args.restore_ckpt, map_location=loc)
        model.load_state_dict(checkpoint, strict=False)

    model.train()

    train_entries, evaluation_entries = read_splits()
    train_dataset = VirtualKITTI2(args=args,
                                  root=args.dataset_root,
                                  inheight=args.inheight,
                                  inwidth=args.inwidth,
                                  entries=train_entries,
                                  istrain=True)
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset) if args.distributed else None
    train_loader = data.DataLoader(train_dataset,
                                   batch_size=args.batch_size,
                                   pin_memory=False,
                                   shuffle=(train_sampler is None),
                                   num_workers=args.num_workers,
                                   drop_last=True,
                                   sampler=train_sampler)

    eval_dataset = VirtualKITTI2(args=args,
                                 root=args.dataset_root,
                                 inheight=args.evalheight,
                                 inwidth=args.evalwidth,
                                 entries=evaluation_entries,
                                 istrain=False)
    eval_sampler = torch.utils.data.distributed.DistributedSampler(
        eval_dataset) if args.distributed else None
    eval_loader = data.DataLoader(eval_dataset,
                                  batch_size=args.batch_size,
                                  pin_memory=False,
                                  shuffle=(eval_sampler is None),
                                  num_workers=2,
                                  drop_last=True,
                                  sampler=eval_sampler)

    print(
        "Training split contains %d images, validation split contained %d images"
        % (len(train_entries), len(evaluation_entries)))

    if args.distributed:
        group = dist.new_group([i for i in range(ngpus_per_node)])

    optimizer, scheduler = fetch_optimizer(args, model)

    total_steps = 0
    VAL_ITERINC = 4

    if args.gpu == 0:
        logger = Logger(model, scheduler, logroot)
        logger.create_summarywriter()

        logger_evaluations = dict()
        for num_iters in range(args.iters, args.iters * 2 + 1, VAL_ITERINC):
            logger_evaluation = Logger(
                model, scheduler,
                os.path.join(
                    args.logroot, 'evaluation_VRKitti',
                    "{}_iternum{}".format(args.name,
                                          str(num_iters).zfill(2))))
            logger_evaluation.create_summarywriter()
            logger_evaluations[num_iters] = logger_evaluation

    VAL_FREQ = 2000
    maxout = 1
    epoch = 0

    st = time.time()
    should_keep_training = True
    while should_keep_training:

        for i_batch, data_blob in enumerate(train_loader):
            optimizer.zero_grad()

            image1 = data_blob['img1'].cuda(gpu, non_blocking=True)
            image2 = data_blob['img2'].cuda(gpu, non_blocking=True)
            flow = data_blob['flowmap'].cuda(gpu, non_blocking=True)

            # exlude invalid pixels and extremely large diplacements
            mag = torch.sum(flow**2, dim=1).sqrt()
            valid = ((flow[:, 0] != 0) * (flow[:, 1] != 0) *
                     (mag < MAX_FLOW)).unsqueeze(1)

            flow_predictions = model(image1, image2, iters=args.iters)
            loss, metrics = sequence_loss(flow_predictions, flow, valid,
                                          args.gamma)

            metrics = dict()
            metrics['loss_flow'] = loss.float().item()

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
            optimizer.step()
            scheduler.step()

            if args.gpu == 0:
                logger.write_dict(metrics, total_steps)
                if total_steps % SUM_FREQ == 0:
                    dr = time.time() - st
                    resths = (args.num_steps -
                              total_steps) * dr / (total_steps + 1) / 60 / 60
                    print("Step: %d, rest hour: %f, flow loss: %f" %
                          (total_steps, resths, loss.item()))
                    logger.write_vls(data_blob, flow_predictions, valid,
                                     total_steps)

            if total_steps % VAL_FREQ == 1:
                for num_iters in range(args.iters, args.iters * 2 + 1,
                                       VAL_ITERINC):

                    if args.gpu == 0 and num_iters == 24:
                        results = validate_VRKitti2(model.module, args,
                                                    eval_loader, num_iters,
                                                    group, logger, total_steps)
                    else:
                        results = validate_VRKitti2(model.module, args,
                                                    eval_loader, num_iters,
                                                    group, None, None)

                    if args.gpu == 0:
                        logger_evaluations[num_iters].write_dict(
                            results, total_steps)

                        if num_iters == 24:
                            if results['out'] < maxout:
                                maxout = results['out']
                                PATH = os.path.join(logroot, 'minout.pth')
                                torch.save(model.state_dict(), PATH)
                                print("model saved to %s" % PATH)

                model.train()
            total_steps += 1

            if total_steps > args.num_steps:
                should_keep_training = False
                break
        epoch = epoch + 1

    if args.gpu == 0:
        logger.close()
        PATH = os.path.join(logroot, 'final.pth')
        torch.save(model.state_dict(), PATH)

    return PATH