def test_resumed_trigger_sparse_call(self): trainer = testing.get_trainer_with_mock_updater( stop_trigger=None, iter_per_epoch=self.iter_per_epoch) accumulated = False with tempfile.NamedTemporaryFile(delete=False) as f: trigger = training.triggers.ManualScheduleTrigger(*self.schedule) for expected, finished in zip(self.expected[:self.resume], self.finished[:self.resume]): trainer.updater.update() accumulated = accumulated or expected if random.randrange(2): self.assertEqual(trigger(trainer), accumulated) self.assertEqual(trigger.finished, finished) accumulated = False torch.save(trigger.state_dict(), f.name) trigger = training.triggers.ManualScheduleTrigger(*self.schedule) trigger.load_state_dict(torch.load(f.name)) for expected, finished in zip(self.expected[self.resume:], self.finished[self.resume:]): trainer.updater.update() accumulated = accumulated or expected if random.randrange(2): self.assertEqual(trigger(trainer), accumulated) self.assertEqual(trigger.finished, finished) accumulated = False
def test_trigger(self): trainer = testing.get_trainer_with_mock_updater( stop_trigger=None, iter_per_epoch=self.iter_per_epoch) trigger = training.triggers.IntervalTrigger(*self.interval) # before the first iteration, trigger should be False for expected in [False] + self.expected: self.assertEqual(trigger(trainer), expected) trainer.updater.update()
def test_raise_error_if_call_not_implemented(self): class MyExtension(training.Extension): pass ext = MyExtension() trainer = testing.get_trainer_with_mock_updater() with pytest.raises(NotImplementedError): ext(trainer)
def test_trigger(self): trainer = testing.get_trainer_with_mock_updater( stop_trigger=None, iter_per_epoch=self.iter_per_epoch) trigger = training.triggers.OnceTrigger(self.call_on_resume) for expected, finished in zip(self.expected, self.finished): self.assertEqual(trigger.finished, finished) self.assertEqual(trigger(trainer), expected) trainer.updater.update()
def test_trigger(self): trainer = testing.get_trainer_with_mock_updater( stop_trigger=None, iter_per_epoch=self.iter_per_epoch) trigger = training.triggers.ManualScheduleTrigger(*self.schedule) for expected, finished in zip(self.expected, self.finished): trainer.updater.update() self.assertEqual(trigger(trainer), expected) self.assertEqual(trigger.finished, finished)
def test_remove_stale_snapshots(self): fmt = 'snapshot_iter_{.updater.iteration}' retain = 3 snapshot = extensions.snapshot(filename=fmt, n_retains=retain, autoload=False) trainer = testing.get_trainer_with_mock_updater() trainer.out = self.path trainer.extend(snapshot, trigger=(1, 'iteration'), priority=2) class TimeStampUpdater(): t = time.time() - 100 name = 'ts_updater' priority = 1 # This must be called after snapshot taken def __call__(self, _trainer): filename = os.path.join(_trainer.out, fmt.format(_trainer)) self.t += 1 # For filesystems that does low timestamp precision os.utime(filename, (self.t, self.t)) trainer.extend(TimeStampUpdater(), trigger=(1, 'iteration')) trainer.run() assert 10 == trainer.updater.iteration assert trainer._done pattern = os.path.join(trainer.out, "snapshot_iter_*") found = [os.path.basename(path) for path in glob.glob(pattern)] assert retain == len(found) found.sort() # snapshot_iter_(8, 9, 10) expected expected = ['snapshot_iter_{}'.format(i) for i in range(8, 11)] expected.sort() assert expected == found trainer2 = testing.get_trainer_with_mock_updater() trainer2.out = self.path assert not trainer2._done snapshot2 = extensions.snapshot(filename=fmt, autoload=True) # Just making sure no error occurs snapshot2.initialize(trainer2)
def test_trigger_sparse_call(self): trainer = testing.get_trainer_with_mock_updater( stop_trigger=None, iter_per_epoch=self.iter_per_epoch) trigger = training.triggers.ManualScheduleTrigger(*self.schedule) accumulated = False for expected, finished in zip(self.expected, self.finished): trainer.updater.update() accumulated = accumulated or expected if random.randrange(2): self.assertEqual(trigger(trainer), accumulated) self.assertEqual(trigger.finished, finished) accumulated = False
def setUp(self): self.optimizer = mock.MagicMock() self.optimizer.param_groups = [{'x': None}] self.extension = extensions.MultistepShift( 'x', self.gamma, self.step_value, self.init, self.optimizer) self.interval = 1 self.expect = [e for e in self.expect for _ in range(self.interval)] self.trigger = util.get_trigger((self.interval, 'iteration')) self.trainer = testing.get_trainer_with_mock_updater(self.trigger) self.trainer.updater.get_optimizer.return_value = self.optimizer
def setUp(self): self.optimizer = mock.MagicMock() self.optimizer.param_groups = [{'x': None}] self.extension = extensions.LinearShift('x', self.value_range, self.time_range, self.optimizer) self.interval = 2 self.trigger = training.get_trigger((self.interval, 'iteration')) self.trainer = testing.get_trainer_with_mock_updater(self.trigger) self.trainer.updater.get_optimizer.return_value = self.optimizer
def test_trigger_sparse_call(self): trainer = testing.get_trainer_with_mock_updater( stop_trigger=None, iter_per_epoch=self.iter_per_epoch) trigger = training.triggers.IntervalTrigger(*self.interval) accumulated = False # before the first iteration, trigger should be False for expected in [False] + self.expected: accumulated = accumulated or expected if random.randrange(2): self.assertEqual(trigger(trainer), accumulated) accumulated = False trainer.updater.update()
def setUp(self): self.optimizer = mock.MagicMock() self.optimizer.param_groups = [{'x': None}] self.extension = extensions.PolynomialShift('x', self.rate, self.max_count, self.init, self.target, self.optimizer) self.interval = 4 self.expect = [e for e in self.expect for _ in range(self.interval)] self.trigger = util.get_trigger((self.interval, 'iteration')) self.trainer = testing.get_trainer_with_mock_updater(self.trigger) self.trainer.updater.get_optimizer.return_value = self.optimizer
def test_resumed_trigger(self): trainer = testing.get_trainer_with_mock_updater( stop_trigger=None, iter_per_epoch=self.iter_per_epoch) with tempfile.NamedTemporaryFile(delete=False) as f: trigger = training.triggers.IntervalTrigger(*self.interval) for expected in self.expected[:self.resume]: trainer.updater.update() self.assertEqual(trigger(trainer), expected) torch.save(trigger.state_dict(), f.name) trigger = training.triggers.IntervalTrigger(*self.interval) trigger.load_state_dict(torch.load(f.name)) for expected in self.expected[self.resume:]: trainer.updater.update() self.assertEqual(trigger(trainer), expected)
def test_trigger_sparse_call(self): trainer = testing.get_trainer_with_mock_updater( stop_trigger=None, iter_per_epoch=self.iter_per_epoch) trigger = training.triggers.OnceTrigger(self.call_on_resume) accumulated = False accumulated_finished = True for expected, finished in zip(self.expected, self.finished): accumulated = accumulated or expected accumulated_finished = accumulated_finished and finished if random.randrange(2): self.assertEqual(trigger.finished, accumulated_finished) self.assertEqual(trigger(trainer), accumulated) accumulated = False accumulated_finished = True trainer.updater.update()
def test_resumed_trigger_backward_compat(self): trainer = testing.get_trainer_with_mock_updater( stop_trigger=None, iter_per_epoch=self.iter_per_epoch) with tempfile.NamedTemporaryFile(delete=False) as f: trigger = training.triggers.IntervalTrigger(*self.interval) for expected in self.expected[:self.resume]: trainer.updater.update() self.assertEqual(trigger(trainer), expected) # old version does not save anything torch.save(dict(dummy=0), f.name) trigger = training.triggers.IntervalTrigger(*self.interval) with testing.assert_warns(UserWarning): trigger.load_state_dict(torch.load(f.name)) for expected in self.expected[self.resume:]: trainer.updater.update() self.assertEqual(trigger(trainer), expected)
def test_resume(self): new_optimizer = mock.Mock() new_optimizer.param_groups = [{'x': None}] new_extension = extensions.LinearShift('x', self.value_range, self.time_range, new_optimizer) self.trainer.extend(self.extension) self.trainer.run() new_trainer = testing.get_trainer_with_mock_updater((5, 'iteration')) new_trainer.extend(new_extension) testing.save_and_load_pth(self.trainer, new_trainer) new_extension.initialize(new_trainer) self.assertEqual(new_optimizer.param_groups[0]['x'], self.optimizer.param_groups[0]['x']) self.assertIsInstance(new_optimizer.param_groups[0]['x'], float)
def test_resume(self): new_optimizer = mock.Mock() new_optimizer.param_groups = [{'x': None}] new_extension = extensions.InverseShift('x', self.gamma, self.power, self.init, self.target, new_optimizer) self.trainer.extend(self.extension) self.trainer.run() new_trainer = testing.get_trainer_with_mock_updater((3, 'iteration')) new_trainer.extend(new_extension) testing.save_and_load_pth(self.trainer, new_trainer) new_extension.initialize(new_trainer) self.assertEqual(new_optimizer.param_groups[0]['x'], self.optimizer.param_groups[0]['x']) self.assertIsInstance(new_optimizer.param_groups[0]['x'], float)
def _setup(self, stream=None, delete_flush=False): self.logreport = mock.MagicMock(spec=extensions.LogReport( ['epoch'], trigger=(1, 'iteration'), log_name=None)) if stream is None: self.stream = mock.MagicMock() if delete_flush: del self.stream.flush else: self.stream = stream self.report = extensions.PrintReport(['epoch'], log_report=self.logreport, out=self.stream) self.trainer = testing.get_trainer_with_mock_updater( stop_trigger=(1, 'iteration')) self.trainer.extend(self.logreport) self.trainer.extend(self.report) self.logreport.log = [{'epoch': 0}]
def test_resumed_trigger(self): trainer = testing.get_trainer_with_mock_updater( stop_trigger=None, iter_per_epoch=self.iter_per_epoch) with tempfile.NamedTemporaryFile(delete=False) as f: trigger = training.triggers.OnceTrigger(self.call_on_resume) for expected, finished in zip(self.resumed_expected[:self.resume], self.resumed_finished[:self.resume]): trainer.updater.update() self.assertEqual(trigger.finished, finished) self.assertEqual(trigger(trainer), expected) torch.save(trigger.state_dict(), f.name) trigger = training.triggers.OnceTrigger(self.call_on_resume) trigger.load_state_dict(torch.load(f.name)) for expected, finished in zip(self.resumed_expected[self.resume:], self.resumed_finished[self.resume:]): trainer.updater.update() self.assertEqual(trigger.finished, finished) self.assertEqual(trigger(trainer), expected)
def _test_trigger(self, trigger, key, accuracies, expected, resume=None, save=None): trainer = testing.get_trainer_with_mock_updater( stop_trigger=(len(accuracies), 'iteration'), iter_per_epoch=self.iter_per_epoch) updater = trainer.updater def _state_dict_updater(): return { 'iteration': updater.iteration, 'epoch': updater.epoch, 'is_new_epoch': updater.is_new_epoch } trainer.updater.state_dict = _state_dict_updater def _load_state_dict_updater(state_dict): updater.iteration = state_dict['iteration'] updater.epoch = state_dict['epoch'] updater.is_new_epoch = state_dict['is_new_epoch'] trainer.updater.load_state_dict = _load_state_dict_updater def set_observation(t): t.observation = {key: accuracies[t.updater.iteration-1]} trainer.extend(set_observation, name='set_observation', trigger=(1, 'iteration'), priority=2) invoked_iterations = [] def record(t): invoked_iterations.append(t.updater.iteration) trainer.extend(record, name='record', trigger=trigger, priority=1) if resume is not None: trainer.load_state_dict(torch.load(resume)) trainer.run() self.assertEqual(invoked_iterations, expected) if save is not None: torch.save(trainer.state_dict(), save)
def _create_mock_trainer(self, iterations): trainer = testing.get_trainer_with_mock_updater( (iterations, 'iteration')) trainer.updater.update_core = lambda: time.sleep(0.001) return trainer
def setUp(self): self.trainer = testing.get_trainer_with_mock_updater( self.stop_trigger, self.iter_per_epoch, extensions=self.extensions)
def setUp(self): self.trainer = testing.get_trainer_with_mock_updater() self.trainer.out = '.' self.filename = 'myfile-deadbeef.dat'
def setUp(self): self.trainer = testing.get_trainer_with_mock_updater() self.trainer.out = '.' self.trainer._done = True