Beispiel #1
0
    def test_resume_from_multiple_of_epoches(self, mock_evaluator, mock_checkpoint, mock_func):
        mock_optim = mock.Mock()

        trainer = SupervisedTrainer()
        trainer.model = mock.Mock()
        trainer.optimizer = mock.Mock()

        callbacks = CallbackContainer(trainer)

        n_epoches = 1
        start_epoch = 1
        step = 7
        trainer.set_local_parameters(123, [], [], [], 1, 1)
        trainer._train_epoches(
            self.data_iterator, n_epoches, start_epoch, step, callbacks, self.data_iterator)
Beispiel #2
0
    def test_batch_num_when_resuming(self, mock_evaluator, mock_checkpoint, mock_func):

        trainer = SupervisedTrainer()
        trainer.model = mock.Mock()
        trainer.optimizer = mock.Mock()

        callbacks = CallbackContainer(trainer)

        n_epoches = 1
        start_epoch = 1
        steps_per_epoch = len(self.data_iterator)
        step = 3
        trainer.set_local_parameters(123, [], [], [], 1, 1)
        trainer._train_epoches(self.data_iterator, n_epoches,
                               start_epoch, step, callbacks, self.data_iterator)

        self.assertEqual(steps_per_epoch - step, mock_func.call_count)