コード例 #1
0
def generate_diagnosis_data():
    sid_hadm_dict = py_op.myreadjson(
        os.path.join(args.data_dir, 'sid_hadm_dict.json'))
    hadm_sid_dict = py_op.myreadjson(
        os.path.join(args.data_dir, 'hadm_sid_dict.json'))

    hadm_map_dict = dict()
    for hadm in hadm_sid_dict:
        sid = hadm_sid_dict[hadm]
        hadm_list = sid_hadm_dict[sid]
        if len(hadm_list) > 1:
            hadm_list = sorted(hadm_list, key=lambda k: int(k))
            idx = hadm_list.index(hadm)
            if idx > 0:
                for h in hadm_list[:idx]:
                    if h not in hadm_map_dict:
                        hadm_map_dict[h] = []
                    hadm_map_dict[h].append(hadm)

    hadm_icd_dict = dict()
    for i_line, line in enumerate(
            open(os.path.join(args.mimic_dir, 'DIAGNOSES_ICD.csv'))):
        if i_line:
            if i_line % 10000 == 0:
                print(i_line)
            line_data = [x.strip('"') for x in py_op.csv_split(line.strip())]
            ROW_ID, SUBJECT_ID, hadm_id, SEQ_NUM, icd = line_data
            if hadm_id in hadm_map_dict:
                for h in hadm_map_dict[hadm_id]:
                    if h not in hadm_icd_dict:
                        hadm_icd_dict[h] = []
                    hadm_icd_dict[h].append(icd)
    hadm_icd_dict = {h: list(set(icds)) for h, icds in hadm_icd_dict.items()}
    py_op.mywritejson(os.path.join(args.data_dir, 'hadm_icd_dict.json'),
                      hadm_icd_dict)
コード例 #2
0
def generate_ehr_files():

    hadm_time_dict = py_op.myreadjson(
        os.path.join(args.data_dir, 'hadm_time_dict.json'))
    hadm_demo_dict = py_op.myreadjson(
        os.path.join(args.data_dir, 'hadm_demo_dict.json'))
    hadm_sid_dict = py_op.myreadjson(
        os.path.join(args.data_dir, 'hadm_sid_dict.json'))
    hadm_icd_dict = py_op.myreadjson(
        os.path.join(args.data_dir, 'hadm_icd_dict.json'))
    hadm_time_drug_dict = py_op.myreadjson(
        os.path.join(args.data_dir, 'hadm_time_drug_dict.json'))
    groundtruth_dir = os.path.join(args.data_dir, 'train_groundtruth')
    py_op.mkdir(groundtruth_dir)
    ehr_count_dict = dict()

    for hadm_id in hadm_sid_dict:

        time_drug_dict = hadm_time_drug_dict.get(hadm_id, {})
        icd_list = hadm_icd_dict.get(hadm_id, [])
        demo = hadm_demo_dict[hadm_id]
        demo[0] = demo[0] + '1'
        demo[1] = 'A' + str(int(demo[1] / 9))
        icd_demo = icd_list + demo

        for icd in icd_demo:
            ehr_count_dict[icd] = ehr_count_dict.get(icd, 0) + 1

        ehr_dict = {'drug': {}, 'icd_demo': icd_demo}

        for setime, drug_list in time_drug_dict.items():
            try:
                stime, etime = setime.split(' -- ')
                start_second = time_to_second(hadm_time_dict[hadm_id])
                stime = str((time_to_second(stime) - start_second) / 3600)
                etime = str((time_to_second(etime) - start_second) / 3600)
                setime = stime + ' -- ' + etime
                for drug in drug_list:
                    ehr_count_dict[drug] = ehr_count_dict.get(drug, 0) + 1
                ehr_dict['drug'][setime] = list(set(drug_list))
            except:
                pass

        py_op.mywritejson(os.path.join(groundtruth_dir, hadm_id + '.json'),
                          ehr_dict)
        # break
    py_op.mywritejson(os.path.join(args.data_dir, 'ehr_count_dict.json'),
                      ehr_count_dict)
コード例 #3
0
def select_records_of_variables_not_in_pivoted():
    count_dict = {v: 0 for v in item_id_dict.values()}
    hadm_time_dict = py_op.myreadjson(
        os.path.join(args.data_dir, args.dataset, 'hadm_time_dict.json'))
    wf = open(os.path.join(args.mimic_dir, 'sepsis_lab.csv'), 'w')
    for i_line, line in enumerate(
            open(os.path.join(args.mimic_dir, 'LABEVENTS.csv'))):
        if i_line:
            line_data = line.split(',')
            if len(line_data) == 0:
                continue
            hadm_id, item_id, ctime = line_data[2:5]
            value = line_data[5]
            if item_id in count_dict and hadm_id in hadm_time_dict:
                # print(line)
                if len(line_data) != 9:
                    print(line)
                # assert len(line_data) == 9
                count_dict[item_id] += 1
                wf.write(line)
        else:
            wf.write(line)
            continue
        if i_line % 10000 == 0:
            print(i_line)
    wf.close()
コード例 #4
0
ファイル: data_loader.py プロジェクト: yinchangchang/TAME
    def __init__(self, args, files, phase='train'):
        assert (phase == 'train' or phase == 'valid' or phase == 'test')
        self.args = args
        self.phase = phase
        self.files = files

        self.feature_mm_dict = py_op.myreadjson(
            os.path.join(args.file_dir,
                         args.dataset + '_feature_mm_dict.json'))
        self.feature_value_dict = py_op.myreadjson(
            os.path.join(
                args.file_dir, args.dataset +
                '_feature_value_dict_{:d}.json'.format(args.split_num)))
        self.n_dd = 40
        if args.dataset in ['MIMIC']:
            self.ehr_list = py_op.myreadjson(
                os.path.join(args.data_dir, args.dataset, 'ehr_list.json'))
            self.ehr_id = {ehr: i + 1 for i, ehr in enumerate(self.ehr_list)}
            self.args.n_ehr = len(self.ehr_id) + 1
コード例 #5
0
def map_ehr_id():
    print('start')
    ehr_count_dict = py_op.myreadjson(
        os.path.join(args.data_dir, 'ehr_count_dict.json'))
    ehr_list = [ehr for ehr, c in ehr_count_dict.items() if c > 100]
    ns = set('0123456789')
    print(ns)
    drug_list = [e for e in ehr_list if e[1] in ns]
    med_list = [e for e in ehr_list if e[1] not in ns]
    print(len(drug_list))
    print(len(med_list))
    py_op.mywritejson(os.path.join(args.data_dir, 'ehr_list.json'), ehr_list)
コード例 #6
0
def get_data():
    vocab_list = py_op.myreadjson(
        os.path.join(args.data_dir, args.dataset,
                     args.dataset[:-4].lower() + 'vocab.json'))
    aid_year_dict = py_op.myreadjson(
        os.path.join(args.data_dir, args.dataset, 'aid_year_dict.json'))
    pid_aid_did_dict = py_op.myreadjson(
        os.path.join(args.data_dir, args.dataset, 'pid_aid_did_dict.json'))
    pid_demo_dict = py_op.myreadjson(
        os.path.join(args.data_dir, args.dataset, 'pid_demo_dict.json'))
    case_control_data = py_op.myreadjson(
        os.path.join(args.data_dir, args.dataset, 'case_control_data.json'))
    case_test = list(
        set(case_control_data['case_test'] + case_control_data['case_valid']))
    case_control_dict = case_control_data['case_control_dict']
    dataset = data_loader.DataBowl(args, phase='DGVis')
    id_name_dict = py_op.myreadjson(
        os.path.join(args.file_dir, 'id_name_dict.json'))

    # select patients with higher ration in knowledge graph
    selected_pid_set = set(pid_aid_did_dict)
    test_set = set()
    for case in case_test:
        test_set.add(case)
        for con in case_control_dict[case]:
            test_set.add(con)
