예제 #1
0
파일: utils.py 프로젝트: lrog/MedCATtrainer
def get_medcat(CDB_MAP, VOCAB_MAP, CAT_MAP, project):
    cdb_id = project.concept_db.id
    vocab_id = project.vocab.id
    cat_id = str(cdb_id) + "-" + str(vocab_id)

    if cat_id in CAT_MAP:
        cat = CAT_MAP[cat_id]
    else:
        if cdb_id in CDB_MAP:
            cdb = CDB_MAP[cdb_id]
        else:
            cdb_path = project.concept_db.cdb_file.path
            cdb = CDB()
            cdb.load_dict(cdb_path)
            CDB_MAP[cdb_id] = cdb

        if vocab_id in VOCAB_MAP:
            vocab = VOCAB_MAP[vocab_id]
        else:
            vocab_path = project.vocab.vocab_file.path
            vocab = Vocab()
            vocab.load_dict(vocab_path)
            VOCAB_MAP[vocab_id] = vocab

        cat = CAT(cdb=cdb, vocab=vocab)
        cat.train = False
        CAT_MAP[cat_id] = cat
    return cat
예제 #2
0
def run_cv(cdb_path, data_path, vocab_path, cv=100, nepochs=16, test_size=0.1, lr=1, groups=None, **kwargs):
    from medcat.cat import CAT
    from medcat.utils.vocab import Vocab
    from medcat.cdb import CDB
    import json

    use_groups = False
    if groups is not None:
        use_groups = True

    f1s = {}
    ps = {}
    rs = {}
    tps = {}
    fns = {}
    fps = {}
    cui_counts = {}
    examples = {}
    for i in range(cv):
        cdb = CDB()
        cdb.load_dict(cdb_path)
        vocab = Vocab()
        vocab.load_dict(path=vocab_path)
        cat = CAT(cdb, vocab=vocab)
        cat.train = False
        cat.spacy_cat.MIN_ACC = 0.30
        cat.spacy_cat.MIN_ACC_TH = 0.30

        # Add groups if they exist
        if groups is not None:
            for cui in cdb.cui2info.keys():
                if "group" in cdb.cui2info[cui]:
                    del cdb.cui2info[cui]['group']
            groups = json.load(open("./groups.json"))
            for k,v in groups.items():
                for val in v:
                    cat.add_cui_to_group(val, k)

        fp, fn, tp, p, r, f1, cui_counts, examples = cat.train_supervised(data_path=data_path,
                             lr=1, test_size=test_size, use_groups=use_groups, nepochs=nepochs, **kwargs)

        for key in f1.keys():
            if key in f1s:
                f1s[key].append(f1[key])
            else:
                f1s[key] = [f1[key]]

            if key in ps:
                ps[key].append(p[key])
            else:
                ps[key] = [p[key]]

            if key in rs:
                rs[key].append(r[key])
            else:
                rs[key] = [r[key]]

            if key in tps:
                tps[key].append(tp.get(key, 0))
            else:
                tps[key] = [tp.get(key, 0)]

            if key in fps:
                fps[key].append(fp.get(key, 0))
            else:
                fps[key] = [fp.get(key, 0)]

            if key in fns:
                fns[key].append(fn.get(key, 0))
            else:
                fns[key] = [fn.get(key, 0)]

    return fps, fns, tps, ps, rs, f1s, cui_counts, examples
예제 #3
0
cat = CAT(cdb=cdb, config=cdb.config, vocab=vocab)
cat.config.ner['min_name_len'] = 3
cat.config.ner['upper_case_limit_len'] = 3
cat.config.linking['disamb_length_limit'] = 3
cat.config.linking['filters'] = {'cuis': set()}
cat.config.linking['train_count_threshold'] = -1
cat.config.linking['context_vector_sizes'] = {'xlong': 27, 'long': 18, 'medium': 9, 'short': 3}
cat.config.linking['context_vector_weights'] = {'xlong': 0, 'long': 0.4, 'medium': 0.4, 'short': 0.2}
cat.config.linking['weighted_average_function'] = lambda step: max(0.1, 1-(step**2*0.0004))

