示例#1
0
def test_gan_container(backend_default):
    """
    Set up a GenerativeAdversarial container and make sure generator
    and discriminator layers get configured correctly.
    """
    init_norm = Gaussian(loc=0.0, scale=0.01)
    # set up container and ensure layers get wired up correctly
    generator = Sequential([Affine(nout=10, init=init_norm), Affine(nout=100, init=init_norm)])
    discriminator = Sequential([Affine(nout=100, init=init_norm), Affine(nout=1, init=init_norm)])
    layers = GenerativeAdversarial(generator, discriminator)

    assert len(layers.layers) == 4
    assert layers.layers[0].nout == 10
    assert layers.layers[1].nout == 100
    assert layers.layers[2].nout == 100
    assert layers.layers[3].nout == 1
    assert layers.generator.layers == layers.layers[0:2]
    assert layers.discriminator.layers == layers.layers[2:4]
示例#2
0
文件: gan3D.py 项目: whhopkins/3Dgan
conv4 = dict(init=init,
             batch_norm=True,
             activation=lrelu,
             dilation=dict(dil_h=2, dil_w=2, dil_d=2))
conv5 = dict(init=init,
             batch_norm=True,
             activation=lrelu,
             padding=dict(pad_h=2, pad_w=2, pad_d=0),
             dilation=dict(dil_h=2, dil_w=2, dil_d=3))
conv6 = dict(init=init,
             batch_norm=False,
             activation=lrelu,
             padding=dict(pad_h=1, pad_w=0, pad_d=3))
G_layers = [
    Linear(64 * 7 * 7, init=init),  # what's about the input volume
    Reshape((7, 7, 8, 8)),
    Conv((6, 6, 8, 64), **conv4),
    Conv((6, 5, 8, 6), **conv5),
    Conv((3, 3, 8, 6), **conv6),
    Conv((2, 2, 2, 1), init=init, batch_norm=False, activation=relu)
]
# what's about Embedding

layers = GenerativeAdversarial(generator=Sequential(G_layers,
                                                    name="Generator"),
                               discriminator=Sequential(D_layers,
                                                        name="Discriminator"))

# setup cost function as CrossEntropy
cost = GeneralizedGANCost(costfunc=GANCost(func="modified"))
示例#3
0
def create_model(dis_model='dc',
                 gen_model='dc',
                 cost_type='wasserstein',
                 noise_type='normal',
                 im_size=64,
                 n_chan=3,
                 n_noise=100,
                 n_gen_ftr=64,
                 n_dis_ftr=64,
                 depth=4,
                 n_extra_layers=0,
                 batch_norm=True,
                 gen_squash=None,
                 dis_squash=None,
                 dis_iters=5,
                 wgan_param_clamp=None,
                 wgan_train_sched=False):
    """
    Create a GAN model and associated GAN cost function for image generation

    Arguments:
        dis_model (str): Discriminator type, can be 'mlp' for a simple MLP or
                         'dc' for a DC-GAN style model. (defaults to 'dc')
        gen_model (str): Generator type, can be 'mlp' for a simple MLP or
                         'dc' for a DC-GAN style model. (defaults to 'dc')
        cost_type (str): Cost type, can be 'original', 'modified' following
                         Goodfellow2014 or 'wasserstein' following Arjovsky2017
                         (defaults to 'wasserstein')
        noise_type (str): Noise distribution, can be 'uniform or' 'normal'
                          (defaults to 'normal')
        im_size (int): Image size (defaults to 64)
        n_chan (int): Number of image channels (defaults to 3)
        n_noise (int): Number of noise dimensions (defaults to 100)
        n_gen_ftr (int): Number of generator feature maps (defaults to 64)
        n_dis_ftr (int): Number of discriminator feature maps (defaults to 64)
        depth (int): Depth of layers in case of MLP (defaults to 4)
        n_extra_layers (int): Number of extra conv layers in case of DC (defaults to 0)
        batch_norm (bool): Enable batch normalization (defaults to True)
        gen_squash (str or None): Squashing function at the end of generator (defaults to None)
        dis_squash (str or None): Squashing function at the end of discriminator (defaults to None)
        dis_iters (int): Number of critics for discriminator (defaults to 5)
        wgan_param_clamp (float or None): In case of WGAN weight clamp value, None for others
        wgan_train_sched (bool): Enable training schedule of number of critics (defaults to False)
    """
    assert dis_model in ['mlp', 'dc'], \
        "Unsupported model type for discriminator net, supported: 'mlp' and 'dc'"
    assert gen_model in ['mlp', 'dc'], \
        "Unsupported model type for generator net, supported: 'mlp' and 'dc'"
    assert cost_type in ['original', 'modified', 'wasserstein'], \
        "Unsupported GAN cost function type, supported: 'original', 'modified' and 'wasserstein'"

    # types of final squashing functions
    squash_func = dict(nosquash=Identity(), sym=Tanh(), asym=Logistic())
    if cost_type == 'wasserstein':
        if gen_model == 'mlp':
            gen_squash = gen_squash or 'nosquash'
        elif gen_model == 'dc':
            gen_squash = gen_squash or 'sym'
        dis_squash = dis_squash or 'nosquash'
    else:  # for all GAN costs other than Wasserstein
        gen_squash = gen_squash or 'sym'
        dis_squash = dis_squash or 'asym'

    assert gen_squash in ['nosquash', 'sym', 'asym'], \
        "Unsupported final squashing function for generator," \
        " supported: 'nosquash', 'sym' and 'asym'"
    assert dis_squash in ['nosquash', 'sym', 'asym'], \
        "Unsupported final squashing function for discriminator," \
        " supported: 'nosquash', 'sym' and 'asym'"

    gfa = squash_func[gen_squash]
    dfa = squash_func[dis_squash]

    # create model layers
    if gen_model == 'mlp':
        gen = create_mlp_generator(im_size,
                                   n_chan,
                                   n_gen_ftr,
                                   depth,
                                   batch_norm=False,
                                   finact=gfa)
        noise_dim = (n_noise, )
    elif gen_model == 'dc':
        gen = create_dc_generator(im_size,
                                  n_chan,
                                  n_noise,
                                  n_gen_ftr,
                                  n_extra_layers,
                                  batch_norm,
                                  finact=gfa)
        noise_dim = (n_noise, 1, 1)

    if dis_model == 'mlp':
        dis = create_mlp_discriminator(im_size,
                                       n_dis_ftr,
                                       depth,
                                       batch_norm=False,
                                       finact=dfa)
    elif dis_model == 'dc':
        dis = create_dc_discriminator(im_size,
                                      n_chan,
                                      n_dis_ftr,
                                      n_extra_layers,
                                      batch_norm,
                                      finact=dfa)
    layers = GenerativeAdversarial(generator=Sequential(gen, name="Generator"),
                                   discriminator=Sequential(
                                       dis, name="Discriminator"))

    return GAN(layers=layers, noise_dim=noise_dim, noise_type=noise_type, k=dis_iters,
               wgan_param_clamp=wgan_param_clamp, wgan_train_sched=wgan_train_sched), \
        GeneralizedGANCost(costfunc=GANCost(func=cost_type))