コード例 #7
0
def generate_demo():
    icu_hadm_dict = py_op.myreadjson('../../src/icu_hadm_dict.json')
    py_op.mywritejson(os.path.join(args.data_dir, 'icu_hadm_dict.json'),
                      icu_hadm_dict)

    sid_demo_dict = dict()
    sid_hadm_dict = dict()
    for i_line, line in enumerate(
            open(os.path.join(args.mimic_dir, 'PATIENTS.csv'))):
        if i_line:
            data = line.split(',')
            sid = data[1]
            gender = data[2].replace('"', '')
            dob = data[3][:4]
            sid_demo_dict[sid] = [gender, int(dob)]
    py_op.mywritejson(os.path.join(args.data_dir, 'sid_demo_dict.json'),
                      sid_demo_dict)

    hadm_sid_dict = dict()
    hadm_demo_dict = dict()
    hadm_time_dict = dict()
    for i_line, line in enumerate(
            open(os.path.join(args.mimic_dir, 'ICUSTAYS.csv'))):
        if i_line:
            line = line.replace('"', '')
            data = line.split(',')
            sid = data[1]
            hadm_id = data[2]
            icu_id = data[3]
            intime = data[-3]
            sid_hadm_dict[sid] = sid_hadm_dict.get(sid, []) + [hadm_id]
            if icu_id not in icu_hadm_dict:
                continue
            hadm_sid_dict[hadm_id] = sid
            gender = sid_demo_dict[sid][0]
            dob = sid_demo_dict[sid][1]
            age = int(intime[:4]) - dob
            if age < 18:
                print(age)
            assert age >= 18
            if age > 150:
                age = 90
            hadm_demo_dict[hadm_id] = [gender, age]
            hadm_time_dict[hadm_id] = intime
    py_op.mywritejson(os.path.join(args.data_dir, 'hadm_demo_dict.json'),
                      hadm_demo_dict)
    py_op.mywritejson(os.path.join(args.data_dir, 'hadm_time_dict.json'),
                      hadm_time_dict)
    py_op.mywritejson(os.path.join(args.data_dir, 'sid_hadm_dict.json'),
                      sid_hadm_dict)
    py_op.mywritejson(os.path.join(args.data_dir, 'hadm_sid_dict.json'),
                      hadm_sid_dict)
コード例 #8
0
    def __init__(self):

        # build model
        self.net = get_model()

        # prepare all the test data
        test_data = get_data()
        self.pid_demo_dict, self.pid_aid_did_dict, self.aid_second_dict, self.dataset, self.case_set, \
                self.vocab_list, self.graph_dict, self.id_name_dict = test_data
        # assert len(self.pid_demo_dict) == len(self.pid_aid_did_dict)
        self.pids = list(self.pid_demo_dict.keys())
        self.icd_name_dict = py_op.myreadjson(
            os.path.join(args.file_dir, 'icd_name_dict.json'))
コード例 #9
0
 def get_test_data(self):
     '''
     return all the data needed for visualization:
         pid_aid_did_dict: 
             pid: patient id
             aid: admission id
             did: diagnosis id
         aid_date_dict:
             aid: admission id
             date: int, admission's time
         vocab_dict: icd9 -> diagnosis  dict
     '''
     aid_date_dict = {
         aid: second_to_date(second)
         for aid, second in self.aid_second_dict.items()
     }
     vocab_list = []
     for vocab in self.vocab_list:
         if vocab in self.id_name_dict:
             vocab_list.append(self.id_name_dict[vocab])
         else:
             if vocab in self.icd_name_dict:
                 vocab_list.append(self.icd_name_dict[vocab])
             else:
                 vocab = vocab.strip('0')
                 if vocab in self.icd_name_dict:
                     vocab_list.append(self.icd_name_dict[vocab])
                 else:
                     vocab = vocab[:-1]
                     try:
                         vocab_list.append(
                             self.icd_name_dict[vocab.strip('0')])
                     except:
                         # vocab_list.append(self.icd_name_dict[vocab.strip('0')])
                         vocab_list.append(vocab)
                     assert len(vocab) >= 3
     # vocab_dict = { k:v for k,v in zip(self.vocab_list, vocab_list) }
     vocab_dict = py_op.myreadjson(
         os.path.join(args.file_dir, 'id_name_dict.json'))
     # py_op.mywritejson(os.path.join(args.file_dir, 'graph.json'), self.graph)
     try:
         return jsonify(self.pid_aid_did_dict, aid_date_dict, vocab_dict,
                        self.graph)
     except:
         return self.pid_aid_did_dict, aid_date_dict, vocab_dict, self.graph
コード例 #10
0
ファイル: main.py プロジェクト: yinchangchang/TAME
def wkmeans(n_cluster):
    subtyping_dir = os.path.join(args.result_dir, args.dataset, 'subtyping')
    hadm_id_list = py_op.myreadjson(os.path.join(subtyping_dir, 'hadm_id_list.json'))
    hadm_dist_matrix = np.load(os.path.join(subtyping_dir, 'hadm_dist_matrix.npy'))
    assert len(hadm_dist_matrix) == len(hadm_id_list)

    # initialization
    indices = range(len(hadm_id_list))
    np.random.shuffle(indices)
    init_groups = [indices[i*10: i*10 + 10] for i in range(n_cluster)]

    groups = init_groups
    for epoch in range(100):
        groups = wkmeans_epoch(hadm_dist_matrix, groups)
        print([len(g) for g in groups])
        if epoch and epoch % 10 == 0:
            cluster_results = []
            for g in groups:
                cluster_results.append([hadm_id_list[i] for i in g])
            py_op.mywritejson(os.path.join(subtyping_dir, 'cluster_results.json'), cluster_results)
コード例 #11
0
    def __init__(self):

        # build model
        model_file = os.path.join(
            args.result_dir,
            '{:s}-kg-gp.ckpt'.format(args.dataset.lower().split('_')[0]))
        self.net = get_model(model_file, 1)
        model_file = os.path.join(
            args.result_dir,
            '{:s}-no-kg.ckpt'.format(args.dataset.lower().split('_')[0]))
        self.net_nokg = get_model(model_file, 0)

        # prepare all the test data
        test_data = get_data()
        self.pid_demo_dict, self.pid_aid_did_dict, self.aid_second_dict, self.dataset, self.case_set, \
                self.vocab_list, self.graph_dict, self.id_name_dict, self.graph = test_data
        # assert len(self.pid_demo_dict) == len(self.pid_aid_did_dict)
        self.pids = list(self.pid_demo_dict.keys())
        self.icd_name_dict = py_op.myreadjson(
            os.path.join(args.file_dir, 'icd_name_dict.json'))
