def disambiguate(self, cuis, entity, name, doc):
        vectors = self.get_context_vectors(entity, doc)
        filters = self.config.linking['filters']

        # If it is trainer we want to filter concepts before disambiguation
        #do not want to explain why, but it is needed.
        if self.config.linking['filter_before_disamb']:
            # DEBUG
            self.log.debug("Is trainer, subsetting CUIs")
            self.log.debug("CUIs before: {}".format(cuis))

            cuis = [cui for cui in cuis if check_filters(cui, filters)]
            # DEBUG
            self.log.debug("CUIs after: {}".format(cuis))

        if cuis:  #Maybe none are left after filtering
            # Calculate similarity for each cui
            similarities = [self._similarity(cui, vectors) for cui in cuis]
            # DEBUG
            self.log.debug("Similarities: {}".format([
                (sim, cui) for sim, cui in zip(cuis, similarities)
            ]))

            # Prefer primary
            if self.config.linking.get('prefer_primary_name', 0) > 0:
                self.log.debug("Preferring primary names")
                for i, cui in enumerate(cuis):
                    if similarities[i] > 0:
                        if self.cdb.name2cuis2status.get(name, {}).get(
                                cui, '') in {'P', 'PD'}:
                            old_sim = similarities[i]
                            similarities[i] = min(
                                0.99, similarities[i] +
                                similarities[i] * self.config.linking.get(
                                    'prefer_primary_name', 0))
                            # DEBUG
                            self.log.debug(
                                "CUI: {}, Name: {}, Old sim: {:.3f}, New sim: {:.3f}"
                                .format(cui, name, old_sim, similarities[i]))

            if self.config.linking.get('prefer_frequent_concepts', 0) > 0:
                self.log.debug("Preferring frequent concepts")
                #Prefer frequent concepts
                cnts = [self.cdb.cui2count_train.get(cui, 0) for cui in cuis]
                m = min(cnts) if min(cnts) > 0 else 1
                scales = [
                    np.log10(cnt / m) *
                    self.config.linking.get('prefer_frequent_concepts', 0)
                    if cnt > 10 else 0 for cnt in cnts
                ]
                similarities = [
                    min(0.99, sim + sim * scales[i])
                    for i, sim in enumerate(similarities)
                ]

            # Prefer concepts with tag
            mx = np.argmax(similarities)
            return cuis[mx], similarities[mx]
        else:
            return None, 0
Esempio n. 2
0
def add_annotations(spacy_doc, user, project, document, existing_annotations,
                    cat):
    spacy_doc._.ents.sort(key=lambda x: len(x.text), reverse=True)

    tkns_in = []
    ents = []
    existing_annos_intervals = [(ann.start_ind, ann.end_ind)
                                for ann in existing_annotations]

    def check_ents(ent):
        return any(
            (ea[0] < ent.start_char < ea[1]) or (ea[0] < ent.end_char < ea[1])
            for ea in existing_annos_intervals)

    for ent in spacy_doc._.ents:
        if not check_ents(ent) and check_filters(
                ent._.cui, cat.config.linking['filters']):
            to_add = True
            for tkn in ent:
                if tkn in tkns_in:
                    to_add = False
            if to_add:
                for tkn in ent:
                    tkns_in.append(tkn)
                ents.append(ent)

    for ent in ents:
        label = ent._.cui

        # Add the concept info to the Concept table if it doesn't exist
        if not Concept.objects.filter(cui=label).exists():
            concept = Concept()
            concept.cui = label
            update_concept_model(concept, project.concept_db, cat.cdb)

        if not Entity.objects.filter(label=label).exists():
            # Create the entity
            entity = Entity()
            entity.label = label
            entity.save()
        else:
            entity = Entity.objects.get(label=label)

        cui_count_limit = cat.config.general.get("cui_count_limit", -1)
        pcc = ProjectCuiCounter.objects.filter(project=project,
                                               entity=entity).first()
        if pcc is not None:
            cui_count = pcc.count
        else:
            cui_count = 1

        if cui_count_limit < 0 or cui_count <= cui_count_limit:
            if AnnotatedEntity.objects.filter(
                    project=project,
                    document=document,
                    start_ind=ent.start_char,
                    end_ind=ent.end_char).count() == 0:
                # If this entity doesn't exist already
                ann_ent = AnnotatedEntity()
                ann_ent.user = user
                ann_ent.project = project
                ann_ent.document = document
                ann_ent.entity = entity
                ann_ent.value = ent.text
                ann_ent.start_ind = ent.start_char
                ann_ent.end_ind = ent.end_char
                ann_ent.acc = ent._.context_similarity

                MIN_ACC = cat.config.linking.get(
                    'similarity_threshold_trainer', 0.2)
                if ent._.context_similarity < MIN_ACC:
                    ann_ent.deleted = True
                    ann_ent.validated = True

                ann_ent.save()
Esempio n. 3
0
    def __call__(self, doc):
        r'''
        '''
        cnf_l = self.config.linking
        linked_entities = []

        doc_tkns = [tkn for tkn in doc if not tkn._.to_skip]
        doc_tkn_ids = [tkn.idx for tkn in doc_tkns]

        if cnf_l['train']:
            # Run training
            for entity in doc._.ents:
                # Check does it have a detected name
                if entity._.detected_name is not None:
                    name = entity._.detected_name
                    cuis = entity._.link_candidates

                    if len(name) >= cnf_l['disamb_length_limit']:
                        if len(cuis) == 1:
                            # N - means name must be disambiguated, is not the prefered
                            #name of the concept, links to other concepts also.
                            if self.cdb.name2cuis2status[name][cuis[0]] != 'N':
                                self._train(cui=cuis[0],
                                            entity=entity,
                                            doc=doc)
                                entity._.cui = cuis[0]
                                entity._.context_similarity = 1
                                linked_entities.append(entity)
                        else:
                            for cui in cuis:
                                if self.cdb.name2cuis2status[name][cui] in {
                                        'P', 'PD'
                                }:
                                    self._train(cui=cui,
                                                entity=entity,
                                                doc=doc)
                                    # It should not be possible that one name is 'P' for two CUIs,
                                    #but it can happen - and we do not care.
                                    entity._.cui = cui
                                    entity._.context_similarity = 1
                                    linked_entities.append(entity)
        else:
            for entity in doc._.ents:
                self.log.debug("Linker started with entity: {}".format(entity))
                # Check does it have a detected name
                if entity._.link_candidates is not None:
                    if entity._.detected_name is not None:
                        name = entity._.detected_name
                        cuis = entity._.link_candidates

                        if len(cuis) > 0:
                            do_disambiguate = False
                            if len(name) < cnf_l['disamb_length_limit']:
                                do_disambiguate = True
                            elif len(cuis) == 1 and self.cdb.name2cuis2status[
                                    name][cuis[0]] in {'N', 'PD'}:
                                # PD means it is preferred but should still be disambiguated and N is disamb always
                                do_disambiguate = True
                            elif len(cuis) > 1:
                                do_disambiguate = True

                            if do_disambiguate:
                                cui, context_similarity = self.context_model.disambiguate(
                                    cuis, entity, name, doc)
                            else:
                                cui = cuis[0]
                                if self.config.linking[
                                        'always_calculate_similarity']:
                                    context_similarity = self.context_model.similarity(
                                        cui, entity, doc)
                                else:
                                    context_similarity = 1  # Direct link, no care for similarity
                    else:
                        # No name detected, just disambiguate
                        cui, context_similarity = self.context_model.disambiguate(
                            entity._.link_candidates, entity, 'unk-unk', doc)

                    # Add the annotation if it exists and if above threshold and in filters
                    if cui and check_filters(cui,
                                             self.config.linking['filters']):
                        th_type = self.config.linking.get(
                            'similarity_threshold_type', 'static')
                        if (th_type == 'static' and context_similarity >= self.config.linking['similarity_threshold']) or \
                           (th_type == 'dynamic' and context_similarity >= self.cdb.cui2average_confidence[cui] * self.config.linking['similarity_threshold']):
                            entity._.cui = cui
                            entity._.context_similarity = context_similarity
                            linked_entities.append(entity)

        doc._.ents = linked_entities
        self._create_main_ann(doc)
        return doc
