Exemplo n.º 1
0
def run_mnist_prediction(prediction_hparams):
    # DATA
    datamodule = MNISTDataModule(data_dir=DATA_PATH + '/mnist_')
    datamodule.prepare_data()  # downloads data to given path
    train_dataloader = datamodule.train_dataloader()
    val_dataloader = datamodule.val_dataloader()
    test_dataloader = datamodule.test_dataloader()

    encoder_model = SimpleConvNet()
    return run_prediction(prediction_hparams, train_dataloader, val_dataloader, test_dataloader,
                          encoder_model, encoder_out_dim=50)
Exemplo n.º 2
0
def run_mnist_dvrl(prediction_hparams, dvrl_hparams):
    # DATA
    datamodule = MNISTDataModule(data_dir=DATA_PATH + '/mnist_')
    datamodule.prepare_data()  # downloads data to given path
    train_dataloader = datamodule.train_dataloader(
        batch_size=dvrl_hparams.get('outer_batch_size', 32))
    val_dataloader = datamodule.val_dataloader(
        batch_size=dvrl_hparams.get('outer_batch_size', 32))
    test_dataloader = datamodule.test_dataloader(
        batch_size=dvrl_hparams.get('outer_batch_size', 32))

    val_split = datamodule.val_split
    encoder_model = SimpleConvNet()
    return run_gumbel(dvrl_hparams,
                      prediction_hparams,
                      train_dataloader,
                      val_dataloader,
                      test_dataloader,
                      val_split,
                      encoder_model,
                      encoder_out_dim=50)
Exemplo n.º 3
0
def mnist_data():
    mnist = MNISTDataModule('../data/', batch_size=512)
    mnist.prepare_data()
    mnist.setup()
    return mnist