def _test_fitting(self, params, logs, has_valid=True): self.assertEqual(len(logs), params['epochs']) train_dict = dict(zip(self.metrics_names, self.metrics_values), loss=ANY) if has_valid: val_metrics_names = ['val_' + metric_name for metric_name in self.metrics_names] val_dict = dict(zip(val_metrics_names, self.metrics_values), val_loss=ANY) log_dict = {**train_dict, **val_dict} else: log_dict = train_dict for epoch, log in enumerate(logs, 1): self.assertEqual(log, dict(log_dict, epoch=epoch)) call_list = [] call_list.append(call.on_train_begin({})) for epoch in range(1, params['epochs']+1): call_list.append(call.on_epoch_begin(epoch, {})) for step in range(1, params['steps']+1): call_list.append(call.on_batch_begin(step, {})) call_list.append(call.on_backward_end(step)) call_list.append(call.on_batch_end(step, {'batch': step, 'size': ANY, **train_dict})) call_list.append(call.on_epoch_end(epoch, {'epoch': epoch, **log_dict})) call_list.append(call.on_train_end({})) method_calls = self.mock_callback.method_calls self.assertIn(call.set_model(self.model), method_calls[:2]) self.assertIn(call.set_params(params), method_calls[:2]) self.assertEqual(len(method_calls), len(call_list) + 2) self.assertEqual(method_calls[2:], call_list)
def test_givenAFasttextAddressParser_whenRetrainWithConfigWithCallbacksNewTags_thenCallbackAreUse( self, ): address_parser = AddressParser( model_type=self.a_fasttext_model_type, device=self.a_cpu_device, verbose=self.verbose, ) callback_mock = MagicMock(spec=Callback) performance_after_training = address_parser.retrain( self.new_prediction_data_container, self.a_train_ratio, epochs=self.a_single_epoch, batch_size=self.a_batch_size, num_workers=self.a_number_of_workers, learning_rate=self.a_learning_rate, callbacks=[callback_mock], logging_path=self.a_checkpoints_saving_dir, prediction_tags=self.with_new_prediction_tags, ) self.assertIsNotNone(performance_after_training) callback_train_start_call = [call.on_train_begin({})] callback_mock.assert_has_calls(callback_train_start_call) callback_train_end_call = [call.on_train_end({})] callback_mock.assert_has_calls(callback_train_end_call) callback_mock.assert_not_called()
def test_epoch_delay(self): epoch_delay = 4 delay_callback = DelayCallback(self.mock_callback, epoch_delay=epoch_delay) train_generator = some_data_generator(DelayCallbackTest.batch_size) valid_generator = some_data_generator(DelayCallbackTest.batch_size) self.model.fit_generator(train_generator, valid_generator, epochs=DelayCallbackTest.epochs, steps_per_epoch=DelayCallbackTest.steps_per_epoch, validation_steps=DelayCallbackTest.steps_per_epoch, callbacks=[delay_callback]) params = {'epochs': DelayCallbackTest.epochs, 'steps': DelayCallbackTest.steps_per_epoch} call_list = [] call_list.append(call.on_train_begin({})) for epoch in range(epoch_delay + 1, DelayCallbackTest.epochs + 1): call_list.append(call.on_epoch_begin(epoch, {})) for step in range(1, params['steps'] + 1): call_list.append(call.on_train_batch_begin(step, {})) call_list.append(call.on_backward_end(step)) call_list.append( call.on_train_batch_end(step, { 'batch': step, 'size': DelayCallbackTest.batch_size, **self.train_dict })) call_list.append(call.on_epoch_end(epoch, {'epoch': epoch, **self.log_dict})) call_list.append(call.on_train_end({})) method_calls = self.mock_callback.method_calls self.assertIn(call.set_model(self.model), method_calls[:2]) self.assertIn(call.set_params(params), method_calls[:2]) self.assertEqual(len(method_calls), len(call_list) + 2) self.assertEqual(method_calls[2:], call_list)
def _test_callbacks_train(self, params, logs, has_valid=True, steps=None): # pylint: disable=too-many-arguments if steps is None: steps = params['steps'] self.assertEqual(len(logs), params['epochs']) train_batch_dict = dict(zip(self.batch_metrics_names, self.batch_metrics_values), loss=ANY, time=ANY) train_epochs_dict = dict(zip(self.epoch_metrics_names, self.epoch_metrics_values)) log_dict = {**train_batch_dict, **train_epochs_dict} if has_valid: val_batch_metrics_names = ['val_' + metric_name for metric_name in self.batch_metrics_names] val_batch_dict = dict(zip(val_batch_metrics_names, self.batch_metrics_values), val_loss=ANY) val_epoch_metrics_names = ['val_' + metric_name for metric_name in self.epoch_metrics_names] val_epochs_dict = dict(zip(val_epoch_metrics_names, self.epoch_metrics_values)) log_dict.update({**val_batch_dict, **val_epochs_dict}) for epoch, log in enumerate(logs, 1): self.assertEqual(log, dict(log_dict, epoch=epoch)) call_list = [] call_list.append(call.on_train_begin({})) for epoch in range(1, params['epochs'] + 1): call_list.append(call.on_epoch_begin(epoch, {})) for step in range(1, steps + 1): call_list.append(call.on_train_batch_begin(step, {})) call_list.append(call.on_backward_end(step)) call_list.append(call.on_train_batch_end(step, {'batch': step, 'size': ANY, **train_batch_dict})) call_list.append(call.on_epoch_end(epoch, {'epoch': epoch, **log_dict})) call_list.append(call.on_train_end({})) method_calls = self.mock_callback.method_calls self.assertIn(call.set_model(self.model), method_calls[:2]) # skip set_model and set param call self.assertIn(call.set_params(params), method_calls[:2]) self.assertEqual(len(method_calls), len(call_list) + 2) # for set_model and set param self.assertEqual(method_calls[2:], call_list)
def test_givenABPEmbAddressParser_whenRetrainWithConfigWithCallbacks_thenCallbackAreUse( self): address_parser = AddressParser(model_type=self.a_bpemb_model_type, device=self.a_torch_device, verbose=self.verbose) callback_mock = MagicMock(spec=Callback) performance_after_training = address_parser.retrain( self.training_container, self.a_train_ratio, epochs=self.a_single_epoch, batch_size=self.a_batch_size, num_workers=self.a_number_of_workers, learning_rate=self.a_learning_rate, callbacks=[callback_mock], logging_path=self.a_checkpoints_saving_dir) self.assertIsNotNone(performance_after_training) callback_train_start_call = [call.on_train_begin({})] callback_mock.assert_has_calls(callback_train_start_call) callback_train_end_call = [call.on_train_end({})] callback_mock.assert_has_calls(callback_train_end_call) callback_mock.assert_not_called()
def test_trainer_callback_system(torch_save): """Test the callback system.""" model = BoringModel() callback_mock = MagicMock() trainer_options = dict( callbacks=[callback_mock], max_epochs=1, limit_val_batches=1, limit_train_batches=3, limit_test_batches=2, progress_bar_refresh_rate=0, ) # no call yet callback_mock.assert_not_called() # fit model trainer = Trainer(**trainer_options) # check that only the to calls exists assert trainer.callbacks[0] == callback_mock assert callback_mock.method_calls == [ call.on_init_start(trainer), call.on_init_end(trainer), ] trainer.fit(model) assert callback_mock.method_calls == [ call.on_init_start(trainer), call.on_init_end(trainer), call.setup(trainer, model, 'fit'), call.on_fit_start(trainer, model), call.on_pretrain_routine_start(trainer, model), call.on_pretrain_routine_end(trainer, model), call.on_sanity_check_start(trainer, model), call.on_validation_start(trainer, model), call.on_validation_epoch_start(trainer, model), call.on_validation_batch_start(trainer, model, ANY, 0, 0), call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0), call.on_validation_epoch_end(trainer, model), call.on_validation_end(trainer, model), call.on_sanity_check_end(trainer, model), call.on_train_start(trainer, model), call.on_epoch_start(trainer, model), call.on_train_epoch_start(trainer, model), call.on_batch_start(trainer, model), call.on_train_batch_start(trainer, model, ANY, 0, 0), call.on_after_backward(trainer, model), call.on_before_zero_grad(trainer, model, trainer.optimizers[0]), call.on_batch_end(trainer, model), call.on_train_batch_end(trainer, model, ANY, ANY, 0, 0), call.on_batch_start(trainer, model), call.on_train_batch_start(trainer, model, ANY, 1, 0), call.on_after_backward(trainer, model), call.on_before_zero_grad(trainer, model, trainer.optimizers[0]), call.on_batch_end(trainer, model), call.on_train_batch_end(trainer, model, ANY, ANY, 1, 0), call.on_batch_start(trainer, model), call.on_train_batch_start(trainer, model, ANY, 2, 0), call.on_after_backward(trainer, model), call.on_before_zero_grad(trainer, model, trainer.optimizers[0]), call.on_batch_end(trainer, model), call.on_train_batch_end(trainer, model, ANY, ANY, 2, 0), call.on_validation_start(trainer, model), call.on_validation_epoch_start(trainer, model), call.on_validation_batch_start(trainer, model, ANY, 0, 0), call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0), call.on_validation_epoch_end(trainer, model), call.on_validation_end(trainer, model), call.on_save_checkpoint(trainer, model), call.on_epoch_end(trainer, model), call.on_train_epoch_end(trainer, model, ANY), call.on_train_end(trainer, model), call.on_fit_end(trainer, model), call.teardown(trainer, model, 'fit'), ] callback_mock.reset_mock() trainer = Trainer(**trainer_options) trainer.test(model) assert callback_mock.method_calls == [ call.on_init_start(trainer), call.on_init_end(trainer), call.setup(trainer, model, 'test'), call.on_fit_start(trainer, model), call.on_test_start(trainer, model), call.on_test_epoch_start(trainer, model), call.on_test_batch_start(trainer, model, ANY, 0, 0), call.on_test_batch_end(trainer, model, ANY, ANY, 0, 0), call.on_test_batch_start(trainer, model, ANY, 1, 0), call.on_test_batch_end(trainer, model, ANY, ANY, 1, 0), call.on_test_epoch_end(trainer, model), call.on_test_end(trainer, model), call.on_fit_end(trainer, model), call.teardown(trainer, model, 'fit'), call.teardown(trainer, model, 'test'), ]
def test_trainer_callback_hook_system_fit(_, tmpdir): """Test the callback hook system for fit.""" model = BoringModel() callback_mock = MagicMock() trainer = Trainer( default_root_dir=tmpdir, callbacks=[callback_mock], max_epochs=1, limit_val_batches=1, limit_train_batches=3, progress_bar_refresh_rate=0, ) # check that only the to calls exists assert trainer.callbacks[0] == callback_mock assert callback_mock.method_calls == [ call.on_init_start(trainer), call.on_init_end(trainer), ] # fit model trainer.fit(model) assert callback_mock.method_calls == [ call.on_init_start(trainer), call.on_init_end(trainer), call.on_before_accelerator_backend_setup(trainer, model), call.setup(trainer, model, 'fit'), call.on_configure_sharded_model(trainer, model), call.on_fit_start(trainer, model), call.on_pretrain_routine_start(trainer, model), call.on_pretrain_routine_end(trainer, model), call.on_sanity_check_start(trainer, model), call.on_validation_start(trainer, model), call.on_epoch_start(trainer, model), call.on_validation_epoch_start(trainer, model), call.on_validation_batch_start(trainer, model, ANY, 0, 0), call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0), call.on_validation_epoch_end(trainer, model), call.on_epoch_end(trainer, model), call.on_validation_end(trainer, model), call.on_sanity_check_end(trainer, model), call.on_train_start(trainer, model), call.on_epoch_start(trainer, model), call.on_train_epoch_start(trainer, model), call.on_batch_start(trainer, model), call.on_train_batch_start(trainer, model, ANY, 0, 0), call.on_before_zero_grad(trainer, model, trainer.optimizers[0]), call.on_after_backward(trainer, model), call.on_train_batch_end(trainer, model, ANY, ANY, 0, 0), call.on_batch_end(trainer, model), call.on_batch_start(trainer, model), call.on_train_batch_start(trainer, model, ANY, 1, 0), call.on_before_zero_grad(trainer, model, trainer.optimizers[0]), call.on_after_backward(trainer, model), call.on_train_batch_end(trainer, model, ANY, ANY, 1, 0), call.on_batch_end(trainer, model), call.on_batch_start(trainer, model), call.on_train_batch_start(trainer, model, ANY, 2, 0), call.on_before_zero_grad(trainer, model, trainer.optimizers[0]), call.on_after_backward(trainer, model), call.on_train_batch_end(trainer, model, ANY, ANY, 2, 0), call.on_batch_end(trainer, model), call.on_train_epoch_end(trainer, model, ANY), call.on_epoch_end(trainer, model), call.on_validation_start(trainer, model), call.on_epoch_start(trainer, model), call.on_validation_epoch_start(trainer, model), call.on_validation_batch_start(trainer, model, ANY, 0, 0), call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0), call.on_validation_epoch_end(trainer, model), call.on_epoch_end(trainer, model), call.on_validation_end(trainer, model), call.on_save_checkpoint( trainer, model), # should take ANY but we are inspecting signature for BC call.on_train_end(trainer, model), call.on_fit_end(trainer, model), call.teardown(trainer, model, 'fit'), ]
def _get_callback_expected_on_calls_when_training(self, params, logs, has_valid=True, steps=None, valid_steps=10): # pylint: disable=too-many-arguments if steps is None: steps = params['steps'] train_batch_dict = dict(zip(self.batch_metrics_names, self.batch_metrics_values), time=ANY, loss=ANY) train_epochs_dict = dict( zip(self.epoch_metrics_names, self.epoch_metrics_values)) log_dict = {**train_batch_dict, **train_epochs_dict} if has_valid: val_batch_metrics_names = [ 'val_' + metric_name for metric_name in self.batch_metrics_names ] val_batch_dict = dict(zip(val_batch_metrics_names, self.batch_metrics_values), val_loss=ANY) val_epoch_metrics_names = [ 'val_' + metric_name for metric_name in self.epoch_metrics_names ] val_epochs_dict = dict( zip(val_epoch_metrics_names, self.epoch_metrics_values)) log_dict.update({**val_batch_dict, **val_epochs_dict}) for epoch, log in enumerate(logs, 1): self.assertEqual(log, dict(log_dict, epoch=epoch)) call_list = [] call_list.append(call.on_train_begin({})) for epoch in range(1, params['epochs'] + 1): call_list.append(call.on_epoch_begin(epoch, {})) for step in range(1, steps + 1): call_list.append(call.on_train_batch_begin(step, {})) call_list.append(call.on_backward_end(step)) call_list.append( call.on_train_batch_end(step, { 'batch': step, 'size': ANY, **train_batch_dict })) if has_valid: call_list.append(call.on_valid_begin({})) for step in range(1, valid_steps + 1): call_list.append(call.on_valid_batch_begin(step, {})) call_list.append( call.on_valid_batch_end( step, { 'batch': step, 'size': ANY, 'time': ANY, **val_batch_dict })) call_list.append( call.on_valid_end({ 'time': ANY, **val_batch_dict, **val_epochs_dict })) call_list.append(call.on_epoch_end(epoch, logs[epoch - 1])) call_list.append(call.on_train_end({})) return call_list