def test_gan_container(backend_default): """ Set up a GenerativeAdversarial container and make sure generator and discriminator layers get configured correctly. """ init_norm = Gaussian(loc=0.0, scale=0.01) # set up container and ensure layers get wired up correctly generator = Sequential([Affine(nout=10, init=init_norm), Affine(nout=100, init=init_norm)]) discriminator = Sequential([Affine(nout=100, init=init_norm), Affine(nout=1, init=init_norm)]) layers = GenerativeAdversarial(generator, discriminator) assert len(layers.layers) == 4 assert layers.layers[0].nout == 10 assert layers.layers[1].nout == 100 assert layers.layers[2].nout == 100 assert layers.layers[3].nout == 1 assert layers.generator.layers == layers.layers[0:2] assert layers.discriminator.layers == layers.layers[2:4]
conv4 = dict(init=init, batch_norm=True, activation=lrelu, dilation=dict(dil_h=2, dil_w=2, dil_d=2)) conv5 = dict(init=init, batch_norm=True, activation=lrelu, padding=dict(pad_h=2, pad_w=2, pad_d=0), dilation=dict(dil_h=2, dil_w=2, dil_d=3)) conv6 = dict(init=init, batch_norm=False, activation=lrelu, padding=dict(pad_h=1, pad_w=0, pad_d=3)) G_layers = [ Linear(64 * 7 * 7, init=init), # what's about the input volume Reshape((7, 7, 8, 8)), Conv((6, 6, 8, 64), **conv4), Conv((6, 5, 8, 6), **conv5), Conv((3, 3, 8, 6), **conv6), Conv((2, 2, 2, 1), init=init, batch_norm=False, activation=relu) ] # what's about Embedding layers = GenerativeAdversarial(generator=Sequential(G_layers, name="Generator"), discriminator=Sequential(D_layers, name="Discriminator")) # setup cost function as CrossEntropy cost = GeneralizedGANCost(costfunc=GANCost(func="modified"))
def create_model(dis_model='dc', gen_model='dc', cost_type='wasserstein', noise_type='normal', im_size=64, n_chan=3, n_noise=100, n_gen_ftr=64, n_dis_ftr=64, depth=4, n_extra_layers=0, batch_norm=True, gen_squash=None, dis_squash=None, dis_iters=5, wgan_param_clamp=None, wgan_train_sched=False): """ Create a GAN model and associated GAN cost function for image generation Arguments: dis_model (str): Discriminator type, can be 'mlp' for a simple MLP or 'dc' for a DC-GAN style model. (defaults to 'dc') gen_model (str): Generator type, can be 'mlp' for a simple MLP or 'dc' for a DC-GAN style model. (defaults to 'dc') cost_type (str): Cost type, can be 'original', 'modified' following Goodfellow2014 or 'wasserstein' following Arjovsky2017 (defaults to 'wasserstein') noise_type (str): Noise distribution, can be 'uniform or' 'normal' (defaults to 'normal') im_size (int): Image size (defaults to 64) n_chan (int): Number of image channels (defaults to 3) n_noise (int): Number of noise dimensions (defaults to 100) n_gen_ftr (int): Number of generator feature maps (defaults to 64) n_dis_ftr (int): Number of discriminator feature maps (defaults to 64) depth (int): Depth of layers in case of MLP (defaults to 4) n_extra_layers (int): Number of extra conv layers in case of DC (defaults to 0) batch_norm (bool): Enable batch normalization (defaults to True) gen_squash (str or None): Squashing function at the end of generator (defaults to None) dis_squash (str or None): Squashing function at the end of discriminator (defaults to None) dis_iters (int): Number of critics for discriminator (defaults to 5) wgan_param_clamp (float or None): In case of WGAN weight clamp value, None for others wgan_train_sched (bool): Enable training schedule of number of critics (defaults to False) """ assert dis_model in ['mlp', 'dc'], \ "Unsupported model type for discriminator net, supported: 'mlp' and 'dc'" assert gen_model in ['mlp', 'dc'], \ "Unsupported model type for generator net, supported: 'mlp' and 'dc'" assert cost_type in ['original', 'modified', 'wasserstein'], \ "Unsupported GAN cost function type, supported: 'original', 'modified' and 'wasserstein'" # types of final squashing functions squash_func = dict(nosquash=Identity(), sym=Tanh(), asym=Logistic()) if cost_type == 'wasserstein': if gen_model == 'mlp': gen_squash = gen_squash or 'nosquash' elif gen_model == 'dc': gen_squash = gen_squash or 'sym' dis_squash = dis_squash or 'nosquash' else: # for all GAN costs other than Wasserstein gen_squash = gen_squash or 'sym' dis_squash = dis_squash or 'asym' assert gen_squash in ['nosquash', 'sym', 'asym'], \ "Unsupported final squashing function for generator," \ " supported: 'nosquash', 'sym' and 'asym'" assert dis_squash in ['nosquash', 'sym', 'asym'], \ "Unsupported final squashing function for discriminator," \ " supported: 'nosquash', 'sym' and 'asym'" gfa = squash_func[gen_squash] dfa = squash_func[dis_squash] # create model layers if gen_model == 'mlp': gen = create_mlp_generator(im_size, n_chan, n_gen_ftr, depth, batch_norm=False, finact=gfa) noise_dim = (n_noise, ) elif gen_model == 'dc': gen = create_dc_generator(im_size, n_chan, n_noise, n_gen_ftr, n_extra_layers, batch_norm, finact=gfa) noise_dim = (n_noise, 1, 1) if dis_model == 'mlp': dis = create_mlp_discriminator(im_size, n_dis_ftr, depth, batch_norm=False, finact=dfa) elif dis_model == 'dc': dis = create_dc_discriminator(im_size, n_chan, n_dis_ftr, n_extra_layers, batch_norm, finact=dfa) layers = GenerativeAdversarial(generator=Sequential(gen, name="Generator"), discriminator=Sequential( dis, name="Discriminator")) return GAN(layers=layers, noise_dim=noise_dim, noise_type=noise_type, k=dis_iters, wgan_param_clamp=wgan_param_clamp, wgan_train_sched=wgan_train_sched), \ GeneralizedGANCost(costfunc=GANCost(func=cost_type))