def test_special_tokens_unaffacted_by_save_load(self):
     tmpdirname = tempfile.mkdtemp()
     original_special_tokens = self.tokenizer.fairseq_tokens_to_ids
     self.tokenizer.save_pretrained(tmpdirname)
     new_tok = MBart50Tokenizer.from_pretrained(tmpdirname)
     self.assertDictEqual(new_tok.fairseq_tokens_to_ids,
                          original_special_tokens)
    def test_tokenization_mbart50(self):
        # Given
        self.base_tokenizer = MBart50Tokenizer.from_pretrained(
            'facebook/mbart-large-50-many-to-many-mmt',
            do_lower_case=False,
            cache_dir=self.test_dir)
        self.rust_tokenizer = PyMBart50Tokenizer(get_from_cache(
            'https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt/resolve/main/sentencepiece.bpe.model'
        ),
                                                 do_lower_case=False)
        self.base_tokenizer.src_lang = "fr_XX"
        output_baseline = []
        for example in self.examples:
            output_baseline.append(
                self.base_tokenizer.encode_plus(
                    example.text_a,
                    add_special_tokens=True,
                    return_overflowing_tokens=True,
                    return_special_tokens_mask=True,
                    max_length=128))

        # When
        # Note: the original sentence piece tokenizer strips trailing spaces
        output_rust = self.rust_tokenizer.encode_list(
            ["fr_XX " + example.text_a.strip() for example in self.examples],
            max_len=256,
            truncation_strategy='longest_first',
            stride=0)

        # Then
        for idx, (rust,
                  baseline) in enumerate(zip(output_rust, output_baseline)):
            if rust.token_ids != baseline['input_ids']:
                if len(rust.token_ids) == len(baseline['input_ids']):
                    if Counter(rust.token_ids) != Counter(
                            baseline['input_ids']):
                        raise AssertionError(
                            f'Difference in tokenization for {self.rust_tokenizer.__class__}: \n '
                            f'Sentence a: {self.examples[idx].text_a} \n'
                            f'Sentence b: {self.examples[idx].text_b} \n'
                            f'Token mismatch: {self.get_token_diff(rust.token_ids, baseline["input_ids"])} \n'
                            f'Rust: {rust.token_ids} \n'
                            f'Python {baseline["input_ids"]}')
                else:
                    raise AssertionError(
                        f'Difference in tokenization for {self.rust_tokenizer.__class__}: \n '
                        f'Sentence a: {self.examples[idx].text_a} \n'
                        f'Sentence b: {self.examples[idx].text_b} \n'
                        f'Token mismatch: {self.get_token_diff(rust.token_ids, baseline["input_ids"])} \n'
                        f'Rust: {rust.token_ids} \n'
                        f'Python {baseline["input_ids"]}')
            assert (
                rust.special_tokens_mask == baseline['special_tokens_mask'])
Exemple #3
0
 def __init__(self, model_path: str, device: str = 'cuda') -> None:
     self.device = device
     self.model = MBartForConditionalGeneration.from_pretrained(model_path).to(device)
     self.tokenizer = MBart50Tokenizer.from_pretrained(model_path)
 def setUpClass(cls):
     cls.tokenizer: MBart50Tokenizer = MBart50Tokenizer.from_pretrained(
         cls.checkpoint_name, src_lang="en_XX", tgt_lang="ro_RO"
     )
     cls.pad_token_id = 1
     return cls
Exemple #5
0
def download_model():
    model_name = "facebook/mbart-large-50-many-to-many-mmt"
    model = MBartForConditionalGeneration.from_pretrained(model_name)
    tokenizer = MBart50Tokenizer.from_pretrained(model_name)
    return model, tokenizer
Exemple #6
0
    def get_tokenizer(self, save_dir, config, src_lang, tgt_lang):
        tokenizer_args = {
            'do_lower_case': False,
            'do_basic_tokenize': False,
            'cache_dir': self._cache,
            'use_fast': self._use_fast(),
            'src_lang': src_lang,
            'tgt_lang': tgt_lang
        }
        if save_dir is not None:
            tokenizer_args.update({
                'pretrained_model_name_or_path': save_dir,
                'config': config
            })
        else:
            tokenizer_args.update(
                {'pretrained_model_name_or_path': self._pretrained_name})

        model_is_marian = isinstance(config, MarianConfig)
        model_is_mbart = isinstance(config, MBartConfig)
        model_is_m2m100 = isinstance(config, M2M100Config)
        model_is_t5 = isinstance(config, T5Config)

        # hack until huggingface provides mbart50 config
        if model_is_mbart and 'mbart-50' in config.name_or_path:
            self._tokenizer = MBart50Tokenizer.from_pretrained(
                **tokenizer_args)
        elif model_is_m2m100:
            self._tokenizer = M2M100Tokenizer.from_pretrained(**tokenizer_args)
        else:
            self._tokenizer = AutoTokenizer.from_pretrained(**tokenizer_args)

        # some tokenizers like Mbart do not set src_lang and tgt_lan when initialized; take care of it here
        self._tokenizer.src_lang = src_lang
        self._tokenizer.tgt_lang = tgt_lang

        # define input prefix to add before every input text
        input_prefix = ''
        if model_is_marian and tgt_lang:
            input_prefix = f'>>{tgt_lang}<< '
        elif model_is_t5:
            t5_task = f'translation_{src_lang}_to_{tgt_lang}'
            # TODO add support for summarization
            # t5_task = 'summarization'
            input_prefix = config.task_specific_params[t5_task]['prefix']

        self.input_prefix = input_prefix

        # We only include the base tokenizers since `isinstance` checks for inheritance
        if isinstance(self._tokenizer, (BertTokenizer, BertTokenizerFast)):
            self._tokenizer.is_piece_fn = lambda wp: wp.startswith('##')
        elif isinstance(self._tokenizer,
                        (XLMRobertaTokenizer, XLMRobertaTokenizerFast,
                         MarianTokenizer, M2M100Tokenizer)):
            self._tokenizer.is_piece_fn = lambda wp: not wp.startswith(
                SPIECE_UNDERLINE)
        elif isinstance(self._tokenizer, (GPT2Tokenizer, GPT2TokenizerFast)):
            self._tokenizer.is_piece_fn = lambda wp: not wp.startswith('Ġ')

        # make sure we assigned is_piece_fn
        assert self._tokenizer.is_piece_fn