Ejemplo n.º 1
0
def loadAbstractSummarizer():
    from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig
    from transformers import LEDForConditionalGeneration, LEDTokenizer

    model_name = get_item("abstract_summarizer_model_name")

    MODEL_DIRECOTRY = f'./models/{model_name}/'

    use_bart = "bart" in model_name
    if os.path.exists(MODEL_DIRECOTRY):
        if use_bart:
            model = BartForConditionalGeneration.from_pretrained(
                MODEL_DIRECOTRY)
            tokenizer = BartTokenizer.from_pretrained(MODEL_DIRECOTRY)
        else:
            model = LEDForConditionalGeneration.from_pretrained(
                MODEL_DIRECOTRY, return_dict_in_generate=True)
            tokenizer = LEDTokenizer.from_pretrained(MODEL_DIRECOTRY)
    else:
        if use_bart:
            model = BartForConditionalGeneration.from_pretrained(model_name)
            tokenizer = BartTokenizer.from_pretrained(model_name)
        else:
            model = LEDForConditionalGeneration.from_pretrained(
                model_name, return_dict_in_generate=True)
            tokenizer = LEDTokenizer.from_pretrained(model_name)

        model.save_pretrained(MODEL_DIRECOTRY)
        tokenizer.save_pretrained(MODEL_DIRECOTRY)
    return model, tokenizer
def get_model_tokenizer(model_name):
    import torch
    torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
    if "pegasus" in model_name:
        #its a pegasus model
        from transformers import PegasusForConditionalGeneration, PegasusTokenizer
        tokenizer = PegasusTokenizer.from_pretrained(model_name)
        model = PegasusForConditionalGeneration.from_pretrained(model_name).to(
            torch_device)
        return model, tokenizer

    elif "bart-large" in model_name:
        # its a bart-model
        from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig
        tokenizer = BartTokenizer.from_pretrained(model_name)
        model = BartForConditionalGeneration.from_pretrained(model_name).to(
            torch_device)
        return model, tokenizer

    elif "bart-custom-large" in model_name:
        from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig
        tokenizer = BartTokenizer.from_pretrained(model_name)
        model = BartForConditionalGeneration.from_pretrained(model_name).to(
            torch_device)
        return model, tokenizer

    else:
        # T5 or distilbart
        from transformers import AutoTokenizer, AutoModelWithLMHead
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelWithLMHead.from_pretrained(model_name).to(
            torch_device)
        return model, tokenizer
Ejemplo n.º 3
0
def avg_token_embeddings(tokenizer: BartTokenizer, bart_model: BartModel,
                         bart_name, num_tokens):
    """when initial added tokens, use their averge token emebddings

    Args:
        tokenizer (BartTokenizer): [description]
        bart_model (BartModel): [description]
        bart_name ([type]): [description]
        num_tokens ([type]): [description]

    Raises:
        RuntimeError: [description]

    Returns:
        [type]: [description]
    """
    _tokenizer = BartTokenizer.from_pretrained(bart_name)
    for token in tokenizer.unique_no_split_tokens:
        if token[:2] == '<<':  # 特殊字符
            index = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(token))
            if len(index) > 1:
                raise RuntimeError(f"{token} wrong split")
            else:
                index = index[0]
            assert index >= num_tokens, (index, num_tokens, token)
            indexes = _tokenizer.convert_tokens_to_ids(
                _tokenizer.tokenize(token[2:-2]))
            embed = bart_model.encoder.embed_tokens.weight.data[indexes[0]]
            for i in indexes[1:]:
                embed += bart_model.decoder.embed_tokens.weight.data[i]
            embed /= len(indexes)
            bart_model.decoder.embed_tokens.weight.data[index] = embed
    return bart_model
