Пример #1
0
def test_rich_progress_bar_refresh_rate():
    progress_bar = RichProgressBar(refresh_rate_per_second=1)
    assert progress_bar.is_enabled
    assert not progress_bar.is_disabled
    progress_bar = RichProgressBar(refresh_rate_per_second=0)
    assert not progress_bar.is_enabled
    assert progress_bar.is_disabled
Пример #2
0
def test_rich_progress_bar_custom_theme(tmpdir):
    """Test to ensure that custom theme styles are used."""
    with mock.patch.multiple(
            "pytorch_lightning.callbacks.progress.rich_progress",
            CustomBarColumn=DEFAULT,
            BatchesProcessedColumn=DEFAULT,
            CustomTimeColumn=DEFAULT,
            ProcessingSpeedColumn=DEFAULT,
    ) as mocks:
        theme = RichProgressBarTheme()

        progress_bar = RichProgressBar(theme=theme)
        progress_bar.on_train_start(Trainer(tmpdir), BoringModel())

        assert progress_bar.theme == theme
        args, kwargs = mocks["CustomBarColumn"].call_args
        assert kwargs["complete_style"] == theme.progress_bar
        assert kwargs["finished_style"] == theme.progress_bar_finished

        args, kwargs = mocks["BatchesProcessedColumn"].call_args
        assert kwargs["style"] == theme.batch_progress

        args, kwargs = mocks["CustomTimeColumn"].call_args
        assert kwargs["style"] == theme.time

        args, kwargs = mocks["ProcessingSpeedColumn"].call_args
        assert kwargs["style"] == theme.processing_speed
Пример #3
0
def test_rich_progress_bar_colab_light_theme_update(*_):
    theme = RichProgressBar().theme
    assert theme.description == "black"
    assert theme.batch_progress == "black"
    assert theme.metrics == "black"

    theme = RichProgressBar(
        theme=RichProgressBarTheme(description="blue", metrics="red")).theme
    assert theme.description == "blue"
    assert theme.batch_progress == "black"
    assert theme.metrics == "red"
Пример #4
0
def test_rich_progress_bar_with_refresh_rate(tmpdir, refresh_rate,
                                             train_batches, val_batches,
                                             expected_call_count):
    model = BoringModel()
    trainer = Trainer(
        default_root_dir=tmpdir,
        num_sanity_val_steps=0,
        limit_train_batches=train_batches,
        limit_val_batches=val_batches,
        max_epochs=1,
        callbacks=RichProgressBar(refresh_rate=refresh_rate),
    )

    trainer.progress_bar_callback.on_train_start(trainer, model)
    with mock.patch.object(trainer.progress_bar_callback.progress,
                           "update",
                           wraps=trainer.progress_bar_callback.progress.update
                           ) as progress_update:
        trainer.fit(model)
        assert progress_update.call_count == expected_call_count

    if train_batches > 0:
        fit_main_bar = trainer.progress_bar_callback.progress.tasks[0]
        assert fit_main_bar.completed == train_batches + val_batches
        assert fit_main_bar.total == train_batches + val_batches
        assert fit_main_bar.visible
    if val_batches > 0:
        fit_val_bar = trainer.progress_bar_callback.progress.tasks[1]
        assert fit_val_bar.completed == val_batches
        assert fit_val_bar.total == val_batches
        assert not fit_val_bar.visible
Пример #5
0
def test_rich_progress_bar_import_error(monkeypatch):
    import pytorch_lightning.callbacks.progress.rich_progress as imports

    monkeypatch.setattr(imports, "_RICH_AVAILABLE", False)
    with pytest.raises(ModuleNotFoundError,
                       match="`RichProgressBar` requires `rich` >= 10.2.2."):
        RichProgressBar()
Пример #6
0
    def _prepare_callbacks(self, callbacks=None) -> List:
        """Prepares the necesary callbacks to the Trainer based on the configuration

        Returns:
            List: A list of callbacks
        """
        callbacks = [] if callbacks is None else callbacks
        if self.config.early_stopping is not None:
            early_stop_callback = pl.callbacks.early_stopping.EarlyStopping(
                monitor=self.config.early_stopping,
                min_delta=self.config.early_stopping_min_delta,
                patience=self.config.early_stopping_patience,
                verbose=False,
                mode=self.config.early_stopping_mode,
            )
            callbacks.append(early_stop_callback)
        if self.config.checkpoints:
            ckpt_name = f"{self.name}-{self.uid}"
            ckpt_name = ckpt_name.replace(" ",
                                          "_") + "_{epoch}-{valid_loss:.2f}"
            model_checkpoint = pl.callbacks.ModelCheckpoint(
                monitor=self.config.checkpoints,
                dirpath=self.config.checkpoints_path,
                filename=ckpt_name,
                save_top_k=self.config.checkpoints_save_top_k,
                mode=self.config.checkpoints_mode,
            )
            callbacks.append(model_checkpoint)
            self.config.checkpoint_callback = True
        else:
            self.config.checkpoint_callback = False
        if self.config.progress_bar == "rich":
            callbacks.append(RichProgressBar())
        logger.debug(f"Callbacks used: {callbacks}")
        return callbacks
