Exemple #1
0
    def testAttentionModelWithInferIndices(self):
        hparams = common_test_utils.create_test_hparams(
            encoder_type="uni",
            num_layers=1,
            attention="scaled_luong",
            attention_architecture="standard",
            use_residual=False,
            inference_indices=[1, 2])
        # TODO(rzhao): Make infer indices support batch_size > 1.
        hparams.infer_batch_size = 1
        vocab_prefix = "nmt/testdata/test_infer_vocab"
        hparams.src_vocab_file = vocab_prefix + "." + hparams.src
        hparams.tgt_vocab_file = vocab_prefix + "." + hparams.tgt

        infer_file = "nmt/testdata/test_infer_file"
        out_dir = os.path.join(tf.test.get_temp_dir(),
                               "attention_infer_with_indices")
        hparams.out_dir = out_dir
        os.makedirs(out_dir)
        output_infer = os.path.join(out_dir, "output_infer")
        ckpt = self._createTestInferCheckpoint(hparams, out_dir)
        inference.inference(ckpt, infer_file, output_infer, hparams)
        with open(output_infer) as f:
            self.assertEqual(2, len(list(f)))
        self.assertTrue(os.path.exists(output_infer + str(1) + ".png"))
        self.assertTrue(os.path.exists(output_infer + str(2) + ".png"))
Exemple #2
0
 def testBasicModel(self):
     hparams = common_test_utils.create_test_hparams(
         encoder_type='uni',
         num_layers=1,
         attention='',
         attention_architecture='',
         use_residual=False)
     ckpt_path = self._createTestInferenceCheckpint(hparams, 'basic_infer')
     infer_file = 'nmt/testdata/test_infer_file'
     output_infer = os.path.join(hparams.out_dir, 'output_infer')
     inference.inference(ckpt_path, infer_file, output_infer, hparams)
     with open(output_infer) as f:
         self.assertEqual(5, len(list(f)))
Exemple #3
0
    def testMultiWorkers(self):
        hparams = common_test_utils.create_test_hparams(
            encoder_type="uni",
            num_layers=2,
            attention="scaled_luong",
            attention_architecture="standard",
            use_residual=False,
        )
        vocab_prefix = "nmt/testdata/test_infer_vocab"
        hparams.src_vocab_file = vocab_prefix + "." + hparams.src
        hparams.tgt_vocab_file = vocab_prefix + "." + hparams.tgt

        infer_file = "nmt/testdata/test_infer_file"
        out_dir = os.path.join(tf.test.get_temp_dir(), "multi_worker_infer")
        hparams.out_dir = out_dir
        os.makedirs(out_dir)
        output_infer = os.path.join(out_dir, "output_infer")

        num_workers = 3

        # There are 5 examples, make batch_size=3 makes job0 has 3 examples, job1
        # has 2 examples, and job2 has 0 example. This helps testing some edge
        # cases.
        hparams.batch_size = 3

        ckpt = self._createTestInferCheckpoint(hparams, out_dir)
        inference.inference(ckpt,
                            infer_file,
                            output_infer,
                            hparams,
                            num_workers,
                            jobid=1)

        inference.inference(ckpt,
                            infer_file,
                            output_infer,
                            hparams,
                            num_workers,
                            jobid=2)

        # Note: Need to start job 0 at the end; otherwise, it will block the testing
        # thread.
        inference.inference(ckpt,
                            infer_file,
                            output_infer,
                            hparams,
                            num_workers,
                            jobid=0)

        with open(output_infer) as f:
            self.assertEqual(5, len(list(f)))
Exemple #4
0
 def testAttentionModel(self):
     hparams = common_test_utils.create_test_hparams(
         encoder_type="uni",
         num_layers=1,
         attention="scaled_luong",
         attention_architecture="standard",
         use_residual=False,
     )
     ckpt_path = self._createTestInferenceCheckpint(hparams,
                                                    "attention_infer")
     infer_file = "nmt/testdata/test_infer_file"
     output_infer = os.path.join(hparams.out_dir, "output_infer")
     inference.inference(ckpt_path, infer_file, output_infer, hparams)
     with open(output_infer) as f:
         self.assertEqual(5, len(list(f)))
