示例#1
0
    def test_resumed_trigger_sparse_call(self):
        trainer = testing.get_trainer_with_mock_updater(
            stop_trigger=None, iter_per_epoch=self.iter_per_epoch)
        accumulated = False
        with tempfile.NamedTemporaryFile(delete=False) as f:
            trigger = training.triggers.ManualScheduleTrigger(*self.schedule)
            for expected, finished in zip(self.expected[:self.resume],
                                          self.finished[:self.resume]):
                trainer.updater.update()
                accumulated = accumulated or expected
                if random.randrange(2):
                    self.assertEqual(trigger(trainer), accumulated)
                    self.assertEqual(trigger.finished, finished)
                    accumulated = False
            serializers.save_npz(f.name, trigger)

            trigger = training.triggers.ManualScheduleTrigger(*self.schedule)
            serializers.load_npz(f.name, trigger)
            for expected, finished in zip(self.expected[self.resume:],
                                          self.finished[self.resume:]):
                trainer.updater.update()
                accumulated = accumulated or expected
                if random.randrange(2):
                    self.assertEqual(trigger(trainer), accumulated)
                    self.assertEqual(trigger.finished, finished)
                    accumulated = False
 def test_trigger(self):
     trainer = testing.get_trainer_with_mock_updater(
         stop_trigger=None, iter_per_epoch=self.iter_per_epoch)
     trigger = training.triggers.ManualScheduleTrigger(*self.schedule)
     for expected in self.expected:
         trainer.updater.update()
         self.assertEqual(trigger(trainer), expected)
 def test_trigger(self):
     trainer = testing.get_trainer_with_mock_updater(
         stop_trigger=None, iter_per_epoch=self.iter_per_epoch)
     trigger = training.triggers.ManualScheduleTrigger(*self.schedule)
     for expected in self.expected:
         trainer.updater.update()
         self.assertEqual(trigger(trainer), expected)
示例#4
0
 def test_trigger(self):
     trainer = testing.get_trainer_with_mock_updater(
         stop_trigger=None, iter_per_epoch=self.iter_per_epoch)
     trigger = training.triggers.OnceTrigger(self.call_on_resume)
     for expected in self.expected:
         self.assertEqual(trigger(trainer), expected)
         trainer.updater.update()
    def test_resumed_trigger_sparse_call(self):
        trainer = testing.get_trainer_with_mock_updater(
            stop_trigger=None, iter_per_epoch=self.iter_per_epoch)
        accumulated = False
        with tempfile.NamedTemporaryFile(delete=False) as f:
            trigger = training.triggers.ManualScheduleTrigger(*self.schedule)
            for expected, finished in zip(self.expected[:self.resume],
                                          self.finished[:self.resume]):
                trainer.updater.update()
                accumulated = accumulated or expected
                if random.randrange(2):
                    self.assertEqual(trigger(trainer), accumulated)
                    self.assertEqual(trigger.finished, finished)
                    accumulated = False
            serializers.save_npz(f.name, trigger)

            trigger = training.triggers.ManualScheduleTrigger(*self.schedule)
            serializers.load_npz(f.name, trigger)
            for expected, finished in zip(self.expected[self.resume:],
                                          self.finished[self.resume:]):
                trainer.updater.update()
                accumulated = accumulated or expected
                if random.randrange(2):
                    self.assertEqual(trigger(trainer), accumulated)
                    self.assertEqual(trigger.finished, finished)
                    accumulated = False
示例#6
0
    def _test_trigger(self, trigger, key, accuracies, expected,
                      resume=None, save=None):
        trainer = testing.get_trainer_with_mock_updater(
            stop_trigger=(len(accuracies), 'iteration'),
            iter_per_epoch=self.iter_per_epoch)
        updater = trainer.updater

        def _serialize_updater(serializer):
            updater.iteration = serializer('iteration', updater.iteration)
            updater.epoch = serializer('epoch', updater.epoch)
            updater.is_new_epoch = serializer(
                'is_new_epoch', updater.is_new_epoch)
        trainer.updater.serialize = _serialize_updater

        def set_observation(t):
            t.observation = {key: accuracies[t.updater.iteration-1]}
        trainer.extend(set_observation, name='set_observation',
                       trigger=(1, 'iteration'), priority=2)

        invoked_iterations = []

        def record(t):
            invoked_iterations.append(t.updater.iteration)
        trainer.extend(record, name='record', trigger=trigger, priority=1)

        if resume is not None:
            serializers.load_npz(resume, trainer)

        trainer.run()
        self.assertEqual(invoked_iterations, expected)

        if save is not None:
            serializers.save_npz(save, trainer)
