Ejemplo n.º 1
0
def test_trainer_callback_hook_system_test(tmpdir):
    """Test the callback hook system for test."""

    model = BoringModel()
    callback_mock = MagicMock()
    trainer = Trainer(
        default_root_dir=tmpdir,
        callbacks=[callback_mock],
        max_epochs=1,
        limit_test_batches=2,
        progress_bar_refresh_rate=0,
    )

    trainer.test(model)

    assert callback_mock.method_calls == [
        call.on_init_start(trainer),
        call.on_init_end(trainer),
        call.on_before_accelerator_backend_setup(trainer, model),
        call.setup(trainer, model, 'test'),
        call.on_configure_sharded_model(trainer, model),
        call.on_test_start(trainer, model),
        call.on_epoch_start(trainer, model),
        call.on_test_epoch_start(trainer, model),
        call.on_test_batch_start(trainer, model, ANY, 0, 0),
        call.on_test_batch_end(trainer, model, ANY, ANY, 0, 0),
        call.on_test_batch_start(trainer, model, ANY, 1, 0),
        call.on_test_batch_end(trainer, model, ANY, ANY, 1, 0),
        call.on_test_epoch_end(trainer, model),
        call.on_epoch_end(trainer, model),
        call.on_test_end(trainer, model),
        call.teardown(trainer, model, 'test'),
    ]
Ejemplo n.º 2
0
    def _get_callback_expected_on_calls_when_testing(self, params):
        test_batch_dict = {"time": ANY, "test_loss": ANY}
        test_batch_dict.update({
            "test_" + metric_name: metric
            for metric_name, metric in zip(self.batch_metrics_names,
                                           self.batch_metrics_values)
        })

        call_list = []
        call_list.append(call.on_test_begin({}))
        for batch in range(1, params['steps'] + 1):
            call_list.append(call.on_test_batch_begin(batch, {}))
            call_list.append(
                call.on_test_batch_end(batch, {
                    'batch': batch,
                    'size': ANY,
                    **test_batch_dict
                }))

        test_batch_dict.update({
            "test_" + metric_name: metric
            for metric_name, metric in zip(self.epoch_metrics_names,
                                           self.epoch_metrics_values)
        })
        call_list.append(
            call.on_test_end({
                "time": ANY,
                "test_loss": ANY,
                **test_batch_dict
            }))
        return call_list
Ejemplo n.º 3
0
    def _test_callbacks_test(self, params, result_log):
        test_batch_dict = dict(zip(self.batch_metrics_names,
                                   self.batch_metrics_values),
                               loss=ANY,
                               time=ANY)

        call_list = []
        call_list.append(call.on_test_begin({}))
        for batch in range(1, params['batch'] + 1):
            call_list.append(call.on_test_batch_begin(batch, {}))
            call_list.append(
                call.on_test_batch_end(batch, {
                    'batch': batch,
                    'size': ANY,
                    **test_batch_dict
                }))
        call_list.append(call.on_test_end(result_log))

        method_calls = self.mock_callback.method_calls
        self.assertEqual(call.set_model(self.model),
                         method_calls[0])  # skip set_model

        self.assertEqual(len(method_calls),
                         len(call_list) + 1)  # for set_model
        self.assertEqual(method_calls[1:], call_list)
Ejemplo n.º 4
0
    def _test_callbacks_test(self, params):
        test_batch_dict = {"time": ANY, "test_loss": ANY}
        test_batch_dict.update({
            "test_" + metric_name: metric
            for metric_name, metric in zip(self.batch_metrics_names,
                                           self.batch_metrics_values)
        })

        call_list = []
        call_list.append(call.on_test_begin({}))
        for batch in range(1, params['steps'] + 1):
            call_list.append(call.on_test_batch_begin(batch, {}))
            call_list.append(
                call.on_test_batch_end(batch, {
                    'batch': batch,
                    'size': ANY,
                    **test_batch_dict
                }))

        test_batch_dict.update({
            "test_" + metric_name: metric
            for metric_name, metric in zip(self.epoch_metrics_names,
                                           self.epoch_metrics_values)
        })
        call_list.append(
            call.on_test_end({
                "time": ANY,
                "test_loss": ANY,
                **test_batch_dict
            }))

        method_calls = self.mock_callback.method_calls
        self.assertEqual(call.set_model(self.model),
                         method_calls[0])  # skip set_model and set param call
        self.assertEqual(call.set_params(params), method_calls[1])

        self.assertEqual(len(method_calls),
                         len(call_list) + 2)  # for set_model and set param
        self.assertEqual(method_calls[2:], call_list)
