# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. """Optimization for GPT2 model""" import torch import torch.nn as nn from transformers.modeling_gpt2 import Attention, GPT2Model from fastseq.logging import get_logger from fastseq.utils.api_decorator import replace logger = get_logger(__name__) @replace(Attention) class AttentionV2(Attention): def __init__(self, nx, n_ctx, config, scale=False, num_beams=1): super().__init__(nx=nx, n_ctx=n_ctx, config=config, scale=scale) self.cache_input_key = None self.cache_input_value = None self.cache_input_len = -1 self.num_beams = num_beams def _attn(self, q, k, v, attention_mask=None, head_mask=None, output_attentions=False): w1 = torch.einsum( "bmhtd,bnhsd->bmhts", q.view((q.size(0) // self.num_beams, self.num_beams) + q.shape[1:]),
# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. """Register models""" import logging from fastseq.logging import get_logger logger = get_logger(__name__, logging.INFO) import fastseq.models.prophetnet_fs import fastseq.models.unilm_hf try: import fastseq.models.modeling_auto_hf except ImportError as error: logger.warning('transformers can not be imported.') except: logger.error("Unexpected error: {}".format(sys.exc_info()[0])) raise
def generate_summaries_or_translations_fast( examples: list, out_file: str, model_name: str, batch_size: int = 8, device: str = DEFAULT_DEVICE, fp16=False, task="summarization", decoder_start_token_id=None, no_repeat_ngram_size=None, skip_special_tokens=True, clean_up_tokenization_spaces=False, preprocess_workers=2, postprocess_workers=2, return_tensors="pt", truncation=True, padding="max_length", max_tokenizer_length=None, max_gen_length=None, max_new_tokens=None, use_causal_lm=False, output_summaries_only=False, output_sequence_scores=False, num_beams=None, eos_token_id=None, temperature=None, top_k=None, top_p=None, do_sample=None, repetition_penalty=None, num_return_sequences=None, padding_side=None, use_slow_tokenizer=False, **gen_kwargs, ) -> None: """Run generation""" import fastseq #pylint: disable=import-outside-toplevel from fastseq.logging import get_logger #pylint: disable=import-outside-toplevel global logger logger = get_logger(__name__, logging.INFO) fout = Path(out_file).open("w", encoding="utf-8") model_name = str(model_name) if use_causal_lm: model = AutoModelForCausalLM.from_pretrained(model_name).to(device) tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast = not use_slow_tokenizer) tokenizer.pad_token = tokenizer.eos_token else: model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device) tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast = not use_slow_tokenizer) if fp16: model = model.half() if decoder_start_token_id is None: decoder_start_token_id = gen_kwargs.pop("decoder_start_token_id", None) if hasattr(tokenizer, 'model_max_length') and max_tokenizer_length is not None: tokenizer.model_max_length = max_tokenizer_length if padding_side is not None: tokenizer.padding_side = padding_side # update config with summarization specific params use_task_specific_params(model, task) data_queue = Queue() msg_queue = Queue() p_list = [] for _ in range(postprocess_workers): p = PostProcess(tokenizer, data_queue, msg_queue, skip_special_tokens, clean_up_tokenization_spaces) p_list.append(p) p.start() io_process = IOProcess( msg_queue, fout) io_process.start() dataset = TokenizeDataset(examples, tokenizer, model_name, model.config.prefix, return_tensors, truncation, padding, max_tokenizer_length) training_generator = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers = preprocess_workers, drop_last=False) try: for ind, batch in tqdm(enumerate(training_generator)): input_ids, attention_mask = batch input_ids = input_ids.view(input_ids.size(0), -1).to(device) attention_mask = attention_mask.view(input_ids.size(0), -1).to(device) input_ids, attention_mask = trim_batch( input_ids, tokenizer.pad_token_id, attention_mask) try: summaries = model.generate( input_ids=input_ids, attention_mask=attention_mask, decoder_start_token_id=decoder_start_token_id, no_repeat_ngram_size=no_repeat_ngram_size, max_length=max_gen_length, max_new_tokens=max_new_tokens, output_scores=output_sequence_scores, return_dict_in_generate=output_sequence_scores, num_beams=num_beams, eos_token_id=eos_token_id, temperature=temperature, top_k=top_k, top_p=top_p, do_sample=do_sample, repetition_penalty=repetition_penalty, num_return_sequences=num_return_sequences, **gen_kwargs, ) except: logger.exception(sys.exc_info()[0]) for p in p_list: p.terminate() io_process.terminate() data_queue.close() msg_queue.close() sys.exit(1) if output_sequence_scores: sequences = summaries.sequences else: sequences = summaries scores_cpu = None if output_sequence_scores: if (type(summaries) in [BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput, BeamSampleDecoderOnlyOutput, BeamSampleEncoderDecoderOutput]): scores_cpu = summaries.sequences_scores.cpu() else: scores_cpu = torch.Tensor([float('nan')] * sequences.shape[0]) if output_summaries_only: sequences = sequences[:, input_ids.shape[-1]:] sequences_cpu = sequences.cpu() if (num_return_sequences is not None and num_return_sequences > 1): sequences_cpu = sequences_cpu.reshape([-1, num_return_sequences, sequences_cpu.shape[-1]]) if (scores_cpu is not None): scores_cpu = scores_cpu.reshape([-1, num_return_sequences]) data_queue.put((ind, sequences_cpu, scores_cpu)) except: logger.exception(sys.exc_info()[0]) for p in p_list: p.terminate() io_process.terminate() data_queue.close() msg_queue.close() sys.exit(1) data_queue.put((-1, GENERATE_FINISHED, None)) for p in p_list: p.join() msg_queue.put((-1, GENERATE_FINISHED, None)) io_process.join() fout.close()