Пример #1
0
def test_tbd_remove_in_v1_0_0_model_hooks():
    hparams = EvalModelTemplate.get_default_hparams()

    model = ModelVer0_6(hparams)

    with pytest.deprecated_call(match='v1.0'):
        trainer = Trainer(logger=False)
        trainer.test(model)
    assert trainer.callback_metrics == {'test_loss': 0.6}

    with pytest.deprecated_call(match='v1.0'):
        trainer = Trainer(logger=False)
        # TODO: why `dataloder` is required if it is not used
        result = trainer._evaluate(model, dataloaders=[[None]], max_batches=1)
    assert result == {'val_loss': 0.6}

    model = ModelVer0_7(hparams)

    with pytest.deprecated_call(match='v1.0'):
        trainer = Trainer(logger=False)
        trainer.test(model)
    assert trainer.callback_metrics == {'test_loss': 0.7}

    with pytest.deprecated_call(match='v1.0'):
        trainer = Trainer(logger=False)
        # TODO: why `dataloder` is required if it is not used
        result = trainer._evaluate(model, dataloaders=[[None]], max_batches=1)
    assert result == {'val_loss': 0.7}
def test_tbd_remove_in_v1_0_0_model_hooks():

    model = ModelVer0_6()

    with pytest.deprecated_call(match='v1.0'):
        trainer = Trainer(logger=False)
        trainer.test(model)
    assert trainer.callback_metrics == {'test_loss': torch.tensor(0.6)}

    with pytest.deprecated_call(match='will be removed in v1.0'):
        trainer = Trainer(logger=False)
        # TODO: why `dataloder` is required if it is not used
        result = trainer._evaluate(model, dataloaders=[[None]], max_batches=1)
    assert result == {'val_loss': torch.tensor(0.6)}

    model = ModelVer0_7()

    with pytest.deprecated_call(match='will be removed in v1.0'):
        trainer = Trainer(logger=False)
        trainer.test(model)
    assert trainer.callback_metrics == {'test_loss': torch.tensor(0.7)}

    with pytest.deprecated_call(match='will be removed in v1.0'):
        trainer = Trainer(logger=False)
        # TODO: why `dataloder` is required if it is not used
        result = trainer._evaluate(model, dataloaders=[[None]], max_batches=1)
    assert result == {'val_loss': torch.tensor(0.7)}
Пример #3
0
def test_tbd_remove_in_v1_0_0_model_hooks():
    hparams = tutils.get_default_hparams()

    model = ModelVer0_6(hparams)

    trainer = Trainer(logger=False)
    trainer.test(model)
    assert trainer.callback_metrics == {'test_loss': 0.6}

    trainer = Trainer(logger=False)
    # TODO: why `dataloder` is required if it is not used
    result = trainer._evaluate(model, dataloaders=[[None]], max_batches=1)
    assert result == {'val_loss': 0.6}

    model = ModelVer0_7(hparams)

    trainer = Trainer(logger=False)
    trainer.test(model)
    assert trainer.callback_metrics == {'test_loss': 0.7}

    trainer = Trainer(logger=False)
    # TODO: why `dataloder` is required if it is not used
    result = trainer._evaluate(model, dataloaders=[[None]], max_batches=1)
    assert result == {'val_loss': 0.7}