Ejemplo n.º 1
0
FLAGS = tf.flags.FLAGS

if __name__ == '__main__':
    tf.logging.set_verbosity(tf.logging.INFO)

    data_dir = FLAGS.vivi_data_dir,
    problem_name = FLAGS.vivi_problem,
    ckpt_path = FLAGS.vivi_ckpt

    # Convert directory into checkpoints
    if tf.io.gfile.isdir(ckpt_path):
        ckpt_path = tf.train.latest_checkpoint(ckpt_path)

    # For back translation, we need a temporary file in the other language
    # before back-translating into the source language.
    tmp_file = os.path.join('{}.tmp.vivi.txt'.format(
        FLAGS.paraphrase_from_file))

    if FLAGS.vivi_interactively:
        print("%s %s %s" % (data_dir, problem_name, ckpt_path))
        #decoding.vivi_interactively(problem_name, data_dir, ckpt_path)
        decoding.vivi_interactively('translate_vivi', './data/translate_vivi',
                                    ckpt_path)
    else:
        # Step 1: Translating from source language to the other language.
        if not tf.io.gfile.exists(tmp_file):
            decoding.t2t_decoder(problem_name, data_dir,
                                 FLAGS.paraphrase_from_file, tmp_file,
                                 ckpt_path)
Ejemplo n.º 2
0
from tensor2tensor.utils import registry

flags = tf.flags
FLAGS = flags.FLAGS


@registry.register_hparams
def transformer_tall9():
    hparams = transformer.transformer_big()
    hparams.hidden_size = 768
    hparams.filter_size = 3072
    hparams.num_hidden_layers = 9
    hparams.num_heads = 12
    return hparams


@registry.register_hparams
def transformer_tall_18_18():
    hparams = transformer_tall9()
    hparams.num_encoder_layers = 18
    hparams.num_decoder_layers = 18
    return hparams


if __name__ == '__main__':
    tf.logging.set_verbosity(tf.logging.INFO)
    # tf.app.run(t2t_decoder.main)
    decoding.t2t_decoder(FLAGS.problem, FLAGS.data_dir, FLAGS.decode_from_file,
                         FLAGS.decode_to_file, FLAGS.checkpoint_path
                         or FLAGS.output_dir)
Ejemplo n.º 3
0
if __name__ == '__main__':
    tf.logging.set_verbosity(tf.logging.INFO)

    # Convert directory into checkpoints
    from_ckpt = FLAGS.from_ckpt
    to_ckpt = FLAGS.to_ckpt
    if tf.gfile.IsDirectory(FLAGS.from_ckpt):
        from_ckpt = tf.train.latest_checkpoint(FLAGS.from_ckpt)
    if tf.gfile.IsDirectory(FLAGS.to_ckpt):
        to_ckpt = tf.train.latest_checkpoint(FLAGS.to_ckpt)

    if FLAGS.backtranslate_interactively:
        decoding.backtranslate_interactively(FLAGS.from_problem,
                                             FLAGS.to_problem,
                                             FLAGS.from_data_dir,
                                             FLAGS.to_data_dir,
                                             FLAGS.from_ckpt, FLAGS.to_ckpt)
    else:
        # For back translation from file, we need a temporary file in the other language
        # before back-translating into the source language.
        tmp_file = os.path.join('{}.tmp.txt'.format(
            FLAGS.paraphrase_from_file))

        # Step 1: Translating from source language to the other language.
        decoding.t2t_decoder(FLAGS.from_problem, FLAGS.from_data_dir,
                             FLAGS.paraphrase_from_file, tmp_file, from_ckpt)

        # Step 2: Translating from the other language (tmp_file) to source.
        decoding.t2t_decoder(FLAGS.to_problem, FLAGS.to_data_dir, tmp_file,
                             FLAGS.paraphrase_to_file, to_ckpt)