def test_rich_progress_bar_import_error(monkeypatch): import pytorch_lightning.callbacks.rich_model_summary as imports monkeypatch.setattr(imports, "_RICH_AVAILABLE", False) with pytest.raises( ModuleNotFoundError, match="`RichModelSummary` requires `rich` to be installed."): RichModelSummary()
def test_rich_summary_tuples(mock_table_add_row, mock_console): """Ensure that tuples are converted into string, and print is called correctly.""" model_summary = RichModelSummary() class TestModel(BoringModel): @property def example_input_array(self) -> Any: return torch.randn(4, 32) model = TestModel() summary = summarize(model) summary_data = summary._get_summary_data() model_summary.summarize(summary_data=summary_data, total_parameters=1, trainable_parameters=1, model_size=1) # ensure that summary was logged + the breakdown of model parameters assert mock_console.call_count == 2 # assert that the input summary data was converted correctly args, kwargs = mock_table_add_row.call_args_list[0] assert args[1:] == ("0", "layer", "Linear", "66 ", "[4, 32]", "[4, 2]")
def test_rich_progress_bar_import_error(): if not _RICH_AVAILABLE: with pytest.raises(ImportError, match="`RichModelSummary` requires `rich` to be installed."): Trainer(callbacks=RichModelSummary())