Ejemplo n.º 4
0
    def __init__(self,
                 device=None,
                 checkpoint=None,
                 state_dict_key='model',
                 pretrained="facebook/bart-large-cnn",
                 hg_transformers=True):
        if not hg_transformers and checkpoint:
            raise Exception(
                "hg_transformers must be set to True in order to load from checkpoint"
            )

        if not device:
            device = torch.device(
                "cuda" if torch.cuda.is_available() else "cpu")

        # huggingface uses dashes and fairseq/torchhub uses dots (periods)
        if pretrained:
            if hg_transformers:
                pretrained = pretrained.replace(".", "-")
            else:
                # only use the part after the "/"
                pretrained = pretrained.split("/")[-1].replace("-", ".")

        if checkpoint != None and "semsim" in checkpoint:
            cache_dir = appdirs.user_cache_dir("DocSum", "HHousen")
            output_file_path = os.path.join(cache_dir, "bart_semsim.pt")
            if not os.path.isfile(output_file_path):
                if not os.path.exists(cache_dir):
                    os.makedirs(cache_dir)
                gdown.download(
                    "https://drive.google.com/uc?id=1CNgK6ZkaqUD239h_6GkLmfUOGgryc2v9",
                    output_file_path)
            checkpoint = output_file_path

        if checkpoint:
            loaded_checkpoint = torch.load(checkpoint)
            model_state_dict = loaded_checkpoint[state_dict_key]

            bart = BartForConditionalGeneration.from_pretrained(
                pretrained, state_dict=model_state_dict)
            tokenizer = BartTokenizer.from_pretrained(
                pretrained, state_dict=model_state_dict)
            self.tokenizer = tokenizer
        else:
            if hg_transformers:
                bart = BartForConditionalGeneration.from_pretrained(pretrained)
                tokenizer = BartTokenizer.from_pretrained(pretrained)
                self.tokenizer = tokenizer
            else:
                bart = torch.hub.load('pytorch/fairseq', pretrained)
                bart.to(device)
                bart.eval()
                bart.half()

        self.logger = logging.getLogger(__name__)
        self.hg_transformers = hg_transformers
        self.bart = bart
def get_dataloaders(dataset, batch_size, num_workers, dynamic_shape,
                    max_src_length, max_tgt_length, shuffle):
    tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')

    def _collate_fn(raw_batch):
        hf_batch = tokenizer.prepare_seq2seq_batch(
            src_texts=[example['src'] for example in raw_batch],
            tgt_texts=[example['tgt'] for example in raw_batch],
            max_length=max_src_length,
            max_target_length=max_tgt_length,
            padding='max_length' if not dynamic_shape else 'longest',
            return_tensors='pt')

        fs_batch = {
            'src_tokens': hf_batch['input_ids'],
            'src_lengths': torch.sum(hf_batch['attention_mask'], dim=1),
            'prev_output_tokens': hf_batch['labels']
        }

        return fs_batch

    return {
        split: DataLoader(dataset=dataset[split],
                          batch_size=batch_size,
                          collate_fn=_collate_fn,
                          shuffle=(split == 'train' and shuffle),
                          drop_last=True,
                          num_workers=num_workers)
        for split in ['train', 'dev']
    }
Ejemplo n.º 6
0
 def __init__(self,
              denumericaliser='BART',
              fields=[("input_ids", "input_text")],
              debug=True,
              skip_special_tokens=True,
              **kwargs):
     if denumericaliser == 'BART':
         self.denumericaliser = BartTokenizer.from_pretrained(
             'facebook/bart-large').decode
     elif denumericaliser == 'BERT':
         self.denumericaliser = BertTokenizer.from_pretrained(
             'bert-base-uncased').decode
     elif denumericaliser == 'Code32k':
         if not os.path.isfile(
                 "datasets/code_search_net/codeBPE.tokenizer.json"):
             download_from_url(
                 "https://storage.googleapis.com/carlos-phd-data/code-search-net-tokenizer/codeBPE.tokenizer.json",
                 "datasets/code_search_net/codeBPE.tokenizer.json")
         code_BPE_tokenizer = Tokenizer.from_file(
             "datasets/code_search_net/codeBPE.tokenizer.json")
         self.denumericaliser = code_BPE_tokenizer.decode
     else:
         self.denumericaliser = denumericaliser
     if debug:
         print(
             f"Denumericaliser. Ex: [0,1,2,3,4,5,6,7,8,9] -> {self.denumericaliser([0,1,2,3,4,5,6,7,8,9])}"
         )
     self.fields = fields
     self.skip_special_tokens = skip_special_tokens
