Exemplo n.º 1
0
def test_train(data_format, grayscale, tmp_path):
    # TODO: Test for grayscale = False and different size.
    # data = create_dataset(
    #     data_format=data_format, size=28, grayscale=True
    # )
    # wrapper = TorchImageClassificationWrapper("simple-cnn")
    data = create_dataset(
        data_format=data_format, grayscale=grayscale, tmp_path=tmp_path
    )
    wrapper = TorchImageClassificationWrapper("resnet18", {}, tmp_path)

    # TODO: Test both resnet18 and simple-cnn with a few different configurations of
    #   data.
    # TODO: Test with and without val/test data.
    # TODO: Check that something was written to writer and experiment.
    wrapper._train(
        train_data=data,
        val_data=data,
        test_data=data,
        writer=SummaryWriter(write_to_disk=False),
        experiment=DummyExperiment(),
        dry_run=True,  # True,
    )

    assert isinstance(wrapper.model, nn.Module)
    assert (tmp_path / "model.pt").exists()
Exemplo n.º 2
0
def test_load(tmp_path):
    train_data = create_dataset()
    train(
        "random-forest",
        train_data=train_data,
        test_data=train_data,
        dry_run=True,
        save=tmp_path,
    )
    loaded_model_wrapper = load(tmp_path)
    assert isinstance(loaded_model_wrapper, ModelWrapper)
    assert loaded_model_wrapper.model is not None
Exemplo n.º 3
0
def wrapper(tmp_path):
    """A simple wrapper around random-forest model"""
    data = create_dataset(data_format="numpy", grayscale=False, size=224)
    wrapper = TorchImageClassificationWrapper("resnet18", {}, tmp_path)
    wrapper._train(
        train_data=data,
        val_data=None,
        test_data=None,
        writer=SummaryWriter(write_to_disk=False),
        experiment=DummyExperiment(),
        dry_run=True,
    )
    return wrapper
Exemplo n.º 4
0
def wrapper(tmp_path):
    """A simple wrapper around random-forest model"""
    data = create_dataset(grayscale=False)
    wrapper = SklearnImageClassificationWrapper("random-forest", {}, tmp_path)
    wrapper._train(
        train_data=data,
        val_data=None,
        test_data=None,
        writer=SummaryWriter(write_to_disk=False),
        experiment=DummyExperiment(),
        dry_run=True,
    )
    return wrapper
Exemplo n.º 5
0
def test_load_image(tmp_path):
    data = create_dataset(grayscale=False,
                          data_format="files",
                          tmp_path=tmp_path)

    # Select a random image.
    image_path = next(data.rglob("*.png"))

    # torch
    img = load_image(image_path, resize=50, crop=40)
    assert isinstance(img, torch.Tensor)
    assert img.shape == (3, 40, 40)

    # numpy
    img = load_image(image_path, resize=50, crop=40, to_numpy=True)
    assert isinstance(img, np.ndarray)
    assert img.shape == (3, 40, 40)
Exemplo n.º 6
0
def test_train(data_format, grayscale, tmp_path):
    data = create_dataset(
        data_format=data_format, grayscale=grayscale, size=28, tmp_path=tmp_path,
    )
    wrapper = SklearnImageClassificationWrapper("random-forest", {}, tmp_path)

    wrapper._train(
        train_data=data,
        val_data=data,
        test_data=data,
        writer=SummaryWriter(write_to_disk=False),
        experiment=DummyExperiment(),
        dry_run=True,
    )

    assert isinstance(wrapper.model, RandomForestClassifier)
    assert wrapper.scaler is not None
    assert (tmp_path / "model.joblib").exists()
Exemplo n.º 7
0
def test_train(tmp_path):
    train_data = create_dataset()

    with pytest.raises(ValueError):
        train("non-existing-model-123", None)

    # with pytest.raises(ValueError):
    #     train("random-forest", None, config={"non-existing-parameter": 123})

    # With save=False (this has to be checked first, so tmp_path is still empty)
    model_wrapper = train(
        "random-forest",
        train_data=train_data,
        dry_run=True,
        save=False,
    )
    assert isinstance(model_wrapper, ModelWrapper)
    assert not any(tmp_path.iterdir())  # is empty dir

    # With all datasets
    model_wrapper = train(
        "random-forest",
        train_data=train_data,
        val_data=train_data,
        test_data=train_data,
        dry_run=True,
        save=tmp_path,
    )
    assert isinstance(model_wrapper, ModelWrapper)
    assert (tmp_path / "info.yml").exists()
    assert (tmp_path / "model.joblib").exists()

    # With only train data
    model_wrapper = train(
        "random-forest",
        train_data=train_data,
        dry_run=True,
        save=tmp_path,
    )
    assert isinstance(model_wrapper, ModelWrapper)
    assert (tmp_path / "info.yml").exists()
    assert (tmp_path / "model.joblib").exists()
Exemplo n.º 8
0
def files_data(tmp_path):
    return create_dataset(data_format="files", seed=0, tmp_path=tmp_path)
Exemplo n.º 9
0
def torch_data():
    return create_dataset(data_format="torch", seed=0, grayscale=False)
Exemplo n.º 10
0
def numpy_data():
    return create_dataset(data_format="numpy", seed=0, grayscale=False)