示例#7
0
    def test_raise_error_if_call_not_implemented(self):
        class MyExtension(training.Extension):
            pass

        ext = MyExtension()
        trainer = testing.get_trainer_with_mock_updater()
        with pytest.raises(NotImplementedError):
            ext(trainer)
 def test_trigger(self):
     trainer = testing.get_trainer_with_mock_updater(
         stop_trigger=None, iter_per_epoch=self.iter_per_epoch)
     trigger = training.triggers.IntervalTrigger(*self.interval)
     # before the first iteration, trigger should be False
     for expected in [False] + self.expected:
         self.assertEqual(trigger(trainer), expected)
         trainer.updater.update()
示例#9
0
 def test_trigger(self):
     trainer = testing.get_trainer_with_mock_updater(
         stop_trigger=None, iter_per_epoch=self.iter_per_epoch)
     trigger = training.triggers.OnceTrigger(self.call_on_resume)
     for expected, finished in zip(self.expected, self.finished):
         self.assertEqual(trigger.finished, finished)
         self.assertEqual(trigger(trainer), expected)
         trainer.updater.update()
示例#10
0
 def test_trigger(self):
     trainer = testing.get_trainer_with_mock_updater(
         stop_trigger=None, iter_per_epoch=self.iter_per_epoch)
     trigger = training.triggers.IntervalTrigger(*self.interval)
     # before the first iteration, trigger should be False
     for expected in [False] + self.expected:
         self.assertEqual(trigger(trainer), expected)
         trainer.updater.update()
示例#11
0
    def test_raise_error_if_call_not_implemented(self):
        class MyExtension(training.Extension):
            pass

        ext = MyExtension()
        trainer = testing.get_trainer_with_mock_updater()
        with pytest.raises(NotImplementedError):
            ext(trainer)
示例#12
0
    def test_elapsed_time_serialization(self):
        self.trainer.run()
        serialized_time = self.trainer.elapsed_time

        new_trainer = testing.get_trainer_with_mock_updater((20, 'iteration'))
        testing.save_and_load_npz(self.trainer, new_trainer)

        new_trainer.run()
        self.assertGreater(new_trainer.elapsed_time, serialized_time)
示例#13
0
    def setUp(self):
        self.optimizer = mock.MagicMock()
        self.extension = extensions.LinearShift(
            'x', self.value_range, self.time_range, self.optimizer)

        self.interval = 2
        self.trigger = training.get_trigger((self.interval, 'iteration'))

        self.trainer = testing.get_trainer_with_mock_updater(self.trigger)
        self.trainer.updater.get_optimizer.return_value = self.optimizer
示例#14
0
    def setUp(self):
        self.optimizer = mock.MagicMock()
        self.extension = extensions.ExponentialShift(
            'x', self.rate, self.init, self.target, self.optimizer)

        self.interval = 4
        self.expect = [e for e in self.expect for _ in range(self.interval)]
        self.trigger = util.get_trigger((self.interval, 'iteration'))

        self.trainer = testing.get_trainer_with_mock_updater(self.trigger)
        self.trainer.updater.get_optimizer.return_value = self.optimizer
示例#15
0
    def test_make_shift(self):
        trainer = testing.get_trainer_with_mock_updater(
            iter_per_epoch=10, extensions=[mod5_shift])
        trainer.updater.get_optimizer.return_value = mock.MagicMock()
        trainer.updater.get_optimizer().x = -1

        mod5_shift.initialize(trainer)
        for i in range(100):
            self.assertEqual(trainer.updater.get_optimizer().x, i % 5)
            trainer.updater.update()
            mod5_shift(trainer)
 def test_trigger_sparse_call(self):
     trainer = testing.get_trainer_with_mock_updater(
         stop_trigger=None, iter_per_epoch=self.iter_per_epoch)
     trigger = training.triggers.ManualScheduleTrigger(*self.schedule)
     accumulated = False
     for expected in self.expected:
         trainer.updater.update()
         accumulated = accumulated or expected
         if random.randrange(2):
             self.assertEqual(trigger(trainer), accumulated)
             accumulated = False
