Example #1
0
 def test_load_data(self, mock_load, mock_download):
     sprites = ds.dSprites('root', False, None)
     self.assertTrue(mock_load.call_args_list[0][0][0] == 'root/imgs.npy')
     self.assertTrue(
         mock_load.call_args_list[1][0][0] == 'root/latents_values.npy')
     self.assertTrue(
         mock_load.call_args_list[2][0][0] == 'root/latents_classes.npy')
Example #2
0
    def test_get_item(self, mock_load, mock_load_data, mock_fromarray):
        import torch
        sprites = ds.dSprites('root', False, None)
        mock_data = torch.rand(5, 1)
        mock_data[1] *= 3
        sprites.data = mock_data
        out = sprites[1]

        self.assertTrue(mock_fromarray.call_args[0][0] == mock_data[1] * 255)
        self.assertTrue(mock_fromarray.call_args[1] == {'mode': 'L'})
        self.assertTrue(len(out) == 2)
Example #3
0
    def test_image_by_latent(self, mock_load, mock_gi, mock_download):
        import numpy as np
        mock_gi.return_value = (None, )

        bases = np.array([737280, 245760, 40960, 1024, 32, 1])
        latent = [1, 2, 3, 4, 5, 6]
        img_id = (bases * latent).sum()

        sprites = ds.dSprites('root', False, None)
        sprites.get_img_by_latent(latent)
        self.assertTrue(sprites.__getitem__.call_args[0][0] == img_id)
Example #4
0
    def test_get_item_transform(self, mock_load, mock_load_data, mock_fromarray):
        import torch
        transform = Mock()
        sprites = ds.dSprites('root', False, transform)
        mock_fromarray.return_value = 'test'
        mock_data = torch.rand(5, 1)
        mock_data[1] *= 3
        sprites.data = mock_data
        _ = sprites[1]

        self.assertTrue(transform.call_args[0][0] == 'test')
Example #5
0
    def test_download(self, mock_load, mock_load_data, mock_mkdirs, mock_copyfileobj, mock_zip, mock_open, mock_urlopen):
        mock_return = Mock()
        # mock_urlopen = Mock()
        # mock_url.return_value = mock_urlopen
        mock_urlopen.return_value.__enter__.return_value = mock_return

        root = 'root'
        filename = "dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz"

        sprites = ds.dSprites(root, True, None)

        self.assertTrue(mock_mkdirs.call_args[0][0] == "root")
        self.assertTrue(mock_mkdirs.call_args[1] == {'exist_ok': True})

        self.assertTrue(mock_urlopen.call_args[0][0] == 'https://github.com/deepmind/dsprites-dataset/blob/master/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz?raw=true')
        self.assertTrue(mock_open.call_args[0][0] == 'root/' + filename)

        self.assertTrue(mock_copyfileobj.call_args[0][0] == mock_return)
Example #6
0
    def test_len(self, mock_load, mock_load_data, mock_download):
        import numpy as np
        mock_load_data.return_value = np.zeros((10, 5, 5))
        sprites = ds.dSprites('root', False, None)

        self.assertTrue(len(sprites) == 10)
Example #7
0
 def test_latents_bases(self, mock_load, mock_load_data, mock_download):
     sprites = ds.dSprites('root', False, None)
     self.assertTrue(
         list(sprites.latents_bases) ==
         [737280, 245760, 40960, 1024, 32, 1])
Example #8
0
 def test_do_download(self, mock_load, mock_load_data, mock_download):
     sprites = ds.dSprites('root', True, None)
     self.assertTrue(sprites.download.call_count == 1)