def gen_masked_source_target(self, tokens: List[int], vocab: Vocabulary): cleaned_tokens = self.clean_eos_bos(tokens) original_target_string = " ".join( [vocab[idx] for idx in cleaned_tokens]).upper() try: annotation = Annotation( original_target_string, accept_flat_intents_slots=self.accept_flat_intents_slots, ) except Exception as e: # This should never happen other than when testing print(e, original_target_string) dec_source = [ vocab.idx[vocab.mask_token] for _ in range(len(tokens)) ] dec_target = [ vocab.idx[vocab.pad_token] for _ in range(len(tokens)) ] return dec_source, dec_target assert len(annotation.root.children) == 1 mask_tree_str = self.gen_masked_tree(annotation.root.children[0], vocab.mask_token) # We are calling the .split() instead of the tokenize() of tensorizer # because the input str contains special MASK token __MASK__ # It we call tokenize() on this input_str, it may lower __MASK__ or split # in unexpected ways causing issues. # Hence temporary workaround is that we call split(" ") and lower all tokens # other than MASK tokens # handle special tokens in vocab mask_tree_str: List[str] = list( map( lambda token: SPECIAL_TOKENS.get(token, token.lower()), mask_tree_str.split(" "), )) dec_source = [vocab.idx.get(t) for t in mask_tree_str] dec_target = self._prepare_dec_target(dec_source, cleaned_tokens, vocab) if self.use_bos: if self.should_mask(): dec_source.insert(0, vocab.get_mask_index()) dec_target.insert(0, vocab.get_bos_index()) else: dec_source.insert(0, vocab.get_bos_index()) dec_target.insert(0, vocab.get_pad_index()) if self.use_eos: if self.should_mask(): dec_source.append(vocab.get_mask_index()) dec_target.append(vocab.get_eos_index()) else: dec_source.append(vocab.get_eos_index()) dec_target.append(vocab.get_pad_index()) return dec_source, dec_target
def __init__( self, model: RNNModel, output_layer: Seq2SeqOutputLayer, sequence_generator: ScriptedSequenceGenerator, src_vocab: Vocabulary, trg_vocab: Vocabulary, dictfeat_vocab: Vocabulary, ): BaseModel.__init__(self) self.model = model self.encoder = self.model.encoder self.decoder = self.model.decoder self.output_layer = output_layer self.sequence_generator = sequence_generator # Target vocab EOS index is useful for recognizing when to stop generating self.trg_eos_index = trg_vocab.get_eos_index() # Target vocab PAD index is useful for shifting source/target prior to decoding self.trg_pad_index = trg_vocab.get_pad_index() # Source, target and dictfeat vocab are needed for export so that we can handle # string input self.src_dict = src_vocab self.trg_dict = trg_vocab self.dictfeat_dict = dictfeat_vocab self.force_eval_predictions = False
def __init__(self, vocab: Vocabulary): super().__init__() self.vocab = ScriptVocabulary( list(vocab), pad_idx=vocab.get_pad_index(-1), bos_idx=vocab.get_bos_index(-1), eos_idx=vocab.get_eos_index(-1), unk_idx=vocab.get_unk_index(-1), )
def __init__( self, model: RNNModel, output_layer: Seq2SeqOutputLayer, src_vocab: Vocabulary, trg_vocab: Vocabulary, dictfeat_vocab: Vocabulary, generator_config=None, ): BaseModel.__init__(self) self.model = model self.encoder = self.model.encoder self.decoder = self.model.decoder self.output_layer = output_layer # Sequence generation is expected to be used only for inference, and to # take the trained model(s) as input. Creating the sequence generator # may apply Torchscript JIT compilation and quantization, which modify # the input model. Therefore, we want to create the sequence generator # after training. if generator_config is not None: self.sequence_generator_builder = lambda models: create_module( generator_config, models, trg_vocab.get_eos_index()) self.sequence_generator = None # Disable predictions until testing (see above comment about sequence # generator). If this functionality is needed, a new sequence generator # with a copy of the model should be used for each epoch during the # EVAL stage. self.force_eval_predictions = False # Target vocab EOS index is useful for recognizing when to stop generating self.trg_eos_index = trg_vocab.get_eos_index() # Target vocab PAD index is useful for shifting source/target prior to decoding self.trg_pad_index = trg_vocab.get_pad_index() # Source, target and dictfeat vocab are needed for export so that we can handle # string input self.src_dict = src_vocab self.trg_dict = trg_vocab self.dictfeat_dict = dictfeat_vocab log_class_usage(__class__)
def __init__(self, tokenizer: Tokenizer, vocab: Vocabulary, max_seq_len: int): super().__init__() self.tokenizer = tokenizer self.vocab = ScriptVocabulary( list(vocab), pad_idx=vocab.get_pad_index(), bos_idx=vocab.get_bos_index(-1), eos_idx=vocab.get_eos_index(-1), unk_idx=vocab.get_unk_index(), ) self.vocab_lookup = VocabLookup(self.vocab) self.max_seq_len = max_seq_len