Exemple #1
0
    def _test_fitting(self, params, logs, has_valid=True):
        self.assertEqual(len(logs), params['epochs'])
        train_dict = dict(zip(self.metrics_names, self.metrics_values), loss=ANY)
        if has_valid:
            val_metrics_names = ['val_' + metric_name for metric_name in self.metrics_names]
            val_dict = dict(zip(val_metrics_names, self.metrics_values), val_loss=ANY)
            log_dict = {**train_dict, **val_dict}
        else:
            log_dict = train_dict

        for epoch, log in enumerate(logs, 1):
            self.assertEqual(log, dict(log_dict, epoch=epoch))

        call_list = []
        call_list.append(call.on_train_begin({}))
        for epoch in range(1, params['epochs']+1):
            call_list.append(call.on_epoch_begin(epoch, {}))
            for step in range(1, params['steps']+1):
                call_list.append(call.on_batch_begin(step, {}))
                call_list.append(call.on_backward_end(step))
                call_list.append(call.on_batch_end(step, {'batch': step, 'size': ANY, **train_dict}))
            call_list.append(call.on_epoch_end(epoch, {'epoch': epoch, **log_dict}))
        call_list.append(call.on_train_end({}))

        method_calls = self.mock_callback.method_calls
        self.assertIn(call.set_model(self.model), method_calls[:2])
        self.assertIn(call.set_params(params), method_calls[:2])

        self.assertEqual(len(method_calls), len(call_list) + 2)
        self.assertEqual(method_calls[2:], call_list)
    def test_givenAFasttextAddressParser_whenRetrainWithConfigWithCallbacksNewTags_thenCallbackAreUse(
        self,
    ):
        address_parser = AddressParser(
            model_type=self.a_fasttext_model_type,
            device=self.a_cpu_device,
            verbose=self.verbose,
        )

        callback_mock = MagicMock(spec=Callback)
        performance_after_training = address_parser.retrain(
            self.new_prediction_data_container,
            self.a_train_ratio,
            epochs=self.a_single_epoch,
            batch_size=self.a_batch_size,
            num_workers=self.a_number_of_workers,
            learning_rate=self.a_learning_rate,
            callbacks=[callback_mock],
            logging_path=self.a_checkpoints_saving_dir,
            prediction_tags=self.with_new_prediction_tags,
        )

        self.assertIsNotNone(performance_after_training)

        callback_train_start_call = [call.on_train_begin({})]
        callback_mock.assert_has_calls(callback_train_start_call)
        callback_train_end_call = [call.on_train_end({})]
        callback_mock.assert_has_calls(callback_train_end_call)
        callback_mock.assert_not_called()
Exemple #3
0
    def test_epoch_delay(self):
        epoch_delay = 4
        delay_callback = DelayCallback(self.mock_callback, epoch_delay=epoch_delay)
        train_generator = some_data_generator(DelayCallbackTest.batch_size)
        valid_generator = some_data_generator(DelayCallbackTest.batch_size)
        self.model.fit_generator(train_generator,
                                 valid_generator,
                                 epochs=DelayCallbackTest.epochs,
                                 steps_per_epoch=DelayCallbackTest.steps_per_epoch,
                                 validation_steps=DelayCallbackTest.steps_per_epoch,
                                 callbacks=[delay_callback])
        params = {'epochs': DelayCallbackTest.epochs, 'steps': DelayCallbackTest.steps_per_epoch}

        call_list = []
        call_list.append(call.on_train_begin({}))
        for epoch in range(epoch_delay + 1, DelayCallbackTest.epochs + 1):
            call_list.append(call.on_epoch_begin(epoch, {}))
            for step in range(1, params['steps'] + 1):
                call_list.append(call.on_train_batch_begin(step, {}))
                call_list.append(call.on_backward_end(step))
                call_list.append(
                    call.on_train_batch_end(step, {
                        'batch': step,
                        'size': DelayCallbackTest.batch_size,
                        **self.train_dict
                    }))
            call_list.append(call.on_epoch_end(epoch, {'epoch': epoch, **self.log_dict}))
        call_list.append(call.on_train_end({}))

        method_calls = self.mock_callback.method_calls
        self.assertIn(call.set_model(self.model), method_calls[:2])
        self.assertIn(call.set_params(params), method_calls[:2])

        self.assertEqual(len(method_calls), len(call_list) + 2)
        self.assertEqual(method_calls[2:], call_list)
