def celeba_dataloader(bs, IMAGE_SIZE, CROP, DATA_PATH): trans = [] trans.append(transforms.RandomHorizontalFlip()) if CROP > 0: trans.append(transforms.CenterCrop(CROP)) trans.append(transforms.Resize(IMAGE_SIZE)) trans.append(transforms.ToTensor()) transform = transforms.Compose(trans) dm = CelebADataModule(data_dir=DATA_PATH, target_type='attr', train_transform=transform, val_transform=transform, download=True, batch_size=bs) return dm
def validation_step(self, batch, batch_idx): loss, logs = self.step(batch, batch_idx) self.log_dict({f"val_{k}": v for k, v in logs.items()}) return loss def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=self.lr) if __name__ == "__main__": parser = ArgumentParser(description='Hyperparameters for our experiments') parser.add_argument('--latent-dim', type=int, default=128, help="size of latent dim for our vae") parser.add_argument('--epochs', type=int, default=50, help="num epochs") parser.add_argument('--gpus', type=int, default=1, help="gpus, if no gpu set to 0, to run on all gpus set to -1") parser.add_argument('--bs', type=int, default=256, help="batch size") parser.add_argument('--alpha', type=int, default=1, help="kl coeff") parser.add_argument('--beta', type=int, default=1, help="mmd coeff") parser.add_argument('--lr', type=int, default=1e-3, help="learning rate") hparams = parser.parse_args() m = VAE(input_height=IMAGE_SIZE, latent_dim=hparams.latent_dim, beta=hparams.beta, lr=hparams.lr) dm = CelebADataModule(data_dir=DATA_PATH, target_type='attr', train_transform=transform, val_transform=transform, download=True, batch_size=hparams.bs) trainer= Trainer(gpus = hparams.gpus, max_epochs = hparams.epochs) trainer.fit(m, datamodule=dm) torch.save(m.state_dict(), "hingevae-celeba-conv.ckpt")
dest="filename", metavar='FILE', help='path to the config file', default='configs/vae.yaml') args = parser.parse_args() config = get_config(args.filename) vae = assembler(config, "training") # Load data trans = [] trans.append(transforms.RandomHorizontalFlip()) if config["exp_params"]["crop_size"] > 0: trans.append(transforms.CenterCrop(config["exp_params"]["crop_size"])) trans.append(transforms.Resize(config["exp_params"]["img_size"])) trans.append(transforms.ToTensor()) transform = transforms.Compose(trans) dm = CelebADataModule(data_dir=config["exp_params"]["data_path"], target_type='attr', train_transform=transform, val_transform=transform, download=True, batch_size=config["exp_params"]["batch_size"]) # Run Training Loop trainer = Trainer(gpus=config["trainer_params"]["gpus"], max_epochs=config["trainer_params"]["max_epochs"]) trainer.fit(vae, datamodule=dm) torch.save(vae.state_dict(), f"{vae.model.name}_celeba_conv.ckpt")
def training_step(self, batch, batch_idx): loss, logs = self.step(batch, batch_idx) self.log_dict({f"train_{k}": v for k, v in logs.items()}, on_step=True, on_epoch=False) return loss def validation_step(self, batch, batch_idx): loss, logs = self.step(batch, batch_idx) self.log_dict({f"val_{k}": v for k, v in logs.items()}) return loss def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=self.lr) if __name__ == "__main__": m = AE(input_height=IMAGE_SIZE) runner = Trainer(gpus=2, gradient_clip_val=0.5, max_epochs=15) dm = CelebADataModule(data_dir=DATA_PATH, target_type='attr', train_transform=transform, val_transform=transform, download=True, batch_size=BATCH_SIZE, num_workers=3) runner.fit(m, datamodule=dm) torch.save(m.state_dict(), "ae-celeba-latent-dim-256.ckpt")