Пример #1
0
    def predict_for_EL(self, X, y=None):
        predictions = []
        annotated_documents = []
        layer_weights = []
        model = Model(
            inputs=self.model.input,
            outputs=self.model.get_layer(name='scaled_dot_product').output)
        for idx, document in enumerate(X):

            # # get tokens
            tokens, _ = transform_annotated_document_to_bio_format(document)
            lengths = map(len, [tokens])
            x_test = self.p.transform([tokens])

            output = self.model.predict(x=x_test)
            tags = self.p.inverse_transform(output, lengths)
            predictions.append(tags)

            annotated_documents.append(
                transform_bio_tags_to_annotated_document(
                    tokens, tags[0], document))

            layer_weights.append(model.predict(x=x_test))

        return annotated_documents, predictions, layer_weights
Пример #2
0
    def transform(self, X, y=None):
        """ Annotates the list of `Document` objects that are provided as
            input and returns a list of `AnnotatedDocument` objects.
        """
        log.info(
            "Annotating named entities in {} documents with BiLSTM...".format(
                len(X)))
        annotated_documents = []
        # x_test, _ = self._transform_to_bio(X)
        for idx, document in enumerate(X):

            # # get tokens
            tokens, _ = transform_annotated_document_to_bio_format(document)

            lengths = map(len, [tokens])
            x_test = self.p.transform([tokens])

            # get predicted tags
            output = self.model.predict(x=x_test)
            tags = self.p.inverse_transform(output, lengths)

            # annotate a document
            annotated_documents.append(
                transform_bio_tags_to_annotated_document(
                    tokens, tags[0], document))
            # info
            log_progress(log, idx, len(X))

        return annotated_documents
def annotated_docs_to_tokens(docs, sentence_pad=False):
    """Align tokenized docs
    
    """
    text_list = []
    label_list = []
    tokens_list = []
    for i, doc in enumerate(docs):
        if sentence_pad:
            text = [[r'<s>'] + text_to_tokens(sent) + [r'<\s>']
                    for sent in text_to_sentences(doc.plain_text_)[0]
                    if len(sent.split()) > 0]
        else:
            text = [
                text_to_tokens(sent)
                for sent in text_to_sentences(doc.plain_text_)[0]
                if len(sent.split()) > 0
            ]

        text_list.append(text)

        count = 0
        pad_index = []
        for line in text:
            for idx, word in enumerate(line):
                if word == r'<s>' or word == r'<\s>':
                    pad_index.append(count + idx)
            count += len(line)

        tokens, labels = transform_annotated_document_to_bio_format(doc)

        count = 0
        for i, line in enumerate(text_list[-1]):
            start_count = 0
            for j, word in enumerate(line):
                if word not in [r'<s>', r'<\s>'] and word != tokens[count]:
                    k = 0
                    start_count = count
                    if tokens[count] in word:
                        text_list[-1][i][j] = tokens[count + k]
                        k += 1
                    while count + k < len(tokens) and tokens[count +
                                                             k] in word:
                        text_list[-1][i].insert(j + k, tokens[count + k])
                        k += 1
                    # print(f'Error: split text= {word}, token{tokens[start_count:count+k]}')
                    count += 1
                elif word not in [r'<s>', r'<\s>']:
                    count += 1

        [labels.insert(i, 'O') for i in pad_index]
        [tokens.insert(i, r'<s>') for i in pad_index]

        label_list.append(labels)
        tokens_list.append(tokens)

    return text_list, label_list, tokens_list
def EL_set(docs, toD_mesh, id2idx_dict):
    data_dict = {}
    all_labels = []
    for idx, doc in enumerate(docs):
        _, bio_labels = transform_annotated_document_to_bio_format(doc)
        entity_list = get_normalizations(doc)
        masks = get_masks(bio_labels, len(entity_list))

        label, mask_list = [], []
        for i in range(len(entity_list)):
            # create C-2-D and UMIM-D and UMIM-C-M filter
            if '+' in entity_list[i]:
                entity_list[i] = entity_list[i].split('+')[0]
            elif '|' in entity_list[i]:
                entity_list[i] = entity_list[i].split('|')[0]
            if entity_list[i] not in id2idx_dict:
                item = toD_mesh.transform(entity_list[i])
                if item is not None:
                    if item not in id2idx_dict:
                        print(f"D MeSH {item} not found in Disease list. Skipping this normalization...")
                        continue
                    entity_list[i] = item
                else:
                    print(f"D MeSH equivalent of {entity_list[i]} not found. Skipping this normalization...")
                    continue
        
            label.append(torch.tensor(id2idx_dict[entity_list[i]]))
            mask = masks[i].tolist()
            # mask = adjust_mask(mask, t, tokens)
            mask_list.append(torch.tensor(mask))

            all_labels.append(entity_list[i])

        data_dict[doc.identifier] = (label, mask_list)
    
    return data_dict
