def test_task_predict_raises(): with pytest.raises(AttributeError, match="`flash.Task.predict` has been removed."): model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.Softmax()) task = ClassificationTask(model, loss_fn=F.nll_loss) task.predict("args", kwarg="test")
def test_classificationtask_task_predict(): model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.Softmax()) task = ClassificationTask(model, preprocess=DefaultPreprocess()) ds = DummyDataset() expected = list(range(10)) # single item x0, _ = ds[0] pred0 = task.predict(x0) assert pred0[0] in expected # list x1, _ = ds[1] pred1 = task.predict([x0, x1]) assert all(c in expected for c in pred1) assert pred0[0] == pred1[0]