Exemple #4
0
    def _test_callbacks_train(self, params, logs, has_valid=True, steps=None):
        # pylint: disable=too-many-arguments
        if steps is None:
            steps = params['steps']
        self.assertEqual(len(logs), params['epochs'])
        train_batch_dict = dict(zip(self.batch_metrics_names, self.batch_metrics_values), loss=ANY, time=ANY)
        train_epochs_dict = dict(zip(self.epoch_metrics_names, self.epoch_metrics_values))
        log_dict = {**train_batch_dict, **train_epochs_dict}
        if has_valid:
            val_batch_metrics_names = ['val_' + metric_name for metric_name in self.batch_metrics_names]
            val_batch_dict = dict(zip(val_batch_metrics_names, self.batch_metrics_values), val_loss=ANY)
            val_epoch_metrics_names = ['val_' + metric_name for metric_name in self.epoch_metrics_names]
            val_epochs_dict = dict(zip(val_epoch_metrics_names, self.epoch_metrics_values))
            log_dict.update({**val_batch_dict, **val_epochs_dict})

        for epoch, log in enumerate(logs, 1):
            self.assertEqual(log, dict(log_dict, epoch=epoch))

        call_list = []
        call_list.append(call.on_train_begin({}))
        for epoch in range(1, params['epochs'] + 1):
            call_list.append(call.on_epoch_begin(epoch, {}))
            for step in range(1, steps + 1):
                call_list.append(call.on_train_batch_begin(step, {}))
                call_list.append(call.on_backward_end(step))
                call_list.append(call.on_train_batch_end(step, {'batch': step, 'size': ANY, **train_batch_dict}))
            call_list.append(call.on_epoch_end(epoch, {'epoch': epoch, **log_dict}))
        call_list.append(call.on_train_end({}))

        method_calls = self.mock_callback.method_calls
        self.assertIn(call.set_model(self.model), method_calls[:2])  # skip set_model and set param call
        self.assertIn(call.set_params(params), method_calls[:2])

        self.assertEqual(len(method_calls), len(call_list) + 2)  # for set_model and set param
        self.assertEqual(method_calls[2:], call_list)
Exemple #5
0
    def test_givenABPEmbAddressParser_whenRetrainWithConfigWithCallbacks_thenCallbackAreUse(
            self):
        address_parser = AddressParser(model_type=self.a_bpemb_model_type,
                                       device=self.a_torch_device,
                                       verbose=self.verbose)

        callback_mock = MagicMock(spec=Callback)
        performance_after_training = address_parser.retrain(
            self.training_container,
            self.a_train_ratio,
            epochs=self.a_single_epoch,
            batch_size=self.a_batch_size,
            num_workers=self.a_number_of_workers,
            learning_rate=self.a_learning_rate,
            callbacks=[callback_mock],
            logging_path=self.a_checkpoints_saving_dir)

        self.assertIsNotNone(performance_after_training)

        callback_train_start_call = [call.on_train_begin({})]
        callback_mock.assert_has_calls(callback_train_start_call)
        callback_train_end_call = [call.on_train_end({})]
        callback_mock.assert_has_calls(callback_train_end_call)
        callback_mock.assert_not_called()
