def loadAbstractSummarizer(): from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig from transformers import LEDForConditionalGeneration, LEDTokenizer model_name = get_item("abstract_summarizer_model_name") MODEL_DIRECOTRY = f'./models/{model_name}/' use_bart = "bart" in model_name if os.path.exists(MODEL_DIRECOTRY): if use_bart: model = BartForConditionalGeneration.from_pretrained( MODEL_DIRECOTRY) tokenizer = BartTokenizer.from_pretrained(MODEL_DIRECOTRY) else: model = LEDForConditionalGeneration.from_pretrained( MODEL_DIRECOTRY, return_dict_in_generate=True) tokenizer = LEDTokenizer.from_pretrained(MODEL_DIRECOTRY) else: if use_bart: model = BartForConditionalGeneration.from_pretrained(model_name) tokenizer = BartTokenizer.from_pretrained(model_name) else: model = LEDForConditionalGeneration.from_pretrained( model_name, return_dict_in_generate=True) tokenizer = LEDTokenizer.from_pretrained(model_name) model.save_pretrained(MODEL_DIRECOTRY) tokenizer.save_pretrained(MODEL_DIRECOTRY) return model, tokenizer
def get_model_tokenizer(model_name): import torch torch_device = 'cuda' if torch.cuda.is_available() else 'cpu' if "pegasus" in model_name: #its a pegasus model from transformers import PegasusForConditionalGeneration, PegasusTokenizer tokenizer = PegasusTokenizer.from_pretrained(model_name) model = PegasusForConditionalGeneration.from_pretrained(model_name).to( torch_device) return model, tokenizer elif "bart-large" in model_name: # its a bart-model from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig tokenizer = BartTokenizer.from_pretrained(model_name) model = BartForConditionalGeneration.from_pretrained(model_name).to( torch_device) return model, tokenizer elif "bart-custom-large" in model_name: from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig tokenizer = BartTokenizer.from_pretrained(model_name) model = BartForConditionalGeneration.from_pretrained(model_name).to( torch_device) return model, tokenizer else: # T5 or distilbart from transformers import AutoTokenizer, AutoModelWithLMHead tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelWithLMHead.from_pretrained(model_name).to( torch_device) return model, tokenizer
def avg_token_embeddings(tokenizer: BartTokenizer, bart_model: BartModel, bart_name, num_tokens): """when initial added tokens, use their averge token emebddings Args: tokenizer (BartTokenizer): [description] bart_model (BartModel): [description] bart_name ([type]): [description] num_tokens ([type]): [description] Raises: RuntimeError: [description] Returns: [type]: [description] """ _tokenizer = BartTokenizer.from_pretrained(bart_name) for token in tokenizer.unique_no_split_tokens: if token[:2] == '<<': # 特殊字符 index = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(token)) if len(index) > 1: raise RuntimeError(f"{token} wrong split") else: index = index[0] assert index >= num_tokens, (index, num_tokens, token) indexes = _tokenizer.convert_tokens_to_ids( _tokenizer.tokenize(token[2:-2])) embed = bart_model.encoder.embed_tokens.weight.data[indexes[0]] for i in indexes[1:]: embed += bart_model.decoder.embed_tokens.weight.data[i] embed /= len(indexes) bart_model.decoder.embed_tokens.weight.data[index] = embed return bart_model
def __init__(self, device=None, checkpoint=None, state_dict_key='model', pretrained="facebook/bart-large-cnn", hg_transformers=True): if not hg_transformers and checkpoint: raise Exception( "hg_transformers must be set to True in order to load from checkpoint" ) if not device: device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") # huggingface uses dashes and fairseq/torchhub uses dots (periods) if pretrained: if hg_transformers: pretrained = pretrained.replace(".", "-") else: # only use the part after the "/" pretrained = pretrained.split("/")[-1].replace("-", ".") if checkpoint != None and "semsim" in checkpoint: cache_dir = appdirs.user_cache_dir("DocSum", "HHousen") output_file_path = os.path.join(cache_dir, "bart_semsim.pt") if not os.path.isfile(output_file_path): if not os.path.exists(cache_dir): os.makedirs(cache_dir) gdown.download( "https://drive.google.com/uc?id=1CNgK6ZkaqUD239h_6GkLmfUOGgryc2v9", output_file_path) checkpoint = output_file_path if checkpoint: loaded_checkpoint = torch.load(checkpoint) model_state_dict = loaded_checkpoint[state_dict_key] bart = BartForConditionalGeneration.from_pretrained( pretrained, state_dict=model_state_dict) tokenizer = BartTokenizer.from_pretrained( pretrained, state_dict=model_state_dict) self.tokenizer = tokenizer else: if hg_transformers: bart = BartForConditionalGeneration.from_pretrained(pretrained) tokenizer = BartTokenizer.from_pretrained(pretrained) self.tokenizer = tokenizer else: bart = torch.hub.load('pytorch/fairseq', pretrained) bart.to(device) bart.eval() bart.half() self.logger = logging.getLogger(__name__) self.hg_transformers = hg_transformers self.bart = bart
def get_dataloaders(dataset, batch_size, num_workers, dynamic_shape, max_src_length, max_tgt_length, shuffle): tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') def _collate_fn(raw_batch): hf_batch = tokenizer.prepare_seq2seq_batch( src_texts=[example['src'] for example in raw_batch], tgt_texts=[example['tgt'] for example in raw_batch], max_length=max_src_length, max_target_length=max_tgt_length, padding='max_length' if not dynamic_shape else 'longest', return_tensors='pt') fs_batch = { 'src_tokens': hf_batch['input_ids'], 'src_lengths': torch.sum(hf_batch['attention_mask'], dim=1), 'prev_output_tokens': hf_batch['labels'] } return fs_batch return { split: DataLoader(dataset=dataset[split], batch_size=batch_size, collate_fn=_collate_fn, shuffle=(split == 'train' and shuffle), drop_last=True, num_workers=num_workers) for split in ['train', 'dev'] }
def __init__(self, denumericaliser='BART', fields=[("input_ids", "input_text")], debug=True, skip_special_tokens=True, **kwargs): if denumericaliser == 'BART': self.denumericaliser = BartTokenizer.from_pretrained( 'facebook/bart-large').decode elif denumericaliser == 'BERT': self.denumericaliser = BertTokenizer.from_pretrained( 'bert-base-uncased').decode elif denumericaliser == 'Code32k': if not os.path.isfile( "datasets/code_search_net/codeBPE.tokenizer.json"): download_from_url( "https://storage.googleapis.com/carlos-phd-data/code-search-net-tokenizer/codeBPE.tokenizer.json", "datasets/code_search_net/codeBPE.tokenizer.json") code_BPE_tokenizer = Tokenizer.from_file( "datasets/code_search_net/codeBPE.tokenizer.json") self.denumericaliser = code_BPE_tokenizer.decode else: self.denumericaliser = denumericaliser if debug: print( f"Denumericaliser. Ex: [0,1,2,3,4,5,6,7,8,9] -> {self.denumericaliser([0,1,2,3,4,5,6,7,8,9])}" ) self.fields = fields self.skip_special_tokens = skip_special_tokens
def __init__(self, hparams, user_tokens=['<newline>', '<bullet>', '<sep>']): super(BartSystem, self).__init__() self.hparams = hparams self.hparams.model_type = self.hparams.model_type.lower() tokenizer = BartTokenizer.from_pretrained( self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path, do_lower_case=self.hparams.do_lower_case, cache_dir=self.hparams.cache_dir if self.hparams.cache_dir else None, ) config = AutoConfig.from_pretrained( self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path, cache_dir=self.hparams.cache_dir if self.hparams.cache_dir else None, output_past=self.hparams.do_test, vocab_size=len(tokenizer)) model = BartForConditionalGeneration.from_pretrained( self.hparams.model_name_or_path, from_tf=bool(".ckpt" in self.hparams.model_name_or_path), config=config, cache_dir=self.hparams.cache_dir if self.hparams.cache_dir else None, ) self.config, self.tokenizer, self.model = config, tokenizer, model self.loss = [] # for keeping track of average loss self.metrics = {} self.vocab = {v: k for k, v in self.tokenizer.get_vocab().items()}
def test_xsum_summarization_same_as_fairseq(self): model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-xsum").to(torch_device) self.assertFalse(model.config.is_valid_mbart()) tok = BartTokenizer.from_pretrained("facebook/bart-large") PGE_ARTICLE = """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow.""" EXPECTED_SUMMARY = "California's largest power company has begun shutting off power to tens of thousands of homes and businesses in the state." dct = tok.batch_encode_plus( [PGE_ARTICLE], max_length=1024, padding="max_length", truncation=True, return_tensors="pt", ).to(torch_device) hypotheses_batch = model.generate( input_ids=dct["input_ids"], attention_mask=dct["attention_mask"], num_beams=2, max_length=62, min_length=11, length_penalty=1.0, no_repeat_ngram_size=3, early_stopping=True, decoder_start_token_id=model.config.eos_token_id, ) decoded = tok.batch_decode(hypotheses_batch, skip_special_tokens=True,) self.assertEqual(EXPECTED_SUMMARY, decoded[0])
def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE): fout = Path(out_file).open("w") model = BartForConditionalGeneration.from_pretrained( "bart-large-cnn", output_past=True, ).to(device) tokenizer = BartTokenizer.from_pretrained("bart-large") for batch in tqdm(list(chunks(lns, batch_size))): dct = tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True) summaries = model.generate( input_ids=dct["input_ids"].to(device), attention_mask=dct["attention_mask"].to(device), num_beams=4, length_penalty=2.0, max_length= 142, # +2 from original because we start at step=1 and stop before max_length min_length=56, # +1 from original because we start at step=1 no_repeat_ngram_size=3, early_stopping=True, do_sample=False, ) dec = [ tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries ] for hypothesis in dec: fout.write(hypothesis + "\n") fout.flush()
def __init__(self, data_path, mapping, bart_name, learn_weights) -> None: self.data_path = data_path self.tokenizer = BartTokenizer.from_pretrained(bart_name) self.mapping = mapping # 记录的是原始tag与转换后的tag的str的匹配关系 self.original_token_nums = self.tokenizer.vocab_size self.learn_weights = learn_weights self._add_tags_to_tokens()
def main(): parser = argparse.ArgumentParser() parser.add_argument('--user_input', action="store_true") args = parser.parse_args() print("initializing bart tokenizer...") tokenizer = BartTokenizer.from_pretrained("facebook/bart-base") print("creating lightseq model...") ls_model = lightseq.Transformer("lightseq_bart_base.pb", 128) print("creating huggingface model...") hf_model = BartForConditionalGeneration.from_pretrained( "facebook/bart-base") while True: if args.user_input: sentences = [input("input the masked sentence:\n")] else: sentences = [ "I love that girl, but <mask> does not <mask> me.", "She is so <mask> that I can not help glance at <mask>.", "Nothing's gonna <mask> my love for you.", "Drop everything now. Meet me in the pouring <mask>. Kiss me on the sidewalk." ] print("tokenizing the sentences...") inputs = tokenizer(sentences, return_tensors="pt", padding=True) inputs_id = inputs["input_ids"] ls_generate(ls_model, tokenizer, inputs_id) hf_generate(hf_model, tokenizer, inputs_id) if not args.user_input: break
def simple_term_counts(data_dir='data/data-simplification/wikismall'): tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-xsum') model = LogisticRegression(max_iter=200) for batch in dataloader(data_dir): X, y = construct_dataset(batch, tokenizer) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2) #apply feature scaling X_train = normalize(X_train) X_test = normalize(X_test) model.fit(X_train, y_train) predictions = model.predict(X_test) print(accuracy_score(y_test, predictions)) return vocab = get_vocab(tokenizer) weights = np.squeeze(model.coef_, axis=0).tolist() sorted_weights = filter(lambda x: len(x[1].strip()) > 0, zip(range(tokenizer.vocab_size), vocab, weights)) sorted_weights = list(sorted(sorted_weights, key=lambda x: x[2])) with open('data/logr_weights/bart_freq_normalized_ids.txt', 'w') as f: for ID, word, weight in sorted_weights: f.write(f'{ID} {weight}\n') with open('data/logr_weights/bart_freq_normalized_tokens.txt', 'w') as f: for ID, word, weight in sorted_weights: f.write(f'{word} {weight}\n')
def main(): parser = argparse.ArgumentParser() parser.add_argument("--input_file") parser.add_argument("--output_file") parser.add_argument( "--decoder", choices=['greedy', 'beam_search', 'random', 'top_k', 'nucleus']) args = parser.parse_args() model_name = 'sshleifer/distilbart-xsum-1-1' model = BartForConditionalGeneration.from_pretrained(model_name).eval() tokenizer = BartTokenizer.from_pretrained(model_name) # Iterate through input file documents, generating summaries outputs = [] for line in tqdm.tqdm(jsonlines.open(args.input_file)): summary, summary_score = generate_summary(model=model, tokenizer=tokenizer, document=line['document'], decoder=args.decoder) outputs.append({ 'id': line['id'], 'generated_summary': summary, 'generated_summary_score': summary_score }) # Write out the generated summaries to file with open(args.output_file, 'w', encoding='utf-8') as f: for l in outputs: f.write(json.dumps(l, ensure_ascii=False) + '\n')
def get_summary(text, model, tokenizer, torch_device): """ Get summary """ tokenizer_summarize = BartTokenizer.from_pretrained("bart-large-cnn") model_summarize = BartForConditionalGeneration.from_pretrained("bart-large-cnn").to( torch_device ) model_summarize.to(torch_device) # Set the model in evaluation mode to deactivate the DropOut modules model_summarize.eval() answers_input_ids = tokenizer_summarize.batch_encode_plus( [text], return_tensors="pt", max_length=1024 )["input_ids"] answers_input_ids = answers_input_ids.to(torch_device) summary_ids = model_summarize.generate( answers_input_ids, num_beams=4, max_length=5, early_stopping=True ) return tokenizer_summarize.decode( summary_ids.squeeze(), skip_special_tokens=True, clean_up_tokenization_spaces=False, )
def load(self, save_dir): self._tokenizer = BartTokenizer.from_pretrained(save_dir) self.decoder_vocab = DecoderVocabulary( self._tokenizer.decoder.values(), None, pad_token=self._tokenizer.pad_token, eos_token=self._tokenizer.eos_token)
def load_model_tokenizer(self, pretrained): """ Load transformer model and tokenizer for given pre-trained name :param pretrained: pre-trained name :return: model, tokenizer """ model = None tokenizer = None if self.method == "T5": if pretrained in T5_PRETRAINED_MODELS: model = T5ForConditionalGeneration.from_pretrained(pretrained) tokenizer = T5Tokenizer.from_pretrained(pretrained) elif self.method == "BART": if pretrained in BART_PRETRAINED_MODELS: model = BartForConditionalGeneration.from_pretrained(pretrained) tokenizer = BartTokenizer.from_pretrained(pretrained) elif self.method == "GPT-2": if pretrained in GPT2_PRETRAINED_MODELS: model = GPT2LMHeadModel.from_pretrained(pretrained) model.config.max_length = self.max_length tokenizer = GPT2Tokenizer.from_pretrained(pretrained) elif self.method == "XLM": if pretrained in XLM_PRETRAINED_MODELS: model = XLMWithLMHeadModel.from_pretrained(pretrained) model.config.max_length = self.max_length tokenizer = XLMTokenizer.from_pretrained(pretrained) else: pass return model, tokenizer
def __init__( self, chkpt_path="/Users/byronwallace/code/RoboSum/weights/pl_title_/pl_title_2048.ckpt" ): self.model = BartForConditionalGeneration.from_pretrained( 'facebook/bart-large-cnn') self.config = BartConfig.from_pretrained('facebook/bart-large-cnn') self.tokenizer = BartTokenizer.from_pretrained( 'facebook/bart-large-cnn') # increase position embeddings from 1024 to 2048 self.add_position_embeddings() # now add special tokens (for title and abstract demarcation) # as a general note: we'll assume "abstract" is either the # actual abstract of extracted text from the same (i.e., punchlines) self.add_special_tokens() # now load the checkpoint print("loading checkpoint", chkpt_path) checkpoint = torch.load(chkpt_path, map_location="cpu") print("done") cnew = {} for key, value in checkpoint['state_dict'].items(): cnew[".".join(key.split('.')[1:])] = value self.model.load_state_dict(cnew)
def load_bart_fever_rte_model(model_name, data_dir): processors = { "rte": RteProcessor } output_modes = { "rte": "classification" } # task_name = args.task_name.lower() task_name = 'rte' if task_name not in processors: raise ValueError("Task not found: %s" % (task_name)) processor = processors[task_name]() output_mode = output_modes[task_name] label_list = processor.get_labels() # [0,1] num_labels = len(label_list) pretrain_model_dir = '{}/FineTuneOn{}'.format(data_dir, model_name) # pretrain_model_dir = 'please enter your pretrain models path here/FineTuneOn{}'.format(model_name) # Prepare model # cache_dir = os.path.join(str(PYTORCH_TRANSFORMERS_CACHE), '{} model distributed_{}'.format(model_name, -1)) # # cache_dir = os.path.join(str(PYTORCH_TRANSFORMERS_CACHE), '{} model distributed_{}'.format(model_name, -1)) model = BartForSequenceClassification.from_pretrained(pretrain_model_dir, num_labels=num_labels) tokenizer = BartTokenizer.from_pretrained(pretrain_model_dir) # model = BertForSequenceClassification.from_pretrained('bert-base-uncased', # cache_dir=cache_dir, # num_labels=num_labels) # tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True) # print(tokenizer) return model, tokenizer
def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE): fout = Path(out_file).open("w") model = BartForMaskedLM.from_pretrained( "bart-large-cnn", output_past=True, ) tokenizer = BartTokenizer.from_pretrained("bart-large") for batch in tqdm(list(chunks(lns, batch_size))): dct = tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True) summaries = model.generate( input_ids=dct["input_ids"].to(device), attention_mask=dct["attention_mask"].to(device), num_beams=4, length_penalty=2.0, max_length=140, min_len=55, no_repeat_ngram_size=3, ) dec = [ tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries ] for hypothesis in dec: fout.write(hypothesis + "\n") fout.flush()
def __init__(self): self.nli_model = BartForSequenceClassification.from_pretrained( 'facebook/bart-large-mnli') self.nli_model = self.nli_model.to(DEVICE) self.tokenizer = BartTokenizer.from_pretrained( 'facebook/bart-large-mnli')
def setUpClass(cls): # summarization # generate yes beam search # Note for BART summarization in transformers repo, beam search performs much better # than no beam search, but even their beam search with num_beams=1 is better, implying that something # is broken in the _generate_no_beam_search function # see ``examples/summarization/bart/evaluate_cnn.py`` for a longer example cls.model = BartForConditionalGeneration.from_pretrained( 'bart-large-cnn') cls.tokenizer = BartTokenizer.from_pretrained('bart-large-cnn') cls.decoding_hyperparams = {'max_length': 40, 'num_beams': 3} cls.test_news_article_1 = 'New Zealand says it has stopped community transmission of Covid-19, ' \ 'effectively eliminating the virus. With new cases in single figures for several days - one on Sunday ' \ '- Prime Minister Jacinda Ardern said the virus was "currently" eliminated. But officials have warned ' \ 'against complacency, saying it does not mean a total end to new coronavirus cases. ' \ 'The news comes hours before New Zealand is set to move out of its toughest level of social restrictions. ' \ 'From Tuesday, some non-essential business, healthcare and education activity will be able to resume. ' \ 'Most people will still be required to remain at home at all times and avoid all social interactions.' cls.test_news_article_2 = \ 'But officials have warned against complacency, saying it does not mean a total end to new HIV cases. ' \ 'Most people will still be required to remain at home at all times and avoid all social interactions.' \ 'Germany says it has stopped community transmission of HIV, ' \ 'effectively eliminating the virus. With new cases in single figures for several days - one on Sunday ' \ '- Prime Minister Angela Merkle said the virus was "currently" eliminated. ' \ 'From Tuesday, some non-essential business, healthcare and education activity will be able to resume. ' \ 'The news comes hours before Germany is set to move out of its toughest level of social restrictions. '
def __init__( self, model_name_or_path, tokenizer_name, model_cache_dir, input_max_length, target_max_length, summary_column_name, document_column_name, wandb_project, wandb_run_name, **kwargs, ): super().__init__( input_max_length, target_max_length, summary_column_name, document_column_name, wandb_project, wandb_run_name, ) self.tokenizer = BartTokenizer.from_pretrained( tokenizer_name if tokenizer_name else model_name_or_path, cache_dir=model_cache_dir, ) self.model = BartForConditionalGeneration.from_pretrained( model_name_or_path, cache_dir=model_cache_dir, )
def test_xsum_summarization_same_as_fairseq(self): model = BartForConditionalGeneration.from_pretrained( "facebook/bart-large-xsum").to(torch_device) self.assertFalse(model.config.is_valid_mbart()) tok = BartTokenizer.from_pretrained("facebook/bart-large") EXPECTED_SUMMARY = "California's largest power company has begun shutting off power to tens of thousands of homes and businesses in the state." dct = tok.batch_encode_plus( [PGE_ARTICLE], max_length=1024, padding="max_length", truncation=True, return_tensors="pt", ).to(torch_device) hypotheses_batch = model.generate( input_ids=dct["input_ids"], attention_mask=dct["attention_mask"], num_beams=2, max_length=62, min_length=11, length_penalty=1.0, no_repeat_ngram_size=3, early_stopping=True, decoder_start_token_id=model.config.eos_token_id, ) decoded = tok.batch_decode( hypotheses_batch, skip_special_tokens=True, ) self.assertEqual(EXPECTED_SUMMARY, decoded[0])
def main(args): tokenizer = BartTokenizer.from_pretrained(args.tokenizer_path) proj_dir = Path() corpus_dir = proj_dir / "corpus" comment_dir = corpus_dir / "comment" source_path = comment_dir / args.corpus mask_path = comment_dir / args.mask_path dm = BartDataModule(source_path=source_path, mask_path=mask_path, tokenizer=tokenizer, batch_size=3, num_workers=1) data_loader = dm.train_dataloader() for masked_encode, attention_mask, encode in data_loader: masked_encode = masked_encode.detach().cpu().numpy() encode = encode.detach().cpu().numpy() for m, e in zip(masked_encode, encode): print(tokenizer.decode(m)) print(tokenizer.decode(e)) print(attention_mask)
def generate_summaries( examples: list, out_file: str, model_name: str, batch_size: int = 8, device: str = DEFAULT_DEVICE ): fout = Path(out_file).open("w") model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn', output_past=True).to(device) tokenizer = BartTokenizer.from_pretrained("bart-large-cnn") max_length = 140 min_length = 55 for batch in tqdm(list(chunks(examples, batch_size))): dct = tokenizer.batch_encode_plus(batch, max_length=64, return_tensors="pt", pad_to_max_length=True) print(dct["input_ids"][0]) print(dct["attention_mask"][0]) summaries = model.generate( input_ids=dct["input_ids"].to(device), attention_mask=dct["attention_mask"].to(device), num_beams=4, length_penalty=10.0, repetition_penalty = 5.0, max_length=20, # +2 from original because we start at step=1 and stop before max_length #min_length=min_length + 1, # +1 from original because we start at step=1 no_repeat_ngram_size=3, early_stopping=True, # decoder_start_token_id=model.config.eos_token_id, ) dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries] in_ids = dct["input_ids"].to(device) in_dec = [tokenizer.decode(id, skip_special_tokens=True, clean_up_tokenization_spaces=False) for id in in_ids] for input, hypothesis in zip(in_dec, dec): fout.write(input + ' ||| ' + hypothesis + "\n") fout.flush()
def test_init_and_from_pretrained(self): rag_config = self.get_rag_config() rag_decoder_tokenizer = BartTokenizer.from_pretrained( "facebook/bart-large-cnn") rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained( "facebook/dpr-question_encoder-single-nq-base") rag_retriever = RagRetriever( rag_config, question_encoder_tokenizer=rag_question_encoder_tokenizer, generator_tokenizer=rag_decoder_tokenizer, ) rag_config = RagConfig.from_pretrained("facebook/rag-sequence-base") rag = TFRagTokenForGeneration(rag_config, retriever=rag_retriever) input_ids = rag_question_encoder_tokenizer( "who sings does he love me with reba", return_tensors="tf").input_ids decoder_input_ids = rag_decoder_tokenizer( "Linda Davis", return_tensors="tf").input_ids rag( input_ids, decoder_input_ids=decoder_input_ids, ) # this should not give any warnings with tempfile.TemporaryDirectory() as tmpdirname: rag.save_pretrained(tmpdirname) rag = TFRagTokenForGeneration.from_pretrained( tmpdirname, retriever=rag_retriever)
def test_diverse_beam_search(self): article = """Justin Timberlake and Jessica Biel, welcome to parenthood. The celebrity couple announced the arrival of their son, Silas Randall Timberlake, in statements to People. "Silas was the middle name of Timberlake's maternal grandfather Bill Bomar, who died in 2012, while Randall is the musician's own middle name, as well as his father's first," People reports. The couple announced the pregnancy in January, with an Instagram post. It is the first baby for both.""" bart_tokenizer = BartTokenizer.from_pretrained( "facebook/bart-large-cnn") bart_model = BartForConditionalGeneration.from_pretrained( "facebook/bart-large-cnn").to(torch_device) input_ids = bart_tokenizer( article, return_tensors="pt").input_ids.to(torch_device) outputs = bart_model.generate(input_ids, num_beams=4, num_return_sequences=2, num_beam_groups=4, diversity_penalty=2.0) generated_text = bart_tokenizer.batch_decode(outputs, skip_special_tokens=True) self.assertListEqual( generated_text, [ "The couple announced the birth of their son, Silas Randall Timberlake, in a statement. Silas was the middle name of Timberlake's maternal grandfather Bill Bomar. Randall is the musician's own middle name, as well as his father's first. It is the first baby for both of them.", "Justin Timberlake and Jessica Biel have a son. The baby is named Silas Randall Timberlake. It is the first child for both. The couple announced the pregnancy in January. The name Silas is the middle name of Timberlake's maternal grandfather. It's also his own middle name.", ], )
def bart_summarize(input_file): model = BartForConditionalGeneration.from_pretrained( 'facebook/bart-large-cnn') tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn') num_count = get_num_pages(input_file) f = open('summarized_bart.txt', 'a+') count = 0 while count < num_count: text = pdf_to_text(input_file, count) ARTICLE_TO_SUMMARIZE = text inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt') # Generate Summary summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=5, early_stopping=True) summarized_text = [ tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids ] print(summarized_text) str1 = ''.join(summarized_text) print(str1) f.write(str1) count += 1 f.close()
def test_bart_summarization_dataset(self): tmp_dir = Path(tempfile.gettempdir()) articles = [" Sam ate lunch today", "Sams lunch ingredients"] summaries = [ "A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee" ] _dump_articles((tmp_dir / "train.source"), articles) _dump_articles((tmp_dir / "train.target"), summaries) tokenizer = BartTokenizer.from_pretrained("facebook/bart-large") max_len_source = max(len(tokenizer.encode(a)) for a in articles) max_len_target = max(len(tokenizer.encode(a)) for a in summaries) trunc_target = 4 train_dataset = SummarizationDataset( tokenizer, data_dir=tmp_dir, type_path="train", max_source_length=20, max_target_length=trunc_target, ) dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn) for batch in dataloader: self.assertEqual(batch["source_mask"].shape, batch["source_ids"].shape) # show that articles were trimmed. self.assertEqual(batch["source_ids"].shape[1], max_len_source) self.assertGreater( 20, batch["source_ids"].shape[1]) # trimmed significantly # show that targets were truncated self.assertEqual(batch["target_ids"].shape[1], trunc_target) # Truncated self.assertGreater(max_len_target, trunc_target) # Truncated
def __init__(self, numericaliser='BART', fields=[("input_text", "input_ids")], use_ray=False, debug=True, max_len=1000, **kwargs): if numericaliser == 'BART': self.numericaliser = BartTokenizer.from_pretrained( 'facebook/bart-large').encode elif numericaliser == 'BERT': self.numericaliser = BertTokenizer.from_pretrained( 'bert-base-uncased').encode elif numericaliser == 'Code32k': if not os.path.isfile( "datasets/code_search_net/codeBPE.tokenizer.json"): download_from_url( "https://storage.googleapis.com/carlos-phd-data/code-search-net-tokenizer/codeBPE.tokenizer.json", "datasets/code_search_net/codeBPE.tokenizer.json") code_BPE_tokenizer = Tokenizer.from_file( "datasets/code_search_net/codeBPE.tokenizer.json") self.custom_tokenizer = code_BPE_tokenizer self.numericaliser = self.custom_tokenizer2ids else: self.numericaliser = numericaliser if debug: print( f"Numericaliser. Ex: 'This is a test' -> {self.numericaliser('This is a test')}" ) self.fields = fields self.use_ray = use_ray self.max_len = max_len