class ExsBuilder: """ExsBuilder produces a list of examples given a document set""" def __init__(self, bert_model='bert-base-uncased', file_emb='', vocab_size=150000, min_src_nsents=1, max_src_nsents=50, min_src_ntokens_per_sent=3, max_src_ntokens_per_sent=100): logger.info('=== Initializing a example builder'.ljust(80, '=')) self.min_src_nsents = min_src_nsents self.max_src_nsents = max_src_nsents self.min_src_ntokens_per_sent = min_src_ntokens_per_sent self.max_src_ntokens_per_sent = max_src_ntokens_per_sent logger.debug(f'Loading BERT pre-trained model [{bert_model}]') self.tokB = BertTokenizer.from_pretrained(bert_model) self.tokC = None if file_emb != '': logger.debug('Loading the WBMET dictionary for custom tokenizer') self.tokC = Tokenizer(vocab_size=vocab_size) self.tokC.from_pretrained(file_emb) self.doc_lbl_freq = [0, 0] # document-level [irrel, rel] self.ext_lbl_freq = [0, 0] # token-level [irrel, rel] @staticmethod def tokenize(data, src_keys=['title', 'body'], tgt_key='text'): """Use Stanford CoreNLP tokenizer to tokenize all the documents.""" REMAP = { "-LRB-": "(", "-RRB-": ")", "-LCB-": "{", "-RCB-": "}", "-LSB-": "[", "-RSB-": "]", "``": '"', "''": '"' } with CoreNLPClient(annotators=['tokenize', 'ssplit'], threads=CPU_CNT)\ as client: for did, d in tqdm(data.items()): text = '' for k in src_keys: text += d[k] + ' ' ann = client.annotate(text.strip()) tokens = [] # list of tokenized sentences for sent in ann.sentence: tokens.append([ REMAP[t.word] if t.word in REMAP else t.word.lower() for t in sent.token ]) d[tgt_key] = tokens def encode(self, exs): """Convert sequences into indicies and create data entries for model inputs""" rtn = [] logger.info('Encoding examples...') for qid, did, rel, doc, flds, mesh, keywords in tqdm(exs): entry = { 'qid': qid, 'did': did, 'src': [], 'src_sent_lens': [], 'tgtB': [], 'tgtB_sent_lens': [], 'tgtC': [], 'tgtC_sent_lens': [] } # src for s in doc: # CoreNLP tokenized sequences (list of sentences) if len(s) <= self.min_src_ntokens_per_sent: continue src_str = ' '.join(s[:self.max_src_ntokens_per_sent]) entry['src'] += self.tokB.convert_tokens_to_ids( self.tokB.tokenize(src_str)) entry['src_sent_lens'].append(len(entry['src'])) if len(entry['src']) == 0: continue # tgt - fields tgt_tokens = set() # Used in identifying token-level labels for seq in flds: # flds (disease, gene, demo) # BERT ids = self.tokB.convert_tokens_to_ids(self.tokB.tokenize(seq)) tgt_tokens.update(ids) entry['tgtB'] += ids entry['tgtB_sent_lens'].append(len(entry['tgtB'])) # BMET ids = self.tokC.convert_tokens_to_ids(self.tokC.tokenize(seq)) ids = list(filter(lambda x: x > 1, ids)) # Remove UNKs entry['tgtC'] += ids entry['tgtC_sent_lens'].append(len(entry['tgtC'])) # tgt - mesh mesh = [f'εmesh_{t}' for t in mesh[0].lower().split()] ids = self.tokC.convert_tokens_to_ids(mesh) ids = list(filter(lambda x: x > 1, ids)) # Remove UNKs entry['tgtC'] += ids entry['tgtC_sent_lens'].append(len(entry['tgtC'])) # tgt - keywords seq = ' '.join(keywords) ids = self.tokC.convert_tokens_to_ids(self.tokC.tokenize(seq)) ids = list(filter(lambda x: x > 1, ids)) # Remove UNKs tgt_tokens.update(ids) entry['tgtC'] += ids entry['tgtC_sent_lens'].append(len(entry['tgtC'])) entry['token_labels'] = \ [1 if t in tgt_tokens else 0 for t in entry['src']] sum_ = sum(entry['token_labels']) self.ext_lbl_freq[0] += len(entry['token_labels']) - sum_ self.ext_lbl_freq[1] += sum_ entry['doc_label'] = 0 if rel == 0 else 1 rtn.append(entry) return rtn def build_trec_exs(self, topics, docs): """For each topic and doc pair, encode them, and construct example list """ exs = list() # Tokenize document using Stanford CoreNLP Tokenizer logger.debug( 'Tokenizing %s documents using Stanford CoreNLP ' 'Tokenizer...', len(docs)) self.tokenize(docs) # Add positive examples for qid in topics: for did, rel in topics[qid]['docs']: if did not in docs or \ len(docs[did]['text']) < self.min_src_nsents: continue d = docs[did] # Complete keywords: doc_keywords > doc_mesh > q_mesh keywords = d['keywords'] if len(d['keywords']) > 0 \ else d['mesh_names'] if len(keywords) == 0 and rel > 0: keywords = [topics[qid]['mesh'][1]] exs.append( (qid, did, rel, d['text'][:self.max_src_nsents], topics[qid]['fields'], topics[qid]['mesh'], keywords)) self.doc_lbl_freq[int(rel > 0)] += 1 # Add negative examples neg_docs_ids = [did for did, d in docs.items() if not d['pos']] qids = random.choices(list(topics.keys()), k=len(neg_docs_ids)) for i, did in enumerate(neg_docs_ids): exs.append( (qids[i], did, 0, docs[did]['text'][:self.max_src_nsents], topics[qid]['fields'], topics[qid]['mesh'], [])) self.doc_lbl_freq[0] += 1 random.shuffle(exs) rtn = self.encode(exs) return rtn # todo. Following function will be changed def build(self, examples, docs): """Bulding examples is done in two modes: one for data preparation and the other for prediction. In data preparation, - `exs` are quries in TREC ref datasets - `docs` consists of pos and neg documents prepared by `read_pubmed_docs` In prediction, - `exs` only contains one query with no labels - `docs` the retrieved documents from Solr search results """ # Tokenize documents and build examples with doc_labels exs = [] # Title and Text are multivalued ('text_general' in Solr) results = docs docs = {} for r in results: title = ' '.join(r['ArticleTitle'] if 'ArticleTitle' in r else []) body = ' '.join(r['AbstractText'] if 'AbstractText' in r else []) docs[r['id']] = (title + ' ' + body).strip() logger.debug(f'Tokinizing {len(docs)} retrieved docs...') pos_docs = self.tokenize(docs) # Build examples (with dummy label -1) qid = list(examples.keys())[0] # There's only one anyways logger.info(f'Preparing examples for {qid}...') for did, text in pos_docs.items(): if len(pos_docs[did]) < self.min_src_nsents: continue exs.append((qid, did, -1, pos_docs[did][:self.max_src_nsents], examples[qid]['topics'])) data = self.encode(exs) return data
class Summarizer: """Use a test model to generate fielded query sentences from documents""" def __init__(self, f_abs, n_best=1, min_length=1, max_length=50, beam_size=4, bert_model='bert-base-uncased'): self.n_best = n_best self.min_length = min_length self.max_length = max_length self.beam_size = beam_size self.abs_model = self.load_abs_model(f_abs) self.eval() logger.info(f'Loading BERT Tokenizer [{bert_model}]...') self.tokenizerB = BertTokenizer.from_pretrained('bert-base-uncased') self.spt_ids_B, self.spt_ids_C, self.eos_mapping = get_special_tokens() logger.info('Loading custom Tokenizer for using WBMET embeddings') self.tokenizerC = Tokenizer(self.abs_model.args.vocab_size) self.tokenizerC.from_pretrained(self.abs_model.args.file_dec_emb) @staticmethod def load_abs_model(f_abs): """Load a pre-trained abs model""" logger.info(f'Loading an abstractive test model from {f_abs}...') data = torch.load(f_abs, map_location=lambda storage, loc: storage) mdl = AbstractiveSummarizer(data['args']) mdl.load_state_dict(data['model']).cuda() return mdl def translate(self, docs): """Translate a batch of documents.""" batch_size = docs.inp.size(0) spt_ids = self.spt_ids_C decode_strategy = BeamSearch(self.beam_size, batch_size, self.n_best, self.min_length, self.max_length, spt_ids, self.eos_mapping) return self._translate_batch_with_strategy(docs, decode_strategy) def _translate_batch_with_strategy(self, batch, decode_strategy): """Translate a batch of documents step by step using cache :param batch (dict): A batch of documentsj :param decode_strategy (DecodeStrategy): A decode strategy for generating translations step by step. I.e., BeamSearch """ # (1) Run the encoder on the src ext_scores, hidden_states = \ self.abs_model.encoder(batch.inp, attention_mask=batch.mask_inp, token_type_ids=batch.segs) # (2) Prepare decoder and decode_strategy self.abs_model.decoder.init_state(batch.inp) field_signals = batch.tgt[:, 0] fn_map_state, memory_bank, memory_pad_mask = \ decode_strategy.initialize(hidden_states[-1], batch.src_lens, field_signals) if fn_map_state is not None: self.abs_model.decoder.map_state(fn_map_state) # (3) Begin decoding step by step: for step in range(decode_strategy.max_length): decoder_input = decode_strategy.current_predictions.unsqueeze(-1) dec_out, attns = self.abs_model.decoder(decoder_input, memory_bank, memory_pad_mask, step=step) log_probs = self.abs_model.generator(dec_out[:, -1, :].squeeze(1)) # Beam advance decode_strategy.advance(log_probs, attns) any_finished = decode_strategy.is_finished.any() if any_finished: decode_strategy.update_finished() if decode_strategy.done: break select_indices = decode_strategy.select_indices if any_finished: # Reorder states. memory_bank = memory_bank.index_select(0, select_indices) memory_pad_mask = memory_pad_mask.index_select( 0, select_indices) if self.beam_size > 1 or any_finished: self.abs_model.decoder.map_state( lambda state, dim: state.index_select(dim, select_indices)) res = { 'batch': batch, 'gold_scores': self._gold_score(batch, hidden_states[-1], batch.mask_inp), 'scores': decode_strategy.scores, 'predictions': decode_strategy.predictions, 'ext_scores': ext_scores, 'attentions': decode_strategy.attention } return res def results_to_translations(self, data): """Convert results into Translation object""" batch = data['batch'] translations = [] for i, did in enumerate(batch.did): src_input_ = batch.inp[i] src_ = self.tokenizerB.decode(src_input_) topic_ = \ self.tokenizerC.convert_id_to_token(batch.tgt[i][0].item()) pred_sents_ = [ self.tokenizerC.decode(data['predictions'][i][n]) for n in range(self.n_best) ] gold_sent_ = self.tokenizerC.decode(batch.tgt[i]) x = Translation(did=did, src_input=src_input_, src=src_, topic=topic_, ext_scores=data['ext_scores'][i], pred_sents=pred_sents_, pred_scores=data['scores'][i], gold_sent=gold_sent_, gold_score=data['gold_scores'][i], attentions=data['attentions'][i]) translations.append(x) return translations def _gold_score(self, batch, memory_bank, memory_pad_mask): if hasattr(batch, 'tgt'): gs = self._score_target(batch, memory_bank, memory_pad_mask) self.abs_model.decoder.init_state(batch.inp) else: gs = [0] * batch.batch_size return gs def _score_target(self, batch, memory_bank, memory_pad_mask): tgt_in = batch.tgt[:, :-1] dec_out, _ = self.abs_model.decoder(tgt_in, memory_bank, memory_pad_mask) log_probs = self.abs_model.generator(dec_out) gold = batch.tgt[:, 1:] tgt_pad_mask = gold.eq(self.spt_ids_C['[PAD]']) log_probs[tgt_pad_mask] = 0 gold_scores = log_probs.gather(2, gold.unsqueeze(-1)) gold_scores = gold_scores.sum(dim=1).view(-1) return gold_scores.tolist() def eval(self): self.abs_model.eval()