コード例 #12
0
def generate_drug_data():
    hadm_sid_dict = py_op.myreadjson(
        os.path.join(args.data_dir, 'hadm_sid_dict.json'))
    hadm_id_set = set(hadm_sid_dict)
    hadm_time_drug_dict = dict()
    for i_line, line in enumerate(
            open(os.path.join(args.mimic_dir, 'PRESCRIPTIONS.csv'))):
        if i_line:
            if i_line % 10000 == 0:
                print(i_line)
            line_data = [x.strip('"') for x in py_op.csv_split(line.strip())]
            _, SUBJECT_ID, hadm_id, _, startdate, enddate, _, drug, DRUG_NAME_POE, DRUG_NAME_GENERIC, FORMULARY_DRUG_CD, gsn, ndc, PROD_STRENGTH, DOSE_VAL_RX, DOSE_UNIT_RX, FORM_VAL_DISP, FORM_UNIT_DISP, ROUTE = line_data
            if len(hadm_id) and hadm_id in hadm_id_set:
                if hadm_id not in hadm_time_drug_dict:
                    hadm_time_drug_dict[hadm_id] = dict()
                time = startdate + ' -- ' + enddate
                if time not in hadm_time_drug_dict[hadm_id]:
                    hadm_time_drug_dict[hadm_id][time] = []
                hadm_time_drug_dict[hadm_id][time].append(drug)
                # hadm_time_drug_dict[hadm_id][time].append(ndc)
    py_op.mywritejson(os.path.join(args.data_dir, 'hadm_time_drug_dict.json'),
                      hadm_time_drug_dict)
コード例 #13
0
import os
import sys
import time
import numpy as np
import random
import json
from collections import OrderedDict
from tqdm import tqdm

sys.path.append('../tools')
sys.path.append('models/tools')
import parse, py_op

args = parse.args

drug_dict = py_op.myreadjson(os.path.join(args.file_dir, 'drug_dict.json'))
drug_set = set(drug_dict)


def find_drug(patient_dict):
    keys = sorted([int(key) for key in patient_dict.keys()])
    keys = [str(k) for k in keys]
    new_patient_dict = {}
    for k in keys:
        visit_data = patient_dict[k]
        if len(set(visit_data) & drug_set) > 0:
            return 1, new_patient_dict
        else:
            new_patient_dict[k] = patient_dict[k]
    return 0, patient_dict
コード例 #14
0
ファイル: main.py プロジェクト: DGViz/DGViz.github.io
def analyse_contributions(labels, contributions, raw_data):
    cui_con_dict = { }
    # py_op.myreadjson('../result/cui_con_dict.json')

    print(labels.shape, contributions.shape, raw_data.shape)
    # print err

    pos_labels = labels[labels>0.5]
    pos_contributions = contributions[labels>0.5, :]
    pos_contributions[pos_contributions < 0] = 0
    pos_data = raw_data[labels>0.5, :]
    # print pos_labels.shape, pos_contributions.shape
    # print err

    pos_sum = (pos_contributions>0) * pos_contributions + 0.0001
    pos_sum = np.fabs(pos_sum.sum(1)).reshape((-1,1))
    # pos_sum = np.fabs(pos_contributions.sum(1)).reshape((-1,1))

    pos_ratio = pos_contributions / pos_sum
    # print_contributions(pos_contributions)

    # 不能reshape
    # pos_ratio = pos_ratio.reshape(-1)
    # pos_data = pos_data.reshape(-1)

    id_name_dict = py_op.myreadjson(os.path.join(args.file_dir, 'id_name_dict.json'))
    ehr_cui_dict = py_op.myreadjson(os.path.join(args.file_dir, 'ehr_cui_dict.json'))
    cui_distance_dict = py_op.myreadjson(os.path.join(args.file_dir, 'cui_distance_dict.json'))
    name_id_dict = { v:ehr_cui_dict.get(k, k) for k,v in id_name_dict.items() }
    name_ratio_dict = { }
    for b_ration, b_idx in zip(pos_ratio, pos_data):
        idx_ratio = dict()
        for ratio, idx in zip(b_ration, b_idx):
            idx_ratio[idx] = idx_ratio.get(idx, 0) + ratio

        for idx, ratio in idx_ratio.items():
            id = args.vocab[idx]
            id = str(id)
            if idx>0 and id in id_name_dict:
                name = id_name_dict[id]
                name_ratio_dict[name] = name_ratio_dict.get(name, []) + [ratio]
                cui_con_dict[id] = cui_con_dict.get(id, []) + [ratio]

    cons_dir = '../result/cons'
    num = len(os.listdir(cons_dir))
    # py_op.mywritejson('../result/cons/{:d}.json'.format(num), cui_con_dict)
    contributions_list = []

    name_score_dict = { }
    for n,v in name_ratio_dict.items():
        if len(v) > 4:
            name_score_dict[n] = np.mean(v)
    name_list = sorted(name_score_dict.keys(), key=lambda n:- name_score_dict[n])
    for name in name_list[:30]:
        # if name_id_dict[name] in cui_distance_dict:
        #     print 'contribution rate: {:3.2f}%  {:d}    {:s}    {:s}'.format(100 * name_score_dict[name], cui_distance_dict.get(name_id_dict[name], -1), name_id_dict[name], name)
        # print name_ratio_dict[name]
        print('contribution rate: {:3.4f}%  {:d}    {:s}    {:s}'.format(100 * name_score_dict[name], cui_distance_dict.get(name_id_dict[name], -1), name_id_dict[name], name))
        # contributions_list.append({ 'id': name_id_dict[name], 'contribution': name_score_dict[name] })
    print('一共{:d}个name'.format(len(name_list)))
    return cui_con_dict
コード例 #15
0
    graph_dict = {'edge': {}, 'node': {}}
    for line in open(os.path.join(args.file_dir, 'relation2id.txt')):
        data = line.strip().split()
        if len(data) == 2:
            relation, id = data[0], int(data[1])
            graph_dict['edge'][id] = relation
    for line in open(os.path.join(args.file_dir, 'entity2id.txt')):
        data = line.strip().split()
        if len(data) == 2:
            cui, id = data[0], int(data[1])
            graph_dict['node'][id] = cui

    if 1:
        selected_pid_set = set()
        graph = {'nodes': {}, 'edges': []}
        id_icd9_dict = py_op.myreadjson(
            os.path.join(args.file_dir, 'id_icd9_dict.json'))
        icd9_cui_dict = py_op.myreadjson(
            os.path.join(args.file_dir, 'icd9_cui_dict.json'))
        edge_dict = {}
        vocab_dict = py_op.myreadjson(
            os.path.join(args.file_dir, 'id_name_dict.json'))
        no_name = 0
        for line in open(os.path.join(args.file_dir, 'graph.txt')):
            data = line.strip().split('\t')
            node_f, node_s, relation_type = int(data[0]), int(data[1]), int(
                data[2])
            edge_dict[node_f] = edge_dict.get(
                node_f, []) + [[node_s, int(relation_type)]]

            # build graph
            cui_f = graph_dict['node'][node_f]
