コード例 #1
0
    def test_multilingual_translation(self):
        model = MBartForConditionalGeneration.from_pretrained(
            "facebook/mbart-large-50-many-to-many-mmt")
        tokenizer = MBart50TokenizerFast.from_pretrained(
            "facebook/mbart-large-50-many-to-many-mmt")

        translator = pipeline(task="translation",
                              model=model,
                              tokenizer=tokenizer)
        # Missing src_lang, tgt_lang
        with self.assertRaises(ValueError):
            translator("This is a test")

        outputs = translator("This is a test",
                             src_lang="en_XX",
                             tgt_lang="ar_AR")
        self.assertEqual(outputs, [{"translation_text": "هذا إختبار"}])

        outputs = translator("This is a test",
                             src_lang="en_XX",
                             tgt_lang="hi_IN")
        self.assertEqual(outputs, [{"translation_text": "यह एक परीक्षण है"}])

        # src_lang, tgt_lang can be defined at pipeline call time
        translator = pipeline(task="translation",
                              model=model,
                              tokenizer=tokenizer,
                              src_lang="en_XX",
                              tgt_lang="ar_AR")
        outputs = translator("This is a test")
        self.assertEqual(outputs, [{"translation_text": "هذا إختبار"}])
コード例 #2
0
    def __init__(
        self,
        model_or_path: str = "facebook/mbart-large-50-many-to-many-mmt",
        tokenizer_path: str = None,
        device: str = "auto",
        model_options: dict = None,
        tokenizer_options: dict = None,
    ):
        """
        Instantiates a multilingual transformer model for translation.

        {{params}}
        {{model_or_path}} The path or the name of the model. Equivalent to the first argument of AutoModel.from_pretrained().
        {{device}} "cpu", "gpu" or "auto". If it's set to "auto", will try to select a GPU when available or else fallback to CPU.
        {{tokenizer_path}} The path to the tokenizer, only if it is different from `model_or_path`; otherwise, leave it as `None`.
        {{model_options}} The keyword arguments passed to the transformer model, which is a mBART-Large for condition generation.
        {{tokenizer_options}} The keyword arguments passed to the tokenizer model, which is a mBART-50 Fast Tokenizer.
        """
        self.model_or_path = model_or_path
        self.device = _select_device(device)

        # Resolve default values
        tokenizer_path = tokenizer_path or self.model_or_path
        model_options = model_options or {}
        tokenizer_options = tokenizer_options or {}

        self.tokenizer = MBart50TokenizerFast.from_pretrained(
            tokenizer_path, **tokenizer_options)

        if model_or_path.endswith(".pt"):
            self.bart_model = torch.load(model_or_path,
                                         map_location=self.device).eval()
        else:
            self.bart_model = (MBartForConditionalGeneration.from_pretrained(
                self.model_or_path, **model_options).to(self.device).eval())
コード例 #3
0
 def __init__(self) -> None:
     self.model = MBartForConditionalGeneration.from_pretrained(
         "facebook/mbart-large-50-many-to-many-mmt"
     )
     self.tokenizer = MBart50TokenizerFast.from_pretrained(
         "facebook/mbart-large-50-many-to-many-mmt"
     )
コード例 #4
0
 def __init__(self):
     try:
         # using the latest model from facebook for many to many language translations
         model_name = "facebook/mbart-large-50-many-to-many-mmt"
         self.model = MBartForConditionalGeneration.from_pretrained(
             model_name)
         self.tokenizer = MBart50TokenizerFast.from_pretrained(model_name)
     except Exception as e:
         logging.error(f"Error initializing model. {e}")
コード例 #5
0
 def __init__(self, infile: str, src_lang: str, tgt_lang: str,
              max_len: int):
     self.tokenizer = MBart50TokenizerFast.from_pretrained(
         "facebook/mbart-large-en-ro", src_lang=src_lang, tgt_lang=tgt_lang)
     self.data = pd.read_csv(infile, sep="\t", error_bad_lines=False)
     self.max_len = max_len
     self.src_lang = src_lang
     self.tgt_lang = tgt_lang
     self.preprocess_data()
