Ejemplo n.º 1
0
def test_task_predict_raises():
    with pytest.raises(AttributeError,
                       match="`flash.Task.predict` has been removed."):
        model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10),
                              nn.Softmax())
        task = ClassificationTask(model, loss_fn=F.nll_loss)
        task.predict("args", kwarg="test")
Ejemplo n.º 2
0
def test_classificationtask_task_predict():
    model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.Softmax())
    task = ClassificationTask(model, preprocess=DefaultPreprocess())
    ds = DummyDataset()
    expected = list(range(10))
    # single item
    x0, _ = ds[0]
    pred0 = task.predict(x0)
    assert pred0[0] in expected
    # list
    x1, _ = ds[1]
    pred1 = task.predict([x0, x1])
    assert all(c in expected for c in pred1)
    assert pred0[0] == pred1[0]