def test_base_log_interval_override(log_histogram, tmpdir, log_every_n_steps, max_steps, expected_calls, datadir):
    """Test logging interval set by log_every_n_steps argument."""
    monitor = TrainingDataMonitor(log_every_n_steps=log_every_n_steps)
    model = LitMNIST(data_dir=datadir, num_workers=0)
    trainer = Trainer(
        default_root_dir=tmpdir,
        log_every_n_steps=1,
        max_steps=max_steps,
        callbacks=[monitor],
    )

    trainer.fit(model)
    assert log_histogram.call_count == (expected_calls * 2)  # 2 tensors per log call
Example #2
0
def test_mnist(tmpdir, datadir):
    seed_everything()

    model = LitMNIST(data_dir=datadir, num_workers=0)
    trainer = Trainer(
        limit_train_batches=0.01,
        limit_val_batches=0.01,
        max_epochs=1,
        limit_test_batches=0.01,
        default_root_dir=tmpdir,
    )
    trainer.fit(model)
    loss = trainer.callback_metrics["train_loss"]
    assert loss <= 2.2, "mnist failed"
def test_base_log_interval_fallback(
    log_histogram, tmpdir, log_every_n_steps, max_steps, expected_calls
):
    """ Test that if log_every_n_steps not set in the callback, fallback to what is defined in the Trainer. """
    monitor = TrainingDataMonitor()
    model = LitMNIST(num_workers=0)
    trainer = Trainer(
        default_root_dir=tmpdir,
        log_every_n_steps=log_every_n_steps,
        max_steps=max_steps,
        callbacks=[monitor],
    )
    trainer.fit(model)
    assert log_histogram.call_count == (expected_calls * 2)  # 2 tensors per log call
Example #4
0
def test_mnist(tmpdir):
    reset_seed()

    model = LitMNIST(data_dir=tmpdir)
    trainer = pl.Trainer(limit_train_batches=0.01,
                         limit_val_batches=0.01,
                         max_epochs=1,
                         limit_test_batches=0.01,
                         default_root_dir=tmpdir)
    trainer.fit(model)
    trainer.test(model)
    loss = trainer.callback_metrics['loss']

    assert loss <= 2.0, 'mnist failed'
def test_training_data_monitor(log_histogram, tmpdir):
    """ Test that the TrainingDataMonitor logs histograms of data points going into training_step. """
    monitor = TrainingDataMonitor()
    model = LitMNIST()
    trainer = Trainer(
        default_root_dir=tmpdir, log_every_n_steps=1, callbacks=[monitor],
    )
    monitor.on_train_start(trainer, model)

    # single tensor
    example_data = torch.rand(2, 3, 4)
    monitor.on_train_batch_start(
        trainer, model, batch=example_data, batch_idx=0, dataloader_idx=0
    )
    assert log_histogram.call_args_list == [
        call(example_data, "training_step/[2, 3, 4]"),
    ]

    log_histogram.reset_mock()

    # tuple
    example_data = (torch.rand(2, 3, 4), torch.rand(5), "non-tensor")
    monitor.on_train_batch_start(
        trainer, model, batch=example_data, batch_idx=0, dataloader_idx=0
    )
    assert log_histogram.call_args_list == [
        call(example_data[0], "training_step/0/[2, 3, 4]"),
        call(example_data[1], "training_step/1/[5]"),
    ]

    log_histogram.reset_mock()

    # dict
    example_data = {
        "x0": torch.rand(2, 3, 4),
        "x1": torch.rand(5),
        "non-tensor": "non-tensor",
    }
    monitor.on_train_batch_start(
        trainer, model, batch=example_data, batch_idx=0, dataloader_idx=0
    )
    assert log_histogram.call_args_list == [
        call(example_data["x0"], "training_step/x0/[2, 3, 4]"),
        call(example_data["x1"], "training_step/x1/[5]"),
    ]