Ejemplo n.º 7
0
    def __init__(self,
                 hparams,
                 user_tokens=['<newline>', '<bullet>', '<sep>']):
        super(BartSystem, self).__init__()
        self.hparams = hparams
        self.hparams.model_type = self.hparams.model_type.lower()
        tokenizer = BartTokenizer.from_pretrained(
            self.hparams.tokenizer_name if self.hparams.tokenizer_name else
            self.hparams.model_name_or_path,
            do_lower_case=self.hparams.do_lower_case,
            cache_dir=self.hparams.cache_dir
            if self.hparams.cache_dir else None,
        )

        config = AutoConfig.from_pretrained(
            self.hparams.config_name
            if self.hparams.config_name else self.hparams.model_name_or_path,
            cache_dir=self.hparams.cache_dir
            if self.hparams.cache_dir else None,
            output_past=self.hparams.do_test,
            vocab_size=len(tokenizer))

        model = BartForConditionalGeneration.from_pretrained(
            self.hparams.model_name_or_path,
            from_tf=bool(".ckpt" in self.hparams.model_name_or_path),
            config=config,
            cache_dir=self.hparams.cache_dir
            if self.hparams.cache_dir else None,
        )

        self.config, self.tokenizer, self.model = config, tokenizer, model
        self.loss = []  # for keeping track of average loss
        self.metrics = {}

        self.vocab = {v: k for k, v in self.tokenizer.get_vocab().items()}
Ejemplo n.º 8
0
    def test_xsum_summarization_same_as_fairseq(self):
        model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-xsum").to(torch_device)
        self.assertFalse(model.config.is_valid_mbart())
        tok = BartTokenizer.from_pretrained("facebook/bart-large")

        PGE_ARTICLE = """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."""
        EXPECTED_SUMMARY = "California's largest power company has begun shutting off power to tens of thousands of homes and businesses in the state."
        dct = tok.batch_encode_plus(
            [PGE_ARTICLE], max_length=1024, padding="max_length", truncation=True, return_tensors="pt",
        ).to(torch_device)

        hypotheses_batch = model.generate(
            input_ids=dct["input_ids"],
            attention_mask=dct["attention_mask"],
            num_beams=2,
            max_length=62,
            min_length=11,
            length_penalty=1.0,
            no_repeat_ngram_size=3,
            early_stopping=True,
            decoder_start_token_id=model.config.eos_token_id,
        )

        decoded = tok.batch_decode(hypotheses_batch, skip_special_tokens=True,)
        self.assertEqual(EXPECTED_SUMMARY, decoded[0])
Ejemplo n.º 9
0
def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE):
    fout = Path(out_file).open("w")
    model = BartForConditionalGeneration.from_pretrained(
        "bart-large-cnn",
        output_past=True,
    ).to(device)
    tokenizer = BartTokenizer.from_pretrained("bart-large")
    for batch in tqdm(list(chunks(lns, batch_size))):
        dct = tokenizer.batch_encode_plus(batch,
                                          max_length=1024,
                                          return_tensors="pt",
                                          pad_to_max_length=True)
        summaries = model.generate(
            input_ids=dct["input_ids"].to(device),
            attention_mask=dct["attention_mask"].to(device),
            num_beams=4,
            length_penalty=2.0,
            max_length=
            142,  # +2 from original because we start at step=1 and stop before max_length
            min_length=56,  # +1 from original because we start at step=1
            no_repeat_ngram_size=3,
            early_stopping=True,
            do_sample=False,
        )
        dec = [
            tokenizer.decode(g,
                             skip_special_tokens=True,
                             clean_up_tokenization_spaces=False)
            for g in summaries
        ]
        for hypothesis in dec:
            fout.write(hypothesis + "\n")
            fout.flush()
Ejemplo n.º 10
0
 def __init__(self, data_path, mapping, bart_name, learn_weights) -> None:
     self.data_path = data_path
     self.tokenizer = BartTokenizer.from_pretrained(bart_name)
     self.mapping = mapping  # 记录的是原始tag与转换后的tag的str的匹配关系
     self.original_token_nums = self.tokenizer.vocab_size
     self.learn_weights = learn_weights
     self._add_tags_to_tokens()
Ejemplo n.º 11
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--user_input', action="store_true")
    args = parser.parse_args()

    print("initializing bart tokenizer...")
    tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")

    print("creating lightseq model...")
    ls_model = lightseq.Transformer("lightseq_bart_base.pb", 128)
    print("creating huggingface model...")
    hf_model = BartForConditionalGeneration.from_pretrained(
        "facebook/bart-base")

    while True:
        if args.user_input:
            sentences = [input("input the masked sentence:\n")]
        else:
            sentences = [
                "I love that girl, but <mask> does not <mask> me.",
                "She is so <mask> that I can not help glance at <mask>.",
                "Nothing's gonna <mask> my love for you.",
                "Drop everything now. Meet me in the pouring <mask>. Kiss me on the sidewalk."
            ]

        print("tokenizing the sentences...")
        inputs = tokenizer(sentences, return_tensors="pt", padding=True)
        inputs_id = inputs["input_ids"]

        ls_generate(ls_model, tokenizer, inputs_id)
        hf_generate(hf_model, tokenizer, inputs_id)

        if not args.user_input:
            break
