def testAttentionModelWithInferIndices(self): hparams = common_test_utils.create_test_hparams( encoder_type="uni", num_layers=1, attention="scaled_luong", attention_architecture="standard", use_residual=False, inference_indices=[1, 2]) # TODO(rzhao): Make infer indices support batch_size > 1. hparams.infer_batch_size = 1 vocab_prefix = "nmt/testdata/test_infer_vocab" hparams.add_hparam("src_vocab_file", vocab_prefix + "." + hparams.src) hparams.add_hparam("tgt_vocab_file", vocab_prefix + "." + hparams.tgt) infer_file = "nmt/testdata/test_infer_file" out_dir = os.path.join(tf.test.get_temp_dir(), "attention_infer_with_indices") hparams.add_hparam("out_dir", out_dir) os.makedirs(out_dir) output_infer = os.path.join(out_dir, "output_infer") ckpt = self._createTestInferCheckpoint(hparams, out_dir) inference.inference(ckpt, infer_file, output_infer, hparams) with open(output_infer) as f: self.assertEqual(2, len(list(f))) self.assertTrue(os.path.exists(output_infer + str(1) + ".png")) self.assertTrue(os.path.exists(output_infer + str(2) + ".png"))
def testMultiWorkers(self): hparams = common_test_utils.create_test_hparams( encoder_type="uni", num_layers=2, attention="scaled_luong", attention_architecture="standard", use_residual=False, ) vocab_prefix = "nmt/testdata/test_infer_vocab" hparams.add_hparam("src_vocab_file", vocab_prefix + "." + hparams.src) hparams.add_hparam("tgt_vocab_file", vocab_prefix + "." + hparams.tgt) infer_file = "nmt/testdata/test_infer_file" out_dir = os.path.join(tf.test.get_temp_dir(), "multi_worker_infer") hparams.add_hparam("out_dir", out_dir) os.makedirs(out_dir) output_infer = os.path.join(out_dir, "output_infer") num_workers = 3 # There are 5 examples, make batch_size=3 makes job0 has 3 examples, job1 # has 2 examples, and job2 has 0 example. This helps testing some edge # cases. hparams.batch_size = 3 ckpt = self._createTestInferCheckpoint(hparams, out_dir) inference.inference(ckpt, infer_file, output_infer, hparams, num_workers, jobid=1) inference.inference(ckpt, infer_file, output_infer, hparams, num_workers, jobid=2) # Note: Need to start job 0 at the end; otherwise, it will block the testing # thread. inference.inference(ckpt, infer_file, output_infer, hparams, num_workers, jobid=0) with open(output_infer) as f: self.assertEqual(5, len(list(f)))
def testBasicModel(self): hparams = common_test_utils.create_test_hparams( encoder_type="uni", num_layers=1, attention="", attention_architecture="", use_residual=False, ) vocab_prefix = "nmt/testdata/test_infer_vocab" hparams.add_hparam("src_vocab_file", vocab_prefix + "." + hparams.src) hparams.add_hparam("tgt_vocab_file", vocab_prefix + "." + hparams.tgt) infer_file = "nmt/testdata/test_infer_file" out_dir = os.path.join(tf.test.get_temp_dir(), "basic_infer") hparams.add_hparam("out_dir", out_dir) os.makedirs(out_dir) output_infer = os.path.join(out_dir, "output_infer") ckpt = self._createTestInferCheckpoint(hparams, out_dir) inference.inference(ckpt, infer_file, output_infer, hparams) with open(output_infer) as f: self.assertEqual(5, len(list(f)))