Example #1
0
def test_gan_get_config(tmpdir):
    z_shape = (1, 8, 8)

    z = Input(z_shape, name='z')
    g_out = Convolution2D(10, 2, 2, activation='relu', border_mode='same')(z)
    generator = Container(z, g_out)
    f, r = Input(z_shape, name='f'), Input(z_shape, name='r')

    dis_input = merge([f, r], mode='concat', concat_axis=1)
    dis_conv = Convolution2D(5, 2, 2, activation='relu')(dis_input)
    dis_flatten = Flatten()(dis_conv)
    dis = Dense(1, activation='sigmoid')(dis_flatten)
    discriminator = Container([f, r], gan_outputs(dis))

    gan = GAN(generator, discriminator, z_shape, z_shape)
    weights_fname = str(tmpdir.mkdir("weights").join("{}.hdf5"))
    gan.save_weights(weights_fname)
    true_config = gan.get_config()

    import json
    with open(os.path.join(TEST_OUTPUT_DIR, "true_config.json"), 'w+') as f:
        json.dump(true_config, f, indent=2)

    gan_from_config = layer_from_config(true_config, custom_objects={
        'GAN': GAN,
        'Split': Split,
    })

    with open(os.path.join(TEST_OUTPUT_DIR, "loaded_config.json"), 'w+') as f:
        json.dump(gan_from_config.get_config(), f, indent=2)
    gan_from_config.load_weights(weights_fname)
Example #2
0
def mogan(self, gan: GAN, loss_fn, d_optimizer, name="mogan",
          gan_objective=binary_crossentropy, gan_regulizer=None,
          cond_true_ndim=4):
    assert len(gan.conditionals) >= 1

    g_dummy_opt = SGD()
    d_optimizer = d_optimizer
    v = gan.build(g_dummy_opt, d_optimizer, gan_objective)
    del v['g_updates']

    cond_true = K.placeholder(ndim=cond_true_ndim)
    inputs = copy(gan.graph.inputs)
    inputs['cond_true'] = cond_true

    cond_loss = loss_fn(cond_true, v.g_outmap)

    metrics = {
        "cond_loss": cond_loss.mean(),
        "d_loss": v.d_loss,
        "g_loss": v.g_loss,
    }

    params = flatten([n.trainable_weights
                      for n in gan.get_generator_nodes().values()])

    return MultipleObjectives(
        name, inputs, metrics=metrics, params=params,
        objectives={'g_loss': v['g_loss'], 'cond_loss': cond_loss},
        additional_updates=v['d_updates'] + gan.updates)
Example #3
0
def test_gan_custom_layer_graph():
    z_shape = (1, 8, 8)
    z = Input(shape=z_shape, name='z')
    gen_cond = Input(shape=(1, 8, 8), name='gen_cond')

    inputs = [z, gen_cond]
    gen_input = merge(inputs, mode='concat', concat_axis=1)
    gen_output = Convolution2D(1, 2, 2, activation='relu',
                               name='g1',
                               border_mode='same')(gen_input)
    generator = Container(inputs, gen_output)

    f, r = Input(z_shape, name='fake'), Input(z_shape, name='real')
    inputs = [f, r]
    dis_input = merge(inputs, mode='concat', concat_axis=0)
    dis_conv = Convolution2D(5, 2, 2, name='d1', activation='relu')(dis_input)
    dis_flatten = Flatten()(dis_conv)
    dis = Dense(1, activation='sigmoid')(dis_flatten)
    discriminator = Container(inputs, gan_outputs(dis))

    gan = GAN(generator, discriminator, z_shape=z_shape, real_shape=z_shape)
    gan.build('adam', 'adam', gan_binary_crossentropy)
    fn = gan.compile_custom_layers(['g1', 'd1'])
    z = np.random.uniform(-1, 1, (64,) + z_shape)
    real = np.random.uniform(-1, 1, (64,) + z_shape)
    cond = np.random.uniform(-1, 1, (64,) + z_shape)
    print(z.shape)
    print(real.shape)
    print(cond.shape)
    fn({'z': z, 'gen_cond': cond, 'real': real})
