def test_init_callback(self): trainer = self.get_trainer() expected_callbacks = DEFAULT_CALLBACKS.copy() + [ProgressCallback] self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks) # Callbacks passed at init are added to the default callbacks trainer = self.get_trainer(callbacks=[MyTestTrainerCallback]) expected_callbacks.append(MyTestTrainerCallback) self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks) # TrainingArguments.disable_tqdm controls if use ProgressCallback or PrinterCallback trainer = self.get_trainer(disable_tqdm=True) expected_callbacks = DEFAULT_CALLBACKS.copy() + [PrinterCallback] self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
def test_add_remove_callback(self): expected_callbacks = DEFAULT_CALLBACKS.copy() + [ProgressCallback] trainer = self.get_trainer() # We can add, pop, or remove by class name trainer.remove_callback(DefaultFlowCallback) expected_callbacks.remove(DefaultFlowCallback) self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks) trainer = self.get_trainer() cb = trainer.pop_callback(DefaultFlowCallback) self.assertEqual(cb.__class__, DefaultFlowCallback) self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks) trainer.add_callback(DefaultFlowCallback) expected_callbacks.insert(0, DefaultFlowCallback) self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks) # We can also add, pop, or remove by instance trainer = self.get_trainer() cb = trainer.callback_handler.callbacks[0] trainer.remove_callback(cb) expected_callbacks.remove(DefaultFlowCallback) self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks) trainer = self.get_trainer() cb1 = trainer.callback_handler.callbacks[0] cb2 = trainer.pop_callback(cb1) self.assertEqual(cb1, cb2) self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks) trainer.add_callback(cb1) expected_callbacks.insert(0, DefaultFlowCallback) self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)