示例#17
0
    def setUp(self):
        self.optimizer = mock.MagicMock()
        self.extension = extensions.MultistepShift(
            'x', self.gamma, self.step_value, self.init, self.optimizer)

        self.interval = 1
        self.expect = [e for e in self.expect for _ in range(self.interval)]
        self.trigger = util.get_trigger((self.interval, 'iteration'))

        self.trainer = testing.get_trainer_with_mock_updater(self.trigger)
        self.trainer.updater.get_optimizer.return_value = self.optimizer
示例#18
0
    def setUp(self):
        self.optimizer = mock.MagicMock()
        self.extension = extensions.LinearShift('x', self.value_range,
                                                self.time_range,
                                                self.optimizer)

        self.interval = 2
        self.trigger = training.get_trigger((self.interval, 'iteration'))

        self.trainer = testing.get_trainer_with_mock_updater(self.trigger)
        self.trainer.updater.get_optimizer.return_value = self.optimizer
示例#19
0
    def setUp(self):
        self.optimizer = mock.MagicMock()
        self.extension = extensions.Multistep(
            'x', self.base_lr, self.gamma, self.step_value, self.optimizer)

        self.interval = 1
        self.expect = [e for e in self.expect for _ in range(self.interval)]
        self.trigger = util.get_trigger((self.interval, 'iteration'))

        self.trainer = testing.get_trainer_with_mock_updater(self.trigger)
        self.trainer.updater.get_optimizer.return_value = self.optimizer
 def test_trigger_sparse_call(self):
     trainer = testing.get_trainer_with_mock_updater(
         stop_trigger=None, iter_per_epoch=self.iter_per_epoch)
     trigger = training.triggers.ManualScheduleTrigger(*self.schedule)
     accumulated = False
     for expected in self.expected:
         trainer.updater.update()
         accumulated = accumulated or expected
         if random.randrange(2):
             self.assertEqual(trigger(trainer), accumulated)
             accumulated = False
示例#21
0
    def test_remove_stale_snapshots(self):
        fmt = 'snapshot_iter_{.updater.iteration}'
        retain = 3
        snapshot = extensions.snapshot(filename=fmt, n_retains=retain,
                                       autoload=False)

        trainer = testing.get_trainer_with_mock_updater()
        trainer.out = self.path
        trainer.extend(snapshot, trigger=(1, 'iteration'), priority=2)

        class TimeStampUpdater():
            t = time.time() - 100
            name = 'ts_updater'
            priority = 1  # This must be called after snapshot taken

            def __call__(self, _trainer):
                filename = os.path.join(_trainer.out, fmt.format(_trainer))
                self.t += 1
                # For filesystems that does low timestamp precision
                os.utime(filename, (self.t, self.t))

        trainer.extend(TimeStampUpdater(), trigger=(1, 'iteration'))
        trainer.run()
        assert 10 == trainer.updater.iteration
        assert trainer._done

        pattern = os.path.join(trainer.out, "snapshot_iter_*")
        found = [os.path.basename(path) for path in glob.glob(pattern)]
        assert retain == len(found)
        found.sort()
        # snapshot_iter_(8, 9, 10) expected
        expected = ['snapshot_iter_{}'.format(i) for i in range(8, 11)]
        expected.sort()
        assert expected == found

        trainer2 = testing.get_trainer_with_mock_updater()
        trainer2.out = self.path
        assert not trainer2._done
        snapshot2 = extensions.snapshot(filename=fmt, autoload=True)
        # Just making sure no error occurs
        snapshot2.initialize(trainer2)
 def test_trigger_sparse_call(self):
     trainer = testing.get_trainer_with_mock_updater(
         stop_trigger=None, iter_per_epoch=self.iter_per_epoch)
     trigger = training.triggers.IntervalTrigger(*self.interval)
     accumulated = False
     # before the first iteration, trigger should be False
     for expected in [False] + self.expected:
         accumulated = accumulated or expected
         if random.randrange(2):
             self.assertEqual(trigger(trainer), accumulated)
             accumulated = False
         trainer.updater.update()
