def test_checkpoint_callbacks_are_last(tmpdir):
    """Test that checkpoint callbacks always get moved to the end of the list, with preserved order."""
    checkpoint1 = ModelCheckpoint(tmpdir)
    checkpoint2 = ModelCheckpoint(tmpdir)
    model_summary = ModelSummary()
    early_stopping = EarlyStopping()
    lr_monitor = LearningRateMonitor()
    progress_bar = ProgressBar()

    # no model reference
    trainer = Trainer(callbacks=[checkpoint1, progress_bar, lr_monitor, model_summary, checkpoint2])
    cb_connector = CallbackConnector(trainer)
    cb_connector._attach_model_callbacks()
    assert trainer.callbacks == [progress_bar, lr_monitor, model_summary, checkpoint1, checkpoint2]

    # no model callbacks
    model = LightningModule()
    model.configure_callbacks = lambda: []
    trainer.model = model
    cb_connector._attach_model_callbacks()
    assert trainer.callbacks == [progress_bar, lr_monitor, model_summary, checkpoint1, checkpoint2]

    # with model-specific callbacks that substitute ones in Trainer
    model = LightningModule()
    model.configure_callbacks = lambda: [checkpoint1, early_stopping, model_summary, checkpoint2]
    trainer = Trainer(callbacks=[progress_bar, lr_monitor, ModelCheckpoint(tmpdir)])
    trainer.model = model
    cb_connector = CallbackConnector(trainer)
    cb_connector._attach_model_callbacks()
    assert trainer.callbacks == [progress_bar, lr_monitor, early_stopping, model_summary, checkpoint1, checkpoint2]
Ejemplo n.º 2
0
 def assert_composition(trainer_callbacks, model_callbacks, expected):
     model = LightningModule()
     model.configure_callbacks = lambda: model_callbacks
     trainer = Trainer(checkpoint_callback=False, progress_bar_refresh_rate=0, callbacks=trainer_callbacks)
     trainer.model = model
     cb_connector = CallbackConnector(trainer)
     cb_connector._attach_model_callbacks()
     assert trainer.callbacks == expected
def test_attach_model_callbacks_override_info(caplog):
    """Test that the logs contain the info about overriding callbacks returned by configure_callbacks."""
    model = LightningModule()
    model.configure_callbacks = lambda: [LearningRateMonitor(), EarlyStopping()]
    trainer = Trainer(checkpoint_callback=False, callbacks=[EarlyStopping(), LearningRateMonitor(), ProgressBar()])
    trainer.model = model
    cb_connector = CallbackConnector(trainer)
    with caplog.at_level(logging.INFO):
        cb_connector._attach_model_callbacks()

    assert "existing callbacks passed to Trainer: EarlyStopping, LearningRateMonitor" in caplog.text
Ejemplo n.º 4
0
 def _attach_callbacks(trainer_callbacks, model_callbacks):
     model = LightningModule()
     model.configure_callbacks = lambda: model_callbacks
     has_progress_bar = any(isinstance(cb, ProgressBarBase) for cb in trainer_callbacks + model_callbacks)
     trainer = Trainer(
         enable_checkpointing=False,
         enable_progress_bar=has_progress_bar,
         enable_model_summary=False,
         callbacks=trainer_callbacks,
     )
     trainer.model = model
     cb_connector = CallbackConnector(trainer)
     cb_connector._attach_model_callbacks()
     return trainer