def test_init(self):
     hparams = common_test_utils.create_test_hparams()
     iterator, src_vocab_table, tgt_vocab_table = common_test_utils.create_test_iterator(
         hparams, tf.estimator.ModeKeys.TRAIN)
     model_creator = model.Model
     train_model = model_creator(hparams, tf.estimator.ModeKeys.TRAIN,
                                 iterator, src_vocab_table, tgt_vocab_table)
Exemplo n.º 2
0
 def testAttentionModel(self):
     hparams = common_test_utils.create_test_hparams(
         encoder_type="uni",
         num_layers=1,
         attention="scaled_luong",
         attention_architecture="standard",
         use_residual=False,
     )
     ckpt_path = self._createTestInferCheckpoint(hparams, "attention_infer")
     infer_file = "nmt/testdata/test_infer_file"
     output_infer = os.path.join(hparams.out_dir, "output_infer")
     inference.inference(ckpt_path, infer_file, output_infer, hparams)
     with open(output_infer) as f:
         self.assertEqual(5, len(list(f)))
Exemplo n.º 3
0
 def testBasicModelWithInferIndices(self):
     hparams = common_test_utils.create_test_hparams(
         encoder_type="uni",
         num_layers=1,
         attention="",
         attention_architecture="",
         use_residual=False,
         inference_indices=[0])
     ckpt_path = self._createTestInferCheckpoint(
         hparams, "basic_infer_with_indices")
     infer_file = "nmt/testdata/test_infer_file"
     output_infer = os.path.join(hparams.out_dir, "output_infer")
     inference.inference(ckpt_path, infer_file, output_infer, hparams)
     with open(output_infer) as f:
         self.assertEqual(1, len(list(f)))
Exemplo n.º 4
0
    def testMultiWorkers(self):
        hparams = common_test_utils.create_test_hparams(
            encoder_type="uni",
            num_layers=2,
            attention="scaled_luong",
            attention_architecture="standard",
            use_residual=False,
        )

        num_workers = 3

        # There are 5 examples, make batch_size=3 makes job0 has 3 examples, job1
        # has 2 examples, and job2 has 0 example. This helps testing some edge
        # cases.
        hparams.batch_size = 3

        ckpt_path = self._createTestInferCheckpoint(hparams,
                                                    "multi_worker_infer")
        infer_file = "nmt/testdata/test_infer_file"
        output_infer = os.path.join(hparams.out_dir, "output_infer")
        inference.inference(ckpt_path,
                            infer_file,
                            output_infer,
                            hparams,
                            num_workers,
                            jobid=1)

        inference.inference(ckpt_path,
                            infer_file,
                            output_infer,
                            hparams,
                            num_workers,
                            jobid=2)

        # Note: Need to start job 0 at the end; otherwise, it will block the testing
        # thread.
        inference.inference(ckpt_path,
                            infer_file,
                            output_infer,
                            hparams,
                            num_workers,
                            jobid=0)

        with open(output_infer) as f:
            self.assertEqual(5, len(list(f)))
Exemplo n.º 5
0
    def testBasicModelWithMultipleTranslations(self):
        hparams = common_test_utils.create_test_hparams(
            encoder_type="uni",
            num_layers=1,
            attention="",
            attention_architecture="",
            use_residual=False,
            num_translations_per_input=2,
            beam_width=2,
        )
        hparams.infer_mode = "beam_search"

        ckpt_path = self._createTestInferCheckpoint(hparams,
                                                    "multi_basic_infer")
        infer_file = "nmt/testdata/test_infer_file"
        output_infer = os.path.join(hparams.out_dir, "output_infer")
        inference.inference(ckpt_path, infer_file, output_infer, hparams)
        with open(output_infer) as f:
            self.assertEqual(10, len(list(f)))
Exemplo n.º 6
0
 def testAttentionModelWithInferIndices(self):
     hparams = common_test_utils.create_test_hparams(
         encoder_type="uni",
         num_layers=1,
         attention="scaled_luong",
         attention_architecture="standard",
         use_residual=False,
         inference_indices=[1, 2])
     # TODO(rzhao): Make infer indices support batch_size > 1.
     hparams.infer_batch_size = 1
     ckpt_path = self._createTestInferCheckpoint(
         hparams, "attention_infer_with_indices")
     infer_file = "nmt/testdata/test_infer_file"
     output_infer = os.path.join(hparams.out_dir, "output_infer")
     inference.inference(ckpt_path, infer_file, output_infer, hparams)
     with open(output_infer) as f:
         self.assertEqual(2, len(list(f)))
     self.assertTrue(os.path.exists(output_infer + str(1) + ".png"))
     self.assertTrue(os.path.exists(output_infer + str(2) + ".png"))