コード例 #1
0
def basic_ae_example():
    from pl_bolts.models.autoencoders import AE
    from pytorch_lightning.plugins import DDPPlugin

    # Data.
    train_dataset = torchvision.datasets.CIFAR10(
        "",
        train=True,
        download=True,
        transform=torchvision.transforms.ToTensor())
    val_dataset = torchvision.datasets.CIFAR10(
        "",
        train=False,
        download=True,
        transform=torchvision.transforms.ToTensor())

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=32,
                                               num_workers=12)
    #train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, num_workers=12, persistent_workers=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=32,
                                             num_workers=12)
    #val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, num_workers=12, persistent_workers=True)

    # Model.
    model = AE(input_height=32)
    """
	# Override any part of this AE to build your own variation.
	class MyAEFlavor(AE):
		def init_encoder(self, hidden_dim, latent_dim, input_width, input_height):
			encoder = YourSuperFancyEncoder(...)
			return encoder

	model = MyAEFlavor(...)
	"""

    # Fit.
    trainer = pl.Trainer(gpus=2,
                         accelerator="ddp",
                         plugins=DDPPlugin(find_unused_parameters=False))
    trainer.fit(model, train_loader, val_loader)

    #--------------------
    # CIFAR-10 pretrained model:
    ae = AE(input_height=32)
    print(AE.pretrained_weights_available())
    ae = ae.from_pretrained("cifar10-resnet18")

    ae.freeze()
コード例 #2
0
def test_ae(tmpdir):
    reset_seed()

    model = AE(data_dir=tmpdir, batch_size=2)
    trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir)
    trainer.fit(model)
    trainer.test(model)
コード例 #3
0
def test_from_pretrained(tmpdir):
    vae = VAE(input_height=32)
    ae = AE(input_height=32)

    assert len(VAE.pretrained_weights_available()) > 0
    assert len(AE.pretrained_weights_available()) > 0

    exception_raised = False

    try:
        vae = vae.from_pretrained('cifar10-resnet18')
        vae = vae.from_pretrained(
            'stl10-resnet18'
        )  # try loading weights not compatible with exact architecture

        ae = ae.from_pretrained('cifar10-resnet18')
    except Exception as e:
        exception_raised = True

    assert exception_raised is False, "error in loading weights"

    keyerror = False

    try:
        vae = vae.from_pretrained('abc')
        ae = ae.from_pretrained('xyz')
    except KeyError:
        keyerror = True

    assert keyerror is True, "KeyError not raised when provided with illegal checkpoint name"
コード例 #4
0
def test_ae(tmpdir, datadir, dm_cls):
    seed_everything()

    dm = dm_cls(data_dir=datadir, batch_size=4)
    model = AE(input_height=dm.size()[-1])
    trainer = Trainer(
        fast_dev_run=True,
        default_root_dir=tmpdir,
        gpus=None,
    )

    trainer.fit(model, datamodule=dm)
コード例 #5
0
def test_ae(tmpdir, datadir, dm_cls):
    seed_everything()

    dm = dm_cls(data_dir=datadir, batch_size=4)
    model = AE(input_height=dm.size()[-1])
    trainer = pl.Trainer(fast_dev_run=True,
                         default_root_dir=tmpdir,
                         max_epochs=1,
                         gpus=None)

    result = trainer.fit(model, dm)
    assert result == 1
コード例 #6
0
def test_from_pretrained(datadir):
    dm = CIFAR10DataModule(data_dir=datadir, batch_size=4)
    dm.setup()
    dm.prepare_data()

    data_loader = dm.val_dataloader()

    vae = VAE(input_height=32)
    ae = AE(input_height=32)

    assert len(VAE.pretrained_weights_available()) > 0
    assert len(AE.pretrained_weights_available()) > 0

    exception_raised = False

    try:
        vae = vae.from_pretrained("cifar10-resnet18")

        # test forward method on pre-trained weights
        for x, y in data_loader:
            vae(x)
            break

        vae = vae.from_pretrained(
            "stl10-resnet18"
        )  # try loading weights not compatible with exact architecture

        ae = ae.from_pretrained("cifar10-resnet18")

        # test forward method on pre-trained weights
        with torch.no_grad():
            for x, y in data_loader:
                ae(x)
                break

    except Exception:
        exception_raised = True

    assert exception_raised is False, "error in loading weights"

    keyerror = False

    try:
        vae.from_pretrained("abc")
        ae.from_pretrained("xyz")
    except KeyError:
        keyerror = True

    assert keyerror is True, "KeyError not raised when provided with illegal checkpoint name"
コード例 #7
0
import argparse

import pytorch_lightning as pl

from pl_bolts.models.autoencoders import AE

from ae_datasets import iCLEVRDataModule
from pytorch_lightning.callbacks import ModelCheckpoint

parser = argparse.ArgumentParser()

parser = AE.add_model_specific_args(parser)
parser = iCLEVRDataModule.add_dataset_specific_args(parser)
parser = pl.Trainer.add_argparse_args(parser)

opt = parser.parse_args()

model = AE(**vars(opt))
data = iCLEVRDataModule(**vars(opt))

trainer = pl.Trainer.from_argparse_args(
    opt,
    log_gpu_memory='all',
    checkpoint_callback=ModelCheckpoint(filepath="{epoch:02d}.ckpt"))
trainer.fit(model, data)