Beispiel #1
0
class ExsBuilder:
    """ExsBuilder produces a list of examples given a document set"""
    def __init__(self,
                 bert_model='bert-base-uncased',
                 file_emb='',
                 vocab_size=150000,
                 min_src_nsents=1,
                 max_src_nsents=50,
                 min_src_ntokens_per_sent=3,
                 max_src_ntokens_per_sent=100):
        logger.info('=== Initializing a example builder'.ljust(80, '='))
        self.min_src_nsents = min_src_nsents
        self.max_src_nsents = max_src_nsents
        self.min_src_ntokens_per_sent = min_src_ntokens_per_sent
        self.max_src_ntokens_per_sent = max_src_ntokens_per_sent

        logger.debug(f'Loading BERT pre-trained model [{bert_model}]')
        self.tokB = BertTokenizer.from_pretrained(bert_model)
        self.tokC = None
        if file_emb != '':
            logger.debug('Loading the WBMET dictionary for custom tokenizer')
            self.tokC = Tokenizer(vocab_size=vocab_size)
            self.tokC.from_pretrained(file_emb)
        self.doc_lbl_freq = [0, 0]  # document-level [irrel, rel]
        self.ext_lbl_freq = [0, 0]  # token-level [irrel, rel]

    @staticmethod
    def tokenize(data, src_keys=['title', 'body'], tgt_key='text'):
        """Use Stanford CoreNLP tokenizer to tokenize all the documents."""
        REMAP = {
            "-LRB-": "(",
            "-RRB-": ")",
            "-LCB-": "{",
            "-RCB-": "}",
            "-LSB-": "[",
            "-RSB-": "]",
            "``": '"',
            "''": '"'
        }
        with CoreNLPClient(annotators=['tokenize', 'ssplit'], threads=CPU_CNT)\
                as client:
            for did, d in tqdm(data.items()):
                text = ''
                for k in src_keys:
                    text += d[k] + ' '
                ann = client.annotate(text.strip())
                tokens = []  # list of tokenized sentences
                for sent in ann.sentence:
                    tokens.append([
                        REMAP[t.word] if t.word in REMAP else t.word.lower()
                        for t in sent.token
                    ])
                d[tgt_key] = tokens

    def encode(self, exs):
        """Convert sequences into indicies and create data entries for
        model inputs"""
        rtn = []
        logger.info('Encoding examples...')
        for qid, did, rel, doc, flds, mesh, keywords in tqdm(exs):
            entry = {
                'qid': qid,
                'did': did,
                'src': [],
                'src_sent_lens': [],
                'tgtB': [],
                'tgtB_sent_lens': [],
                'tgtC': [],
                'tgtC_sent_lens': []
            }

            # src
            for s in doc:  # CoreNLP tokenized sequences (list of sentences)
                if len(s) <= self.min_src_ntokens_per_sent:
                    continue
                src_str = ' '.join(s[:self.max_src_ntokens_per_sent])
                entry['src'] += self.tokB.convert_tokens_to_ids(
                    self.tokB.tokenize(src_str))
                entry['src_sent_lens'].append(len(entry['src']))
            if len(entry['src']) == 0:
                continue

            # tgt - fields
            tgt_tokens = set()  # Used in identifying token-level labels
            for seq in flds:  # flds (disease, gene, demo)
                # BERT
                ids = self.tokB.convert_tokens_to_ids(self.tokB.tokenize(seq))
                tgt_tokens.update(ids)
                entry['tgtB'] += ids
                entry['tgtB_sent_lens'].append(len(entry['tgtB']))
                # BMET
                ids = self.tokC.convert_tokens_to_ids(self.tokC.tokenize(seq))
                ids = list(filter(lambda x: x > 1, ids))  # Remove UNKs
                entry['tgtC'] += ids
                entry['tgtC_sent_lens'].append(len(entry['tgtC']))

            # tgt - mesh
            mesh = [f'εmesh_{t}' for t in mesh[0].lower().split()]
            ids = self.tokC.convert_tokens_to_ids(mesh)
            ids = list(filter(lambda x: x > 1, ids))  # Remove UNKs
            entry['tgtC'] += ids
            entry['tgtC_sent_lens'].append(len(entry['tgtC']))

            # tgt - keywords
            seq = ' '.join(keywords)
            ids = self.tokC.convert_tokens_to_ids(self.tokC.tokenize(seq))
            ids = list(filter(lambda x: x > 1, ids))  # Remove UNKs
            tgt_tokens.update(ids)
            entry['tgtC'] += ids
            entry['tgtC_sent_lens'].append(len(entry['tgtC']))
            entry['token_labels'] = \
                [1 if t in tgt_tokens else 0 for t in entry['src']]
            sum_ = sum(entry['token_labels'])
            self.ext_lbl_freq[0] += len(entry['token_labels']) - sum_
            self.ext_lbl_freq[1] += sum_
            entry['doc_label'] = 0 if rel == 0 else 1
            rtn.append(entry)
        return rtn

    def build_trec_exs(self, topics, docs):
        """For each topic and doc pair, encode them, and construct example list
        """
        exs = list()
        # Tokenize document using Stanford CoreNLP Tokenizer
        logger.debug(
            'Tokenizing %s documents using Stanford CoreNLP '
            'Tokenizer...', len(docs))
        self.tokenize(docs)

        # Add positive examples
        for qid in topics:
            for did, rel in topics[qid]['docs']:
                if did not in docs or \
                        len(docs[did]['text']) < self.min_src_nsents:
                    continue
                d = docs[did]
                # Complete keywords: doc_keywords > doc_mesh > q_mesh
                keywords = d['keywords'] if len(d['keywords']) > 0 \
                    else d['mesh_names']
                if len(keywords) == 0 and rel > 0:
                    keywords = [topics[qid]['mesh'][1]]

                exs.append(
                    (qid, did, rel, d['text'][:self.max_src_nsents],
                     topics[qid]['fields'], topics[qid]['mesh'], keywords))
                self.doc_lbl_freq[int(rel > 0)] += 1

        # Add negative examples
        neg_docs_ids = [did for did, d in docs.items() if not d['pos']]
        qids = random.choices(list(topics.keys()), k=len(neg_docs_ids))
        for i, did in enumerate(neg_docs_ids):
            exs.append(
                (qids[i], did, 0, docs[did]['text'][:self.max_src_nsents],
                 topics[qid]['fields'], topics[qid]['mesh'], []))
            self.doc_lbl_freq[0] += 1
        random.shuffle(exs)
        rtn = self.encode(exs)

        return rtn

    # todo. Following function will be changed
    def build(self, examples, docs):
        """Bulding examples is done in two modes: one for data preparation and
        the other for prediction.

        In data preparation,
        - `exs` are quries in TREC ref datasets
        - `docs` consists of pos and neg documents prepared by `read_pubmed_docs`

        In prediction,
        - `exs` only contains one query with no labels
        - `docs` the retrieved documents from Solr search results

        """
        # Tokenize documents and build examples with doc_labels
        exs = []
        # Title and Text are multivalued ('text_general' in Solr)
        results = docs
        docs = {}
        for r in results:
            title = ' '.join(r['ArticleTitle'] if 'ArticleTitle' in r else [])
            body = ' '.join(r['AbstractText'] if 'AbstractText' in r else [])
            docs[r['id']] = (title + ' ' + body).strip()
        logger.debug(f'Tokinizing {len(docs)} retrieved docs...')
        pos_docs = self.tokenize(docs)

        # Build examples (with dummy label -1)
        qid = list(examples.keys())[0]  # There's only one anyways
        logger.info(f'Preparing examples for {qid}...')
        for did, text in pos_docs.items():
            if len(pos_docs[did]) < self.min_src_nsents:
                continue
            exs.append((qid, did, -1, pos_docs[did][:self.max_src_nsents],
                        examples[qid]['topics']))

        data = self.encode(exs)
        return data
