Esempio n. 1
0
class CAT(object):
    r'''
    The main MedCAT class used to annotate documents, it is built on top of spaCy
    and works as a spaCy pipline. Creates an instance of a spaCy pipline that can
    be used as a spacy nlp model.

    Args:
        cdb (medcat.cdb.CDB):
            The concept database that will be used for NER+L
        config (medcat.config.Config):
            Global configuration for medcat
        vocab (medcat.vocab.Vocab, optional):
            Vocabulary used for vector embeddings and spelling. Default: None
        meta_cats (list of medcat.meta_cat.MetaCAT, optional):
            A list of models that will be applied sequentially on each
            detected annotation.

    Attributes (limited):
        cdb (medcat.cdb.CDB):
            Concept database used with this CAT instance, please do not assign
            this value directly.
        config (medcat.config.Config):
            The global configuration for medcat. Usuall cdb.config can be used for this
            field.
        vocab (medcat.utils.vocab.Vocab):
            The vocabulary object used with this instance, please do not assign
            this value directly.
        config - WILL BE REMOVED - TEMPORARY PLACEHOLDER

    Examples:
        >>>cat = CAT(cdb, vocab)
        >>>spacy_doc = cat("Put some text here")
        >>>print(spacy_doc.ents) # Detected entites
    '''
    log = logging.getLogger(__package__)
    # Add file and console handlers
    log = add_handlers(log)

    def __init__(self, cdb, config, vocab, meta_cats=[]):
        self.cdb = cdb
        self.vocab = vocab
        # Take config from the cdb
        self.config = config

        # Set log level
        self.log.setLevel(self.config.general['log_level'])

        # Build the pipeline
        self.nlp = Pipe(tokenizer=spacy_split_all, config=self.config)
        self.nlp.add_tagger(tagger=partial(tag_skip_and_punct,
                                           config=self.config),
                            name='skip_and_punct',
                            additional_fields=['is_punct'])

        spell_checker = BasicSpellChecker(cdb_vocab=self.cdb.vocab,
                                          config=self.config,
                                          data_vocab=vocab)
        self.nlp.add_token_normalizer(spell_checker=spell_checker,
                                      config=self.config)

        # Add NER
        self.ner = NER(self.cdb, self.config)
        self.nlp.add_ner(self.ner)

        # Add LINKER
        self.linker = Linker(self.cdb, vocab, self.config)
        self.nlp.add_linker(self.linker)

        # Add meta_annotaiton classes if they exist
        self._meta_annotations = False
        for meta_cat in meta_cats:
            self.nlp.add_meta_cat(meta_cat, meta_cat.category_name)
            self._meta_annotations = True

        # Set max document length
        self.nlp.nlp.max_length = self.config.preprocessing.get(
            'max_document_length', 1000000)

    def get_spacy_nlp(self):
        ''' Returns the spacy pipeline with MedCAT
        '''
        return self.nlp.nlp

    def __call__(self, text, do_train=False):
        r'''
        Push the text through the pipeline.

        Args:
            text (string):
                The text to be annotated, if it is longer than self.config.preprocessing['max_document_length'] it will be trimmed
                to that length.
            do_train (bool, defaults to `False`):
                This causes so many screwups when not there, so I'll force training
                to False. To run training it is much better to use the self.train() function
                but for some special cases I'm leaving it here also.
        Returns:
            A spacy document with the extracted entities
        '''
        # Should we train - do not use this for training, unles you know what you are doing. Use the
        #self.train() function
        self.config.linking['train'] = do_train

        if text and len(text) > 0:
            return self.nlp(text[0:self.config.preprocessing.
                                 get('max_document_length', 1000000)])
        else:
            return None

    def _print_stats(self,
                     data,
                     epoch=0,
                     use_filters=False,
                     use_overlaps=False,
                     use_cui_doc_limit=False,
                     use_groups=False):
        r''' TODO: Refactor and make nice
        Print metrics on a dataset (F1, P, R), it will also print the concepts that have the most FP,FN,TP.

        Args:
            data (list of dict):
                The json object that we get from MedCATtrainer on export.
            epoch (int):
                Used during training, so we know what epoch is it.
            use_filters (boolean):
                Each project in medcattrainer can have filters, do we want to respect those filters
                when calculating metrics.
            use_overlaps (boolean):
                Allow overlapping entites, nearly always False as it is very difficult to annotate overlapping entites.
            use_cui_doc_limit (boolean):
                If True the metrics for a CUI will be only calculated if that CUI appears in a document, in other words
                if the document was annotated for that CUI. Useful in very specific situations when during the annotation
                process the set of CUIs changed.
            use_groups (boolean):
                If True concepts that have groups will be combined and stats will be reported on groups.

        Returns:
            fps (dict):
                False positives for each CUI
            fns (dict):
                False negatives for each CUI
            tps (dict):
                True positives for each CUI
            cui_prec (dict):
                Precision for each CUI
            cui_rec (dict):
                Recall for each CUI
            cui_f1 (dict):
                F1 for each CUI
            cui_counts (dict):
                Number of occurrence for each CUI
            examples (dict):
                Examples for each of the fp, fn, tp. Foramt will be examples['fp']['cui'][<list_of_examples>]
        '''
        tp = 0
        fp = 0
        fn = 0
        fps = {}
        fns = {}
        tps = {}
        cui_prec = {}
        cui_rec = {}
        cui_f1 = {}
        cui_counts = {}
        examples = {'fp': {}, 'fn': {}, 'tp': {}}

        fp_docs = set()
        fn_docs = set()
        # Backup for filters
        _filters = deepcopy(self.config.linking['filters'])
        # Shortcut for filters
        filters = self.config.linking['filters']

        for pind, project in tqdm(enumerate(data['projects']),
                                  desc="Stats project",
                                  total=len(data['projects']),
                                  leave=False):
            if use_filters:
                if type(project.get('cuis', None)) == str:
                    # Old filters
                    filters['cuis'] = process_old_project_filters(
                        cuis=project.get('cuis', None),
                        type_ids=project.get('tuis', None),
                        cdb=self.cdb)
                elif type(project.get('cuis', None)) == list:
                    # New filters
                    filters['cuis'] = project.get('cuis')

            start_time = time.time()
            for dind, doc in tqdm(enumerate(project['documents']),
                                  desc='Stats document',
                                  total=len(project['documents']),
                                  leave=False):
                if type(doc['annotations']) == list:
                    anns = doc['annotations']
                elif type(doc['annotations']) == dict:
                    anns = doc['annotations'].values()

                # Apply document level filtering if
                if use_cui_doc_limit:
                    _cuis = set([ann['cui'] for ann in anns])
                    if _cuis:
                        filters['cuis'] = _cuis

                spacy_doc = self(doc['text'])

                if use_overlaps:
                    p_anns = spacy_doc._.ents
                else:
                    p_anns = spacy_doc.ents

                anns_norm = []
                anns_norm_neg = []
                anns_examples = []
                anns_norm_cui = []
                for ann in anns:
                    cui = ann['cui']
                    if not use_filters or check_filters(cui, filters):
                        if use_groups:
                            cui = self.cdb.addl_info['cui2group'].get(cui, cui)

                        if ann.get('validated',
                                   True) and (not ann.get('killed', False) and
                                              not ann.get('deleted', False)):
                            anns_norm.append((ann['start'], cui))
                            anns_examples.append({
                                "text":
                                doc['text'][max(0, ann['start'] -
                                                60):ann['end'] + 60],
                                "cui":
                                cui,
                                "source value":
                                ann['value'],
                                "acc":
                                1,
                                "project index":
                                pind,
                                "document inedex":
                                dind
                            })
                        elif ann.get('validated', True) and (ann.get(
                                'killed', False) or ann.get('deleted', False)):
                            anns_norm_neg.append((ann['start'], cui))

                        if ann.get("validated", True):
                            # This is used to test was someone annotating for this CUI in this document
                            anns_norm_cui.append(cui)
                            cui_counts[cui] = cui_counts.get(cui, 0) + 1

                p_anns_norm = []
                p_anns_examples = []
                for ann in p_anns:
                    cui = ann._.cui
                    if use_groups:
                        cui = self.cdb.addl_info['cui2group'].get(cui, cui)

                    p_anns_norm.append((ann.start_char, cui))
                    p_anns_examples.append({
                        "text":
                        doc['text'][max(0, ann.start_char - 60):ann.end_char +
                                    60],
                        "cui":
                        cui,
                        "source value":
                        ann.text,
                        "acc":
                        float(ann._.context_similarity),
                        "project index":
                        pind,
                        "document inedex":
                        dind
                    })

                for iann, ann in enumerate(p_anns_norm):
                    cui = ann[1]
                    if ann in anns_norm:
                        tp += 1
                        tps[cui] = tps.get(cui, 0) + 1

                        example = p_anns_examples[iann]
                        examples['tp'][cui] = examples['tp'].get(
                            cui, []) + [example]
                    else:
                        fp += 1
                        fps[cui] = fps.get(cui, 0) + 1
                        fp_docs.add(doc.get('name', 'unk'))

                        # Add example for this FP prediction
                        example = p_anns_examples[iann]
                        if ann in anns_norm_neg:
                            # Means that it really was annotated as negative
                            example['real_fp'] = True

                        examples['fp'][cui] = examples['fp'].get(
                            cui, []) + [example]

                for iann, ann in enumerate(anns_norm):
                    if ann not in p_anns_norm:
                        cui = ann[1]
                        fn += 1
                        fn_docs.add(doc.get('name', 'unk'))

                        fns[cui] = fns.get(cui, 0) + 1
                        examples['fn'][cui] = examples['fn'].get(
                            cui, []) + [anns_examples[iann]]

        try:
            prec = tp / (tp + fp)
            rec = tp / (tp + fn)
            f1 = 2 * (prec * rec) / (prec + rec)
            print("Epoch: {}, Prec: {}, Rec: {}, F1: {}\n".format(
                epoch, prec, rec, f1))
            print("Docs with false positives: {}\n".format("; ".join(
                [str(x) for x in list(fp_docs)[0:10]])))
            print("Docs with false negatives: {}\n".format("; ".join(
                [str(x) for x in list(fn_docs)[0:10]])))

            # Sort fns & prec
            fps = {
                k: v
                for k, v in sorted(
                    fps.items(), key=lambda item: item[1], reverse=True)
            }
            fns = {
                k: v
                for k, v in sorted(
                    fns.items(), key=lambda item: item[1], reverse=True)
            }
            tps = {
                k: v
                for k, v in sorted(
                    tps.items(), key=lambda item: item[1], reverse=True)
            }

            # F1 per concept
            for cui in tps.keys():
                prec = tps[cui] / (tps.get(cui, 0) + fps.get(cui, 0))
                rec = tps[cui] / (tps.get(cui, 0) + fns.get(cui, 0))
                f1 = 2 * (prec * rec) / (prec + rec)
                cui_prec[cui] = prec
                cui_rec[cui] = rec
                cui_f1[cui] = f1

            # Get top 10
            pr_fps = [(self.cdb.cui2preferred_name.get(
                cui,
                list(self.cdb.cui2names.get(cui, [cui]))[0]), cui, fps[cui])
                      for cui in list(fps.keys())[0:10]]
            pr_fns = [(self.cdb.cui2preferred_name.get(
                cui,
                list(self.cdb.cui2names.get(cui, [cui]))[0]), cui, fns[cui])
                      for cui in list(fns.keys())[0:10]]
            pr_tps = [(self.cdb.cui2preferred_name.get(
                cui,
                list(self.cdb.cui2names.get(cui, [cui]))[0]), cui, tps[cui])
                      for cui in list(tps.keys())[0:10]]

            print("\n\nFalse Positives\n")
            for one in pr_fps:
                print("{:70} - {:20} - {:10}".format(
                    str(one[0])[0:69],
                    str(one[1])[0:19], one[2]))
            print("\n\nFalse Negatives\n")
            for one in pr_fns:
                print("{:70} - {:20} - {:10}".format(
                    str(one[0])[0:69],
                    str(one[1])[0:19], one[2]))
            print("\n\nTrue Positives\n")
            for one in pr_tps:
                print("{:70} - {:20} - {:10}".format(
                    str(one[0])[0:69],
                    str(one[1])[0:19], one[2]))
            print("*" * 110 + "\n")

        except Exception as e:
            traceback.print_exc()

        self.config.linking['filters'] = _filters

        return fps, fns, tps, cui_prec, cui_rec, cui_f1, cui_counts, examples

    def train(self, data_iterator, fine_tune=True, progress_print=1000):
        """ Runs training on the data, note that the maximum lenght of a line
        or document is 1M characters. Anything longer will be trimmed.

        data_iterator:
            Simple iterator over sentences/documents, e.g. a open file
            or an array or anything that we can use in a for loop.
        fine_tune:
            If False old training will be removed
        progress_print:
            Print progress after N lines
        """
        if not fine_tune:
            self.log.info("Removing old training data!")
            self.cdb.reset_training()

        cnt = 0
        for line in data_iterator:
            if line is not None and line:
                # Convert to string
                line = str(line).strip()

                try:
                    _ = self(line, do_train=True)
                except Exception as e:
                    self.log.warning("LINE: '{}...' \t WAS SKIPPED".format(
                        line[0:100]))
                    self.log.warning("BECAUSE OF: " + str(e))
                if cnt % progress_print == 0:
                    self.log.info("DONE: " + str(cnt))
                cnt += 1

        self.config.linking['train'] = False

    def add_cui_to_group(self, cui, group_name, reset_all_groups=False):
        r'''
        Ads a CUI to a group, will appear in cdb.addl_info['cui2group']

        Args:
            cui (str):
                The concept to be added
            group_name (str):
                The group to whcih the concept will be added
            reset_all_groups (boolean):
                If True it will reset all existing groups and remove them.

        Examples:
            >>> cat.add_cui_to_group("S-17", 'pain')
        '''

        # Reset if needed
        if reset_all_groups:
            self.cdb.addl_info['cui2group'] = {}

        # Add group_name
        self.cdb.addl_info['cui2group'][cui] = group_name

    def unlink_concept_name(self, cui, name, preprocessed_name=False):
        r'''
        Unlink a concept name from the CUI (or all CUIs if full_unlink), removes the link from
        the Concept Database (CDB). As a consequence medcat will never again link the `name`
        to this CUI - meaning the name will not be detected as a concept in the future.

        Args:
            cui (str):
                The CUI from which the `name` will be removed
            name (str):
                The span of text to be removed from the linking dictionary
        Examples:
            >>> # To never again link C0020538 to HTN
            >>> cat.unlink_concept_name('C0020538', 'htn', False)
        '''

        cuis = [cui]
        if preprocessed_name:
            names = {name: 'nothing'}
        else:
            names = prepare_name(name, self, {}, self.config)

        # If full unlink find all CUIs
        if self.config.general.get('full_unlink', False):
            for name in names:
                cuis.extend(self.cdb.name2cuis.get(name, []))

        # Remove name from all CUIs
        for cui in cuis:
            self.cdb.remove_names(cui=cui, names=names)

    def add_and_train_concept(self,
                              cui,
                              name,
                              spacy_doc=None,
                              spacy_entity=None,
                              ontologies=set(),
                              name_status='A',
                              type_ids=set(),
                              description='',
                              full_build=True,
                              negative=False,
                              devalue_others=False,
                              do_add_concept=True):
        r''' Add a name to an existing concept, or add a new concept, or do not do anything if the name and concept alraedy exist. Perform
        training if spacy_entity and spacy_doc are set.

        Args:
            cui (str):
                CUI of the concept
            name (str):
                Name to be linked to the concept (in the case of MedCATtrainer this is simply the
                selected value in text, no preprocessing or anything needed).
            spacy_doc (spacy.tokens.Doc):
                Spacy represenation of the document that was manually annotated.
            spacy_entity (List[spacy.tokens.Token]):
                Given the spacy document, this is the annotated span of text - list of annotated tokens that are marked with this CUI.
            negative (bool):
                Is this a negative or positive example.
            devalue_others:
                If set, cuis to which this name is assigned and are not `cui` will receive negative training given
                that negative=False.

            **other:
                Refer to CDB.add_concept
        '''

        names = prepare_name(name, self, {}, self.config)
        if do_add_concept:
            self.cdb.add_concept(cui=cui,
                                 names=names,
                                 ontologies=ontologies,
                                 name_status=name_status,
                                 type_ids=type_ids,
                                 description=description,
                                 full_build=full_build)

        if spacy_entity is not None and spacy_doc is not None:
            # Train Linking
            self.linker.context_model.train(cui=cui,
                                            entity=spacy_entity,
                                            doc=spacy_doc,
                                            negative=negative,
                                            names=names)

            if not negative and devalue_others:
                # Find all cuis
                cuis = set()
                for name in names:
                    cuis.update(self.cdb.name2cuis.get(name, []))
                # Remove the cui for which we just added positive training
                cuis.remove(cui)
                # Add negative training for all other CUIs that link to these names
                for _cui in cuis:
                    self.linker.context_model.train(cui=_cui,
                                                    entity=spacy_entity,
                                                    doc=spacy_doc,
                                                    negative=True)

    def train_supervised(self,
                         data_path,
                         reset_cui_count=False,
                         nepochs=1,
                         print_stats=0,
                         use_filters=False,
                         terminate_last=False,
                         use_overlaps=False,
                         use_cui_doc_limit=False,
                         test_size=0,
                         devalue_others=False,
                         use_groups=False,
                         never_terminate=False,
                         train_from_false_positives=False):
        r''' TODO: Refactor, left from old
        Run supervised training on a dataset from MedCATtrainer. Please take care that this is more a simiulated
        online training then supervised.

        Args:
            data_path (str):
                The path to the json file that we get from MedCATtrainer on export.
            reset_cui_count (boolean):
                Used for training with weight_decay (annealing). Each concept has a count that is there
                from the beginning of the CDB, that count is used for annealing. Resetting the count will
                significantly incrase the training impact. This will reset the count only for concepts
                that exist in the the training data.
            nepochs (int):
                Number of epochs for which to run the training.
            print_stats (int):
                If > 0 it will print stats every print_stats epochs.
            use_filters (boolean):
                Each project in medcattrainer can have filters, do we want to respect those filters
                when calculating metrics.
            terminate_last (boolean):
                If true, concept termination will be done after all training.
            use_overlaps (boolean):
                Allow overlapping entites, nearly always False as it is very difficult to annotate overlapping entites.
            use_cui_doc_limit (boolean):
                If True the metrics for a CUI will be only calculated if that CUI appears in a document, in other words
                if the document was annotated for that CUI. Useful in very specific situations when during the annotation
                process the set of CUIs changed.
            test_size (float):
                If > 0 the data set will be split into train test based on this ration. Should be between 0 and 1.
                Usually 0.1 is fine.
            devalue_others(bool):
                Check add_name for more details.
            use_groups (boolean):
                If True concepts that have groups will be combined and stats will be reported on groups.
            never_terminate (boolean):
                If True no termination will be applied
            train_from_false_positives (boolean):
                If True it will use false positive examples detected by medcat and train from them as negative examples.

        Returns:
            fp (dict):
                False positives for each CUI
            fn (dict):
                False negatives for each CUI
            tp (dict):
                True positives for each CUI
            p (dict):
                Precision for each CUI
            r (dict):
                Recall for each CUI
            f1 (dict):
                F1 for each CUI
            cui_counts (dict):
                Number of occurrence for each CUI
            examples (dict):
                FP/FN examples of sentences for each CUI
        '''
        fp = fn = tp = p = r = f1 = cui_counts = examples = {}
        data = json.load(open(data_path))
        cui_counts = {}

        if test_size == 0:
            self.log.info("Running without a test set, or train=test")
            test_set = data
            train_set = data
        else:
            train_set, test_set, _, _ = make_mc_train_test(data,
                                                           self.cdb,
                                                           test_size=test_size)

        if print_stats > 0:
            self._print_stats(test_set,
                              use_filters=use_filters,
                              use_cui_doc_limit=use_cui_doc_limit,
                              use_overlaps=use_overlaps,
                              use_groups=use_groups)

        if reset_cui_count:
            # Get all CUIs
            cuis = []
            for project in train_set['projects']:
                for doc in project['documents']:
                    if type(doc['annotations']) == list:
                        doc_annotations = doc['annotations']
                    elif type(doc['annotations']) == dict:
                        doc_annotations = doc['annotations'].values()

                    for ann in doc_annotations:
                        cuis.append(ann['cui'])
            for cui in set(cuis):
                if cui in self.cdb.cui2count_train:
                    self.cdb.cui2count_train[cui] = 10

        # Remove entities that were terminated
        if not never_terminate:
            for project in train_set['projects']:
                for doc in project['documents']:
                    if type(doc['annotations']) == list:
                        doc_annotations = doc['annotations']
                    elif type(doc['annotations']) == dict:
                        doc_annotations = doc['annotations'].values()

                    for ann in doc_annotations:
                        if ann.get('killed', False):
                            self.unlink_concept_name(ann['cui'], ann['value'])

        for epoch in tqdm(range(nepochs), desc='Epoch', leave=False):
            # Print acc before training
            for project in tqdm(train_set['projects'],
                                desc='Project',
                                leave=False,
                                total=len(train_set['projects'])):
                for i_doc, doc in tqdm(enumerate(project['documents']),
                                       desc='Document',
                                       leave=False,
                                       total=len(project['documents'])):
                    spacy_doc = self(doc['text'])
                    # Compatibility with old output where annotations are a list
                    if type(doc['annotations']) == list:
                        doc_annotations = doc['annotations']
                    elif type(doc['annotations']) == dict:
                        doc_annotations = doc['annotations'].values()

                    for ann in doc_annotations:
                        if not ann.get('killed', False):
                            cui = ann['cui']
                            start = ann['start']
                            end = ann['end']
                            spacy_entity = tkns_from_doc(spacy_doc=spacy_doc,
                                                         start=start,
                                                         end=end)
                            deleted = ann.get('deleted', False)
                            self.add_and_train_concept(
                                cui=cui,
                                name=ann['value'],
                                spacy_doc=spacy_doc,
                                spacy_entity=spacy_entity,
                                negative=deleted,
                                devalue_others=devalue_others)
                    if train_from_false_positives:
                        fps = get_false_positives(doc, spacy_doc)

                        for fp in fps:
                            self.add_and_train_concept(cui=fp._.cui,
                                                       name=fp.text,
                                                       spacy_doc=spacy_doc,
                                                       spacy_entity=fp,
                                                       negative=True,
                                                       do_add_concept=False)

            if terminate_last and not never_terminate:
                # Remove entities that were terminated, but after all training is done
                for project in train_set['projects']:
                    for doc in project['documents']:
                        if type(doc['annotations']) == list:
                            doc_annotations = doc['annotations']
                        elif type(doc['annotations']) == dict:
                            doc_annotations = doc['annotations'].values()

                        for ann in doc_annotations:
                            if ann.get('killed', False):
                                self.unlink_concept_name(
                                    ann['cui'], ann['value'])

            if print_stats > 0 and (epoch + 1) % print_stats == 0:
                fp, fn, tp, p, r, f1, cui_counts, examples = self._print_stats(
                    test_set,
                    epoch=epoch + 1,
                    use_filters=use_filters,
                    use_cui_doc_limit=use_cui_doc_limit,
                    use_overlaps=use_overlaps,
                    use_groups=use_groups)
        return fp, fn, tp, p, r, f1, cui_counts, examples

    def get_entities(self,
                     text,
                     only_cui=False,
                     addl_info=['cui2icd10', 'cui2ontologies']):
        r''' Get entities

        text:  text to be annotated
        return:  entities
        '''
        cnf_annotation_output = getattr(self.config, 'annotation_output', {})
        doc = self(text)
        out = {'entities': {}, 'tokens': []}
        if doc is not None:
            out_ent = {}
            if self.config.general.get('show_nested_entities', False):
                _ents = doc._.ents
            else:
                _ents = doc.ents

            if cnf_annotation_output.get("lowercase_context", True):
                doc_tokens = [tkn.text_with_ws.lower() for tkn in list(doc)]
            else:
                doc_tokens = [tkn.text_with_ws for tkn in list(doc)]

            if cnf_annotation_output.get('doc_extended_info', False):
                # Add tokens if extended info
                out['tokens'] = doc_tokens

            context_left = cnf_annotation_output.get('context_left', -1)
            context_right = cnf_annotation_output.get('context_right', -1)
            doc_extended_info = cnf_annotation_output.get(
                'doc_extended_info', False)

            for ind, ent in enumerate(_ents):
                cui = str(ent._.cui)
                if not only_cui:
                    out_ent['pretty_name'] = self.cdb.cui2preferred_name.get(
                        cui, '')
                    out_ent['cui'] = cui
                    out_ent['tuis'] = list(self.cdb.cui2type_ids.get(cui, ''))
                    out_ent['types'] = [
                        self.cdb.addl_info['type_id2name'].get(tui, '')
                        for tui in out_ent['tuis']
                    ]
                    out_ent['source_value'] = ent.text
                    out_ent['detected_name'] = str(ent._.detected_name)
                    out_ent['acc'] = float(ent._.context_similarity)
                    out_ent['context_similarity'] = float(
                        ent._.context_similarity)
                    out_ent['start'] = ent.start_char
                    out_ent['end'] = ent.end_char
                    for addl in addl_info:
                        tmp = self.cdb.addl_info[addl].get(cui, [])
                        out_ent[addl.split("2")[-1]] = list(tmp) if type(
                            tmp) == set else tmp
                    out_ent['id'] = ent._.id
                    out_ent['meta_anns'] = {}

                    if doc_extended_info:
                        out_ent['start_tkn'] = ent.start
                        out_ent['end_tkn'] = ent.end

                    if context_left > 0 and context_right > 0:
                        out_ent['context_left'] = doc_tokens[
                            max(ent.start - context_left, 0):ent.start]
                        out_ent['context_right'] = doc_tokens[
                            ent.end:min(ent.end +
                                        context_right, len(doc_tokens))]
                        out_ent['context_center'] = doc_tokens[ent.start:ent.
                                                               end]

                    if hasattr(ent._, 'meta_anns') and ent._.meta_anns:
                        out_ent['meta_anns'] = ent._.meta_anns

                    out['entities'][out_ent['id']] = dict(out_ent)
                else:
                    out['entities'][ent._.id] = cui

        return out

    def get_json(self,
                 text,
                 only_cui=False,
                 addl_info=['cui2icd10', 'cui2ontologies']):
        """ Get output in json format

        text:  text to be annotated
        return:  json with fields {'entities': <>, 'text': text}
        """
        ents = self.get_entities(text, only_cui,
                                 addl_info=addl_info)['entities']
        out = {'annotations': ents, 'text': text}

        return json.dumps(out)

    def multiprocessing(self,
                        in_data,
                        nproc=8,
                        batch_size=100,
                        only_cui=False,
                        addl_info=[]):
        r''' Run multiprocessing NOT FOR TRAINING

        in_data:  an iterator or array with format: [(id, text), (id, text), ...]
        nproc:  number of processors
        batch_size:  obvious

        return:  an list of tuples: [(id, doc_json), (id, doc_json), ...]
        '''
        if self._meta_annotations:
            # Hack for torch using multithreading, which is not good here
            import torch
            torch.set_num_threads(1)

        # Create the input output for MP
        in_q = Queue(maxsize=4 * nproc)
        manager = Manager()
        out_dict = manager.dict()
        out_dict['processed'] = []

        # Create processes
        procs = []
        for i in range(nproc):
            p = Process(target=self._mp_cons,
                        kwargs={
                            'in_q': in_q,
                            'out_dict': out_dict,
                            'pid': i,
                            'only_cui': only_cui,
                            'addl_info': addl_info
                        })
            p.start()
            procs.append(p)

        data = []
        for id, text in in_data:
            data.append((id, str(text)))
            if len(data) == batch_size:
                in_q.put(data)
                data = []
        # Put the last batch if it exists
        if len(data) > 0:
            in_q.put(data)

        for _ in range(nproc):  # tell workers we're done
            in_q.put(None)

        for p in procs:
            p.join()

        # Close the queue as it can cause memory leaks
        in_q.close()

        out = []
        for key in out_dict.keys():
            if 'pid' in key:
                data = out_dict[key]
                out.extend(data)

        # Sometimes necessary to free memory
        out_dict.clear()
        del out_dict

        return out

    def _mp_cons(self, in_q, out_dict, pid=0, only_cui=False, addl_info=[]):
        cnt = 0
        out = []
        while True:
            if not in_q.empty():
                data = in_q.get()
                if data is None:
                    out_dict['pid: {}'.format(pid)] = out
                    break

                for id, text in data:
                    try:
                        doc = self.get_entities(text=text,
                                                only_cui=only_cui,
                                                addl_info=addl_info)
                        doc['text'] = text
                        out.append((id, doc))
                    except Exception as e:
                        self.log.warning("Exception in _mp_cons")
                        self.log.warning(e, stack_info=True)

            sleep(1)
Esempio n. 2
0
class MakeVocab(object):
    r'''
    Create a new vocab from a text file. To make a vocab and train word embeddings do:
    Args:
        cdb (medcat.cdb.CDB):
            The concept database that will be added ontop of the Vocab built from the text file.
        vocab (medcat.utils.vocab.Vocab, optional):
            Vocabulary to be extended, leave as None if you want to make a new Vocab. Default: None
        word_tokenizer (<function>):
            A custom tokenizer for word spliting - used if embeddings are BERT or similar.
            Default: None

    >>>cdb = <your existing cdb>
    >>>maker = MakeVocab(cdb=cdb, config=config)
    >>>maker.make(data_iterator, out_folder="./output/")
    >>>maker.add_vectors(in_path="./output/data.txt")
    >>>
    '''
    log = logging.getLogger(__name__)

    def __init__(self, config, cdb=None, vocab=None, word_tokenizer=None):
        self.cdb = cdb
        self.config = config
        self.w2v = None
        if vocab is not None:
            self.vocab = vocab
        else:
            self.vocab = Vocab()

        # Build the required spacy pipeline
        self.nlp = Pipe(tokenizer=spacy_split_all, config=config)
        self.nlp.add_tagger(tagger=tag_skip_and_punct,
                            name='skip_and_punct',
                            additional_fields=['is_punct'])

        # Get the tokenizer
        if word_tokenizer is not None:
            self.tokenizer = word_tokenizer
        else:
            self.tokenizer = self._tok

        # Used for saving if the real path is not set
        self.vocab_path = "./tmp_vocab.dat"

    def _tok(self, text):
        return [text]

    def make(self,
             iter_data,
             out_folder,
             join_cdb=True,
             normalize_tokens=False):
        r'''
        Make a vocab - without vectors initially. This will create two files in the out_folder:
        - vocab.dat -> The vocabulary without vectors
        - data.txt -> The tokenized dataset prepared for training of word2vec or similar embeddings.

        Args:
            iter_data (Iterator):
                An iterator over sentences or documents. Can also be a simple array of text documents/sentences.
            out_folder (string):
                A path to a folder where all the results will be saved
            join_cdb (bool):
                Should the words from the CDB be added to the Vocab. Default: True
            normalize_tokens (bool, defaults to True):
                If set tokens will be lematized - tends to work better in some cases where the difference
                between e.g. plural/singular should be ignored. But in general not so important if the dataset is big enough.
        '''
        # Save the preprocessed data, used for emb training
        out_path = Path(out_folder) / "data.txt"
        vocab_path = Path(out_folder) / "vocab.dat"
        self.vocab_path = vocab_path
        out = open(out_path, 'w', encoding='utf-8')

        for ind, doc in enumerate(iter_data):
            if ind % 10000 == 0:
                self.log.info("Vocab builder at: " + str(ind))
                print(ind)

            doc = self.nlp.nlp.tokenizer(doc)
            line = ""

            for token in doc:
                if token.is_space or token.is_punct:
                    continue

                if len(token.lower_) > 0:
                    if normalize_tokens:
                        self.vocab.inc_or_add(token._.norm)
                    else:
                        self.vocab.inc_or_add(token.lower_)

                if normalize_tokens:
                    line = line + " " + "_".join(token._.norm.split(" "))
                else:
                    line = line + " " + "_".join(token.lower_.split(" "))

            out.write(line.strip())
            out.write("\n")
        out.close()

        if join_cdb and self.cdb:
            for word in self.cdb.vocab.keys():
                if word not in self.vocab:
                    self.vocab.add_word(word)
                else:
                    # Update the count with the counts from the new dataset
                    self.cdb.vocab[word] += self.vocab[word]

        # Save the vocab also
        self.vocab.save(path=self.vocab_path)

    def add_vectors(self,
                    in_path=None,
                    w2v=None,
                    overwrite=False,
                    data_iter=None,
                    workers=14,
                    niter=2,
                    min_count=10,
                    window=10,
                    vsize=300,
                    unigram_table_size=100000000):
        r'''
        Add vectors to an existing vocabulary and save changes to the vocab_path.

        Args:
            in_path (String):
                Path to the data.txt that was created by the MakeVocab.make() function.
            w2v (Word2Vec, optional):
                An existing word2vec instance. Default: None
            overwrite (bool):
                If True it will overwrite existing vectors in the vocabulary. Default: False
            data_iter (iterator):
                If you want to provide a customer iterator over the data use this. If yes, then in_path is not needed.
            **: Word2Vec arguments

        Returns:
            A trained word2vec model.
        '''
        if w2v is None:
            if data_iter is None:
                data = SimpleIter(in_path)
            else:
                data = data_iter
            w2v = Word2Vec(data,
                           window=window,
                           min_count=min_count,
                           workers=workers,
                           size=vsize,
                           iter=niter)

        for word in w2v.wv.vocab.keys():
            if word in self.vocab:
                if overwrite:
                    self.vocab.add_vec(word, w2v.wv.get_vector(word))
                else:
                    if self.vocab.vec(word) is None:
                        self.vocab.add_vec(word, w2v.wv.get_vector(word))

        # Save the vocab again, now with vectors
        self.vocab.make_unigram_table(table_size=unigram_table_size)
        self.vocab.save(path=self.vocab_path)
        return w2v

    def destroy_pipe(self):
        self.nlp.destroy()
Esempio n. 3
0
class CDBMaker(object):
    r''' Given a CSV as shown in https://github.com/CogStack/MedCAT/tree/master/examples/<example> it creates a CDB or
    updates an exisitng one.

    Args:
        config (`medcat.config.Config`):
            Global config for MedCAT.
        cdb (`medcat.cdb.CDB`, optional):
            If set the `CDBMaker` will updat the existing `CDB` with new concepts in the CSV.
        name_max_words (`int`, defaults to `20`):
            Names with more words will be skipped during the build of a CDB
    '''
    log = logging.getLogger(__package__)
    log = add_handlers(log)

    def __init__(self, config, cdb=None, name_max_words=20):
        self.config = config
        # Set log level
        self.log.setLevel(self.config.general['log_level'])

        # To make life a bit easier
        self.cnf_cm = config.cdb_maker

        if cdb is None:
            self.cdb = CDB(config=self.config)
        else:
            self.cdb = cdb

        # Build the required spacy pipeline
        self.nlp = Pipe(tokenizer=spacy_split_all, config=config)
        self.nlp.add_tagger(tagger=tag_skip_and_punct,
                            name='skip_and_punct',
                            additional_fields=['is_punct'])


    def prepare_csvs(self, csv_paths, sep=',', encoding=None, escapechar=None, index_col=False, full_build=False, only_existing_cuis=False, **kwargs):
        r''' Compile one or multiple CSVs into a CDB.

        Args:
            csv_paths (`List[str]`):
                An array of paths to the csv files that should be processed
            full_build (`bool`, defaults to `True`):
                If False only the core portions of the CDB will be built (the ones required for
                the functioning of MedCAT). If True, everything will be added to the CDB - this
                usually includes concept descriptions, various forms of names etc (take care that
                this option produces a much larger CDB).
            sep (`str`, defaults to `,`):
                If necessary a custom separator for the csv files
            encoding (`str`, optional):
                Encoding to be used for reading the CSV file
            escapechar (`str`, optional):
                Escape char for the CSV
            index_col (`bool`, defaults_to `False`):
                Index column for pandas read_csv
            only_existing_cuis (`bool`, defaults to False):
                If True no new CUIs will be added, but only linked names will be extended. Mainly used when
                enriching names of a CDB (e.g. SNOMED with UMLS terms).
        Return:
            `medcat.cdb.CDB` with the new concepts added.

        Note:
            **kwargs:
                Will be passed to pandas for CSV reading
            csv:
                Examples of the CSV used to make the CDB can be found on [GitHub](link)
        '''

        useful_columns = ['cui', 'name', 'ontologies', 'name_status', 'type_ids', 'description']
        name_status_options = {'A', 'P', 'N'}

        for csv_path in csv_paths:
            # Read CSV, everything is converted to strings
            df = pandas.read_csv(csv_path, sep=sep, encoding=encoding, escapechar=escapechar, index_col=index_col, dtype=str, **kwargs)
            df = df.fillna('')

            # Find which columns to use from the CSV
            cols = []
            col2ind = {}
            for col in list(df.columns):
                if str(col).lower().strip() in useful_columns:
                    col2ind[str(col).lower().strip()] = len(cols)
                    cols.append(col)

            self.log.info("Started importing concepts from: {}".format(csv_path))
            _time = None # Used to check speed
            _logging_freq = np.ceil(len(df[cols]) / 100)
            for row_id, row in enumerate(df[cols].values):
                if row_id % _logging_freq == 0:
                    # Print some stats
                    if _time is None:
                        # Add last time if it does not exist
                        _time = datetime.datetime.now()
                    # Get current time
                    ctime = datetime.datetime.now()
                    # Get time difference
                    timediff = ctime - _time
                    self.log.info("Current progress: {:.0f}% at {:.3f}s per {} rows".format(
                        (row_id / len(df)) * 100, timediff.microseconds/10**6 + timediff.seconds, (len(df[cols]) // 100)))
                    # Set previous time to current time
                    _time = ctime

                # This must exist
                cui = row[col2ind['cui']].strip().upper()

                if not only_existing_cuis or (only_existing_cuis and cui in self.cdb.cui2names):
                    if 'ontologies' in col2ind:
                        ontologies = set([ontology.strip() for ontology in row[col2ind['ontologies']].upper().split(self.cnf_cm['multi_separator']) if
                                         len(ontology.strip()) > 0])
                    else:
                        ontologies = set()

                    if 'name_status' in col2ind:
                        name_status = row[col2ind['name_status']].strip().upper()

                        # Must be allowed
                        if name_status not in name_status_options:
                            name_status = 'A'
                    else:
                        # Defaults to A - meaning automatic
                        name_status = 'A'

                    if 'type_ids' in col2ind:
                        type_ids = set([type_id.strip() for type_id in row[col2ind['type_ids']].upper().split(self.cnf_cm['multi_separator']) if
                                        len(type_id.strip()) > 0])
                    else:
                        type_ids = set()

                    # Get the ones that do not need any changing
                    if 'description' in col2ind:
                        description = row[col2ind['description']].strip()
                    else:
                        description = ""

                    # We can have multiple versions of a name
                    names = {} # {'name': {'tokens': [<str>], 'snames': [<str>]}}

                    raw_names = [raw_name.strip() for raw_name in row[col2ind['name']].split(self.cnf_cm['multi_separator']) if 
                                 len(raw_name.strip()) > 0]
                    for raw_name in raw_names:
                        raw_name = raw_name.strip()
                        prepare_name(raw_name, self.nlp, names, self.config)

                        if self.config.cdb_maker.get('remove_parenthesis', 0) > 0 and name_status == 'P':
                            # Should we remove the content in parenthesis from primary names and add them also
                            raw_name = PH_REMOVE.sub(" ", raw_name).strip()
                            if len(raw_name) >= self.config.cdb_maker['remove_parenthesis']:
                                prepare_name(raw_name, self.nlp, names, self.config)

                    self.cdb.add_concept(cui=cui, names=names, ontologies=ontologies, name_status=name_status, type_ids=type_ids,
                                         description=description, full_build=full_build)
                    # DEBUG
                    self.log.debug("\n\n**** Added\n CUI: {}\n Names: {}\n Ontologies: {}\n Name status: {}\n".format(cui, names, ontologies, name_status) + \
                                   " Type IDs: {}\n Description: {}\n Is full build: {}".format(
                                   type_ids, description, full_build))

        return self.cdb

    def destroy_pipe(self):
        self.nlp.destroy()
Esempio n. 4
0
class NerArchiveTests(unittest.TestCase):
    def setUp(self) -> None:
        self.config = Config()
        self.config.general['log_level'] = logging.INFO
        cdb = CDB(config=self.config)

        self.nlp = Pipe(tokenizer=spacy_split_all, config=self.config)
        self.nlp.add_tagger(tagger=tag_skip_and_punct,
                            name='skip_and_punct',
                            additional_fields=['is_punct'])

        # Add a couple of names
        cdb.add_names(cui='S-229004',
                      names=prepare_name('Movar', self.nlp, {}, self.config))
        cdb.add_names(cui='S-229004',
                      names=prepare_name('Movar viruses', self.nlp, {},
                                         self.config))
        cdb.add_names(cui='S-229005',
                      names=prepare_name('CDB', self.nlp, {}, self.config))
        # Check
        #assert cdb.cui2names == {'S-229004': {'movar', 'movarvirus', 'movarviruses'}, 'S-229005': {'cdb'}}

        self.vocab_path = "./tmp_vocab.dat"
        if not os.path.exists(self.vocab_path):
            import requests
            tmp = requests.get(
                "https://medcat.rosalind.kcl.ac.uk/media/vocab.dat")
            with open(self.vocab_path, 'wb') as f:
                f.write(tmp.content)

        vocab = Vocab.load(self.vocab_path)
        # Make the pipeline
        self.nlp = Pipe(tokenizer=spacy_split_all, config=self.config)
        self.nlp.add_tagger(tagger=tag_skip_and_punct,
                            name='skip_and_punct',
                            additional_fields=['is_punct'])
        spell_checker = BasicSpellChecker(cdb_vocab=cdb.vocab,
                                          config=self.config,
                                          data_vocab=vocab)
        self.nlp.add_token_normalizer(spell_checker=spell_checker,
                                      config=self.config)
        ner = NER(cdb, self.config)
        self.nlp.add_ner(ner)

        # Add Linker
        link = Linker(cdb, vocab, self.config)
        self.nlp.add_linker(link)

        self.text = "CDB - I was running and then Movar    Virus attacked and CDb"

    def tearDown(self) -> None:
        self.nlp.destroy()

    def test_limits_for_tokens_and_uppercase(self):
        self.config.ner['max_skip_tokens'] = 1
        self.config.ner['upper_case_limit_len'] = 4
        self.config.linking['disamb_length_limit'] = 2

        d = self.nlp(self.text)

        assert len(d._.ents) == 2
        assert d._.ents[0]._.link_candidates[0] == 'S-229004'

    def test_change_limit_for_skip(self):
        self.config.ner['max_skip_tokens'] = 3
        d = self.nlp(self.text)
        assert len(d._.ents) == 3

    def test_change_limit_for_upper_case(self):
        self.config.ner['upper_case_limit_len'] = 3
        d = self.nlp(self.text)
        assert len(d._.ents) == 4

    def test_check_name_length_limit(self):
        self.config.ner['min_name_len'] = 4
        d = self.nlp(self.text)
        assert len(d._.ents) == 2

    def test_speed(self):
        text = "CDB - I was running and then Movar    Virus attacked and CDb"
        text = text * 300
        self.config.general['spell_check'] = True
        start = timer()
        for i in range(50):
            d = self.nlp(text)
        end = timer()
        print("Time: ", end - start)

    def test_without_spell_check(self):
        # Now without spell check
        self.config.general['spell_check'] = False
        start = timer()
        for i in range(50):
            d = self.nlp(self.text)
        end = timer()
        print("Time: ", end - start)

    def test_for_linker(self):
        self.config = Config()
        self.config.general['log_level'] = logging.DEBUG
        cdb = CDB(config=self.config)

        # Add a couple of names
        cdb.add_names(cui='S-229004',
                      names=prepare_name('Movar', self.nlp, {}, self.config))
        cdb.add_names(cui='S-229004',
                      names=prepare_name('Movar viruses', self.nlp, {},
                                         self.config))
        cdb.add_names(cui='S-229005',
                      names=prepare_name('CDB', self.nlp, {}, self.config))
        cdb.add_names(cui='S-2290045',
                      names=prepare_name('Movar', self.nlp, {}, self.config))
        # Check
        #assert cdb.cui2names == {'S-229004': {'movar', 'movarvirus', 'movarviruses'}, 'S-229005': {'cdb'}, 'S-2290045': {'movar'}}

        cuis = list(cdb.cui2names.keys())
        for cui in cuis[0:50]:
            vectors = {
                'short': np.random.rand(300),
                'long': np.random.rand(300),
                'medium': np.random.rand(300)
            }
            cdb.update_context_vector(cui, vectors, negative=False)

        d = self.nlp(self.text)
        vocab = Vocab.load(self.vocab_path)
        cm = ContextModel(cdb, vocab, self.config)
        cm.train_using_negative_sampling('S-229004')
        self.config.linking['train_count_threshold'] = 0

        cm.train('S-229004', d._.ents[1], d)

        cm.similarity('S-229004', d._.ents[1], d)

        cm.disambiguate(['S-2290045', 'S-229004'], d._.ents[1], 'movar', d)
Esempio n. 5
0
from medcat.linking.vector_context_model import ContextModel
from functools import partial
from medcat.linking.context_based_linker import Linker
from medcat.config import Config
import logging
from medcat.cdb import CDB
import os
import requests

config = Config()
config.general['log_level'] = logging.INFO
cdb = CDB(config=config)

nlp = Pipe(tokenizer=spacy_split_all, config=config)
nlp.add_tagger(tagger=partial(tag_skip_and_punct, config=config),
               name='skip_and_punct',
               additional_fields=['is_punct'])

# Add a couple of names
cdb.add_names(cui='S-229004', names=prepare_name('Movar', nlp, {}, config))
cdb.add_names(cui='S-229004',
              names=prepare_name('Movar viruses', nlp, {}, config))
cdb.add_names(cui='S-229005', names=prepare_name('CDB', nlp, {}, config))
# Check
#assert cdb.cui2names == {'S-229004': {'movar', 'movarvirus', 'movarviruses'}, 'S-229005': {'cdb'}}

vocab_path = "./tmp_vocab.dat"
if not os.path.exists(vocab_path):
    import requests
    tmp = requests.get("https://s3-eu-west-1.amazonaws.com/zkcl/vocab.dat")
    with open(vocab_path, 'wb') as f: