def test_forward(num_classes, shape): """Tests that a tensor can be given to the model forward and gives the correct output size.""" model = TemplateSKLearnClassifier( num_features=shape[1], num_classes=num_classes, ) model.eval() row = torch.rand(*shape) out = model(row) assert out.shape == (shape[0], num_classes)
def test_jit(tmpdir, jitter, args): path = os.path.join(tmpdir, "test.pt") model = TemplateSKLearnClassifier(num_features=16, num_classes=10) model.eval() model = jitter(model, *args) torch.jit.save(model, path) model = torch.jit.load(path) out = model(torch.rand(1, 16)) assert isinstance(out, torch.Tensor) assert out.shape == torch.Size([1, 10])