示例#23
0
 def test_trigger_sparse_call(self):
     trainer = testing.get_trainer_with_mock_updater(
         stop_trigger=None, iter_per_epoch=self.iter_per_epoch)
     trigger = training.triggers.IntervalTrigger(*self.interval)
     accumulated = False
     # before the first iteration, trigger should be False
     for expected in [False] + self.expected:
         accumulated = accumulated or expected
         if random.randrange(2):
             self.assertEqual(trigger(trainer), accumulated)
             accumulated = False
         trainer.updater.update()
示例#24
0
def test_snapshot():
    trainer = testing.get_trainer_with_mock_updater()
    trainer.out = '.'
    trainer._done = True

    with tempfile.TemporaryDirectory() as td:
        writer = SimpleWriter(td)
        snapshot = extensions.snapshot(writer=writer)
        snapshot(trainer)
        assert 'snapshot_iter_0' in os.listdir(td)

        trainer2 = chainer.testing.get_trainer_with_mock_updater()
        load_snapshot(trainer2, td, fail_on_no_file=True)
    def setUp(self):
        stop_trigger = (2, 'iteration')
        extension_trigger = (1, 'iteration')
        self.filename = 'variable_statistics_plot_test.png'

        self.trainer = testing.get_trainer_with_mock_updater(
            stop_trigger=stop_trigger)

        x = numpy.random.rand(1, 2, 3)
        self.extension = extensions.VariableStatisticsPlot(
            chainer.variable.Variable(x), trigger=extension_trigger,
            filename=self.filename)
        self.trainer.extend(self.extension, extension_trigger)
    def setUp(self):
        stop_trigger = (2, 'iteration')
        extension_trigger = (1, 'iteration')
        self.file_name = 'variable_statistics_plot_test.png'

        self.trainer = testing.get_trainer_with_mock_updater(
            stop_trigger=stop_trigger)

        x = numpy.random.rand(1, 2, 3)
        self.extension = extensions.VariableStatisticsPlot(
            chainer.variable.Variable(x), trigger=extension_trigger,
            file_name=self.file_name)
        self.trainer.extend(self.extension, extension_trigger)
示例#27
0
    def test_resume(self):
        new_optimizer = mock.Mock()
        new_extension = extensions.ExponentialShift('x', self.rate, self.init,
                                                    self.target, new_optimizer)

        self.trainer.extend(self.extension)
        self.trainer.run()

        new_trainer = testing.get_trainer_with_mock_updater((3, 'iteration'))
        new_trainer.extend(new_extension)
        testing.save_and_load_npz(self.trainer, new_trainer)

        new_extension.initialize(new_trainer)
        self.assertEqual(new_optimizer.x, self.optimizer.x)
示例#28
0
    def test_resume(self):
        new_optimizer = mock.Mock()
        new_extension = extensions.LinearShift('x', self.value_range,
                                               self.time_range, new_optimizer)

        self.trainer.extend(self.extension)
        self.trainer.run()

        new_trainer = testing.get_trainer_with_mock_updater((5, 'iteration'))
        new_trainer.extend(new_extension)
        testing.save_and_load_npz(self.trainer, new_trainer)

        new_extension.initialize(new_trainer)
        self.assertEqual(new_optimizer.x, self.optimizer.x)
