def test_base_log_interval_override(log_histogram, tmpdir, log_every_n_steps, max_steps, expected_calls, datadir): """Test logging interval set by log_every_n_steps argument.""" monitor = TrainingDataMonitor(log_every_n_steps=log_every_n_steps) model = LitMNIST(data_dir=datadir, num_workers=0) trainer = Trainer( default_root_dir=tmpdir, log_every_n_steps=1, max_steps=max_steps, callbacks=[monitor], ) trainer.fit(model) assert log_histogram.call_count == (expected_calls * 2) # 2 tensors per log call
def test_mnist(tmpdir, datadir): seed_everything() model = LitMNIST(data_dir=datadir, num_workers=0) trainer = Trainer( limit_train_batches=0.01, limit_val_batches=0.01, max_epochs=1, limit_test_batches=0.01, default_root_dir=tmpdir, ) trainer.fit(model) loss = trainer.callback_metrics["train_loss"] assert loss <= 2.2, "mnist failed"
def test_base_log_interval_fallback( log_histogram, tmpdir, log_every_n_steps, max_steps, expected_calls ): """ Test that if log_every_n_steps not set in the callback, fallback to what is defined in the Trainer. """ monitor = TrainingDataMonitor() model = LitMNIST(num_workers=0) trainer = Trainer( default_root_dir=tmpdir, log_every_n_steps=log_every_n_steps, max_steps=max_steps, callbacks=[monitor], ) trainer.fit(model) assert log_histogram.call_count == (expected_calls * 2) # 2 tensors per log call
def test_mnist(tmpdir): reset_seed() model = LitMNIST(data_dir=tmpdir) trainer = pl.Trainer(limit_train_batches=0.01, limit_val_batches=0.01, max_epochs=1, limit_test_batches=0.01, default_root_dir=tmpdir) trainer.fit(model) trainer.test(model) loss = trainer.callback_metrics['loss'] assert loss <= 2.0, 'mnist failed'
def test_training_data_monitor(log_histogram, tmpdir): """ Test that the TrainingDataMonitor logs histograms of data points going into training_step. """ monitor = TrainingDataMonitor() model = LitMNIST() trainer = Trainer( default_root_dir=tmpdir, log_every_n_steps=1, callbacks=[monitor], ) monitor.on_train_start(trainer, model) # single tensor example_data = torch.rand(2, 3, 4) monitor.on_train_batch_start( trainer, model, batch=example_data, batch_idx=0, dataloader_idx=0 ) assert log_histogram.call_args_list == [ call(example_data, "training_step/[2, 3, 4]"), ] log_histogram.reset_mock() # tuple example_data = (torch.rand(2, 3, 4), torch.rand(5), "non-tensor") monitor.on_train_batch_start( trainer, model, batch=example_data, batch_idx=0, dataloader_idx=0 ) assert log_histogram.call_args_list == [ call(example_data[0], "training_step/0/[2, 3, 4]"), call(example_data[1], "training_step/1/[5]"), ] log_histogram.reset_mock() # dict example_data = { "x0": torch.rand(2, 3, 4), "x1": torch.rand(5), "non-tensor": "non-tensor", } monitor.on_train_batch_start( trainer, model, batch=example_data, batch_idx=0, dataloader_idx=0 ) assert log_histogram.call_args_list == [ call(example_data["x0"], "training_step/x0/[2, 3, 4]"), call(example_data["x1"], "training_step/x1/[5]"), ]