예제 #1
0
    def test_pseudo_batching_mapping(self):
        transformer = SampleTransformer(self.dset)
        mapping = {"a": torch.rand(2, 3, 4), "b": torch.rand(3, 4, 5)}
        batched = transformer._change_pseudo_batch_dim(mapping, add=True)

        self.assertIsInstance(batched, type(mapping))
        self.assertEqual(len(mapping), len(batched))
        for key in batched.keys():
            with self.subTest(key=key):
                self.assertIn(key, mapping)

        for k, v in mapping.items():
            with self.subTest(k=k, v=v):
                self.assertEqual(v.ndim + 1, batched[k].ndim)

        unbatched = transformer._change_pseudo_batch_dim(batched, add=False)
        self.assertIsInstance(unbatched, type(mapping))
        self.assertEqual(len(mapping), len(unbatched))
        for key in unbatched.keys():
            with self.subTest(key=key):
                self.assertIn(key, mapping)

        for k, v in mapping.items():
            with self.subTest(k=k, v=v):
                self.assertEqual(v.ndim, unbatched[k].ndim)
예제 #2
0
    def test_pseudo_batching_str(self):
        transformer = SampleTransformer(self.dset)
        input_str = "abc"

        for add in [True, False]:
            with self.subTest(add=add):
                self.assertEqual(
                    input_str,
                    transformer._change_pseudo_batch_dim(input_str, add=add))
예제 #3
0
    def test_pseudo_batching_float(self):
        transformer = SampleTransformer(self.dset)
        input_float = 42.0

        for add in [True, False]:
            with self.subTest(add=add):
                self.assertEqual(
                    input_float,
                    transformer._change_pseudo_batch_dim(input_float, add=add))
예제 #4
0
    def test_pseudo_batching_array(self):
        transformer = SampleTransformer(self.dset)

        input_array = np.random.rand(2, 3, 4)
        batched = transformer._change_pseudo_batch_dim(input_array, add=True)
        self.assertTupleEqual(batched.shape, (1, 2, 3, 4))

        unbatched = transformer._change_pseudo_batch_dim(batched, add=False)
        self.assertTupleEqual(unbatched.shape, input_array.shape)
        self.assertTrue(np.allclose(unbatched, input_array))
예제 #5
0
    def test_pseudo_batching_tensor(self):
        transformer = SampleTransformer(self.dset)

        input_tensor = torch.rand(2, 3, 4)
        batched = transformer._change_pseudo_batch_dim(input_tensor, add=True)
        self.assertTupleEqual(batched.shape, (1, 2, 3, 4))

        unbatched = transformer._change_pseudo_batch_dim(batched, add=False)
        self.assertTupleEqual(unbatched.shape, input_tensor.shape)
        self.assertTrue(torch.allclose(unbatched, input_tensor))
예제 #6
0
    def test_pseudo_batch_dim_custom_obj(self):
        class Foo(object):
            self.bar = 5.0

        transformer = SampleTransformer(self.dset)
        foo = Foo()
        batched = transformer._change_pseudo_batch_dim(foo, add=True)
        unbatched = transformer._change_pseudo_batch_dim(batched, add=False)

        self.assertEqual(foo, batched)
        self.assertEqual(foo, unbatched)
        self.assertEqual(batched, unbatched)
예제 #7
0
    def test_pseudo_batch_dim_named_tuple(self):
        from collections import namedtuple
        Foo = namedtuple('Foo', 'bar')
        transformer = SampleTransformer(self.dset)

        foo = Foo(torch.tensor([2, 3, 4]))

        batched = transformer._change_pseudo_batch_dim(foo, add=True)
        self.assertIsInstance(batched, Foo)
        self.assertTupleEqual(batched.bar.shape,
                              tuple([1] + list(foo.bar.shape)))

        unbatched = transformer._change_pseudo_batch_dim(batched, add=False)
        self.assertIsInstance(unbatched, Foo)
        self.assertTrue(torch.allclose(unbatched.bar, foo.bar))
예제 #8
0
    def test_pseudo_batch_dim_sequence(self):
        transformer = SampleTransformer(self.dset)

        input_sequence = [torch.tensor([2, 3, 4]), torch.tensor([3, 4, 5])]
        batched = transformer._change_pseudo_batch_dim(input_sequence,
                                                       add=True)
        self.assertEqual(len(batched), len(input_sequence))

        for idx in range(len(input_sequence)):
            with self.subTest(idx=idx):
                self.assertEqual(input_sequence[idx].ndim + 1,
                                 batched[idx].ndim)

        unbatched = transformer._change_pseudo_batch_dim(batched, add=False)
        self.assertEqual(len(unbatched), len(input_sequence))

        for idx in range(len(input_sequence)):
            with self.subTest(idx=idx):
                self.assertTrue(
                    torch.allclose(unbatched[idx], input_sequence[idx]))
예제 #9
0
    def test_trafo_no_pseudo_batchdim(self):
        def trafo(**data):
            for k, v in data.items():
                self.assertTupleEqual(v.shape, (1, 28, 28))

            return data

        transformer = SampleTransformer(self.dset, trafo)
        for i in range(len(self.dset)):
            with self.subTest(idx=i):
                transformed = transformer[i]["data"]
                orig = self.dset[i]["data"]

                self.assertTrue(np.allclose(transformed, orig))
예제 #10
0
 def test_no_trafo(self):
     transformer = SampleTransformer(self.dset, None)
     for i in range(len(self.dset)):
         with self.subTest(idx=i):
             self.assertTrue(
                 np.allclose(transformer[0]["data"], self.dset[0]["data"]))