Exemple #1
0
def main(args):
    torch.backends.cudnn.deterministic = True
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)

    boardio = SummaryWriter(log_dir='checkpoints/' + args.exp_name)
    _init_(args)

    textio = IOStream('checkpoints/' + args.exp_name + '/run.log')
    textio.cprint(str(args))

    if args.dataset == 'modelnet40':
        train_loader = DataLoader(ModelNet40(
            num_points=args.num_points,
            partition='train',
            gaussian_noise=args.gaussian_noise,
            unseen=args.unseen,
            factor=args.factor),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  drop_last=True)
        test_loader = DataLoader(ModelNet40(num_points=args.num_points,
                                            partition='test',
                                            gaussian_noise=args.gaussian_noise,
                                            unseen=args.unseen,
                                            factor=args.factor),
                                 batch_size=args.test_batch_size,
                                 shuffle=False,
                                 drop_last=False)
    else:
        raise Exception("not implemented")

    if args.model == 'dcp':
        net = DCP(args).cuda()
        if args.eval:
            if args.model_path is '':
                model_path = 'checkpoints' + '/' + args.exp_name + '/models/model.best.t7'
            else:
                model_path = args.model_path
                print(model_path)
            if not os.path.exists(model_path):
                print("can't find pretrained model")
                return
            net.load_state_dict(torch.load(model_path), strict=False)
        if torch.cuda.device_count() > 1:
            net = nn.DataParallel(net)
            print("Let's use", torch.cuda.device_count(), "GPUs!")
    else:
        raise Exception('Not implemented')
    if args.eval:
        test(args, net, test_loader, boardio, textio)
    else:
        train(args, net, train_loader, test_loader, boardio, textio)

    print('FINISH')
    boardio.close()
Exemple #2
0
class DCPEval():
    def __init__(self):
        self.args = parser.parse_args()
        self.net = DCP(self.args)
        self.model_path = self.args.model_path
        self.net.load_state_dict(torch.load(self.model_path), strict=False)

    def run(self, src, tgt):
        return test_single_registration(self.net, src, tgt)
Exemple #3
0
def inference(args, name):
    net = DCP(args).cuda()
    model_path = args.model_path
    net.load_state_dict(torch.load(model_path), strict=False)

    #src_path = join('data/{}.{}'.format(name, 'ply'))
    #tgt_path = join('data/{}_trans.{}'.format(name, 'ply'))
    #src = o3d.io.read_point_cloud(src_path)
    #tgt = o3d.io.read_point_cloud(tgt_path)

    source, target, source_down, target_down, source_fpfh, target_fpfh = \
            prepare_dataset(VOXEL_SIZE, 'data/', name)

    _voxel_size = 0.001
    num_points = np.asarray(source.points).shape[0]
    while True:
        source = source.voxel_down_sample(voxel_size=_voxel_size)
        target = target.voxel_down_sample(voxel_size=_voxel_size)
        down_points = np.asarray(source.points).shape[0]
        print('[{}] Number of points: {} -> {}'.format(name, num_points,
                                                       down_points))
        if down_points > 20000:
            _voxel_size += 0.0002
        else:
            break

    source = np.asarray(source.points)
    target = np.asarray(target.points)

    print(source.shape)

    start_time = time.time()
    result = test_single_registration(net, source, target)
    total_time = time.time() - start_time
    transformation = get_transformation(result[0][0].cpu().detach().numpy(),
                                        result[1][0].cpu().detach().numpy())

    source = numpy_to_point_cloud(source.astype(np.float64))
    target = numpy_to_point_cloud(target.astype(np.float64))
    result_dcp = o3d.registration.evaluate_registration(
        source, target, 0.1, transformation=transformation)

    print('[{}] Result DCP: {} ({:.2f})'.format(name, result_dcp, total_time))

    result_icp = refine_registration(source_down, target_down, source_fpfh,
                                     target_fpfh, VOXEL_SIZE, result_dcp,
                                     'point')
    print('[{}] Result ICP: {} ({:.2f})'.format(name, result_icp, 0.))