cat.config.linking['similarity_threshold_type'] = 'dynamic'
cat.config.linking['similarity_threshold'] = 0.35
cat.config.linking['calculate_dynamic_threshold'] = True

cat.train(df.text.values, fine_tune=True)


cdb.config.general['spacy_disabled_components'] = ['ner', 'parser', 'vectors', 'textcat',
                                                      'entity_linker', 'sentencizer', 'entity_ruler', 'merge_noun_chunks',
                                                                                                    'merge_entities', 'merge_subtokens']

%load_ext autoreload
%autoreload 2

# Train
_ = cat.train(open("./tmp_medmentions_text_only.txt", 'r'), fine_tune=False)

_ = cat.train_supervised("/home/ubuntu/data/medmentions/medmentions.json", reset_cui_count=True, nepochs=13, train_from_false_positives=True, print_stats=3, test_size=0.1)
cdb.save("/home/ubuntu/data/umls/2020ab/cdb_trained_medmen.dat")
예제 #4
0
def run_cv(cdb_path,
           data_path,
           vocab_path,
           cv=100,
           nepochs=16,
           reset_cui_count=True,
           test_size=0.1):
    from medcat.cat import CAT
    from medcat.utils.vocab import Vocab
    from medcat.cdb import CDB
    import json

    f1s = {}
    ps = {}
    rs = {}
    tps = {}
    fns = {}
    fps = {}
    cui_counts = {}
    for i in range(cv):
        cdb = CDB()
        cdb.load_dict(cdb_path)
        vocab = Vocab()
        vocab.load_dict(path=vocab_path)
        cat = CAT(cdb, vocab=vocab)
        cat.train = False
        cat.spacy_cat.MIN_ACC = 0.30
        cat.spacy_cat.MIN_ACC_TH = 0.30

        fp, fn, tp, p, r, f1, cui_counts = cat.train_supervised(
            data_path=data_path,
            lr=1,
            nepochs=nepochs,
            anneal=True,
            print_stats=True,
            use_filters=True,
            reset_cui_count=reset_cui_count,
            terminate_last=True,
            test_size=test_size)

        for key in f1.keys():
            if key in f1s:
                f1s[key].append(f1[key])
            else:
                f1s[key] = [f1[key]]

            if key in ps:
                ps[key].append(p[key])
            else:
                ps[key] = [p[key]]

            if key in rs:
                rs[key].append(r[key])
            else:
                rs[key] = [r[key]]

            if key in tps:
                tps[key].append(tp.get(key, 0))
            else:
                tps[key] = [tp.get(key, 0)]

            if key in fps:
                fps[key].append(fp.get(key, 0))
            else:
                fps[key] = [fp.get(key, 0)]

            if key in fns:
                fns[key].append(fn.get(key, 0))
            else:
                fns[key] = [fn.get(key, 0)]

    return fps, fns, tps, ps, rs, f1s, cui_counts
# cdb.load_dict(os.path.join(medcat_path, 'simple_cdb.csv'))


# If you need a special CDB you can build one from a .csv file
preparator = PrepareCDB(vocab=vocab)
csv_paths = [os.path.join(medcat_path, 'simple_cdb.csv')]#, '<another one>', ...]
csv_paths = [os.path.join(medcat_path, 'attention_cdb.csv')]
cdb = preparator.prepare_csvs(csv_paths)

# Save the new CDB for later
cdb.save_dict(os.path.join(medcat_path, 'simple_cdb.cdb'))

# To annotate documents we do
doc = "My simple document with kidney failure"
cat = CAT(cdb=cdb, vocab=vocab)
cat.train = False
doc_spacy = cat(doc)
# Entities are in
doc_spacy._.ents
# Or to get a json
doc_json = cat.get_json(doc)

# To have a look at the results:
from spacy import displacy
# Note that this will not show all entites, but only the longest ones
displacy.serve(doc_spacy, style='ent')

# To run cat on a large number of documents
data = [] # [(<doc_id>, <text>), (<doc_id>, <text>), ...]
docs = cat.multi_processing(data)