Beispiel #1
0
def test_trainer_callback_hook_system_test(tmpdir):
    """Test the callback hook system for test."""

    model = BoringModel()
    callback_mock = MagicMock()
    trainer = Trainer(
        default_root_dir=tmpdir,
        callbacks=[callback_mock],
        max_epochs=1,
        limit_test_batches=2,
        progress_bar_refresh_rate=0,
    )

    trainer.test(model)

    assert callback_mock.method_calls == [
        call.on_init_start(trainer),
        call.on_init_end(trainer),
        call.on_before_accelerator_backend_setup(trainer, model),
        call.setup(trainer, model, 'test'),
        call.on_configure_sharded_model(trainer, model),
        call.on_test_start(trainer, model),
        call.on_epoch_start(trainer, model),
        call.on_test_epoch_start(trainer, model),
        call.on_test_batch_start(trainer, model, ANY, 0, 0),
        call.on_test_batch_end(trainer, model, ANY, ANY, 0, 0),
        call.on_test_batch_start(trainer, model, ANY, 1, 0),
        call.on_test_batch_end(trainer, model, ANY, ANY, 1, 0),
        call.on_test_epoch_end(trainer, model),
        call.on_epoch_end(trainer, model),
        call.on_test_end(trainer, model),
        call.teardown(trainer, model, 'test'),
    ]
Beispiel #2
0
    def test_givenABPEmbAddressParser_whenTestWithConfigWithCallbacks_thenCallbackAreUse(
            self):
        address_parser = AddressParser(model_type=self.a_bpemb_model_type,
                                       device=self.a_torch_device,
                                       verbose=self.verbose)

        self.training(address_parser)

        callback_mock = MagicMock()
        performance_after_test = address_parser.test(
            self.test_container,
            batch_size=self.a_batch_size,
            num_workers=self.a_number_of_workers,
            callbacks=[callback_mock],
            logging_path=self.a_checkpoints_saving_dir)

        self.assertIsNotNone(performance_after_test)

        callback_test_start_call = [call.on_test_begin({})]
        callback_mock.assert_has_calls(callback_test_start_call)
        callback_test_end_call = [
            call.on_test_end({
                "time":
                ANY,
                "test_loss":
                performance_after_test["test_loss"],
                "test_accuracy":
                performance_after_test["test_accuracy"]
            })
        ]
        callback_mock.assert_has_calls(callback_test_end_call)
        callback_mock.assert_not_called()
Beispiel #3
0
    def _test_callbacks_test(self, params, result_log):
        test_batch_dict = dict(zip(self.batch_metrics_names,
                                   self.batch_metrics_values),
                               loss=ANY,
                               time=ANY)

        call_list = []
        call_list.append(call.on_test_begin({}))
        for batch in range(1, params['batch'] + 1):
            call_list.append(call.on_test_batch_begin(batch, {}))
            call_list.append(
                call.on_test_batch_end(batch, {
                    'batch': batch,
                    'size': ANY,
                    **test_batch_dict
                }))
        call_list.append(call.on_test_end(result_log))

        method_calls = self.mock_callback.method_calls
        self.assertEqual(call.set_model(self.model),
                         method_calls[0])  # skip set_model

        self.assertEqual(len(method_calls),
                         len(call_list) + 1)  # for set_model
        self.assertEqual(method_calls[1:], call_list)
Beispiel #4
0
    def _get_callback_expected_on_calls_when_testing(self, params):
        test_batch_dict = {"time": ANY, "test_loss": ANY}
        test_batch_dict.update({
            "test_" + metric_name: metric
            for metric_name, metric in zip(self.batch_metrics_names,
                                           self.batch_metrics_values)
        })

        call_list = []
        call_list.append(call.on_test_begin({}))
        for batch in range(1, params['steps'] + 1):
            call_list.append(call.on_test_batch_begin(batch, {}))
            call_list.append(
                call.on_test_batch_end(batch, {
                    'batch': batch,
                    'size': ANY,
                    **test_batch_dict
                }))

        test_batch_dict.update({
            "test_" + metric_name: metric
            for metric_name, metric in zip(self.epoch_metrics_names,
                                           self.epoch_metrics_values)
        })
        call_list.append(
            call.on_test_end({
                "time": ANY,
                "test_loss": ANY,
                **test_batch_dict
            }))
        return call_list
Beispiel #5
0
    def test_givenAFasttextAddressParser_whenTestWithConfigWithCallbacks_thenCallbackAreUse(
        self,
    ):
        address_parser = AddressParser(
            model_type=self.a_fasttext_model_type,
            device=self.a_cpu_device,
            verbose=self.verbose,
        )

        self.training(address_parser, self.training_container, self.a_number_of_workers)

        callback_mock = MagicMock()
        performance_after_test = address_parser.test(
            self.test_container,
            batch_size=self.a_batch_size,
            num_workers=self.a_number_of_workers,
            callbacks=[callback_mock],
        )

        self.assertIsNotNone(performance_after_test)

        callback_test_start_call = [call.on_test_begin({})]
        callback_mock.assert_has_calls(callback_test_start_call)
        callback_test_end_call = [
            call.on_test_end(
                {
                    "time": ANY,
                    "test_loss": performance_after_test["test_loss"],
                    "test_accuracy": performance_after_test["test_accuracy"],
                }
            )
        ]
        callback_mock.assert_has_calls(callback_test_end_call)
        callback_mock.assert_not_called()
