def preprocess(self, examples, volatile=False): """Preprocess a batch of EditExamples, converting them into arrays. Args: examples (list[EditExample]) Returns: EditorInput """ input_words, target_words = self._batch_editor_examples(examples) dynamic_vocabs = self._compute_dynamic_vocabs(input_words, self.base_vocab) dynamic_token_embedder = DynamicMultiVocabTokenEmbedder( self.base_source_token_embedder, dynamic_vocabs, self.base_vocab) # WARNING: # Note that we currently use the same token embedder for both inputs to the encoder and # inputs to the decoder. In the future, we may use a different embedder for the decoder. encoder_input = self.encoder.preprocess(input_words, target_words, dynamic_token_embedder, volatile=volatile) train_decoder_input = HardCopyTrainDecoderInput( target_words, dynamic_vocabs, self.base_vocab) return EditorInput(encoder_input, train_decoder_input)
def _edit_batch(self, examples, max_seq_length, beam_size, constrain_vocab): # should only run in evaluation mode assert not self.training input_words, output_words = self._batch_editor_examples(examples) base_vocab = self.base_vocab dynamic_vocabs = self._compute_dynamic_vocabs(input_words, base_vocab) dynamic_token_embedder = DynamicMultiVocabTokenEmbedder(self.base_source_token_embedder, dynamic_vocabs, base_vocab) encoder_input = self.encoder.preprocess(input_words, output_words, dynamic_token_embedder, volatile=True) encoder_output, _ = self.encoder(encoder_input) extension_probs_modifiers = [] if constrain_vocab: whitelists = [flatten(ex.input_words) for ex in examples] # will contain duplicates, that's ok vocab_constrainer = LexicalWhitelister(whitelists, self.base_vocab, word_to_forms) extension_probs_modifiers.append(vocab_constrainer) beams, decoder_traces = self.test_decoder_beam.decode(examples, encoder_output, beam_size=beam_size, max_seq_length=max_seq_length, extension_probs_modifiers=extension_probs_modifiers ) # replace copy tokens in predictions with actual words, modifying beams in-place for beam, dyna_vocab in izip(beams, dynamic_vocabs): copy_to_word = dyna_vocab.copy_token_to_word for i, seq in enumerate(beam): beam[i] = [copy_to_word.get(w, w) for w in seq] return beams, [EditTrace(ex, d_trace.beam_traces[-1], dyna_vocab) for ex, d_trace, dyna_vocab in izip(examples, decoder_traces, dynamic_vocabs)]
def token_embedder(self, base_vocab, embeds_array, dynamic_vocabs): word_embeds = SimpleEmbeddings(embeds_array, base_vocab) base_embedder = TokenEmbedder(word_embeds) return DynamicMultiVocabTokenEmbedder(base_embedder, dynamic_vocabs, base_vocab)