Exemple #6
0
def test_trainer_callback_system(torch_save):
    """Test the callback system."""

    model = BoringModel()

    callback_mock = MagicMock()

    trainer_options = dict(
        callbacks=[callback_mock],
        max_epochs=1,
        limit_val_batches=1,
        limit_train_batches=3,
        limit_test_batches=2,
        progress_bar_refresh_rate=0,
    )

    # no call yet
    callback_mock.assert_not_called()

    # fit model
    trainer = Trainer(**trainer_options)

    # check that only the to calls exists
    assert trainer.callbacks[0] == callback_mock
    assert callback_mock.method_calls == [
        call.on_init_start(trainer),
        call.on_init_end(trainer),
    ]

    trainer.fit(model)

    assert callback_mock.method_calls == [
        call.on_init_start(trainer),
        call.on_init_end(trainer),
        call.setup(trainer, model, 'fit'),
        call.on_fit_start(trainer, model),
        call.on_pretrain_routine_start(trainer, model),
        call.on_pretrain_routine_end(trainer, model),
        call.on_sanity_check_start(trainer, model),
        call.on_validation_start(trainer, model),
        call.on_validation_epoch_start(trainer, model),
        call.on_validation_batch_start(trainer, model, ANY, 0, 0),
        call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
        call.on_validation_epoch_end(trainer, model),
        call.on_validation_end(trainer, model),
        call.on_sanity_check_end(trainer, model),
        call.on_train_start(trainer, model),
        call.on_epoch_start(trainer, model),
        call.on_train_epoch_start(trainer, model),
        call.on_batch_start(trainer, model),
        call.on_train_batch_start(trainer, model, ANY, 0, 0),
        call.on_after_backward(trainer, model),
        call.on_before_zero_grad(trainer, model, trainer.optimizers[0]),
        call.on_batch_end(trainer, model),
        call.on_train_batch_end(trainer, model, ANY, ANY, 0, 0),
        call.on_batch_start(trainer, model),
        call.on_train_batch_start(trainer, model, ANY, 1, 0),
        call.on_after_backward(trainer, model),
        call.on_before_zero_grad(trainer, model, trainer.optimizers[0]),
        call.on_batch_end(trainer, model),
        call.on_train_batch_end(trainer, model, ANY, ANY, 1, 0),
        call.on_batch_start(trainer, model),
        call.on_train_batch_start(trainer, model, ANY, 2, 0),
        call.on_after_backward(trainer, model),
        call.on_before_zero_grad(trainer, model, trainer.optimizers[0]),
        call.on_batch_end(trainer, model),
        call.on_train_batch_end(trainer, model, ANY, ANY, 2, 0),
        call.on_validation_start(trainer, model),
        call.on_validation_epoch_start(trainer, model),
        call.on_validation_batch_start(trainer, model, ANY, 0, 0),
        call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
        call.on_validation_epoch_end(trainer, model),
        call.on_validation_end(trainer, model),
        call.on_save_checkpoint(trainer, model),
        call.on_epoch_end(trainer, model),
        call.on_train_epoch_end(trainer, model, ANY),
        call.on_train_end(trainer, model),
        call.on_fit_end(trainer, model),
        call.teardown(trainer, model, 'fit'),
    ]

    callback_mock.reset_mock()
    trainer = Trainer(**trainer_options)
    trainer.test(model)

    assert callback_mock.method_calls == [
        call.on_init_start(trainer),
        call.on_init_end(trainer),
        call.setup(trainer, model, 'test'),
        call.on_fit_start(trainer, model),
        call.on_test_start(trainer, model),
        call.on_test_epoch_start(trainer, model),
        call.on_test_batch_start(trainer, model, ANY, 0, 0),
        call.on_test_batch_end(trainer, model, ANY, ANY, 0, 0),
        call.on_test_batch_start(trainer, model, ANY, 1, 0),
        call.on_test_batch_end(trainer, model, ANY, ANY, 1, 0),
        call.on_test_epoch_end(trainer, model),
        call.on_test_end(trainer, model),
        call.on_fit_end(trainer, model),
        call.teardown(trainer, model, 'fit'),
        call.teardown(trainer, model, 'test'),
    ]