Пример #5
0
def construct_data(data,
                   annotated_docs,
                   predictions,
                   scope_note,
                   id_dict,
                   ctd_file,
                   c2m_file,
                   use_ELMO=True,
                   elmo_model=None,
                   elmo_dim=1024,
                   device=torch.device('cpu')):
    """ re-format the data in easily trainable format using pytorch generators
    """
    text = []  # sentence
    text_emb = []
    scope = []  # scope note
    m_id = []  # mesh ID
    mask_list = []  # mask list
    label = []  # labels for positive and vegative examples

    toD = Convert2D(ctd_file, c2m_file)
    skipped_id = []
    for idx, pred_doc in enumerate(annotated_docs):
        tags = predictions[idx]

        o_doc = data[idx]

        tokens, bio_labels = transform_annotated_document_to_bio_format(o_doc)

        new_tags = check_tags(bio_labels, tags)
        entity_list = get_normalizations(o_doc, copy.deepcopy(pred_doc))

        masks = get_masks(new_tags, len(entity_list))

        for i in range(len(entity_list)):
            # create C-2-D and UMIM-D and UMIM-C-M filter
            if '+' in entity_list[i]:
                entity_list[i] = entity_list[i].split('+')[0]
            elif '|' in entity_list[i]:
                entity_list[i] = entity_list[i].split('|')[0]
            if entity_list[i] not in id_dict:
                item = toD.transform(entity_list[i])
                if item is not None:
                    if item not in id_dict:
                        print(
                            f"D MeSH {item} not found in Disease list. Skipping this normalization..."
                        )
                        skipped_id.append(item)
                        continue
                    entity_list[i] = item
                else:
                    print(
                        f"D MeSH equivalent of {entity_list[i]} not found. Skipping this normalization..."
                    )
                    skipped_id.append(entity_list[i])
                    continue
            note = []
            # text, scope_note, Mesh_ID, Mask, positive_lable
            if use_ELMO:
                t = [[r'<s>'] + text_to_tokens(sent) + [r'<\s>']
                     for sent in text_to_sentences(pred_doc.plain_text_)[0]
                     if len(sent.split()) > 0]

                char_id = batch_to_ids(t).to(device)
                with torch.no_grad():
                    elmo_emb = elmo_model(char_id)
                t_emb = elmo_emb['elmo_representations'][0].view(
                    -1, elmo_dim).detach().cpu()
                t_emb = torch.stack([
                    tensor for tensor in t_emb
                    if len(np.nonzero(tensor.numpy())[0]) != 0
                ],
                                    dim=0)
                text_emb.append(t_emb)
                text.extend(t)

                note = scope_note[id_dict[entity_list[i]]]
                note = batch_to_ids(note).to(device)
                with torch.no_grad():
                    elmo_emb = elmo_model(note)
                note = elmo_emb['elmo_representations'][0].view(
                    -1, elmo_dim).detach().cpu()
                scope.append(note)
                mask = masks[i].tolist()
                mask = adjust_mask(mask, t, tokens)
                mask_list.append(torch.tensor(mask))

            else:
                t = text_to_tokens(pred_doc.plain_text_)
                text.append(t)
                _ = [
                    note.extend(line[1:-1])
                    for line in scope_note[id_dict[entity_list[i]]]
                    if len(line) > 1
                ]
                scope.append(note)
                mask = masks[i].tolist()
                mask = adjust_mask(mask, [t], tokens)
                mask_list.append(torch.tensor(mask))

                assert (len(t) == len(mask)
                        ), 'Length of mask is not equal to length of sentence.'

            m_id.append(entity_list[i])
            label.append(1)

    print('Total skipped: ', len(skipped_id), ' unique skips: ',
          len(set(skipped_id)))
    sample = []
    for i in range(len(text)):
        sample.append(
            (text[i], text_emb[i], scope[i], m_id[i], mask_list[i], label[i]))

    return sample, text
