コード例 #1
0
 def testSaveLoadTrie(self):
     trie = trie_decoder_utils.DecoderTrie(TrieDecoderUtilsTest.vocab_path)
     for question, answer in TrieDecoderUtilsTest.test_trie_entries:
         trie.insert_question(question, answer)
     NEW_KEY = 'newkey'
     self.assertNotIn(NEW_KEY, trie)
     self.assertFalse(os.path.exists(TrieDecoderUtilsTest.save_path))
     trie.save_to_file(TrieDecoderUtilsTest.save_path)
     self.assertTrue(os.path.exists(TrieDecoderUtilsTest.save_path))
     trie[NEW_KEY] = 'foo'
     self.assertIn(NEW_KEY, trie)
     loaded_trie = trie_decoder_utils.DecoderTrie.load_from_file(
         TrieDecoderUtilsTest.save_path)
     self.assertNotIn(NEW_KEY, loaded_trie)
コード例 #2
0
    def __init__(self, hparams_path, source_prefix, out_dir,
                 environment_server_address):
        """Constructor for the reformulator.

    Args:
      hparams_path: Path to json hparams file.
      source_prefix: A prefix that is added to every question before
        translation which should be used for adding tags like <en> <2en>.
        Can be empty or None in which case the prefix is ignored.
      out_dir: Directory where the model output will be written.
      environment_server_address: Address of the environment server.

    Raises:
      ValueError: if model architecture is not known.
    """

        self.hparams = load_hparams(hparams_path, out_dir)
        assert self.hparams.num_buckets == 1, "No bucketing when in server mode."
        assert not self.hparams.server_mode, (
            "server_mode set to True but not "
            "running as server.")

        self.hparams.environment_server = environment_server_address
        if self.hparams.subword_option == "spm":
            self.sentpiece = sentencepiece_processor.SentencePieceProcessor()
            self.sentpiece.Load(self.hparams.subword_model.encode("utf-8"))
        self.source_prefix = source_prefix

        # Create the model
        if not self.hparams.attention:
            model_creator = nmt_model.Model
        elif self.hparams.attention_architecture == "standard":
            model_creator = attention_model.AttentionModel
        elif self.hparams.attention_architecture in ["gnmt", "gnmt_v2"]:
            model_creator = gnmt_model.GNMTModel
        else:
            raise ValueError("Unknown model architecture")

        self.trie = trie_decoder_utils.DecoderTrie(
            vocab_path=self.hparams.tgt_vocab_file,
            eos_token=self.hparams.eos,
            subword_option=self.hparams.subword_option,
            subword_model=self.hparams.get("subword_model"),
            optimize_ngrams_len=self.hparams.optimize_ngrams_len)
        if self.hparams.trie_path is not None and tf.gfile.Exists(
                self.hparams.trie_path):
            self.trie.populate_from_text_file(self.hparams.trie_path)

        combined_graph = tf.Graph()
        self.train_model = model_helper.create_train_model(
            model_creator,
            self.hparams,
            graph=combined_graph,
            trie=self.trie,
            use_placeholders=True)

        # Create different inference models for beam search, sampling and greedy
        # decoding.
        default_infer_mode = self.hparams.infer_mode
        default_beam_width = self.hparams.beam_width
        self.infer_models = {}
        self.hparams.use_rl = False
        self.hparams.infer_mode = "greedy"
        self.hparams.beam_width = 0
        self.infer_models[reformulator_pb2.ReformulatorRequest.
                          GREEDY] = model_helper.create_infer_model(
                              model_creator,
                              self.hparams,
                              graph=combined_graph)

        self.hparams.infer_mode = "sample"
        self.hparams.beam_width = 0
        self.infer_models[reformulator_pb2.ReformulatorRequest.
                          SAMPLING] = model_helper.create_infer_model(
                              model_creator,
                              self.hparams,
                              graph=combined_graph)

        self.hparams.infer_mode = "beam_search"
        self.hparams.beam_width = max(1, default_beam_width)
        self.infer_models[reformulator_pb2.ReformulatorRequest.
                          BEAM_SEARCH] = model_helper.create_infer_model(
                              model_creator,
                              self.hparams,
                              graph=combined_graph)

        self.hparams.infer_mode = "trie_greedy"
        self.hparams.beam_width = 0
        self.infer_models[reformulator_pb2.ReformulatorRequest.
                          TRIE_GREEDY] = model_helper.create_infer_model(
                              model_creator,
                              self.hparams,
                              graph=combined_graph,
                              trie=self.trie)
        self.hparams.infer_mode = default_infer_mode
        self.hparams.beam_width = default_beam_width

        self.hparams.infer_mode = "trie_sample"
        self.hparams.beam_width = 0
        self.infer_models[reformulator_pb2.ReformulatorRequest.
                          TRIE_SAMPLE] = model_helper.create_infer_model(
                              model_creator,
                              self.hparams,
                              graph=combined_graph,
                              trie=self.trie)

        self.hparams.infer_mode = "trie_beam_search"
        self.hparams.beam_width = max(1, default_beam_width)
        self.infer_models[reformulator_pb2.ReformulatorRequest.
                          TRIE_BEAM_SEARCH] = model_helper.create_infer_model(
                              model_creator,
                              self.hparams,
                              graph=combined_graph,
                              trie=self.trie)

        self.hparams.use_rl = True
        self.sess = tf.Session(graph=combined_graph,
                               config=misc_utils.get_config_proto())

        with combined_graph.as_default():
            # p1 = "C:/Users/aanamika/Documents/QuestionGeneration/active-qa-master/tmp/active-qa/translate.ckpt-1460356"
            p1 = 'C:/Users/aanamika/Documents/QuestionGeneration/active-qa-master/tmp/active-qa/temp/translate.ckpt-1460356'
            print('p1:', p1)
            self.sess.run(tf.global_variables_initializer())
            self.sess.run(tf.tables_initializer())
            _, global_step = model_helper.create_or_load_model(
                self.train_model.model, p1, self.sess, "train")
            # self.train_model.model, out_dir, self.sess, "train")
            self.last_save_step = global_step

        self.summary_writer = tf.summary.FileWriter(
            os.path.join(out_dir, "train_log"), self.train_model.graph)
        self.checkpoint_path = os.path.join(out_dir, "translate.ckpt")
        self.trie_save_path = os.path.join(out_dir, "trie")
コード例 #3
0
 def testCreateEmptyTrie(self):
     eos_token = '</s>'
     trie = trie_decoder_utils.DecoderTrie(TrieDecoderUtilsTest.vocab_path,
                                           eos_token)
     self.assertLen(trie, 0)
     self.assertEqual(trie.eos_idx, '2')