def test_resume_from_multiple_of_epoches(self, mock_evaluator, mock_checkpoint, mock_func): mock_optim = mock.Mock() trainer = SupervisedTrainer() trainer.model = mock.Mock() trainer.optimizer = mock.Mock() callbacks = CallbackContainer(trainer) n_epoches = 1 start_epoch = 1 step = 7 trainer.set_local_parameters(123, [], [], [], 1, 1) trainer._train_epoches( self.data_iterator, n_epoches, start_epoch, step, callbacks, self.data_iterator)
def test_batch_num_when_resuming(self, mock_evaluator, mock_checkpoint, mock_func): trainer = SupervisedTrainer() trainer.model = mock.Mock() trainer.optimizer = mock.Mock() callbacks = CallbackContainer(trainer) n_epoches = 1 start_epoch = 1 steps_per_epoch = len(self.data_iterator) step = 3 trainer.set_local_parameters(123, [], [], [], 1, 1) trainer._train_epoches(self.data_iterator, n_epoches, start_epoch, step, callbacks, self.data_iterator) self.assertEqual(steps_per_epoch - step, mock_func.call_count)