def test_resumed_trigger_sparse_call(self):
        trainer = testing.get_trainer_with_mock_updater(
            stop_trigger=None, iter_per_epoch=self.iter_per_epoch)
        accumulated = False
        with tempfile.NamedTemporaryFile(delete=False) as f:
            trigger = training.triggers.ManualScheduleTrigger(*self.schedule)
            for expected, finished in zip(self.expected[:self.resume],
                                          self.finished[:self.resume]):
                trainer.updater.update()
                accumulated = accumulated or expected
                if random.randrange(2):
                    self.assertEqual(trigger(trainer), accumulated)
                    self.assertEqual(trigger.finished, finished)
                    accumulated = False
            torch.save(trigger.state_dict(), f.name)

            trigger = training.triggers.ManualScheduleTrigger(*self.schedule)
            trigger.load_state_dict(torch.load(f.name))
            for expected, finished in zip(self.expected[self.resume:],
                                          self.finished[self.resume:]):
                trainer.updater.update()
                accumulated = accumulated or expected
                if random.randrange(2):
                    self.assertEqual(trigger(trainer), accumulated)
                    self.assertEqual(trigger.finished, finished)
                    accumulated = False
Exemple #2
0
 def test_trigger(self):
     trainer = testing.get_trainer_with_mock_updater(
         stop_trigger=None, iter_per_epoch=self.iter_per_epoch)
     trigger = training.triggers.IntervalTrigger(*self.interval)
     # before the first iteration, trigger should be False
     for expected in [False] + self.expected:
         self.assertEqual(trigger(trainer), expected)
         trainer.updater.update()
    def test_raise_error_if_call_not_implemented(self):
        class MyExtension(training.Extension):
            pass

        ext = MyExtension()
        trainer = testing.get_trainer_with_mock_updater()
        with pytest.raises(NotImplementedError):
            ext(trainer)
Exemple #4
0
 def test_trigger(self):
     trainer = testing.get_trainer_with_mock_updater(
         stop_trigger=None, iter_per_epoch=self.iter_per_epoch)
     trigger = training.triggers.OnceTrigger(self.call_on_resume)
     for expected, finished in zip(self.expected, self.finished):
         self.assertEqual(trigger.finished, finished)
         self.assertEqual(trigger(trainer), expected)
         trainer.updater.update()
 def test_trigger(self):
     trainer = testing.get_trainer_with_mock_updater(
         stop_trigger=None, iter_per_epoch=self.iter_per_epoch)
     trigger = training.triggers.ManualScheduleTrigger(*self.schedule)
     for expected, finished in zip(self.expected, self.finished):
         trainer.updater.update()
         self.assertEqual(trigger(trainer), expected)
         self.assertEqual(trigger.finished, finished)
    def test_remove_stale_snapshots(self):
        fmt = 'snapshot_iter_{.updater.iteration}'
        retain = 3
        snapshot = extensions.snapshot(filename=fmt,
                                       n_retains=retain,
                                       autoload=False)

        trainer = testing.get_trainer_with_mock_updater()
        trainer.out = self.path
        trainer.extend(snapshot, trigger=(1, 'iteration'), priority=2)

        class TimeStampUpdater():
            t = time.time() - 100
            name = 'ts_updater'
            priority = 1  # This must be called after snapshot taken

            def __call__(self, _trainer):
                filename = os.path.join(_trainer.out, fmt.format(_trainer))
                self.t += 1
                # For filesystems that does low timestamp precision
                os.utime(filename, (self.t, self.t))

        trainer.extend(TimeStampUpdater(), trigger=(1, 'iteration'))
        trainer.run()
        assert 10 == trainer.updater.iteration
        assert trainer._done

        pattern = os.path.join(trainer.out, "snapshot_iter_*")
        found = [os.path.basename(path) for path in glob.glob(pattern)]
        assert retain == len(found)
        found.sort()
        # snapshot_iter_(8, 9, 10) expected
        expected = ['snapshot_iter_{}'.format(i) for i in range(8, 11)]
        expected.sort()
        assert expected == found

        trainer2 = testing.get_trainer_with_mock_updater()
        trainer2.out = self.path
        assert not trainer2._done
        snapshot2 = extensions.snapshot(filename=fmt, autoload=True)
        # Just making sure no error occurs
        snapshot2.initialize(trainer2)
 def test_trigger_sparse_call(self):
     trainer = testing.get_trainer_with_mock_updater(
         stop_trigger=None, iter_per_epoch=self.iter_per_epoch)
     trigger = training.triggers.ManualScheduleTrigger(*self.schedule)
     accumulated = False
     for expected, finished in zip(self.expected, self.finished):
         trainer.updater.update()
         accumulated = accumulated or expected
         if random.randrange(2):
             self.assertEqual(trigger(trainer), accumulated)
             self.assertEqual(trigger.finished, finished)
             accumulated = False
    def setUp(self):
        self.optimizer = mock.MagicMock()
        self.optimizer.param_groups = [{'x': None}]
        self.extension = extensions.MultistepShift(
            'x', self.gamma, self.step_value, self.init, self.optimizer)

        self.interval = 1
        self.expect = [e for e in self.expect for _ in range(self.interval)]
        self.trigger = util.get_trigger((self.interval, 'iteration'))

        self.trainer = testing.get_trainer_with_mock_updater(self.trigger)
        self.trainer.updater.get_optimizer.return_value = self.optimizer