def simple_term_counts(data_dir='data/data-simplification/wikismall'):

    tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-xsum')
    model = LogisticRegression(max_iter=200)

    for batch in dataloader(data_dir):
        X, y = construct_dataset(batch, tokenizer)
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2) 

        #apply feature scaling
        X_train = normalize(X_train)
        X_test = normalize(X_test)

        model.fit(X_train, y_train)
        predictions = model.predict(X_test)
        print(accuracy_score(y_test, predictions))

    return

    vocab = get_vocab(tokenizer)
    weights = np.squeeze(model.coef_, axis=0).tolist()

    sorted_weights = filter(lambda x: len(x[1].strip()) > 0, zip(range(tokenizer.vocab_size), vocab, weights))
    sorted_weights = list(sorted(sorted_weights, key=lambda x: x[2]))
    
    with open('data/logr_weights/bart_freq_normalized_ids.txt', 'w') as f:
        for ID, word, weight in sorted_weights:
            f.write(f'{ID} {weight}\n')

    with open('data/logr_weights/bart_freq_normalized_tokens.txt', 'w') as f:
        for ID, word, weight in sorted_weights:
            f.write(f'{word} {weight}\n')
Ejemplo n.º 13
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_file")
    parser.add_argument("--output_file")
    parser.add_argument(
        "--decoder",
        choices=['greedy', 'beam_search', 'random', 'top_k', 'nucleus'])
    args = parser.parse_args()

    model_name = 'sshleifer/distilbart-xsum-1-1'
    model = BartForConditionalGeneration.from_pretrained(model_name).eval()
    tokenizer = BartTokenizer.from_pretrained(model_name)

    # Iterate through input file documents, generating summaries
    outputs = []
    for line in tqdm.tqdm(jsonlines.open(args.input_file)):
        summary, summary_score = generate_summary(model=model,
                                                  tokenizer=tokenizer,
                                                  document=line['document'],
                                                  decoder=args.decoder)

        outputs.append({
            'id': line['id'],
            'generated_summary': summary,
            'generated_summary_score': summary_score
        })

    # Write out the generated summaries to file
    with open(args.output_file, 'w', encoding='utf-8') as f:
        for l in outputs:
            f.write(json.dumps(l, ensure_ascii=False) + '\n')
Ejemplo n.º 14
0
def get_summary(text, model, tokenizer, torch_device):
    """
    Get summary
    """

    tokenizer_summarize = BartTokenizer.from_pretrained("bart-large-cnn")
    model_summarize = BartForConditionalGeneration.from_pretrained("bart-large-cnn").to(
        torch_device
    )

    model_summarize.to(torch_device)
    # Set the model in evaluation mode to deactivate the DropOut modules
    model_summarize.eval()

    answers_input_ids = tokenizer_summarize.batch_encode_plus(
        [text], return_tensors="pt", max_length=1024
    )["input_ids"]

    answers_input_ids = answers_input_ids.to(torch_device)

    summary_ids = model_summarize.generate(
        answers_input_ids, num_beams=4, max_length=5, early_stopping=True
    )

    return tokenizer_summarize.decode(
        summary_ids.squeeze(),
        skip_special_tokens=True,
        clean_up_tokenization_spaces=False,
    )