Exemple #4
0
def main():
    parser = argparse.ArgumentParser(description='Point Cloud Registration')
    parser.add_argument('--exp_name',
                        type=str,
                        default='exp',
                        metavar='N',
                        help='Name of the experiment')
    parser.add_argument('--model',
                        type=str,
                        default='dcp',
                        metavar='N',
                        choices=['dcp'],
                        help='Model to use, [dcp]')
    parser.add_argument('--emb_nn',
                        type=str,
                        default='pointnet',
                        metavar='N',
                        choices=['pointnet', 'dgcnn'],
                        help='Embedding nn to use, [pointnet, dgcnn]')
    parser.add_argument(
        '--pointer',
        type=str,
        default='transformer',
        metavar='N',
        choices=['identity', 'transformer'],
        help='Attention-based pointer generator to use, [identity, transformer]'
    )
    parser.add_argument('--head',
                        type=str,
                        default='svd',
                        metavar='N',
                        choices=[
                            'mlp',
                            'svd',
                        ],
                        help='Head to use, [mlp, svd]')
    parser.add_argument('--emb_dims',
                        type=int,
                        default=512,
                        metavar='N',
                        help='Dimension of embeddings')
    parser.add_argument('--n_blocks',
                        type=int,
                        default=1,
                        metavar='N',
                        help='Num of blocks of encoder&decoder')
    parser.add_argument('--n_heads',
                        type=int,
                        default=4,
                        metavar='N',
                        help='Num of heads in multiheadedattention')
    parser.add_argument('--ff_dims',
                        type=int,
                        default=1024,
                        metavar='N',
                        help='Num of dimensions of fc in transformer')
    parser.add_argument('--dropout',
                        type=float,
                        default=0.0,
                        metavar='N',
                        help='Dropout ratio in transformer')
    parser.add_argument('--batch_size',
                        type=int,
                        default=32,
                        metavar='batch_size',
                        help='Size of batch)')
    parser.add_argument('--test_batch_size',
                        type=int,
                        default=10,
                        metavar='batch_size',
                        help='Size of batch)')
    parser.add_argument('--epochs',
                        type=int,
                        default=250,
                        metavar='N',
                        help='number of episode to train ')
    parser.add_argument('--use_sgd',
                        action='store_true',
                        default=False,
                        help='Use SGD')
    parser.add_argument(
        '--lr',
        type=float,
        default=0.001,
        metavar='LR',
        help='learning rate (default: 0.001, 0.1 if using sgd)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.9,
                        metavar='M',
                        help='SGD momentum (default: 0.9)')
    parser.add_argument('--no_cuda',
                        action='store_true',
                        default=False,
                        help='enables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=1234,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--eval',
                        action='store_true',
                        default=False,
                        help='evaluate the model')
    parser.add_argument('--cycle',
                        type=bool,
                        default=False,
                        metavar='N',
                        help='Whether to use cycle consistency')
    parser.add_argument('--gaussian_noise',
                        type=bool,
                        default=False,
                        metavar='N',
                        help='Wheter to add gaussian noise')
    parser.add_argument('--unseen',
                        type=bool,
                        default=False,
                        metavar='N',
                        help='Wheter to test on unseen category')
    parser.add_argument('--num_points',
                        type=int,
                        default=1024,
                        metavar='N',
                        help='Num of points to use')
    parser.add_argument('--dataset',
                        type=str,
                        default='modelnet40',
                        choices=['modelnet40'],
                        metavar='N',
                        help='dataset to use')
    parser.add_argument('--factor',
                        type=float,
                        default=4,
                        metavar='N',
                        help='Divided factor for rotations')
    parser.add_argument('--model_path',
                        type=str,
                        default='pretrained/dcp_v1.t7',
                        metavar='N',
                        help='Pretrained model path')

    args = parser.parse_args()
    torch.backends.cudnn.deterministic = True
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)

    boardio = SummaryWriter(log_dir='checkpoints/' + args.exp_name)
    _init_(args)

    textio = IOStream('checkpoints/' + args.exp_name + '/run.log')
    textio.cprint(str(args))

    if args.dataset == 'modelnet40':
        train_loader = DataLoader(ModelNet40(
            num_points=args.num_points,
            partition='train',
            gaussian_noise=args.gaussian_noise,
            unseen=args.unseen,
            factor=args.factor),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  drop_last=True)
        test_loader = DataLoader(ModelNet40(num_points=args.num_points,
                                            partition='test',
                                            gaussian_noise=args.gaussian_noise,
                                            unseen=args.unseen,
                                            factor=args.factor),
                                 batch_size=args.test_batch_size,
                                 shuffle=False,
                                 drop_last=False)
    else:
        raise Exception("not implemented")

    if args.model == 'dcp':
        net = DCP(args).cuda()
        print("Model Parameters")
        count_parameters(net)
        if args.eval:
            if args.model_path is '':
                model_path = 'checkpoints' + '/' + args.exp_name + '/models/model.best.t7'
            else:
                model_path = args.model_path
                print(model_path)
            if not os.path.exists(model_path):
                print("can't find pretrained model")
                return
            net.load_state_dict(torch.load(model_path), strict=False)
        if torch.cuda.device_count() > 1:
            net = nn.DataParallel(net)
            print("Let's use", torch.cuda.device_count(), "GPUs!")
    else:
        raise Exception('Not implemented')

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    if args.eval:
        start.record()
        test(args, net, test_loader, boardio, textio)
        end.record()
        torch.cuda.synchronize()
        print("Time to test: ", start.elapsed_time(end))
    else:
        train(args, net, train_loader, test_loader, boardio, textio)

    print('FINISH')
    boardio.close()
