def test_ae(tmpdir, dm_cls):
    seed_everything()
    dm = dm_cls(batch_size=2, num_workers=0)
    model = VAE(*dm.size())
    trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir)
    trainer.fit(model, dm)
    trainer.test(model, datamodule=dm)
Exemple #2
0
def test_vae(tmpdir):
    reset_seed()

    model = VAE(data_dir=tmpdir, batch_size=2)
    trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir)
    trainer.fit(model)
    trainer.test(model)
    loss = trainer.callback_metrics['loss']

    assert loss > 0, 'VAE failed'
Exemple #3
0
def test_vae(tmpdir):
    reset_seed()

    model = VAE(data_dir=tmpdir, batch_size=2)
    trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir)
    trainer.fit(model)
    results = trainer.test(model)[0]
    loss = results['test_loss']

    assert loss > 0, 'VAE failed'
def test_vae(tmpdir, dm_cls):
    seed_everything()
    dm = dm_cls(batch_size=2, num_workers=0)
    model = VAE(*dm.size())
    trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir, deterministic=True)
    trainer.fit(model, dm)
    results = trainer.test(model, datamodule=dm)[0]
    loss = results["test_loss"]

    assert loss > 0, "VAE failed"
Exemple #5
0
    def __init__(self, env):
        super().__init__(env)

        # Assumes raycast observation space of len 20
        self.observation_space = gym.spaces.Discrete(5)
        self.v = VAE.load_from_checkpoint(
            "/home/denpak/logs/vae_model/last.ckpt")

        for param in self.v.parameters():
            param.requires_grad = False
Exemple #6
0
def test_vae(tmpdir, datadir, dm_cls):
    seed_everything()

    dm = dm_cls(data_dir=datadir, batch_size=4)
    model = VAE(input_height=dm.size()[-1])
    trainer = Trainer(
        fast_dev_run=True,
        default_root_dir=tmpdir,
        gpus=None,
    )

    trainer.fit(model, datamodule=dm)
Exemple #7
0
def test_vae(tmpdir):
    seed_everything()

    model = VAE(data_dir=tmpdir, batch_size=2, num_workers=0)
    trainer = pl.Trainer(fast_dev_run=True,
                         default_root_dir=tmpdir,
                         deterministic=True)
    trainer.fit(model)
    results = trainer.test(model)[0]
    loss = results['test_loss']

    assert loss > 0, 'VAE failed'
Exemple #8
0
def test_vae(tmpdir, datadir, dm_cls):
    seed_everything()

    dm = dm_cls(data_dir=datadir, batch_size=4)
    model = VAE(input_height=dm.size()[-1])
    trainer = pl.Trainer(fast_dev_run=True,
                         default_root_dir=tmpdir,
                         max_epochs=1,
                         gpus=None)

    result = trainer.fit(model, dm)
    assert result == 1
Exemple #9
0
def test_from_pretrained(tmpdir):
    vae = VAE(input_height=32)
    ae = AE(input_height=32)

    assert len(VAE.pretrained_weights_available()) > 0
    assert len(AE.pretrained_weights_available()) > 0

    exception_raised = False

    try:
        vae = vae.from_pretrained('cifar10-resnet18')
        vae = vae.from_pretrained(
            'stl10-resnet18'
        )  # try loading weights not compatible with exact architecture

        ae = ae.from_pretrained('cifar10-resnet18')
    except Exception as e:
        exception_raised = True

    assert exception_raised is False, "error in loading weights"

    keyerror = False

    try:
        vae = vae.from_pretrained('abc')
        ae = ae.from_pretrained('xyz')
    except KeyError:
        keyerror = True

    assert keyerror is True, "KeyError not raised when provided with illegal checkpoint name"
