Пример #1
0
def test_logistic_regression_model(tmpdir, datadir):
    pl.seed_everything(0)

    # create dataset
    dm = MNISTDataModule(num_workers=0, data_dir=datadir)

    model = LogisticRegression(input_dim=28 * 28,
                               num_classes=10,
                               learning_rate=0.001)
    model.prepare_data = dm.prepare_data
    model.setup = dm.setup
    model.train_dataloader = dm.train_dataloader
    model.val_dataloader = dm.val_dataloader
    model.test_dataloader = dm.test_dataloader

    trainer = pl.Trainer(
        max_epochs=3,
        default_root_dir=tmpdir,
        progress_bar_refresh_rate=0,
        logger=False,
        checkpoint_callback=False,
    )
    trainer.fit(model)
    trainer.test(model)
    assert trainer.progress_bar_dict['test_acc'] >= 0.9
Пример #2
0
def objective(trial):
    bias = trial.suggest_categorical("bias", [True, False])
    learning_rate = trial.suggest_float("learning_rate", 1e-6, 1e-1, log=True)
    l1_strength = trial.suggest_float("l1_strength", 1e-10, 1e2)
    l2_strength = trial.suggest_float("l2_strength", 1e-10, 1e2)

    metrics_callback = MetricsCallback()
    datamodule = GrenadeDataModule()
    model = LogisticRegression(input_dim=12,
                               num_classes=2,
                               bias=bias,
                               learning_rate=learning_rate,
                               l1_strength=l1_strength,
                               l2_strength=l2_strength)
    trainer = Trainer(max_epochs=200,
                      gpus=1,
                      callbacks=[
                          metrics_callback,
                          PyTorchLightningPruningCallback(trial,
                                                          monitor="val_acc")
                      ])
    trainer.fit(model=model, datamodule=datamodule)

    return metrics_callback.metrics[-1]["val_acc"].item()