def test_predict_sklearn(): """Tests that we can generate predictions from a scikit-learn ``Bunch``.""" bunch = datasets.load_iris() model = TemplateSKLearnClassifier(num_features=DummyDataset.num_features, num_classes=DummyDataset.num_classes) data_pipe = DataPipeline(preprocess=TemplatePreprocess()) out = model.predict(bunch, data_source="sklearn", data_pipeline=data_pipe) assert isinstance(out[0], int)
def test_predict_numpy(): """Tests that we can generate predictions from a numpy array.""" row = np.random.rand(1, DummyDataset.num_features) model = TemplateSKLearnClassifier(num_features=DummyDataset.num_features, num_classes=DummyDataset.num_classes) data_pipe = DataPipeline(preprocess=TemplatePreprocess()) out = model.predict(row, data_pipeline=data_pipe) assert isinstance(out[0], int)
datamodule = TemplateData.from_sklearn( train_bunch=data_bunch, val_split=0.8, ) # 3. Build the model model = TemplateSKLearnClassifier( num_features=datamodule.num_features, num_classes=datamodule.num_classes, serializer=Labels(), ) # 4. Create the trainer. trainer = flash.Trainer(max_epochs=1, limit_train_batches=1, limit_val_batches=1) # 5. Train the model trainer.fit(model, datamodule=datamodule) # 6. Save it! trainer.save_checkpoint("template_model.pt") # 7. Classify a few examples predictions = model.predict([ np.array([4.9, 3.0, 1.4, 0.2]), np.array([6.9, 3.2, 5.7, 2.3]), np.array([7.2, 3.0, 5.8, 1.6]), ]) print(predictions)