示例#29
0
    def test_resume(self):
        new_optimizer = mock.Mock()
        new_extension = extensions.ExponentialShift(
            'x', self.rate, self.init, self.target, new_optimizer)

        self.trainer.extend(self.extension)
        self.trainer.run()

        new_trainer = testing.get_trainer_with_mock_updater((3, 'iteration'))
        new_trainer.extend(new_extension)
        testing.save_and_load_npz(self.trainer, new_trainer)

        new_extension.initialize(new_trainer)
        self.assertEqual(new_optimizer.x, self.optimizer.x)
        self.assertIsInstance(new_optimizer.x, float)
    def test_resume(self):
        new_optimizer = mock.Mock()
        new_extension = extensions.InverseShift(
            'x', self.gamma, self.power, self.init, self.target, new_optimizer)

        self.trainer.extend(self.extension)
        self.trainer.run()

        new_trainer = testing.get_trainer_with_mock_updater((3, 'iteration'))
        new_trainer.extend(new_extension)
        testing.save_and_load_npz(self.trainer, new_trainer)

        new_extension.initialize(new_trainer)
        self.assertEqual(new_optimizer.x, self.optimizer.x)
        self.assertIsInstance(new_optimizer.x, float)
    def test_resumed_trigger(self):
        trainer = testing.get_trainer_with_mock_updater(
            stop_trigger=None, iter_per_epoch=self.iter_per_epoch)
        with tempfile.NamedTemporaryFile(delete=False) as f:
            trigger = training.triggers.IntervalTrigger(*self.interval)
            for expected in self.expected[:self.resume]:
                trainer.updater.update()
                self.assertEqual(trigger(trainer), expected)
            serializers.save_npz(f.name, trigger)

            trigger = training.triggers.IntervalTrigger(*self.interval)
            serializers.load_npz(f.name, trigger)
            for expected in self.expected[self.resume:]:
                trainer.updater.update()
                self.assertEqual(trigger(trainer), expected)
    def test_resumed_trigger(self):
        trainer = testing.get_trainer_with_mock_updater(
            stop_trigger=None, iter_per_epoch=self.iter_per_epoch)
        with tempfile.NamedTemporaryFile(delete=False) as f:
            trigger = training.triggers.ManualScheduleTrigger(*self.schedule)
            for expected in self.expected[:self.resume]:
                trainer.updater.update()
                self.assertEqual(trigger(trainer), expected)
            serializers.save_npz(f.name, trigger)

            trigger = training.triggers.ManualScheduleTrigger(*self.schedule)
            serializers.load_npz(f.name, trigger)
            for expected in self.expected[self.resume:]:
                trainer.updater.update()
                self.assertEqual(trigger(trainer), expected)
示例#33
0
 def test_trigger_sparse_call(self):
     trainer = testing.get_trainer_with_mock_updater(
         stop_trigger=None, iter_per_epoch=self.iter_per_epoch)
     trigger = training.triggers.OnceTrigger(self.call_on_resume)
     accumulated = False
     accumulated_finished = True
     for expected, finished in zip(self.expected, self.finished):
         accumulated = accumulated or expected
         accumulated_finished = accumulated_finished and finished
         if random.randrange(2):
             self.assertEqual(trigger.finished, accumulated_finished)
             self.assertEqual(trigger(trainer), accumulated)
             accumulated = False
             accumulated_finished = True
         trainer.updater.update()
示例#34
0
    def test_resumed_trigger_backward_compat(self):
        trainer = testing.get_trainer_with_mock_updater(
            stop_trigger=None, iter_per_epoch=self.iter_per_epoch)
        with tempfile.NamedTemporaryFile(delete=False) as f:
            trigger = training.triggers.IntervalTrigger(*self.interval)
            for expected in self.expected[:self.resume]:
                trainer.updater.update()
                self.assertEqual(trigger(trainer), expected)
            # old version does not save anything
            np.savez(f, dummy=0)

            trigger = training.triggers.IntervalTrigger(*self.interval)
            serializers.load_npz(f.name, trigger)
            for expected in self.expected[self.resume:]:
                trainer.updater.update()
                self.assertEqual(trigger(trainer), expected)
    def test_resumed_trigger_backward_compat(self):
        trainer = testing.get_trainer_with_mock_updater(
            stop_trigger=None, iter_per_epoch=self.iter_per_epoch)
        with tempfile.NamedTemporaryFile(delete=False) as f:
            trigger = training.triggers.IntervalTrigger(*self.interval)
            for expected in self.expected[:self.resume]:
                trainer.updater.update()
                self.assertEqual(trigger(trainer), expected)
            # old version does not save anything
            np.savez(f, dummy=0)

            trigger = training.triggers.IntervalTrigger(*self.interval)
            serializers.load_npz(f.name, trigger)
            for expected in self.expected[self.resume:]:
                trainer.updater.update()
                self.assertEqual(trigger(trainer), expected)
