def bulk_update_local_configs(models,
                              update_dict=DEFAULT_UPDATE_DICT,
                              save_dir=MIRROR_DIR):
    failures = []
    for slug in tqdm_nice(models):
        assert slug.startswith('opus-mt')
        try:
            cfg = MarianConfig.from_pretrained(f'Helsinki-NLP/{slug}')
        except OSError:
            failures.append(slug)
            continue
        for k, v in update_dict.items():
            setattr(cfg, k, v)
        # if a new value depends on a cfg value, add code here
        # e.g. cfg.decoder_start_token_id = cfg.pad_token_id

        dest_dir = (save_dir / 'Helsinki-NLP' / slug)
        if not dest_dir.exists():
            print(f'making {dest_dir}')
            dest_dir.mkdir(exist_ok=True)
        cfg.save_pretrained(dest_dir)
        assert cfg.from_pretrained(dest_dir).model_type == 'marian'
예제 #2
0
 def config(self):
     config = MarianConfig.from_pretrained("sshleifer/tiny-marian-en-de")
     return config
def main(args):
    df = pd.read_csv(args.input_fname,
                     encoding='utf-8')[[args.source_lang, args.target_lang]]
    logging.info(f'Loaded {df.shape}')

    #convert to dictionary
    j = {'translation': []}
    for i in df.itertuples():
        j['translation'] += [{args.source_lang: i[1], args.target_lang: i[2]}]

    train_dataset = Dataset.from_dict(j)
    raw_datasets = train_dataset.train_test_split(test_size=args.valid_pct,
                                                  seed=args.seed)
    logging.info(f'Datasets created {raw_datasets}')

    tokenizer = MarianTokenizer.from_pretrained(args.output_dir)
    logging.info(f'Tokenizer loaded from {args.output_dir}')

    #tokenize datasets
    tokenized_datasets = raw_datasets.map(
        partial(preprocess_function,
                tokenizer=tokenizer,
                max_input_length=args.max_input_length,
                max_target_length=args.max_target_length,
                source_lang=args.source_lang,
                target_lang=args.target_lang),
        batched=True,
    )
    logging.info(f'Tokenized datasets: {tokenized_datasets}')

    #filter those with too few tokens
    tokenized_datasets = tokenized_datasets.filter(
        lambda example: len(example['translation']['zh']) > 2)
    tokenized_datasets = tokenized_datasets.filter(
        lambda example: len(example['translation']['th']) > 2)
    logging.info(
        f'Tokenized datasets when filtered out less than 2 tokens per sequence: {tokenized_datasets}'
    )

    config = MarianConfig.from_pretrained(args.output_dir)
    model = MarianMTModel(config)
    logging.info(f'Loaded model from {args.output_dir}')

    training_args = Seq2SeqTrainingArguments(
        args.output_dir,
        evaluation_strategy="epoch",
        load_best_model_at_end=True,
        learning_rate=args.learning_rate,
        warmup_ratio=args.warmup_ratio,
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.batch_size,
        weight_decay=args.weight_decay,
        save_total_limit=args.save_total_limit,
        num_train_epochs=args.num_train_epochs,
        predict_with_generate=True,
        fp16=args.fp16,
        seed=args.seed,
    )
    logging.info(f'Training congig {training_args}')

    data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
    trainer = Seq2SeqTrainer(
        model,
        training_args,
        train_dataset=tokenized_datasets["train"],
        eval_dataset=tokenized_datasets["test"],
        data_collator=data_collator,
        tokenizer=tokenizer,
        compute_metrics=partial(compute_metrics,
                                tokenizer=tokenizer,
                                metric=metric,
                                metric_tokenize=args.metric_tokenize),
    )
    logging.info(f'Trainer created')

    trainer.train()

    model.save_pretrained(f"{args.output_dir}_best")
    tokenizer.save_pretrained(f"{args.output_dir}_best")
    logging.info(f'Best model saved')

    model.cpu()
    src_text = ['我爱你', '国王有很多心事。我明白']
    translated = model.generate(
        **tokenizer(src_text, return_tensors="pt", padding=True))
    print([tokenizer.decode(t, skip_special_tokens=True) for t in translated])