Пример #1
0
    def gen_masked_source_target(self, tokens: List[int], vocab: Vocabulary):
        cleaned_tokens = self.clean_eos_bos(tokens)
        original_target_string = " ".join(
            [vocab[idx] for idx in cleaned_tokens]).upper()
        try:
            annotation = Annotation(
                original_target_string,
                accept_flat_intents_slots=self.accept_flat_intents_slots,
            )
        except Exception as e:
            # This should never happen other than when testing
            print(e, original_target_string)
            dec_source = [
                vocab.idx[vocab.mask_token] for _ in range(len(tokens))
            ]
            dec_target = [
                vocab.idx[vocab.pad_token] for _ in range(len(tokens))
            ]
            return dec_source, dec_target
        assert len(annotation.root.children) == 1
        mask_tree_str = self.gen_masked_tree(annotation.root.children[0],
                                             vocab.mask_token)

        # We are calling the .split() instead of the tokenize() of tensorizer
        # because the input str contains special MASK token __MASK__
        # It we call tokenize() on this input_str, it may lower __MASK__ or split
        # in unexpected ways causing issues.
        # Hence temporary workaround is that we call split(" ") and lower all tokens
        # other than MASK tokens

        # handle special tokens in vocab
        mask_tree_str: List[str] = list(
            map(
                lambda token: SPECIAL_TOKENS.get(token, token.lower()),
                mask_tree_str.split(" "),
            ))

        dec_source = [vocab.idx.get(t) for t in mask_tree_str]

        dec_target = self._prepare_dec_target(dec_source, cleaned_tokens,
                                              vocab)

        if self.use_bos:
            if self.should_mask():
                dec_source.insert(0, vocab.get_mask_index())
                dec_target.insert(0, vocab.get_bos_index())
            else:
                dec_source.insert(0, vocab.get_bos_index())
                dec_target.insert(0, vocab.get_pad_index())

        if self.use_eos:
            if self.should_mask():
                dec_source.append(vocab.get_mask_index())
                dec_target.append(vocab.get_eos_index())
            else:
                dec_source.append(vocab.get_eos_index())
                dec_target.append(vocab.get_pad_index())
        return dec_source, dec_target
Пример #2
0
    def __init__(
        self,
        model: RNNModel,
        output_layer: Seq2SeqOutputLayer,
        sequence_generator: ScriptedSequenceGenerator,
        src_vocab: Vocabulary,
        trg_vocab: Vocabulary,
        dictfeat_vocab: Vocabulary,
    ):
        BaseModel.__init__(self)
        self.model = model
        self.encoder = self.model.encoder
        self.decoder = self.model.decoder
        self.output_layer = output_layer
        self.sequence_generator = sequence_generator

        # Target vocab EOS index is useful for recognizing when to stop generating
        self.trg_eos_index = trg_vocab.get_eos_index()

        # Target vocab PAD index is useful for shifting source/target prior to decoding
        self.trg_pad_index = trg_vocab.get_pad_index()

        # Source, target and dictfeat vocab are needed for export so that we can handle
        # string input
        self.src_dict = src_vocab
        self.trg_dict = trg_vocab
        self.dictfeat_dict = dictfeat_vocab

        self.force_eval_predictions = False
Пример #3
0
 def __init__(self, vocab: Vocabulary):
     super().__init__()
     self.vocab = ScriptVocabulary(
         list(vocab),
         pad_idx=vocab.get_pad_index(-1),
         bos_idx=vocab.get_bos_index(-1),
         eos_idx=vocab.get_eos_index(-1),
         unk_idx=vocab.get_unk_index(-1),
     )
Пример #4
0
    def __init__(
        self,
        model: RNNModel,
        output_layer: Seq2SeqOutputLayer,
        src_vocab: Vocabulary,
        trg_vocab: Vocabulary,
        dictfeat_vocab: Vocabulary,
        generator_config=None,
    ):
        BaseModel.__init__(self)
        self.model = model
        self.encoder = self.model.encoder
        self.decoder = self.model.decoder
        self.output_layer = output_layer

        # Sequence generation is expected to be used only for inference, and to
        # take the trained model(s) as input. Creating the sequence generator
        # may apply Torchscript JIT compilation and quantization, which modify
        # the input model. Therefore, we want to create the sequence generator
        # after training.
        if generator_config is not None:
            self.sequence_generator_builder = lambda models: create_module(
                generator_config, models, trg_vocab.get_eos_index())
        self.sequence_generator = None

        # Disable predictions until testing (see above comment about sequence
        # generator). If this functionality is needed, a new sequence generator
        # with a copy of the model should be used for each epoch during the
        # EVAL stage.
        self.force_eval_predictions = False

        # Target vocab EOS index is useful for recognizing when to stop generating
        self.trg_eos_index = trg_vocab.get_eos_index()

        # Target vocab PAD index is useful for shifting source/target prior to decoding
        self.trg_pad_index = trg_vocab.get_pad_index()

        # Source, target and dictfeat vocab are needed for export so that we can handle
        # string input
        self.src_dict = src_vocab
        self.trg_dict = trg_vocab
        self.dictfeat_dict = dictfeat_vocab

        log_class_usage(__class__)
Пример #5
0
 def __init__(self, tokenizer: Tokenizer, vocab: Vocabulary, max_seq_len: int):
     super().__init__()
     self.tokenizer = tokenizer
     self.vocab = ScriptVocabulary(
         list(vocab),
         pad_idx=vocab.get_pad_index(),
         bos_idx=vocab.get_bos_index(-1),
         eos_idx=vocab.get_eos_index(-1),
         unk_idx=vocab.get_unk_index(),
     )
     self.vocab_lookup = VocabLookup(self.vocab)
     self.max_seq_len = max_seq_len