示例#36
0
    def _setup(self, stream=None, delete_flush=False):
        self.logreport = mock.MagicMock(spec=extensions.LogReport(
            ['epoch'], trigger=(1, 'iteration'), log_name=None))
        if stream is None:
            self.stream = mock.MagicMock()
            if delete_flush:
                del self.stream.flush
        else:
            self.stream = stream
        self.report = extensions.PrintReport(
            ['epoch'], log_report=self.logreport, out=self.stream)

        self.trainer = testing.get_trainer_with_mock_updater(
            stop_trigger=(1, 'iteration'))
        self.trainer.extend(self.logreport)
        self.trainer.extend(self.report)
        self.logreport.log = [{'epoch': 0}]
    def _setup(self, stream=None, delete_flush=False):
        self.logreport = mock.MagicMock(spec=extensions.LogReport(
            ['epoch'], trigger=(1, 'iteration'), log_name=None))
        if stream is None:
            self.stream = mock.MagicMock()
            if delete_flush:
                del self.stream.flush
        else:
            self.stream = stream
        self.report = extensions.PrintReport(
            ['epoch'], log_report=self.logreport, out=self.stream)

        self.trainer = testing.get_trainer_with_mock_updater(
            stop_trigger=(1, 'iteration'))
        self.trainer.extend(self.logreport)
        self.trainer.extend(self.report)
        self.logreport.log = [{'epoch': 0}]
示例#38
0
    def _test_trigger(self,
                      trigger,
                      key,
                      accuracies,
                      expected,
                      resume=None,
                      save=None):
        trainer = testing.get_trainer_with_mock_updater(
            stop_trigger=(len(accuracies), 'iteration'),
            iter_per_epoch=self.iter_per_epoch)
        updater = trainer.updater

        def _serialize_updater(serializer):
            updater.iteration = serializer('iteration', updater.iteration)
            updater.epoch = serializer('epoch', updater.epoch)
            updater.is_new_epoch = serializer('is_new_epoch',
                                              updater.is_new_epoch)

        trainer.updater.serialize = _serialize_updater

        def set_observation(t):
            t.observation = {key: accuracies[t.updater.iteration - 1]}

        trainer.extend(set_observation,
                       name='set_observation',
                       trigger=(1, 'iteration'),
                       priority=2)

        invoked_iterations = []

        def record(t):
            invoked_iterations.append(t.updater.iteration)

        trainer.extend(record, name='record', trigger=trigger, priority=1)

        if resume is not None:
            serializers.load_npz(resume, trainer)

        trainer.run()
        self.assertEqual(invoked_iterations, expected)

        if save is not None:
            serializers.save_npz(save, trainer)
示例#39
0
    def test_resumed_trigger(self):
        trainer = testing.get_trainer_with_mock_updater(
            stop_trigger=None, iter_per_epoch=self.iter_per_epoch)
        with tempfile.NamedTemporaryFile(delete=False) as f:
            trigger = training.triggers.OnceTrigger(self.call_on_resume)
            for expected, finished in zip(self.resumed_expected[:self.resume],
                                          self.resumed_finished[:self.resume]):
                trainer.updater.update()
                self.assertEqual(trigger.finished, finished)
                self.assertEqual(trigger(trainer), expected)
            serializers.save_npz(f.name, trigger)

            trigger = training.triggers.OnceTrigger(self.call_on_resume)
            serializers.load_npz(f.name, trigger)
            for expected, finished in zip(self.resumed_expected[self.resume:],
                                          self.resumed_finished[self.resume:]):
                trainer.updater.update()
                self.assertEqual(trigger.finished, finished)
                self.assertEqual(trigger(trainer), expected)
