def test_datasets_url_is_used(self, urlretrieve):
     original_url = mnist.datasets_url
     mnist.datasets_url = 'http://aaa.com/'
     mnist.download_file('mnist_datasets_url.gz')
     fname = os.path.join(tempfile.gettempdir(), 'mnist_datasets_url.gz')
     urlretrieve.assert_called_once_with(
         'http://aaa.com/mnist_datasets_url.gz', fname)
     mnist.datasets_url = original_url
 def test_test_images_has_right_size(self):
     fname = 't10k-images-idx3-ubyte.gz'
     fname = mnist.download_file(fname, force=True)
     expected_size = (HEADER_IMAGES_SIZE +
                      TEST_SAMPLES * IMAGE_SIZE * PIXEL_BYTES)
     actual_size = self._gzip_file_size(fname)
     self.assertEqual(expected_size, actual_size)
 def test_temporary_dir_is_used(self, urlretrieve):
     original_temp_dir = mnist.temporary_dir
     mnist.temporary_dir = lambda: '/another/tmp/dir/'
     fname = mnist.download_file('test')
     urlretrieve.assert_called_once_with(mnist.datasets_url + 'test',
                                         '/another/tmp/dir/test')
     self.assertEqual(fname, '/another/tmp/dir/test')
     mnist.temporary_dir = original_temp_dir
Beispiel #4
0
 def test_file_is_downloaded_when_exists_and_force_is_true(self, urlretrieve):
     mnist.download_file('test', force=True)
     urlretrieve.assert_called_once_with(mnist.datasets_url + 'test',
                                         os.path.join(tempfile.gettempdir(), 'test'))
Beispiel #5
0
 def test_file_is_not_downloaded_when_force_is_false(self, urlretrieve):
     mnist.download_file(self.downloaded_fname, force=False)
     self.assertFalse(urlretrieve.called)
Beispiel #6
0
 def test_file_is_downloaded_to_target_dir(self, urlretrieve):
     fname = mnist.download_file('test', target_dir='/tmp/mnist_test/')
     urlretrieve.assert_called_once_with(mnist.datasets_url + 'test',
                                         '/tmp/mnist_test/test')
     self.assertEqual(fname, '/tmp/mnist_test/test')
Beispiel #7
0
 def test_test_labels_has_right_size(self):
     fname = 't10k-labels-idx1-ubyte.gz'
     fname = mnist.download_file(fname, force=True)
     expected_size = HEADER_LABELS_SIZE + TEST_SAMPLES * LABEL_BYTES
     actual_size = self._gzip_file_size(fname)
     self.assertEqual(expected_size, actual_size)