Beispiel #1
0
def main(args):
    """
    Main function for the script
    :param args: parsed command line arguments
    :return: None
    """

    from config import cfg as opt

    opt.merge_from_file(args.config)
    opt.freeze()

    print("Creating generator object ...")
    # create the generator object
    gen = Generator(resolution=opt.dataset.resolution,
                    num_channels=opt.dataset.channels,
                    structure=opt.structure,
                    **opt.model.gen)

    print("Loading the generator weights from:", args.generator_file)
    # load the weights into it
    # gen.load_state_dict(torch.load(args.generator_file))
    gen.load(args.generator_file)
    src_seeds=[1,32, 44, 86]
    dst_seeds=[231,415,1515,16]
    draw_interp_figure(args.output, gen, out_depth=5,
                                 src_seeds=src_seeds, dst_seeds=dst_seeds, psis=[0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7,0.8, 0.9,1])
    
    print('Done.')
def main(args):
    """
    Main function for the script
    :param args: parsed command line arguments
    :return: None
    """

    from config import cfg as opt

    opt.merge_from_file(args.config)
    opt.freeze()

    print("Creating generator object ...")
    # create the generator object
    gen = Generator(resolution=opt.dataset.resolution,
                    num_channels=opt.dataset.channels,
                    structure=opt.structure,
                    **opt.model.gen)

    print("Loading the generator weights from:", args.generator_file)
    # load the weights into it
    # gen.load_state_dict(torch.load(args.generator_file))
    gen.load(args.generator_file)

    draw_truncation_trick_figure('figure08-truncation-trick.png',
                                 gen,
                                 out_depth=5,
                                 seeds=[1, 32, 44, 86, 91, 388],
                                 psis=[1, 0.7, 0.5, 0, -0.5, -1])

    print('Done.')
def main(args):
    """
    Main function for the script
    :param args: parsed command line arguments
    :return: None
    """

    # from config import cfg as opt
    #
    # opt.merge_from_file(args.config)
    # opt.freeze()

    print("Creating generator object ...")
    # create the generator object
    gen = Generator()

    print("Loading the generator weights from:", args.generator_file)
    # load the weights into it
    state_dict = torch.load(args.generator_file)
    gen.load_state_dict(state_dict['g_ema'])
    gen = gen.to(device)

    avg = state_dict['latent_avg'].numpy()
    draw_truncation_trick_figure('figure-truncation-trick.jpg',
                                 gen,
                                 avg,
                                 seeds=[91, 388],
                                 psis=[1, 0.7, 0.5, 0, -0.5, -1])

    print('Done.')
def main(args):
    """
    Main function for the script
    :param args: parsed command line arguments
    :return: None
    """

    from config import cfg as opt

    opt.merge_from_file(args.config)
    opt.freeze()

    save_path = args.output_dir
    os.makedirs(save_path, exist_ok=True)

    print("Creating generator object ...")
    # create the generator object
    gen = Generator(resolution=opt.dataset.resolution,
                    num_channels=opt.dataset.channels,
                    structure=opt.structure,
                    **opt.model.gen)

    print("Loading the generator weights from:", args.generator_file)
    # load the weights into it
    gen.load_state_dict(torch.load(args.generator_file))

    build_truncation_trick_seq(save_path,
                               gen,
                               out_depth=5,
                               num_samles=args.num_samples)

    print('Done.')
Beispiel #5
0
def main(args):

    train_set = DataLoader(dataset=FaceDataset('hw3_data/face',
                                               mode='train',
                                               normalize=True),
                           batch_size=args.batch_size,
                           shuffle=True,
                           num_workers=args.num_workers)
    valid_set = DataLoader(dataset=FaceDataset('hw3_data/face', mode='test'),
                           batch_size=args.batch_size,
                           shuffle=False,
                           num_workers=args.num_workers)

    G = Generator(args.latent_dim).to(args.device)
    D = Discriminator(args.weight_cliping_limit).to(args.device)
    # G.load_state_dict(torch.load(
    #     '/home/en/SSD/DLCV/hw3-shuoenchang/weights/q2/G_{:02d}.pth'.format(150), map_location=args.device))
    # D.load_state_dict(torch.load(
    #     '/home/en/SSD/DLCV/hw3-shuoenchang/weights/q2/D_{:02d}.pth'.format(150), map_location=args.device))
    optimzer_G = optim.RMSprop(G.parameters(), lr=args.learning_rate)
    optimzer_D = optim.RMSprop(D.parameters(), lr=args.learning_rate)
    latent = torch.randn((32, args.latent_dim, 1, 1)).to(args.device)

    for epoch in range(0, 1500):

        print('\nepoch: {}'.format(epoch))
        train(train_set, G, D, optimzer_G, optimzer_D, args.device,
              args.latent_dim, args.n_critic)  # noqa
        val(G, args.latent_dim, args.device, epoch, latent)

        torch.save(G.state_dict(),
                   '{}/G_{:02d}.pth'.format(args.save_folder, epoch))  # noqa
        torch.save(D.state_dict(),
                   '{}/D_{:02d}.pth'.format(args.save_folder, epoch))  # noqa
