def test_full_tokenizer(self):
        tokenizer = MBart50Tokenizer(SAMPLE_VOCAB, src_lang="en_XX", tgt_lang="ro_RO", keep_accents=True)

        tokens = tokenizer.tokenize("This is a test")
        self.assertListEqual(tokens, ["▁This", "▁is", "▁a", "▁t", "est"])

        self.assertListEqual(
            tokenizer.convert_tokens_to_ids(tokens),
            [value + tokenizer.fairseq_offset for value in [285, 46, 10, 170, 382]],
        )

        tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.")
        self.assertListEqual(
            tokens,
            # fmt: off
            [SPIECE_UNDERLINE + "I", SPIECE_UNDERLINE + "was", SPIECE_UNDERLINE + "b", "or", "n", SPIECE_UNDERLINE + "in", SPIECE_UNDERLINE + "", "9", "2", "0", "0", "0", ",", SPIECE_UNDERLINE + "and", SPIECE_UNDERLINE + "this", SPIECE_UNDERLINE + "is", SPIECE_UNDERLINE + "f", "al", "s", "é", "."],
            # fmt: on
        )
        ids = tokenizer.convert_tokens_to_ids(tokens)
        self.assertListEqual(
            ids,
            [
                value + tokenizer.fairseq_offset
                for value in [8, 21, 84, 55, 24, 19, 7, 2, 602, 347, 347, 347, 3, 12, 66, 46, 72, 80, 6, 2, 4]
            ],
        )

        back_tokens = tokenizer.convert_ids_to_tokens(ids)
        self.assertListEqual(
            back_tokens,
            # fmt: off
            [SPIECE_UNDERLINE + "I", SPIECE_UNDERLINE + "was", SPIECE_UNDERLINE + "b", "or", "n", SPIECE_UNDERLINE + "in", SPIECE_UNDERLINE + "", "<unk>", "2", "0", "0", "0", ",", SPIECE_UNDERLINE + "and", SPIECE_UNDERLINE + "this", SPIECE_UNDERLINE + "is", SPIECE_UNDERLINE + "f", "al", "s", "<unk>", "."],
            # fmt: on
        )
 def test_special_tokens_unaffacted_by_save_load(self):
     tmpdirname = tempfile.mkdtemp()
     original_special_tokens = self.tokenizer.fairseq_tokens_to_ids
     self.tokenizer.save_pretrained(tmpdirname)
     new_tok = MBart50Tokenizer.from_pretrained(tmpdirname)
     self.assertDictEqual(new_tok.fairseq_tokens_to_ids,
                          original_special_tokens)
    def setUp(self):
        super().setUp()

        # We have a SentencePiece fixture for testing
        tokenizer = MBart50Tokenizer(SAMPLE_VOCAB,
                                     src_lang="en_XX",
                                     tgt_lang="ro_RO",
                                     keep_accents=True)
        tokenizer.save_pretrained(self.tmpdirname)
    def test_tokenization_mbart50(self):
        # Given
        self.base_tokenizer = MBart50Tokenizer.from_pretrained(
            'facebook/mbart-large-50-many-to-many-mmt',
            do_lower_case=False,
            cache_dir=self.test_dir)
        self.rust_tokenizer = PyMBart50Tokenizer(get_from_cache(
            'https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt/resolve/main/sentencepiece.bpe.model'
        ),
                                                 do_lower_case=False)
        self.base_tokenizer.src_lang = "fr_XX"
        output_baseline = []
        for example in self.examples:
            output_baseline.append(
                self.base_tokenizer.encode_plus(
                    example.text_a,
                    add_special_tokens=True,
                    return_overflowing_tokens=True,
                    return_special_tokens_mask=True,
                    max_length=128))

        # When
        # Note: the original sentence piece tokenizer strips trailing spaces
        output_rust = self.rust_tokenizer.encode_list(
            ["fr_XX " + example.text_a.strip() for example in self.examples],
            max_len=256,
            truncation_strategy='longest_first',
            stride=0)

        # Then
        for idx, (rust,
                  baseline) in enumerate(zip(output_rust, output_baseline)):
            if rust.token_ids != baseline['input_ids']:
                if len(rust.token_ids) == len(baseline['input_ids']):
                    if Counter(rust.token_ids) != Counter(
                            baseline['input_ids']):
                        raise AssertionError(
                            f'Difference in tokenization for {self.rust_tokenizer.__class__}: \n '
                            f'Sentence a: {self.examples[idx].text_a} \n'
                            f'Sentence b: {self.examples[idx].text_b} \n'
                            f'Token mismatch: {self.get_token_diff(rust.token_ids, baseline["input_ids"])} \n'
                            f'Rust: {rust.token_ids} \n'
                            f'Python {baseline["input_ids"]}')
                else:
                    raise AssertionError(
                        f'Difference in tokenization for {self.rust_tokenizer.__class__}: \n '
                        f'Sentence a: {self.examples[idx].text_a} \n'
                        f'Sentence b: {self.examples[idx].text_b} \n'
                        f'Token mismatch: {self.get_token_diff(rust.token_ids, baseline["input_ids"])} \n'
                        f'Rust: {rust.token_ids} \n'
                        f'Python {baseline["input_ids"]}')
            assert (
                rust.special_tokens_mask == baseline['special_tokens_mask'])