Beispiel #6
0
    def _test_callbacks_test(self, params):
        test_batch_dict = {"time": ANY, "test_loss": ANY}
        test_batch_dict.update({
            "test_" + metric_name: metric
            for metric_name, metric in zip(self.batch_metrics_names,
                                           self.batch_metrics_values)
        })

        call_list = []
        call_list.append(call.on_test_begin({}))
        for batch in range(1, params['steps'] + 1):
            call_list.append(call.on_test_batch_begin(batch, {}))
            call_list.append(
                call.on_test_batch_end(batch, {
                    'batch': batch,
                    'size': ANY,
                    **test_batch_dict
                }))

        test_batch_dict.update({
            "test_" + metric_name: metric
            for metric_name, metric in zip(self.epoch_metrics_names,
                                           self.epoch_metrics_values)
        })
        call_list.append(
            call.on_test_end({
                "time": ANY,
                "test_loss": ANY,
                **test_batch_dict
            }))

        method_calls = self.mock_callback.method_calls
        self.assertEqual(call.set_model(self.model),
                         method_calls[0])  # skip set_model and set param call
        self.assertEqual(call.set_params(params), method_calls[1])

        self.assertEqual(len(method_calls),
                         len(call_list) + 2)  # for set_model and set param
        self.assertEqual(method_calls[2:], call_list)
Beispiel #7
0
def test_trainer_callback_system(torch_save):
    """Test the callback system."""

    model = BoringModel()

    callback_mock = MagicMock()

    trainer_options = dict(
        callbacks=[callback_mock],
        max_epochs=1,
        limit_val_batches=1,
        limit_train_batches=3,
        limit_test_batches=2,
        progress_bar_refresh_rate=0,
    )

    # no call yet
    callback_mock.assert_not_called()

    # fit model
    trainer = Trainer(**trainer_options)

    # check that only the to calls exists
    assert trainer.callbacks[0] == callback_mock
    assert callback_mock.method_calls == [
        call.on_init_start(trainer),
        call.on_init_end(trainer),
    ]

    trainer.fit(model)

    assert callback_mock.method_calls == [
        call.on_init_start(trainer),
        call.on_init_end(trainer),
        call.setup(trainer, model, 'fit'),
        call.on_fit_start(trainer, model),
        call.on_pretrain_routine_start(trainer, model),
        call.on_pretrain_routine_end(trainer, model),
        call.on_sanity_check_start(trainer, model),
        call.on_validation_start(trainer, model),
        call.on_validation_epoch_start(trainer, model),
        call.on_validation_batch_start(trainer, model, ANY, 0, 0),
        call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
        call.on_validation_epoch_end(trainer, model),
        call.on_validation_end(trainer, model),
        call.on_sanity_check_end(trainer, model),
        call.on_train_start(trainer, model),
        call.on_epoch_start(trainer, model),
        call.on_train_epoch_start(trainer, model),
        call.on_batch_start(trainer, model),
        call.on_train_batch_start(trainer, model, ANY, 0, 0),
        call.on_after_backward(trainer, model),
        call.on_before_zero_grad(trainer, model, trainer.optimizers[0]),
        call.on_batch_end(trainer, model),
        call.on_train_batch_end(trainer, model, ANY, ANY, 0, 0),
        call.on_batch_start(trainer, model),
        call.on_train_batch_start(trainer, model, ANY, 1, 0),
        call.on_after_backward(trainer, model),
        call.on_before_zero_grad(trainer, model, trainer.optimizers[0]),
        call.on_batch_end(trainer, model),
        call.on_train_batch_end(trainer, model, ANY, ANY, 1, 0),
        call.on_batch_start(trainer, model),
        call.on_train_batch_start(trainer, model, ANY, 2, 0),
        call.on_after_backward(trainer, model),
        call.on_before_zero_grad(trainer, model, trainer.optimizers[0]),
        call.on_batch_end(trainer, model),
        call.on_train_batch_end(trainer, model, ANY, ANY, 2, 0),
        call.on_validation_start(trainer, model),
        call.on_validation_epoch_start(trainer, model),
        call.on_validation_batch_start(trainer, model, ANY, 0, 0),
        call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
        call.on_validation_epoch_end(trainer, model),
        call.on_validation_end(trainer, model),
        call.on_save_checkpoint(trainer, model),
        call.on_epoch_end(trainer, model),
        call.on_train_epoch_end(trainer, model, ANY),
        call.on_train_end(trainer, model),
        call.on_fit_end(trainer, model),
        call.teardown(trainer, model, 'fit'),
    ]

    callback_mock.reset_mock()
    trainer = Trainer(**trainer_options)
    trainer.test(model)

    assert callback_mock.method_calls == [
        call.on_init_start(trainer),
        call.on_init_end(trainer),
        call.setup(trainer, model, 'test'),
        call.on_fit_start(trainer, model),
        call.on_test_start(trainer, model),
        call.on_test_epoch_start(trainer, model),
        call.on_test_batch_start(trainer, model, ANY, 0, 0),
        call.on_test_batch_end(trainer, model, ANY, ANY, 0, 0),
        call.on_test_batch_start(trainer, model, ANY, 1, 0),
        call.on_test_batch_end(trainer, model, ANY, ANY, 1, 0),
        call.on_test_epoch_end(trainer, model),
        call.on_test_end(trainer, model),
        call.on_fit_end(trainer, model),
        call.teardown(trainer, model, 'fit'),
        call.teardown(trainer, model, 'test'),
    ]