コード例 #1
0
 def test_generate_fp16(self):
     config, input_dict = self.model_tester.prepare_config_and_inputs()
     input_ids = input_dict["input_ids"]
     attention_mask = input_ids.ne(1).to(torch_device)
     model = MarianMTModel(config).eval().to(torch_device)
     if torch_device == "cuda":
         model.half()
     model.generate(input_ids, attention_mask=attention_mask)
     model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
コード例 #2
0
def main(args):
    df = pd.read_csv(args.input_fname,
                     encoding='utf-8')[[args.source_lang, args.target_lang]]
    logging.info(f'Loaded {df.shape}')

    #convert to dictionary
    j = {'translation': []}
    for i in df.itertuples():
        j['translation'] += [{args.source_lang: i[1], args.target_lang: i[2]}]

    train_dataset = Dataset.from_dict(j)
    raw_datasets = train_dataset.train_test_split(test_size=args.valid_pct,
                                                  seed=args.seed)
    logging.info(f'Datasets created {raw_datasets}')

    tokenizer = MarianTokenizer.from_pretrained(args.output_dir)
    logging.info(f'Tokenizer loaded from {args.output_dir}')

    #tokenize datasets
    tokenized_datasets = raw_datasets.map(
        partial(preprocess_function,
                tokenizer=tokenizer,
                max_input_length=args.max_input_length,
                max_target_length=args.max_target_length,
                source_lang=args.source_lang,
                target_lang=args.target_lang),
        batched=True,
    )
    logging.info(f'Tokenized datasets: {tokenized_datasets}')

    #filter those with too few tokens
    tokenized_datasets = tokenized_datasets.filter(
        lambda example: len(example['translation']['zh']) > 2)
    tokenized_datasets = tokenized_datasets.filter(
        lambda example: len(example['translation']['th']) > 2)
    logging.info(
        f'Tokenized datasets when filtered out less than 2 tokens per sequence: {tokenized_datasets}'
    )

    config = MarianConfig.from_pretrained(args.output_dir)
    model = MarianMTModel(config)
    logging.info(f'Loaded model from {args.output_dir}')

    training_args = Seq2SeqTrainingArguments(
        args.output_dir,
        evaluation_strategy="epoch",
        load_best_model_at_end=True,
        learning_rate=args.learning_rate,
        warmup_ratio=args.warmup_ratio,
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.batch_size,
        weight_decay=args.weight_decay,
        save_total_limit=args.save_total_limit,
        num_train_epochs=args.num_train_epochs,
        predict_with_generate=True,
        fp16=args.fp16,
        seed=args.seed,
    )
    logging.info(f'Training congig {training_args}')

    data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
    trainer = Seq2SeqTrainer(
        model,
        training_args,
        train_dataset=tokenized_datasets["train"],
        eval_dataset=tokenized_datasets["test"],
        data_collator=data_collator,
        tokenizer=tokenizer,
        compute_metrics=partial(compute_metrics,
                                tokenizer=tokenizer,
                                metric=metric,
                                metric_tokenize=args.metric_tokenize),
    )
    logging.info(f'Trainer created')

    trainer.train()

    model.save_pretrained(f"{args.output_dir}_best")
    tokenizer.save_pretrained(f"{args.output_dir}_best")
    logging.info(f'Best model saved')

    model.cpu()
    src_text = ['我爱你', '国王有很多心事。我明白']
    translated = model.generate(
        **tokenizer(src_text, return_tensors="pt", padding=True))
    print([tokenizer.decode(t, skip_special_tokens=True) for t in translated])