コード例 #1
0
    def test_unsupported_reset_finalized(self):
        dataset = [1, 2, 3, 4]
        it = iterators.MultiprocessIterator(dataset,
                                            2,
                                            repeat=False,
                                            **self.options)
        it.next()
        it.next()
        it.finalize()
        self.assertRaises(NotImplementedError, it.reset)


@testing.parameterize(*testing.product({
    'n_prefetch': [1, 2],
    'shared_mem': [None, 1000000],
    'order_sampler':
    [None, lambda order, _: numpy.random.permutation(len(order))],
    'maxtasksperchild': [None],
}))
class TestMultiprocessIterator(BaseTestMultiprocessIterator,
                               unittest.TestCase):
    pass


@testing.parameterize(*testing.product({
    'n_prefetch': [1, 2],
    'shared_mem': [None, 1000000],
    'order_sampler':
    [None, lambda order, _: numpy.random.permutation(len(order))],
    'maxtasksperchild': [1, 10],
}))
コード例 #2
0
from __future__ import division
import copy
import unittest

import numpy
import six

from pytorch_trainer import iterators
from pytorch_trainer import testing


@testing.parameterize(*testing.product({
    'n_threads': [1, 2],
    'order_sampler':
    [None, lambda order, _: numpy.random.permutation(len(order))]
}))
class TestMultithreadIterator(unittest.TestCase):
    def setUp(self):
        self.options = {
            'n_threads': self.n_threads,
            'order_sampler': self.order_sampler
        }

    def test_iterator_repeat(self):
        dataset = [1, 2, 3, 4, 5, 6]
        it = iterators.MultithreadIterator(dataset, 2, **self.options)
        for i in range(3):
            self.assertEqual(it.epoch, i)
            self.assertAlmostEqual(it.epoch_detail, i + 0 / 6)
            if i == 0:
                self.assertIsNone(it.previous_epoch_detail)
コード例 #3
0
from __future__ import division

import copy
import itertools
import unittest

from pytorch_trainer import iterators
from pytorch_trainer import testing


@testing.parameterize(*testing.product({
    'n_prefetch': [1, 2],
    'shared_mem': [None, 1000000],
}))
class TestIteratorCompatibility(unittest.TestCase):
    def setUp(self):
        self.n_processes = 2
        self.options = {
            'n_processes': self.n_processes,
            'n_prefetch': self.n_prefetch,
            'shared_mem': self.shared_mem
        }

    def test_iterator_compatibilty(self):
        dataset = [1, 2, 3, 4, 5, 6]

        iters = (
            lambda: iterators.SerialIterator(dataset, 2),
            lambda: iterators.MultiprocessIterator(dataset, 2, **self.options),
        )
コード例 #4
0
        snapshot = extensions.snapshot_object(self.trainer,
                                              self.filename,
                                              snapshot_on_error=True)
        self.trainer.extend(snapshot)

        self.assertFalse(os.path.exists(self.filename))

        with self.assertRaises(TheOnlyError):
            self.trainer.run()

        self.assertTrue(os.path.exists(self.filename))


@testing.parameterize(*testing.product({
    'fmt':
    ['snapshot_iter_{}', 'snapshot_iter_{}.npz', '{}_snapshot_man_suffix.npz']
}))
class TestFindSnapshot(unittest.TestCase):
    def setUp(self):
        self.path = tempfile.mkdtemp()

    def tearDown(self):
        shutil.rmtree(self.path)

    def test_find_snapshot_files(self):
        files = (self.fmt.format(i) for i in range(1, 100))
        noise = ('dummy-foobar-iter{}'.format(i) for i in range(10, 304))
        noise2 = ('tmpsnapshot_iter_{}'.format(i) for i in range(10, 304))

        for file in itertools.chain(noise, files, noise2):
            file = os.path.join(self.path, file)
コード例 #5
0
    def test_evaluate(self):
        reporter = pytorch_trainer.Reporter()
        reporter.add_observer('target', self.target)
        with reporter:
            self.evaluator.evaluate()

        # The model gets results of converter.
        self.assertEqual(len(self.target.args), len(self.batches))
        for i in range(len(self.batches)):
            numpy.testing.assert_array_equal(self.target.args[i],
                                             self.batches[i])


@testing.parameterize(*testing.product({
    'repeat': [True, False],
    'iterator_class':
    [iterators.SerialIterator, iterators.MultiprocessIterator]
}))
class TestEvaluatorRepeat(unittest.TestCase):
    def test_user_warning(self):
        dataset = torch.ones((4, 6))
        iterator = self.iterator_class(dataset, 2, repeat=self.repeat)
        if self.repeat:
            with testing.assert_warns(UserWarning):
                extensions.Evaluator(iterator, {})


class TestEvaluatorProgressBar(unittest.TestCase):
    def setUp(self):
        self.data = [torch.empty(3, 4).uniform_(-1, 1) for _ in range(2)]
コード例 #6
0
from __future__ import division

import math
import unittest

from pytorch_trainer import testing


def _dummy_extension(trainer):
    pass


@testing.parameterize(*testing.product({
    'stop_trigger': [(5, 'iteration'), (5, 'epoch')],
    'iter_per_epoch': [0.5, 1, 1.5, 5],
    'extensions': [[], [_dummy_extension]]
}))
class TestGetTrainerWithMockUpdater(unittest.TestCase):

    def setUp(self):
        self.trainer = testing.get_trainer_with_mock_updater(
            self.stop_trigger, self.iter_per_epoch,
            extensions=self.extensions)

    def test_run(self):
        iteration = [0]

        def check(trainer):
            iteration[0] += 1

            self.assertEqual(trainer.updater.iteration, iteration[0])