Пример #1
0
    parser.add_argument("--n_cpu", type=int, default=16, help="number of cpu threads to use during batch generation")
    opt = parser.parse_args()

    epochs = opt.epochs
    batch_size = opt.batch_size
    dataset_path = opt.dataset_path

    train_ds = ImageFolderDataset(dataset_path, img_dim=64)
    train_dataloader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=opt.n_cpu)

    os.makedirs("checkpoints", exist_ok=True)
    os.makedirs("output", exist_ok=True)

    device = "cuda:0" if torch.cuda.is_available() else "cpu"

    model = NVAE(z_dim=512, img_dim=(64, 64), M_N=opt.batch_size / len(train_ds))

    # apply Spectral Normalization
    model.apply(add_sn)

    model.to(device)

    if opt.pretrained_weights:
        model.load_state_dict(torch.load(opt.pretrained_weights, map_location=device), strict=False)

    optimizer = torch.optim.Adamax(model.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=15)

    for epoch in range(epochs):
        model.train()
Пример #2
0
import torch

from nvae.dataset import ImageFolderDataset
from nvae.utils import add_sn
from nvae.vae_celeba import NVAE
import numpy as np
import matplotlib.pyplot as plt

if __name__ == '__main__':

    train_ds = ImageFolderDataset("E:\data\img_align_celeba", img_dim=64)

    device = "cpu"
    model = NVAE(z_dim=512, img_dim=(64, 64))
    model.apply(add_sn)
    model.to(device)

    model.load_state_dict(torch.load("../checkpoints/ae_ckpt_7_0.080791.pth",
                                     map_location=device),
                          strict=False)

    model.eval()

    img = train_ds[34].unsqueeze(0).to(device)
    ori_image = img.permute(0, 2, 3, 1)[0]
    ori_image = (ori_image.numpy() + 1) / 2 * 255
    plt.imshow(ori_image.astype(np.uint8))
    plt.show()

    with torch.no_grad():
        gen_imgs, _ = model(img)
Пример #3
0
    epochs = opt.epochs
    batch_size = opt.batch_size
    dataset_path = opt.dataset_path

    train_ds = ImageFolderDataset(dataset_path, img_dim=64)
    train_dataloader = DataLoader(train_ds,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=opt.n_cpu)

    os.makedirs("checkpoints", exist_ok=True)
    os.makedirs("output", exist_ok=True)

    device = "cuda:0" if torch.cuda.is_available() else "cpu"

    model = NVAE(z_dim=512, img_dim=(64, 64))

    # apply Spectral Normalization
    model.apply(add_sn)

    model.to(device)

    if opt.pretrained_weights:
        model.load_state_dict(torch.load(opt.pretrained_weights,
                                         map_location=device),
                              strict=False)

    warmup_kl = WarmupKLLoss(init_weights=[1., 1. / 2, 1. / 8],
                             steps=[4500, 3000, 1500],
                             M_N=opt.batch_size / len(train_ds),
                             eta_M_N=5e-6,
Пример #4
0
import torch

from nvae.utils import add_sn
from nvae.vae_celeba import NVAE
import numpy as np
import matplotlib.pyplot as plt

if __name__ == '__main__':
    device = "cpu"
    model = NVAE(z_dim=512, img_dim=(64, 64))
    model.apply(add_sn)
    model.to(device)

    model.load_state_dict(torch.load("checkpoints/ae_ckpt_103_0.074130.pth",
                                     map_location=device),
                          strict=False)

    model.eval()

    with torch.no_grad():
        z = torch.randn((25, 512)).to(device)
        gen_imgs, _ = model.decoder(z)
        gen_imgs = gen_imgs.permute(0, 2, 3, 1)
        for gen_img in gen_imgs:
            gen_img = gen_img.cpu().numpy() * 255
            gen_img = gen_img.astype(np.uint8)

            plt.imshow(gen_img)
            # plt.savefig(f"output/ae_ckpt_%d_%.6f.png" % (epoch, total_loss))
            plt.show()
Пример #5
0
import numpy as np
import torch

from nvae.utils import add_sn
from nvae.vae_celeba import NVAE

if __name__ == '__main__':

    img_size = 64
    z_dim = 512
    cols, rows = 12, 12
    width = cols * img_size
    height = rows * img_size

    device = "cpu"
    model = NVAE(z_dim=z_dim, img_dim=img_size)
    model.apply(add_sn)
    model.to(device)

    model.load_state_dict(torch.load("checkpoints/ae_ckpt_0_0.761000.pth",
                                     map_location=device),
                          strict=False)

    model.eval()

    result = np.zeros((width, height, 3), dtype=np.uint8)

    with torch.no_grad():
        z = torch.randn((cols * rows, z_dim, 2, 2)).to(device)
        gen_imgs, _ = model.decoder(z)
        gen_imgs = gen_imgs.reshape(rows, cols, 3, img_size, img_size)