Beispiel #6
0
def main(args):
    """
    Main function for the script
    :param args: parsed command line arguments
    :return: None
    """

    # from config import cfg as opt
    #
    # opt.merge_from_file(args.config)
    # opt.freeze()

    print("Creating generator object ...")
    # create the generator object
    gen = Generator()

    print("Loading the generator weights from:", args.generator_file)
    # load the weights into it
    gen.load_state_dict(torch.load(args.generator_file)['g_ema'])
    gen = gen.to(device)

    # path for saving the files:
    # generate the images:
    # src_seeds = [639, 701, 687, 615, 1999], dst_seeds = [888, 888, 888],
    draw_style_mixing_figure(os.path.join('diagrams/figure-style-mixing.jpg'),
                             gen,
                             src_seeds=[639, 1995, 687, 615, 1999],
                             dst_seeds=[888, 888, 888],
                             style_ranges=[range(0, 4)] * 1 +
                             [range(4, 12)] * 1 + [range(12, 18)] * 1)
    print('Done.')
Beispiel #7
0
def main(args):
    """
    Main function for the script
    :param args: parsed command line arguments
    :return: None
    """

    from config import cfg as opt

    opt.merge_from_file(args.config)
    opt.freeze()

    print("Creating generator object ...")
    # create the generator object
    gen = Generator(resolution=opt.dataset.resolution,
                    num_channels=opt.dataset.channels,
                    structure=opt.structure,
                    **opt.model.gen)

    print("Loading the generator weights from:", args.generator_file)
    # load the weights into it
    # gen.load_state_dict(torch.load(args.generator_file))
    gen = load(gen, args.generator_file)

    # path for saving the files:
    # generate the images:
    # src_seeds = [639, 701, 687, 615, 1999], dst_seeds = [888, 888, 888],
    draw_style_mixing_figure(os.path.join('figure03-style-mixing.png'),
                             gen,
                             out_depth=4,
                             src_seeds=[639, 1995, 687, 615, 1999],
                             dst_seeds=[888, 888, 888],
                             style_ranges=[range(0, 2)] * 1 +
                             [range(2, 8)] * 1 + [range(8, 14)] * 1)
    print('Done.')
Beispiel #8
0
def main(args):
    torch.manual_seed(args.seed)

    device = 'cuda'
    G = Generator(100).to(device)
    G.load_state_dict(torch.load('weights/q2.pth', map_location=device))

    generate(G, device, output_path=args.save_folder, seed=args.seed)
def main(args):
    """
    Main function for the script
    :param args: parsed command line arguments
    :return: None
    """

    from config import cfg as opt

    opt.merge_from_file(args.config)
    opt.freeze()

    print("Creating generator object ...")
    # create the generator object
    gen = Generator(resolution=opt.dataset.resolution,
                    num_channels=opt.dataset.channels,
                    structure=opt.structure,
                    **opt.model.gen)

    print("Loading the generator weights from:", args.generator_file)
    # load the weights into it
    # gen.load_state_dict(torch.load(args.generator_file))
    gen.load(args.generator_file)

    # path for saving the files:
    save_path = args.output_dir
    os.makedirs(save_path, exist_ok=True)
    latent_size = opt.model.gen.latent_size
    out_depth = int(np.log2(opt.dataset.resolution)) - 2

    print("Generating scale synchronized images ...")
    # generate the images:
    # with torch.no_grad():
    with jt.no_grad():
        # point = torch.randn(args.n_row * args.n_col, latent_size)
        np.random.seed(1000)
        point = np.random.randn(args.n_row * args.n_col, latent_size)
        # point = (point / point.norm()) * (latent_size ** 0.5)
        point = (point / np.linalg.norm(point)) * (latent_size**0.5)
        point = jt.array(point, dtype='float32')
        ss_image = gen(point, depth=out_depth, alpha=1)
        # color adjust the generated image:
        ss_image = adjust_dynamic_range(ss_image)
    print("gen done")
    # save the ss_image in the directory
    # ss_image = torch.from_numpy(ss_image.data)
    # save_image(ss_image, os.path.join(save_path, "grid.png"), nrow=args.n_row,
    #             normalize=True, scale_each=True, pad_value=128, padding=1)
    jt.save_image_my(ss_image,
                     os.path.join(save_path, "grid.png"),
                     nrow=args.n_row,
                     normalize=True,
                     scale_each=True,
                     pad_value=128,
                     padding=1)

    print('Done.')
