コード例 #1
0
    def testAttentionModelWithInferIndices(self):
        hparams = common_test_utils.create_test_hparams(
            encoder_type="uni",
            num_layers=1,
            attention="scaled_luong",
            attention_architecture="standard",
            use_residual=False,
            inference_indices=[1, 2])
        # TODO(rzhao): Make infer indices support batch_size > 1.
        hparams.infer_batch_size = 1
        vocab_prefix = "nmt/testdata/test_infer_vocab"
        hparams.src_vocab_file = vocab_prefix + "." + hparams.src
        hparams.tgt_vocab_file = vocab_prefix + "." + hparams.tgt

        infer_file = "nmt/testdata/test_infer_file"
        out_dir = os.path.join(tf.test.get_temp_dir(),
                               "attention_infer_with_indices")
        hparams.out_dir = out_dir
        os.makedirs(out_dir)
        output_infer = os.path.join(out_dir, "output_infer")
        ckpt = self._createTestInferCheckpoint(hparams, out_dir)
        inference.inference(ckpt, infer_file, output_infer, hparams)
        with open(output_infer) as f:
            self.assertEqual(2, len(list(f)))
        self.assertTrue(os.path.exists(output_infer + str(1) + ".png"))
        self.assertTrue(os.path.exists(output_infer + str(2) + ".png"))
コード例 #2
0
 def testBasicModel(self):
     hparams = common_test_utils.create_test_hparams(
         encoder_type='uni',
         num_layers=1,
         attention='',
         attention_architecture='',
         use_residual=False)
     ckpt_path = self._createTestInferenceCheckpint(hparams, 'basic_infer')
     infer_file = 'nmt/testdata/test_infer_file'
     output_infer = os.path.join(hparams.out_dir, 'output_infer')
     inference.inference(ckpt_path, infer_file, output_infer, hparams)
     with open(output_infer) as f:
         self.assertEqual(5, len(list(f)))
コード例 #3
0
 def testAttentionModel(self):
     hparams = common_test_utils.create_test_hparams(
         encoder_type="uni",
         num_layers=1,
         attention="scaled_luong",
         attention_architecture="standard",
         use_residual=False,
     )
     ckpt_path = self._createTestInferenceCheckpint(hparams,
                                                    "attention_infer")
     infer_file = "nmt/testdata/test_infer_file"
     output_infer = os.path.join(hparams.out_dir, "output_infer")
     inference.inference(ckpt_path, infer_file, output_infer, hparams)
     with open(output_infer) as f:
         self.assertEqual(5, len(list(f)))
コード例 #4
0
 def testBasicModelWithInferIndices(self):
     hparams = common_test_utils.create_test_hparams(
         encoder_type="uni",
         num_layers=1,
         attention="",
         attention_architecture="",
         use_residual=False,
         inference_indices=[0])
     ckpt_path = self._createTestInferenceCheckpint(
         hparams, "basic_infer_with_indices")
     infer_file = "nmt/testdata/test_infer_file"
     output_infer = os.path.join(hparams.out_dir, "output_infer")
     inference.inference(ckpt_path, infer_file, output_infer, hparams)
     with open(output_infer) as f:
         self.assertEqual(1, len(list(f)))
コード例 #5
0
    def testBasicModelWithMultipleTranslations(self):
        hparams = common_test_utils.create_test_hparams(
            encoder_type='uni',
            num_layers=1,
            attention='',
            attention_architecture='',
            use_residual=False,
            num_translations_per_input=2,
            beam_width=2)
        hparams.infer_mode = 'beam_search'

        ckpt_path = self._createTestInferenceCheckpint(hparams,
                                                       'multi_basice_infer')
        infer_file = 'nmt/testdata/test_infer_file'
        output_infer = os.path.join(hparams.out_dir, 'output_infer')
        inference.inference(ckpt_path, infer_file, output_infer, hparams)
        with open(output_infer) as f:
            self.assertEqual(10, len(list(f)))