def get_formatted_data(x_data, pred_data, tags, weights, device='cpu'):
    entity_list = []
    entity_emb_list = []
    for doc_idx, doc in enumerate(x_data):
        # if doc_idx == 308:
        #     print('Here')
        pred_doc = pred_data[doc_idx]
        pred_tag = tags[doc_idx]
        pred_weight = weights[doc_idx]
        max_sent_length = len(pred_tag)

        mask = torch.tensor([0 if i is 'O' else 1 for i in pred_tag],
                            dtype=torch.float,
                            device=device)
        mask = mask.unsqueeze(1).expand(-1, pred_weight.shape[1])
        # consider only the attention weight of the sentence and keep only the entity weights
        pred_weight = pred_weight[0:max_sent_length] * mask

        tokens, bio_labels = transform_annotated_document_to_bio_format(doc)
        update = False
        emb = []
        for lbl_idx, label in enumerate(pred_tag):
            # check if 'Begin' label matches for true and predicted tag
            if 'B_' in label and 'B_' in bio_labels[lbl_idx]:
                # update if full entiy span is found (single or multi word)
                if update:
                    update = False
                    emb = np.asarray(emb)
                    # take average pooling of the (single or) multi word entity
                    entity_emb_list.append(np.mean(emb, axis=0))
                    emb = []
                # check if next tags of true and predicted tag matches; needed for complete 'BI' sequence
                if lbl_idx + 1 < max_sent_length and pred_tag[
                        lbl_idx + 1] != bio_labels[lbl_idx + 1]:
                    continue
                update = True
                emb.append(pred_weight[lbl_idx].cpu().detach().numpy())
            # check if 'Inside' label matches for true and predicted tag
            elif 'I_' in label and 'I_' in bio_labels[lbl_idx]:
                # check if next tags of true and predicted tag matches; needed for complete 'BI' sequence
                if lbl_idx + 1 < max_sent_length and pred_tag[
                        lbl_idx + 1] != bio_labels[lbl_idx + 1]:
                    update = False
                    continue
                # add only if 'Begin' tag matches
                if update:
                    emb.append(pred_weight[lbl_idx].cpu().detach().numpy())

            if 'O' in label or lbl_idx == max_sent_length - 1:
                # update if full entiy span is found (single or multi word)
                if update:
                    update = False
                    emb = np.asarray(emb)
                    # take average pooling of the (single or) multi word entity
                    entity_emb_list.append(np.mean(emb, axis=0))
                    emb = []

        # find the corresponding normalized concept in true annotated docs
        for annotation in doc.annotations:
            # max predicted annotations
            max_ann_length = len(pred_doc.annotations)
            ann_counter = 0
            # iterate in predicted annotations
            complete_flag = False
            while ann_counter < max_ann_length:
                # find the matching annotation text and offset in the predicted annotations
                # TODO: properly fix inverted comma issue in the whole NERDS
                if '"' in pred_doc.annotations[ann_counter].text:
                    flag = annotation.text.replace(
                        ' ',
                        '') == pred_doc.annotations[ann_counter].text.replace(
                            '"', '').replace(' ', '')
                else:
                    flag = False
                if (annotation.text == pred_doc.annotations[ann_counter].text
                        or flag) and annotation.offset == pred_doc.annotations[
                            ann_counter].offset:
                    ann_identifier = annotation.identifier
                    pred_doc.annotations.remove(
                        pred_doc.annotations[ann_counter])
                    for norm in doc.normalizations:
                        if norm.argument_id == ann_identifier:
                            entity_list.append(norm.preferred_term)
                            complete_flag = True
                            break
                if complete_flag:
                    break
                ann_counter += 1

        assert (len(entity_list) == len(entity_emb_list)
                ), 'Entity ID and embedding mismatch'

    return entity_list, entity_emb_list
Пример #7
0
    def transform(self, X, y=None):
        """ Annotates the list of `Document` objects that are provided as
            input and returns a list of `AnnotatedDocument` objects.
        """
        log.info(
            "Annotating named entities in {} documents with BiLSTM...".format(
                len(X)))
        annotated_documents = []
        for idx, document in enumerate(X):

            # get tokens
            tokens, _ = transform_annotated_document_to_bio_format(document)

            # encode tokens and pad the sequence
            coded_tokens = [
                self.encoder.encode_word(token) for token in tokens
            ]
            x = pad_sequences(maxlen=self.encoder.max_len,
                              sequences=[coded_tokens],
                              padding="post",
                              value=self.encoder.encode_word(PAD_WORD))
            inputs = [x]

            # add encoded and padded char sequences if needed
            if self.config.get_parameter("use_char_emb"):
                c = [[self.encoder.encode_char(char) for char in token]
                     for token in tokens]
                c = pad_sequences(
                    maxlen=self.encoder.max_len_char,
                    sequences=c,
                    padding="post",
                    value=self.encoder.encode_char(PAD_CHAR)).tolist()
                # add padding chars for padding words
                for i in range(len(tokens), self.encoder.max_len):
                    c.append([self.encoder.encode_char(PAD_CHAR)] *
                             self.encoder.max_len_char)
                c = np.array([c], ndmin=3)
                inputs.append(c)

            # add encoded and padded POS tag sequences if needed
            if self.config.get_parameter("use_pos_emb"):
                pos_tags = tokens_to_pos_tags(tokens)
                coded_pos_tags = [
                    self.encoder.encode_pos(pos) for pos in pos_tags
                ]
                p = pad_sequences(maxlen=self.encoder.max_len,
                                  sequences=[coded_pos_tags],
                                  padding="post",
                                  value=self.encoder.encode_pos(PAD_POS))
                inputs.append(p)

            # get predicted tags
            output = self.model.predict(x=inputs)
            coded_tags = np.argmax(output, axis=-1)[0]
            tags = [self.encoder.decode_tag(idx) for idx in coded_tags]
            tags = tags[:len(tokens)]

            # annotate a document
            annotated_documents.append(
                transform_bio_tags_to_annotated_document(
                    tokens, tags, document))
            # info
            log_progress(log, idx, len(X))

        return annotated_documents