コード例 #6
0
    def __init__(self):

        self.model = MBartForConditionalGeneration.from_pretrained(
            'facebook/mbart-large-50-many-to-many-mmt')
        self.tokenizer = MBart50TokenizerFast.from_pretrained(
            'facebook/mbart-large-50-many-to-many-mmt')
        self.supported_langs = [
            'en_XX', 'gu_IN', 'hi_IN', 'bn_IN', 'ml_IN', 'mr_IN', 'ta_IN',
            'te_IN'
        ]
コード例 #7
0
    def check_model(self, use_cache):
        from optimum.intel.openvino import OVMBartForConditionalGeneration
        from transformers import MBart50TokenizerFast

        model = OVMBartForConditionalGeneration.from_pretrained(
            "facebook/mbart-large-50-many-to-many-mmt",
            use_cache=use_cache,
            from_pt=True)
        tokenizer = MBart50TokenizerFast.from_pretrained(
            "facebook/mbart-large-50-many-to-many-mmt")

        article_hi = "संयुक्त राष्ट्र के प्रमुख का कहना है कि सीरिया में कोई सैन्य समाधान नहीं है"
        tokenizer.src_lang = "hi_IN"
        encoded_hi = tokenizer(article_hi, return_tensors="pt")
        generated_tokens = model.generate(
            **encoded_hi,
            forced_bos_token_id=tokenizer.lang_code_to_id["fr_XX"])

        expected_tokens = [[
            2,
            250008,
            636,
            21861,
            8,
            96,
            242,
            136840,
            222939,
            1103,
            242,
            379,
            653,
            242,
            53,
            10,
            452,
            8,
            29806,
            128683,
            22,
            51712,
            5,
            2,
        ]]

        self.assertListEqual(generated_tokens.tolist(), expected_tokens)

        decoded_fr = tokenizer.batch_decode(generated_tokens,
                                            skip_special_tokens=True)[0]
        self.assertEqual(
            decoded_fr,
            "Le chef de l 'ONU affirme qu 'il n 'y a pas de solution militaire en Syria."
        )
コード例 #8
0
 def __init__(self, config):
     model_name = config.get("model_name", None)
     model_path = config.get("model_path", None)
     device = config.get("device", 0)  # default on gpu 0
     self.tokenizer = MBart50TokenizerFast.from_pretrained(model_path)
     self.model = MBartForConditionalGeneration.from_pretrained(model_path)
     self.model.eval()
     self.model.half()
     self.device = torch.device(
         "cpu" if device < 0 else "cuda:{}".format(device))
     if self.device.type == "cuda":
         self.model = self.model.to(self.device)
コード例 #9
0
    def __init__(self):
        self._tokenizer = MBart50TokenizerFast.from_pretrained(TOKENIZER_PATH)

        special_tokens = [
            'masc', 'fem', 'neut', 'undefined_g', 'past', 'pres', 'fut',
            'sing', 'plur', 'undefined_n'
        ]

        self.special_tokens = {
            k: v
            for k, v in zip(special_tokens,
                            self._tokenizer.additional_special_tokens)
        }

        self.bos_token_id = self._tokenizer.bos_token_id
        self.eos_token_id = self._tokenizer.eos_token_id

        self.vocab_size = property(lambda: len(self._tokenizer.get_vocab()))
コード例 #10
0
    def load(self, path):
        """
        Loads a model specified by path.

        Args:
            path: model path

        Returns:
            (model, tokenizer)
        """

        if path.startswith("Helsinki-NLP"):
            model = MarianMTModel.from_pretrained(path)
            tokenizer = MarianTokenizer.from_pretrained(path)
        else:
            model = MBartForConditionalGeneration.from_pretrained(path)
            tokenizer = MBart50TokenizerFast.from_pretrained(path)

        # Apply model initialization routines
        model = self.prepare(model)

        return (model, tokenizer)
コード例 #11
0
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
import pandas as pd

#########
# Mbart50
#########

path_to_new_dataset = '../../../03_dataset/task_01/subtask1-document/additional_training_data'

model = MBartForConditionalGeneration.from_pretrained(
    "facebook/mbart-large-50-many-to-many-mmt")
tokenizer = MBart50TokenizerFast.from_pretrained(
    "facebook/mbart-large-50-many-to-many-mmt")

translate_sentence = 'I like icecream.'

# translate Eng to Hindi
tokenizer.src_lang = "en_XX"
encoded_hi = tokenizer(translate_sentence, return_tensors="pt")
generated_tokens = model.generate(
    **encoded_hi, forced_bos_token_id=tokenizer.lang_code_to_id["de_DE"])
tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)

# ACLED (EN)
acled_en_pos = pd.read_json(f"{path_to_new_dataset}/acled_eng.json",
                            lines=True).rename(columns={
                                "notes": "text",
                                "label": "label"
                            })
acled_en_pos_select = acled_en_pos[:6928]
def get_pipeline():
    model = MBartForConditionalGeneration.from_pretrained(
        "facebook/mbart-large-50-many-to-many-mmt")
    tokenizer = MBart50TokenizerFast.from_pretrained(
        "facebook/mbart-large-50-many-to-many-mmt")
    return model, tokenizer
コード例 #13
0
# hf-experiments
# @author Loreto Parisi (loretoparisi at gmail dot com)
# Copyright (c) 2020-2021 Loreto Parisi (loretoparisi at gmail dot com)
# HF: https://huggingface.co/facebook/mbart-large-50-one-to-many-mmt

import os
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast

article_en = "The head of the United Nations says there is no military solution in Syria"
model = MBartForConditionalGeneration.from_pretrained(
    "facebook/mbart-large-50-one-to-many-mmt",
    cache_dir=os.getenv("cache_dir", "../../models"))
tokenizer = MBart50TokenizerFast.from_pretrained(
    "facebook/mbart-large-50-one-to-many-mmt",
    src_lang="en_XX",
    cache_dir=os.getenv("cache_dir", "../../models"))

model_inputs = tokenizer(article_en, return_tensors="pt")

# translate from English to Hindi
generated_tokens = model.generate(
    **model_inputs, forced_bos_token_id=tokenizer.lang_code_to_id["hi_IN"])
tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
# => 'संयुक्त राष्ट्र के नेता कहते हैं कि सीरिया में कोई सैन्य समाधान नहीं है'

# translate from English to Chinese
generated_tokens = model.generate(
    **model_inputs, forced_bos_token_id=tokenizer.lang_code_to_id["zh_CN"])
