def test_full_tokenizer(self): tokenizer = MBart50Tokenizer(SAMPLE_VOCAB, src_lang="en_XX", tgt_lang="ro_RO", keep_accents=True) tokens = tokenizer.tokenize("This is a test") self.assertListEqual(tokens, ["▁This", "▁is", "▁a", "▁t", "est"]) self.assertListEqual( tokenizer.convert_tokens_to_ids(tokens), [value + tokenizer.fairseq_offset for value in [285, 46, 10, 170, 382]], ) tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.") self.assertListEqual( tokens, # fmt: off [SPIECE_UNDERLINE + "I", SPIECE_UNDERLINE + "was", SPIECE_UNDERLINE + "b", "or", "n", SPIECE_UNDERLINE + "in", SPIECE_UNDERLINE + "", "9", "2", "0", "0", "0", ",", SPIECE_UNDERLINE + "and", SPIECE_UNDERLINE + "this", SPIECE_UNDERLINE + "is", SPIECE_UNDERLINE + "f", "al", "s", "é", "."], # fmt: on ) ids = tokenizer.convert_tokens_to_ids(tokens) self.assertListEqual( ids, [ value + tokenizer.fairseq_offset for value in [8, 21, 84, 55, 24, 19, 7, 2, 602, 347, 347, 347, 3, 12, 66, 46, 72, 80, 6, 2, 4] ], ) back_tokens = tokenizer.convert_ids_to_tokens(ids) self.assertListEqual( back_tokens, # fmt: off [SPIECE_UNDERLINE + "I", SPIECE_UNDERLINE + "was", SPIECE_UNDERLINE + "b", "or", "n", SPIECE_UNDERLINE + "in", SPIECE_UNDERLINE + "", "<unk>", "2", "0", "0", "0", ",", SPIECE_UNDERLINE + "and", SPIECE_UNDERLINE + "this", SPIECE_UNDERLINE + "is", SPIECE_UNDERLINE + "f", "al", "s", "<unk>", "."], # fmt: on )
def test_special_tokens_unaffacted_by_save_load(self): tmpdirname = tempfile.mkdtemp() original_special_tokens = self.tokenizer.fairseq_tokens_to_ids self.tokenizer.save_pretrained(tmpdirname) new_tok = MBart50Tokenizer.from_pretrained(tmpdirname) self.assertDictEqual(new_tok.fairseq_tokens_to_ids, original_special_tokens)
def setUp(self): super().setUp() # We have a SentencePiece fixture for testing tokenizer = MBart50Tokenizer(SAMPLE_VOCAB, src_lang="en_XX", tgt_lang="ro_RO", keep_accents=True) tokenizer.save_pretrained(self.tmpdirname)
def test_tokenization_mbart50(self): # Given self.base_tokenizer = MBart50Tokenizer.from_pretrained( 'facebook/mbart-large-50-many-to-many-mmt', do_lower_case=False, cache_dir=self.test_dir) self.rust_tokenizer = PyMBart50Tokenizer(get_from_cache( 'https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt/resolve/main/sentencepiece.bpe.model' ), do_lower_case=False) self.base_tokenizer.src_lang = "fr_XX" output_baseline = [] for example in self.examples: output_baseline.append( self.base_tokenizer.encode_plus( example.text_a, add_special_tokens=True, return_overflowing_tokens=True, return_special_tokens_mask=True, max_length=128)) # When # Note: the original sentence piece tokenizer strips trailing spaces output_rust = self.rust_tokenizer.encode_list( ["fr_XX " + example.text_a.strip() for example in self.examples], max_len=256, truncation_strategy='longest_first', stride=0) # Then for idx, (rust, baseline) in enumerate(zip(output_rust, output_baseline)): if rust.token_ids != baseline['input_ids']: if len(rust.token_ids) == len(baseline['input_ids']): if Counter(rust.token_ids) != Counter( baseline['input_ids']): raise AssertionError( f'Difference in tokenization for {self.rust_tokenizer.__class__}: \n ' f'Sentence a: {self.examples[idx].text_a} \n' f'Sentence b: {self.examples[idx].text_b} \n' f'Token mismatch: {self.get_token_diff(rust.token_ids, baseline["input_ids"])} \n' f'Rust: {rust.token_ids} \n' f'Python {baseline["input_ids"]}') else: raise AssertionError( f'Difference in tokenization for {self.rust_tokenizer.__class__}: \n ' f'Sentence a: {self.examples[idx].text_a} \n' f'Sentence b: {self.examples[idx].text_b} \n' f'Token mismatch: {self.get_token_diff(rust.token_ids, baseline["input_ids"])} \n' f'Rust: {rust.token_ids} \n' f'Python {baseline["input_ids"]}') assert ( rust.special_tokens_mask == baseline['special_tokens_mask'])
def __init__(self, model_path: str, device: str = 'cuda') -> None: self.device = device self.model = MBartForConditionalGeneration.from_pretrained(model_path).to(device) self.tokenizer = MBart50Tokenizer.from_pretrained(model_path)
def setUpClass(cls): cls.tokenizer: MBart50Tokenizer = MBart50Tokenizer.from_pretrained( cls.checkpoint_name, src_lang="en_XX", tgt_lang="ro_RO" ) cls.pad_token_id = 1 return cls
def download_model(): model_name = "facebook/mbart-large-50-many-to-many-mmt" model = MBartForConditionalGeneration.from_pretrained(model_name) tokenizer = MBart50Tokenizer.from_pretrained(model_name) return model, tokenizer
def get_tokenizer(self, save_dir, config, src_lang, tgt_lang): tokenizer_args = { 'do_lower_case': False, 'do_basic_tokenize': False, 'cache_dir': self._cache, 'use_fast': self._use_fast(), 'src_lang': src_lang, 'tgt_lang': tgt_lang } if save_dir is not None: tokenizer_args.update({ 'pretrained_model_name_or_path': save_dir, 'config': config }) else: tokenizer_args.update( {'pretrained_model_name_or_path': self._pretrained_name}) model_is_marian = isinstance(config, MarianConfig) model_is_mbart = isinstance(config, MBartConfig) model_is_m2m100 = isinstance(config, M2M100Config) model_is_t5 = isinstance(config, T5Config) # hack until huggingface provides mbart50 config if model_is_mbart and 'mbart-50' in config.name_or_path: self._tokenizer = MBart50Tokenizer.from_pretrained( **tokenizer_args) elif model_is_m2m100: self._tokenizer = M2M100Tokenizer.from_pretrained(**tokenizer_args) else: self._tokenizer = AutoTokenizer.from_pretrained(**tokenizer_args) # some tokenizers like Mbart do not set src_lang and tgt_lan when initialized; take care of it here self._tokenizer.src_lang = src_lang self._tokenizer.tgt_lang = tgt_lang # define input prefix to add before every input text input_prefix = '' if model_is_marian and tgt_lang: input_prefix = f'>>{tgt_lang}<< ' elif model_is_t5: t5_task = f'translation_{src_lang}_to_{tgt_lang}' # TODO add support for summarization # t5_task = 'summarization' input_prefix = config.task_specific_params[t5_task]['prefix'] self.input_prefix = input_prefix # We only include the base tokenizers since `isinstance` checks for inheritance if isinstance(self._tokenizer, (BertTokenizer, BertTokenizerFast)): self._tokenizer.is_piece_fn = lambda wp: wp.startswith('##') elif isinstance(self._tokenizer, (XLMRobertaTokenizer, XLMRobertaTokenizerFast, MarianTokenizer, M2M100Tokenizer)): self._tokenizer.is_piece_fn = lambda wp: not wp.startswith( SPIECE_UNDERLINE) elif isinstance(self._tokenizer, (GPT2Tokenizer, GPT2TokenizerFast)): self._tokenizer.is_piece_fn = lambda wp: not wp.startswith('Ġ') # make sure we assigned is_piece_fn assert self._tokenizer.is_piece_fn
def convert_wav2vec2_checkpoint( checkpoint_path, pytorch_dump_folder_path, dict_path, config_yaml_path, encoder_config_path, decoder_config_path, add_adapter, adapter_kernel_size, adapter_stride, decoder_start_token_id, encoder_output_dim, ): """ Copy/paste/tweak model's weights to transformers design. """ # load configs encoder_config = Wav2Vec2Config.from_pretrained( encoder_config_path, add_adapter=True, adapter_stride=adapter_stride, adapter_kernel_size=adapter_kernel_size, use_auth_token=True, output_hidden_size=encoder_output_dim, ) decoder_config = MBartConfig.from_pretrained(decoder_config_path) # load model model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task( [checkpoint_path], arg_overrides={ "config_yaml": config_yaml_path, "data": "/".join(dict_path.split("/")[:-1]), "w2v_path": checkpoint_path, "load_pretrained_decoder_from": None, }, ) model = model[0].eval() # load feature extractor feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( encoder_config_path, use_auth_token=True) # set weights for wav2vec2 encoder hf_encoder = Wav2Vec2Model(encoder_config) recursively_load_weights_wav2vec2(model.encoder, hf_encoder) # load decoder weights hf_decoder = MBartForCausalLM(decoder_config) missing_keys, unexpected_keys = hf_decoder.model.decoder.load_state_dict( model.decoder.state_dict(), strict=False) logger.warning( f"The following keys are missing when loading the decoder weights: {missing_keys}" ) logger.warning( f"The following keys are unexpected when loading the decoder weights: {unexpected_keys}" ) hf_wav2vec = SpeechEncoderDecoderModel(encoder=hf_encoder, decoder=hf_decoder) hf_wav2vec.config.tie_word_embeddings = False tokenizer = MBart50Tokenizer(dict_path) tokenizer.save_pretrained(pytorch_dump_folder_path) config = hf_wav2vec.config.to_dict() config["pad_token_id"] = tokenizer.pad_token_id config["bos_token_id"] = tokenizer.bos_token_id config["eos_token_id"] = tokenizer.eos_token_id config["tokenizer_class"] = "mbart50" config["feature_extractor_type"] = "wav2vec2" config["decoder_start_token_id"] = tokenizer.eos_token_id config["forced_bos_token_id"] = 250004 config["forced_eos_token_id"] = tokenizer.eos_token_id hf_wav2vec.config = SpeechEncoderDecoderConfig.from_dict(config) hf_wav2vec.save_pretrained(pytorch_dump_folder_path) feature_extractor.save_pretrained(pytorch_dump_folder_path)