コード例 #16
0
def main():

    assert args.dataset in ['DACMI', 'MIMIC']
    if args.dataset == 'MIMIC':
        args.n_ehr = len(py_op.myreadjson(os.path.join(args.data_dir, args.dataset, 'ehr_list.json')))
    args.name_list = py_op.myreadjson(os.path.join(args.file_dir, args.dataset+'_feature_list.json'))[1:]
    args.output_size = len(args.name_list)
    files = sorted(glob(os.path.join(args.data_dir, args.dataset, 'train_with_missing/*.csv')))
    data_splits = py_op.myreadjson(os.path.join(args.file_dir, args.dataset + '_splits.json'))
    train_files = [f for idx in [0, 1, 2, 3, 4, 5, 6] for f in data_splits[idx]]
    valid_files = [f for idx in [7] for f in data_splits[idx]]
    test_files = [f for idx in [8, 9] for f in data_splits[idx]]
    if args.phase == 'test':
        train_phase, valid_phase, test_phase, train_shuffle = 'test', 'test', 'test', False
    else:
        train_phase, valid_phase, test_phase, train_shuffle = 'train', 'valid', 'test', True
    train_dataset = data_loader.DataBowl(args, train_files, phase=train_phase)
    valid_dataset = data_loader.DataBowl(args, valid_files, phase=valid_phase)
    test_dataset = data_loader.DataBowl(args, test_files, phase=test_phase)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=train_shuffle, num_workers=args.workers, pin_memory=True)
    valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True)
    args.vocab_size = (args.output_size + 2) * (1 + args.split_num) + 5

    if args.model == 'tame':
        net = tame.AutoEncoder(args)
    loss = myloss.MSELoss(args)

    net = _cuda(net, 0)
    loss = _cuda(loss, 0)

    best_metric= [0,0]
    start_epoch = 0

    if args.resume:
        p_dict = {'model': net}
        function.load_model(p_dict, args.resume)
        best_metric = p_dict['best_metric']
        start_epoch = p_dict['epoch'] + 1

    parameters_all = []
    for p in net.parameters():
        parameters_all.append(p)

    optimizer = torch.optim.Adam(parameters_all, args.lr)

    if args.phase == 'train':
        for epoch in range(start_epoch, args.epochs):
            print('start epoch :', epoch)
            train_eval(train_loader, net, loss, epoch, optimizer, best_metric)
            best_metric = train_eval(valid_loader, net, loss, epoch, optimizer, best_metric, phase='valid')
        print 'best metric', best_metric

    elif args.phase == 'test':
        folder = os.path.join(args.result_dir, args.dataset, 'imputation_result')
        os.system('rm -r ' + folder)
        os.system('mkdir ' + folder)

        train_eval(train_loader, net, loss, 0, optimizer, best_metric, 'test')
        train_eval(valid_loader, net, loss, 0, optimizer, best_metric, 'test')
        train_eval(test_loader, net, loss, 0, optimizer, best_metric, 'test')
コード例 #17
0
def merge_pivoted_data(csv_list):
    name_list = ['hadm_id', 'charttime']
    for k, v in variable_map_dict.items():
        if k not in ['age', 'gender']:
            if len(v):
                name_list.append(v)
            elif k in item_id_dict:
                name_list.append(k)
    name_index_dict = {name: id for id, name in enumerate(name_list)}

    hadm_time_dict = py_op.myreadjson(
        os.path.join(args.data_dir, args.dataset, 'hadm_time_dict.json'))
    icu_hadm_dict = py_op.myreadjson(
        os.path.join(args.data_dir, args.dataset, 'icu_hadm_dict.json'))
    merge_dir = os.path.join(args.data_dir, args.dataset, 'merge_pivoted')
    os.system('rm -r ' + merge_dir)
    os.system('mkdir ' + merge_dir)
    pivoted_dir = os.path.join(args.result_dir, 'mimic/pivoted_sofa')
    py_op.mkdir(pivoted_dir)

    for fi in csv_list:
        print(fi)
        for i_line, line in enumerate(open(os.path.join(args.mimic_dir, fi))):
            if i_line:
                line_data = line.strip().split(',')
                if len(line_data) <= 0:
                    continue
                line_dict = dict()
                for iv, v in enumerate(line_data):
                    if len(v.strip()):
                        name = head[iv]
                        line_dict[name] = v

                if fi == 'pivoted_sofa.csv':
                    icu_id = line_dict.get('icustay_id', 'xxx')
                    if icu_id not in icu_hadm_dict:
                        continue
                    hadm_id = str(icu_hadm_dict[icu_id])
                    line_dict['hadm_id'] = hadm_id
                    line_dict['charttime'] = line_dict['starttime']

                hadm_id = line_dict.get('hadm_id', 'xxx')
                if hadm_id not in hadm_time_dict:
                    continue
                hadm_time = time_to_second(hadm_time_dict[hadm_id])
                now_time = time_to_second(line_dict['charttime'])
                delta_hour = int((now_time - hadm_time) / 3600)
                line_dict['charttime'] = str(delta_hour)

                if fi == 'pivoted_sofa.csv':
                    sofa_file = os.path.join(pivoted_dir, hadm_id + '.csv')
                    if not os.path.exists(sofa_file):
                        with open(sofa_file, 'w') as f:
                            f.write(sofa_head)
                    wf = open(sofa_file, 'a')
                    sofa_line = [str(delta_hour)] + line.split(',')[4:]
                    wf.write(','.join(sofa_line))
                    wf.close()

                assert 'hadm_id' in line_dict
                assert 'charttime' in line_dict
                new_line = []
                for name in name_list:
                    new_line.append(line_dict.get(name, ''))
                new_line = ','.join(new_line) + '\n'
                hadm_file = os.path.join(merge_dir, hadm_id + '.csv')
                if not os.path.exists(hadm_file):
                    with open(hadm_file, 'w') as f:
                        f.write(','.join(name_list) + '\n')
                wf = open(hadm_file, 'a')
                wf.write(new_line)
                wf.close()

            else:
                if fi == 'pivoted_sofa.csv':
                    sofa_head = ','.join(['time'] +
                                         line.replace('"', '').split(',')[4:])
                # "icustay_id","hr","starttime","endtime","pao2fio2ratio_novent","pao2fio2ratio_vent","rate_epinephrine","rate_norepinephrine","rate_dopamine","rate_dobutamine","meanbp_min","gcs_min","urineoutput","bilirubin_max","creatinine_max","platelet_min","respiration","coagulation","liver","cardiovascular","cns","renal","respiration_24hours","coagulation_24hours","liver_24hours","cardiovascular_24hours","cns_24hours","renal_24hours","sofa_24hours"

                head = line.replace('"', '').strip().split(',')
                head = [h.strip() for h in head]
                # print(line)
                for h in head:
                    if h not in name_index_dict:
                        print(h)
