def test_rich_progress_bar_import_error(monkeypatch):
    import pytorch_lightning.callbacks.rich_model_summary as imports

    monkeypatch.setattr(imports, "_RICH_AVAILABLE", False)
    with pytest.raises(
            ModuleNotFoundError,
            match="`RichModelSummary` requires `rich` to be installed."):
        RichModelSummary()
Пример #2
0
def test_rich_summary_tuples(mock_table_add_row, mock_console):
    """Ensure that tuples are converted into string, and print is called correctly."""
    model_summary = RichModelSummary()

    class TestModel(BoringModel):
        @property
        def example_input_array(self) -> Any:
            return torch.randn(4, 32)

    model = TestModel()
    summary = summarize(model)
    summary_data = summary._get_summary_data()

    model_summary.summarize(summary_data=summary_data, total_parameters=1, trainable_parameters=1, model_size=1)

    # ensure that summary was logged + the breakdown of model parameters
    assert mock_console.call_count == 2
    # assert that the input summary data was converted correctly
    args, kwargs = mock_table_add_row.call_args_list[0]
    assert args[1:] == ("0", "layer", "Linear", "66  ", "[4, 32]", "[4, 2]")
Пример #3
0
def test_rich_progress_bar_import_error():
    if not _RICH_AVAILABLE:
        with pytest.raises(ImportError, match="`RichModelSummary` requires `rich` to be installed."):
            Trainer(callbacks=RichModelSummary())