コード例 #6
0
 def testAttentionModelWithInferIndices(self):
     hparams = common_test_utils.create_test_hparams(
         encoder_type="uni",
         num_layers=1,
         attention="scaled_luong",
         attention_architecture="standard",
         use_residual=False,
         inference_indices=[1, 2])
     # TODO(rzhao): Make infer indices support batch_size > 1.
     hparams.infer_batch_size = 1
     ckpt_path = self._createTestInferenceCheckpint(
         hparams, "attention_infer_with_indices")
     infer_file = "nmt/testdata/test_infer_file"
     output_infer = os.path.join(hparams.out_dir, "output_infer")
     inference.inference(ckpt_path, infer_file, output_infer, hparams)
     with open(output_infer) as f:
         self.assertEqual(2, len(list(f)))
     self.assertTrue(os.path.exists(output_infer + str(1) + ".png"))
     self.assertTrue(os.path.exists(output_infer + str(2) + ".png"))
コード例 #7
0
    def testBasicModel(self):
        hparams = common_test_utils.create_test_hparams(
            encoder_type="uni",
            num_layers=1,
            attention="",
            attention_architecture="",
            use_residual=False,
        )
        vocab_prefix = "nmt/testdata/test_infer_vocab"
        hparams.src_vocab_file = vocab_prefix + "." + hparams.src
        hparams.tgt_vocab_file = vocab_prefix + "." + hparams.tgt

        infer_file = "nmt/testdata/test_infer_file"
        out_dir = os.path.join(tf.test.get_temp_dir(), "basic_infer")
        hparams.out_dir = out_dir
        os.makedirs(out_dir)
        output_infer = os.path.join(out_dir, "output_infer")
        ckpt = self._createTestInferCheckpoint(hparams, out_dir)
        inference.inference(ckpt, infer_file, output_infer, hparams)
        with open(output_infer) as f:
            self.assertEqual(5, len(list(f)))
コード例 #8
0
    def testMultiWorkers(self):
        hparams = common_test_utils.create_test_hparams(
            encoder_type="uni",
            num_layers=2,
            attention="scaled_luong",
            attention_architecture="standard",
            use_residual=False,
        )
        vocab_prefix = "nmt/testdata/test_infer_vocab"
        hparams.src_vocab_file = vocab_prefix + "." + hparams.src
        hparams.tgt_vocab_file = vocab_prefix + "." + hparams.tgt

        infer_file = "nmt/testdata/test_infer_file"
        out_dir = os.path.join(tf.test.get_temp_dir(), "multi_worker_infer")
        hparams.out_dir = out_dir
        os.makedirs(out_dir)
        output_infer = os.path.join(out_dir, "output_infer")

        num_workers = 3

        # There are 5 examples, make batch_size=3 makes job0 has 3 examples, job1
        # has 2 examples, and job2 has 0 example. This helps testing some edge
        # cases.
        hparams.batch_size = 3

        ckpt = self._createTestInferCheckpoint(hparams, out_dir)
        inference.inference(ckpt,
                            infer_file,
                            output_infer,
                            hparams,
                            num_workers,
                            jobid=1)

        inference.inference(ckpt,
                            infer_file,
                            output_infer,
                            hparams,
                            num_workers,
                            jobid=2)

        # Note: Need to start job 0 at the end; otherwise, it will block the testing
        # thread.
        inference.inference(ckpt,
                            infer_file,
                            output_infer,
                            hparams,
                            num_workers,
                            jobid=0)

        with open(output_infer) as f:
            self.assertEqual(5, len(list(f)))
コード例 #9
0
      infer_input_file = os.path.join(trans_dir, source_filename + '.csv')
      del_temp_file = True
    else:
      del_temp_file = False

    
    # check model path
    ckpt = args.ckpt
    if not ckpt:
      model_dir = hparams.out_dir
    if args.model_dir:
      model_dir = args.model_dir
    ckpt = tf.train.latest_checkpoint(model_dir)
    
    # decode
    inference.inference(ckpt, infer_input_file, trans_file, hparams)

    if args.rescore:
      if not args.rescore_logdir:
        rescore_dir = os.path.join(model_dir, "post_process")
      post_process.rescore(trans_file, trans_file + "_prob", infer_input_file,
                           rescore_dir, trans_dir + source_filename + "_rescore")
      print("Done")
    
      # Create report if m_mod(21 AAs + 3 tokens) or p_mod (24 AAs + 3 tokens)
      if hparams.tgt_vocab_size == 24:
        report_utils.main(trans_dir, input_dir, 'm-mod')
      elif hparams.tgt_vocab_size == 27:
        report_utils.main(trans_dir, input_dir, 'p-mod')
    
    if del_temp_file: