def test_datasets_url_is_used(self, urlretrieve): original_url = mnist.datasets_url mnist.datasets_url = 'http://aaa.com/' mnist.download_file('mnist_datasets_url.gz') fname = os.path.join(tempfile.gettempdir(), 'mnist_datasets_url.gz') urlretrieve.assert_called_once_with( 'http://aaa.com/mnist_datasets_url.gz', fname) mnist.datasets_url = original_url
def test_test_images_has_right_size(self): fname = 't10k-images-idx3-ubyte.gz' fname = mnist.download_file(fname, force=True) expected_size = (HEADER_IMAGES_SIZE + TEST_SAMPLES * IMAGE_SIZE * PIXEL_BYTES) actual_size = self._gzip_file_size(fname) self.assertEqual(expected_size, actual_size)
def test_temporary_dir_is_used(self, urlretrieve): original_temp_dir = mnist.temporary_dir mnist.temporary_dir = lambda: '/another/tmp/dir/' fname = mnist.download_file('test') urlretrieve.assert_called_once_with(mnist.datasets_url + 'test', '/another/tmp/dir/test') self.assertEqual(fname, '/another/tmp/dir/test') mnist.temporary_dir = original_temp_dir
def test_file_is_downloaded_when_exists_and_force_is_true(self, urlretrieve): mnist.download_file('test', force=True) urlretrieve.assert_called_once_with(mnist.datasets_url + 'test', os.path.join(tempfile.gettempdir(), 'test'))
def test_file_is_not_downloaded_when_force_is_false(self, urlretrieve): mnist.download_file(self.downloaded_fname, force=False) self.assertFalse(urlretrieve.called)
def test_file_is_downloaded_to_target_dir(self, urlretrieve): fname = mnist.download_file('test', target_dir='/tmp/mnist_test/') urlretrieve.assert_called_once_with(mnist.datasets_url + 'test', '/tmp/mnist_test/test') self.assertEqual(fname, '/tmp/mnist_test/test')
def test_test_labels_has_right_size(self): fname = 't10k-labels-idx1-ubyte.gz' fname = mnist.download_file(fname, force=True) expected_size = HEADER_LABELS_SIZE + TEST_SAMPLES * LABEL_BYTES actual_size = self._gzip_file_size(fname) self.assertEqual(expected_size, actual_size)