コード例 #1
0
from GAN.utils import vis_grid
from GAN.utils.data import transform_skeleton, inverse_transform_skeleton
from GAN.utils.init import InitNormal

from keras.optimizers import Adam, SGD, RMSprop

if __name__ == '__main__':
    nbatch = 128 
    nmax   = nbatch * 100
    npxw, npxh = 64, 128

    from load import people, load_all
    va_data, tr_stream, _ = people(pathfile='protocol/PPPS.txt', size=(npxw, npxh), batch_size=nbatch)


    g = Generator(g_size=(8, npxh, npxw), g_nb_filters=128, g_nb_coding=500, g_scales=4, g_init=InitNormal(scale=0.002))#, g_FC=[5000])
    d = Discriminator(d_size=g.g_size, d_nb_filters=128, d_scales=4, d_init=InitNormal(scale=0.002))#, d_FC=[5000])
    gan = GAN(g, d)

    from keras.optimizers import Adam, SGD, RMSprop
    gan.fit(tr_stream, 
                save_dir='./samples/parsing_skeleton/', 
                k=1, 
                nbatch=nbatch,
                nmax=nmax,
                opt=Adam(lr=0.0002, beta_1=0.5, decay=1e-5),
                transform=transform_skeleton, #opt=RMSprop(lr=0.01))
                inverse_transform=inverse_transform_skeleton)
    

コード例 #2
0
ファイル: init_with_ae.py プロジェクト: GRSEB9S/Keras-GAN-2
    def random_stream():
        while 1:
            yield x[np.random.choice(x.shape[0], replace=False,
                                     size=nbatch)].transpose(0, 2, 3, 1)

    return x, y, random_stream


if __name__ == '__main__':
    # init with ae and then run gan
    nbatch = 128
    x, y, stream = get_mnist(nbatch)

    g = Generator(g_size=(1, 28, 28),
                  g_nb_filters=64,
                  g_nb_coding=200,
                  g_scales=2,
                  g_FC=[1024],
                  g_init=InitNormal(scale=0.05))
    d = Discriminator(d_size=g.g_size,
                      d_nb_filters=64,
                      d_scales=2,
                      d_FC=[1024],
                      d_init=InitNormal(scale=0.05))

    g.load_weights('models/mnist_ae_g.h5')
    print g.get_weights()[0].sum()
    d.load_weights('models/mnist_ae_d.h5')
    print d.get_weights()[0].sum()

    #    gan = GAN(g, d)
    gan = AEGAN(g, d)
コード例 #3
0
ファイル: reid_gan.py プロジェクト: GRSEB9S/Keras-GAN-2
from GAN.utils.data import transform, inverse_transform
from GAN.utils.init import InitNormal

if __name__ == '__main__':
    nbatch = 128
    nmax = nbatch * 100
    npxw, npxh = 64, 128

    from load import people
    va_data, tr_stream, _ = people(pathfile='protocol/cuhk01-train.txt',
                                   size=(npxw, npxh),
                                   batch_size=nbatch)

    g = Generator(g_size=(3, npxh, npxw),
                  g_nb_filters=128,
                  g_nb_coding=200,
                  g_scales=4,
                  g_init=InitNormal(scale=0.002))
    d = Discriminator(d_size=g.g_size,
                      d_nb_filters=128,
                      d_scales=4,
                      d_init=InitNormal(scale=0.002))
    gan = GAN(g, d)

    from keras.optimizers import Adam, SGD, RMSprop
    gan.fit(tr_stream,
            save_dir='/home/shaofan/Projects/JSTL/transfer/gan/',
            k=1,
            nbatch=nbatch,
            nmax=nmax,
            opt=Adam(lr=0.0003, beta_1=0.5, decay=1e-5))
コード例 #4
0
ファイル: WGAN_mnist.py プロジェクト: GRSEB9S/Keras-GAN-2
from models_WGAN import generator_upsampling, generator_deconv
from models_WGAN import discriminator

if __name__ == '__main__':
    nbatch = 64
    x, y, stream = get_mnist(nbatch)

    init = InitNormal(scale=0.02)
    #   g = MLP(g_size=(1, 28, 28),
    #               g_nb_filters=128,
    #               g_nb_coding=50,
    #               g_init=init)
    g = Generator(g_size=(1, 28, 28),
                  g_nb_filters=32,
                  g_FC=[1024],
                  g_nb_coding=100,
                  g_scales=2,
                  g_init=init)

    #   d = Sequential([
    #           Flatten(input_shape=g.output_shape[1:]),
    #           Dense(128, init=init),
    #           Activation('relu'),
    #           Dense(128, init=init),
    #           Activation('relu'),
    #           Dense(128, init=init),
    #           Activation('relu'),
    #           Dense(1, init=init),
    #       ])
    d = Critic(d_size=(1, 28, 28),
               d_nb_filters=32,