Ejemplo n.º 1
0
def mask_blending_gan(offset_generator, mask_generator, discriminator,
                      nb_fake=64, nb_real=32):
    assert len(mask_generator.input_shape) == 2
    assert len(offset_generator.input_shape) == 2

    g = Graph()
    mask_input_dim = mask_generator.input_shape[1]
    z_shape = (nb_fake, offset_generator.input_shape[1] - mask_input_dim)
    g.add_input(GAN.z_name, batch_input_shape=z_shape)

    g.add_node(Dense(mask_input_dim, activation='relu'), 'dense_mask',
               input=GAN.z_name)
    g.add_node(mask_generator, 'mask_generator', input='dense_mask')
    g.add_node(offset_generator, 'gen_offset',
               inputs=[GAN.z_name, 'dense_mask'])
    g.add_node(PyramidBlending(mask_generator, input_pyramid_layers=3,
                               mask_pyramid_layers=2),
               GAN.generator_name,
               input='gen_offset')

    real_shape = (nb_real,) + g.nodes[GAN.generator_name].output_shape[1:]
    g.add_input(GAN.real_name, batch_input_shape=real_shape)
    g.add_node(discriminator, "discriminator",
               inputs=[GAN.generator_name, "real"], concat_axis=0)
    gan_outputs(g, fake_for_gen=(0, nb_fake),
                    fake_for_dis=(nb_fake//2, nb_fake),
                    real=(nb_fake, nb_fake+nb_real))
    return g
Ejemplo n.º 2
0
def mask_blending_gan_new(offset_generator,
                          mask_generator,
                          discriminator,
                          nb_fake=64,
                          nb_real=32):
    assert len(mask_generator.input_shape) == 2
    assert len(offset_generator.input_shape) == 2

    g = Graph()
    mask_input_dim = mask_generator.input_shape[1]
    z_shape = (nb_fake, offset_generator.input_shape[1] - mask_input_dim)

    g.add_input(GAN.z_name, batch_input_shape=z_shape)

    g.add_node(Dense(32), 'gen_driver_dense_1', input=GAN.z_name)
    g.add_node(BatchNormalization(),
               'gen_driver_bn_1',
               input='gen_driver_dense_1')
    g.add_node(Activation('relu'), 'gen_driver_act_1', input='gen_driver_bn_1')

    g.add_node(Dense(mask_input_dim),
               'gen_driver_dense_2',
               input='gen_driver_act_1')
    g.add_node(BatchNormalization(),
               'gen_driver_bn_2',
               input='gen_driver_dense_2')
    g.add_node(Layer(), 'driver', input='gen_driver_bn_2')
    # g.add_node(ZeroGradient(), 'gen_driver_zero_grad', input='driver')

    g.add_node(mask_generator, 'mask_generator', input='driver')
    g.add_node(offset_generator, 'gen_offset', input=GAN.z_name)
    g.add_node(PyramidBlending(mask_generator,
                               input_pyramid_layers=3,
                               mask_pyramid_layers=2),
               'blending',
               input='gen_offset')
    reg_layer = Layer()
    act = ActivityInBoundsRegularizer(-1, 1)
    act.set_layer(reg_layer)
    reg_layer.regularizers = [act]
    g.add_node(reg_layer, GAN.generator_name, input='blending')

    real_shape = (nb_real, ) + g.nodes[GAN.generator_name].output_shape[1:]
    g.add_input(GAN.real_name, batch_input_shape=real_shape)
    g.add_node(discriminator,
               "discriminator",
               inputs=[GAN.generator_name, "real"],
               concat_axis=0)
    gan_outputs(g,
                fake_for_gen=(0, nb_fake - nb_real),
                fake_for_dis=(nb_fake - nb_real, nb_fake),
                real=(nb_fake, nb_fake + nb_real))
    return g
Ejemplo n.º 3
0
def test_gan_get_config(tmpdir):
    z_shape = (1, 8, 8)

    z = Input(z_shape, name='z')
    g_out = Convolution2D(10, 2, 2, activation='relu', border_mode='same')(z)
    generator = Container(z, g_out)
    f, r = Input(z_shape, name='f'), Input(z_shape, name='r')

    dis_input = merge([f, r], mode='concat', concat_axis=1)
    dis_conv = Convolution2D(5, 2, 2, activation='relu')(dis_input)
    dis_flatten = Flatten()(dis_conv)
    dis = Dense(1, activation='sigmoid')(dis_flatten)
    discriminator = Container([f, r], gan_outputs(dis))

    gan = GAN(generator, discriminator, z_shape, z_shape)
    weights_fname = str(tmpdir.mkdir("weights").join("{}.hdf5"))
    gan.save_weights(weights_fname)
    true_config = gan.get_config()

    import json
    with open(os.path.join(TEST_OUTPUT_DIR, "true_config.json"), 'w+') as f:
        json.dump(true_config, f, indent=2)

    gan_from_config = layer_from_config(true_config, custom_objects={
        'GAN': GAN,
        'Split': Split,
    })

    with open(os.path.join(TEST_OUTPUT_DIR, "loaded_config.json"), 'w+') as f:
        json.dump(gan_from_config.get_config(), f, indent=2)
    gan_from_config.load_weights(weights_fname)
Ejemplo n.º 4
0
def test_gan_custom_layer_graph():
    z_shape = (1, 8, 8)
    z = Input(shape=z_shape, name='z')
    gen_cond = Input(shape=(1, 8, 8), name='gen_cond')

    inputs = [z, gen_cond]
    gen_input = merge(inputs, mode='concat', concat_axis=1)
    gen_output = Convolution2D(1, 2, 2, activation='relu',
                               name='g1',
                               border_mode='same')(gen_input)
    generator = Container(inputs, gen_output)

    f, r = Input(z_shape, name='fake'), Input(z_shape, name='real')
    inputs = [f, r]
    dis_input = merge(inputs, mode='concat', concat_axis=0)
    dis_conv = Convolution2D(5, 2, 2, name='d1', activation='relu')(dis_input)
    dis_flatten = Flatten()(dis_conv)
    dis = Dense(1, activation='sigmoid')(dis_flatten)
    discriminator = Container(inputs, gan_outputs(dis))

    gan = GAN(generator, discriminator, z_shape=z_shape, real_shape=z_shape)
    gan.build('adam', 'adam', gan_binary_crossentropy)
    fn = gan.compile_custom_layers(['g1', 'd1'])
    z = np.random.uniform(-1, 1, (64,) + z_shape)
    real = np.random.uniform(-1, 1, (64,) + z_shape)
    cond = np.random.uniform(-1, 1, (64,) + z_shape)
    print(z.shape)
    print(real.shape)
    print(cond.shape)
    fn({'z': z, 'gen_cond': cond, 'real': real})
Ejemplo n.º 5
0
 def get_discriminator():
     f, r = Input(z_shape, name='f'), Input(z_shape, name='r')
     inputs = [f, r]
     dis_input = merge(inputs, mode='concat', concat_axis=1)
     dis_conv = Convolution2D(5, 2, 2, activation='relu')(dis_input)
     dis_flatten = Flatten()(dis_conv)
     dis = Dense(1, activation='sigmoid')(dis_flatten)
     return Container(inputs, gan_outputs(dis))
Ejemplo n.º 6
0
 def discriminator(x):
     return gan_outputs(sequential([
         Flatten(),
         Dense(1),
     ])(concat(x)),
                        fake_for_gen=(0, 10),
                        fake_for_dis=(0, 10),
                        real=(10, 20))
Ejemplo n.º 7
0
def mask_blending_gan_new(offset_generator, mask_generator, discriminator,
                          nb_fake=64, nb_real=32):
    assert len(mask_generator.input_shape) == 2
    assert len(offset_generator.input_shape) == 2

    g = Graph()
    mask_input_dim = mask_generator.input_shape[1]
    z_shape = (nb_fake, offset_generator.input_shape[1] - mask_input_dim)

    g.add_input(GAN.z_name, batch_input_shape=z_shape)

    g.add_node(Dense(32), 'gen_driver_dense_1', input=GAN.z_name)
    g.add_node(BatchNormalization(), 'gen_driver_bn_1',
               input='gen_driver_dense_1')
    g.add_node(Activation('relu'), 'gen_driver_act_1',
               input='gen_driver_bn_1')

    g.add_node(Dense(mask_input_dim), 'gen_driver_dense_2',
               input='gen_driver_act_1')
    g.add_node(BatchNormalization(), 'gen_driver_bn_2',
               input='gen_driver_dense_2')
    g.add_node(Layer(), 'driver', input='gen_driver_bn_2')
    # g.add_node(ZeroGradient(), 'gen_driver_zero_grad', input='driver')

    g.add_node(mask_generator, 'mask_generator', input='driver')
    g.add_node(offset_generator, 'gen_offset',
               input=GAN.z_name)
    g.add_node(PyramidBlending(mask_generator, input_pyramid_layers=3,
                               mask_pyramid_layers=2),
               'blending', input='gen_offset')
    reg_layer = Layer()
    act = ActivityInBoundsRegularizer(-1, 1)
    act.set_layer(reg_layer)
    reg_layer.regularizers = [act]
    g.add_node(reg_layer, GAN.generator_name, input='blending')

    real_shape = (nb_real,) + g.nodes[GAN.generator_name].output_shape[1:]
    g.add_input(GAN.real_name, batch_input_shape=real_shape)
    g.add_node(discriminator, "discriminator",
               inputs=[GAN.generator_name, "real"], concat_axis=0)
    gan_outputs(g, fake_for_gen=(0, nb_fake - nb_real),
                    fake_for_dis=(nb_fake - nb_real, nb_fake),
                    real=(nb_fake, nb_fake+nb_real))
    return g
Ejemplo n.º 8
0
def mask_blending_gan(offset_generator,
                      mask_generator,
                      discriminator,
                      nb_fake=64,
                      nb_real=32):
    assert len(mask_generator.input_shape) == 2
    assert len(offset_generator.input_shape) == 2

    g = Graph()
    mask_input_dim = mask_generator.input_shape[1]
    z_shape = (nb_fake, offset_generator.input_shape[1] - mask_input_dim)
    g.add_input(GAN.z_name, batch_input_shape=z_shape)

    g.add_node(Dense(mask_input_dim, activation='relu'),
               'dense_mask',
               input=GAN.z_name)
    g.add_node(mask_generator, 'mask_generator', input='dense_mask')
    g.add_node(offset_generator,
               'gen_offset',
               inputs=[GAN.z_name, 'dense_mask'])
    g.add_node(PyramidBlending(mask_generator,
                               input_pyramid_layers=3,
                               mask_pyramid_layers=2),
               GAN.generator_name,
               input='gen_offset')

    real_shape = (nb_real, ) + g.nodes[GAN.generator_name].output_shape[1:]
    g.add_input(GAN.real_name, batch_input_shape=real_shape)
    g.add_node(discriminator,
               "discriminator",
               inputs=[GAN.generator_name, "real"],
               concat_axis=0)
    gan_outputs(g,
                fake_for_gen=(0, nb_fake),
                fake_for_dis=(nb_fake // 2, nb_fake),
                real=(nb_fake, nb_fake + nb_real))
    return g
Ejemplo n.º 9
0
def simple_gan():
    z = Input(batch_shape=simple_gan_z_shape, name='z')
    generator = sequential([
        Dense(simple_gan_nb_z, activation='relu', name='g1'),
        Dense(simple_gan_nb_z, activation='relu', name='g2'),
        Dense(simple_gan_nb_out, activation='sigmoid', name='g3'),
    ])(z)

    fake = Input(batch_shape=simple_gan_real_shape, name='fake')
    real = Input(batch_shape=simple_gan_real_shape, name='real')

    discriminator = sequential([
        Dense(20, activation='relu', input_dim=2, name='d1'),
        Dense(1, activation='sigmoid', name='d2')
    ])(concat([fake, real], axis=0))
    return GAN(Container(z, generator),
               Container([fake, real],  gan_outputs(discriminator)),
               simple_gan_z_shape[1:], simple_gan_real_shape[1:])
Ejemplo n.º 10
0
def test_gan_graph():
    z_shape = (1, 8, 8)
    z = Input(shape=z_shape, name='z')
    gen_cond = Input(shape=(1, 8, 8), name='gen_cond')

    inputs = [z, gen_cond]
    gen_input = merge(inputs, mode='concat', concat_axis=1)
    gen_output = Convolution2D(10, 2, 2, activation='relu',
                               border_mode='same')(gen_input)
    generator = Container(inputs, gen_output)

    f, r = Input(z_shape, name='f'), Input(z_shape, name='r')
    inputs = [f, r]
    dis_input = merge(inputs, mode='concat', concat_axis=1)
    dis_conv = Convolution2D(5, 2, 2, activation='relu')(dis_input)
    dis_flatten = Flatten()(dis_conv)
    dis = Dense(1, activation='sigmoid')(dis_flatten)
    discriminator = Container(inputs, gan_outputs(dis))

    gan = GAN(generator, discriminator, z_shape=z_shape, real_shape=z_shape)
    gan.build('adam', 'adam', gan_binary_crossentropy)
    gan.compile()
    gan.generate({'gen_cond': np.zeros((64,) + z_shape)}, nb_samples=64)
Ejemplo n.º 11
0
 def discriminator(x):
     return gan_outputs(sequential([
         Flatten(),
         Dense(1),
     ])(concat(x)), fake_for_gen=(0, 10), fake_for_dis=(0, 10),
                        real=(10, 20))