Ejemplo n.º 5
0
def test_trainer_callback_system(torch_save):
    """Test the callback system."""

    model = BoringModel()

    callback_mock = MagicMock()

    trainer_options = dict(
        callbacks=[callback_mock],
        max_epochs=1,
        limit_val_batches=1,
        limit_train_batches=3,
        limit_test_batches=2,
        progress_bar_refresh_rate=0,
    )

    # no call yet
    callback_mock.assert_not_called()

    # fit model
    trainer = Trainer(**trainer_options)

    # check that only the to calls exists
    assert trainer.callbacks[0] == callback_mock
    assert callback_mock.method_calls == [
        call.on_init_start(trainer),
        call.on_init_end(trainer),
    ]

    trainer.fit(model)

    assert callback_mock.method_calls == [
        call.on_init_start(trainer),
        call.on_init_end(trainer),
        call.setup(trainer, model, 'fit'),
        call.on_fit_start(trainer, model),
        call.on_pretrain_routine_start(trainer, model),
        call.on_pretrain_routine_end(trainer, model),
        call.on_sanity_check_start(trainer, model),
        call.on_validation_start(trainer, model),
        call.on_validation_epoch_start(trainer, model),
        call.on_validation_batch_start(trainer, model, ANY, 0, 0),
        call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
        call.on_validation_epoch_end(trainer, model),
        call.on_validation_end(trainer, model),
        call.on_sanity_check_end(trainer, model),
        call.on_train_start(trainer, model),
        call.on_epoch_start(trainer, model),
        call.on_train_epoch_start(trainer, model),
        call.on_batch_start(trainer, model),
        call.on_train_batch_start(trainer, model, ANY, 0, 0),
        call.on_after_backward(trainer, model),
        call.on_before_zero_grad(trainer, model, trainer.optimizers[0]),
        call.on_batch_end(trainer, model),
        call.on_train_batch_end(trainer, model, ANY, ANY, 0, 0),
        call.on_batch_start(trainer, model),
        call.on_train_batch_start(trainer, model, ANY, 1, 0),
        call.on_after_backward(trainer, model),
        call.on_before_zero_grad(trainer, model, trainer.optimizers[0]),
        call.on_batch_end(trainer, model),
        call.on_train_batch_end(trainer, model, ANY, ANY, 1, 0),
        call.on_batch_start(trainer, model),
        call.on_train_batch_start(trainer, model, ANY, 2, 0),
        call.on_after_backward(trainer, model),
        call.on_before_zero_grad(trainer, model, trainer.optimizers[0]),
        call.on_batch_end(trainer, model),
        call.on_train_batch_end(trainer, model, ANY, ANY, 2, 0),
        call.on_validation_start(trainer, model),
        call.on_validation_epoch_start(trainer, model),
        call.on_validation_batch_start(trainer, model, ANY, 0, 0),
        call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
        call.on_validation_epoch_end(trainer, model),
        call.on_validation_end(trainer, model),
        call.on_save_checkpoint(trainer, model),
        call.on_epoch_end(trainer, model),
        call.on_train_epoch_end(trainer, model, ANY),
        call.on_train_end(trainer, model),
        call.on_fit_end(trainer, model),
        call.teardown(trainer, model, 'fit'),
    ]

    callback_mock.reset_mock()
    trainer = Trainer(**trainer_options)
    trainer.test(model)

    assert callback_mock.method_calls == [
        call.on_init_start(trainer),
        call.on_init_end(trainer),
        call.setup(trainer, model, 'test'),
        call.on_fit_start(trainer, model),
        call.on_test_start(trainer, model),
        call.on_test_epoch_start(trainer, model),
        call.on_test_batch_start(trainer, model, ANY, 0, 0),
        call.on_test_batch_end(trainer, model, ANY, ANY, 0, 0),
        call.on_test_batch_start(trainer, model, ANY, 1, 0),
        call.on_test_batch_end(trainer, model, ANY, ANY, 1, 0),
        call.on_test_epoch_end(trainer, model),
        call.on_test_end(trainer, model),
        call.on_fit_end(trainer, model),
        call.teardown(trainer, model, 'fit'),
        call.teardown(trainer, model, 'test'),
    ]