Beispiel #10
0
def main(args):
    """
    Main function for the script
    :param args: parsed command line arguments
    :return: None
    """

    from config import cfg as opt

    opt.merge_from_file(args.config)
    opt.freeze()

    print("Creating generator object ...")
    print(opt.model.gen)
    # create the generator object
    gen = Generator(resolution=opt.dataset.resolution,
                    num_channels=opt.dataset.channels,
                    structure=opt.structure,
                    **opt.model.gen)

    print("Loading the generator weights from:", args.generator_file)
    # load the weights into it
    # gen.load_state_dict(torch.load(args.generator_file))
    gen = load(gen, args.generator_file)

    # path for saving the files:
    save_path = args.output_dir
    os.makedirs(save_path, exist_ok=True)
    latent_size = opt.model.gen.latent_size
    out_depth = int(np.log2(opt.dataset.resolution)) - 2

    if args.input is None:
        print("Generating scale synchronized images ...")
        for img_num in tqdm(range(1, args.num_samples + 1)):
            # generate the images:
            with torch.no_grad():
                point = torch.randn(1, latent_size)
                point = (point / point.norm()) * (latent_size ** 0.5)
                ss_image = gen(point, depth=out_depth, alpha=1)
                # color adjust the generated image:
                ss_image = adjust_dynamic_range(ss_image)

            # save the ss_image in the directory
            save_image(ss_image, os.path.join(save_path, str(img_num) + ".png"))

        print("Generated %d images at %s" % (args.num_samples, save_path))
    else:
        code = np.load(args.input)
        dlatent_in = torch.unsqueeze(torch.from_numpy(code), 0)
        ss_image = gen.g_synthesis(dlatent_in, depth=out_depth, alpha=1)
        # color adjust the generated image:
        ss_image = adjust_dynamic_range(ss_image)
        save_image(ss_image, args.output)
Beispiel #11
0
def main(args):
    """
    Main function for the script
    :param args: parsed command line arguments
    :return: None
    """

    from config import cfg as opt

    opt.merge_from_file(args.config)
    opt.model.gen.use_noise = False
    opt.freeze()

    print("Creating generator object ...")
    # create the generator object
    gen = Generator(resolution=opt.dataset.resolution,
                    num_channels=opt.dataset.channels,
                    structure=opt.structure,
                    **opt.model.gen)

    print("Loading the generator weights from:", args.generator_file)
    # load the weights into it
    # gen.load_state_dict(torch.load(args.generator_file))
    gen.load(args.generator_file)

    # path for saving the files:
    # generate the images:
    # src_seeds = [639, 701, 687, 615, 1999], dst_seeds = [888, 888, 888],

    # src_seeds = [i for i in range(200)]
    # src_seeds = [166, 1721, 1181, 21, 239]
    # dst_seeds = [284, 2310, 1140, 255, 626]
    src_seeds = [166, 1721, 1181, 255, 239, 284, 2310, 1140]
    dst_seeds = [21] * 6

    draw_style_mixing_figure(args.output,
                             gen,
                             out_depth=5,
                             src_seeds=src_seeds,
                             dst_seeds=dst_seeds,
                             style_ranges=[range(0, 2)] * 1 +
                             [range(0, 4)] * 1 + [range(0, 6)] * 1 +
                             [range(0, 8)] * 1 + [range(0, 10)] * 1 +
                             [range(0, 12)] * 1)
    # draw_style_mixing_figure(os.path.join('figure03-style-mixing.png'), gen,
    #                          out_depth=4, src_seeds=[670, 1995, 687, 255, 1999], dst_seeds=[888, 888, 888],
    #                          style_ranges=[range(0, 1)] * 1 + [range(1, 6)] * 1 + [range(6, 10)] * 1)
    print('Done.')
