FLAGS = tf.flags.FLAGS if __name__ == '__main__': tf.logging.set_verbosity(tf.logging.INFO) data_dir = FLAGS.vivi_data_dir, problem_name = FLAGS.vivi_problem, ckpt_path = FLAGS.vivi_ckpt # Convert directory into checkpoints if tf.io.gfile.isdir(ckpt_path): ckpt_path = tf.train.latest_checkpoint(ckpt_path) # For back translation, we need a temporary file in the other language # before back-translating into the source language. tmp_file = os.path.join('{}.tmp.vivi.txt'.format( FLAGS.paraphrase_from_file)) if FLAGS.vivi_interactively: print("%s %s %s" % (data_dir, problem_name, ckpt_path)) #decoding.vivi_interactively(problem_name, data_dir, ckpt_path) decoding.vivi_interactively('translate_vivi', './data/translate_vivi', ckpt_path) else: # Step 1: Translating from source language to the other language. if not tf.io.gfile.exists(tmp_file): decoding.t2t_decoder(problem_name, data_dir, FLAGS.paraphrase_from_file, tmp_file, ckpt_path)
from tensor2tensor.utils import registry flags = tf.flags FLAGS = flags.FLAGS @registry.register_hparams def transformer_tall9(): hparams = transformer.transformer_big() hparams.hidden_size = 768 hparams.filter_size = 3072 hparams.num_hidden_layers = 9 hparams.num_heads = 12 return hparams @registry.register_hparams def transformer_tall_18_18(): hparams = transformer_tall9() hparams.num_encoder_layers = 18 hparams.num_decoder_layers = 18 return hparams if __name__ == '__main__': tf.logging.set_verbosity(tf.logging.INFO) # tf.app.run(t2t_decoder.main) decoding.t2t_decoder(FLAGS.problem, FLAGS.data_dir, FLAGS.decode_from_file, FLAGS.decode_to_file, FLAGS.checkpoint_path or FLAGS.output_dir)
if __name__ == '__main__': tf.logging.set_verbosity(tf.logging.INFO) # Convert directory into checkpoints from_ckpt = FLAGS.from_ckpt to_ckpt = FLAGS.to_ckpt if tf.gfile.IsDirectory(FLAGS.from_ckpt): from_ckpt = tf.train.latest_checkpoint(FLAGS.from_ckpt) if tf.gfile.IsDirectory(FLAGS.to_ckpt): to_ckpt = tf.train.latest_checkpoint(FLAGS.to_ckpt) if FLAGS.backtranslate_interactively: decoding.backtranslate_interactively(FLAGS.from_problem, FLAGS.to_problem, FLAGS.from_data_dir, FLAGS.to_data_dir, FLAGS.from_ckpt, FLAGS.to_ckpt) else: # For back translation from file, we need a temporary file in the other language # before back-translating into the source language. tmp_file = os.path.join('{}.tmp.txt'.format( FLAGS.paraphrase_from_file)) # Step 1: Translating from source language to the other language. decoding.t2t_decoder(FLAGS.from_problem, FLAGS.from_data_dir, FLAGS.paraphrase_from_file, tmp_file, from_ckpt) # Step 2: Translating from the other language (tmp_file) to source. decoding.t2t_decoder(FLAGS.to_problem, FLAGS.to_data_dir, tmp_file, FLAGS.paraphrase_to_file, to_ckpt)