コード例 #18
0
ファイル: data_loader.py プロジェクト: DGViz/DGViz.github.io
    def __init__(self, args, phase='train'):
        assert (phase == 'train' or phase == 'valid'
                or phase == 'test', phase == 'DGVis')
        self.args = args
        self.phase = phase
        self.vocab = json.load(
            open(
                os.path.join(
                    args.data_dir, args.dataset,
                    args.dataset.lower().replace('json', '') + 'vocab.json')))
        if phase == 'DGVis':
            self.inputs = json.load(
                open(os.path.join(args.data_dir, args.dataset, 'test.json')))
        else:
            self.inputs = json.load(
                open(os.path.join(args.data_dir, args.dataset,
                                  phase + '.json')))
        if phase == 'train':
            group_drug = np.load(
                os.path.join(args.data_dir, args.dataset, 'group_drug.npy'))
            self.index_to_group = dict()
            groups = set(group_drug[0])
            print('Groups: ', groups)
            self.group_to_index = {
                g: {
                    'with_drug': [],
                    'without_drug': []
                }
                for g in groups
            }
            self.indices_without_drug = []
            for index in range(len(group_drug[0])):
                group, drug = group_drug[:, index]
                self.index_to_group[index] = group
                if drug == 0:
                    self.group_to_index[group]['without_drug'].append(index)
                    self.indices_without_drug.append(index)
                else:
                    self.group_to_index[group]['with_drug'].append(index)
        self.seq = args.seq_length
        # print 'seq lenght', args.seq_length

        self.id_icd9_dict = py_op.myreadjson(
            os.path.join(args.file_dir, 'id_icd9_dict.json'))
        self.icd9_cui_dict = py_op.myreadjson(
            os.path.join(args.file_dir, 'icd9_cui_dict.json'))

        self.entity_id = dict()
        self.id_entity = []
        for line in open(os.path.join(args.file_dir, 'entity2id.txt')):
            data = line.strip().split()
            if len(data) == 2:
                cui, id = data[0], int(data[1])
                self.entity_id[cui] = id
                self.id_entity.append(cui)

        # infomation for knowledge graph
        cui_set = set()
        for id in self.vocab:
            id = str(id)
            if id in self.id_icd9_dict:
                icd9 = self.id_icd9_dict[id]
                if icd9 in self.icd9_cui_dict:
                    cui = self.icd9_cui_dict[icd9]
                    cui_set.add(cui)
        # print 'start', len(cui_set)

        relation_set = set()
        for line in open(os.path.join(args.file_dir, 'graph.txt')):
            node_f, node_s, relation = line.strip().split('\t')
            node_f = self.id_entity[int(node_f)]
            node_s = self.id_entity[int(node_s)]
            cui_set.add(node_f)
            cui_set.add(node_s)
            relation_set.add(relation)
        self.vocab = ['null'
                      ] + self.vocab + list(set(cui_set) - set(self.vocab))
        # self.vocab = ['null'] + sorted(set(self.vocab) | cui_set)
        # self.vocab_index = { w:i for i,w in enumerate(self.vocab) }
        self.vocab_index = {}
        for i, w in enumerate(self.vocab):
            if w not in self.vocab_index:
                self.vocab_index[w] = i
        self.relation = sorted(relation_set)

        # stati number of events in knowledge graph
        if args.phase == 'DGVis':
            self.id_icd9_dict = py_op.myreadjson(
                os.path.join(args.file_dir, 'id_icd9_dict.json'))
            self.icd9_cui_dict = py_op.myreadjson(
                os.path.join(args.file_dir, 'icd9_cui_dict.json'))
            self.edge_dict = {}
            for line in open(os.path.join(args.file_dir, 'graph.txt')):
                data = line.strip().split('\t')
                node_f, node_s, relation_type = int(data[0]), int(
                    data[1]), int(data[2])
                self.edge_dict[node_f] = self.edge_dict.get(
                    node_f, []) + [[node_s, int(relation_type)]]
            # ..
            in_kg_set = set()
            for id in self.vocab:
                if id in self.id_icd9_dict:
                    icd = self.id_icd9_dict[id]
                    if icd in self.icd9_cui_dict:
                        cui = self.icd9_cui_dict[icd]
                        if cui in self.entity_id:
                            cui = self.entity_id[cui]
                            if cui in self.edge_dict:
                                in_kg_set.add(id)
            print('There are :', len(in_kg_set), len(self.vocab))
            # ..
            n_in_event = 0
            n_out_event = 0
            for pid_data in self.inputs:
                dp, lp = pid_data
                for vs in dp.values():
                    for fid in vs:
                        if fid in in_kg_set:
                            n_in_event += 1
                        else:
                            n_out_event += 1
            print('in/out', n_in_event, n_out_event)
コード例 #19
0
def train_eval(data_loader, net, loss, epoch, optimizer, best_metric, phase='train'):
    print(phase)
    lr = get_lr(epoch)
    if phase == 'train':
        net.train()
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
    else:
        net.eval()

    loss_list, pred_list, label_list, mask_list = [], [], [], []
    feature_mm_dict = py_op.myreadjson(os.path.join(args.file_dir, args.dataset + '_feature_mm_dict.json'))
    for b, data_list in enumerate(tqdm(data_loader)):
        data, label, mask, files = data_list[:4]
        data = index_value(data)
        if args.model == 'tame':
            pre_input, pre_time, post_input, post_time, dd_list = data_list [4:9]
            pre_input = index_value(pre_input)
            post_input = index_value(post_input)
            pre_time = Variable(_cuda(pre_time))
            post_time = Variable(_cuda(post_time))
            dd_list = Variable(_cuda(dd_list))
            neib = [pre_input, pre_time, post_input, post_time]

        label = Variable(_cuda(label)) # [bs, 1]
        mask = Variable(_cuda(mask)) # [bs, 1]
        if args.dataset in ['MIMIC'] and args.model == 'tame' and args.use_mm:
            output = net(data, neib=neib, dd=dd_list, mask=mask) # [bs, 1]
        elif args.model == 'tame' and args.use_ta:
            output = net(data, neib=neib, mask=mask) # [bs, 1]
        else:
            output = net(data, mask=mask) # [bs, 1]

        if phase == 'test':
            folder = os.path.join(args.result_dir, args.dataset, 'imputation_result')
            output_data = output.data.cpu().numpy()
            mask_data = mask.data.cpu().numpy().max(2)
            for (icu_data, icu_mask, icu_file) in zip(output_data, mask_data, files):
                icu_file = os.path.join(folder, icu_file.split('/')[-1].replace('.csv', '.npy'))
                np.save(icu_file, icu_data)
                if args.dataset == 'MIMIC':
                    with open(os.path.join(args.data_dir, args.dataset, 'train_groundtruth', icu_file.split('/')[-1].replace('.npy', '.csv'))) as f:
                        init_data = f.read().strip().split('\n')
                    # print(icu_file)
                    wf = open(icu_file.replace('.npy', '.csv'), 'w')
                    wf.write(init_data[0] + '\n')
                    item_list = init_data[0].strip().split(',')
                    if len(init_data) <= args.n_visit:
                        try:
                            assert len(init_data) == (icu_mask >= 0).sum() + 1
                        except:
                            pass
                            # print(len(init_data))
                            # print(sum(icu_mask >= 0))
                            # print(icu_file)
                    for init_line, out_line in zip(init_data[1:], icu_data):
                        init_line = init_line.strip().split(',')
                        new_line = [init_line[0]]
                        # assert len(init_line) == len(out_line) + 1
                        for item, iv, ov in zip(item_list[1:], init_line[1:], out_line):
                            if iv.strip() not in ['', 'NA']:
                                new_line.append('{:4.4f}'.format(float(iv.strip())))
                            else:
                                minv, maxv = feature_mm_dict[item]
                                ov = ov * (maxv - minv) + minv
                                new_line.append('{:4.4f}'.format(ov))
                        new_line = ','.join(new_line)
                        wf.write(new_line + '\n')
                    wf.close()


        loss_output = loss(output, label, mask)
        pred_list.append(output.data.cpu().numpy())
        loss_list.append(loss_output.data.cpu().numpy())
        label_list.append(label.data.cpu().numpy())
        mask_list.append(mask.data.cpu().numpy())

        if phase == 'train':
            optimizer.zero_grad()
            loss_output.backward()
            optimizer.step()


    pred = np.concatenate(pred_list, 0)
    label = np.concatenate(label_list, 0)
    mask = np.concatenate(mask_list, 0)
    metric_list = function.compute_nRMSE(pred, label, mask)
    avg_loss = np.mean(loss_list)

    print('\nTrain Epoch %03d (lr %.5f)' % (epoch, lr))
    print('loss: {:3.4f} \t'.format(avg_loss))
    print('metric: {:s}'.format('\t'.join(['{:3.4f}'.format(m) for m in metric_list[:2]])))


    metric = metric_list[0]
    if phase == 'valid' and (best_metric[0] == 0 or best_metric[0] > metric):
        best_metric = [metric, epoch]
        function.save_model({'args': args, 'model': net, 'epoch':epoch, 'best_metric': best_metric})
    metric_list = metric_list[2:]
    name_list = args.name_list
    assert len(metric_list) == len(name_list) * 2
    s = args.model
    for i in range(len(metric_list)/2):
        name = name_list[i] + ''.join(['.' for _ in range(40 - len(name_list[i]))])
        print('{:s}{:3.4f}......{:3.4f}'.format(name, metric_list[2*i], metric_list[2*i+1]))
        s = s+ '  {:3.4f}'.format(metric_list[2*i])
    if phase != 'train':
        print('\t\t\t\t best epoch: {:d}     best MSE on missing value: {:3.4f} \t'.format(best_metric[1], best_metric[0])) 
        print(s)
    return best_metric