示例#40
0
    def test_resumed_trigger_backward_compat(self):
        trainer = testing.get_trainer_with_mock_updater(
            stop_trigger=None, iter_per_epoch=self.iter_per_epoch)
        with tempfile.NamedTemporaryFile(delete=False) as f:
            trigger = training.triggers.ManualScheduleTrigger(*self.schedule)
            for expected, finished in zip(self.expected[:self.resume],
                                          self.finished[:self.resume]):
                trainer.updater.update()
                self.assertEqual(trigger(trainer), expected)
                self.assertEqual(trigger.finished, finished)
            # old version does not save anything
            np.savez(f, dummy=0)

            trigger = training.triggers.ManualScheduleTrigger(*self.schedule)
            with testing.assert_warns(UserWarning):
                serializers.load_npz(f.name, trigger)
            for expected, finished in zip(self.expected[self.resume:],
                                          self.finished[self.resume:]):
                trainer.updater.update()
                self.assertEqual(trigger(trainer), expected)
                self.assertEqual(trigger.finished, finished)
    def test_resumed_trigger_backward_compat(self):
        trainer = testing.get_trainer_with_mock_updater(
            stop_trigger=None, iter_per_epoch=self.iter_per_epoch)
        with tempfile.NamedTemporaryFile(delete=False) as f:
            trigger = training.triggers.ManualScheduleTrigger(*self.schedule)
            for expected, finished in zip(self.expected[:self.resume],
                                          self.finished[:self.resume]):
                trainer.updater.update()
                self.assertEqual(trigger(trainer), expected)
                self.assertEqual(trigger.finished, finished)
            # old version does not save anything
            np.savez(f, dummy=0)

            trigger = training.triggers.ManualScheduleTrigger(*self.schedule)
            with testing.assert_warns(UserWarning):
                serializers.load_npz(f.name, trigger)
            for expected, finished in zip(self.expected[self.resume:],
                                          self.finished[self.resume:]):
                trainer.updater.update()
                self.assertEqual(trigger(trainer), expected)
                self.assertEqual(trigger.finished, finished)
示例#42
0
    def test_resumed_trigger_sparse_call(self):
        trainer = testing.get_trainer_with_mock_updater(
            stop_trigger=None, iter_per_epoch=self.iter_per_epoch)
        accumulated = False
        with tempfile.NamedTemporaryFile(delete=False) as f:
            trigger = training.triggers.OnceTrigger(self.call_on_resume)
            for expected in self.expected:
                trainer.updater.update()
                accumulated = accumulated or expected
                if random.randrange(2):
                    self.assertEqual(trigger(trainer), accumulated)
                    accumulated = False
            serializers.save_npz(f.name, trigger)

            trigger = training.triggers.OnceTrigger(self.call_on_resume)
            serializers.load_npz(f.name, trigger)
            for expected in self.expected_resume:
                trainer.updater.update()
                accumulated = accumulated or expected
                if random.randrange(2):
                    self.assertEqual(trigger(trainer), accumulated)
                    accumulated = False
示例#43
0
 def _create_mock_trainer(self, iterations):
     trainer = testing.get_trainer_with_mock_updater(
         (iterations, 'iteration'))
     trainer.updater.update_core = lambda: time.sleep(0.001)
     return trainer
示例#44
0
 def setUp(self):
     self.trainer = testing.get_trainer_with_mock_updater(
         self.stop_trigger, self.iter_per_epoch)
示例#45
0
 def setUp(self):
     self.trainer = testing.get_trainer_with_mock_updater()
     self.trainer.out = '.'
     self.trainer._done = True
示例#46
0
 def setUp(self):
     self.trainer = testing.get_trainer_with_mock_updater((5, 'iteration'))
示例#47
0
 def setUp(self):
     self.trainer = testing.get_trainer_with_mock_updater()
     self.trainer.out = '.'
     self.trainer._done = True
示例#48
0
 def setUp(self):
     self.trainer = testing.get_trainer_with_mock_updater()
     self.trainer.out = '.'
     self.filename = 'myfile-deadbeef.dat'
示例#49
0
 def setUp(self):
     self.trainer = testing.get_trainer_with_mock_updater()
     self.trainer.out = '.'
     self.filename = 'myfile-deadbeef.dat'
 def _create_mock_trainer(self, iterations):
     trainer = testing.get_trainer_with_mock_updater(
         (iterations, 'iteration'))
     trainer.updater.update_core = lambda: time.sleep(0.001)
     return trainer