def test_predict_sklearn():
    """Tests that we can generate predictions from a scikit-learn ``Bunch``."""
    bunch = datasets.load_iris()
    model = TemplateSKLearnClassifier(num_features=DummyDataset.num_features, num_classes=DummyDataset.num_classes)
    data_pipe = DataPipeline(preprocess=TemplatePreprocess())
    out = model.predict(bunch, data_source="sklearn", data_pipeline=data_pipe)
    assert isinstance(out[0], int)
def test_predict_numpy():
    """Tests that we can generate predictions from a numpy array."""
    row = np.random.rand(1, DummyDataset.num_features)
    model = TemplateSKLearnClassifier(num_features=DummyDataset.num_features, num_classes=DummyDataset.num_classes)
    data_pipe = DataPipeline(preprocess=TemplatePreprocess())
    out = model.predict(row, data_pipeline=data_pipe)
    assert isinstance(out[0], int)
Beispiel #3
0
datamodule = TemplateData.from_sklearn(
    train_bunch=data_bunch,
    val_split=0.8,
)

# 3. Build the model
model = TemplateSKLearnClassifier(
    num_features=datamodule.num_features,
    num_classes=datamodule.num_classes,
    serializer=Labels(),
)

# 4. Create the trainer.
trainer = flash.Trainer(max_epochs=1,
                        limit_train_batches=1,
                        limit_val_batches=1)

# 5. Train the model
trainer.fit(model, datamodule=datamodule)

# 6. Save it!
trainer.save_checkpoint("template_model.pt")

# 7. Classify a few examples
predictions = model.predict([
    np.array([4.9, 3.0, 1.4, 0.2]),
    np.array([6.9, 3.2, 5.7, 2.3]),
    np.array([7.2, 3.0, 5.8, 1.6]),
])
print(predictions)