def main(argv): input_fns, charset = make_wikitext_char_input_fn( data_dir=tf.flags.FLAGS.data_dir, batch_length=tf.flags.FLAGS.batch_length, batch_size=tf.flags.FLAGS.batch_size ) layers = 6 kernel_size = 5 model_mode = 'cnn' config = TextConfig( model_fn=make_model_gan_fn( charset=charset, decoder_fn=make_decoder_cnn_fn( bn=True, layers=layers, kernel_size=kernel_size, padding='valid'), gan_discriminator_fn=make_discriminator_gan_cnn_ml_fn( layers=6, kernel_size=5, padding='valid'), gan_loss_fn=wgan_losses, model_mode=model_mode, dis_opt=tf.train.AdamOptimizer(1e-5, name='dis_opt'), gen_opt=tf.train.AdamOptimizer(1e-5, name='gen_opt'), combined=True, padding_size=layers * (kernel_size - 1) ), input_fns=input_fns, mode=model_mode ) text_aae.trainer.main(argv, config=config)
def main(argv): input_fns, charset = make_wikitext_char_input_fn( data_dir=tf.flags.FLAGS.data_dir, batch_length=tf.flags.FLAGS.batch_length, batch_size=tf.flags.FLAGS.batch_size) model_mode = 'cnn' config = TextConfig(model_fn=make_model_vae_fn(charset=charset, encoder_fn=encoder_cnn_fn, decoder_fn=decoder_cnn_fn, model_mode=model_mode), input_fns=input_fns, mode=model_mode) text_aae.trainer.main(argv, config=config)
def main(argv): input_fns, charset = make_wikitext_char_input_fn( data_dir=tf.flags.FLAGS.data_dir, batch_length=tf.flags.FLAGS.batch_length, batch_size=tf.flags.FLAGS.batch_size ) layers = 9 kernel_size = 3 model_mode = 'cnn' config = TextConfig( model_fn=make_model_gan_fn( charset=charset, decoder_fn=make_decoder_resnet_cnn_fn( bn=True, layers=layers, kernel_size=kernel_size, activation_fn=tf.nn.leaky_relu, padding='valid'), gan_discriminator_fn=make_discriminator_gan_cnn_ml_fn( bn_fn=make_batch_norm(eps=1e-6, clip_var=1.), layers=9, kernel_size=5, padding='valid', emedding_scale=10, activation_fn=tf.nn.leaky_relu ), gan_loss_fn=wgan_losses, model_mode=model_mode, dis_opt=tf.train.AdamOptimizer(1e-5, name='dis_opt'), gen_opt=tf.train.AdamOptimizer(1e-5, name='gen_opt'), combined=True, padding_size=layers * (kernel_size - 1)*2 ), input_fns=input_fns, mode=model_mode ) text_aae.trainer.main(argv, config=config)
def main(argv): config = TextConfig( make_model_fn=make_model_ae_fn ) text_aae.trainer.main(argv, config=config)