Ejemplo n.º 15
0
 def load(self, save_dir):
     self._tokenizer = BartTokenizer.from_pretrained(save_dir)
     self.decoder_vocab = DecoderVocabulary(
         self._tokenizer.decoder.values(),
         None,
         pad_token=self._tokenizer.pad_token,
         eos_token=self._tokenizer.eos_token)
    def load_model_tokenizer(self, pretrained):
        """ Load transformer model and tokenizer for given pre-trained name 
        
        :param pretrained: pre-trained name
        :return: model, tokenizer
        """
        
        model = None
        tokenizer = None
        
        if self.method == "T5":
            if pretrained in T5_PRETRAINED_MODELS:
                model = T5ForConditionalGeneration.from_pretrained(pretrained)
                tokenizer = T5Tokenizer.from_pretrained(pretrained)
        elif self.method == "BART":
            if pretrained in BART_PRETRAINED_MODELS:
                model = BartForConditionalGeneration.from_pretrained(pretrained)
                tokenizer = BartTokenizer.from_pretrained(pretrained)
        elif self.method == "GPT-2":
            if pretrained in GPT2_PRETRAINED_MODELS:
                model = GPT2LMHeadModel.from_pretrained(pretrained)
                model.config.max_length = self.max_length
                tokenizer = GPT2Tokenizer.from_pretrained(pretrained)
        elif self.method == "XLM":
            if pretrained in XLM_PRETRAINED_MODELS:
                model = XLMWithLMHeadModel.from_pretrained(pretrained)
                model.config.max_length = self.max_length
                tokenizer = XLMTokenizer.from_pretrained(pretrained)
        else:
            pass

        return model, tokenizer
Ejemplo n.º 17
0
    def __init__(
        self,
        chkpt_path="/Users/byronwallace/code/RoboSum/weights/pl_title_/pl_title_2048.ckpt"
    ):
        self.model = BartForConditionalGeneration.from_pretrained(
            'facebook/bart-large-cnn')
        self.config = BartConfig.from_pretrained('facebook/bart-large-cnn')
        self.tokenizer = BartTokenizer.from_pretrained(
            'facebook/bart-large-cnn')

        # increase position embeddings from 1024 to 2048
        self.add_position_embeddings()

        # now add special tokens (for title and abstract demarcation)
        # as a general note: we'll assume "abstract" is either the
        # actual abstract of extracted text from the same (i.e., punchlines)
        self.add_special_tokens()

        # now load the checkpoint
        print("loading checkpoint", chkpt_path)
        checkpoint = torch.load(chkpt_path, map_location="cpu")
        print("done")

        cnew = {}
        for key, value in checkpoint['state_dict'].items():
            cnew[".".join(key.split('.')[1:])] = value
        self.model.load_state_dict(cnew)
Ejemplo n.º 18
0
def load_bart_fever_rte_model(model_name, data_dir):
    processors = {
        "rte": RteProcessor
    }

    output_modes = {
        "rte": "classification"
    }
    # task_name = args.task_name.lower()
    task_name = 'rte'
    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    output_mode = output_modes[task_name]

    label_list = processor.get_labels()  # [0,1]
    num_labels = len(label_list)
    pretrain_model_dir = '{}/FineTuneOn{}'.format(data_dir, model_name)
    # pretrain_model_dir = 'please enter your pretrain models path here/FineTuneOn{}'.format(model_name)
    # Prepare model
    # cache_dir = os.path.join(str(PYTORCH_TRANSFORMERS_CACHE), '{} model distributed_{}'.format(model_name, -1))
    # # cache_dir = os.path.join(str(PYTORCH_TRANSFORMERS_CACHE), '{} model distributed_{}'.format(model_name, -1))

    model = BartForSequenceClassification.from_pretrained(pretrain_model_dir, num_labels=num_labels)
    tokenizer = BartTokenizer.from_pretrained(pretrain_model_dir)
    # model = BertForSequenceClassification.from_pretrained('bert-base-uncased',
    #           cache_dir=cache_dir,
    #           num_labels=num_labels)
    # tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
    # print(tokenizer)
    return model, tokenizer
Ejemplo n.º 19
0
def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE):
    fout = Path(out_file).open("w")
    model = BartForMaskedLM.from_pretrained(
        "bart-large-cnn",
        output_past=True,
    )
    tokenizer = BartTokenizer.from_pretrained("bart-large")
    for batch in tqdm(list(chunks(lns, batch_size))):
        dct = tokenizer.batch_encode_plus(batch,
                                          max_length=1024,
                                          return_tensors="pt",
                                          pad_to_max_length=True)
        summaries = model.generate(
            input_ids=dct["input_ids"].to(device),
            attention_mask=dct["attention_mask"].to(device),
            num_beams=4,
            length_penalty=2.0,
            max_length=140,
            min_len=55,
            no_repeat_ngram_size=3,
        )
        dec = [
            tokenizer.decode(g,
                             skip_special_tokens=True,
                             clean_up_tokenization_spaces=False)
            for g in summaries
        ]
        for hypothesis in dec:
            fout.write(hypothesis + "\n")
            fout.flush()