コード例 #20
0
ファイル: main.py プロジェクト: yinchangchang/TAME
def compute_dist_mat():
    files = glob(
        os.path.join(args.result_dir, args.dataset,
                     'imputation_result/*.csv'))  # [:100]
    feature_ms_dict = py_op.myreadjson(
        os.path.join(args.file_dir, args.dataset + '_feature_ms_dict.json'))
    subtyping_dir = os.path.join(args.result_dir, args.dataset, 'subtyping')
    py_op.mkdir(subtyping_dir)
    hadm_id_list = []
    mean_variables = []
    hadm_variable_dict = {}
    all_values = []

    for i_fi, fi in enumerate(tqdm(files)):
        hadm_id = fi.split('/')[-1].split('.')[0]
        hadm_data = []
        for i_line, line in enumerate(open(fi)):
            if i_line:
                line_data = line.strip().split(',')
                line_data = np.array([float(x) for x in line_data])
                if len(line_data) != n_variables + 1:
                    print(i_fi, fi)
                if line_data[0] < 0:
                    continue
                elif line_data[0] < 24:
                    hadm_data.append(line_data)
                else:
                    break
            else:
                head = line.strip().split(',')[1:]
                assert len(head) == n_variables

        values = np.array(hadm_data, dtype=np.float32)
        values = values[-24:]
        times = values[:, 0]
        values = values[:, 1:]

        assert len(values.shape) == 2
        assert values.shape[1] == n_variables

        hadm_variable_dict[hadm_id] = values
        hadm_id_list.append(hadm_id)
        all_values.append(values)

    all_values = np.concatenate(all_values, 0)
    ms = [all_values.mean(0), all_values.std(0)]

    hadm_dist_matrix = np.zeros((len(hadm_id_list), len(hadm_id_list))) - 1
    for i in tqdm(range(len(hadm_id_list))):
        hadm_dist_matrix[i, i] = 0
        for j in range(i + 1, len(hadm_id_list)):
            if hadm_dist_matrix[i, j] >= 0 or i == j:
                continue
            s1 = hadm_variable_dict[hadm_id_list[i]]
            s2 = hadm_variable_dict[hadm_id_list[j]]
            s1 = norm(s1, ms)
            s2 = norm(s2, ms)
            dist_mat = dist_func(s1, s2)
            path = np.zeros([dist_mat.shape[0], dist_mat.shape[1], 3
                             ]) - inf - 1
            compute_dtw(dist_mat, path, hadm_dist_matrix, i, j)

    py_op.mywritejson(os.path.join(subtyping_dir, 'hadm_id_list.json'),
                      hadm_id_list)
    np.save(os.path.join(subtyping_dir, 'hadm_dist_matrix.npy'),
            hadm_dist_matrix)
コード例 #21
0
def get_data():
    vocab_list = py_op.myreadjson(
        os.path.join(args.data_dir, args.dataset,
                     args.dataset[:-4].lower() + 'vocab.json'))
    aid_year_dict = py_op.myreadjson(
        os.path.join(args.data_dir, args.dataset, 'aid_year_dict.json'))
    pid_aid_did_dict = py_op.myreadjson(
        os.path.join(args.data_dir, args.dataset, 'pid_aid_did_dict.json'))
    pid_demo_dict = py_op.myreadjson(
        os.path.join(args.data_dir, args.dataset, 'pid_demo_dict.json'))
    case_control_data = py_op.myreadjson(
        os.path.join(args.data_dir, args.dataset, 'case_control_data.json'))
    case_test = list(
        set(case_control_data['case_test'] + case_control_data['case_valid']))
    case_control_dict = case_control_data['case_control_dict']
    dataset = data_loader.DataBowl(args, phase='DGVis')
    id_name_dict = py_op.myreadjson(
        os.path.join(args.file_dir, 'id_name_dict.json'))
    graph_dict = {'edge': {}, 'node': {}}
    for line in open(os.path.join(args.file_dir, 'relation2id.txt')):
        data = line.strip().split()
        if len(data) == 2:
            relation, id = data[0], int(data[1])
            graph_dict['edge'][id] = relation
    for line in open(os.path.join(args.file_dir, 'entity2id.txt')):
        data = line.strip().split()
        if len(data) == 2:
            cui, id = data[0], int(data[1])
            if cui in id_name_dict:
                graph_dict['node'][id] = id_name_dict[cui]
            else:
                graph_dict['node'][id] = cui

    aid_second_dict = py_op.myreadjson(
        os.path.join(args.data_dir, args.dataset, 'aid_second_dict.json'))

    for pid, aid_did_dict in pid_aid_did_dict.items():
        n = 0
        aids = sorted(aid_did_dict.keys(),
                      key=lambda aid: int(aid),
                      reverse=True)
        for ia, aid in enumerate(aids):
            n += len(aid_did_dict[aid])
            if n > 120:
                pid_aid_did_dict[pid] = {
                    aid: aid_did_dict[aid]
                    for aid in aids[:ia]
                }
                break

    new_pid_demo_dict = dict()
    pid_list = case_test + [
        c for case in case_test for c in case_control_dict[str(case)]
    ]
    pid_list = [str(pid) for pid in pid_list]
    for pid in pid_list:
        pid = str(pid)
        demo = pid_demo_dict[pid]
        gender = demo[0]
        yob = int(demo[2:])
        if pid not in pid_aid_did_dict:
            continue
        aids = pid_aid_did_dict[pid].keys()
        year = max([aid_year_dict[aid] for aid in aids])
        age = year - yob
        assert age < 100 and age > 0
        new_pid_demo_dict[pid] = [gender, age]

    # return data
    # case_control_dict = { case: [c for c in case_control_dict[case] if c in new_pid_demo_dict] for case in case_test if case in new_pid_demo_dict}

    pid_demo_dict = new_pid_demo_dict
    pid_aid_did_dict = {
        pid: pid_aid_did_dict[pid]
        for pid in new_pid_demo_dict
    }

    # print('case_set', case_control_dict.keys())

    return pid_demo_dict, pid_aid_did_dict, aid_second_dict, dataset, set(
        case_control_dict), vocab_list, graph_dict, id_name_dict
