def test_trainer_callback_hook_system_test(tmpdir): """Test the callback hook system for test.""" model = BoringModel() callback_mock = MagicMock() trainer = Trainer( default_root_dir=tmpdir, callbacks=[callback_mock], max_epochs=1, limit_test_batches=2, progress_bar_refresh_rate=0, ) trainer.test(model) assert callback_mock.method_calls == [ call.on_init_start(trainer), call.on_init_end(trainer), call.on_before_accelerator_backend_setup(trainer, model), call.setup(trainer, model, 'test'), call.on_configure_sharded_model(trainer, model), call.on_test_start(trainer, model), call.on_epoch_start(trainer, model), call.on_test_epoch_start(trainer, model), call.on_test_batch_start(trainer, model, ANY, 0, 0), call.on_test_batch_end(trainer, model, ANY, ANY, 0, 0), call.on_test_batch_start(trainer, model, ANY, 1, 0), call.on_test_batch_end(trainer, model, ANY, ANY, 1, 0), call.on_test_epoch_end(trainer, model), call.on_epoch_end(trainer, model), call.on_test_end(trainer, model), call.teardown(trainer, model, 'test'), ]
def _get_callback_expected_on_calls_when_testing(self, params): test_batch_dict = {"time": ANY, "test_loss": ANY} test_batch_dict.update({ "test_" + metric_name: metric for metric_name, metric in zip(self.batch_metrics_names, self.batch_metrics_values) }) call_list = [] call_list.append(call.on_test_begin({})) for batch in range(1, params['steps'] + 1): call_list.append(call.on_test_batch_begin(batch, {})) call_list.append( call.on_test_batch_end(batch, { 'batch': batch, 'size': ANY, **test_batch_dict })) test_batch_dict.update({ "test_" + metric_name: metric for metric_name, metric in zip(self.epoch_metrics_names, self.epoch_metrics_values) }) call_list.append( call.on_test_end({ "time": ANY, "test_loss": ANY, **test_batch_dict })) return call_list
def _test_callbacks_test(self, params, result_log): test_batch_dict = dict(zip(self.batch_metrics_names, self.batch_metrics_values), loss=ANY, time=ANY) call_list = [] call_list.append(call.on_test_begin({})) for batch in range(1, params['batch'] + 1): call_list.append(call.on_test_batch_begin(batch, {})) call_list.append( call.on_test_batch_end(batch, { 'batch': batch, 'size': ANY, **test_batch_dict })) call_list.append(call.on_test_end(result_log)) method_calls = self.mock_callback.method_calls self.assertEqual(call.set_model(self.model), method_calls[0]) # skip set_model self.assertEqual(len(method_calls), len(call_list) + 1) # for set_model self.assertEqual(method_calls[1:], call_list)
def _test_callbacks_test(self, params): test_batch_dict = {"time": ANY, "test_loss": ANY} test_batch_dict.update({ "test_" + metric_name: metric for metric_name, metric in zip(self.batch_metrics_names, self.batch_metrics_values) }) call_list = [] call_list.append(call.on_test_begin({})) for batch in range(1, params['steps'] + 1): call_list.append(call.on_test_batch_begin(batch, {})) call_list.append( call.on_test_batch_end(batch, { 'batch': batch, 'size': ANY, **test_batch_dict })) test_batch_dict.update({ "test_" + metric_name: metric for metric_name, metric in zip(self.epoch_metrics_names, self.epoch_metrics_values) }) call_list.append( call.on_test_end({ "time": ANY, "test_loss": ANY, **test_batch_dict })) method_calls = self.mock_callback.method_calls self.assertEqual(call.set_model(self.model), method_calls[0]) # skip set_model and set param call self.assertEqual(call.set_params(params), method_calls[1]) self.assertEqual(len(method_calls), len(call_list) + 2) # for set_model and set param self.assertEqual(method_calls[2:], call_list)
def test_trainer_callback_system(torch_save): """Test the callback system.""" model = BoringModel() callback_mock = MagicMock() trainer_options = dict( callbacks=[callback_mock], max_epochs=1, limit_val_batches=1, limit_train_batches=3, limit_test_batches=2, progress_bar_refresh_rate=0, ) # no call yet callback_mock.assert_not_called() # fit model trainer = Trainer(**trainer_options) # check that only the to calls exists assert trainer.callbacks[0] == callback_mock assert callback_mock.method_calls == [ call.on_init_start(trainer), call.on_init_end(trainer), ] trainer.fit(model) assert callback_mock.method_calls == [ call.on_init_start(trainer), call.on_init_end(trainer), call.setup(trainer, model, 'fit'), call.on_fit_start(trainer, model), call.on_pretrain_routine_start(trainer, model), call.on_pretrain_routine_end(trainer, model), call.on_sanity_check_start(trainer, model), call.on_validation_start(trainer, model), call.on_validation_epoch_start(trainer, model), call.on_validation_batch_start(trainer, model, ANY, 0, 0), call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0), call.on_validation_epoch_end(trainer, model), call.on_validation_end(trainer, model), call.on_sanity_check_end(trainer, model), call.on_train_start(trainer, model), call.on_epoch_start(trainer, model), call.on_train_epoch_start(trainer, model), call.on_batch_start(trainer, model), call.on_train_batch_start(trainer, model, ANY, 0, 0), call.on_after_backward(trainer, model), call.on_before_zero_grad(trainer, model, trainer.optimizers[0]), call.on_batch_end(trainer, model), call.on_train_batch_end(trainer, model, ANY, ANY, 0, 0), call.on_batch_start(trainer, model), call.on_train_batch_start(trainer, model, ANY, 1, 0), call.on_after_backward(trainer, model), call.on_before_zero_grad(trainer, model, trainer.optimizers[0]), call.on_batch_end(trainer, model), call.on_train_batch_end(trainer, model, ANY, ANY, 1, 0), call.on_batch_start(trainer, model), call.on_train_batch_start(trainer, model, ANY, 2, 0), call.on_after_backward(trainer, model), call.on_before_zero_grad(trainer, model, trainer.optimizers[0]), call.on_batch_end(trainer, model), call.on_train_batch_end(trainer, model, ANY, ANY, 2, 0), call.on_validation_start(trainer, model), call.on_validation_epoch_start(trainer, model), call.on_validation_batch_start(trainer, model, ANY, 0, 0), call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0), call.on_validation_epoch_end(trainer, model), call.on_validation_end(trainer, model), call.on_save_checkpoint(trainer, model), call.on_epoch_end(trainer, model), call.on_train_epoch_end(trainer, model, ANY), call.on_train_end(trainer, model), call.on_fit_end(trainer, model), call.teardown(trainer, model, 'fit'), ] callback_mock.reset_mock() trainer = Trainer(**trainer_options) trainer.test(model) assert callback_mock.method_calls == [ call.on_init_start(trainer), call.on_init_end(trainer), call.setup(trainer, model, 'test'), call.on_fit_start(trainer, model), call.on_test_start(trainer, model), call.on_test_epoch_start(trainer, model), call.on_test_batch_start(trainer, model, ANY, 0, 0), call.on_test_batch_end(trainer, model, ANY, ANY, 0, 0), call.on_test_batch_start(trainer, model, ANY, 1, 0), call.on_test_batch_end(trainer, model, ANY, ANY, 1, 0), call.on_test_epoch_end(trainer, model), call.on_test_end(trainer, model), call.on_fit_end(trainer, model), call.teardown(trainer, model, 'fit'), call.teardown(trainer, model, 'test'), ]