Exemple #9
0
    def setUp(self):
        self.optimizer = mock.MagicMock()
        self.optimizer.param_groups = [{'x': None}]
        self.extension = extensions.LinearShift('x', self.value_range,
                                                self.time_range,
                                                self.optimizer)

        self.interval = 2
        self.trigger = training.get_trigger((self.interval, 'iteration'))

        self.trainer = testing.get_trainer_with_mock_updater(self.trigger)
        self.trainer.updater.get_optimizer.return_value = self.optimizer
Exemple #10
0
 def test_trigger_sparse_call(self):
     trainer = testing.get_trainer_with_mock_updater(
         stop_trigger=None, iter_per_epoch=self.iter_per_epoch)
     trigger = training.triggers.IntervalTrigger(*self.interval)
     accumulated = False
     # before the first iteration, trigger should be False
     for expected in [False] + self.expected:
         accumulated = accumulated or expected
         if random.randrange(2):
             self.assertEqual(trigger(trainer), accumulated)
             accumulated = False
         trainer.updater.update()
    def setUp(self):
        self.optimizer = mock.MagicMock()
        self.optimizer.param_groups = [{'x': None}]
        self.extension = extensions.PolynomialShift('x', self.rate,
                                                    self.max_count, self.init,
                                                    self.target,
                                                    self.optimizer)

        self.interval = 4
        self.expect = [e for e in self.expect for _ in range(self.interval)]
        self.trigger = util.get_trigger((self.interval, 'iteration'))

        self.trainer = testing.get_trainer_with_mock_updater(self.trigger)
        self.trainer.updater.get_optimizer.return_value = self.optimizer
Exemple #12
0
    def test_resumed_trigger(self):
        trainer = testing.get_trainer_with_mock_updater(
            stop_trigger=None, iter_per_epoch=self.iter_per_epoch)
        with tempfile.NamedTemporaryFile(delete=False) as f:
            trigger = training.triggers.IntervalTrigger(*self.interval)
            for expected in self.expected[:self.resume]:
                trainer.updater.update()
                self.assertEqual(trigger(trainer), expected)
            torch.save(trigger.state_dict(), f.name)

            trigger = training.triggers.IntervalTrigger(*self.interval)
            trigger.load_state_dict(torch.load(f.name))
            for expected in self.expected[self.resume:]:
                trainer.updater.update()
                self.assertEqual(trigger(trainer), expected)
Exemple #13
0
 def test_trigger_sparse_call(self):
     trainer = testing.get_trainer_with_mock_updater(
         stop_trigger=None, iter_per_epoch=self.iter_per_epoch)
     trigger = training.triggers.OnceTrigger(self.call_on_resume)
     accumulated = False
     accumulated_finished = True
     for expected, finished in zip(self.expected, self.finished):
         accumulated = accumulated or expected
         accumulated_finished = accumulated_finished and finished
         if random.randrange(2):
             self.assertEqual(trigger.finished, accumulated_finished)
             self.assertEqual(trigger(trainer), accumulated)
             accumulated = False
             accumulated_finished = True
         trainer.updater.update()
Exemple #14
0
    def test_resumed_trigger_backward_compat(self):
        trainer = testing.get_trainer_with_mock_updater(
            stop_trigger=None, iter_per_epoch=self.iter_per_epoch)
        with tempfile.NamedTemporaryFile(delete=False) as f:
            trigger = training.triggers.IntervalTrigger(*self.interval)
            for expected in self.expected[:self.resume]:
                trainer.updater.update()
                self.assertEqual(trigger(trainer), expected)
            # old version does not save anything
            torch.save(dict(dummy=0), f.name)

            trigger = training.triggers.IntervalTrigger(*self.interval)
            with testing.assert_warns(UserWarning):
                trigger.load_state_dict(torch.load(f.name))
            for expected in self.expected[self.resume:]:
                trainer.updater.update()
                self.assertEqual(trigger(trainer), expected)