Exemple #5
0
 def __init__(self, model_path: str, device: str = 'cuda') -> None:
     self.device = device
     self.model = MBartForConditionalGeneration.from_pretrained(model_path).to(device)
     self.tokenizer = MBart50Tokenizer.from_pretrained(model_path)
 def setUpClass(cls):
     cls.tokenizer: MBart50Tokenizer = MBart50Tokenizer.from_pretrained(
         cls.checkpoint_name, src_lang="en_XX", tgt_lang="ro_RO"
     )
     cls.pad_token_id = 1
     return cls
Exemple #7
0
def download_model():
    model_name = "facebook/mbart-large-50-many-to-many-mmt"
    model = MBartForConditionalGeneration.from_pretrained(model_name)
    tokenizer = MBart50Tokenizer.from_pretrained(model_name)
    return model, tokenizer
Exemple #8
0
    def get_tokenizer(self, save_dir, config, src_lang, tgt_lang):
        tokenizer_args = {
            'do_lower_case': False,
            'do_basic_tokenize': False,
            'cache_dir': self._cache,
            'use_fast': self._use_fast(),
            'src_lang': src_lang,
            'tgt_lang': tgt_lang
        }
        if save_dir is not None:
            tokenizer_args.update({
                'pretrained_model_name_or_path': save_dir,
                'config': config
            })
        else:
            tokenizer_args.update(
                {'pretrained_model_name_or_path': self._pretrained_name})

        model_is_marian = isinstance(config, MarianConfig)
        model_is_mbart = isinstance(config, MBartConfig)
        model_is_m2m100 = isinstance(config, M2M100Config)
        model_is_t5 = isinstance(config, T5Config)

        # hack until huggingface provides mbart50 config
        if model_is_mbart and 'mbart-50' in config.name_or_path:
            self._tokenizer = MBart50Tokenizer.from_pretrained(
                **tokenizer_args)
        elif model_is_m2m100:
            self._tokenizer = M2M100Tokenizer.from_pretrained(**tokenizer_args)
        else:
            self._tokenizer = AutoTokenizer.from_pretrained(**tokenizer_args)

        # some tokenizers like Mbart do not set src_lang and tgt_lan when initialized; take care of it here
        self._tokenizer.src_lang = src_lang
        self._tokenizer.tgt_lang = tgt_lang

        # define input prefix to add before every input text
        input_prefix = ''
        if model_is_marian and tgt_lang:
            input_prefix = f'>>{tgt_lang}<< '
        elif model_is_t5:
            t5_task = f'translation_{src_lang}_to_{tgt_lang}'
            # TODO add support for summarization
            # t5_task = 'summarization'
            input_prefix = config.task_specific_params[t5_task]['prefix']

        self.input_prefix = input_prefix

        # We only include the base tokenizers since `isinstance` checks for inheritance
        if isinstance(self._tokenizer, (BertTokenizer, BertTokenizerFast)):
            self._tokenizer.is_piece_fn = lambda wp: wp.startswith('##')
        elif isinstance(self._tokenizer,
                        (XLMRobertaTokenizer, XLMRobertaTokenizerFast,
                         MarianTokenizer, M2M100Tokenizer)):
            self._tokenizer.is_piece_fn = lambda wp: not wp.startswith(
                SPIECE_UNDERLINE)
        elif isinstance(self._tokenizer, (GPT2Tokenizer, GPT2TokenizerFast)):
            self._tokenizer.is_piece_fn = lambda wp: not wp.startswith('Ġ')

        # make sure we assigned is_piece_fn
        assert self._tokenizer.is_piece_fn
