def test_pseudo_batching_mapping(self): transformer = SampleTransformer(self.dset) mapping = {"a": torch.rand(2, 3, 4), "b": torch.rand(3, 4, 5)} batched = transformer._change_pseudo_batch_dim(mapping, add=True) self.assertIsInstance(batched, type(mapping)) self.assertEqual(len(mapping), len(batched)) for key in batched.keys(): with self.subTest(key=key): self.assertIn(key, mapping) for k, v in mapping.items(): with self.subTest(k=k, v=v): self.assertEqual(v.ndim + 1, batched[k].ndim) unbatched = transformer._change_pseudo_batch_dim(batched, add=False) self.assertIsInstance(unbatched, type(mapping)) self.assertEqual(len(mapping), len(unbatched)) for key in unbatched.keys(): with self.subTest(key=key): self.assertIn(key, mapping) for k, v in mapping.items(): with self.subTest(k=k, v=v): self.assertEqual(v.ndim, unbatched[k].ndim)
def test_pseudo_batching_str(self): transformer = SampleTransformer(self.dset) input_str = "abc" for add in [True, False]: with self.subTest(add=add): self.assertEqual( input_str, transformer._change_pseudo_batch_dim(input_str, add=add))
def test_pseudo_batching_float(self): transformer = SampleTransformer(self.dset) input_float = 42.0 for add in [True, False]: with self.subTest(add=add): self.assertEqual( input_float, transformer._change_pseudo_batch_dim(input_float, add=add))
def test_pseudo_batching_array(self): transformer = SampleTransformer(self.dset) input_array = np.random.rand(2, 3, 4) batched = transformer._change_pseudo_batch_dim(input_array, add=True) self.assertTupleEqual(batched.shape, (1, 2, 3, 4)) unbatched = transformer._change_pseudo_batch_dim(batched, add=False) self.assertTupleEqual(unbatched.shape, input_array.shape) self.assertTrue(np.allclose(unbatched, input_array))
def test_pseudo_batching_tensor(self): transformer = SampleTransformer(self.dset) input_tensor = torch.rand(2, 3, 4) batched = transformer._change_pseudo_batch_dim(input_tensor, add=True) self.assertTupleEqual(batched.shape, (1, 2, 3, 4)) unbatched = transformer._change_pseudo_batch_dim(batched, add=False) self.assertTupleEqual(unbatched.shape, input_tensor.shape) self.assertTrue(torch.allclose(unbatched, input_tensor))
def test_pseudo_batch_dim_custom_obj(self): class Foo(object): self.bar = 5.0 transformer = SampleTransformer(self.dset) foo = Foo() batched = transformer._change_pseudo_batch_dim(foo, add=True) unbatched = transformer._change_pseudo_batch_dim(batched, add=False) self.assertEqual(foo, batched) self.assertEqual(foo, unbatched) self.assertEqual(batched, unbatched)
def test_pseudo_batch_dim_named_tuple(self): from collections import namedtuple Foo = namedtuple('Foo', 'bar') transformer = SampleTransformer(self.dset) foo = Foo(torch.tensor([2, 3, 4])) batched = transformer._change_pseudo_batch_dim(foo, add=True) self.assertIsInstance(batched, Foo) self.assertTupleEqual(batched.bar.shape, tuple([1] + list(foo.bar.shape))) unbatched = transformer._change_pseudo_batch_dim(batched, add=False) self.assertIsInstance(unbatched, Foo) self.assertTrue(torch.allclose(unbatched.bar, foo.bar))
def test_pseudo_batch_dim_sequence(self): transformer = SampleTransformer(self.dset) input_sequence = [torch.tensor([2, 3, 4]), torch.tensor([3, 4, 5])] batched = transformer._change_pseudo_batch_dim(input_sequence, add=True) self.assertEqual(len(batched), len(input_sequence)) for idx in range(len(input_sequence)): with self.subTest(idx=idx): self.assertEqual(input_sequence[idx].ndim + 1, batched[idx].ndim) unbatched = transformer._change_pseudo_batch_dim(batched, add=False) self.assertEqual(len(unbatched), len(input_sequence)) for idx in range(len(input_sequence)): with self.subTest(idx=idx): self.assertTrue( torch.allclose(unbatched[idx], input_sequence[idx]))
def test_trafo_no_pseudo_batchdim(self): def trafo(**data): for k, v in data.items(): self.assertTupleEqual(v.shape, (1, 28, 28)) return data transformer = SampleTransformer(self.dset, trafo) for i in range(len(self.dset)): with self.subTest(idx=i): transformed = transformer[i]["data"] orig = self.dset[i]["data"] self.assertTrue(np.allclose(transformed, orig))
def test_no_trafo(self): transformer = SampleTransformer(self.dset, None) for i in range(len(self.dset)): with self.subTest(idx=i): self.assertTrue( np.allclose(transformer[0]["data"], self.dset[0]["data"]))