Exemple #10
0
def simple_vae_example():
    from pl_bolts.models.autoencoders import VAE
    from pytorch_lightning.plugins import DDPPlugin

    # Data.
    train_dataset = torchvision.datasets.CIFAR10(
        "",
        train=True,
        download=True,
        transform=torchvision.transforms.ToTensor())
    val_dataset = torchvision.datasets.CIFAR10(
        "",
        train=False,
        download=True,
        transform=torchvision.transforms.ToTensor())

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=32,
                                               num_workers=12)
    #train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, num_workers=12, persistent_workers=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=32,
                                             num_workers=12)
    #val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, num_workers=12, persistent_workers=True)

    # Model.
    model = VAE(input_height=32)
    """
	# Override any part of this AE to build your own variation.
	class MyVAEFlavor(VAE):
		def get_posterior(self, mu, std):
			# Do something other than the default.
			P = self.get_distribution(self.prior, loc=torch.zeros_like(mu), scale=torch.ones_like(std))
			return P

	model = MyVAEFlavor(...)
	"""

    # Fit.
    trainer = pl.Trainer(gpus=2,
                         accelerator="ddp",
                         plugins=DDPPlugin(find_unused_parameters=False))
    trainer.fit(model, train_loader, val_loader)

    #--------------------
    # CIFAR-10 pretrained model:
    vae = VAE(input_height=32)
    print(VAE.pretrained_weights_available())
    vae = vae.from_pretrained("cifar10-resnet18")

    vae.freeze()
Exemple #11
0
def test_from_pretrained(datadir):
    dm = CIFAR10DataModule(data_dir=datadir, batch_size=4)
    dm.setup()
    dm.prepare_data()

    data_loader = dm.val_dataloader()

    vae = VAE(input_height=32)
    ae = AE(input_height=32)

    assert len(VAE.pretrained_weights_available()) > 0
    assert len(AE.pretrained_weights_available()) > 0

    exception_raised = False

    try:
        vae = vae.from_pretrained("cifar10-resnet18")

        # test forward method on pre-trained weights
        for x, y in data_loader:
            vae(x)
            break

        vae = vae.from_pretrained(
            "stl10-resnet18"
        )  # try loading weights not compatible with exact architecture

        ae = ae.from_pretrained("cifar10-resnet18")

        # test forward method on pre-trained weights
        with torch.no_grad():
            for x, y in data_loader:
                ae(x)
                break

    except Exception:
        exception_raised = True

    assert exception_raised is False, "error in loading weights"

    keyerror = False

    try:
        vae.from_pretrained("abc")
        ae.from_pretrained("xyz")
    except KeyError:
        keyerror = True

    assert keyerror is True, "KeyError not raised when provided with illegal checkpoint name"
Exemple #12
0
    from pl_bolts.datamodules import MNISTDataModule, ImagenetDataModule
    '''
    datamodule = SprayDataModule(dataset_parms=config['exp_params'],
                                 trainer_parms=config['trainer_params'],
                                 model_params=config['model_params'])
    model = VAE(hidden_dim = 128,
                latent_dim = 32,
                input_channels = config['model_params']['in_channels'],
                input_width = config['exp_params']['img_size'],
                input_height = config['exp_params']['img_size'],
                batch_size = config['exp_params']['batch_size'],
                learning_rate = config['exp_params']['LR'],
                datamodule = datamodule)
    '''
    datamodule = MNISTDataModule(data_dir='./')
    #datamodule = ImagenetDataModule(data_dir='./')

    model = VAE(datamodule=datamodule)

    trainer = pl.Trainer(
        gpus=1,
        max_epochs=config['trainer_params']['max_epochs'],
        checkpoint_callback=checkpoint_callback)  #, profiler=True)

    trainer.fit(model)

    output_dir = 'trained_models/'
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)
    torch.save(model.state_dict(), os.path.join(output_dir, "model.pt"))
Exemple #13
0
    def _data_loader(self,
                     dataset: Dataset,
                     shuffle: bool = False) -> DataLoader:
        return DataLoader(dataset,
                          batch_size=self.batch_size,
                          shuffle=shuffle,
                          num_workers=self.num_workers,
                          drop_last=self.drop_last,
                          pin_memory=self.pin_memory)


#model = VAE(64,latent_dim=5)
#trainer = pl.Trainer(gpus=1)
data = FishDataModule()
#trainer.fit(model,data)

v = VAE.load_from_checkpoint(
    "/home/denpak/Research/BBEEncoders/vae_model/last.ckpt")
d = FishDataset(data_dir="/home/denpak/Research/BBEProject/Recordings",
                labels_file="labels.csv")
x = torch.unsqueeze(d[0][0], 0)
x = v.encoder(x)
print(v.fc_mu(x)[0])
print(v.fc_var(x)[0])

x = torch.unsqueeze(d[1][0], 0)
x = v.encoder(x)
print(v.fc_mu(x)[0].detach().numpy())
print(v.fc_var(x)[0])