Beispiel #2
0
class Summarizer:
    """Use a test model to generate fielded query sentences from documents"""
    def __init__(self,
                 f_abs,
                 n_best=1,
                 min_length=1,
                 max_length=50,
                 beam_size=4,
                 bert_model='bert-base-uncased'):
        self.n_best = n_best
        self.min_length = min_length
        self.max_length = max_length
        self.beam_size = beam_size
        self.abs_model = self.load_abs_model(f_abs)
        self.eval()
        logger.info(f'Loading BERT Tokenizer [{bert_model}]...')
        self.tokenizerB = BertTokenizer.from_pretrained('bert-base-uncased')
        self.spt_ids_B, self.spt_ids_C, self.eos_mapping = get_special_tokens()
        logger.info('Loading custom Tokenizer for using WBMET embeddings')
        self.tokenizerC = Tokenizer(self.abs_model.args.vocab_size)
        self.tokenizerC.from_pretrained(self.abs_model.args.file_dec_emb)

    @staticmethod
    def load_abs_model(f_abs):
        """Load a pre-trained abs model"""
        logger.info(f'Loading an abstractive test model from {f_abs}...')
        data = torch.load(f_abs, map_location=lambda storage, loc: storage)
        mdl = AbstractiveSummarizer(data['args'])
        mdl.load_state_dict(data['model']).cuda()
        return mdl

    def translate(self, docs):
        """Translate a batch of documents."""
        batch_size = docs.inp.size(0)
        spt_ids = self.spt_ids_C
        decode_strategy = BeamSearch(self.beam_size, batch_size, self.n_best,
                                     self.min_length, self.max_length, spt_ids,
                                     self.eos_mapping)
        return self._translate_batch_with_strategy(docs, decode_strategy)

    def _translate_batch_with_strategy(self, batch, decode_strategy):
        """Translate a batch of documents step by step using cache

        :param batch (dict): A batch of documentsj
        :param decode_strategy (DecodeStrategy): A decode strategy for
            generating translations step by step. I.e., BeamSearch
        """

        # (1) Run the encoder on the src
        ext_scores, hidden_states = \
            self.abs_model.encoder(batch.inp,
                                   attention_mask=batch.mask_inp,
                                   token_type_ids=batch.segs)

        # (2) Prepare decoder and decode_strategy
        self.abs_model.decoder.init_state(batch.inp)
        field_signals = batch.tgt[:, 0]
        fn_map_state, memory_bank, memory_pad_mask = \
            decode_strategy.initialize(hidden_states[-1], batch.src_lens,
                                       field_signals)
        if fn_map_state is not None:
            self.abs_model.decoder.map_state(fn_map_state)

        # (3) Begin decoding step by step:
        for step in range(decode_strategy.max_length):
            decoder_input = decode_strategy.current_predictions.unsqueeze(-1)
            dec_out, attns = self.abs_model.decoder(decoder_input,
                                                    memory_bank,
                                                    memory_pad_mask,
                                                    step=step)
            log_probs = self.abs_model.generator(dec_out[:, -1, :].squeeze(1))
            # Beam advance
            decode_strategy.advance(log_probs, attns)

            any_finished = decode_strategy.is_finished.any()
            if any_finished:
                decode_strategy.update_finished()
                if decode_strategy.done:
                    break

            select_indices = decode_strategy.select_indices
            if any_finished:
                # Reorder states.
                memory_bank = memory_bank.index_select(0, select_indices)
                memory_pad_mask = memory_pad_mask.index_select(
                    0, select_indices)

            if self.beam_size > 1 or any_finished:
                self.abs_model.decoder.map_state(
                    lambda state, dim: state.index_select(dim, select_indices))
        res = {
            'batch':
            batch,
            'gold_scores':
            self._gold_score(batch, hidden_states[-1], batch.mask_inp),
            'scores':
            decode_strategy.scores,
            'predictions':
            decode_strategy.predictions,
            'ext_scores':
            ext_scores,
            'attentions':
            decode_strategy.attention
        }
        return res

    def results_to_translations(self, data):
        """Convert results into Translation object"""
        batch = data['batch']
        translations = []
        for i, did in enumerate(batch.did):
            src_input_ = batch.inp[i]
            src_ = self.tokenizerB.decode(src_input_)
            topic_ = \
                self.tokenizerC.convert_id_to_token(batch.tgt[i][0].item())
            pred_sents_ = [
                self.tokenizerC.decode(data['predictions'][i][n])
                for n in range(self.n_best)
            ]
            gold_sent_ = self.tokenizerC.decode(batch.tgt[i])
            x = Translation(did=did,
                            src_input=src_input_,
                            src=src_,
                            topic=topic_,
                            ext_scores=data['ext_scores'][i],
                            pred_sents=pred_sents_,
                            pred_scores=data['scores'][i],
                            gold_sent=gold_sent_,
                            gold_score=data['gold_scores'][i],
                            attentions=data['attentions'][i])
            translations.append(x)
        return translations

    def _gold_score(self, batch, memory_bank, memory_pad_mask):
        if hasattr(batch, 'tgt'):
            gs = self._score_target(batch, memory_bank, memory_pad_mask)
            self.abs_model.decoder.init_state(batch.inp)
        else:
            gs = [0] * batch.batch_size
        return gs

    def _score_target(self, batch, memory_bank, memory_pad_mask):
        tgt_in = batch.tgt[:, :-1]
        dec_out, _ = self.abs_model.decoder(tgt_in, memory_bank,
                                            memory_pad_mask)
        log_probs = self.abs_model.generator(dec_out)
        gold = batch.tgt[:, 1:]
        tgt_pad_mask = gold.eq(self.spt_ids_C['[PAD]'])
        log_probs[tgt_pad_mask] = 0
        gold_scores = log_probs.gather(2, gold.unsqueeze(-1))
        gold_scores = gold_scores.sum(dim=1).view(-1)
        return gold_scores.tolist()

    def eval(self):
        self.abs_model.eval()