示例#1
0
                           wgan_param_clamp=0.01,
                           wgan_train_sched=True)

# setup optimizer
optimizer = RMSProp(learning_rate=5e-5, decay_rate=0.99, epsilon=1e-8)

# setup data provider
train = make_loader(args.manifest['train'], args.manifest_root, model.be,
                    args.subset_pct, random_seed)

# configure callbacks
callbacks = Callbacks(model, **args.callback_args)
fdir = ensure_dirs_exist(
    os.path.join(os.path.dirname(os.path.realpath(__file__)), 'results/'))
fname = os.path.splitext(os.path.basename(__file__))[0] +\
    '_[' + datetime.now().strftime('%Y-%m-%d-%H-%M-%S') + ']'
im_args = dict(filename=os.path.join(fdir, fname),
               hw=64,
               num_samples=args.batch_size,
               nchan=3,
               sym_range=True)
callbacks.add_callback(GANPlotCallback(**im_args))
callbacks.add_callback(GANCostCallback())

# model fit
model.fit(train,
          optimizer=optimizer,
          num_epochs=args.epochs,
          cost=cost,
          callbacks=callbacks)
示例#2
0
         init=init,
         batch_norm=False,
         activation=Logistic(shortcut=False))
]

layers = GenerativeAdversarial(generator=Sequential(G_layers,
                                                    name="Generator"),
                               discriminator=Sequential(D_layers,
                                                        name="Discriminator"))

# setup cost function as CrossEntropy
cost = GeneralizedCost(
    costfunc=GANCost(cost_type="dis", original_cost=args.original_cost))

# setup optimizer
optimizer = Adam(learning_rate=0.0005, beta_1=0.5)

# initialize model
noise_dim = (2, 7, 7)
gan = GAN(layers=layers, noise_dim=noise_dim, k=args.kbatch)

# configure callbacks
callbacks = Callbacks(gan, eval_set=valid_set, **args.callback_args)
callbacks.add_callback(GANPlotCallback(filename=splitext(__file__)[0], hw=27))
# run fit
gan.fit(train_set,
        optimizer=optimizer,
        num_epochs=args.epochs,
        cost=cost,
        callbacks=callbacks)