def construct_gan(sample, nb_fake, nb_real):
    print('construct_gan')
    g_mask = dcgan_small_generator(nb_units=128, input_dim=22)
    g_mask.trainable = False
    d = dcgan_discriminator(out_activation='sigmoid')

    mask_driver = get_mask_driver(input_dim=sample['z_dim'],
                                  output_dim=g_mask.input_shape[1])
    print('construct_gan')
    graph = mask_blending_gan_hyperopt(
        mask_driver,
        g_mask,
        d,
        offset_inputs=sample['offset_inputs'],
        offset_nb_units=sample['offset_nb_units'],
        merge=sample['merge'],
        nb_merge_conv_layers=sample['nb_merge_conv_layers'],
        z_dim=sample['z_dim'],
        nb_fake=nb_fake,
        nb_real=nb_real,
    )
    g_mask.load_weights(
        "models/train_autoencoder_mask_generator_adam_n128_gray/mask_generator.hdf5"
    )
    return GAN(graph)
Example #5
0
def dcmogan(generator_fn, discriminator_fn, batch_size=128):
    nb_g_z = 20
    nb_grid_config = NUM_CONFIGS + NUM_MIDDLE_CELLS + len(CONFIG_ROTS)
    ff_generator = generator_fn(input_dim=2 * nb_g_z, nb_output_channels=2)

    g = Graph()
    g.add_input("z", (nb_g_z, ))
    g.add_input("grid_config", (nb_grid_config, ))
    g.add_node(Dense(nb_g_z, activation='relu'), "dense1", input="grid_config")
    g.add_node(ff_generator,
               "dcgan",
               inputs=["z", "dense1"],
               merge_mode='concat')
    g.add_output("output", input="dcgan")
    g.add_node(Dense(1, activation='sigmoid'),
               "alpha",
               input="dense1",
               create_output=True)

    def reconstruct(g_outmap):
        g_out = g_outmap["output"]
        alpha = g_outmap["alpha"]
        alpha = 0.5 * alpha + 0.5
        alpha = alpha.reshape((batch_size, 1, 1, 1))
        m = g_out[:, :1]
        v = g_out[:, 1:]
        return (alpha * m + (1 - alpha) * v).reshape(
            (batch_size, 1, TAG_SIZE, TAG_SIZE))

    grid_loss_weight = theano.shared(np.cast[np.float32](1))

    def grid_loss(grid_idx, g_outmap):
        g_out = g_outmap['output']
        m = g_out[:, :1]
        b = binary_mask(grid_idx, ignore=0.0, white=1.)
        return grid_loss_weight * mse(b, m)

    gan = GAN(g,
              asgraph(discriminator_fn(), input_name=GAN.d_input),
              z_shape=(batch_size, nb_g_z),
              reconstruct_fn=reconstruct)
    mogan = MOGAN(gan,
                  grid_loss,
                  lambda: Adam(lr=0.0002, beta_1=0.5),
                  gan_regulizer=GAN.L2Regularizer())

    return mogan, grid_loss_weight
