コード例 #1
0
ファイル: network_gan.py プロジェクト: xingjinglu/PerfAILibs
def create_model(dis_model='dc',
                 gen_model='dc',
                 cost_type='wasserstein',
                 noise_type='normal',
                 im_size=64,
                 n_chan=3,
                 n_noise=100,
                 n_gen_ftr=64,
                 n_dis_ftr=64,
                 depth=4,
                 n_extra_layers=0,
                 batch_norm=True,
                 gen_squash=None,
                 dis_squash=None,
                 dis_iters=5,
                 wgan_param_clamp=None,
                 wgan_train_sched=False):
    """
    Create a GAN model and associated GAN cost function for image generation

    Arguments:
        dis_model (str): Discriminator type, can be 'mlp' for a simple MLP or
                         'dc' for a DC-GAN style model. (defaults to 'dc')
        gen_model (str): Generator type, can be 'mlp' for a simple MLP or
                         'dc' for a DC-GAN style model. (defaults to 'dc')
        cost_type (str): Cost type, can be 'original', 'modified' following
                         Goodfellow2014 or 'wasserstein' following Arjovsky2017
                         (defaults to 'wasserstein')
        noise_type (str): Noise distribution, can be 'uniform or' 'normal'
                          (defaults to 'normal')
        im_size (int): Image size (defaults to 64)
        n_chan (int): Number of image channels (defaults to 3)
        n_noise (int): Number of noise dimensions (defaults to 100)
        n_gen_ftr (int): Number of generator feature maps (defaults to 64)
        n_dis_ftr (int): Number of discriminator feature maps (defaults to 64)
        depth (int): Depth of layers in case of MLP (defaults to 4)
        n_extra_layers (int): Number of extra conv layers in case of DC (defaults to 0)
        batch_norm (bool): Enable batch normalization (defaults to True)
        gen_squash (str or None): Squashing function at the end of generator (defaults to None)
        dis_squash (str or None): Squashing function at the end of discriminator (defaults to None)
        dis_iters (int): Number of critics for discriminator (defaults to 5)
        wgan_param_clamp (float or None): In case of WGAN weight clamp value, None for others
        wgan_train_sched (bool): Enable training schedule of number of critics (defaults to False)
    """
    assert dis_model in ['mlp', 'dc'], \
        "Unsupported model type for discriminator net, supported: 'mlp' and 'dc'"
    assert gen_model in ['mlp', 'dc'], \
        "Unsupported model type for generator net, supported: 'mlp' and 'dc'"
    assert cost_type in ['original', 'modified', 'wasserstein'], \
        "Unsupported GAN cost function type, supported: 'original', 'modified' and 'wasserstein'"

    # types of final squashing functions
    squash_func = dict(nosquash=Identity(), sym=Tanh(), asym=Logistic())
    if cost_type == 'wasserstein':
        if gen_model == 'mlp':
            gen_squash = gen_squash or 'nosquash'
        elif gen_model == 'dc':
            gen_squash = gen_squash or 'sym'
        dis_squash = dis_squash or 'nosquash'
    else:  # for all GAN costs other than Wasserstein
        gen_squash = gen_squash or 'sym'
        dis_squash = dis_squash or 'asym'

    assert gen_squash in ['nosquash', 'sym', 'asym'], \
        "Unsupported final squashing function for generator," \
        " supported: 'nosquash', 'sym' and 'asym'"
    assert dis_squash in ['nosquash', 'sym', 'asym'], \
        "Unsupported final squashing function for discriminator," \
        " supported: 'nosquash', 'sym' and 'asym'"

    gfa = squash_func[gen_squash]
    dfa = squash_func[dis_squash]

    # create model layers
    if gen_model == 'mlp':
        gen = create_mlp_generator(im_size,
                                   n_chan,
                                   n_gen_ftr,
                                   depth,
                                   batch_norm=False,
                                   finact=gfa)
        noise_dim = (n_noise, )
    elif gen_model == 'dc':
        gen = create_dc_generator(im_size,
                                  n_chan,
                                  n_noise,
                                  n_gen_ftr,
                                  n_extra_layers,
                                  batch_norm,
                                  finact=gfa)
        noise_dim = (n_noise, 1, 1)

    if dis_model == 'mlp':
        dis = create_mlp_discriminator(im_size,
                                       n_dis_ftr,
                                       depth,
                                       batch_norm=False,
                                       finact=dfa)
    elif dis_model == 'dc':
        dis = create_dc_discriminator(im_size,
                                      n_chan,
                                      n_dis_ftr,
                                      n_extra_layers,
                                      batch_norm,
                                      finact=dfa)
    layers = GenerativeAdversarial(generator=Sequential(gen, name="Generator"),
                                   discriminator=Sequential(
                                       dis, name="Discriminator"))

    return GAN(layers=layers, noise_dim=noise_dim, noise_type=noise_type, k=dis_iters,
               wgan_param_clamp=wgan_param_clamp, wgan_train_sched=wgan_train_sched), \
        GeneralizedGANCost(costfunc=GANCost(func=cost_type))
コード例 #2
0
ファイル: mnist_dcgan.py プロジェクト: anlthms/neon
         init=init,
         batch_norm=False,
         activation=Logistic(shortcut=False))
]

layers = GenerativeAdversarial(generator=Sequential(G_layers,
                                                    name="Generator"),
                               discriminator=Sequential(D_layers,
                                                        name="Discriminator"))

# setup cost function as CrossEntropy
cost = GeneralizedCost(
    costfunc=GANCost(cost_type="dis", original_cost=args.original_cost))

# setup optimizer
optimizer = Adam(learning_rate=0.0005, beta_1=0.5)

# initialize model
noise_dim = (2, 7, 7)
gan = GAN(layers=layers, noise_dim=noise_dim, k=args.kbatch)

# configure callbacks
callbacks = Callbacks(gan, eval_set=valid_set, **args.callback_args)
callbacks.add_callback(GANPlotCallback(filename=splitext(__file__)[0], hw=27))
# run fit
gan.fit(train_set,
        optimizer=optimizer,
        num_epochs=args.epochs,
        cost=cost,
        callbacks=callbacks)
コード例 #3
0
# setup optimizer
optimizer = RMSProp(learning_rate=1e-3, decay_rate=0.9, epsilon=1e-8)
# optimizer = GradientDescentMomentum(learning_rate=1e-3, momentum_coef = 0.9)

# setup cost function as Binary CrossEntropy
cost = GeneralizedGANCost(costfunc=GANCost(func="modified"))

nb_epochs = 2
batch_size = 10
latent_size = 50
inb_classes = 2
nb_test = 100

# initialize model
noise_dim = (latent_size)
gan = GAN(layers=layers, noise_dim=noise_dim)

# configure callbacks
callbacks = Callbacks(gan, eval_set=valid_set, **args.callback_args)
fdir = ensure_dirs_exist(
    os.path.join(os.path.dirname(os.path.realpath(__file__)), 'results/'))
fname = os.path.splitext(os.path.basename(__file__))[0] +\
    '_[' + datetime.now().strftime('%Y-%m-%d-%H-%M-%S') + ']'
#im_args = dict(filename=os.path.join(fdir, fname), hw=24,
#               num_samples=args.batch_size, nchan=1, sym_range=True)
#callbacks.add_callback(GANPlotCallback(**im_args))
callbacks.add_callback(GANCostCallback())
#callbacks.add_save_best_state_callback("./best_state.pkl")

# run fit
gan.fit(train_set,
コード例 #4
0
# setup optimizer
# optimizer = RMSProp(learning_rate=1e-4, decay_rate=0.9, epsilon=1e-8)
optimizer = GradientDescentMomentum(learning_rate=1e-3, momentum_coef = 0.9)
#optimizer = Adam(learning_rate=1e-3)

# setup cost function as Binary CrossEntropy
cost = GeneralizedGANCost(costfunc=GANCost(func="wasserstein"))

nb_epochs = 15
latent_size = 200
inb_classes = 2
nb_test = 100

# initialize model
noise_dim = (latent_size)
gan = GAN(layers=layers, noise_dim=noise_dim, k=5, wgan_param_clamp=0.9)

# configure callbacks
callbacks = Callbacks(gan, eval_set=valid_set)
callbacks.add_callback(GANCostCallback())
#callbacks.add_save_best_state_callback("./best_state.pkl")

print 'starting training'
# run fit
gan.fit(train_set, num_epochs=nb_epochs, optimizer=optimizer,
        cost=cost, callbacks=callbacks)

# gan.save_params('our_gan.prm')

x_new = np.random.randn(100, latent_size) 
inference_set = HDF5Iterator(x_new, None, nclass=2, lshape=(latent_size))
コード例 #5
0
ファイル: mnist_dcgan.py プロジェクト: StevenLOL/neon
            Conv((7, 7, 1), name="D_out",
                 init=init, batch_norm=False,
                 activation=Logistic(shortcut=False))]

layers = GenerativeAdversarial(generator=Sequential(G_layers, name="Generator"),
                               discriminator=Sequential(D_layers, name="Discriminator"))

# setup cost function as CrossEntropy
cost = GeneralizedGANCost(costfunc=GANCost(func="modified"))

# setup optimizer
optimizer = Adam(learning_rate=0.0005, beta_1=0.5)

# initialize model
noise_dim = (2, 7, 7)
gan = GAN(layers=layers, noise_dim=noise_dim, k=args.kbatch)

# configure callbacks
callbacks = Callbacks(gan, eval_set=valid_set, **args.callback_args)
fdir = ensure_dirs_exist(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'results/'))
fname = os.path.splitext(os.path.basename(__file__))[0] +\
    '_[' + datetime.now().strftime('%Y-%m-%d-%H-%M-%S') + ']'
im_args = dict(filename=os.path.join(fdir, fname), hw=27,
               num_samples=args.batch_size, nchan=1, sym_range=True)
callbacks.add_callback(GANPlotCallback(**im_args))
callbacks.add_callback(GANCostCallback())

# run fit
gan.fit(train_set, optimizer=optimizer,
        num_epochs=args.epochs, cost=cost, callbacks=callbacks)