Exemple #5
0
 def testBasicModelWithInferIndices(self):
     hparams = common_test_utils.create_test_hparams(
         encoder_type="uni",
         num_layers=1,
         attention="",
         attention_architecture="",
         use_residual=False,
         inference_indices=[0])
     ckpt_path = self._createTestInferenceCheckpint(
         hparams, "basic_infer_with_indices")
     infer_file = "nmt/testdata/test_infer_file"
     output_infer = os.path.join(hparams.out_dir, "output_infer")
     inference.inference(ckpt_path, infer_file, output_infer, hparams)
     with open(output_infer) as f:
         self.assertEqual(1, len(list(f)))
Exemple #6
0
    def testBasicModelWithMultipleTranslations(self):
        hparams = common_test_utils.create_test_hparams(
            encoder_type='uni',
            num_layers=1,
            attention='',
            attention_architecture='',
            use_residual=False,
            num_translations_per_input=2,
            beam_width=2)
        hparams.infer_mode = 'beam_search'

        ckpt_path = self._createTestInferenceCheckpint(hparams,
                                                       'multi_basice_infer')
        infer_file = 'nmt/testdata/test_infer_file'
        output_infer = os.path.join(hparams.out_dir, 'output_infer')
        inference.inference(ckpt_path, infer_file, output_infer, hparams)
        with open(output_infer) as f:
            self.assertEqual(10, len(list(f)))
Exemple #7
0
 def testAttentionModelWithInferIndices(self):
     hparams = common_test_utils.create_test_hparams(
         encoder_type="uni",
         num_layers=1,
         attention="scaled_luong",
         attention_architecture="standard",
         use_residual=False,
         inference_indices=[1, 2])
     # TODO(rzhao): Make infer indices support batch_size > 1.
     hparams.infer_batch_size = 1
     ckpt_path = self._createTestInferenceCheckpint(
         hparams, "attention_infer_with_indices")
     infer_file = "nmt/testdata/test_infer_file"
     output_infer = os.path.join(hparams.out_dir, "output_infer")
     inference.inference(ckpt_path, infer_file, output_infer, hparams)
     with open(output_infer) as f:
         self.assertEqual(2, len(list(f)))
     self.assertTrue(os.path.exists(output_infer + str(1) + ".png"))
     self.assertTrue(os.path.exists(output_infer + str(2) + ".png"))
Exemple #8
0
    def testBasicModel(self):
        hparams = common_test_utils.create_test_hparams(
            encoder_type="uni",
            num_layers=1,
            attention="",
            attention_architecture="",
            use_residual=False,
        )
        vocab_prefix = "nmt/testdata/test_infer_vocab"
        hparams.src_vocab_file = vocab_prefix + "." + hparams.src
        hparams.tgt_vocab_file = vocab_prefix + "." + hparams.tgt

        infer_file = "nmt/testdata/test_infer_file"
        out_dir = os.path.join(tf.test.get_temp_dir(), "basic_infer")
        hparams.out_dir = out_dir
        os.makedirs(out_dir)
        output_infer = os.path.join(out_dir, "output_infer")
        ckpt = self._createTestInferCheckpoint(hparams, out_dir)
        inference.inference(ckpt, infer_file, output_infer, hparams)
        with open(output_infer) as f:
            self.assertEqual(5, len(list(f)))
