def test_module_data_monitor_submodules_all(log_histogram, tmpdir): """ Test that the ModuleDataMonitor logs the inputs and outputs of each submodule. """ monitor = ModuleDataMonitor(submodules=True) model = ModuleDataMonitorModel() trainer = Trainer( default_root_dir=tmpdir, log_every_n_steps=1, callbacks=[monitor], ) monitor.on_train_start(trainer, model) monitor.on_train_batch_start(trainer, model, batch=None, batch_idx=0, dataloader_idx=0) example_input = torch.rand(2, 6, 2) output = model(example_input) assert log_histogram.call_args_list == [ call(model.layer1_input, "input/layer1/[2, 12]"), call(model.layer1_output, "output/layer1/[2, 5]"), call(model.layer2_input, "input/layer2.sub_layer/[2, 5]"), call(model.layer2_output, "output/layer2.sub_layer/[2, 2]"), call(model.layer2_input, "input/layer2/[2, 5]"), call(model.layer2_output, "output/layer2/[2, 2]"), call(example_input, "input/[2, 6, 2]"), call(output, "output/[2, 2]"), ]
def test_module_data_monitor_forward(log_histogram, tmpdir): """ Test that the default ModuleDataMonitor logs inputs and outputs of model's forward. """ monitor = ModuleDataMonitor(submodules=None) model = ModuleDataMonitorModel() trainer = Trainer( default_root_dir=tmpdir, log_every_n_steps=1, callbacks=[monitor], ) monitor.on_train_start(trainer, model) monitor.on_train_batch_start( trainer, model, batch=None, batch_idx=0, dataloader_idx=0 ) example_input = torch.rand(2, 6, 2) output = model(example_input) assert log_histogram.call_args_list == [ call(example_input, "input/[2, 6, 2]"), call(output, "output/[2, 2]"), ]