示例#1
0
from GAN.utils.init import InitNormal

from keras.optimizers import Adam, SGD, RMSprop

if __name__ == '__main__':
    nbatch = 128 
    nmax   = nbatch * 100
    npxw, npxh = 64, 128

    from load import people, load_all
    va_data, tr_stream, _ = people(pathfile='protocol/PPPS.txt', size=(npxw, npxh), batch_size=nbatch)


    g = Generator(g_size=(8, npxh, npxw), g_nb_filters=128, g_nb_coding=500, g_scales=4, g_init=InitNormal(scale=0.002))#, g_FC=[5000])
    d = Discriminator(d_size=g.g_size, d_nb_filters=128, d_scales=4, d_init=InitNormal(scale=0.002))#, d_FC=[5000])
    gan = GAN(g, d)

    from keras.optimizers import Adam, SGD, RMSprop
    gan.fit(tr_stream, 
                save_dir='./samples/parsing_skeleton/', 
                k=1, 
                nbatch=nbatch,
                nmax=nmax,
                opt=Adam(lr=0.0002, beta_1=0.5, decay=1e-5),
                transform=transform_skeleton, #opt=RMSprop(lr=0.01))
                inverse_transform=inverse_transform_skeleton)
    



示例#2
0
def main(args):
    pl.seed_everything(args.seed)

    initial_transforms = [
        transforms.CenterCrop(168),
        transforms.Resize(64),
    ]
    augmentation_transforms = [
        transforms.RandomHorizontalFlip(p=0.5),
    ]
    final_transforms = []

    if args.model.lower() == "realnvp":
        datainfo = DataInfo("celeba", 3, 64)
        prior = Normal(loc=0.0, scale=1.0)
        hyperparams = Hyperparameters(base_dim=32, res_blocks=2)

        model = RealNVP(datainfo, prior, hyperparams, args.lr)
    elif args.model.lower() == "vae":
        model = VAE(latent_dim=args.latent_dim,
                    lr=args.lr,
                    kl_weight=args.kl_weight)
    elif args.model.lower() == "gan":
        final_transforms = [
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]

        model = GAN(nz=args.latent_dim,
                    ngf=args.generator_features,
                    ndf=args.discriminator_features,
                    lr=args.lr)
    elif args.model.lower() == "glow":
        final_transforms = [preprocess]

        model = Glow((64, 64, 3),
                     flow_coupling=args.flow_coupling,
                     lr=args.lr,
                     warmup=args.warmup)
    else:
        raise ValueError(
            f"{args.model} is not a valid model name. Use -h flag for help.")

    train_transform = transforms.Compose(initial_transforms +
                                         augmentation_transforms +
                                         [transforms.ToTensor()] +
                                         final_transforms)
    test_transform = transforms.Compose(initial_transforms +
                                        [transforms.ToTensor()] +
                                        final_transforms)

    train_dataset = datasets.CelebA(root=args.data_root,
                                    split="train",
                                    transform=train_transform)
    val_dataset = datasets.CelebA(root=args.data_root,
                                  split="test",
                                  transform=test_transform)

    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.workers)
    val_loader = DataLoader(val_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers)

    if args.model.lower() == "glow":
        model.train()
        if args.gpus and args.gpus > 0:
            model.to('cuda')

        init_batches = []
        for x, _ in islice(train_loader, None, args.n_init_batches):
            init_batches.append(x)
        init_batches = torch.cat(init_batches).to(model.device)

        with torch.no_grad():
            model(init_batches)

        model.cpu()

    tt_logger = TestTubeLogger(args.logs_root, name=args.model.lower())
    checkpoint_callback = ModelCheckpoint(
        verbose=True,
        save_last=True,
    )

    trainer = pl.Trainer.from_argparse_args(
        args,
        logger=tt_logger,
        deterministic=True,
        min_epochs=5,
        callbacks=[checkpoint_callback],
    )
    trainer.fit(model, train_loader, val_loader)
