def train_g64_d64_pyramid_preprocess_2layer(): nb_units = 64 generator_input_dim = 100 preprocess_input_dim = 50 preprocess_nb_hidden = 256 batch_size = 128 nb_batches_per_epoch = 100 nb_epoch = 200 output_dir = "models/train_g64_d64_pyramid_preprocess_2layer" os.makedirs(output_dir, exist_ok=True) g = dcgan_generator(nb_units, generator_input_dim) g.load_weights("models/dcgan_g64_d64_fine_tune/generator_0060.hdf5") p = Sequential() p.add( Dense(preprocess_nb_hidden, activation='relu', input_dim=preprocess_input_dim)) p.add( Dense(g.layers[0].input_shape[1], activation='relu', input_dim=preprocess_input_dim)) g.trainable = False p.add(g) save = SaveModels({"pyramid_{epoch:04d}.hdf5": p}, every_epoch=20, output_dir=output_dir) nb_z_param = preprocess_input_dim - nb_normalized_params() def generator(): for z, (param, grid_idx) in zip(z_generator((batch_size, nb_z_param)), grids_lecture_generator(batch_size)): yield np.concatenate([param, z], axis=1), grid_idx print(next(generator())[0].shape) print("Compiling...") start = time.time() p.compile('adam', to_keras_loss(pyramid_loss)) print("Done Compiling in {0:.2f}s".format(time.time() - start)) history = p.fit_generator(generator(), nb_batches_per_epoch * batch_size, nb_epoch, verbose=1, callbacks=[save]) with open(os.path.join(output_dir, "history.json"), 'w+') as f: json.dump(history.history, f) with open(os.path.join(output_dir, "network_config.json"), 'w+') as f: f.write(p.to_json())
def gan_grid_idx(generator, discriminator, batch_size=128, nb_z=20, reconstruct_fn=None): nb_grid_params = nb_normalized_params() z_shape = (batch_size, nb_z) grid_params_shape = (nb_grid_params, ) g_graph = Graph() g_graph.add_input('z', input_shape=z_shape[1:]) g_graph.add_input('grid_params', input_shape=grid_params_shape) g_graph.add_node(generator, 'generator', inputs=['grid_params', 'z']) g_graph.add_output('output', input='generator') d_graph = asgraph(discriminator, input_name=GAN.d_input) return GAN(g_graph, d_graph, z_shape, reconstruct_fn=reconstruct_fn)
def dcgan_pyramid_generator(nb_dcgan_units=64, z_dim=50, nb_fake=64): nb_grid_params = nb_normalized_params() g_input_dim = nb_grid_params + z_dim g = Graph() g.add_input(GAN.z_name, batch_input_shape=(nb_fake, z_dim)) g.add_input("gen_grid_params", batch_input_shape=(nb_fake, nb_grid_params)) g.add_input("gen_grid_idx", batch_input_shape=(nb_fake, 1, 64, 64)) tag_mean = Dense(2, activation='relu', init=normal(0.005)) g.add_node(tag_mean, "tag_mean", inputs=[GAN.z_name, "gen_grid_params"]) g.add_node(dcgan_generator(nb_dcgan_units, input_dim=g_input_dim), "dcgan_generator", inputs=[GAN.z_name, "gen_grid_params"]) g.add_node(PyramidBlending(tag_mean), GAN.generator_name, inputs=["dcgan_generator", "gen_grid_idx"], concat_axis=1) return g