def test_collate_models(self):
        """
        Test collate_batched_meshes returns items of the correct shapes and types.
        Check that when collate_batched_meshes is passed to Dataloader, batches of
        the correct shapes and types are returned.
        """
        # Load ShapeNetCore without specifying any particular categories.
        shapenet_dataset = ShapeNetCore(SHAPENET_PATH)
        # Randomly retrieve several objects from the dataset.
        rand_idxs = torch.randint(len(shapenet_dataset), (6, ))
        rand_objs = [shapenet_dataset[idx] for idx in rand_idxs]

        # Collate the randomly selected objects
        collated_meshes = collate_batched_meshes(rand_objs)
        verts, faces = (collated_meshes["verts"], collated_meshes["faces"])
        self.assertEqual(len(verts), 6)
        self.assertEqual(len(faces), 6)

        # Pass the custom collate_fn function to DataLoader and check elements
        # in batch have the correct shape.
        batch_size = 12
        shapenet_core_loader = DataLoader(shapenet_dataset,
                                          batch_size=batch_size,
                                          collate_fn=collate_batched_meshes)
        it = iter(shapenet_core_loader)
        object_batch = next(it)
        self.assertEqual(len(object_batch["synset_id"]), batch_size)
        self.assertEqual(len(object_batch["model_id"]), batch_size)
        self.assertEqual(len(object_batch["label"]), batch_size)
        self.assertEqual(object_batch["mesh"].verts_padded().shape[0],
                         batch_size)
        self.assertEqual(object_batch["mesh"].faces_padded().shape[0],
                         batch_size)
Beispiel #2
0
    def test_collate_models(self):
        """
        Test collate_batched_meshes returns items of the correct shapes and types.
        Check that when collate_batched_meshes is passed to Dataloader, batches of
        the correct shapes and types are returned.
        """
        # Load dataset in the train split.
        r2n2_dataset = R2N2("train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)

        # Randomly retrieve several objects from the dataset and collate them.
        collated_meshes = collate_batched_meshes(
            [r2n2_dataset[idx] for idx in torch.randint(len(r2n2_dataset), (6,))]
        )
        # Check the collated verts and faces have the correct shapes.
        verts, faces = collated_meshes["verts"], collated_meshes["faces"]
        self.assertEqual(len(verts), 6)
        self.assertEqual(len(faces), 6)
        self.assertEqual(verts[0].shape[-1], 3)
        self.assertEqual(faces[0].shape[-1], 3)

        # Check the collated mesh has the correct shape.
        mesh = collated_meshes["mesh"]
        self.assertEqual(mesh.verts_padded().shape[0], 6)
        self.assertEqual(mesh.verts_padded().shape[-1], 3)
        self.assertEqual(mesh.faces_padded().shape[0], 6)
        self.assertEqual(mesh.faces_padded().shape[-1], 3)

        # Pass the custom collate_fn function to DataLoader and check elements
        # in batch have the correct shape.
        batch_size = 12
        r2n2_loader = DataLoader(
            r2n2_dataset, batch_size=batch_size, collate_fn=collate_batched_meshes
        )
        it = iter(r2n2_loader)
        object_batch = next(it)
        self.assertEqual(len(object_batch["synset_id"]), batch_size)
        self.assertEqual(len(object_batch["model_id"]), batch_size)
        self.assertEqual(len(object_batch["label"]), batch_size)
        self.assertEqual(object_batch["mesh"].verts_padded().shape[0], batch_size)
        self.assertEqual(object_batch["mesh"].faces_padded().shape[0], batch_size)
        self.assertEqual(object_batch["images"].shape[0], batch_size)