Example #6
0
def test_gan_save_weights(tmpdir):
    z_shape = (1, 8, 8)
    gen_cond = Input(shape=(1, 8, 8), name='gen_cond')
    def get_generator():
        z = Input(shape=z_shape, name='z')

        inputs = [z, gen_cond]
        gen_input = merge(inputs, mode='concat', concat_axis=1)
        gen_output = Convolution2D(10, 2, 2, activation='relu',
                                   border_mode='same')(gen_input)
        return Container(inputs, gen_output)
    def get_discriminator():
        f, r = Input(z_shape, name='f'), Input(z_shape, name='r')
        inputs = [f, r]
        dis_input = merge(inputs, mode='concat', concat_axis=1)
        dis_conv = Convolution2D(5, 2, 2, activation='relu')(dis_input)
        dis_flatten = Flatten()(dis_conv)
        dis = Dense(1, activation='sigmoid')(dis_flatten)
        return Container(inputs, gan_outputs(dis))

    gan = GAN(get_generator(), get_discriminator(), z_shape=z_shape,
              real_shape=z_shape)
    gan.save_weights(str(tmpdir + "/{}.hdf5"))

    gan_load = GAN(get_generator(), get_discriminator(), z_shape=z_shape,
                   real_shape=z_shape)

    all_equal = True
    for s, l in zip(gan.layers, gan_load.layers):
        if not all([
            (sw.get_value() == lw.get_value()).all()
            for sw, lw in zip(s.trainable_weights, l.trainable_weights)
        ]):
            all_equal = False
    assert not all_equal

    gan_load.generator.load_weights(str(tmpdir.join("generator.hdf5")))

    for s, l in zip(gan.generator.layers, gan_load.generator.layers):
        for sw, lw in zip(s.trainable_weights, l.trainable_weights):
            assert (sw.get_value() == lw.get_value()).all()

    gan_load.discriminator.load_weights(str(tmpdir.join("discriminator.hdf5")))

    for s, l in zip(gan.layers, gan_load.layers):
        for sw, lw in zip(s.trainable_weights, l.trainable_weights):
            assert (sw.get_value() == lw.get_value()).all()
Example #7
0
def test_gan_utility_funcs(simple_gan: GAN):
    simple_gan.build('adam', 'adam', gan_binary_crossentropy)
    simple_gan.compile()
    xy_shp = simple_gan_z_shape[1:]
    x = np.zeros(xy_shp, dtype=np.float32)
    y = np.zeros(xy_shp, dtype=np.float32)
    simple_gan.interpolate(x, y)

    z_point = simple_gan.random_z_point()
    neighbors = simple_gan.neighborhood(z_point, std=0.05)

    diff = np.stack([neighbors[0]]*len(neighbors)) - neighbors
    assert np.abs(diff).mean() < 0.1
Example #8
0
def test_gan_stop_regularizer():
    reg = GAN.StopRegularizer()

    g_loss = theano.shared(np.cast['float32'](reg.high.get_value() + 2))
    d_loss = theano.shared(np.cast['float32'](1.))
    _, d_reg = reg(g_loss, d_loss)
    assert d_reg.eval() == 0

    g_loss = theano.shared(np.cast['float32'](reg.high.get_value() - 0.2))
    d_loss = theano.shared(np.cast['float32'](1.))
    _, d_reg = reg(g_loss, d_loss)
    assert d_reg.eval() == d_loss.eval()
Example #9
0
def test_gan_learn_simple_distribution():
    def sample_multivariate(nb_samples):
        mean = (0.2, 0)
        cov = [[0.5,  0.1],
               [0.2,  0.4]]
        return np.random.multivariate_normal(mean, cov, (nb_samples,))

    nb_samples = 600
    # X = sample_multivariate(nb_samples)
    X = sample_circle(nb_samples)

    for r in (GAN.Regularizer(), GAN.Regularizer()):
        gan = simple_gan()
        gan.add_gan_regularizer(r)
        gan.build('adam', 'adam', gan_binary_crossentropy)
        gan.compile()
        callbacks = []
        if visual_debug:
            callbacks.append(Plotter(X, TEST_OUTPUT_DIR + "/epoches_plot"))
        z = np.random.uniform(-1, 1, (len(X), simple_gan_nb_z))
        gan.fit({'real': X, 'z': z}, nb_epoch=1, verbose=0,
                callbacks=callbacks, batch_size={'real': 32, 'z': 96})
