コード例 #1
0
    def preprocess(self, examples, volatile=False):
        """Preprocess a batch of EditExamples, converting them into arrays.

        Args:
            examples (list[EditExample])

        Returns:
            EditorInput
        """
        input_words, target_words = self._batch_editor_examples(examples)
        dynamic_vocabs = self._compute_dynamic_vocabs(input_words,
                                                      self.base_vocab)

        dynamic_token_embedder = DynamicMultiVocabTokenEmbedder(
            self.base_source_token_embedder, dynamic_vocabs, self.base_vocab)

        # WARNING:
        # Note that we currently use the same token embedder for both inputs to the encoder and
        # inputs to the decoder. In the future, we may use a different embedder for the decoder.

        encoder_input = self.encoder.preprocess(input_words,
                                                target_words,
                                                dynamic_token_embedder,
                                                volatile=volatile)
        train_decoder_input = HardCopyTrainDecoderInput(
            target_words, dynamic_vocabs, self.base_vocab)
        return EditorInput(encoder_input, train_decoder_input)
コード例 #2
0
    def _edit_batch(self, examples, max_seq_length, beam_size, constrain_vocab):
        # should only run in evaluation mode
        assert not self.training

        input_words, output_words = self._batch_editor_examples(examples)
        base_vocab = self.base_vocab
        dynamic_vocabs = self._compute_dynamic_vocabs(input_words, base_vocab)
        dynamic_token_embedder = DynamicMultiVocabTokenEmbedder(self.base_source_token_embedder, dynamic_vocabs, base_vocab)

        encoder_input = self.encoder.preprocess(input_words, output_words, dynamic_token_embedder, volatile=True)
        encoder_output, _ = self.encoder(encoder_input)

        extension_probs_modifiers = []

        if constrain_vocab:
            whitelists = [flatten(ex.input_words) for ex in examples]  # will contain duplicates, that's ok
            vocab_constrainer = LexicalWhitelister(whitelists, self.base_vocab, word_to_forms)
            extension_probs_modifiers.append(vocab_constrainer)

        beams, decoder_traces = self.test_decoder_beam.decode(examples, encoder_output,
            beam_size=beam_size, max_seq_length=max_seq_length,
            extension_probs_modifiers=extension_probs_modifiers
        )

        # replace copy tokens in predictions with actual words, modifying beams in-place
        for beam, dyna_vocab in izip(beams, dynamic_vocabs):
            copy_to_word = dyna_vocab.copy_token_to_word
            for i, seq in enumerate(beam):
                beam[i] = [copy_to_word.get(w, w) for w in seq]

        return beams, [EditTrace(ex, d_trace.beam_traces[-1], dyna_vocab)
                       for ex, d_trace, dyna_vocab in izip(examples, decoder_traces, dynamic_vocabs)]
コード例 #3
0
 def token_embedder(self, base_vocab, embeds_array, dynamic_vocabs):
     word_embeds = SimpleEmbeddings(embeds_array, base_vocab)
     base_embedder = TokenEmbedder(word_embeds)
     return DynamicMultiVocabTokenEmbedder(base_embedder, dynamic_vocabs, base_vocab)