Exemplo n.º 1
0
import torch

from nvae.dataset import ImageFolderDataset
from nvae.utils import add_sn
from nvae.vae_celeba import NVAE
import numpy as np
import matplotlib.pyplot as plt

if __name__ == '__main__':

    train_ds = ImageFolderDataset("E:\data\img_align_celeba", img_dim=64)

    device = "cpu"
    model = NVAE(z_dim=512, img_dim=(64, 64))
    model.apply(add_sn)
    model.to(device)

    model.load_state_dict(torch.load("../checkpoints/ae_ckpt_7_0.080791.pth",
                                     map_location=device),
                          strict=False)

    model.eval()

    img = train_ds[34].unsqueeze(0).to(device)
    ori_image = img.permute(0, 2, 3, 1)[0]
    ori_image = (ori_image.numpy() + 1) / 2 * 255
    plt.imshow(ori_image.astype(np.uint8))
    plt.show()

    with torch.no_grad():
        gen_imgs, _ = model(img)
Exemplo n.º 2
0
    parser.add_argument("--n_cpu", type=int, default=16, help="number of cpu threads to use during batch generation")
    opt = parser.parse_args()

    epochs = opt.epochs
    batch_size = opt.batch_size
    dataset_path = opt.dataset_path

    train_ds = ImageFolderDataset(dataset_path, img_dim=64)
    train_dataloader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=opt.n_cpu)

    os.makedirs("checkpoints", exist_ok=True)
    os.makedirs("output", exist_ok=True)

    device = "cuda:0" if torch.cuda.is_available() else "cpu"

    model = NVAE(z_dim=512, img_dim=(64, 64), M_N=opt.batch_size / len(train_ds))

    # apply Spectral Normalization
    model.apply(add_sn)

    model.to(device)

    if opt.pretrained_weights:
        model.load_state_dict(torch.load(opt.pretrained_weights, map_location=device), strict=False)

    optimizer = torch.optim.Adamax(model.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=15)

    for epoch in range(epochs):
        model.train()
Exemplo n.º 3
0
import numpy as np
import torch

from nvae.utils import add_sn
from nvae.vae_celeba import NVAE

if __name__ == '__main__':

    img_size = 64
    z_dim = 512
    cols, rows = 12, 12
    width = cols * img_size
    height = rows * img_size

    device = "cpu"
    model = NVAE(z_dim=z_dim, img_dim=img_size)
    model.apply(add_sn)
    model.to(device)

    model.load_state_dict(torch.load("checkpoints/ae_ckpt_0_0.761000.pth",
                                     map_location=device),
                          strict=False)

    model.eval()

    result = np.zeros((width, height, 3), dtype=np.uint8)

    with torch.no_grad():
        z = torch.randn((cols * rows, z_dim, 2, 2)).to(device)
        gen_imgs, _ = model.decoder(z)
        gen_imgs = gen_imgs.reshape(rows, cols, 3, img_size, img_size)