Exemple #9
0
    def _testGNMTModel(self, architecture):
        hparams = common_test_utils.create_test_hparams(
            encoder_type='gnmt',
            num_layers=4,
            attention='scaled_luong',
            attention_architecture=architecture)

        workers, _ = tf.test.create_local_cluster(1, 0)
        worker = workers[0]

        # pylint: disable=line-too-long
        expected_var_names = [
            'dynamic_seq2seq/encoder/embedding_encoder:0',
            'dynamic_seq2seq/decoder/embedding_decoder:0',
            'dynamic_seq2seq/encoder/bidirectional_rnn/fw/basic_lstm_cell/kernel:0',
            'dynamic_seq2seq/encoder/bidirectional_rnn/fw/basic_lstm_cell/bias:0',
            'dynamic_seq2seq/encoder/bidirectional_rnn/bw/basic_lstm_cell/kernel:0',
            'dynamic_seq2seq/encoder/bidirectional_rnn/bw/basic_lstm_cell/bias:0',
            'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel:0',
            'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias:0',
            'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel:0',
            'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias:0',
            'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_2/basic_lstm_cell/kernel:0',
            'dynamic_seq2seq/encoder/rnn/multi_rnn_cell/cell_2/basic_lstm_cell/bias:0',
            'dynamic_seq2seq/decoder/memory_layer/kernel:0',
            'dynamic_seq2seq/decoder/multi_rnn_cell/cell_0_attention/attention/basic_lstm_cell/kernel:0',
            'dynamic_seq2seq/decoder/multi_rnn_cell/cell_0_attention/attention/basic_lstm_cell/bias:0',
            'dynamic_seq2seq/decoder/multi_rnn_cell/cell_0_attention/attention/luong_attention/attention_g:0',
            'dynamic_seq2seq/decoder/multi_rnn_cell/cell_1/basic_lstm_cell/kernel:0',
            'dynamic_seq2seq/decoder/multi_rnn_cell/cell_1/basic_lstm_cell/bias:0',
            'dynamic_seq2seq/decoder/multi_rnn_cell/cell_2/basic_lstm_cell/kernel:0',
            'dynamic_seq2seq/decoder/multi_rnn_cell/cell_2/basic_lstm_cell/bias:0',
            'dynamic_seq2seq/decoder/multi_rnn_cell/cell_3/basic_lstm_cell/kernel:0',
            'dynamic_seq2seq/decoder/multi_rnn_cell/cell_3/basic_lstm_cell/bias:0',
            'dynamic_seq2seq/decoder/output_projection/kernel:0'
        ]
        # pylint: enable=line-too-long

        test_prefix = 'GNMTModel_%s' % architecture
        with tf.Graph().as_default():
            with tf.Session(worker.target,
                            config=self._get_session_config()) as sess:
                train_m = self._createTestTrainModel(gnmt_model.GNMTModel,
                                                     hparams, sess)

                m_vars = tf.trainable_variables()
                self._assertModelVariableNames(expected_var_names,
                                               [v.name for v in m_vars],
                                               test_prefix)
                with tf.variable_scope('dynamic_seq2seq', reuse=True):
                    last_enc_weight = tf.get_variable(
                        'encoder/rnn/multi_rnn_cell/cell_2/basic_lstm_cell/kernel'
                    )
                    last_dec_weight = tf.get_variable(
                        'decoder/multi_rnn_cell/cell_3/basic_lstm_cell/kernel')
                    mem_layer_weight = tf.get_variable(
                        'decoder/memory_layer/kernel')
                self._assertTrainStepsLoss(train_m, sess, test_prefix)

                self._assertModelVariable(last_enc_weight, sess,
                                          '%s/last_enc_weight' % test_prefix)
                self._assertModelVariable(last_dec_weight, sess,
                                          '%s/last_dec_weight' % test_prefix)
                self._assertModelVariable(mem_layer_weight, sess,
                                          '%s/mem_layer_weight' % test_prefix)

        with tf.Graph().as_default():
            with tf.Session(worker.target,
                            config=self._get_session_config()) as sess:
                eval_m = self._createTestEvalModel(gnmt_model.GNMTModel,
                                                   hparams, sess)
                self._assertEvalLossAndPredictCount(eval_m, sess, test_prefix)

        with tf.Graph().as_default():
            with tf.Session(worker.target,
                            config=self._get_session_config()) as sess:
                infer_m = self._createTestInferModel(gnmt_model.GNMTModel,
                                                     hparams, sess)
                self._assertInferLogits(infer_m, sess, test_prefix)