import torch from nvae.dataset import ImageFolderDataset from nvae.utils import add_sn from nvae.vae_celeba import NVAE import numpy as np import matplotlib.pyplot as plt if __name__ == '__main__': train_ds = ImageFolderDataset("E:\data\img_align_celeba", img_dim=64) device = "cpu" model = NVAE(z_dim=512, img_dim=(64, 64)) model.apply(add_sn) model.to(device) model.load_state_dict(torch.load("../checkpoints/ae_ckpt_7_0.080791.pth", map_location=device), strict=False) model.eval() img = train_ds[34].unsqueeze(0).to(device) ori_image = img.permute(0, 2, 3, 1)[0] ori_image = (ori_image.numpy() + 1) / 2 * 255 plt.imshow(ori_image.astype(np.uint8)) plt.show() with torch.no_grad(): gen_imgs, _ = model(img)
parser.add_argument("--n_cpu", type=int, default=16, help="number of cpu threads to use during batch generation") opt = parser.parse_args() epochs = opt.epochs batch_size = opt.batch_size dataset_path = opt.dataset_path train_ds = ImageFolderDataset(dataset_path, img_dim=64) train_dataloader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=opt.n_cpu) os.makedirs("checkpoints", exist_ok=True) os.makedirs("output", exist_ok=True) device = "cuda:0" if torch.cuda.is_available() else "cpu" model = NVAE(z_dim=512, img_dim=(64, 64), M_N=opt.batch_size / len(train_ds)) # apply Spectral Normalization model.apply(add_sn) model.to(device) if opt.pretrained_weights: model.load_state_dict(torch.load(opt.pretrained_weights, map_location=device), strict=False) optimizer = torch.optim.Adamax(model.parameters(), lr=0.001) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=15) for epoch in range(epochs): model.train()
import numpy as np import torch from nvae.utils import add_sn from nvae.vae_celeba import NVAE if __name__ == '__main__': img_size = 64 z_dim = 512 cols, rows = 12, 12 width = cols * img_size height = rows * img_size device = "cpu" model = NVAE(z_dim=z_dim, img_dim=img_size) model.apply(add_sn) model.to(device) model.load_state_dict(torch.load("checkpoints/ae_ckpt_0_0.761000.pth", map_location=device), strict=False) model.eval() result = np.zeros((width, height, 3), dtype=np.uint8) with torch.no_grad(): z = torch.randn((cols * rows, z_dim, 2, 2)).to(device) gen_imgs, _ = model.decoder(z) gen_imgs = gen_imgs.reshape(rows, cols, 3, img_size, img_size)