コード例 #22
0
ファイル: model.py プロジェクト: DGViz/DGViz.github.io
    def __init__(self, opt, use_kg=0):
        super(FCModel, self).__init__()
        self.use_kg = use_kg
        self.opt = opt
        self.vocab = opt.vocab
        self.vocab_size = len(opt.vocab)
        self.vocab_index = {w: i for i, w in enumerate(self.vocab)}
        self.input_encoding_size = opt.embed_size
        self.rnn_type = opt.rnn_type
        self.rnn_size = opt.rnn_size
        self.num_layers = opt.num_layers
        self.drop_prob_lm = opt.drop_prob_lm
        self.seq_length = opt.seq_length

        self.ss_prob = 0.0  # Schedule sampling probability

        self.core = LSTMCore(opt, opt.embed_size)
        self.embed = nn.Embedding(self.vocab_size, self.input_encoding_size)
        self.logit = nn.Linear(self.rnn_size, 1, False)
        self.maxpooling = nn.AdaptiveMaxPool1d(1, True)

        # infomation for knowledge graph
        self.relation = opt.relation
        self.id_icd9_dict = py_op.myreadjson(
            os.path.join(args.file_dir, 'id_icd9_dict.json'))
        self.cui_icd9_dict = py_op.myreadjson(
            os.path.join(args.file_dir, 'cui_icd9_dict.json'))
        self.icd9_cui_dict = py_op.myreadjson(
            os.path.join(args.file_dir, 'icd9_cui_dict.json'))

        # knowledge graph attention
        self.h2att = nn.Linear(opt.embed_size, kg_embedding_size)
        self.i2att = nn.Linear(kg_embedding_size, kg_embedding_size)
        self.att_hid_size = opt.embed_size
        self.alpha_net_kg = nn.Linear(kg_embedding_size, 1)

        # visit attention
        self.o2att = nn.Linear(opt.embed_size, opt.embed_size)
        self.alpha_net_vs = nn.Linear(opt.embed_size, 1)

        self.kg2kg = nn.Linear(kg_embedding_size, kg_embedding_size)

        # graph embedding 内容
        self.entity_id = dict()
        self.id_entity = []
        for line in open(os.path.join(args.file_dir, 'entity2id.txt')):
            data = line.strip().split()
            if len(data) == 2:
                cui, id = data[0], int(data[1])
                self.entity_id[cui] = id
                self.id_entity.append(cui)

        self.relation_id = dict()
        self.id_relation = []
        for line in open(os.path.join(args.file_dir, 'relation2id.txt')):
            data = line.strip().split()
            if len(data) == 2:
                cui, id = data[0], int(data[1])
                self.relation_id[cui] = id
                self.id_relation.append(cui)

        self.edge_dict = {}
        for line in open(os.path.join(args.file_dir, 'graph.txt')):
            data = line.strip().split('\t')
            node_f, node_s, relation_type = int(data[0]), int(data[1]), int(
                data[2])
            # if relation_type == 0:
            #     self.edge_dict[node_f] = self.edge_dict.get(node_f, []) + [[node_s, int(relation_type)]]
            self.edge_dict[node_f] = self.edge_dict.get(
                node_f, []) + [[node_s, int(relation_type)]]

        if args.model == 'GRAM':
            embedding = py_op.myreadjson(
                os.path.join(args.file_dir, 'embedding.vec.json'))
            self.ent_embeddings = Variable(
                _cuda(
                    torch.from_numpy(
                        np.array(embedding['ent_embeddings'],
                                 dtype=np.float32))))
            self.rel_embeddings = Variable(
                _cuda(
                    torch.from_numpy(
                        np.array(embedding['rel_embeddings'],
                                 dtype=np.float32))))

        elif kg_embedding_size in [100, 200, 300]:
            embedding = py_op.myreadjson(
                os.path.join(args.file_dir, 'embedding.vec.json'))
            # self.ent_embeddings = Variable(torch.from_numpy(np.array(embedding['ent_embeddings'], dtype=np.float32)).cuda(async=True))
            # self.rel_embeddings = Variable(torch.from_numpy(np.array(embedding['rel_embeddings'], dtype=np.float32)).cuda(async=True))
            self.ent_embeddings = np.array(embedding['ent_embeddings'],
                                           dtype=np.float32)
            self.rel_embeddings = np.array(embedding['rel_embeddings'],
                                           dtype=np.float32)
        elif kg_embedding_size in [500, 1000]:
            self.ent_embeddings = np.zeros(
                (len(self.entity_id), kg_embedding_size), dtype=np.float32)
            num = 0
            # wf = open('../data/kg/cui2vec_selected.csv', 'w')
            # for line in open('../data/kg/cui2vec_pretrained.csv'):
            for line in open(
                    os.path.join(args.file_dir, 'cui2vec_selected.csv')):
                data = line.strip().split(',')
                cui = data[0].strip('"')
                vec = data[1:]
                try:
                    vec = [float(v) for v in vec]
                except:
                    continue
                if cui in self.entity_id:
                    id = self.entity_id[cui]
                    self.ent_embeddings[id] = vec
                    num += 1
                    # wf.write(line)
                # else:
                # print(cui)

        self.ensemble = nn.Linear(2, 1)
        self.init_weights()
コード例 #23
0
ファイル: main.py プロジェクト: onlyzdd/clinical-fusion
def main():
    args.n_ehr = len(
        json.load(
            open(os.path.join(args.files_dir, 'demo_index_dict.json'),
                 'r'))) + 10
    args.name_list = json.load(
        open(os.path.join(args.files_dir, 'feature_list.json'), 'r'))[1:]
    args.input_size = len(args.name_list)
    files = sorted(glob(os.path.join(args.data_dir, 'resample_data/*.csv')))
    data_splits = json.load(
        open(os.path.join(args.files_dir, 'splits.json'), 'r'))
    train_files = [
        f for idx in [0, 1, 2, 3, 4, 5, 6] for f in data_splits[idx]
    ]
    valid_files = [f for idx in [7] for f in data_splits[idx]]
    test_files = [f for idx in [8, 9] for f in data_splits[idx]]
    if args.phase == 'test':
        train_phase, valid_phase, test_phase, train_shuffle = 'test', 'test', 'test', False
    else:
        train_phase, valid_phase, test_phase, train_shuffle = 'train', 'valid', 'test', True
    train_dataset = data_loader.DataBowl(args, train_files, phase=train_phase)
    valid_dataset = data_loader.DataBowl(args, valid_files, phase=valid_phase)
    test_dataset = data_loader.DataBowl(args, test_files, phase=test_phase)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=train_shuffle,
                              num_workers=args.workers,
                              pin_memory=True)
    valid_loader = DataLoader(valid_dataset,
                              batch_size=args.batch_size,
                              shuffle=False,
                              num_workers=args.workers,
                              pin_memory=True)
    test_loader = DataLoader(test_dataset,
                             batch_size=args.batch_size,
                             shuffle=False,
                             num_workers=args.workers,
                             pin_memory=True)

    args.vocab_size = args.input_size + 2

    if args.use_unstructure:
        args.unstructure_size = len(
            py_op.myreadjson(os.path.join(args.files_dir,
                                          'vocab_list.json'))) + 10

    # net = icnn.CNN(args)
    # net = cnn.CNN(args)
    net = lstm.LSTM(args)
    # net = torch.nn.DataParallel(net)
    # loss = myloss.Loss(0)
    loss = myloss.MultiClassLoss(0)

    net = _cuda(net, 0)
    loss = _cuda(loss, 0)

    best_metric = [0, 0]
    start_epoch = 0

    if args.resume:
        p_dict = {'model': net}
        function.load_model(p_dict, args.resume)
        best_metric = p_dict['best_metric']
        start_epoch = p_dict['epoch'] + 1

    parameters_all = []
    for p in net.parameters():
        parameters_all.append(p)

    optimizer = torch.optim.Adam(parameters_all, args.lr)

    if args.phase == 'train':
        for epoch in range(start_epoch, args.epochs):
            print('start epoch :', epoch)
            t0 = time.time()
            train_eval(train_loader, net, loss, epoch, optimizer, best_metric)
            t1 = time.time()
            print('Running time:', t1 - t0)
            best_metric = train_eval(valid_loader,
                                     net,
                                     loss,
                                     epoch,
                                     optimizer,
                                     best_metric,
                                     phase='valid')
        print('best metric', best_metric)

    elif args.phase == 'test':
        train_eval(test_loader, net, loss, 0, optimizer, best_metric, 'test')