def convert_wav2vec2_checkpoint(
    checkpoint_path,
    pytorch_dump_folder_path,
    dict_path,
    config_yaml_path,
    encoder_config_path,
    decoder_config_path,
    add_adapter,
    adapter_kernel_size,
    adapter_stride,
    decoder_start_token_id,
    encoder_output_dim,
):
    """
    Copy/paste/tweak model's weights to transformers design.
    """
    # load configs
    encoder_config = Wav2Vec2Config.from_pretrained(
        encoder_config_path,
        add_adapter=True,
        adapter_stride=adapter_stride,
        adapter_kernel_size=adapter_kernel_size,
        use_auth_token=True,
        output_hidden_size=encoder_output_dim,
    )
    decoder_config = MBartConfig.from_pretrained(decoder_config_path)

    # load model
    model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
        [checkpoint_path],
        arg_overrides={
            "config_yaml": config_yaml_path,
            "data": "/".join(dict_path.split("/")[:-1]),
            "w2v_path": checkpoint_path,
            "load_pretrained_decoder_from": None,
        },
    )
    model = model[0].eval()

    # load feature extractor
    feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
        encoder_config_path, use_auth_token=True)

    # set weights for wav2vec2 encoder
    hf_encoder = Wav2Vec2Model(encoder_config)

    recursively_load_weights_wav2vec2(model.encoder, hf_encoder)

    # load decoder weights
    hf_decoder = MBartForCausalLM(decoder_config)
    missing_keys, unexpected_keys = hf_decoder.model.decoder.load_state_dict(
        model.decoder.state_dict(), strict=False)
    logger.warning(
        f"The following keys are missing when loading the decoder weights: {missing_keys}"
    )
    logger.warning(
        f"The following keys are unexpected when loading the decoder weights: {unexpected_keys}"
    )

    hf_wav2vec = SpeechEncoderDecoderModel(encoder=hf_encoder,
                                           decoder=hf_decoder)
    hf_wav2vec.config.tie_word_embeddings = False

    tokenizer = MBart50Tokenizer(dict_path)
    tokenizer.save_pretrained(pytorch_dump_folder_path)

    config = hf_wav2vec.config.to_dict()
    config["pad_token_id"] = tokenizer.pad_token_id
    config["bos_token_id"] = tokenizer.bos_token_id
    config["eos_token_id"] = tokenizer.eos_token_id
    config["tokenizer_class"] = "mbart50"
    config["feature_extractor_type"] = "wav2vec2"

    config["decoder_start_token_id"] = tokenizer.eos_token_id
    config["forced_bos_token_id"] = 250004
    config["forced_eos_token_id"] = tokenizer.eos_token_id

    hf_wav2vec.config = SpeechEncoderDecoderConfig.from_dict(config)

    hf_wav2vec.save_pretrained(pytorch_dump_folder_path)
    feature_extractor.save_pretrained(pytorch_dump_folder_path)