Пример #7
0
def test_rich_progress_bar_counter_with_val_check_interval(tmpdir):
    """Test the completed and total counter for rich progress bar when using val_check_interval."""
    progress_bar = RichProgressBar()
    model = BoringModel()
    trainer = Trainer(
        default_root_dir=tmpdir,
        val_check_interval=2,
        max_epochs=1,
        limit_train_batches=7,
        limit_val_batches=4,
        callbacks=[progress_bar],
    )
    trainer.fit(model)

    fit_main_progress_bar = progress_bar.progress.tasks[1]
    assert fit_main_progress_bar.completed == 7 + 3 * 4
    assert fit_main_progress_bar.total == 7 + 3 * 4

    fit_val_bar = progress_bar.progress.tasks[2]
    assert fit_val_bar.completed == 4
    assert fit_val_bar.total == 4

    trainer.validate(model)
    val_bar = progress_bar.progress.tasks[0]
    assert val_bar.completed == 4
    assert val_bar.total == 4
Пример #8
0
def test_rich_progress_bar_import_error():

    if not _RICH_AVAILABLE:
        with pytest.raises(
                ImportError,
                match="`RichProgressBar` requires `rich` to be installed."):
            Trainer(callbacks=RichProgressBar())
Пример #9
0
def test_rich_progress_bar(progress_update, tmpdir, dataset):
    class TestModel(BoringModel):
        def train_dataloader(self):
            return DataLoader(dataset=dataset)

        def val_dataloader(self):
            return DataLoader(dataset=dataset)

        def test_dataloader(self):
            return DataLoader(dataset=dataset)

        def predict_dataloader(self):
            return DataLoader(dataset=dataset)

    model = TestModel()

    trainer = Trainer(
        default_root_dir=tmpdir,
        num_sanity_val_steps=0,
        limit_train_batches=1,
        limit_val_batches=1,
        limit_test_batches=1,
        limit_predict_batches=1,
        max_steps=1,
        callbacks=RichProgressBar(),
    )

    trainer.fit(model)
    trainer.validate(model)
    trainer.test(model)
    trainer.predict(model)

    assert progress_update.call_count == 8
Пример #10
0
def test_rich_progress_bar_callback():
    trainer = Trainer(callbacks=RichProgressBar())

    progress_bars = [c for c in trainer.callbacks if isinstance(c, ProgressBarBase)]

    assert len(progress_bars) == 1
    assert isinstance(trainer.progress_bar_callback, RichProgressBar)
Пример #11
0
def test_rich_progress_bar_refresh_rate_disabled(progress_update, tmpdir):
    trainer = Trainer(
        default_root_dir=tmpdir,
        fast_dev_run=4,
        callbacks=RichProgressBar(refresh_rate=0),
    )
    trainer.fit(BoringModel())
    assert progress_update.call_count == 0
Пример #12
0
def test_rich_progress_bar(tmpdir, dataset):
    class TestModel(BoringModel):
        def train_dataloader(self):
            return DataLoader(dataset=dataset)

        def val_dataloader(self):
            return DataLoader(dataset=dataset)

        def test_dataloader(self):
            return DataLoader(dataset=dataset)

        def predict_dataloader(self):
            return DataLoader(dataset=dataset)

    trainer = Trainer(
        default_root_dir=tmpdir,
        num_sanity_val_steps=0,
        limit_train_batches=1,
        limit_val_batches=1,
        limit_test_batches=1,
        limit_predict_batches=1,
        max_epochs=1,
        callbacks=RichProgressBar(),
    )
    model = TestModel()

    with mock.patch(
            "pytorch_lightning.callbacks.progress.rich_progress.Progress.update"
    ) as mocked:
        trainer.fit(model)
    # 3 for main progress bar and 1 for val progress bar
    assert mocked.call_count == 4

    with mock.patch(
            "pytorch_lightning.callbacks.progress.rich_progress.Progress.update"
    ) as mocked:
        trainer.validate(model)
    assert mocked.call_count == 1

    with mock.patch(
            "pytorch_lightning.callbacks.progress.rich_progress.Progress.update"
    ) as mocked:
        trainer.test(model)
    assert mocked.call_count == 1

    with mock.patch(
            "pytorch_lightning.callbacks.progress.rich_progress.Progress.update"
    ) as mocked:
        trainer.predict(model)
    assert mocked.call_count == 1
