Exemplo n.º 1
0
def test_Combination():
    dataset_params = {
        'root': SHAPENET_ROOT,
        'categories': ['can'],
        'train': True,
        'split': .8,
    }
    # images = shapenet.ShapeNet_Images(root=SHAPENET_ROOT, cache_dir=CACHE_DIR,
    #                                   categories=['bowl'], views=1, train=True, split=.8)
    meshes = shapenet.ShapeNet_Meshes(**dataset_params)
    voxels = shapenet.ShapeNet_Voxels(**dataset_params,
                                      cache_dir=CACHE_DIR,
                                      resolutions=[32])
    sdf_points = shapenet.ShapeNet_SDF_Points(**dataset_params,
                                              cache_dir=CACHE_DIR,
                                              smoothing_iterations=3,
                                              num_points=500,
                                              occ=False,
                                              sample_box=True)

    points = shapenet.ShapeNet_Points(**dataset_params,
                                      cache_dir=CACHE_DIR,
                                      resolution=100,
                                      smoothing_iterations=3,
                                      num_points=500,
                                      surface=False,
                                      normals=True)

    dataset = shapenet.ShapeNet_Combination([voxels, sdf_points, points])

    for obj in dataset:
        obj_data = obj['data']
        assert set(obj['data']['sdf_points'].shape) == set([500, 3])
        assert set(obj['data']['sdf_distances'].shape) == set([500])
        assert set(obj['data']['32'].shape) == set([32, 32, 32])
        assert set(obj['data']['points'].shape) == set([500, 3])
        assert set(obj['data']['normals'].shape) == set([500, 3])

    train_loader = DataLoader(dataset,
                              batch_size=2,
                              shuffle=True,
                              num_workers=8)
    for batch in train_loader:
        assert set(batch['data']['sdf_points'].shape) == set([2, 500, 3])
        assert set(batch['data']['sdf_distances'].shape) == set([2, 500])
        assert set(batch['data']['32'].shape) == set([2, 32, 32, 32])
        assert set(batch['data']['points'].shape) == set([2, 500, 3])
        assert set(batch['data']['normals'].shape) == set([2, 500, 3])

    shutil.rmtree('tests/datasets/cache/sdf_points')
    shutil.rmtree('tests/datasets/cache/points')
    shutil.rmtree('tests/datasets/cache/voxels')
    shutil.rmtree('tests/datasets/cache/surface_meshes')
Exemplo n.º 2
0
def test_Voxels():
    voxels = shapenet.ShapeNet_Voxels(root=SHAPENET_ROOT,
                                      cache_dir=CACHE_DIR,
                                      categories=['can'],
                                      train=True,
                                      split=.7,
                                      resolutions=[32])
    assert len(voxels) == 75
    assert voxels.cache_dir.exists()
    assert len(list(voxels.cache_dir.rglob('*.npz'))) == 75
    for obj in voxels:
        # assert os.path.isfile(obj['32_name'])
        assert (set(obj['data']['32'].shape) == set([32, 32, 32]))

    voxels = shapenet.ShapeNet_Voxels(root=SHAPENET_ROOT,
                                      cache_dir=CACHE_DIR,
                                      categories=['can'],
                                      train=False,
                                      split=.7,
                                      resolutions=[32])
    assert len(voxels) == 33

    shutil.rmtree('tests/datasets/cache/voxels')
Exemplo n.º 3
0
parser.add_argument('--no-vis',
                    action='store_true',
                    help='Disable visualization of each model.')
args = parser.parse_args()

# Data
mesh_set = shapenet.ShapeNet_Surface_Meshes(root=args.shapenet_root,
                                            cache_dir=args.cache_dir,
                                            categories=args.categories,
                                            resolution=32,
                                            train=False,
                                            split=.7,
                                            mode='Tri')
voxel_set = shapenet.ShapeNet_Voxels(root=args.shapenet_root,
                                     cache_dir=args.cache_dir,
                                     categories=args.categories,
                                     train=False,
                                     resolutions=[32],
                                     split=.7)
valid_set = shapenet.ShapeNet_Combination([mesh_set, voxel_set])

encoder = MeshEncoder(30).to(args.device)
decoder = VoxelDecoder(30).to(args.device)

logdir = f'log/{args.expid}/AutoEncoder'
checkpoint = torch.load(os.path.join(logdir, 'best.ckpt'))
encoder.load_state_dict(checkpoint['encoder'])
decoder.load_state_dict(checkpoint['decoder'])

loss_epoch = 0.
num_batches = 0
num_items = 0