Exemple #1
0
def test_model_saves_with_example_output(tmpdir):
    """Test that ONNX model saves when provided with example output."""
    model = BoringModel()
    trainer = Trainer(fast_dev_run=True)
    trainer.fit(model)

    file_path = os.path.join(tmpdir, "model.onnx")
    input_sample = torch.randn((1, 32))
    model.eval()
    example_outputs = model.forward(input_sample)
    model.to_onnx(file_path, input_sample, example_outputs=example_outputs)
    assert os.path.exists(file_path) is True
def test_eval_loop_config(tmpdir):
    """When either eval step or eval data is missing."""
    trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)

    # has val step but no val data
    model = BoringModel()
    model.val_dataloader = None
    with pytest.raises(MisconfigurationException,
                       match=r"No `val_dataloader\(\)` method defined"):
        trainer.validate(model)

    # has test data but no val step
    model = BoringModel()
    model.validation_step = None
    with pytest.raises(MisconfigurationException,
                       match=r"No `validation_step\(\)` method defined"):
        trainer.validate(model)

    # has test loop but no test data
    model = BoringModel()
    model.test_dataloader = None
    with pytest.raises(MisconfigurationException,
                       match=r"No `test_dataloader\(\)` method defined"):
        trainer.test(model)

    # has test data but no test step
    model = BoringModel()
    model.test_step = None
    with pytest.raises(MisconfigurationException,
                       match=r"No `test_step\(\)` method defined"):
        trainer.test(model)

    # has predict step but no predict data
    model = BoringModel()
    model.predict_dataloader = None
    with pytest.raises(MisconfigurationException,
                       match=r"No `predict_dataloader\(\)` method defined"):
        trainer.predict(model)

    # has predict data but no predict_step
    model = BoringModel()
    model.predict_step = None
    with pytest.raises(MisconfigurationException,
                       match=r"`predict_step` cannot be None."):
        trainer.predict(model)

    # has predict data but no forward
    model = BoringModel()
    model.forward = None
    with pytest.raises(MisconfigurationException,
                       match=r"requires `forward` method to run."):
        trainer.predict(model)