def test_module_data_monitor_submodules_all(log_histogram, tmpdir):
    """ Test that the ModuleDataMonitor logs the inputs and outputs of each submodule. """
    monitor = ModuleDataMonitor(submodules=True)
    model = ModuleDataMonitorModel()
    trainer = Trainer(
        default_root_dir=tmpdir,
        log_every_n_steps=1,
        callbacks=[monitor],
    )
    monitor.on_train_start(trainer, model)
    monitor.on_train_batch_start(trainer,
                                 model,
                                 batch=None,
                                 batch_idx=0,
                                 dataloader_idx=0)

    example_input = torch.rand(2, 6, 2)
    output = model(example_input)
    assert log_histogram.call_args_list == [
        call(model.layer1_input, "input/layer1/[2, 12]"),
        call(model.layer1_output, "output/layer1/[2, 5]"),
        call(model.layer2_input, "input/layer2.sub_layer/[2, 5]"),
        call(model.layer2_output, "output/layer2.sub_layer/[2, 2]"),
        call(model.layer2_input, "input/layer2/[2, 5]"),
        call(model.layer2_output, "output/layer2/[2, 2]"),
        call(example_input, "input/[2, 6, 2]"),
        call(output, "output/[2, 2]"),
    ]
def test_module_data_monitor_forward(log_histogram, tmpdir):
    """ Test that the default ModuleDataMonitor logs inputs and outputs of model's forward. """
    monitor = ModuleDataMonitor(submodules=None)
    model = ModuleDataMonitorModel()
    trainer = Trainer(
        default_root_dir=tmpdir, log_every_n_steps=1, callbacks=[monitor],
    )
    monitor.on_train_start(trainer, model)
    monitor.on_train_batch_start(
        trainer, model, batch=None, batch_idx=0, dataloader_idx=0
    )

    example_input = torch.rand(2, 6, 2)
    output = model(example_input)
    assert log_histogram.call_args_list == [
        call(example_input, "input/[2, 6, 2]"),
        call(output, "output/[2, 2]"),
    ]