示例#1
0
def main(argv):
    input_fns, charset = make_wikitext_char_input_fn(
        data_dir=tf.flags.FLAGS.data_dir,
        batch_length=tf.flags.FLAGS.batch_length,
        batch_size=tf.flags.FLAGS.batch_size
    )
    layers = 6
    kernel_size = 5
    model_mode = 'cnn'
    config = TextConfig(
        model_fn=make_model_gan_fn(
            charset=charset,
            decoder_fn=make_decoder_cnn_fn(
                bn=True,
                layers=layers,
                kernel_size=kernel_size,
                padding='valid'),
            gan_discriminator_fn=make_discriminator_gan_cnn_ml_fn(
                layers=6,
                kernel_size=5,
                padding='valid'),
            gan_loss_fn=wgan_losses,
            model_mode=model_mode,
            dis_opt=tf.train.AdamOptimizer(1e-5, name='dis_opt'),
            gen_opt=tf.train.AdamOptimizer(1e-5, name='gen_opt'),
            combined=True,
            padding_size=layers * (kernel_size - 1)
        ),
        input_fns=input_fns,
        mode=model_mode
    )
    text_aae.trainer.main(argv, config=config)
示例#2
0
def main(argv):
    input_fns, charset = make_wikitext_char_input_fn(
        data_dir=tf.flags.FLAGS.data_dir,
        batch_length=tf.flags.FLAGS.batch_length,
        batch_size=tf.flags.FLAGS.batch_size)
    model_mode = 'cnn'
    config = TextConfig(model_fn=make_model_vae_fn(charset=charset,
                                                   encoder_fn=encoder_cnn_fn,
                                                   decoder_fn=decoder_cnn_fn,
                                                   model_mode=model_mode),
                        input_fns=input_fns,
                        mode=model_mode)
    text_aae.trainer.main(argv, config=config)
示例#3
0
def main(argv):
    input_fns, charset = make_wikitext_char_input_fn(
        data_dir=tf.flags.FLAGS.data_dir,
        batch_length=tf.flags.FLAGS.batch_length,
        batch_size=tf.flags.FLAGS.batch_size
    )
    layers = 9
    kernel_size = 3
    model_mode = 'cnn'
    config = TextConfig(
        model_fn=make_model_gan_fn(
            charset=charset,
            decoder_fn=make_decoder_resnet_cnn_fn(
                bn=True,
                layers=layers,
                kernel_size=kernel_size,
                activation_fn=tf.nn.leaky_relu,
                padding='valid'),
            gan_discriminator_fn=make_discriminator_gan_cnn_ml_fn(
                bn_fn=make_batch_norm(eps=1e-6, clip_var=1.),
                layers=9,
                kernel_size=5,
                padding='valid',
                emedding_scale=10,
                activation_fn=tf.nn.leaky_relu
            ),
            gan_loss_fn=wgan_losses,
            model_mode=model_mode,
            dis_opt=tf.train.AdamOptimizer(1e-5, name='dis_opt'),
            gen_opt=tf.train.AdamOptimizer(1e-5, name='gen_opt'),
            combined=True,
            padding_size=layers * (kernel_size - 1)*2
        ),
        input_fns=input_fns,
        mode=model_mode
    )
    text_aae.trainer.main(argv, config=config)