Ejemplo n.º 20
0
    def __init__(self):

        self.nli_model = BartForSequenceClassification.from_pretrained(
            'facebook/bart-large-mnli')
        self.nli_model = self.nli_model.to(DEVICE)
        self.tokenizer = BartTokenizer.from_pretrained(
            'facebook/bart-large-mnli')
Ejemplo n.º 21
0
    def setUpClass(cls):
        # summarization
        # generate yes beam search
        # Note for BART summarization in transformers repo, beam search performs much better
        #  than no beam search, but even their beam search with num_beams=1 is better, implying that something
        #  is broken in the _generate_no_beam_search function

        # see ``examples/summarization/bart/evaluate_cnn.py`` for a longer example
        cls.model = BartForConditionalGeneration.from_pretrained(
            'bart-large-cnn')
        cls.tokenizer = BartTokenizer.from_pretrained('bart-large-cnn')

        cls.decoding_hyperparams = {'max_length': 40, 'num_beams': 3}

        cls.test_news_article_1 = 'New Zealand says it has stopped community transmission of Covid-19, ' \
                            'effectively eliminating the virus. With new cases in single figures for several days - one on Sunday ' \
                            '- Prime Minister Jacinda Ardern said the virus was "currently" eliminated. But officials have warned ' \
                            'against complacency, saying it does not mean a total end to new coronavirus cases. ' \
                            'The news comes hours before New Zealand is set to move out of its toughest level of social restrictions. ' \
                            'From Tuesday, some non-essential business, healthcare and education activity will be able to resume. ' \
                            'Most people will still be required to remain at home at all times and avoid all social interactions.'

        cls.test_news_article_2 = \
            'But officials have warned against complacency, saying it does not mean a total end to new HIV cases. ' \
            'Most people will still be required to remain at home at all times and avoid all social interactions.' \
            'Germany says it has stopped community transmission of HIV, ' \
            'effectively eliminating the virus. With new cases in single figures for several days - one on Sunday ' \
            '- Prime Minister Angela Merkle said the virus was "currently" eliminated. ' \
            'From Tuesday, some non-essential business, healthcare and education activity will be able to resume. ' \
            'The news comes hours before Germany is set to move out of its toughest level of social restrictions. '
Ejemplo n.º 22
0
 def __init__(
     self,
     model_name_or_path,
     tokenizer_name,
     model_cache_dir,
     input_max_length,
     target_max_length,
     summary_column_name,
     document_column_name,
     wandb_project,
     wandb_run_name,
     **kwargs,
 ):
     super().__init__(
         input_max_length,
         target_max_length,
         summary_column_name,
         document_column_name,
         wandb_project,
         wandb_run_name,
     )
     self.tokenizer = BartTokenizer.from_pretrained(
         tokenizer_name if tokenizer_name else model_name_or_path,
         cache_dir=model_cache_dir,
     )
     self.model = BartForConditionalGeneration.from_pretrained(
         model_name_or_path,
         cache_dir=model_cache_dir,
     )
Ejemplo n.º 23
0
    def test_xsum_summarization_same_as_fairseq(self):
        model = BartForConditionalGeneration.from_pretrained(
            "facebook/bart-large-xsum").to(torch_device)
        self.assertFalse(model.config.is_valid_mbart())
        tok = BartTokenizer.from_pretrained("facebook/bart-large")

        EXPECTED_SUMMARY = "California's largest power company has begun shutting off power to tens of thousands of homes and businesses in the state."
        dct = tok.batch_encode_plus(
            [PGE_ARTICLE],
            max_length=1024,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        ).to(torch_device)

        hypotheses_batch = model.generate(
            input_ids=dct["input_ids"],
            attention_mask=dct["attention_mask"],
            num_beams=2,
            max_length=62,
            min_length=11,
            length_penalty=1.0,
            no_repeat_ngram_size=3,
            early_stopping=True,
            decoder_start_token_id=model.config.eos_token_id,
        )

        decoded = tok.batch_decode(
            hypotheses_batch,
            skip_special_tokens=True,
        )
        self.assertEqual(EXPECTED_SUMMARY, decoded[0])
