# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""Optimization for GPT2 model"""

import torch
import torch.nn as nn

from transformers.modeling_gpt2 import Attention, GPT2Model

from fastseq.logging import get_logger
from fastseq.utils.api_decorator import replace

logger = get_logger(__name__)


@replace(Attention)
class AttentionV2(Attention):
    def __init__(self, nx, n_ctx, config, scale=False, num_beams=1):
        super().__init__(nx=nx, n_ctx=n_ctx, config=config, scale=scale)

        self.cache_input_key = None
        self.cache_input_value = None
        self.cache_input_len = -1
        self.num_beams = num_beams

    def _attn(self, q, k, v, attention_mask=None, head_mask=None,
              output_attentions=False):
        w1 = torch.einsum(
            "bmhtd,bnhsd->bmhts",
            q.view((q.size(0) // self.num_beams, self.num_beams) + q.shape[1:]),
예제 #2
0
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Register models"""

import logging
from fastseq.logging import get_logger
logger = get_logger(__name__, logging.INFO)

import fastseq.models.prophetnet_fs
import fastseq.models.unilm_hf

try:
    import fastseq.models.modeling_auto_hf
except ImportError as error:
    logger.warning('transformers can not be imported.')
except:
    logger.error("Unexpected error: {}".format(sys.exc_info()[0]))
    raise
예제 #3
0
def generate_summaries_or_translations_fast(
    examples: list,
    out_file: str,
    model_name: str,
    batch_size: int = 8,
    device: str = DEFAULT_DEVICE,
    fp16=False,
    task="summarization",
    decoder_start_token_id=None,
    no_repeat_ngram_size=None,
    skip_special_tokens=True,
    clean_up_tokenization_spaces=False,
    preprocess_workers=2,
    postprocess_workers=2,
    return_tensors="pt",
    truncation=True,
    padding="max_length",
    max_tokenizer_length=None,
    max_gen_length=None,
    max_new_tokens=None,
    use_causal_lm=False,
    output_summaries_only=False,
    output_sequence_scores=False,
    num_beams=None,
    eos_token_id=None,
    temperature=None,
    top_k=None,
    top_p=None,
    do_sample=None,
    repetition_penalty=None,
    num_return_sequences=None,
    padding_side=None,
    use_slow_tokenizer=False,
    **gen_kwargs,
) -> None:
    """Run generation"""
    import fastseq  #pylint: disable=import-outside-toplevel
    from fastseq.logging import get_logger #pylint: disable=import-outside-toplevel
    global logger
    logger = get_logger(__name__, logging.INFO)
    fout = Path(out_file).open("w", encoding="utf-8")
    model_name = str(model_name)
    if use_causal_lm:
        model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
        tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast = not use_slow_tokenizer)
        tokenizer.pad_token = tokenizer.eos_token
    else:
        model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
        tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast = not use_slow_tokenizer)

    if fp16:
        model = model.half()
    if decoder_start_token_id is None:
        decoder_start_token_id = gen_kwargs.pop("decoder_start_token_id", None)
    if hasattr(tokenizer, 'model_max_length') and max_tokenizer_length is not None:
        tokenizer.model_max_length = max_tokenizer_length
    if padding_side is not None:
        tokenizer.padding_side = padding_side

    # update config with summarization specific params
    use_task_specific_params(model, task)

    data_queue = Queue()
    msg_queue =  Queue()
    p_list = []

    for _ in range(postprocess_workers):
        p = PostProcess(tokenizer, data_queue, msg_queue,
            skip_special_tokens, clean_up_tokenization_spaces)
        p_list.append(p)
        p.start()

    io_process = IOProcess( msg_queue, fout)
    io_process.start()

    dataset = TokenizeDataset(examples, tokenizer, model_name,
        model.config.prefix, return_tensors, truncation, padding,
        max_tokenizer_length)
    training_generator = torch.utils.data.DataLoader(dataset,
            batch_size=batch_size, num_workers = preprocess_workers,
            drop_last=False)
    try:
        for ind, batch in tqdm(enumerate(training_generator)):
            input_ids, attention_mask = batch
            input_ids = input_ids.view(input_ids.size(0), -1).to(device)
            attention_mask = attention_mask.view(input_ids.size(0), -1).to(device)
            input_ids, attention_mask = trim_batch(
              input_ids, tokenizer.pad_token_id, attention_mask)
            try:
                summaries = model.generate(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    decoder_start_token_id=decoder_start_token_id,
                    no_repeat_ngram_size=no_repeat_ngram_size,
                    max_length=max_gen_length,
                    max_new_tokens=max_new_tokens,
                    output_scores=output_sequence_scores,
                    return_dict_in_generate=output_sequence_scores,
                    num_beams=num_beams,
                    eos_token_id=eos_token_id,
                    temperature=temperature,
                    top_k=top_k,
                    top_p=top_p,
                    do_sample=do_sample,
                    repetition_penalty=repetition_penalty,
                    num_return_sequences=num_return_sequences,
                    **gen_kwargs,
                )
            except:
                logger.exception(sys.exc_info()[0])
                for p in p_list:
                    p.terminate()
                io_process.terminate()
                data_queue.close()
                msg_queue.close()
                sys.exit(1)
            if output_sequence_scores:
                sequences = summaries.sequences
            else:
                sequences = summaries
            scores_cpu = None
            if output_sequence_scores:
                if (type(summaries) in [BeamSearchEncoderDecoderOutput, 
                                        BeamSearchDecoderOnlyOutput, 
                                        BeamSampleDecoderOnlyOutput, 
                                        BeamSampleEncoderDecoderOutput]):
                        scores_cpu = summaries.sequences_scores.cpu()
                else: 
                    scores_cpu = torch.Tensor([float('nan')] * sequences.shape[0])
            if output_summaries_only:
                sequences = sequences[:, input_ids.shape[-1]:] 
            sequences_cpu = sequences.cpu()
            if (num_return_sequences is not None and num_return_sequences > 1):
                sequences_cpu = sequences_cpu.reshape([-1, num_return_sequences, sequences_cpu.shape[-1]])
                if (scores_cpu is not None):
                    scores_cpu = scores_cpu.reshape([-1, num_return_sequences])
            data_queue.put((ind, sequences_cpu, scores_cpu))
    except:
        logger.exception(sys.exc_info()[0])
        for p in p_list:
            p.terminate()
        io_process.terminate()
        data_queue.close()
        msg_queue.close()
        sys.exit(1)

    data_queue.put((-1, GENERATE_FINISHED, None))
    for p in p_list:
        p.join()
    msg_queue.put((-1, GENERATE_FINISHED, None))
    io_process.join()
    fout.close()