Beispiel #12
0
def generate_model(pp, latent_dim, n_epochs, nb_patches):
    """ Loss function """
    adversarial_loss = torch.nn.BCELoss().to(device)
    """ Initialize generator and discriminator """
    generator = Generator().to(device)
    discriminator = Discriminator().to(device)
    """ Optimizers """
    optimizer_G = torch.optim.Adam(generator.parameters(),
                                   lr=lr,
                                   betas=(b1, b2))
    optimizer_D = torch.optim.Adam(discriminator.parameters(),
                                   lr=lr,
                                   betas=(b1, b2))

    gen_data = fitGAN(pp, generator, discriminator, optimizer_G, optimizer_D,
                      adversarial_loss, latent_dim, n_epochs, nb_patches,
                      Tensor)

    return gen_data
Beispiel #13
0
def main(args):
    from config import cfg as opt

    opt.merge_from_file(args.config)
    opt.freeze()

    print("Creating generator object ...")
    # create the generator object
    gen = Generator(resolution=opt.dataset.resolution,
                    num_channels=opt.dataset.channels,
                    structure=opt.structure,
                    **opt.model.gen)

    state_G, state_D, state_Gs = load_weights(args.input_file)

    # we delete the useless to_rgb filters
    params = {}
    for k, v in state_Gs.items():
        params[k] = v
    param_dict = {
        key_translate(k): weight_translate(k, v)
        for k, v in state_Gs.items() if 'torgb_lod' not in key_translate(k)
    }

    sd_shapes = {k: v.shape for k, v in gen.state_dict().items()}
    param_shapes = {k: v.shape for k, v in param_dict.items()}

    # check for mismatch
    for k in list(sd_shapes) + list(param_shapes):
        pds = param_shapes.get(k)
        sds = sd_shapes.get(k)
        if pds is None:
            print("sd only", k, sds)
        elif sds is None:
            print("pd only", k, pds)
        elif sds != pds:
            print("mismatch!", k, pds, sds)

    gen.load_state_dict(param_dict,
                        strict=False)  # needed for the blur kernels
    torch.save(gen.state_dict(), args.output_file)
    print('Done.')
Beispiel #14
0
def main(args):
    """
    Main function for the script
    :param args: parsed command line arguments
    :return: None
    """

    device = 'cuda'

    print("Creating generator object ...")
    # create the generator object
    gen = Generator()

    print("Loading the generator weights from:", args.generator_file)
    # load the weights into it
    gen.load_state_dict(torch.load(args.generator_file)['g_ema'])
    gen = gen.to(device)

    # path for saving the files:
    save_path = args.output_dir
    os.makedirs(save_path, exist_ok=True)

    print("Generating scale synchronized images ...")
    # generate the images:
    with torch.no_grad():
        point = torch.randn(args.n_row * args.n_col, 512).to(device)
        # point = (point / point.norm()) * (latent_size ** 0.5)
        ss_image = gen(point)
        # color adjust the generated image:
        ss_image = adjust_dynamic_range(ss_image)

    # save the ss_image in the directory
    save_image(ss_image,
               os.path.join(save_path, "grid.jpg"),
               nrow=args.n_row,
               normalize=True,
               scale_each=True,
               pad_value=128,
               padding=1)

    print('Done.')
Beispiel #15
0
import numpy as np

import torch
from torchvision import utils

from models.GAN import Generator

if __name__ == '__main__':
    gen = Generator()
    gen.load_state_dict(
        torch.load('./weights/stylegan2-ffhq-config-f.pt')['g_ema'])

    device = 'cuda'

    batch_size = {256: 16, 512: 9, 1024: 9}
    n_sample = batch_size.get(1024, 25)

    g = gen.to(device)

    z = np.random.RandomState(1).randn(n_sample, 512).astype('float32')
    # x = torch.randn(n_sample, 512).to(device)

    with torch.no_grad():
        img = g(torch.from_numpy(z).to(device))

    utils.save_image(img,
                     'gen' + '.png',
                     nrow=int(n_sample**0.5),
                     normalize=True,
                     range=(-1, 1))