Esempio n. 4
0
    def _print_stats(self,
                     data,
                     epoch=0,
                     use_filters=False,
                     use_overlaps=False,
                     use_cui_doc_limit=False,
                     use_groups=False):
        r''' TODO: Refactor and make nice
        Print metrics on a dataset (F1, P, R), it will also print the concepts that have the most FP,FN,TP.

        Args:
            data (list of dict):
                The json object that we get from MedCATtrainer on export.
            epoch (int):
                Used during training, so we know what epoch is it.
            use_filters (boolean):
                Each project in medcattrainer can have filters, do we want to respect those filters
                when calculating metrics.
            use_overlaps (boolean):
                Allow overlapping entites, nearly always False as it is very difficult to annotate overlapping entites.
            use_cui_doc_limit (boolean):
                If True the metrics for a CUI will be only calculated if that CUI appears in a document, in other words
                if the document was annotated for that CUI. Useful in very specific situations when during the annotation
                process the set of CUIs changed.
            use_groups (boolean):
                If True concepts that have groups will be combined and stats will be reported on groups.

        Returns:
            fps (dict):
                False positives for each CUI
            fns (dict):
                False negatives for each CUI
            tps (dict):
                True positives for each CUI
            cui_prec (dict):
                Precision for each CUI
            cui_rec (dict):
                Recall for each CUI
            cui_f1 (dict):
                F1 for each CUI
            cui_counts (dict):
                Number of occurrence for each CUI
            examples (dict):
                Examples for each of the fp, fn, tp. Foramt will be examples['fp']['cui'][<list_of_examples>]
        '''
        tp = 0
        fp = 0
        fn = 0
        fps = {}
        fns = {}
        tps = {}
        cui_prec = {}
        cui_rec = {}
        cui_f1 = {}
        cui_counts = {}
        examples = {'fp': {}, 'fn': {}, 'tp': {}}

        fp_docs = set()
        fn_docs = set()
        # Backup for filters
        _filters = deepcopy(self.config.linking['filters'])
        # Shortcut for filters
        filters = self.config.linking['filters']

        for pind, project in tqdm(enumerate(data['projects']),
                                  desc="Stats project",
                                  total=len(data['projects']),
                                  leave=False):
            if use_filters:
                if type(project.get('cuis', None)) == str:
                    # Old filters
                    filters['cuis'] = process_old_project_filters(
                        cuis=project.get('cuis', None),
                        type_ids=project.get('tuis', None),
                        cdb=self.cdb)
                elif type(project.get('cuis', None)) == list:
                    # New filters
                    filters['cuis'] = project.get('cuis')

            start_time = time.time()
            for dind, doc in tqdm(enumerate(project['documents']),
                                  desc='Stats document',
                                  total=len(project['documents']),
                                  leave=False):
                if type(doc['annotations']) == list:
                    anns = doc['annotations']
                elif type(doc['annotations']) == dict:
                    anns = doc['annotations'].values()

                # Apply document level filtering if
                if use_cui_doc_limit:
                    _cuis = set([ann['cui'] for ann in anns])
                    if _cuis:
                        filters['cuis'] = _cuis

                spacy_doc = self(doc['text'])

                if use_overlaps:
                    p_anns = spacy_doc._.ents
                else:
                    p_anns = spacy_doc.ents

                anns_norm = []
                anns_norm_neg = []
                anns_examples = []
                anns_norm_cui = []
                for ann in anns:
                    cui = ann['cui']
                    if not use_filters or check_filters(cui, filters):
                        if use_groups:
                            cui = self.cdb.addl_info['cui2group'].get(cui, cui)

                        if ann.get('validated',
                                   True) and (not ann.get('killed', False) and
                                              not ann.get('deleted', False)):
                            anns_norm.append((ann['start'], cui))
                            anns_examples.append({
                                "text":
                                doc['text'][max(0, ann['start'] -
                                                60):ann['end'] + 60],
                                "cui":
                                cui,
                                "source value":
                                ann['value'],
                                "acc":
                                1,
                                "project index":
                                pind,
                                "document inedex":
                                dind
                            })
                        elif ann.get('validated', True) and (ann.get(
                                'killed', False) or ann.get('deleted', False)):
                            anns_norm_neg.append((ann['start'], cui))

                        if ann.get("validated", True):
                            # This is used to test was someone annotating for this CUI in this document
                            anns_norm_cui.append(cui)
                            cui_counts[cui] = cui_counts.get(cui, 0) + 1

                p_anns_norm = []
                p_anns_examples = []
                for ann in p_anns:
                    cui = ann._.cui
                    if use_groups:
                        cui = self.cdb.addl_info['cui2group'].get(cui, cui)

                    p_anns_norm.append((ann.start_char, cui))
                    p_anns_examples.append({
                        "text":
                        doc['text'][max(0, ann.start_char - 60):ann.end_char +
                                    60],
                        "cui":
                        cui,
                        "source value":
                        ann.text,
                        "acc":
                        float(ann._.context_similarity),
                        "project index":
                        pind,
                        "document inedex":
                        dind
                    })

                for iann, ann in enumerate(p_anns_norm):
                    cui = ann[1]
                    if ann in anns_norm:
                        tp += 1
                        tps[cui] = tps.get(cui, 0) + 1

                        example = p_anns_examples[iann]
                        examples['tp'][cui] = examples['tp'].get(
                            cui, []) + [example]
                    else:
                        fp += 1
                        fps[cui] = fps.get(cui, 0) + 1
                        fp_docs.add(doc.get('name', 'unk'))

                        # Add example for this FP prediction
                        example = p_anns_examples[iann]
                        if ann in anns_norm_neg:
                            # Means that it really was annotated as negative
                            example['real_fp'] = True

                        examples['fp'][cui] = examples['fp'].get(
                            cui, []) + [example]

                for iann, ann in enumerate(anns_norm):
                    if ann not in p_anns_norm:
                        cui = ann[1]
                        fn += 1
                        fn_docs.add(doc.get('name', 'unk'))

                        fns[cui] = fns.get(cui, 0) + 1
                        examples['fn'][cui] = examples['fn'].get(
                            cui, []) + [anns_examples[iann]]

        try:
            prec = tp / (tp + fp)
            rec = tp / (tp + fn)
            f1 = 2 * (prec * rec) / (prec + rec)
            print("Epoch: {}, Prec: {}, Rec: {}, F1: {}\n".format(
                epoch, prec, rec, f1))
            print("Docs with false positives: {}\n".format("; ".join(
                [str(x) for x in list(fp_docs)[0:10]])))
            print("Docs with false negatives: {}\n".format("; ".join(
                [str(x) for x in list(fn_docs)[0:10]])))

            # Sort fns & prec
            fps = {
                k: v
                for k, v in sorted(
                    fps.items(), key=lambda item: item[1], reverse=True)
            }
            fns = {
                k: v
                for k, v in sorted(
                    fns.items(), key=lambda item: item[1], reverse=True)
            }
            tps = {
                k: v
                for k, v in sorted(
                    tps.items(), key=lambda item: item[1], reverse=True)
            }

            # F1 per concept
            for cui in tps.keys():
                prec = tps[cui] / (tps.get(cui, 0) + fps.get(cui, 0))
                rec = tps[cui] / (tps.get(cui, 0) + fns.get(cui, 0))
                f1 = 2 * (prec * rec) / (prec + rec)
                cui_prec[cui] = prec
                cui_rec[cui] = rec
                cui_f1[cui] = f1

            # Get top 10
            pr_fps = [(self.cdb.cui2preferred_name.get(
                cui,
                list(self.cdb.cui2names.get(cui, [cui]))[0]), cui, fps[cui])
                      for cui in list(fps.keys())[0:10]]
            pr_fns = [(self.cdb.cui2preferred_name.get(
                cui,
                list(self.cdb.cui2names.get(cui, [cui]))[0]), cui, fns[cui])
                      for cui in list(fns.keys())[0:10]]
            pr_tps = [(self.cdb.cui2preferred_name.get(
                cui,
                list(self.cdb.cui2names.get(cui, [cui]))[0]), cui, tps[cui])
                      for cui in list(tps.keys())[0:10]]

            print("\n\nFalse Positives\n")
            for one in pr_fps:
                print("{:70} - {:20} - {:10}".format(
                    str(one[0])[0:69],
                    str(one[1])[0:19], one[2]))
            print("\n\nFalse Negatives\n")
            for one in pr_fns:
                print("{:70} - {:20} - {:10}".format(
                    str(one[0])[0:69],
                    str(one[1])[0:19], one[2]))
            print("\n\nTrue Positives\n")
            for one in pr_tps:
                print("{:70} - {:20} - {:10}".format(
                    str(one[0])[0:69],
                    str(one[1])[0:19], one[2]))
            print("*" * 110 + "\n")

        except Exception as e:
            traceback.print_exc()

        self.config.linking['filters'] = _filters

        return fps, fns, tps, cui_prec, cui_rec, cui_f1, cui_counts, examples
Esempio n. 5
0
def add_annotations(spacy_doc, user, project, document, existing_annotations,
                    cat):
    spacy_doc._.ents.sort(key=lambda x: len(x.text), reverse=True)

    tkns_in = []
    ents = []
    existing_annos_intervals = [(ann.start_ind, ann.end_ind)
                                for ann in existing_annotations]

    def check_ents(ent):
        return any(
            (ea[0] < ent.start_char < ea[1]) or (ea[0] < ent.end_char < ea[1])
            for ea in existing_annos_intervals)

    for ent in spacy_doc._.ents:
        if not check_ents(ent) and check_filters(
                ent._.cui, cat.config.linking['filters']):
            to_add = True
            for tkn in ent:
                if tkn in tkns_in:
                    to_add = False
            if to_add:
                for tkn in ent:
                    tkns_in.append(tkn)
                ents.append(ent)

    for ent in ents:
        label = ent._.cui
        tuis = list(cat.cdb.cui2type_ids.get(label, ''))

        # Add the concept info to the Concept table if it doesn't exist
        cnt = Concept.objects.filter(cui=label).count()
        if cnt == 0:
            pretty_name = cat.cdb.cui2preferred_name.get(label, label)

            concept = Concept()
            concept.pretty_name = pretty_name
            concept.cui = label
            concept.tui = ','.join(tuis)
            concept.semantic_type = ','.join([
                cat.cdb.addl_info['type_id2name'].get(tui, '') for tui in tuis
            ])
            concept.desc = cat.cdb.addl_info['cui2description'].get(label, '')
            concept.synonyms = ",".join(
                cat.cdb.addl_info['cui2original_names'].get(label, []))
            concept.cdb = project.concept_db
            concept.save()

        cnt = Entity.objects.filter(label=label).count()
        if cnt == 0:
            # Create the entity
            entity = Entity()
            entity.label = label
            entity.save()
        else:
            entity = Entity.objects.get(label=label)

        cui_count_limit = cat.config.general.get("cui_count_limit", -1)
        pcc = ProjectCuiCounter.objects.filter(project=project,
                                               entity=entity).first()
        if pcc is not None:
            cui_count = pcc.count
        else:
            cui_count = 1

        if cui_count_limit < 0 or cui_count <= cui_count_limit:
            if AnnotatedEntity.objects.filter(
                    project=project,
                    document=document,
                    start_ind=ent.start_char,
                    end_ind=ent.end_char).count() == 0:
                # If this entity doesn't exist already
                ann_ent = AnnotatedEntity()
                ann_ent.user = user
                ann_ent.project = project
                ann_ent.document = document
                ann_ent.entity = entity
                ann_ent.value = ent.text
                ann_ent.start_ind = ent.start_char
                ann_ent.end_ind = ent.end_char
                ann_ent.acc = ent._.context_similarity

                MIN_ACC = cat.config.linking.get(
                    'similarity_threshold_trainer', 0.2)
                if ent._.context_similarity < MIN_ACC:
                    ann_ent.deleted = True
                    ann_ent.validated = True

                ann_ent.save()