def test_classificationtask_task_predict(): model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10)) task = ClassificationTask(model) ds = DummyDataset() expected = list(range(10)) # single item x0, _ = ds[0] pred0 = task.predict(x0) assert pred0[0] in expected # list x1, _ = ds[1] pred1 = task.predict([x0, x1]) assert all(c in expected for c in pred1) assert pred0[0] == pred1[0]
def test_task_finetune(tmpdir: str): model = DummyClassifier() train_dl = torch.utils.data.DataLoader(DummyDataset()) val_dl = torch.utils.data.DataLoader(DummyDataset()) task = ClassificationTask(model, F.nll_loss) trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir) result = trainer.finetune(task, train_dl, val_dl, strategy=NoFreeze()) assert result
def test_task_fit(tmpdir: str): model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.LogSoftmax()) train_dl = torch.utils.data.DataLoader(DummyDataset()) val_dl = torch.utils.data.DataLoader(DummyDataset()) task = ClassificationTask(model, F.nll_loss) trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir) result = trainer.fit(task, train_dl, val_dl) assert result
def test_classificationtask_train(tmpdir: str, metrics: Any): model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.Softmax()) train_dl = torch.utils.data.DataLoader(DummyDataset()) val_dl = torch.utils.data.DataLoader(DummyDataset()) task = ClassificationTask(model, F.nll_loss, metrics=metrics) trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir) result = trainer.fit(task, train_dl, val_dl) assert result result = trainer.test(task, val_dl) assert "test_nll_loss" in result[0]
def test_classification_task_trainer_predict(tmpdir): model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10)) task = ClassificationTask(model) ds = PredictDummyDataset() batch_size = 3 predict_dl = torch.utils.data.DataLoader(ds, batch_size=batch_size) trainer = pl.Trainer(default_root_dir=tmpdir) predictions = trainer.predict(task, predict_dl) assert len(predictions) == len(ds) // batch_size for batch_pred in predictions: assert len(batch_pred) == batch_size assert all(y < 10 for y in batch_pred)
def test_task_datapipeline_save(tmpdir): model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10)) train_dl = torch.utils.data.DataLoader(DummyDataset()) task = ClassificationTask(model, F.nll_loss) # to check later task.data_pipeline.test = True # generate a checkpoint trainer = pl.Trainer( default_root_dir=tmpdir, limit_train_batches=1, max_epochs=1, progress_bar_refresh_rate=0, weights_summary=None, logger=False, ) trainer.fit(task, train_dl) path = str(tmpdir / "model.ckpt") trainer.save_checkpoint(path) # load from file task = ClassificationTask.load_from_checkpoint(path, model=model) assert task.data_pipeline.test
def test_classificationtask_trainer_predict(tmpdir): model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10)) task = ClassificationTask(model) ds = DummyDataset() batch_size = 3 predict_dl = torch.utils.data.DataLoader( ds, batch_size=batch_size, collate_fn=task.data_pipeline.collate_fn) trainer = pl.Trainer(default_root_dir=tmpdir) expected = list(range(10)) predictions = trainer.predict(task, predict_dl) predictions = predictions[0] # TODO(tchaton): why do we need this? for pred in predictions[:-1]: # check batch sizes are correct assert len(pred) == batch_size assert all(c in expected for c in pred) # check size of last batch (not full) assert len(predictions[-1]) == len(ds) % batch_size
# 1. Load a basic backbone model = nn.Sequential( nn.Flatten(), nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 10), ) # 2. Load a dataset dataset = datasets.MNIST('./data', download=True, transform=transforms.ToTensor()) # 3. Split the data randomly train, val, test = random_split(dataset, [50000, 5000, 5000]) # type: ignore # 4. Create the model classifier = ClassificationTask(model, loss_fn=nn.functional.cross_entropy, optimizer=optim.Adam, learning_rate=10e-3) # 5. Create the trainer trainer = pl.Trainer( max_epochs=10, limit_train_batches=128, limit_val_batches=128, ) # 6. Train the model trainer.fit(classifier, DataLoader(train), DataLoader(val)) # 7. Test the model results = trainer.test(classifier, test_dataloaders=DataLoader(test))