Пример #1
0
def train_MBart(data_path,tokenizer,output_path):
    model_config = MBartConfig(vocab_size=300,d_model=10,encoder_layers=1,decoder_layers=1,encoder_attention_heads=1,decoder_attention_heads=1,encoder_ffn_dim=10,decoder_ffn_dim=10,max_position_embeddings=512)
    model = MBartModel(config=model_config)

    sentences = {} #associates lang_id with list of sentences
    
    #read data files and separate language data into different lists
    lang_id = 0 #counter for languages in dataset
    for sentence_file in os.listdir(data_path):
        with open(data_path+sentence_file,'r') as data:
            sentences[lang_id] = []
            for line in data:
                sentences[lang_id].append(line)
        lang_id += 1

    #create token sequences to pass into model
    src_lang,tgt_lang = (sentences[lang_id] for lang_id in sentences)
    batch = tokenizer.prepare_seq2seq_batch(src_texts=src_lang,tgt_texts=tgt_lang,return_tensors='pt')
    
    
    model(input_ids=batch['input_ids'],decoder_input_ids=batch['labels'])
    model.save_pretrained(output_path)
Пример #2
0
TGT_DATA = "./data_tgt_de.txt"
SRC_DATA = "./data_source_hsb.txt"

from transformers import MBartForConditionalGeneration, MBartTokenizer, MBartModel, MBartConfig

#Read from the data files
src_txts = []
tgt_txts = []
with open(SRC_DATA) as f:
    for line in f:
        src_txts.append(line)

with open(TGT_DATA) as f:
    for line in f:
        tgt_txts.append(line)

tokenizer = MBartTokenizer.from_pretrained('./tokenizer_de_hsb.model')
batch = tokenizer.prepare_seq2seq_batch(src_texts=src_txts,
                                        src_lang="en_XX",
                                        tgt_texts=tgt_txts,
                                        tgt_lang="ro_RO",
                                        return_tensors="pt")
config = MBartConfig()
model = MBartModel(config)
model(input_ids=batch['input_ids'],
      decoder_input_ids=batch['labels'])  # forward pass
model.save_pretrained('./trained_model')