def gen_masked_source_target(self, tokens: List[int], vocab: Vocabulary):
        cleaned_tokens = self.clean_eos_bos(tokens)
        original_target_string = " ".join(
            [vocab[idx] for idx in cleaned_tokens]).upper()
        try:
            annotation = Annotation(
                original_target_string,
                accept_flat_intents_slots=self.accept_flat_intents_slots,
            )
        except Exception as e:
            # This should never happen other than when testing
            print(e, original_target_string)
            dec_source = [
                vocab.idx[vocab.mask_token] for _ in range(len(tokens))
            ]
            dec_target = [
                vocab.idx[vocab.pad_token] for _ in range(len(tokens))
            ]
            return dec_source, dec_target
        assert len(annotation.root.children) == 1
        mask_tree_str = self.gen_masked_tree(annotation.root.children[0],
                                             vocab.mask_token)

        # We are calling the .split() instead of the tokenize() of tensorizer
        # because the input str contains special MASK token __MASK__
        # It we call tokenize() on this input_str, it may lower __MASK__ or split
        # in unexpected ways causing issues.
        # Hence temporary workaround is that we call split(" ") and lower all tokens
        # other than MASK tokens

        # handle special tokens in vocab
        mask_tree_str: List[str] = list(
            map(
                lambda token: SPECIAL_TOKENS.get(token, token.lower()),
                mask_tree_str.split(" "),
            ))

        dec_source = [vocab.idx.get(t) for t in mask_tree_str]

        dec_target = self._prepare_dec_target(dec_source, cleaned_tokens,
                                              vocab)

        if self.use_bos:
            if self.should_mask():
                dec_source.insert(0, vocab.get_mask_index())
                dec_target.insert(0, vocab.get_bos_index())
            else:
                dec_source.insert(0, vocab.get_bos_index())
                dec_target.insert(0, vocab.get_pad_index())

        if self.use_eos:
            if self.should_mask():
                dec_source.append(vocab.get_mask_index())
                dec_target.append(vocab.get_eos_index())
            else:
                dec_source.append(vocab.get_eos_index())
                dec_target.append(vocab.get_pad_index())
        return dec_source, dec_target
    def _prepare_dec_target(self, dec_source: List[int],
                            clean_input_tokens: List[int],
                            vocab: Vocabulary) -> List[int]:
        dec_target = [
            vocab.get_pad_index()
            if dec_source_token != vocab.get_mask_index() else
            dec_real_target_token
            for (dec_source_token,
                 dec_real_target_token) in zip(dec_source, clean_input_tokens)
        ]

        return dec_target
    def gen_masked_source_target(self, tokens: List[int], vocab: Vocabulary):
        num_masks = self.random.randint(self.minimum_masks, len(tokens))

        ind: Set[int] = set(
            self.random.choice(len(tokens), size=num_masks, replace=False))

        dec_source: List[int] = [
            vocab.get_mask_index() if idx in ind else token
            for idx, token in enumerate(tokens)
        ]

        dec_target = self._prepare_dec_target(dec_source, tokens, vocab)

        return dec_source, dec_target
 def gen_masked_source_target(self, tokens, vocab: Vocabulary):
     dec_source: List[int] = [vocab.get_mask_index() for idx in tokens]
     dec_target = self._prepare_dec_target(dec_source, tokens, vocab)
     return dec_source, dec_target