def test_unsupported_reset_finalized(self): dataset = [1, 2, 3, 4] it = iterators.MultiprocessIterator(dataset, 2, repeat=False, **self.options) it.next() it.next() it.finalize() self.assertRaises(NotImplementedError, it.reset) @testing.parameterize(*testing.product({ 'n_prefetch': [1, 2], 'shared_mem': [None, 1000000], 'order_sampler': [None, lambda order, _: numpy.random.permutation(len(order))], 'maxtasksperchild': [None], })) class TestMultiprocessIterator(BaseTestMultiprocessIterator, unittest.TestCase): pass @testing.parameterize(*testing.product({ 'n_prefetch': [1, 2], 'shared_mem': [None, 1000000], 'order_sampler': [None, lambda order, _: numpy.random.permutation(len(order))], 'maxtasksperchild': [1, 10], }))
from __future__ import division import copy import unittest import numpy import six from pytorch_trainer import iterators from pytorch_trainer import testing @testing.parameterize(*testing.product({ 'n_threads': [1, 2], 'order_sampler': [None, lambda order, _: numpy.random.permutation(len(order))] })) class TestMultithreadIterator(unittest.TestCase): def setUp(self): self.options = { 'n_threads': self.n_threads, 'order_sampler': self.order_sampler } def test_iterator_repeat(self): dataset = [1, 2, 3, 4, 5, 6] it = iterators.MultithreadIterator(dataset, 2, **self.options) for i in range(3): self.assertEqual(it.epoch, i) self.assertAlmostEqual(it.epoch_detail, i + 0 / 6) if i == 0: self.assertIsNone(it.previous_epoch_detail)
from __future__ import division import copy import itertools import unittest from pytorch_trainer import iterators from pytorch_trainer import testing @testing.parameterize(*testing.product({ 'n_prefetch': [1, 2], 'shared_mem': [None, 1000000], })) class TestIteratorCompatibility(unittest.TestCase): def setUp(self): self.n_processes = 2 self.options = { 'n_processes': self.n_processes, 'n_prefetch': self.n_prefetch, 'shared_mem': self.shared_mem } def test_iterator_compatibilty(self): dataset = [1, 2, 3, 4, 5, 6] iters = ( lambda: iterators.SerialIterator(dataset, 2), lambda: iterators.MultiprocessIterator(dataset, 2, **self.options), )
snapshot = extensions.snapshot_object(self.trainer, self.filename, snapshot_on_error=True) self.trainer.extend(snapshot) self.assertFalse(os.path.exists(self.filename)) with self.assertRaises(TheOnlyError): self.trainer.run() self.assertTrue(os.path.exists(self.filename)) @testing.parameterize(*testing.product({ 'fmt': ['snapshot_iter_{}', 'snapshot_iter_{}.npz', '{}_snapshot_man_suffix.npz'] })) class TestFindSnapshot(unittest.TestCase): def setUp(self): self.path = tempfile.mkdtemp() def tearDown(self): shutil.rmtree(self.path) def test_find_snapshot_files(self): files = (self.fmt.format(i) for i in range(1, 100)) noise = ('dummy-foobar-iter{}'.format(i) for i in range(10, 304)) noise2 = ('tmpsnapshot_iter_{}'.format(i) for i in range(10, 304)) for file in itertools.chain(noise, files, noise2): file = os.path.join(self.path, file)
def test_evaluate(self): reporter = pytorch_trainer.Reporter() reporter.add_observer('target', self.target) with reporter: self.evaluator.evaluate() # The model gets results of converter. self.assertEqual(len(self.target.args), len(self.batches)) for i in range(len(self.batches)): numpy.testing.assert_array_equal(self.target.args[i], self.batches[i]) @testing.parameterize(*testing.product({ 'repeat': [True, False], 'iterator_class': [iterators.SerialIterator, iterators.MultiprocessIterator] })) class TestEvaluatorRepeat(unittest.TestCase): def test_user_warning(self): dataset = torch.ones((4, 6)) iterator = self.iterator_class(dataset, 2, repeat=self.repeat) if self.repeat: with testing.assert_warns(UserWarning): extensions.Evaluator(iterator, {}) class TestEvaluatorProgressBar(unittest.TestCase): def setUp(self): self.data = [torch.empty(3, 4).uniform_(-1, 1) for _ in range(2)]
from __future__ import division import math import unittest from pytorch_trainer import testing def _dummy_extension(trainer): pass @testing.parameterize(*testing.product({ 'stop_trigger': [(5, 'iteration'), (5, 'epoch')], 'iter_per_epoch': [0.5, 1, 1.5, 5], 'extensions': [[], [_dummy_extension]] })) class TestGetTrainerWithMockUpdater(unittest.TestCase): def setUp(self): self.trainer = testing.get_trainer_with_mock_updater( self.stop_trigger, self.iter_per_epoch, extensions=self.extensions) def test_run(self): iteration = [0] def check(trainer): iteration[0] += 1 self.assertEqual(trainer.updater.iteration, iteration[0])