Exemple #15
0
    def test_resume(self):
        new_optimizer = mock.Mock()
        new_optimizer.param_groups = [{'x': None}]
        new_extension = extensions.LinearShift('x', self.value_range,
                                               self.time_range, new_optimizer)

        self.trainer.extend(self.extension)
        self.trainer.run()

        new_trainer = testing.get_trainer_with_mock_updater((5, 'iteration'))
        new_trainer.extend(new_extension)
        testing.save_and_load_pth(self.trainer, new_trainer)

        new_extension.initialize(new_trainer)
        self.assertEqual(new_optimizer.param_groups[0]['x'],
                         self.optimizer.param_groups[0]['x'])
        self.assertIsInstance(new_optimizer.param_groups[0]['x'], float)
    def test_resume(self):
        new_optimizer = mock.Mock()
        new_optimizer.param_groups = [{'x': None}]
        new_extension = extensions.InverseShift('x', self.gamma, self.power,
                                                self.init, self.target,
                                                new_optimizer)

        self.trainer.extend(self.extension)
        self.trainer.run()

        new_trainer = testing.get_trainer_with_mock_updater((3, 'iteration'))
        new_trainer.extend(new_extension)
        testing.save_and_load_pth(self.trainer, new_trainer)

        new_extension.initialize(new_trainer)
        self.assertEqual(new_optimizer.param_groups[0]['x'],
                         self.optimizer.param_groups[0]['x'])
        self.assertIsInstance(new_optimizer.param_groups[0]['x'], float)
    def _setup(self, stream=None, delete_flush=False):
        self.logreport = mock.MagicMock(spec=extensions.LogReport(
            ['epoch'], trigger=(1, 'iteration'), log_name=None))
        if stream is None:
            self.stream = mock.MagicMock()
            if delete_flush:
                del self.stream.flush
        else:
            self.stream = stream
        self.report = extensions.PrintReport(['epoch'],
                                             log_report=self.logreport,
                                             out=self.stream)

        self.trainer = testing.get_trainer_with_mock_updater(
            stop_trigger=(1, 'iteration'))
        self.trainer.extend(self.logreport)
        self.trainer.extend(self.report)
        self.logreport.log = [{'epoch': 0}]
Exemple #18
0
    def test_resumed_trigger(self):
        trainer = testing.get_trainer_with_mock_updater(
            stop_trigger=None, iter_per_epoch=self.iter_per_epoch)
        with tempfile.NamedTemporaryFile(delete=False) as f:
            trigger = training.triggers.OnceTrigger(self.call_on_resume)
            for expected, finished in zip(self.resumed_expected[:self.resume],
                                          self.resumed_finished[:self.resume]):
                trainer.updater.update()
                self.assertEqual(trigger.finished, finished)
                self.assertEqual(trigger(trainer), expected)
            torch.save(trigger.state_dict(), f.name)

            trigger = training.triggers.OnceTrigger(self.call_on_resume)
            trigger.load_state_dict(torch.load(f.name))
            for expected, finished in zip(self.resumed_expected[self.resume:],
                                          self.resumed_finished[self.resume:]):
                trainer.updater.update()
                self.assertEqual(trigger.finished, finished)
                self.assertEqual(trigger(trainer), expected)
Exemple #19
0
    def _test_trigger(self, trigger, key, accuracies, expected,
                      resume=None, save=None):
        trainer = testing.get_trainer_with_mock_updater(
            stop_trigger=(len(accuracies), 'iteration'),
            iter_per_epoch=self.iter_per_epoch)
        updater = trainer.updater

        def _state_dict_updater():
            return {
                'iteration': updater.iteration,
                'epoch': updater.epoch,
                'is_new_epoch': updater.is_new_epoch
            }
        trainer.updater.state_dict = _state_dict_updater

        def _load_state_dict_updater(state_dict):
            updater.iteration = state_dict['iteration']
            updater.epoch = state_dict['epoch']
            updater.is_new_epoch = state_dict['is_new_epoch']
        trainer.updater.load_state_dict = _load_state_dict_updater

        def set_observation(t):
            t.observation = {key: accuracies[t.updater.iteration-1]}
        trainer.extend(set_observation, name='set_observation',
                       trigger=(1, 'iteration'), priority=2)

        invoked_iterations = []

        def record(t):
            invoked_iterations.append(t.updater.iteration)
        trainer.extend(record, name='record', trigger=trigger, priority=1)

        if resume is not None:
            trainer.load_state_dict(torch.load(resume))

        trainer.run()
        self.assertEqual(invoked_iterations, expected)

        if save is not None:
            torch.save(trainer.state_dict(), save)
 def _create_mock_trainer(self, iterations):
     trainer = testing.get_trainer_with_mock_updater(
         (iterations, 'iteration'))
     trainer.updater.update_core = lambda: time.sleep(0.001)
     return trainer
Exemple #21
0
 def setUp(self):
     self.trainer = testing.get_trainer_with_mock_updater(
         self.stop_trigger, self.iter_per_epoch,
         extensions=self.extensions)
 def setUp(self):
     self.trainer = testing.get_trainer_with_mock_updater()
     self.trainer.out = '.'
     self.filename = 'myfile-deadbeef.dat'
 def setUp(self):
     self.trainer = testing.get_trainer_with_mock_updater()
     self.trainer.out = '.'
     self.trainer._done = True