Ejemplo n.º 24
0
def main(args):

    tokenizer = BartTokenizer.from_pretrained(args.tokenizer_path)

    proj_dir = Path()
    corpus_dir = proj_dir / "corpus"
    comment_dir = corpus_dir / "comment"

    source_path = comment_dir / args.corpus
    mask_path = comment_dir / args.mask_path

    dm = BartDataModule(source_path=source_path,
                        mask_path=mask_path,
                        tokenizer=tokenizer,
                        batch_size=3,
                        num_workers=1)

    data_loader = dm.train_dataloader()

    for masked_encode, attention_mask, encode in data_loader:
        masked_encode = masked_encode.detach().cpu().numpy()
        encode = encode.detach().cpu().numpy()

        for m, e in zip(masked_encode, encode):
            print(tokenizer.decode(m))
            print(tokenizer.decode(e))
            print(attention_mask)
Ejemplo n.º 25
0
def generate_summaries(
    examples: list, out_file: str, model_name: str, batch_size: int = 8, device: str = DEFAULT_DEVICE
):
    fout = Path(out_file).open("w")
    model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn', output_past=True).to(device)
    tokenizer = BartTokenizer.from_pretrained("bart-large-cnn")

    max_length = 140
    min_length = 55

    for batch in tqdm(list(chunks(examples, batch_size))):
        dct = tokenizer.batch_encode_plus(batch, max_length=64, return_tensors="pt", pad_to_max_length=True)
        print(dct["input_ids"][0])
        print(dct["attention_mask"][0])
        summaries = model.generate(
            input_ids=dct["input_ids"].to(device),
            attention_mask=dct["attention_mask"].to(device),
            num_beams=4,
            length_penalty=10.0,
            repetition_penalty = 5.0,
            max_length=20,  # +2 from original because we start at step=1 and stop before max_length
            #min_length=min_length + 1,  # +1 from original because we start at step=1
            no_repeat_ngram_size=3,
            early_stopping=True,
        #    decoder_start_token_id=model.config.eos_token_id,
        )
        dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries]
        in_ids = dct["input_ids"].to(device)
        in_dec = [tokenizer.decode(id, skip_special_tokens=True, clean_up_tokenization_spaces=False) for id in in_ids]
        for input, hypothesis in zip(in_dec, dec):
            fout.write(input + ' ||| ' + hypothesis + "\n")
            fout.flush()
Ejemplo n.º 26
0
    def test_init_and_from_pretrained(self):
        rag_config = self.get_rag_config()
        rag_decoder_tokenizer = BartTokenizer.from_pretrained(
            "facebook/bart-large-cnn")
        rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
            "facebook/dpr-question_encoder-single-nq-base")
        rag_retriever = RagRetriever(
            rag_config,
            question_encoder_tokenizer=rag_question_encoder_tokenizer,
            generator_tokenizer=rag_decoder_tokenizer,
        )

        rag_config = RagConfig.from_pretrained("facebook/rag-sequence-base")
        rag = TFRagTokenForGeneration(rag_config, retriever=rag_retriever)

        input_ids = rag_question_encoder_tokenizer(
            "who sings does he love me with reba",
            return_tensors="tf").input_ids
        decoder_input_ids = rag_decoder_tokenizer(
            "Linda Davis", return_tensors="tf").input_ids

        rag(
            input_ids,
            decoder_input_ids=decoder_input_ids,
        )

        # this should not give any warnings
        with tempfile.TemporaryDirectory() as tmpdirname:
            rag.save_pretrained(tmpdirname)
            rag = TFRagTokenForGeneration.from_pretrained(
                tmpdirname, retriever=rag_retriever)
Ejemplo n.º 27
0
    def test_diverse_beam_search(self):
        article = """Justin Timberlake and Jessica Biel, welcome to parenthood.
        The celebrity couple announced the arrival of their son, Silas Randall Timberlake, in statements to People.
        "Silas was the middle name of Timberlake's maternal grandfather Bill Bomar, who died in 2012, while Randall is the musician's own middle name, as well as his father's first," People reports.
        The couple announced the pregnancy in January, with an Instagram post. It is the first baby for both."""

        bart_tokenizer = BartTokenizer.from_pretrained(
            "facebook/bart-large-cnn")
        bart_model = BartForConditionalGeneration.from_pretrained(
            "facebook/bart-large-cnn").to(torch_device)
        input_ids = bart_tokenizer(
            article, return_tensors="pt").input_ids.to(torch_device)

        outputs = bart_model.generate(input_ids,
                                      num_beams=4,
                                      num_return_sequences=2,
                                      num_beam_groups=4,
                                      diversity_penalty=2.0)

        generated_text = bart_tokenizer.batch_decode(outputs,
                                                     skip_special_tokens=True)

        self.assertListEqual(
            generated_text,
            [
                "The couple announced the birth of their son, Silas Randall Timberlake, in a statement. Silas was the middle name of Timberlake's maternal grandfather Bill Bomar. Randall is the musician's own middle name, as well as his father's first. It is the first baby for both of them.",
                "Justin Timberlake and Jessica Biel have a son. The baby is named Silas Randall Timberlake. It is the first child for both. The couple announced the pregnancy in January. The name Silas is the middle name of Timberlake's maternal grandfather. It's also his own middle name.",
            ],
        )