Exemple #5
0
def main():
    parser = argparse.ArgumentParser(description='Point Cloud Registration')
    parser.add_argument('--exp_name',
                        type=str,
                        default='exp',
                        metavar='N',
                        help='Name of the experiment')
    parser.add_argument('--model',
                        type=str,
                        default='dcp',
                        metavar='N',
                        choices=['dcp'],
                        help='Model to use, [dcp]')
    parser.add_argument('--emb_nn',
                        type=str,
                        default='dgcnn',
                        metavar='N',
                        choices=['pointnet', 'dgcnn'],
                        help='Embedding nn to use, [pointnet, dgcnn]')
    parser.add_argument(
        '--pointer',
        type=str,
        default='transformer',
        metavar='N',
        choices=['identity', 'transformer'],
        help='Attention-based pointer generator to use, [identity, transformer]'
    )
    parser.add_argument('--head',
                        type=str,
                        default='svd',
                        metavar='N',
                        choices=[
                            'mlp',
                            'svd',
                        ],
                        help='Head to use, [mlp, svd]')
    parser.add_argument('--emb_dims',
                        type=int,
                        default=512,
                        metavar='N',
                        help='Dimension of embeddings')
    parser.add_argument('--n_blocks',
                        type=int,
                        default=1,
                        metavar='N',
                        help='Num of blocks of encoder&decoder')
    parser.add_argument('--n_heads',
                        type=int,
                        default=16,
                        metavar='N',
                        help='Num of heads in multiheadedattention')
    parser.add_argument('--ff_dims',
                        type=int,
                        default=1024,
                        metavar='N',
                        help='Num of dimensions of fc in transformer')
    parser.add_argument('--dropout',
                        type=float,
                        default=0.0,
                        metavar='N',
                        help='Dropout ratio in transformer')
    parser.add_argument('--batch_size',
                        type=int,
                        default=16,
                        metavar='batch_size',
                        help='Size of batch)')
    parser.add_argument('--test_batch_size',
                        type=int,
                        default=10,
                        metavar='batch_size',
                        help='Size of batch)')
    parser.add_argument('--epochs',
                        type=int,
                        default=250,
                        metavar='N',
                        help='number of episode to train ')
    parser.add_argument('--use_sgd',
                        action='store_true',
                        default=False,
                        help='Use SGD')
    parser.add_argument(
        '--lr',
        type=float,
        default=0.001,
        metavar='LR',
        help='learning rate (default: 0.001, 0.1 if using sgd)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.9,
                        metavar='M',
                        help='SGD momentum (default: 0.9)')
    parser.add_argument('--no_cuda',
                        action='store_true',
                        default=False,
                        help='enables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=1234,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--eval',
                        action='store_true',
                        default=False,
                        help='evaluate the model')
    parser.add_argument('--cycle',
                        type=bool,
                        default=False,
                        metavar='N',
                        help='Whether to use cycle consistency')
    parser.add_argument('--gaussian_noise',
                        type=bool,
                        default=False,
                        metavar='N',
                        help='Wheter to add gaussian noise')
    parser.add_argument('--unseen',
                        type=bool,
                        default=False,
                        metavar='N',
                        help='Wheter to test on unseen category')
    parser.add_argument('--num_points',
                        type=int,
                        default=1024,
                        metavar='N',
                        help='Num of points to use')
    parser.add_argument('--dataset',
                        type=str,
                        default='modelnet40',
                        choices=['modelnet40', 'threedmatch'],
                        metavar='N',
                        help='dataset to use')
    parser.add_argument('--factor',
                        type=float,
                        default=4,
                        metavar='N',
                        help='Divided factor for rotations')
    parser.add_argument('--model_path',
                        type=str,
                        default='',
                        metavar='N',
                        help='Pretrained model path')
    parser.add_argument('--betas',
                        type=float,
                        default=(0.9, 0.999),
                        metavar='N',
                        nargs='+',
                        help='Betas for adam')
    parser.add_argument('--same_pointclouds',
                        type=bool,
                        default=True,
                        metavar='N',
                        help='R*src + t should be exactly same as target')
    parser.add_argument('--debug',
                        type=bool,
                        default=False,
                        metavar='N',
                        help='saves variables in folder variables_storage')
    parser.add_argument('--num_itr_test',
                        type=int,
                        default=1,
                        metavar='N',
                        help='Num of net() during testing')
    parser.add_argument(
        '--loss',
        type=str,
        default='cross_entropy_corr',
        metavar='N',
        choices=['cross_entropy_corr', 'mse_transf'],
        help='loss function: choose one of [mse_transf or cross_entropy_corr]')
    parser.add_argument('--cut_plane',
                        type=bool,
                        default=False,
                        metavar='N',
                        help='generates partial data')
    parser.add_argument('--one_cloud',
                        type=bool,
                        default=False,
                        metavar='N',
                        help='test for one unseen cloud')
    parser.add_argument(
        '--partial',
        type=float,
        default=0.0,
        metavar='N',
        help='partial = 0.1 ==> (num_points*partial) will be removed')
    parser.add_argument('--pretrained',
                        type=bool,
                        default=False,
                        metavar='N',
                        help='load pretrained model')

    args = parser.parse_args()

    # for deterministic training
    torch.backends.cudnn.deterministic = True
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)

    boardio = SummaryWriter(log_dir='checkpoints/' + args.exp_name)
    _init_(args)

    textio = IOStream('checkpoints/' + args.exp_name + '/run.log')
    textio.cprint(str(args))

    # dataloading
    num_workers = 32
    if args.dataset == 'modelnet40':
        train_dataset = ModelNet40(num_points=args.num_points,
                                   partition='train',
                                   gaussian_noise=args.gaussian_noise,
                                   unseen=args.unseen,
                                   factor=args.factor,
                                   same_pointclouds=args.same_pointclouds,
                                   partial=args.partial,
                                   cut_plane=args.cut_plane)
        train_loader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  drop_last=True,
                                  num_workers=num_workers)

        test_dataset = ModelNet40(num_points=args.num_points,
                                  partition='test',
                                  gaussian_noise=args.gaussian_noise,
                                  unseen=args.unseen,
                                  factor=args.factor,
                                  same_pointclouds=args.same_pointclouds,
                                  partial=args.partial,
                                  cut_plane=args.cut_plane)
        test_loader = DataLoader(test_dataset,
                                 batch_size=args.test_batch_size,
                                 shuffle=False,
                                 drop_last=False,
                                 num_workers=num_workers)

    else:
        raise Exception("not implemented")

    # model loading
    if args.model == 'dcp':
        net = DCP(args).cuda()
        if args.eval:
            if args.model_path is '':
                model_path = 'checkpoints' + '/' + args.exp_name + '/models/model.best.t7'
            else:
                model_path = args.model_path
                print("Model loaded from ", model_path)
            if not os.path.exists(model_path):
                print("can't find pretrained model")
                return
            net.load_state_dict(torch.load(model_path), strict=False)
        if args.pretrained:
            if args.model_path == '':
                print(
                    'Please specify path to pretrained weights \n For Ex: checkpoints/partial_global_512_identical/models/model.best.t7'
                )
            else:
                model_path = args.model_path
            print("Using pretrained weights stored at:\n{}".format(model_path))
            net.load_state_dict(torch.load(model_path), strict=False)

        if torch.cuda.device_count() > 1:
            net = nn.DataParallel(net)
            print("Let's use", torch.cuda.device_count(), "GPUs!")
    else:
        raise Exception('Not implemented')

    # training and evaluation
    if args.eval:
        if args.one_cloud:  # testing on a single point cloud
            print("one_cloud")
            test_bunny(args, net)

        else:
            test(args, net, test_loader, boardio, textio)

    else:
        train(args, net, train_loader, test_loader, boardio, textio)

    print('FINISH')
    boardio.close()
