示例#1
0
文件: demo.py 项目: kmbriedis/RAFT
def demo(args):
    model = RAFT(args)
    model = torch.nn.DataParallel(model)
    model.load_state_dict(torch.load(args.model))

    model.to(DEVICE)
    model.eval()

    with torch.no_grad():

        # sintel images
        image1 = load_image('images/sintel_0.png')
        image2 = load_image('images/sintel_1.png')

        flow_predictions = model(image1, image2, iters=args.iters, upsample=False)
        display(image1[0], image2[0], flow_predictions[-1][0])

        # kitti images
        image1 = load_image('images/kitti_0.png')
        image2 = load_image('images/kitti_1.png')

        flow_predictions = model(image1, image2, iters=16)    
        display(image1[0], image2[0], flow_predictions[-1][0])

        # davis images
        image1 = load_image('images/davis_0.jpg')
        image2 = load_image('images/davis_1.jpg')

        flow_predictions = model(image1, image2, iters=16)    
        display(image1[0], image2[0], flow_predictions[-1][0])
示例#2
0
def train(gpu, ngpus_per_node, args):
    print("Using GPU %d for training" % gpu)
    args.gpu = gpu

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=ngpus_per_node, rank=args.gpu)

    model = RAFT(args)
    if args.distributed:
        torch.cuda.set_device(args.gpu)
        args.batch_size = int(args.batch_size / ngpus_per_node)
        model = nn.SyncBatchNorm.convert_sync_batchnorm(module=model)
        model = model.to(f'cuda:{args.gpu}')
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True, output_device=args.gpu)

        eppCbck = eppConstrainer_background(height=args.image_size[0], width=args.image_size[1], bz=args.batch_size)
        eppCbck.to(f'cuda:{args.gpu}')

        eppconcluer = eppConcluer()
        eppconcluer.to(f'cuda:{args.gpu}')
    else:
        model = torch.nn.DataParallel(model)
        model.cuda()


    if args.restore_ckpt is not None:
        print("=> loading checkpoint '{}'".format(args.restore_ckpt))
        loc = 'cuda:{}'.format(args.gpu)
        checkpoint = torch.load(args.restore_ckpt, map_location=loc)
        model.load_state_dict(checkpoint, strict=False)

    model.eval()

    if args.stage != 'chairs':
        model.module.freeze_bn()

    _, evaluation_entries = read_splits()

    eval_dataset = KITTI_eigen(split='evaluation', root=args.dataset_root, entries=evaluation_entries, semantics_root=args.semantics_root, depth_root=args.depth_root)
    eval_sampler = torch.utils.data.distributed.DistributedSampler(eval_dataset) if args.distributed else None
    eval_loader = data.DataLoader(eval_dataset, batch_size=1, pin_memory=True,
                                   shuffle=(eval_sampler is None), num_workers=4, drop_last=True,
                                   sampler=eval_sampler)

    if args.distributed:
        group = dist.new_group([i for i in range(ngpus_per_node)])

    print(validate_kitti(model.module, args, eval_loader, eppCbck, eppconcluer, group))
    return
示例#3
0
def RAFT(pretrained=False, model_name="chairs+things", device=None, **kwargs):
    """
    RAFT model (https://arxiv.org/abs/2003.12039)
    model_name (str): One of 'chairs+things', 'sintel', 'kitti' and 'small'
                      note that for 'small', the architecture is smaller
    """

    model_list = ["chairs+things", "sintel", "kitti", "small"]
    if model_name not in model_list:
        raise ValueError("Model should be one of " + str(model_list))

    model_args = argparse.Namespace(**kwargs)
    model_args.small = "small" in model_name

    model = RAFT_module(model_args)
    if device is None:
        device = torch.cuda.current_device() if torch.cuda.is_available(
        ) else "cpu"
    if device != "cpu":
        model = torch.nn.DataParallel(model, device_ids=[device])
    else:
        model = torch.nn.DataParallel(model)
        model.device_ids = None

    if pretrained:
        torch_home = _get_torch_home()
        model_dir = os.path.join(torch_home, "checkpoints", "models_RAFT")
        model_path = os.path.join(model_dir, "models", model_name + ".pth")
        if not os.path.exists(model_dir):
            os.makedirs(model_dir, exist_ok=True)
            response = urllib.request.urlopen(models_url, timeout=10)
            z = zipfile.ZipFile(io.BytesIO(response.read()))
            z.extractall(model_dir)
        else:
            time.sleep(
                10
            )  # Give the time for the models to be downloaded and unzipped

        map_location = torch.device('cpu') if device == "cpu" else None
        model.load_state_dict(torch.load(model_path,
                                         map_location=map_location))

    model = model.to(device)
    model.eval()
    return model
示例#4
0
            out = ((epe > 3.0) & ((epe / mag) > 0.05)).float()
            epe_list.append(epe[val].mean().item())
            out_list.append(out[val].cpu().numpy())

    epe_list = np.array(epe_list)
    out_list = np.concatenate(out_list)

    print("Validation KITTI: %f, %f" %
          (np.mean(epe_list), 100 * np.mean(out_list)))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', help="restore checkpoint")
    parser.add_argument('--small', action='store_true', help='use small model')
    parser.add_argument('--sintel_iters', type=int, default=50)
    parser.add_argument('--kitti_iters', type=int, default=32)

    args = parser.parse_args()

    model = RAFT(args)
    model = torch.nn.DataParallel(model)
    model.load_state_dict(torch.load(args.model))

    model.to('cuda')
    model.eval()

    validate_sintel(args, model, args.sintel_iters)
    validate_kitti(args, model, args.kitti_iters)