def test_logistic_regression_model(tmpdir, datadir): pl.seed_everything(0) # create dataset dm = MNISTDataModule(num_workers=0, data_dir=datadir) model = LogisticRegression(input_dim=28 * 28, num_classes=10, learning_rate=0.001) model.prepare_data = dm.prepare_data model.setup = dm.setup model.train_dataloader = dm.train_dataloader model.val_dataloader = dm.val_dataloader model.test_dataloader = dm.test_dataloader trainer = pl.Trainer( max_epochs=3, default_root_dir=tmpdir, progress_bar_refresh_rate=0, logger=False, checkpoint_callback=False, ) trainer.fit(model) trainer.test(model) assert trainer.progress_bar_dict['test_acc'] >= 0.9
def objective(trial): bias = trial.suggest_categorical("bias", [True, False]) learning_rate = trial.suggest_float("learning_rate", 1e-6, 1e-1, log=True) l1_strength = trial.suggest_float("l1_strength", 1e-10, 1e2) l2_strength = trial.suggest_float("l2_strength", 1e-10, 1e2) metrics_callback = MetricsCallback() datamodule = GrenadeDataModule() model = LogisticRegression(input_dim=12, num_classes=2, bias=bias, learning_rate=learning_rate, l1_strength=l1_strength, l2_strength=l2_strength) trainer = Trainer(max_epochs=200, gpus=1, callbacks=[ metrics_callback, PyTorchLightningPruningCallback(trial, monitor="val_acc") ]) trainer.fit(model=model, datamodule=datamodule) return metrics_callback.metrics[-1]["val_acc"].item()