decoded = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
# => '联合国首脑说,叙利亚没有军事解决办法'
print(decoded)
コード例 #14
0
def main(params):
    """ Finetunes the mBart50 model on some languages and
    then evaluates the BLEU score for each direction."""

    if params.wandb:
        wandb.init(project='mnmt', entity='nlp-mnmt-project', group='finetuning',
            config={k: v for k, v in params.__dict__.items() if isinstance(v, (float, int, str, list))})

    new_root_path = params.location
    new_name = params.name
    logger = logging.TrainLogger(params)
    logger.make_dirs()
    logger.save_params()

    # load model and tokenizer
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50")
    model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50").to(device)
    optimizer = torch.optim.Adam(model.parameters())

    # scale in terms of max lr
    lr_scale = params.max_lr * np.sqrt(params.warmup_steps)
    scheduler = WarmupDecay(optimizer, params.warmup_steps, 1, lr_scale=lr_scale)

    # set dropout
    model.config.dropout = params.dropout 
    model.config.attention_dropout = params.dropout

    def pipeline(dataset, langs, batch_size, max_len):

        cols = ['input_ids_' + l for l in langs]

        def tokenize_fn(example):
            """apply tokenization"""
            l_tok = []
            for lang in langs:
                encoded = tokenizer.encode(example[lang])
                encoded[0] = tokenizer.lang_code_to_id[LANG_CODES[lang]]
                l_tok.append(encoded)
            return {'input_ids_' + l: tok for l, tok in zip(langs, l_tok)}

        def pad_seqs(examples):
            """Apply padding"""
            ex_langs = list(zip(*[tuple(ex[col] for col in cols) for ex in examples]))
            ex_langs = tuple(pad_sequence(x, batch_first=True, max_len=max_len) for x in ex_langs)
            return ex_langs

        dataset = filter_languages(dataset, langs)
        dataset = dataset.map(tokenize_fn)
        dataset.set_format(type='torch', columns=cols)
        num_examples = len(dataset)
        print('-'.join(langs) + ' : {} examples.'.format(num_examples))
        dataloader = torch.utils.data.DataLoader(dataset,
                                                batch_size=batch_size,
                                                collate_fn=pad_seqs)
        return dataloader, num_examples

    # load data
    dataset = load_dataset('ted_multi')
    train_dataset = dataset['train']
    test_dataset = dataset['validation' if params.split == 'val' else 'test']

    # preprocess splits for each direction
    num_train_examples = {}
    train_dataloaders, val_dataloaders, test_dataloaders = {}, {}, {}
    for l1, l2 in combinations(params.langs, 2):
        train_dataloaders[l1+'-'+l2], num_train_examples[l1+'-'+l2] = pipeline(
            train_dataset, [l1, l2], params.batch_size, params.max_len)
        test_dataloaders[l1+'-'+l2], _ = pipeline(test_dataset, [l1, l2], params.batch_size, params.max_len)

    # print dataset sizes
    for direction, num in num_train_examples.items():
        print(direction, ': {} examples.'.format(num))

    def freeze_layers(layers, unfreeze=False):
        for n in layers:
            for parameter in model.model.encoder.layers[n].parameters():
                parameter.requires_grad = unfreeze

    # define loss function
    if params.label_smoothing is not None:
        loss_object = LabelSmoothingLoss(params.label_smoothing)
        loss_fn = lambda out, tar: loss_object(out.logits, tar)
    else:
        loss_fn = lambda out, tar: out.loss

    # train the model
    _target = torch.tensor(1.0).to(device)
    def train_step(x, y, aux=False):

        y_inp, y_tar = y[:,:-1].contiguous(), y[:,1:].contiguous()
        enc_mask, dec_mask = (x != 0), (y_inp != 0)

        x, y_inp, y_tar, enc_mask, dec_mask = to_devices(
          (x, y_inp, y_tar, enc_mask, dec_mask), device)

        model.train()
        if aux: freeze_layers(params.frozen_layers, unfreeze=True)
        output = model(input_ids=x, decoder_input_ids=y_inp,
                   labels=y_tar, attention_mask=enc_mask,
                   decoder_attention_mask=dec_mask)
        optimizer.zero_grad()
        loss = loss_fn(output, y_tar)
        loss.backward(retain_graph=aux)

        if aux: freeze_layers(params.frozen_layers)
        torch.set_grad_enabled(aux)

        x_enc = output.encoder_last_hidden_state
        y_enc = model.model.encoder(y_inp, attention_mask=dec_mask)['last_hidden_state']
        x_enc = torch.max(x_enc + -999 * (1-enc_mask.type(x_enc.dtype)).unsqueeze(-1), dim=1)[0]
        y_enc = torch.max(y_enc + -999 * (1-dec_mask.type(y_enc.dtype)).unsqueeze(-1), dim=1)[0]
        aux_loss = F.cosine_embedding_loss(x_enc, y_enc, _target)
        scaled_aux_loss = params.aux_strength * aux_loss
        
        torch.set_grad_enabled(True)
        if aux: scaled_aux_loss.backward()

        optimizer.step()
        scheduler.step()

        accuracy = accuracy_fn(output.logits, y_tar)

        return loss.item(), aux_loss.item(), accuracy.item()

    # prepare iterators
    iterators = {direction: iter(loader) for direction, loader in train_dataloaders.items()}

    # compute sampling probabilites (and set zero shot directions to 0)
    num_examples = num_train_examples.copy()
    zero_shots = [(params.zero_shot[i]+'-'+params.zero_shot[i+1]) for i in range(0, len(params.zero_shot), 2)]
    for d in zero_shots:
        num_examples[d] = 0
    directions, num_examples = list(num_examples.keys()), np.array(list(num_examples.values()))
    dir_dist = (num_examples ** params.temp) / ((num_examples ** params.temp).sum())

    #train
    losses, aux_losses, accs = [], [], []
    start_ = time.time()
    for i in range(params.train_steps):

        # sample a direction
        direction = directions[int(np.random.choice(len(num_examples), p=dir_dist))]
        try: # check iterator is not exhausted
            x, y = next(iterators[direction])
        except StopIteration:
            iterators[direction] = iter(train_dataloaders[direction])
            x, y = next(iterators[direction])
        x, y = get_direction(x, y, sample=not params.single_direction)
           
        # train on the direction
        loss, aux_loss, acc = train_step(x, y, aux=params.auxiliary)
        losses.append(loss)
        aux_losses.append(aux_loss)
        accs.append(acc)

        if i % params.verbose == 0:
            print('Batch {} Loss {:.4f} Aux Loss {:.4f} Acc {:.4f} in {:.4f} secs per batch'.format(
                i, np.mean(losses[-params.verbose:]), np.mean(aux_losses[-params.verbose:]),
                np.mean(accs[-params.verbose:]), (time.time() - start_)/(i+1)))
        if params.wandb:
            wandb.log({'train_loss':loss, 'aux_loss':aux_loss, 'train_acc':acc})

    # save results
    if params.save:
        logger.save_model(params.train_steps, model, optimizer, scheduler=scheduler)
    
    train_results = {'loss':[np.mean(losses)], 'aux_loss':[np.mean(aux_losses)], 'accuarcy':[np.mean(accs)]}
    pd.DataFrame(train_results).to_csv(logger.root_path + '/train_results.csv', index=False)

    # evaluate the model
    def evaluate(x, y, y_code, bleu):
        y_inp, y_tar = y[:,:-1].contiguous(), y[:,1:].contiguous()
        enc_mask = (x != 0)
        x, y_inp, y_tar, enc_mask = to_devices(
          (x, y_inp, y_tar, enc_mask), device)
        
        model.eval()
        y_pred = model.generate(input_ids=x, decoder_start_token_id=y_code,
            attention_mask=enc_mask, max_length=params.max_len+1,
            num_beams=params.num_beams, length_penalty=params.length_penalty,
            early_stopping=True)
        bleu(y_pred[:,1:], y_tar)

    test_results = {}
    for direction, loader in test_dataloaders.items():
        alt_direction = '-'.join(reversed(direction.split('-')))
        bleu1, bleu2 = BLEU(), BLEU()
        bleu1.set_excluded_indices([0, 2])
        bleu2.set_excluded_indices([0, 2])
        x_code = tokenizer.lang_code_to_id[LANG_CODES[direction.split('-')[0]]]
        y_code = tokenizer.lang_code_to_id[LANG_CODES[direction.split('-')[-1]]]

        start_ = time.time()
        for i, (x, y) in enumerate(loader):
            if params.test_batches is not None:
                if i > params.test_batches:
                    break

            evaluate(x, y, y_code, bleu1)
            if not params.single_direction:
                evaluate(y, x, x_code, bleu2)
            if i % params.verbose == 0:
                bl1, bl2 = bleu1.get_metric(), bleu2.get_metric()
                print('Batch {} Bleu1 {:.4f} Bleu2 {:.4f} in {:.4f} secs per batch'.format(
                    i, bl1, bl2, (time.time() - start_)/(i+1)))
                if params.wandb:
                    wandb.log({'Bleu1':bl1, 'Bleu2':bl2})

        test_results[direction] = [bleu1.get_metric()]
        test_results[alt_direction] = [bleu2.get_metric()]

    # save test_results
    pd.DataFrame(test_results).to_csv(logger.root_path + '/test_results.csv', index=False)

    if params.wandb:
        wandb.finish()
