def test_load_data(self, mock_load, mock_download): sprites = ds.dSprites('root', False, None) self.assertTrue(mock_load.call_args_list[0][0][0] == 'root/imgs.npy') self.assertTrue( mock_load.call_args_list[1][0][0] == 'root/latents_values.npy') self.assertTrue( mock_load.call_args_list[2][0][0] == 'root/latents_classes.npy')
def test_get_item(self, mock_load, mock_load_data, mock_fromarray): import torch sprites = ds.dSprites('root', False, None) mock_data = torch.rand(5, 1) mock_data[1] *= 3 sprites.data = mock_data out = sprites[1] self.assertTrue(mock_fromarray.call_args[0][0] == mock_data[1] * 255) self.assertTrue(mock_fromarray.call_args[1] == {'mode': 'L'}) self.assertTrue(len(out) == 2)
def test_image_by_latent(self, mock_load, mock_gi, mock_download): import numpy as np mock_gi.return_value = (None, ) bases = np.array([737280, 245760, 40960, 1024, 32, 1]) latent = [1, 2, 3, 4, 5, 6] img_id = (bases * latent).sum() sprites = ds.dSprites('root', False, None) sprites.get_img_by_latent(latent) self.assertTrue(sprites.__getitem__.call_args[0][0] == img_id)
def test_get_item_transform(self, mock_load, mock_load_data, mock_fromarray): import torch transform = Mock() sprites = ds.dSprites('root', False, transform) mock_fromarray.return_value = 'test' mock_data = torch.rand(5, 1) mock_data[1] *= 3 sprites.data = mock_data _ = sprites[1] self.assertTrue(transform.call_args[0][0] == 'test')
def test_download(self, mock_load, mock_load_data, mock_mkdirs, mock_copyfileobj, mock_zip, mock_open, mock_urlopen): mock_return = Mock() # mock_urlopen = Mock() # mock_url.return_value = mock_urlopen mock_urlopen.return_value.__enter__.return_value = mock_return root = 'root' filename = "dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz" sprites = ds.dSprites(root, True, None) self.assertTrue(mock_mkdirs.call_args[0][0] == "root") self.assertTrue(mock_mkdirs.call_args[1] == {'exist_ok': True}) self.assertTrue(mock_urlopen.call_args[0][0] == 'https://github.com/deepmind/dsprites-dataset/blob/master/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz?raw=true') self.assertTrue(mock_open.call_args[0][0] == 'root/' + filename) self.assertTrue(mock_copyfileobj.call_args[0][0] == mock_return)
def test_len(self, mock_load, mock_load_data, mock_download): import numpy as np mock_load_data.return_value = np.zeros((10, 5, 5)) sprites = ds.dSprites('root', False, None) self.assertTrue(len(sprites) == 10)
def test_latents_bases(self, mock_load, mock_load_data, mock_download): sprites = ds.dSprites('root', False, None) self.assertTrue( list(sprites.latents_bases) == [737280, 245760, 40960, 1024, 32, 1])
def test_do_download(self, mock_load, mock_load_data, mock_download): sprites = ds.dSprites('root', True, None) self.assertTrue(sprites.download.call_count == 1)