Пример #13
0
def test_rich_progress_bar_refresh_rate(progress_update, tmpdir, refresh_rate, expected_call_count):

    model = BoringModel()

    trainer = Trainer(
        default_root_dir=tmpdir,
        num_sanity_val_steps=0,
        limit_train_batches=6,
        limit_val_batches=6,
        max_epochs=1,
        callbacks=RichProgressBar(refresh_rate=refresh_rate),
    )

    trainer.fit(model)

    assert progress_update.call_count == expected_call_count
Пример #14
0
def test_rich_progress_bar_leave(tmpdir, leave, reset_call_count):
    # Calling `reset` means continuing on the same progress bar.
    model = BoringModel()

    with mock.patch(
            "pytorch_lightning.callbacks.progress.rich_progress.Progress.reset",
            autospec=True) as mock_progress_reset:
        progress_bar = RichProgressBar(leave=leave)
        trainer = Trainer(
            default_root_dir=tmpdir,
            num_sanity_val_steps=0,
            limit_train_batches=1,
            max_epochs=6,
            callbacks=progress_bar,
        )
        trainer.fit(model)
    assert mock_progress_reset.call_count == reset_call_count
Пример #15
0
def test_rich_progress_bar_num_sanity_val_steps(tmpdir, limit_val_batches: int):
    model = BoringModel()

    progress_bar = RichProgressBar()
    num_sanity_val_steps = 3

    trainer = Trainer(
        default_root_dir=tmpdir,
        num_sanity_val_steps=num_sanity_val_steps,
        limit_train_batches=1,
        limit_val_batches=limit_val_batches,
        max_epochs=1,
        callbacks=progress_bar,
    )

    trainer.fit(model)
    assert progress_bar.progress.tasks[0].completed == min(num_sanity_val_steps, limit_val_batches)
Пример #16
0
def test_rich_progress_bar_metric_display_task_id(tmpdir):
    class CustomModel(BoringModel):
        def training_step(self, *args, **kwargs):
            res = super().training_step(*args, **kwargs)
            self.log("train_loss", res["loss"], prog_bar=True)
            return res

    progress_bar = RichProgressBar()
    model = CustomModel()
    trainer = Trainer(default_root_dir=tmpdir, callbacks=progress_bar, fast_dev_run=True)

    trainer.fit(model)
    main_progress_bar_id = progress_bar.main_progress_bar_id
    val_progress_bar_id = progress_bar.val_progress_bar_id
    rendered = progress_bar.progress.columns[-1]._renderable_cache

    for key in ("loss", "v_num", "train_loss"):
        assert key in rendered[main_progress_bar_id][1]
        assert key not in rendered[val_progress_bar_id][1]
def test_rich_progress_bar(progress_update, tmpdir):

    model = BoringModel()

    trainer = Trainer(
        default_root_dir=tmpdir,
        num_sanity_val_steps=0,
        limit_train_batches=1,
        limit_val_batches=1,
        limit_test_batches=1,
        limit_predict_batches=1,
        max_steps=1,
        callbacks=RichProgressBar(),
    )

    trainer.fit(model)
    trainer.test(model)
    trainer.predict(model)

    assert progress_update.call_count == 6
Пример #18
0
def test_rich_progress_bar_keyboard_interrupt(tmpdir):
    """Test to ensure that when the user keyboard interrupts, we close the progress bar."""
    class TestModel(BoringModel):
        def on_train_start(self) -> None:
            raise KeyboardInterrupt

    model = TestModel()

    with mock.patch(
            "pytorch_lightning.callbacks.progress.rich_progress.Progress.stop",
            autospec=True) as mock_progress_stop:
        progress_bar = RichProgressBar()
        trainer = Trainer(
            default_root_dir=tmpdir,
            fast_dev_run=True,
            callbacks=progress_bar,
        )

        trainer.fit(model)
    mock_progress_stop.assert_called_once()
Пример #19
0
def test_rich_model_summary_callback():
    trainer = Trainer(callbacks=RichProgressBar())

    assert any(isinstance(cb, RichModelSummary) for cb in trainer.callbacks)
    assert isinstance(trainer.progress_bar_callback, RichProgressBar)