コード例 #1
0
def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches):
    """
    Test that num_sanity_val_steps=-1 runs through all validation data once, and as many batches as
    limited by "limit_val_batches" Trainer argument.
    """
    model = EvalModelTemplate()
    model.validation_step = model.validation_step__multiple_dataloaders
    model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
    trainer = Trainer(
        default_root_dir=tmpdir,
        num_sanity_val_steps=-1,
        limit_val_batches=limit_val_batches,
        max_steps=1,
    )
    assert trainer.num_sanity_val_steps == float('inf')
    val_dataloaders = model.val_dataloader__multiple()
コード例 #2
0
def test_num_sanity_val_steps(tmpdir, limit_val_batches):
    """
    Test that num_sanity_val_steps=-1 runs through all validation data once.
    Makes sure this setting is independent of limit_val_batches.
    """
    model = EvalModelTemplate()
    model.validation_step = model.validation_step__multiple_dataloaders
    model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
    trainer = Trainer(
        default_root_dir=tmpdir,
        num_sanity_val_steps=-1,
        limit_val_batches=limit_val_batches,  # should have no influence
        max_steps=1,
    )
    assert trainer.num_sanity_val_steps == float('inf')
    val_dataloaders = model.val_dataloader__multiple()

    with patch.object(trainer, 'evaluation_forward', wraps=trainer.evaluation_forward) as mocked:
        trainer.fit(model, val_dataloaders=val_dataloaders)
        assert mocked.call_count == sum(len(dl) * (limit_val_batches > 0) for dl in val_dataloaders)
コード例 #3
0
def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches):
    """
    Test that num_sanity_val_steps=-1 runs through all validation data once, and as many batches as
    limited by "limit_val_batches" Trainer argument.
    """
    model = EvalModelTemplate()
    model.validation_step = model.validation_step__multiple_dataloaders
    model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
    trainer = Trainer(
        default_root_dir=tmpdir,
        num_sanity_val_steps=-1,
        limit_val_batches=limit_val_batches,
        max_steps=1,
    )
    assert trainer.num_sanity_val_steps == float('inf')
    val_dataloaders = model.val_dataloader__multiple()

    with patch.object(trainer,
                      'evaluation_forward',
                      wraps=trainer.evaluation_forward) as mocked:
        trainer.fit(model, val_dataloaders=val_dataloaders)
        assert mocked.call_count == sum(trainer.num_val_batches)