コード例 #24
0
ファイル: data_loader.py プロジェクト: yinchangchang/TAME
    def get_mm_item(self, idx):
        input_file = self.files[idx]
        output_file = input_file.replace('with_missing', 'groundtruth')

        with open(output_file) as f:
            output_data = f.read().strip().split('\n')
        with open(input_file) as f:
            input_data = f.read().strip().split('\n')

        if self.args.random_missing and self.phase == 'train':
            input_data = []
            valid_list = []
            for line in output_data:
                values = line.strip().split(',')
                input_data.append(values)
                valid = []
                for iv, v in enumerate(values):
                    if v.strip() not in ['', 'NA']:
                        valid.append(1)
                    else:
                        valid.append(0)
                valid_list.append(valid)
            valid_list = np.array(valid_list)
            valid_list[0] = 0
            for i in range(1, valid_list.shape[1]):
                valid = valid_list[:, i]
                indices = np.where(valid > 0)[0]
                np.random.shuffle(indices)
                if len(indices > 2):
                    input_data[indices[0]][i] = 'NA'
            input_data = [','.join(line) for line in input_data]

        init_input_data = input_data

        if self.args.use_ve == 0:
            input_data = self.pre_filling(input_data)

        assert len(input_data) == len(output_data)

        mask_list, input_list, output_list = [], [], []
        pre_input_list, pre_time_list = [], []
        post_input_list, post_time_list = [], []
        input_split = [x.strip().split(',') for x in init_input_data]

        for iline in range(len(input_data)):
            inp = input_data[iline].strip()
            oup = output_data[iline].strip()
            init_inp = init_input_data[iline].strip()

            if iline == 0:
                feat_list = inp.split(',')
            else:
                in_vs = inp.split(',')
                ou_vs = oup.split(',')
                init_vs = init_inp.split(',')
                ctime = int(inp.split(',')[0])
                mask, input, output = [], [], []
                rd = np.random.random(len(in_vs))
                for i, (iv, ov, ir,
                        init_iv) in enumerate(zip(in_vs, ou_vs, rd, init_vs)):
                    if ir < 0.7:
                        # iv = 'NA'
                        pass

                    if init_iv not in ['NA', '']:
                        mask.append(0)
                    elif ov not in ['NA', '']:
                        # print('err')
                        mask.append(1)
                    else:
                        mask.append(-1)
                    if self.args.use_ve:
                        input.append(self.map_input(iv, feat_list, i))
                    else:
                        input.append(self.map_output(iv, feat_list, i))
                    output.append(self.map_output(ov, feat_list, i))
                mask_list.append(mask)
                input_list.append(input)
                output_list.append(output)
                # pre and post info
                pre_input, pre_time = self.get_pre_info(
                    input_split, iline, feat_list)
                pre_input_list.append(pre_input)
                pre_time_list.append(pre_time)
                post_input, post_time = self.get_post_info(
                    input_split, iline, feat_list)
                post_input_list.append(post_input)
                post_time_list.append(post_time)

        if len(mask_list) < self.args.n_visit:
            for _ in range(self.args.n_visit - len(mask_list)):
                # pad empty visit
                mask = [-1 for _ in range(self.args.output_size + 1)]
                vs = [0 for _ in range(self.args.output_size + 1)]
                mask_list.append(mask)
                input_list.append(vs)
                output_list.append(vs)
                pre_input_list.append(vs)
                pre_time_list.append(vs)
                post_input_list.append(vs)
                post_time_list.append(vs)
            # print(mask_list)
        else:
            mask_list = mask_list[:self.args.n_visit]
            input_list = input_list[:self.args.n_visit]
            output_list = output_list[:self.args.n_visit]
            pre_input_list = pre_input_list[:self.args.n_visit]
            pre_time_list = pre_time_list[:self.args.n_visit]
            post_input_list = post_input_list[:self.args.n_visit]
            post_time_list = post_time_list[:self.args.n_visit]

        # print(mask_list)
        mask_list = np.array(mask_list, dtype=np.int64)
        output_list = np.array(output_list, dtype=np.float32)
        pre_time_list = np.array(pre_time_list, dtype=np.int64)
        post_time_list = np.array(post_time_list, dtype=np.int64)
        if self.args.value_embedding == 'no' or self.args.use_ve == 0:
            input_list = np.array(input_list, dtype=np.float32)
            pre_input_list = np.array(pre_input_list, dtype=np.float32)
            post_input_list = np.array(post_input_list, dtype=np.float32)
        else:
            input_list = np.array(input_list, dtype=np.int64)
            pre_input_list = np.array(pre_input_list, dtype=np.int64)
            post_input_list = np.array(post_input_list, dtype=np.int64)

        input_list = input_list[:, 1:]
        output_list = output_list[:, 1:]
        mask_list = mask_list[:, 1:]
        pre_input_list = pre_input_list[:, 1:]
        pre_time_list = pre_time_list[:, 1:]
        post_input_list = post_input_list[:, 1:]
        post_time_list = post_time_list[:, 1:]

        time_list = [x[0] for x in input_split][1:]
        max_time = int(time_list[min(self.args.n_visit,
                                     len(time_list) - 1)]) + 1

        if self.args.dataset in ['MIMIC']:
            ehr_dict = py_op.myreadjson(
                os.path.join(
                    input_file.replace('with_missing',
                                       'groundtruth').replace('.csv',
                                                              '.json')))
        else:
            ehr_dict = dict()
        icd_list = [
            self.ehr_id[e] for e in ehr_dict.get('icd_demo', {})
            if e in self.ehr_id
        ]
        icd_list = set(icd_list)
        icd_list = set()
        drug_dict = ehr_dict.get('drug', {})
        visit_dict = dict()
        for i in range(-250, max_time + 1):
            visit_dict[i] = sorted(icd_list)
        for k, drug_list in drug_dict.items():
            stime, etime = k.split(' -- ')
            id_list = list(
                set([self.ehr_id[e] for e in drug_list if e in self.ehr_id]))
            if len(id_list):
                for t in range(max(0, int(stime)), int(etime)):
                    if t in visit_dict:
                        visit_dict[t] = visit_dict[t] + id_list
        for k, v in visit_dict.items():
            v = list(set(v))
            visit_dict[k] = v
            # if self.n_dd < len(v):
            #     self.n_dd = max(self.n_dd, len(v))
            #     print(self.n_dd)
        dd_list = np.zeros((len(input_list), self.n_dd), dtype=np.int64)
        for i, t in enumerate(time_list[:self.args.n_visit]):
            if int(t) not in visit_dict:
                continue
            id_list = visit_dict[int(t)]
            if len(id_list):
                id_list = np.array(id_list, dtype=np.int64)
                if len(id_list) > self.n_dd:
                    np.random.shuffle(id_list)
                    dd_list[i] = id_list[-self.n_dd:]
                else:
                    dd_list[i][:len(id_list)] = id_list

        # assert pre_time_list.max() < 256
        # assert post_time_list.max() < 256
        assert pre_time_list.min() >= 0
        assert post_time_list.min() >= 0
        pre_time_list[pre_time_list > 200] = 200
        post_time_list[post_time_list > 200] = 200
        assert len(mask_list[0]) == self.args.output_size
        assert len(mask_list[0]) == len(pre_input_list[0])

        # print(input_list.shape)
        return torch.from_numpy(input_list), torch.from_numpy(output_list), torch.from_numpy(mask_list), input_file,\
                torch.from_numpy(pre_input_list), torch.from_numpy(pre_time_list), torch.from_numpy(post_input_list), \
                torch.from_numpy(post_time_list), torch.from_numpy(dd_list)