コード例 #1
0
def main():
    args = options()

    torch.backends.cudnn.deterministic = True
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)

    boardio = SummaryWriter(log_dir='checkpoints/' + args.exp_name)
    _init_(args)

    textio = IOStream('checkpoints/' + args.exp_name + '/run.log')
    textio.cprint(str(args))

    trainset = RegistrationData('PCRNet',
                                ModelNet40Data(train=True, download=True))
    testset = RegistrationData('PCRNet',
                               ModelNet40Data(train=False, download=True))
    train_loader = DataLoader(trainset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              drop_last=True,
                              num_workers=args.workers)
    test_loader = DataLoader(testset,
                             batch_size=args.batch_size,
                             shuffle=False,
                             drop_last=False,
                             num_workers=args.workers)

    if not torch.cuda.is_available():
        args.device = 'cpu'
    args.device = torch.device(args.device)

    # Create PointNet Model.
    ptnet = PointNet(emb_dims=args.emb_dims)
    model = iPCRNet(feature_model=ptnet)
    model = model.to(args.device)

    checkpoint = None
    if args.resume:
        assert os.path.isfile(args.resume)
        checkpoint = torch.load(args.resume)
        args.start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['model'])

    if args.pretrained:
        assert os.path.isfile(args.pretrained)
        model.load_state_dict(torch.load(args.pretrained, map_location='cpu'))
    model.to(args.device)

    if args.eval:
        test(args, model, test_loader, textio)
    else:
        train(args, model, train_loader, test_loader, boardio, textio,
              checkpoint)
コード例 #2
0
def main():
    args = options()
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)

    testset = RegistrationData("PCRNet",
                               ModelNet40Data(train=False, ),
                               is_testing=True)
    test_loader = DataLoader(
        testset,
        batch_size=args.batch_size,
        shuffle=False,
        drop_last=False,
        num_workers=args.workers,
    )

    if not torch.cuda.is_available():
        args.device = "cpu"
    args.device = torch.device(args.device)

    # Create PointNet Model.
    ptnet = PointNet(emb_dims=args.emb_dims)
    model = iPCRNet(feature_model=ptnet)

    # Create sampler
    if args.sampler == "samplenet":
        sampler = SampleNet(
            args.num_out_points,
            args.bottleneck_size,
            args.projection_group_size,
            skip_projection=args.skip_projection,
            input_shape="bnc",
            output_shape="bnc",
        )
    elif args.sampler == "fps":
        sampler = FPSSampler(args.num_out_points,
                             permute=True,
                             input_shape="bnc",
                             output_shape="bnc")
    else:
        sampler = None

    model.sampler = sampler
    model = model.to(args.device)

    model = model.to(args.device)

    if args.pretrained:
        assert os.path.isfile(args.pretrained)
        model.load_state_dict(torch.load(args.pretrained, map_location="cpu"))
    model.to(args.device)

    test(args, model, test_loader)
コード例 #3
0
def main():
	args = options()

	testset = RegistrationData('PCRNet', ModelNet40Data(train=False))
	test_loader = DataLoader(testset, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=args.workers)

	if not torch.cuda.is_available():
		args.device = 'cpu'
	args.device = torch.device(args.device)

	# Create PointNet Model.
	ptnet = PointNet(emb_dims=args.emb_dims)
	model = iPCRNet(feature_model=ptnet)
	model = model.to(args.device)

	if args.pretrained:
		assert os.path.isfile(args.pretrained)
		model.load_state_dict(torch.load(args.pretrained, map_location='cpu'))
	model.to(args.device)

	test(args, model, test_loader)
コード例 #4
0
def main():
    parser = sputils.get_parser()
    args = options(parser)

    torch.backends.cudnn.deterministic = True
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)

    boardio = SummaryWriter(log_dir="checkpoints/" + args.exp_name)
    _init_(args)

    textio = IOStream("checkpoints/" + args.exp_name + "/run.log")
    textio.cprint(str(args))

    # trainset = RegistrationData("PCRNet", ModelNet40Data(train=True, download=True))
    # testset = RegistrationData("PCRNet", ModelNet40Data(train=False, download=True))

    transforms = torchvision.transforms.Compose(
        [PointcloudToTensor(), OnUnitCube()])
    traindata = ModelNetCls(
        args.num_in_points,
        transforms=transforms,
        train=True,
        download=False,
        folder=args.datafolder,
    )
    testdata = ModelNetCls(
        args.num_in_points,
        transforms=transforms,
        train=False,
        download=False,
        folder=args.datafolder,
    )

    trainset = RegistrationData("PCRNet", traindata)
    testset = RegistrationData("PCRNet", testdata)

    train_loader = DataLoader(
        trainset,
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=args.workers,
    )
    test_loader = DataLoader(
        testset,
        batch_size=args.batch_size,
        shuffle=False,
        drop_last=False,
        num_workers=args.workers,
    )

    if not torch.cuda.is_available():
        args.device = "cpu"
    args.device = torch.device(args.device)

    # Create PointNet Model.
    ptnet = PointNet(emb_dims=args.emb_dims)
    model = iPCRNet(feature_model=ptnet)

    if args.train_pcrnet:
        model.requires_grad_(True)
        model.train()
    else:
        model.requires_grad_(False)
        model.eval()

    # Create sampler
    if args.sampler == "samplenet":
        sampler = SampleNet(
            args.num_out_points,
            args.bottleneck_size,
            args.projection_group_size,
            skip_projection=args.skip_projection,
            input_shape="bnc",
            output_shape="bnc",
        )
        if args.train_samplenet:
            sampler.requires_grad_(True)
            sampler.train()
        else:
            sampler.requires_grad_(False)
            sampler.eval()
    else:
        sampler = None

    model.sampler = sampler
    model = model.to(args.device)

    checkpoint = None
    if args.resume:
        assert os.path.isfile(args.resume)
        checkpoint = torch.load(args.resume)
        args.start_epoch = checkpoint["epoch"]
        model.load_state_dict(checkpoint["model"])

    if args.pretrained:
        assert os.path.isfile(args.pretrained)
        missing_keys, unexpected_keys = model.load_state_dict(torch.load(
            args.pretrained, map_location="cpu"),
                                                              strict=False)

        if len(missing_keys) != 0:
            print(f"Found missing keys in checkpoint: {missing_keys}")
        if len(unexpected_keys) != 0:
            raise RuntimeError(
                f"Found missing keys in model: {unexpected_keys}")

        filtered_missing_keys = [
            x for x in missing_keys if not x.startswith("sampler")
        ]
        if len(filtered_missing_keys) != 0:
            raise RuntimeError(
                f"Found missing keys in checkpoint: {filtered_missing_keys}")

    model.to(args.device)

    if args.eval:
        test(args, model, test_loader, textio)
    else:
        train(args, model, train_loader, test_loader, boardio, textio,
              checkpoint)