コード例 #15
0
def main(params):
    """ Evaluates a finetuned model on the test or validation dataset."""

    # load model and tokenizer
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50")
    config = MBartConfig.from_pretrained("facebook/mbart-large-50")
    model = MBartForConditionalGeneration(config).to(device)
    checkpoint_location = params.location + '/' + params.name + '/checkpoint/checkpoint'
    model, _, _, _ = logging.load_checkpoint(checkpoint_location, device,
                                             model)

    def pipeline(dataset, langs, batch_size, max_len):

        cols = ['input_ids_' + l for l in langs]

        def tokenize_fn(example):
            """apply tokenization"""
            l_tok = []
            for lang in langs:
                encoded = tokenizer.encode(example[lang])
                encoded[0] = tokenizer.lang_code_to_id[LANG_CODES[lang]]
                l_tok.append(encoded)
            return {'input_ids_' + l: tok for l, tok in zip(langs, l_tok)}

        def pad_seqs(examples):
            """Apply padding"""
            ex_langs = list(
                zip(*[tuple(ex[col] for col in cols) for ex in examples]))
            ex_langs = tuple(
                pad_sequence(x, batch_first=True, max_len=max_len)
                for x in ex_langs)
            return ex_langs

        dataset = filter_languages(dataset, langs)
        dataset = dataset.map(tokenize_fn)
        dataset.set_format(type='torch', columns=cols)
        num_examples = len(dataset)
        print('-'.join(langs) + ' : {} examples.'.format(num_examples))
        dataloader = torch.utils.data.DataLoader(dataset,
                                                 batch_size=batch_size,
                                                 collate_fn=pad_seqs)
        return dataloader, num_examples

    # load data
    if params.split == 'val':
        test_dataset = load_dataset('ted_multi', split='validation')
    elif params.split == 'test':
        test_dataset = load_dataset('ted_multi', split='test')
    elif params.split == 'combine':
        test_dataset = load_dataset('ted_multi', split='validation+test')
    else:
        raise NotImplementedError

    # preprocess splits for each direction
    test_dataloaders = {}
    for l1, l2 in combinations(params.langs, 2):
        test_dataloaders[l1 + '-' + l2], _ = pipeline(test_dataset, [l1, l2],
                                                      params.batch_size,
                                                      params.max_len)

    # evaluate the model
    def evaluate(x, y, y_code, bleu):
        y_inp, y_tar = y[:, :-1].contiguous(), y[:, 1:].contiguous()
        enc_mask = (x != 0)
        x, y_inp, y_tar, enc_mask = to_devices((x, y_inp, y_tar, enc_mask),
                                               device)

        model.eval()
        y_pred = model.generate(input_ids=x,
                                decoder_start_token_id=y_code,
                                attention_mask=enc_mask,
                                max_length=x.size(1) + 1,
                                num_beams=params.num_beams,
                                length_penalty=params.length_penalty,
                                early_stopping=True)
        bleu(y_pred[:, 1:], y_tar)

    test_results = {}
    for direction, loader in test_dataloaders.items():
        alt_direction = '-'.join(reversed(direction.split('-')))
        bleu1, bleu2 = BLEU(), BLEU()
        bleu1.set_excluded_indices([0, 2])
        bleu2.set_excluded_indices([0, 2])
        x_code = tokenizer.lang_code_to_id[LANG_CODES[direction.split('-')[0]]]
        y_code = tokenizer.lang_code_to_id[LANG_CODES[direction.split('-')
                                                      [-1]]]

        start_ = time.time()
        for i, (x, y) in enumerate(loader):
            if params.test_batches is not None:
                if i > params.test_batches:
                    break

            evaluate(x, y, y_code, bleu1)
            if not params.single_direction:
                evaluate(y, x, x_code, bleu2)
            if i % params.verbose == 0:
                bl1, bl2 = bleu1.get_metric(), bleu2.get_metric()
                print(
                    'Batch {} Bleu1 {:.4f} Bleu2 {:.4f} in {:.4f} secs per batch'
                    .format(i, bl1, bl2, (time.time() - start_) / (i + 1)))

        bl1, bl2 = bleu1.get_metric(), bleu2.get_metric()
        test_results[direction] = [bl1]
        test_results[alt_direction] = [bl2]
        print(direction, bl1, bl2)

    # save test_results
    pd.DataFrame(test_results).to_csv(params.location + '/' + params.name +
                                      '/test_results.csv',
                                      index=False)
コード例 #16
0
ファイル: model.py プロジェクト: andim461/project-ML
from transformers import PretrainedConfig
import pytorch_lightning as pl
from transformers import MBartConfig
from transformers import MBart50TokenizerFast
from transformers.modeling_outputs import Seq2SeqLMOutput
from transformers.models.mbart import MBartForConditionalGeneration
from transformers.optimization import AdamW, get_constant_schedule_with_warmup
import torch

tokenizer = MBart50TokenizerFast.from_pretrained("./tokenizer",
                                                 src_lang='ru_RU',
                                                 tgt_lang='ru_RU')


class Seq2SeqConfig(PretrainedConfig):

    model_type = "mbart"
    keys_to_ignore_at_inference = ["past_key_values"]

    def __init__(self,
                 vocab_size=50265,
                 max_position_embeddings=1024,
                 encoder_layers=12,
                 encoder_ffn_dim=4096,
                 encoder_attention_heads=16,
                 decoder_layers=12,
                 decoder_ffn_dim=4096,
                 decoder_attention_heads=16,
                 encoder_layerdrop=0.0,
                 decoder_layerdrop=0.0,
                 use_cache=True,