Ejemplo n.º 1
0
    def test_celeba_hq_use_numpy_loader(self, mock_gi, mock_len, mock_loader):
        mock_len.return_value = 10
        mock_gi.return_value = Mock()

        celebahq = ds.CelebA_HQ('root', True, Mock())
        out = celebahq[6]

        self.assertTrue(celebahq.loader == mock_loader)
        self.assertTrue(celebahq.extensions == ['npy'])
Ejemplo n.º 2
0
    def test_celeba_hq_use_default_loader(self, mock_loader, mock_gi, mock_len, ):
        import torchvision
        mock_len.return_value = 10

        celebahq = ds.CelebA_HQ('root', False, Mock())
        out = celebahq[6]

        self.assertTrue(celebahq.loader == mock_loader)
        self.assertTrue(celebahq.extensions == torchvision.datasets.folder.IMG_EXTENSIONS)
Ejemplo n.º 3
0
    def test_celeba_hq_get_item(self, mock_gi, mock_len):
        return_val = 'test'
        mock_len.return_value = 10
        mock_gi.return_value = return_val

        celebahq = ds.CelebA_HQ('root', False, Mock())
        out = celebahq[6]

        self.assertTrue(mock_gi.call_args[0][0] == 6)
        self.assertTrue(out == return_val)
Ejemplo n.º 4
0
    def test_celeba_hq_numpy_loader(self, mock_fromarray, mock_load, mock_len):
        samples = ['a.npy', 'b.npy', 'c.npy']
        mock_len.return_value = 3
        mock_output = Mock()
        mock_output.transpose.return_value = mock_output
        mock_load.return_value = (mock_output, )

        celebahq = ds.CelebA_HQ('root', True, Mock())
        celebahq.samples = samples
        out = celebahq[1]

        self.assertTrue(mock_load.call_args[0][0] == samples[1])
        self.assertTrue(mock_fromarray.call_args[0][0] == mock_output)