Ejemplo n.º 1
0
    def setUp(self):
        super().setUp()

        # We have a SentencePiece fixture for testing
        tokenizer = PegasusTokenizer(SAMPLE_VOCAB,
                                     offset=0,
                                     mask_token_sent=None,
                                     mask_token="[MASK]")
        tokenizer.save_pretrained(self.tmpdirname)
Ejemplo n.º 2
0
def generate_summaries_or_translations(
    examples: List[str],
    out_file: str,
    model_name: str,
    batch_size: int = 8,
    device: str = DEFAULT_DEVICE,
    fp16=False,
    task="summarization",
    prefix=None,
    tokenizer_name=None,
    vocab_file=None,
    max_length=None,
    **generate_kwargs,
) -> Dict:
    """Save model.generate results to <out_file>, and return how long it took."""
    fout = Path(out_file).open("w", encoding="utf-8")
    model_name = str(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
    if fp16:
        model = model.half()

    if tokenizer_name and tokenizer_name == "t5" and vocab_file:
        from transformers import T5TokenizerFast, T5Tokenizer
        print(vocab_file)
        tokenizer = T5Tokenizer(vocab_file)
        print("custom tokenizer", tokenizer)
    elif tokenizer_name and tokenizer_name == "pegasus" and vocab_file:
        from transformers import PegasusTokenizerFast, PegasusTokenizer
        print(vocab_file)
        tokenizer = PegasusTokenizer(vocab_file)
        print("custom tokenizer", tokenizer)
    else:
        try:
            tokenizer = AutoTokenizer.from_pretrained(model_name)
        except:
            import os
            print(os.path.basename(os.path.dirname(model_name)))
            tokenizer = AutoTokenizer.from_pretrained(
                os.path.basename(os.path.dirname(model_name)))
    logger.info(f"Inferred tokenizer type: {tokenizer.__class__}"
                )  # if this is wrong, check config.model_type.

    start_time = time.time()
    # update config with task specific params
    # use_task_specific_params(model, task)
    if prefix is None:
        prefix = prefix or getattr(model.config, "prefix", "") or ""
    for examples_chunk in tqdm(list(chunks(examples, batch_size))):
        examples_chunk = [prefix + text for text in examples_chunk]
        batch = tokenizer(examples_chunk,
                          return_tensors="pt",
                          truncation=True,
                          padding="longest",
                          max_length=max_length).to(device)

        summaries = model.generate(
            input_ids=batch.input_ids,
            attention_mask=batch.attention_mask,
            # max_length=128,
            # num_beams=1,
            # no_repeat_ngram_size=2,
            # num_return_sequences=1,  # 다섯 개의 문장을 리턴
            # early_stopping=True,
            # num_beams=4,
            bos_token_id=1,
            eos_token_id=2,
            pad_token_id=3,
            # length_penalty=2.0,
            decoder_start_token_id=1,
            # no_repeat_ngram_size=3,
            do_sample=True,  # 샘플링 전략 사용
            max_length=128,  # 최대 디코딩 길이는 50
            top_k=50,  # 확률 순위가 50위 밖인 토큰은 샘플링에서 제외
            top_p=0.95,  # 누적 확률이 95%인 후보집합에서만 생성
            num_return_sequences=1  # 3개의 결과를 디코딩해낸다
            # **generate_kwargs,
        )
        print(summaries)
        dec = tokenizer.batch_decode(summaries,
                                     skip_special_tokens=True,
                                     clean_up_tokenization_spaces=False)
        for hypothesis in dec:
            print(hypothesis)
            fout.write(hypothesis + "\n")
            fout.flush()
    fout.close()
    runtime = int(time.time() - start_time)  # seconds
    n_obs = len(examples)
    return dict(n_obs=n_obs,
                runtime=runtime,
                seconds_per_sample=round(runtime / n_obs, 4))
Ejemplo n.º 3
0
    def setUp(self):
        super().setUp()

        # We have a SentencePiece fixture for testing
        tokenizer = PegasusTokenizer(SAMPLE_VOCAB)
        tokenizer.save_pretrained(self.tmpdirname)