Beispiel #1
0
def tune_mnist(num_samples=10, num_epochs=10, gpus_per_trial=0):
    data_dir = os.path.join(tempfile.gettempdir(), "mnist_data_")
    # Download data
    MNISTDataModule(data_dir=data_dir).prepare_data()

    config = {
        "layer_1": tune.choice([32, 64, 128]),
        "layer_2": tune.choice([64, 128, 256]),
        "lr": tune.loguniform(1e-4, 1e-1),
        "batch_size": tune.choice([32, 64, 128]),
    }

    trainable = tune.with_parameters(
        train_mnist_tune,
        data_dir=data_dir,
        num_epochs=num_epochs,
        num_gpus=gpus_per_trial)
    analysis = tune.run(
        trainable,
        resources_per_trial={
            "cpu": 1,
            "gpu": gpus_per_trial
        },
        metric="loss",
        mode="min",
        config=config,
        num_samples=num_samples,
        name="tune_mnist")

    print("Best hyperparameters found were: ", analysis.best_config)
Beispiel #2
0
def train_mnist(config,
                data_dir=None,
                num_epochs=10,
                num_workers=1,
                use_gpu=False,
                callbacks=None):
    # Make sure data is downloaded on all nodes.
    def download_data():
        from filelock import FileLock
        with FileLock(os.path.join(data_dir, ".lock")):
            MNISTDataModule(data_dir=data_dir).prepare_data()

    model = LightningMNISTClassifier(config, data_dir)

    callbacks = callbacks or []

    trainer = pl.Trainer(max_epochs=num_epochs,
                         callbacks=callbacks,
                         progress_bar_refresh_rate=0,
                         plugins=[
                             RayPlugin(num_workers=num_workers,
                                       use_gpu=use_gpu,
                                       init_hook=download_data)
                         ])
    dm = MNISTDataModule(data_dir=data_dir,
                         num_workers=1,
                         batch_size=config["batch_size"])
    trainer.fit(model, dm)
Beispiel #3
0
def run_mnist_prediction(prediction_hparams):
    # DATA
    datamodule = MNISTDataModule(data_dir=DATA_PATH + '/mnist_')
    datamodule.prepare_data()  # downloads data to given path
    train_dataloader = datamodule.train_dataloader()
    val_dataloader = datamodule.val_dataloader()
    test_dataloader = datamodule.test_dataloader()

    encoder_model = SimpleConvNet()
    return run_prediction(prediction_hparams, train_dataloader, val_dataloader, test_dataloader,
                          encoder_model, encoder_out_dim=50)
Beispiel #4
0
def train_mnist_tune(config, data_dir=None, num_epochs=10, num_gpus=0):
    model = LightningMNISTClassifier(config, data_dir)
    dm = MNISTDataModule(
        data_dir=data_dir, num_workers=1, batch_size=config["batch_size"])
    metrics = {"loss": "ptl/val_loss", "acc": "ptl/val_accuracy"}
    trainer = pl.Trainer(
        max_epochs=num_epochs,
        gpus=num_gpus,
        progress_bar_refresh_rate=0,
        callbacks=[TuneReportCallback(metrics, on="validation_end")])
    trainer.fit(model, dm)
Beispiel #5
0
def test_predict(tmpdir, ray_start_2_cpus, seed, num_slots):
    config = {
        "layer_1": 32,
        "layer_2": 32,
        "lr": 1e-2,
        "batch_size": 32,
    }
    model = LightningMNISTClassifier(config, tmpdir)
    dm = MNISTDataModule(
        data_dir=tmpdir, num_workers=1, batch_size=config["batch_size"])
    trainer = get_trainer(
        tmpdir, limit_train_batches=10, max_epochs=1, num_slots=num_slots)
    predict_test(trainer, model, dm)
Beispiel #6
0
def train_mnist_tune(config, num_epochs=10, num_gpus=0):
    data_dir = os.path.abspath("./data")
    model = LightningMNISTClassifier(config, data_dir)
    with FileLock(os.path.expanduser("~/.data.lock")):
        dm = MNISTDataModule(data_dir=data_dir,
                             num_workers=1,
                             batch_size=config["batch_size"])
    metrics = {"loss": "ptl/val_loss", "acc": "ptl/val_accuracy"}
    trainer = pl.Trainer(
        max_epochs=num_epochs,
        # If fractional GPUs passed in, convert to int.
        gpus=math.ceil(num_gpus),
        progress_bar_refresh_rate=0,
        callbacks=[TuneReportCallback(metrics, on="validation_end")])
    trainer.fit(model, dm)
