def test_base_no_logger_warning():
    """ Test a warning is displayed when Trainer has no logger. """
    monitor = TrainingDataMonitor()
    trainer = Trainer(logger=False, callbacks=[monitor])
    with pytest.warns(
        UserWarning, match="Cannot log histograms because Trainer has no logger"
    ):
        monitor.on_train_start(trainer, pl_module=None)
def test_base_unsupported_logger_warning(tmpdir):
    """ Test a warning is displayed when an unsupported logger is used. """
    monitor = TrainingDataMonitor()
    trainer = Trainer(logger=LoggerCollection([TensorBoardLogger(tmpdir)]),
                      callbacks=[monitor])
    with pytest.warns(UserWarning,
                      match="does not support logging with LoggerCollection"):
        monitor.on_train_start(trainer, pl_module=None)
def test_training_data_monitor(log_histogram, tmpdir, datadir):
    """ Test that the TrainingDataMonitor logs histograms of data points going into training_step. """
    monitor = TrainingDataMonitor()
    model = LitMNIST(data_dir=datadir)
    trainer = Trainer(
        default_root_dir=tmpdir,
        log_every_n_steps=1,
        callbacks=[monitor],
    )
    monitor.on_train_start(trainer, model)

    # single tensor
    example_data = torch.rand(2, 3, 4)
    monitor.on_train_batch_start(trainer,
                                 model,
                                 batch=example_data,
                                 batch_idx=0,
                                 dataloader_idx=0)
    assert log_histogram.call_args_list == [
        call(example_data, "training_step/[2, 3, 4]"),
    ]

    log_histogram.reset_mock()

    # tuple
    example_data = (torch.rand(2, 3, 4), torch.rand(5), "non-tensor")
    monitor.on_train_batch_start(trainer,
                                 model,
                                 batch=example_data,
                                 batch_idx=0,
                                 dataloader_idx=0)
    assert log_histogram.call_args_list == [
        call(example_data[0], "training_step/0/[2, 3, 4]"),
        call(example_data[1], "training_step/1/[5]"),
    ]

    log_histogram.reset_mock()

    # dict
    example_data = {
        "x0": torch.rand(2, 3, 4),
        "x1": torch.rand(5),
        "non-tensor": "non-tensor",
    }
    monitor.on_train_batch_start(trainer,
                                 model,
                                 batch=example_data,
                                 batch_idx=0,
                                 dataloader_idx=0)
    assert log_histogram.call_args_list == [
        call(example_data["x0"], "training_step/x0/[2, 3, 4]"),
        call(example_data["x1"], "training_step/x1/[5]"),
    ]