Example #10
0
def mogan(self,
          gan: GAN,
          loss_fn,
          d_optimizer,
          name="mogan",
          gan_objective=binary_crossentropy,
          gan_regulizer=None,
          cond_true_ndim=4):
    assert len(gan.conditionals) >= 1

    g_dummy_opt = SGD()
    d_optimizer = d_optimizer
    v = gan.build(g_dummy_opt, d_optimizer, gan_objective)
    del v['g_updates']

    cond_true = K.placeholder(ndim=cond_true_ndim)
    inputs = copy(gan.graph.inputs)
    inputs['cond_true'] = cond_true

    cond_loss = loss_fn(cond_true, v.g_outmap)

    metrics = {
        "cond_loss": cond_loss.mean(),
        "d_loss": v.d_loss,
        "g_loss": v.g_loss,
    }

    params = flatten(
        [n.trainable_weights for n in gan.get_generator_nodes().values()])

    return MultipleObjectives(name,
                              inputs,
                              metrics=metrics,
                              params=params,
                              objectives={
                                  'g_loss': v['g_loss'],
                                  'cond_loss': cond_loss
                              },
                              additional_updates=v['d_updates'] + gan.updates)
Example #11
0
def get_bravo():
    gen = Sequential()
    gen.add(Convolution2D(10, 3, 2, 2, border_mode='same', activation='relu'))
    gen.add(Convolution2D(20, 10, 2, 2, border_mode='same', activation='relu'))
    gen.add(Convolution2D(10, 20, 2, 2, border_mode='same', activation='relu'))
    gen.add(Convolution2D(1, 10, 2, 2, border_mode='same', activation='sigmoid'))

    dis = Sequential()
    dis.add(Convolution2D(5, 2, 2, 2, border_mode='same', activation='relu'))
    dis.add(Dropout(0.5))
    dis.add(Flatten())
    dis.add(Dense(32*32*5, 512, activation="relu"))
    dis.add(Dense(512, 1, activation="sigmoid"))
    return GAN(gen, dis, z_shape=(batch_size//2, 1, 32, 32), num_gen_conditional=2, num_dis_conditional=1)
Example #12
0
    def simple_gan():
        generator = Sequential()
        generator.add(Dense(20, activation='relu',
                            input_dim=simple_gan_nb_z + simple_gan_nb_cond))
        generator.add(Dense(50, activation='relu'))
        generator.add(Dense(50, activation='relu'))
        generator.add(Dense(simple_gan_nb_out))

        discriminator = Sequential()
        discriminator.add(Dense(25, activation='relu', input_dim=2))
        discriminator.add(Dense(1, activation='sigmoid'))
        discriminator = asgraph(discriminator, input_name=GAN.d_input)
        return GAN(generator, discriminator, simple_gan_z_shape,
                   reconstruct_fn=reconstruction_fn)
Example #13
0
def gan_grid_idx(generator,
                 discriminator,
                 batch_size=128,
                 nb_z=20,
                 reconstruct_fn=None):
    nb_grid_params = nb_normalized_params()
    z_shape = (batch_size, nb_z)
    grid_params_shape = (nb_grid_params, )
    g_graph = Graph()
    g_graph.add_input('z', input_shape=z_shape[1:])
    g_graph.add_input('grid_params', input_shape=grid_params_shape)
    g_graph.add_node(generator, 'generator', inputs=['grid_params', 'z'])
    g_graph.add_output('output', input='generator')
    d_graph = asgraph(discriminator, input_name=GAN.d_input)
    return GAN(g_graph, d_graph, z_shape, reconstruct_fn=reconstruct_fn)
Example #14
0
def simple_gan():
    z = Input(batch_shape=simple_gan_z_shape, name='z')
    generator = sequential([
        Dense(simple_gan_nb_z, activation='relu', name='g1'),
        Dense(simple_gan_nb_z, activation='relu', name='g2'),
        Dense(simple_gan_nb_out, activation='sigmoid', name='g3'),
    ])(z)

    fake = Input(batch_shape=simple_gan_real_shape, name='fake')
    real = Input(batch_shape=simple_gan_real_shape, name='real')

    discriminator = sequential([
        Dense(20, activation='relu', input_dim=2, name='d1'),
        Dense(1, activation='sigmoid', name='d2')
    ])(concat([fake, real], axis=0))
    return GAN(Container(z, generator),
               Container([fake, real],  gan_outputs(discriminator)),
               simple_gan_z_shape[1:], simple_gan_real_shape[1:])
Example #15
0
def mogan_learn_bw_grid(generator,
                        discriminator,
                        optimizer_fn,
                        batch_size=128,
                        nb_z=20):
    variation_weight = 0.3

    def reconstruct(g_outmap):
        g_out = g_outmap["output"]
        grid_idx = g_outmap["grid_idx"]
        z_rot90 = g_outmap['z_rot90']
        alphas = binary_mask(grid_idx,
                             black=variation_weight,
                             ignore=1.0,
                             white=variation_weight)
        m = theano.gradient.disconnected_grad(g_out[:, :1])
        v = g_out[:, 1:]
        combined = v  # T.clip(m + alphas*v, 0., 1.)
        return rotate_by_multiple_of_90(combined, z_rot90)

    grid_loss_weight = theano.shared(np.cast[np.float32](1))

    def grid_loss(g_outmap):
        g_out = g_outmap['output']
        grid_idx = g_outmap["grid_idx"]
        m = g_out[:, :1]
        b = binary_mask(grid_idx,
                        ignore=0,
                        black=0,
                        white=1 - variation_weight)
        return grid_loss_weight * mse(b, m)

    gan = gan_with_z_rot90_grid_idx(generator,
                                    discriminator,
                                    batch_size=batch_size,
                                    nb_z=nb_z,
                                    reconstruct_fn=reconstruct)
    # FIXME
    mogan = MOGAN(gan,
                  grid_loss,
                  optimizer_fn,
                  gan_regulizer=GAN.L2Regularizer())
    return mogan, grid_loss_weight
Example #16
0
def mogan_pyramid(generator,
                  discriminator,
                  optimizer_fn,
                  batch_size=128,
                  nb_z=20,
                  gan_objective=binary_crossentropy,
                  d_loss_grad_weight=0):
    def tag_loss(cond_true, g_out_dict):
        g_out = g_out_dict['output']
        grid_idx = cond_true
        return pyramid_loss(grid_idx, g_out).loss

    gan = gan_grid_idx(generator, discriminator, batch_size, nb_z)
    # FIXME
    mogan = MOGAN(gan,
                  tag_loss,
                  optimizer_fn,
                  gan_regulizer=GAN.L2Regularizer(),
                  gan_objective=gan_objective,
                  d_loss_grad_weight=d_loss_grad_weight)
    return mogan
Example #17
0
def test_gan_graph():
    z_shape = (1, 8, 8)
    z = Input(shape=z_shape, name='z')
    gen_cond = Input(shape=(1, 8, 8), name='gen_cond')

    inputs = [z, gen_cond]
    gen_input = merge(inputs, mode='concat', concat_axis=1)
    gen_output = Convolution2D(10, 2, 2, activation='relu',
                               border_mode='same')(gen_input)
    generator = Container(inputs, gen_output)

    f, r = Input(z_shape, name='f'), Input(z_shape, name='r')
    inputs = [f, r]
    dis_input = merge(inputs, mode='concat', concat_axis=1)
    dis_conv = Convolution2D(5, 2, 2, activation='relu')(dis_input)
    dis_flatten = Flatten()(dis_conv)
    dis = Dense(1, activation='sigmoid')(dis_flatten)
    discriminator = Container(inputs, gan_outputs(dis))

    gan = GAN(generator, discriminator, z_shape=z_shape, real_shape=z_shape)
    gan.build('adam', 'adam', gan_binary_crossentropy)
    gan.compile()
    gan.generate({'gen_cond': np.zeros((64,) + z_shape)}, nb_samples=64)
Example #18
0
def test_mask_blending_generator():
    nb_driver = 20

    def driver(z):
        return Dense(nb_driver)(z)

    def mask_generator(x):
        return sequential([
            Dense(16),
            Reshape((1, 4, 4)),
            UpSampling2D((16, 16))
        ])(x)

    def merge_mask(subsample):
        def call(x):
            if subsample:
                x = MaxPooling2D(subsample)(x)
            return Convolution2D(1, 3, 3, border_mode='same')(x)
        return call

    def light_generator(ins):
        seq = sequential([
            Convolution2D(1, 3, 3, border_mode='same')
        ])(concat(ins))
        return UpSampling2D((4, 4))(seq), UpSampling2D((4, 4))(seq),

    def offset_front(x):
        return sequential([
            Dense(16),
            Reshape((1, 4, 4)),
            UpSampling2D((4, 4))
        ])(concat(x))

    def offset_middle(x):
        return UpSampling2D()(concat(x))

    def offset_back(x):
        feature_map = sequential([
            UpSampling2D(),
        ])(concat(x))
        return feature_map, Convolution2D(1, 3, 3,
                                          border_mode='same')(feature_map)

    def mask_post(x):
        return sequential([
            Convolution2D(1, 3, 3, border_mode='same')
        ])(concat(x))

    def mask_weight_blending(x):
        return sequential([
            Flatten(),
            Dense(1),
        ])(x)

    def discriminator(x):
        return gan_outputs(sequential([
            Flatten(),
            Dense(1),
        ])(concat(x)), fake_for_gen=(0, 10), fake_for_dis=(0, 10),
                           real=(10, 20))

    gen = mask_blending_generator(
        mask_driver=driver,
        mask_generator=mask_generator,
        light_merge_mask16=merge_mask(None),
        offset_merge_light16=merge_mask((4, 4)),
        offset_merge_mask16=merge_mask(None),
        offset_merge_mask32=merge_mask(None),
        lighting_generator=light_generator,
        offset_front=offset_front,
        offset_middle=offset_middle,
        offset_back=offset_back,
        mask_weight_blending32=mask_weight_blending,
        mask_weight_blending64=mask_weight_blending,
        mask_postprocess=mask_post,
        z_for_driver=(0, 10),
        z_for_offset=(10, 20),
        z_for_bits=(20, 32),
    )
    z_shape = (32, )
    real_shape = (1, 64, 64)
    gan = GAN(gen, discriminator, z_shape, real_shape)
    gan.build(Adam(), Adam(), gan_binary_crossentropy)
    for l in gan.gen_layers:
        print("{}: {}, {}".format(
            l.name, l.output_shape, getattr(l, 'regularizers', [])))
    bs = 10
    z_in = np.random.sample((bs,) + z_shape)
    gan.compile_generate()
    gan.generate({'z': z_in})
Example #19
0
def test_mask_blending_generator():
    nb_driver = 20

    def driver(z):
        return Dense(nb_driver)(z)

    def mask_generator(x):
        return sequential(
            [Dense(16), Reshape((1, 4, 4)),
             UpSampling2D((16, 16))])(x)

    def merge_mask(subsample):
        def call(x):
            if subsample:
                x = MaxPooling2D(subsample)(x)
            return Convolution2D(1, 3, 3, border_mode='same')(x)

        return call

    def light_generator(ins):
        seq = sequential([Convolution2D(1, 3, 3,
                                        border_mode='same')])(concat(ins))
        return UpSampling2D((4, 4))(seq), UpSampling2D((4, 4))(seq),

    def offset_front(x):
        return sequential(
            [Dense(16), Reshape((1, 4, 4)),
             UpSampling2D((4, 4))])(concat(x))

    def offset_middle(x):
        return UpSampling2D()(concat(x))

    def offset_back(x):
        feature_map = sequential([
            UpSampling2D(),
        ])(concat(x))
        return feature_map, Convolution2D(1, 3, 3,
                                          border_mode='same')(feature_map)

    def mask_post(x):
        return sequential([Convolution2D(1, 3, 3,
                                         border_mode='same')])(concat(x))

    def mask_weight_blending(x):
        return sequential([
            Flatten(),
            Dense(1),
        ])(x)

    def discriminator(x):
        return gan_outputs(sequential([
            Flatten(),
            Dense(1),
        ])(concat(x)),
                           fake_for_gen=(0, 10),
                           fake_for_dis=(0, 10),
                           real=(10, 20))

    gen = mask_blending_generator(
        mask_driver=driver,
        mask_generator=mask_generator,
        light_merge_mask16=merge_mask(None),
        offset_merge_light16=merge_mask((4, 4)),
        offset_merge_mask16=merge_mask(None),
        offset_merge_mask32=merge_mask(None),
        lighting_generator=light_generator,
        offset_front=offset_front,
        offset_middle=offset_middle,
        offset_back=offset_back,
        mask_weight_blending32=mask_weight_blending,
        mask_weight_blending64=mask_weight_blending,
        mask_postprocess=mask_post,
        z_for_driver=(0, 10),
        z_for_offset=(10, 20),
        z_for_bits=(20, 32),
    )
    z_shape = (32, )
    real_shape = (1, 64, 64)
    gan = GAN(gen, discriminator, z_shape, real_shape)
    gan.build(Adam(), Adam(), gan_binary_crossentropy)
    for l in gan.gen_layers:
        print("{}: {}, {}".format(l.name, l.output_shape,
                                  getattr(l, 'regularizers', [])))
    bs = 10
    z_in = np.random.sample((bs, ) + z_shape)
    gan.compile_generate()
    gan.generate({'z': z_in})
def objective(sample):
    def md5(x):
        m = hashlib.md5()
        m.update(str(sample).encode())
        return m.hexdigest()

    intify = [
        'z_dim', 'offset_nb_units', 'discriminator_nb_units',
        'nb_merge_conv_layers'
    ]
    for k in intify:
        sample[k] = int(sample[k])

    nb_fake = 96
    nb_real = 32
    nb_epoch = 25
    batches_per_epoch = 128

    output_dir = 'models/hyperopt/{}/'.format(md5(sample))
    os.makedirs(output_dir)
    print(sample)
    info = {
        'status': STATUS_OK,
        'output_dir': output_dir,
    }
    try:
        info['loss'] = 0.
        gan = construct_gan(sample, nb_fake, nb_real)
        if sample['gan_regularizer'] == 'stop':
            gan.add_gan_regularizer(GAN.StopRegularizer())
        print(info)
        compile_time = compile(gan, sample)
        print("Done Compiling in {0:.2f}s".format(compile_time))
        info['compile_time'] = compile_time
        vis = VisualiseGAN(10**2,
                           output_dir,
                           preprocess=lambda x: np.clip(x, -1, 1))

        hist = gan.fit_generator(generator(nb_fake, nb_real, sample['z_dim']),
                                 batches_per_epoch,
                                 nb_epoch,
                                 verbose=1,
                                 callbacks=[vis])
        g_loss = np.mean(np.array(hist.history['g_loss']))
        d_loss = np.mean(np.array(hist.history['d_loss']))
        info['loss'] = g_loss - d_loss
        info['hist'] = hist.history
        gan.graph.save_weights(output_dir + 'graph.hdf5')
        with open(output_dir + 'gan.json', 'w') as f:
            f.write(gan.graph.to_json())
    except KeyboardInterrupt:
        sys.exit(1)
    except:
        err, v, tb = sys.exc_info()
        print(type(err))
        print(err)
        print(v)
        print(tb)
        traceback.print_tb(tb, limit=50)
        info['status'] = STATUS_FAIL
    finally:
        return info