Exemple #6
0
def main():
    parser = argparse.ArgumentParser(description='Point Cloud Registration')
    parser.add_argument('--exp_name',
                        type=str,
                        default='exp',
                        metavar='N',
                        help='Name of the experiment')
    parser.add_argument('--model',
                        type=str,
                        default='dcp',
                        metavar='N',
                        choices=['dcp', 'dcflow', 'unsupervised_dcflow'],
                        help='Model to use, [dcp, dcflow]')
    parser.add_argument('--emb_nn',
                        type=str,
                        default='pointnet',
                        metavar='N',
                        choices=['pointnet', 'dgcnn'],
                        help='Embedding nn to use, [pointnet, dgcnn]')
    parser.add_argument(
        '--pointer',
        type=str,
        default='transformer',
        metavar='N',
        choices=['identity', 'transformer'],
        help='Attention-based pointer generator to use, [identity, transformer]'
    )
    parser.add_argument('--head',
                        type=str,
                        default='svd',
                        metavar='N',
                        choices=['mlp', 'svd', 'pointnet'],
                        help='Head to use, [mlp, svd, pointnet]')
    parser.add_argument('--emb_dims',
                        type=int,
                        default=512,
                        metavar='N',
                        help='Dimension of embeddings')
    parser.add_argument('--n_blocks',
                        type=int,
                        default=1,
                        metavar='N',
                        help='Num of blocks of encoder&decoder')
    parser.add_argument('--n_heads',
                        type=int,
                        default=4,
                        metavar='N',
                        help='Num of heads in multiheadedattention')
    parser.add_argument('--ff_dims',
                        type=int,
                        default=1024,
                        metavar='N',
                        help='Num of dimensions of fc in transformer')
    parser.add_argument('--dropout',
                        type=float,
                        default=0.0,
                        metavar='N',
                        help='Dropout ratio in transformer')
    parser.add_argument('--batch_size',
                        type=int,
                        default=32,
                        metavar='batch_size',
                        help='Size of batch)')
    parser.add_argument('--test_batch_size',
                        type=int,
                        default=10,
                        metavar='batch_size',
                        help='Size of batch)')
    parser.add_argument('--epochs',
                        type=int,
                        default=250,
                        metavar='N',
                        help='number of episode to train ')
    parser.add_argument('--use_sgd',
                        action='store_true',
                        default=False,
                        help='Use SGD')
    parser.add_argument(
        '--lr',
        type=float,
        default=0.001,
        metavar='LR',
        help='learning rate (default: 0.001, 0.1 if using sgd)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.9,
                        metavar='M',
                        help='SGD momentum (default: 0.9)')
    parser.add_argument('--no_cuda',
                        action='store_true',
                        default=False,
                        help='enables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=1234,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--eval',
                        action='store_true',
                        default=False,
                        help='evaluate the model')
    parser.add_argument('--cycle',
                        type=bool,
                        default=False,
                        metavar='N',
                        help='Whether to use cycle consistency')
    parser.add_argument('--gaussian_noise',
                        type=bool,
                        default=False,
                        metavar='N',
                        help='Wheter to add gaussian noise')
    parser.add_argument('--unseen',
                        type=bool,
                        default=False,
                        metavar='N',
                        help='Wheter to test on unseen category')
    parser.add_argument('--num_points',
                        type=int,
                        default=1024,
                        metavar='N',
                        help='Num of points to use')
    parser.add_argument('--dataset',
                        type=str,
                        default='modelnet40',
                        choices=[
                            'modelnet40', 'kitti2015reg', 'kitti2015flow',
                            'flyingthings3dflow'
                        ],
                        metavar='N',
                        help='dataset to use')
    parser.add_argument('--factor',
                        type=float,
                        default=4,
                        metavar='N',
                        help='Divided factor for rotations')
    parser.add_argument('--model_path',
                        type=str,
                        default='',
                        metavar='N',
                        help='Pretrained model path')
    parser.add_argument('--k',
                        type=int,
                        default=20,
                        metavar='N',
                        help='Num of nearest neighbors to use')
    parser.add_argument('--display_scene_flow',
                        action='store_true',
                        default=False,
                        help='view the scene flow at testing')
    parser.add_argument(
        '--onlytrain',
        action='store_true',
        default=False,
        help='Only performs training when --eval is not passed')
    parser.add_argument(
        '--resume_training',
        action='store_true',
        default=False,
        help='Resume training from model_path or best checkpoint')
    parser.add_argument('--eval_full',
                        action='store_true',
                        default=False,
                        help='Eval on entire pointcloud')

    args = parser.parse_args()
    torch.backends.cudnn.deterministic = True
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)

    boardio = SummaryWriter(log_dir='checkpoints/' + args.exp_name)
    _init_(args)

    textio = IOStream('checkpoints/' + args.exp_name + '/run.log')
    textio.cprint(str(args))

    if args.dataset == 'modelnet40':
        train_loader = DataLoader(ModelNet40(
            num_points=args.num_points,
            partition='train',
            gaussian_noise=args.gaussian_noise,
            unseen=args.unseen,
            factor=args.factor),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  drop_last=True)
        test_loader = DataLoader(ModelNet40(num_points=args.num_points,
                                            partition='test',
                                            gaussian_noise=args.gaussian_noise,
                                            unseen=args.unseen,
                                            factor=args.factor),
                                 batch_size=args.test_batch_size,
                                 shuffle=False,
                                 drop_last=False)
    elif args.dataset == 'kitti2015reg':
        train_loader = DataLoader(Kitti2015Reg(
            num_points=args.num_points,
            partition='train',
            gaussian_noise=args.gaussian_noise,
            unseen=args.unseen,
            factor=args.factor),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  drop_last=True)
        test_loader = DataLoader(Kitti2015Reg(
            num_points=args.num_points,
            partition='test',
            gaussian_noise=args.gaussian_noise,
            unseen=args.unseen,
            factor=args.factor),
                                 batch_size=args.test_batch_size,
                                 shuffle=False,
                                 drop_last=False)
    elif args.dataset == 'kitti2015flow' or args.dataset == 'flyingthings3dflow':
        train_loader = DataLoader(SceneFlow(dataset_name=args.dataset,
                                            num_points=args.num_points,
                                            partition='train',
                                            gaussian_noise=args.gaussian_noise,
                                            unseen=args.unseen,
                                            factor=args.factor),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  drop_last=True)
        if args.onlytrain:
            test_loader = None
        else:
            if args.eval_full:
                partition = 'full'
            else:
                partition = 'test'
            test_loader = DataLoader(SceneFlow(
                dataset_name=args.dataset,
                num_points=args.num_points,
                partition=partition,
                gaussian_noise=args.gaussian_noise,
                unseen=args.unseen,
                factor=args.factor),
                                     batch_size=args.test_batch_size,
                                     shuffle=False,
                                     drop_last=False)
    else:
        raise Exception("not implemented")

    if args.model == 'dcp' and args.dataset != 'kitti2015flow':
        net = DCP(args).cuda()
        if args.eval:
            if args.model_path is '':
                model_path = 'checkpoints' + '/' + args.exp_name + '/models/model.best.t7'
            else:
                model_path = args.model_path
                print(model_path)
            if not os.path.exists(model_path):
                print("can't find pretrained model")
                return
            net.load_state_dict(torch.load(model_path), strict=False)
    elif args.model == 'dcflow':
        net = DCFlow(args).cuda()
        if args.eval or args.resume_training:
            if args.model_path is '':
                model_path = 'checkpoints' + '/' + args.exp_name + '/models/model.best.t7'
            else:
                model_path = args.model_path
                print(model_path)
            if not os.path.exists(model_path):
                print("can't find pretrained/checkpoint model")
                return
            net.load_state_dict(torch.load(model_path), strict=False)
    elif args.model == 'unsupervised_dcflow':
        net = UnsupervisedDCFlow(args).cuda()
        if args.eval or args.resume_training:
            if args.model_path is '':
                model_path = 'checkpoints' + '/' + args.exp_name + '/models/model.best.t7'
            else:
                model_path = args.model_path
                print(model_path)
            if not os.path.exists(model_path):
                print("can't find pretrained/checkpoint model")
                return
            net.load_state_dict(torch.load(model_path), strict=False)
    else:
        raise Exception('Not implemented')

    if torch.cuda.device_count() > 1:
        net = nn.DataParallel(net)
        print("Let's use", torch.cuda.device_count(), "GPUs!")

    if args.eval:
        if args.model == 'dcp':
            from registration import test
            test(args, net, test_loader, boardio, textio)
        elif args.model == 'dcflow':
            from scene_flow import test_flow
            test_flow(args, net, test_loader, boardio, textio)
        elif args.model == 'unsupervised_dcflow':
            from unsupervised_scene_flow import test_flow
            test_flow(args, net, test_loader, boardio, textio)
    else:
        if args.model == 'dcp':
            from registration import train
            train(args, net, train_loader, test_loader, boardio, textio)
        elif args.model == 'dcflow':
            from scene_flow import train_flow
            train_flow(args, net, train_loader, test_loader, boardio, textio)
        elif args.model == 'unsupervised_dcflow':
            from unsupervised_scene_flow import train_flow
            train_flow(args, net, train_loader, test_loader, boardio, textio)

    print('FINISH')
    boardio.close()
Exemple #7
0
 def __init__(self):
     self.args = parser.parse_args()
     self.net = DCP(self.args)
     self.model_path = self.args.model_path
     self.net.load_state_dict(torch.load(self.model_path), strict=False)