Beispiel #16
0
    return args


if __name__ == '__main__':
    args = parse_arguments()

    from config import cfg as opt

    opt.merge_from_file(args.config)
    opt.freeze()

    print("Creating generator object ...")
    # create the generator object
    gen = Generator(resolution=opt.dataset.resolution,
                    num_channels=opt.dataset.channels,
                    structure=opt.structure,
                    **opt.model.gen)
    out_depth = gen.g_synthesis.depth - 1

    state_G, state_D, state_Gs, dlatent_avg = load_weights(args.input_file)

    # we delete the useless to_rgb filters
    params = {}
    for k, v in state_Gs.items():
        params[k] = v
    param_dict = {
        key_translate(k): weight_translate(k, v)
        for k, v in state_Gs.items() if 'torgb_lod' not in key_translate(k)
    }

    for k, v in dlatent_avg.items():
Beispiel #17
0
                    type=int,
                    default=1,
                    help="number of image channels")
parser.add_argument("--sample_interval",
                    type=int,
                    default=400,
                    help="interval betwen image samples")
opt = parser.parse_args()

img_shape = (opt.channels, opt.img_size, opt.img_size)
cuda = True if torch.cuda.is_available() else False

adversarial_loss = torch.nn.BCELoss()

# Initialize generator and discriminator
generator = Generator(z_dim=opt.latent_dim, img_shape=img_shape)
discriminator = Discriminator(img_shape=img_shape)

if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()

# Configure data loader
if opt.dataset == 'mnist':
    dataloader = torch.utils.data.DataLoader(
        datasets.MNIST(
            "./data/mnist",
            train=True,
            transform=transforms.Compose([
                transforms.Resize(opt.img_size),
Beispiel #18
0
from models.GAN import Generator
from torchsummary import summary
import torch
g = Generator(1024)
device = torch.device('cuda')
g = g.to(device)
# input = torch.rand((2, 512)).to(device)
# noise_sample = torch.randn(2, 512, device=device)
# output = g(noise_sample)
# print('aa')
summary(g, (512,))
# gen = Generator()
# gen = gen.to(device)
# test_latents_in = torch.randn(4, 512).to(device)
# test_imgs_out = gen(test_latents_in)

# print('Done.')
Beispiel #19
0
            optimizer_D.step()

            if i % config.print_freq == 0:
                logging.info(
                    f'Epoch [{epoch}][{i}/{len(dataloader)}] [D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]'
                )
            if i % 100 == 0:
                save_image(gen_imgs.data[:25], epoch, i)


if __name__ == "__main__":
    config = parser_args()
    logging.info(config)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    img_shape = (config.in_channel, config.img_size, config.img_size)
    generator = Generator(config, img_shape).to(device)
    discriminator = Discriminator(config, img_shape).to(device)

    logging.info(generator)
    logging.info(discriminator)

    dataset = Mnist(config, 'train')
    loader = Data.DataLoader(dataset,
                             batch_size=config.batch_size,
                             shuffle=True,
                             num_workers=config.workers,
                             pin_memory=True)

    main(generator, discriminator, loader, device, config)
Beispiel #20
0
        transforms.CenterCrop(resize),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ])

    imgs = []
    files = os.listdir('/home/dung/Data/anh')

    for imgfile in files:
        imgfile = '/home/dung/Data/anh/' + imgfile
        img = transform(Image.open(imgfile).convert('RGB'))
        imgs.append(img)

    imgs = torch.stack(imgs, 0).to(device)

    g_ema = Generator(args.size, 512, 8)
    g_ema.load_state_dict(torch.load(args.ckpt)['g_ema'])
    g_ema.eval()
    g_ema = g_ema.to(device)

    with torch.no_grad():
        noise_sample = torch.randn(n_mean_latent, 512, device=device)
        latent_out = g_ema.style(noise_sample)

        latent_mean = latent_out.mean(0)
        latent_std = ((latent_out - latent_mean).pow(2).sum() /
                      n_mean_latent)**0.5

    # percept = lpips.PerceptualLoss(
    #     model='net-lin', net='vgg', use_gpu=device.startswith('cuda'))