Exemple #7
0
def test_trainer_callback_hook_system_fit(_, tmpdir):
    """Test the callback hook system for fit."""

    model = BoringModel()
    callback_mock = MagicMock()
    trainer = Trainer(
        default_root_dir=tmpdir,
        callbacks=[callback_mock],
        max_epochs=1,
        limit_val_batches=1,
        limit_train_batches=3,
        progress_bar_refresh_rate=0,
    )

    # check that only the to calls exists
    assert trainer.callbacks[0] == callback_mock
    assert callback_mock.method_calls == [
        call.on_init_start(trainer),
        call.on_init_end(trainer),
    ]

    # fit model
    trainer.fit(model)

    assert callback_mock.method_calls == [
        call.on_init_start(trainer),
        call.on_init_end(trainer),
        call.on_before_accelerator_backend_setup(trainer, model),
        call.setup(trainer, model, 'fit'),
        call.on_configure_sharded_model(trainer, model),
        call.on_fit_start(trainer, model),
        call.on_pretrain_routine_start(trainer, model),
        call.on_pretrain_routine_end(trainer, model),
        call.on_sanity_check_start(trainer, model),
        call.on_validation_start(trainer, model),
        call.on_epoch_start(trainer, model),
        call.on_validation_epoch_start(trainer, model),
        call.on_validation_batch_start(trainer, model, ANY, 0, 0),
        call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
        call.on_validation_epoch_end(trainer, model),
        call.on_epoch_end(trainer, model),
        call.on_validation_end(trainer, model),
        call.on_sanity_check_end(trainer, model),
        call.on_train_start(trainer, model),
        call.on_epoch_start(trainer, model),
        call.on_train_epoch_start(trainer, model),
        call.on_batch_start(trainer, model),
        call.on_train_batch_start(trainer, model, ANY, 0, 0),
        call.on_before_zero_grad(trainer, model, trainer.optimizers[0]),
        call.on_after_backward(trainer, model),
        call.on_train_batch_end(trainer, model, ANY, ANY, 0, 0),
        call.on_batch_end(trainer, model),
        call.on_batch_start(trainer, model),
        call.on_train_batch_start(trainer, model, ANY, 1, 0),
        call.on_before_zero_grad(trainer, model, trainer.optimizers[0]),
        call.on_after_backward(trainer, model),
        call.on_train_batch_end(trainer, model, ANY, ANY, 1, 0),
        call.on_batch_end(trainer, model),
        call.on_batch_start(trainer, model),
        call.on_train_batch_start(trainer, model, ANY, 2, 0),
        call.on_before_zero_grad(trainer, model, trainer.optimizers[0]),
        call.on_after_backward(trainer, model),
        call.on_train_batch_end(trainer, model, ANY, ANY, 2, 0),
        call.on_batch_end(trainer, model),
        call.on_train_epoch_end(trainer, model, ANY),
        call.on_epoch_end(trainer, model),
        call.on_validation_start(trainer, model),
        call.on_epoch_start(trainer, model),
        call.on_validation_epoch_start(trainer, model),
        call.on_validation_batch_start(trainer, model, ANY, 0, 0),
        call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
        call.on_validation_epoch_end(trainer, model),
        call.on_epoch_end(trainer, model),
        call.on_validation_end(trainer, model),
        call.on_save_checkpoint(
            trainer,
            model),  # should take ANY but we are inspecting signature for BC
        call.on_train_end(trainer, model),
        call.on_fit_end(trainer, model),
        call.teardown(trainer, model, 'fit'),
    ]
Exemple #8
0
    def _get_callback_expected_on_calls_when_training(self,
                                                      params,
                                                      logs,
                                                      has_valid=True,
                                                      steps=None,
                                                      valid_steps=10):
        # pylint: disable=too-many-arguments
        if steps is None:
            steps = params['steps']

        train_batch_dict = dict(zip(self.batch_metrics_names,
                                    self.batch_metrics_values),
                                time=ANY,
                                loss=ANY)
        train_epochs_dict = dict(
            zip(self.epoch_metrics_names, self.epoch_metrics_values))
        log_dict = {**train_batch_dict, **train_epochs_dict}
        if has_valid:
            val_batch_metrics_names = [
                'val_' + metric_name
                for metric_name in self.batch_metrics_names
            ]
            val_batch_dict = dict(zip(val_batch_metrics_names,
                                      self.batch_metrics_values),
                                  val_loss=ANY)
            val_epoch_metrics_names = [
                'val_' + metric_name
                for metric_name in self.epoch_metrics_names
            ]
            val_epochs_dict = dict(
                zip(val_epoch_metrics_names, self.epoch_metrics_values))
            log_dict.update({**val_batch_dict, **val_epochs_dict})

        for epoch, log in enumerate(logs, 1):
            self.assertEqual(log, dict(log_dict, epoch=epoch))

        call_list = []
        call_list.append(call.on_train_begin({}))
        for epoch in range(1, params['epochs'] + 1):
            call_list.append(call.on_epoch_begin(epoch, {}))
            for step in range(1, steps + 1):
                call_list.append(call.on_train_batch_begin(step, {}))
                call_list.append(call.on_backward_end(step))
                call_list.append(
                    call.on_train_batch_end(step, {
                        'batch': step,
                        'size': ANY,
                        **train_batch_dict
                    }))
            if has_valid:
                call_list.append(call.on_valid_begin({}))
                for step in range(1, valid_steps + 1):
                    call_list.append(call.on_valid_batch_begin(step, {}))
                    call_list.append(
                        call.on_valid_batch_end(
                            step, {
                                'batch': step,
                                'size': ANY,
                                'time': ANY,
                                **val_batch_dict
                            }))
                call_list.append(
                    call.on_valid_end({
                        'time': ANY,
                        **val_batch_dict,
                        **val_epochs_dict
                    }))
            call_list.append(call.on_epoch_end(epoch, logs[epoch - 1]))

        call_list.append(call.on_train_end({}))
        return call_list