def test_collate_models(self): """ Test collate_batched_meshes returns items of the correct shapes and types. Check that when collate_batched_meshes is passed to Dataloader, batches of the correct shapes and types are returned. """ # Load ShapeNetCore without specifying any particular categories. shapenet_dataset = ShapeNetCore(SHAPENET_PATH) # Randomly retrieve several objects from the dataset. rand_idxs = torch.randint(len(shapenet_dataset), (6, )) rand_objs = [shapenet_dataset[idx] for idx in rand_idxs] # Collate the randomly selected objects collated_meshes = collate_batched_meshes(rand_objs) verts, faces = (collated_meshes["verts"], collated_meshes["faces"]) self.assertEqual(len(verts), 6) self.assertEqual(len(faces), 6) # Pass the custom collate_fn function to DataLoader and check elements # in batch have the correct shape. batch_size = 12 shapenet_core_loader = DataLoader(shapenet_dataset, batch_size=batch_size, collate_fn=collate_batched_meshes) it = iter(shapenet_core_loader) object_batch = next(it) self.assertEqual(len(object_batch["synset_id"]), batch_size) self.assertEqual(len(object_batch["model_id"]), batch_size) self.assertEqual(len(object_batch["label"]), batch_size) self.assertEqual(object_batch["mesh"].verts_padded().shape[0], batch_size) self.assertEqual(object_batch["mesh"].faces_padded().shape[0], batch_size)
def test_collate_models(self): """ Test collate_batched_meshes returns items of the correct shapes and types. Check that when collate_batched_meshes is passed to Dataloader, batches of the correct shapes and types are returned. """ # Load dataset in the train split. r2n2_dataset = R2N2("train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH) # Randomly retrieve several objects from the dataset and collate them. collated_meshes = collate_batched_meshes( [r2n2_dataset[idx] for idx in torch.randint(len(r2n2_dataset), (6,))] ) # Check the collated verts and faces have the correct shapes. verts, faces = collated_meshes["verts"], collated_meshes["faces"] self.assertEqual(len(verts), 6) self.assertEqual(len(faces), 6) self.assertEqual(verts[0].shape[-1], 3) self.assertEqual(faces[0].shape[-1], 3) # Check the collated mesh has the correct shape. mesh = collated_meshes["mesh"] self.assertEqual(mesh.verts_padded().shape[0], 6) self.assertEqual(mesh.verts_padded().shape[-1], 3) self.assertEqual(mesh.faces_padded().shape[0], 6) self.assertEqual(mesh.faces_padded().shape[-1], 3) # Pass the custom collate_fn function to DataLoader and check elements # in batch have the correct shape. batch_size = 12 r2n2_loader = DataLoader( r2n2_dataset, batch_size=batch_size, collate_fn=collate_batched_meshes ) it = iter(r2n2_loader) object_batch = next(it) self.assertEqual(len(object_batch["synset_id"]), batch_size) self.assertEqual(len(object_batch["model_id"]), batch_size) self.assertEqual(len(object_batch["label"]), batch_size) self.assertEqual(object_batch["mesh"].verts_padded().shape[0], batch_size) self.assertEqual(object_batch["mesh"].faces_padded().shape[0], batch_size) self.assertEqual(object_batch["images"].shape[0], batch_size)