Ejemplo n.º 28
0
def bart_summarize(input_file):
    model = BartForConditionalGeneration.from_pretrained(
        'facebook/bart-large-cnn')
    tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')

    num_count = get_num_pages(input_file)
    f = open('summarized_bart.txt', 'a+')

    count = 0
    while count < num_count:
        text = pdf_to_text(input_file, count)
        ARTICLE_TO_SUMMARIZE = text
        inputs = tokenizer([ARTICLE_TO_SUMMARIZE],
                           max_length=1024,
                           return_tensors='pt')
        # Generate Summary
        summary_ids = model.generate(inputs['input_ids'],
                                     num_beams=4,
                                     max_length=5,
                                     early_stopping=True)
        summarized_text = [
            tokenizer.decode(g,
                             skip_special_tokens=True,
                             clean_up_tokenization_spaces=False)
            for g in summary_ids
        ]
        print(summarized_text)
        str1 = ''.join(summarized_text)
        print(str1)
        f.write(str1)
        count += 1

    f.close()
Ejemplo n.º 29
0
    def test_bart_summarization_dataset(self):
        tmp_dir = Path(tempfile.gettempdir())
        articles = [" Sam ate lunch today", "Sams lunch ingredients"]
        summaries = [
            "A very interesting story about what I ate for lunch.",
            "Avocado, celery, turkey, coffee"
        ]
        _dump_articles((tmp_dir / "train.source"), articles)
        _dump_articles((tmp_dir / "train.target"), summaries)
        tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
        max_len_source = max(len(tokenizer.encode(a)) for a in articles)
        max_len_target = max(len(tokenizer.encode(a)) for a in summaries)
        trunc_target = 4
        train_dataset = SummarizationDataset(
            tokenizer,
            data_dir=tmp_dir,
            type_path="train",
            max_source_length=20,
            max_target_length=trunc_target,
        )
        dataloader = DataLoader(train_dataset,
                                batch_size=2,
                                collate_fn=train_dataset.collate_fn)
        for batch in dataloader:
            self.assertEqual(batch["source_mask"].shape,
                             batch["source_ids"].shape)
            # show that articles were trimmed.
            self.assertEqual(batch["source_ids"].shape[1], max_len_source)
            self.assertGreater(
                20, batch["source_ids"].shape[1])  # trimmed significantly

            # show that targets were truncated
            self.assertEqual(batch["target_ids"].shape[1],
                             trunc_target)  # Truncated
            self.assertGreater(max_len_target, trunc_target)  # Truncated
Ejemplo n.º 30
0
 def __init__(self,
              numericaliser='BART',
              fields=[("input_text", "input_ids")],
              use_ray=False,
              debug=True,
              max_len=1000,
              **kwargs):
     if numericaliser == 'BART':
         self.numericaliser = BartTokenizer.from_pretrained(
             'facebook/bart-large').encode
     elif numericaliser == 'BERT':
         self.numericaliser = BertTokenizer.from_pretrained(
             'bert-base-uncased').encode
     elif numericaliser == 'Code32k':
         if not os.path.isfile(
                 "datasets/code_search_net/codeBPE.tokenizer.json"):
             download_from_url(
                 "https://storage.googleapis.com/carlos-phd-data/code-search-net-tokenizer/codeBPE.tokenizer.json",
                 "datasets/code_search_net/codeBPE.tokenizer.json")
         code_BPE_tokenizer = Tokenizer.from_file(
             "datasets/code_search_net/codeBPE.tokenizer.json")
         self.custom_tokenizer = code_BPE_tokenizer
         self.numericaliser = self.custom_tokenizer2ids
     else:
         self.numericaliser = numericaliser
     if debug:
         print(
             f"Numericaliser. Ex: 'This is a test' -> {self.numericaliser('This is a test')}"
         )
     self.fields = fields
     self.use_ray = use_ray
     self.max_len = max_len