Beispiel #7
0
def sweep_iteration(proj):
    wandb.init()    # required to have access to `wandb.config`
    wandb_logger = WandbLogger()
    mnist = MNISTDataModule('../data/')
    # setup model - note how we refer to sweep parameters with wandb.config
    mlp = MLP(
        num_classes=mnist.num_classes, C=C, W=W, H=H,
        num_layer_1=wandb.config.num_layer_1,
        num_layer_2=wandb.config.num_layer_2,
    )
    model = LitModel(datamodule=mnist, arch=mlp, lr=wandb.config.lr)
    trainer = Trainer(
        logger=wandb_logger,    # W&B integration
        #gpus=-1,                # use all GPU's
        max_epochs=3            # number of epochs
    )
    trainer.fit(model, mnist)
Beispiel #8
0
def test_predict(tmpdir, ray_start_2_cpus, seed, num_slots):
    """Tests if trained model has high accuracy on test set."""
    config = {
        "layer_1": 32,
        "layer_2": 32,
        "lr": 1e-2,
        "batch_size": 32,
    }
    model = LightningMNISTClassifier(config, tmpdir)
    dm = MNISTDataModule(data_dir=tmpdir,
                         num_workers=1,
                         batch_size=config["batch_size"])
    plugin = HorovodRayPlugin(num_slots=num_slots, use_gpu=False)
    trainer = get_trainer(tmpdir,
                          limit_train_batches=20,
                          max_epochs=1,
                          plugins=[plugin])
    predict_test(trainer, model, dm)
def test_predict_client(tmpdir, start_ray_client_server_2_cpus, seed,
                        num_slots):
    assert ray.util.client.ray.is_connected()
    config = {
        "layer_1": 32,
        "layer_2": 32,
        "lr": 1e-2,
        "batch_size": 32,
    }
    model = LightningMNISTClassifier(config, tmpdir)
    dm = MNISTDataModule(data_dir=tmpdir,
                         num_workers=1,
                         batch_size=config["batch_size"])
    plugin = HorovodRayPlugin(num_slots=num_slots, use_gpu=False)
    trainer = get_trainer(tmpdir,
                          limit_train_batches=20,
                          max_epochs=1,
                          plugins=[plugin])
    predict_test(trainer, model, dm)
Beispiel #10
0
def run_mnist_dvrl(prediction_hparams, dvrl_hparams):
    # DATA
    datamodule = MNISTDataModule(data_dir=DATA_PATH + '/mnist_')
    datamodule.prepare_data()  # downloads data to given path
    train_dataloader = datamodule.train_dataloader(
        batch_size=dvrl_hparams.get('outer_batch_size', 32))
    val_dataloader = datamodule.val_dataloader(
        batch_size=dvrl_hparams.get('outer_batch_size', 32))
    test_dataloader = datamodule.test_dataloader(
        batch_size=dvrl_hparams.get('outer_batch_size', 32))

    val_split = datamodule.val_split
    encoder_model = SimpleConvNet()
    return run_gumbel(dvrl_hparams,
                      prediction_hparams,
                      train_dataloader,
                      val_dataloader,
                      test_dataloader,
                      val_split,
                      encoder_model,
                      encoder_out_dim=50)
Beispiel #11
0
import torch
import pytorch_lightning as pl

from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule

from .model import HebbMLP
from .train_pipe import TrainPipe

device_id = None

latest_checkpoint = ""

if __name__ == '__main__':
    data_dir = '../data/'
    train_transforms = None
    val_transforms = None
    mnist_dm = MNISTDataModule(data_dir, train_transforms=train_transforms, val_transforms=val_transforms)
    mlp = HebbMLP(1)
    train_pipe = TrainPipe(model=mlp)

    classic_trainer = pl.Trainer(gpus=([device_id] if device_id is not None else None), max_epochs=10)
    print("TRAINING MLP")
    classic_trainer.fit(train_pipe, datamodule=mnist_dm)

    print("SAVE MLP")
    torch.save(train_pipe.state_dict(), 'lateral_model_v1')

    # print("TEST MLP")
    # classic_trainer.test(train_pipe, datamodule=mnist_dm)
Beispiel #12
0
def mnist_data():
    mnist = MNISTDataModule('../data/', batch_size=512)
    mnist.prepare_data()
    mnist.setup()
    return mnist
Beispiel #13
0
 def download_data():
     from filelock import FileLock
     with FileLock(os.path.join(data_dir, ".lock")):
         MNISTDataModule(data_dir=data_dir).prepare_data()