def test_base_no_logger_warning(): """ Test a warning is displayed when Trainer has no logger. """ monitor = TrainingDataMonitor() trainer = Trainer(logger=False, callbacks=[monitor]) with pytest.warns( UserWarning, match="Cannot log histograms because Trainer has no logger" ): monitor.on_train_start(trainer, pl_module=None)
def test_base_unsupported_logger_warning(tmpdir): """ Test a warning is displayed when an unsupported logger is used. """ monitor = TrainingDataMonitor() trainer = Trainer(logger=LoggerCollection([TensorBoardLogger(tmpdir)]), callbacks=[monitor]) with pytest.warns(UserWarning, match="does not support logging with LoggerCollection"): monitor.on_train_start(trainer, pl_module=None)
def test_training_data_monitor(log_histogram, tmpdir, datadir): """ Test that the TrainingDataMonitor logs histograms of data points going into training_step. """ monitor = TrainingDataMonitor() model = LitMNIST(data_dir=datadir) trainer = Trainer( default_root_dir=tmpdir, log_every_n_steps=1, callbacks=[monitor], ) monitor.on_train_start(trainer, model) # single tensor example_data = torch.rand(2, 3, 4) monitor.on_train_batch_start(trainer, model, batch=example_data, batch_idx=0, dataloader_idx=0) assert log_histogram.call_args_list == [ call(example_data, "training_step/[2, 3, 4]"), ] log_histogram.reset_mock() # tuple example_data = (torch.rand(2, 3, 4), torch.rand(5), "non-tensor") monitor.on_train_batch_start(trainer, model, batch=example_data, batch_idx=0, dataloader_idx=0) assert log_histogram.call_args_list == [ call(example_data[0], "training_step/0/[2, 3, 4]"), call(example_data[1], "training_step/1/[5]"), ] log_histogram.reset_mock() # dict example_data = { "x0": torch.rand(2, 3, 4), "x1": torch.rand(5), "non-tensor": "non-tensor", } monitor.on_train_batch_start(trainer, model, batch=example_data, batch_idx=0, dataloader_idx=0) assert log_histogram.call_args_list == [ call(example_data["x0"], "training_step/x0/[2, 3, 4]"), call(example_data["x1"], "training_step/x1/[5]"), ]