def test_checkpoint_callbacks_are_last(tmpdir): """Test that checkpoint callbacks always get moved to the end of the list, with preserved order.""" checkpoint1 = ModelCheckpoint(tmpdir) checkpoint2 = ModelCheckpoint(tmpdir) model_summary = ModelSummary() early_stopping = EarlyStopping() lr_monitor = LearningRateMonitor() progress_bar = ProgressBar() # no model reference trainer = Trainer(callbacks=[checkpoint1, progress_bar, lr_monitor, model_summary, checkpoint2]) cb_connector = CallbackConnector(trainer) cb_connector._attach_model_callbacks() assert trainer.callbacks == [progress_bar, lr_monitor, model_summary, checkpoint1, checkpoint2] # no model callbacks model = LightningModule() model.configure_callbacks = lambda: [] trainer.model = model cb_connector._attach_model_callbacks() assert trainer.callbacks == [progress_bar, lr_monitor, model_summary, checkpoint1, checkpoint2] # with model-specific callbacks that substitute ones in Trainer model = LightningModule() model.configure_callbacks = lambda: [checkpoint1, early_stopping, model_summary, checkpoint2] trainer = Trainer(callbacks=[progress_bar, lr_monitor, ModelCheckpoint(tmpdir)]) trainer.model = model cb_connector = CallbackConnector(trainer) cb_connector._attach_model_callbacks() assert trainer.callbacks == [progress_bar, lr_monitor, early_stopping, model_summary, checkpoint1, checkpoint2]
def assert_composition(trainer_callbacks, model_callbacks, expected): model = LightningModule() model.configure_callbacks = lambda: model_callbacks trainer = Trainer(checkpoint_callback=False, progress_bar_refresh_rate=0, callbacks=trainer_callbacks) trainer.model = model cb_connector = CallbackConnector(trainer) cb_connector._attach_model_callbacks() assert trainer.callbacks == expected
def test_attach_model_callbacks_override_info(caplog): """Test that the logs contain the info about overriding callbacks returned by configure_callbacks.""" model = LightningModule() model.configure_callbacks = lambda: [LearningRateMonitor(), EarlyStopping()] trainer = Trainer(checkpoint_callback=False, callbacks=[EarlyStopping(), LearningRateMonitor(), ProgressBar()]) trainer.model = model cb_connector = CallbackConnector(trainer) with caplog.at_level(logging.INFO): cb_connector._attach_model_callbacks() assert "existing callbacks passed to Trainer: EarlyStopping, LearningRateMonitor" in caplog.text
def _attach_callbacks(trainer_callbacks, model_callbacks): model = LightningModule() model.configure_callbacks = lambda: model_callbacks has_progress_bar = any(isinstance(cb, ProgressBarBase) for cb in trainer_callbacks + model_callbacks) trainer = Trainer( enable_checkpointing=False, enable_progress_bar=has_progress_bar, enable_model_summary=False, callbacks=trainer_callbacks, ) trainer.model = model cb_connector = CallbackConnector(trainer) cb_connector._attach_model_callbacks() return trainer