示例#3
0
if __name__ == '__main__':
    nbatch = 128
    nmax = nbatch * 100
    npxw, npxh = 64, 128

    from load import people
    va_data, tr_stream, _ = people(pathfile='protocol/cuhk01-train.txt',
                                   size=(npxw, npxh),
                                   batch_size=nbatch)

    g = Generator(g_size=(3, npxh, npxw),
                  g_nb_filters=128,
                  g_nb_coding=200,
                  g_scales=4,
                  g_init=InitNormal(scale=0.002))
    d = Discriminator(d_size=g.g_size,
                      d_nb_filters=128,
                      d_scales=4,
                      d_init=InitNormal(scale=0.002))
    gan = GAN(g, d)

    from keras.optimizers import Adam, SGD, RMSprop
    gan.fit(tr_stream,
            save_dir='/home/shaofan/Projects/JSTL/transfer/gan/',
            k=1,
            nbatch=nbatch,
            nmax=nmax,
            opt=Adam(lr=0.0003, beta_1=0.5, decay=1e-5))
    #opt=RMSprop(lr=0.01))
示例#4
0

class mnist_stream():
    def __init__(self):
        (x_train, _), (_, _) = mnist.load_data()
        x = np.expand_dims(x_train / 127.5 - 1, 1)
        self.x = x

    def __call__(self, bs):
        return self.x[np.random.choice(self.x.shape[0],
                                       replace=False,
                                       size=(bs, ))]


if __name__ == '__main__':
    coding = 200
    img_shape = (1, 28, 28)

    g = basic_gen((coding, ), img_shape, nf=64, scale=2, FC=[256])
    d = basic_dis(img_shape, nf=64, scale=2, FC=[512])
    gan = GAN(g, d, init=InitNormal(scale=0.02))

    from keras.optimizers import Adam, SGD, RMSprop
    gan.fit(mnist_stream(),
            niter=5000,
            save_dir='./quickshots/mnist',
            k=3,
            save_iter=100,
            nbatch=128,
            opt=Adam(lr=0.0002, beta_1=0.5))
示例#5
0
    return x, y, random_stream


if __name__ == '__main__':
    nbatch = 128
    x, y, stream = get_mnist(nbatch)

    g = Generator(g_size=(1, 28, 28),
                  g_nb_filters=64,
                  g_nb_coding=200,
                  g_scales=2,
                  g_FC=[1024],
                  g_init=InitNormal(scale=0.05))
    d = Discriminator(d_size=g.g_size,
                      d_nb_filters=64,
                      d_scales=2,
                      d_FC=[1024],
                      d_init=InitNormal(scale=0.05))
    gan = GAN(g, d)
    from keras.optimizers import Adam, SGD, RMSprop
    gan.fit(stream,
            save_dir='./labs/mnist',
            k=1,
            nbatch=nbatch,
            opt=Adam(lr=0.0002, beta_1=0.5, decay=1e-5))

# 10/26/16: if not initialize in a proper way
#           all-zero will appears in BN layer and cause nan
#           solution: use leaky relu
示例#6
0
from GAN.utils.data import transform, inverse_transform
from GAN.utils.init import InitNormal

if __name__ == '__main__':
    g = Generator(g_size=(1, 28, 28),
                  g_nb_filters=64,
                  g_nb_coding=200,
                  g_scales=2,
                  g_FC=[1024],
                  g_init=InitNormal(scale=0.05))
    d = Discriminator(d_size=g.g_size,
                      d_nb_filters=64,
                      d_scales=2,
                      d_FC=[1024],
                      d_init=InitNormal(scale=0.05))
    gan = GAN(g, d)

    vis_grid(inverse_transform(g.random_generate(100)), (10, 10), 'sample.png')
    g.load_weights('{}/{}_gen_params.h5'.format('samples/mnist_bck/', 600))
    vis_grid(inverse_transform(g.random_generate(100)), (10, 10),
             'sample2.png')

    img = g.random_generate(128)
    reverse(g, img, savedir='reverse/mnist/')

    def reverse(self, target, min=-1.0, max=+1.0, savedir=None):
        import theano
        import theano.tensor as T
        from ..utils.vis import vis_grid
        from ..utils.data import transform, inverse_transform