def setUp(self): self.optimizer = OptimizerMock() self.trainer = Trainer(ModelMock(), CriterionMock(), self.optimizer, DatasetMock()) self.num_epochs = 3 self.dataset_size = len(self.trainer.dataset) self.num_iters = self.num_epochs * self.dataset_size
class TestTrainer(TestCase): intervals = [ [(1, 'iteration')], [(1, 'epoch')], [(1, 'batch')], [(1, 'update')], [(5, 'iteration')], [(5, 'epoch')], [(5, 'batch')], [(5, 'update')], [(1, 'iteration'), (1, 'epoch')], [(5, 'update'), (1, 'iteration')], [(2, 'epoch'), (1, 'batch')], ] def setUp(self): self.optimizer = OptimizerMock() self.trainer = Trainer(ModelMock(), CriterionMock(), self.optimizer, DatasetMock()) self.num_epochs = 3 self.dataset_size = len(self.trainer.dataset) self.num_iters = self.num_epochs * self.dataset_size def test_register_plugin(self): for interval in self.intervals: simple_plugin = SimplePlugin(interval) self.trainer.register_plugin(simple_plugin) self.assertEqual(simple_plugin.trainer, self.trainer) def test_optimizer_step(self): self.trainer.run(epochs=1) self.assertEqual(self.trainer.optimizer.num_steps, 10) def test_plugin_interval(self): for interval in self.intervals: self.setUp() simple_plugin = SimplePlugin(interval) self.trainer.register_plugin(simple_plugin) self.trainer.run(epochs=self.num_epochs) units = { ('iteration', self.num_iters), ('epoch', self.num_epochs), ('batch', self.num_iters), ('update', self.num_iters) } for unit, num_triggers in units: call_every = None for i, i_unit in interval: if i_unit == unit: call_every = i break if call_every: expected_num_calls = math.floor(num_triggers / call_every) else: expected_num_calls = 0 num_calls = getattr(simple_plugin, 'num_' + unit) self.assertEqual(num_calls, expected_num_calls, 0) def test_model_called(self): self.trainer.run(epochs=self.num_epochs) num_model_calls = self.trainer.model.num_calls num_crit_calls = self.trainer.criterion.num_calls self.assertEqual(num_model_calls, num_crit_calls) for num_calls in [num_model_calls, num_crit_calls]: lower_bound = OptimizerMock.min_evals * self.num_iters upper_bound = OptimizerMock.max_evals * self.num_iters self.assertEqual(num_calls, self.trainer.optimizer.num_evals) self.assertLessEqual(lower_bound, num_calls) self.assertLessEqual(num_calls, upper_bound) def test_model_gradient(self): self.trainer.run(epochs=self.num_epochs) output_var = self.trainer.model.output expected_grad = torch.ones(1, 1) * 2 * self.optimizer.num_evals self.assertEqual(output_var.grad.data, expected_grad)