예제 #1
0
파일: main.py 프로젝트: zhangxiaowbl/DG-RNN
def main():
    dataset = data_loader.DataBowl(args, phase='train')
    train_loader = DataLoader(dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.workers,
                              pin_memory=True)
    dataset = data_loader.DataBowl(args, phase='valid')
    valid_loader = DataLoader(dataset,
                              batch_size=args.batch_size,
                              shuffle=False,
                              num_workers=args.workers,
                              pin_memory=True)
    args.vocab = dataset.vocab
    args.relation = dataset.relation

    # net, loss = model.Net(args), model.Loss()
    net, loss = model.FCModel(args), model.Loss()

    net = _cuda(net, 0)
    loss = _cuda(loss, 0)

    parameters_all = []
    for p in net.parameters():
        parameters_all.append(p)

    optimizer = torch.optim.Adam(parameters_all, args.lr)

    best_auc = [0, 0, 0, 0, 0, 0]

    cui_con_dict = {}
    if args.phase == 'train':
        for epoch in range(args.epochs):
            train(train_loader, net, loss, epoch, optimizer, best_auc)
            best_auc, cui_con_dict = test(valid_loader, net, loss, epoch,
                                          best_auc, 'valid', cui_con_dict)
            print args.words

        if 1:
            cons_dir = '../result/cons/{:s}/{:d}'.format(
                args.model, args.predict_day)
            py_op.mkdir(cons_dir)
            num = len(os.listdir(cons_dir))
            py_op.mywritejson(os.path.join(cons_dir, '{:d}.json'.format(num)),
                              cui_con_dict)
            # break

        print 'best auc', best_auc
        auc = best_auc[0]
        with open('../result/log.txt', 'a') as f:
            f.write('#model {:s} #auc {:3.4f}\n'.format(args.model, auc))

    elif args.phase == 'test':
        net.load_state_dict(torch.load(args.resume))
        test(valid_loader, net, loss, 0, best_auc, 'valid', cui_con_dict)
예제 #2
0
def get_data():
    vocab_list = py_op.myreadjson(
        os.path.join(args.data_dir, args.dataset,
                     args.dataset[:-4].lower() + 'vocab.json'))
    aid_year_dict = py_op.myreadjson(
        os.path.join(args.data_dir, args.dataset, 'aid_year_dict.json'))
    pid_aid_did_dict = py_op.myreadjson(
        os.path.join(args.data_dir, args.dataset, 'pid_aid_did_dict.json'))
    pid_demo_dict = py_op.myreadjson(
        os.path.join(args.data_dir, args.dataset, 'pid_demo_dict.json'))
    case_control_data = py_op.myreadjson(
        os.path.join(args.data_dir, args.dataset, 'case_control_data.json'))
    case_test = list(
        set(case_control_data['case_test'] + case_control_data['case_valid']))
    case_control_dict = case_control_data['case_control_dict']
    dataset = data_loader.DataBowl(args, phase='DGVis')
    id_name_dict = py_op.myreadjson(
        os.path.join(args.file_dir, 'id_name_dict.json'))

    # select patients with higher ration in knowledge graph
    selected_pid_set = set(pid_aid_did_dict)
    test_set = set()
    for case in case_test:
        test_set.add(case)
        for con in case_control_dict[case]:
            test_set.add(con)
예제 #3
0
def get_model(model_file=os.path.join(args.result_dir, 'mimic-kg-gp.ckpt')):
    dataset = data_loader.DataBowl(args, phase='valid')
    args.vocab = dataset.vocab
    args.relation = dataset.relation

    net, _ = model.FCModel(args), model.Loss()
    net = _cuda(net)
    net.load_state_dict(torch.load(model_file))
    return net
예제 #4
0
def get_model(model_file, use_kg):
    dataset = data_loader.DataBowl(args, phase='valid')
    args.vocab = dataset.vocab
    args.relation = dataset.relation

    net, _ = model.FCModel(args, use_kg), model.Loss()
    net = _cuda(net)
    # return net
    try:
        net.load_state_dict(torch.load(model_file))
    except:
        # print(os.path.exists(model_file))
        d = torch.load(model_file, map_location=torch.device('cpu'))
        for k, v in d.items():
            d[k] = v.cpu()
            # print(k, type(v))
        net.load_state_dict(d)
    return net
예제 #5
0
def main():
    args.n_ehr = len(
        json.load(
            open(os.path.join(args.files_dir, 'demo_index_dict.json'),
                 'r'))) + 10
    args.name_list = json.load(
        open(os.path.join(args.files_dir, 'feature_list.json'), 'r'))[1:]
    args.input_size = len(args.name_list)
    files = sorted(glob(os.path.join(args.data_dir, 'resample_data/*.csv')))
    data_splits = json.load(
        open(os.path.join(args.files_dir, 'splits.json'), 'r'))
    train_files = [
        f for idx in [0, 1, 2, 3, 4, 5, 6] for f in data_splits[idx]
    ]
    valid_files = [f for idx in [7] for f in data_splits[idx]]
    test_files = [f for idx in [8, 9] for f in data_splits[idx]]
    if args.phase == 'test':
        train_phase, valid_phase, test_phase, train_shuffle = 'test', 'test', 'test', False
    else:
        train_phase, valid_phase, test_phase, train_shuffle = 'train', 'valid', 'test', True
    train_dataset = data_loader.DataBowl(args, train_files, phase=train_phase)
    valid_dataset = data_loader.DataBowl(args, valid_files, phase=valid_phase)
    test_dataset = data_loader.DataBowl(args, test_files, phase=test_phase)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=train_shuffle,
                              num_workers=args.workers,
                              pin_memory=True)
    valid_loader = DataLoader(valid_dataset,
                              batch_size=args.batch_size,
                              shuffle=False,
                              num_workers=args.workers,
                              pin_memory=True)
    test_loader = DataLoader(test_dataset,
                             batch_size=args.batch_size,
                             shuffle=False,
                             num_workers=args.workers,
                             pin_memory=True)

    args.vocab_size = args.input_size + 2

    if args.use_unstructure:
        args.unstructure_size = len(
            py_op.myreadjson(os.path.join(args.files_dir,
                                          'vocab_list.json'))) + 10

    # net = icnn.CNN(args)
    # net = cnn.CNN(args)
    net = lstm.LSTM(args)
    # net = torch.nn.DataParallel(net)
    # loss = myloss.Loss(0)
    loss = myloss.MultiClassLoss(0)

    net = _cuda(net, 0)
    loss = _cuda(loss, 0)

    best_metric = [0, 0]
    start_epoch = 0

    if args.resume:
        p_dict = {'model': net}
        function.load_model(p_dict, args.resume)
        best_metric = p_dict['best_metric']
        start_epoch = p_dict['epoch'] + 1

    parameters_all = []
    for p in net.parameters():
        parameters_all.append(p)

    optimizer = torch.optim.Adam(parameters_all, args.lr)

    if args.phase == 'train':
        for epoch in range(start_epoch, args.epochs):
            print('start epoch :', epoch)
            t0 = time.time()
            train_eval(train_loader, net, loss, epoch, optimizer, best_metric)
            t1 = time.time()
            print('Running time:', t1 - t0)
            best_metric = train_eval(valid_loader,
                                     net,
                                     loss,
                                     epoch,
                                     optimizer,
                                     best_metric,
                                     phase='valid')
        print('best metric', best_metric)

    elif args.phase == 'test':
        train_eval(test_loader, net, loss, 0, optimizer, best_metric, 'test')
예제 #6
0
def main():

    assert args.dataset in ['DACMI', 'MIMIC']
    if args.dataset == 'MIMIC':
        args.n_ehr = len(py_op.myreadjson(os.path.join(args.data_dir, args.dataset, 'ehr_list.json')))
    args.name_list = py_op.myreadjson(os.path.join(args.file_dir, args.dataset+'_feature_list.json'))[1:]
    args.output_size = len(args.name_list)
    files = sorted(glob(os.path.join(args.data_dir, args.dataset, 'train_with_missing/*.csv')))
    data_splits = py_op.myreadjson(os.path.join(args.file_dir, args.dataset + '_splits.json'))
    train_files = [f for idx in [0, 1, 2, 3, 4, 5, 6] for f in data_splits[idx]]
    valid_files = [f for idx in [7] for f in data_splits[idx]]
    test_files = [f for idx in [8, 9] for f in data_splits[idx]]
    if args.phase == 'test':
        train_phase, valid_phase, test_phase, train_shuffle = 'test', 'test', 'test', False
    else:
        train_phase, valid_phase, test_phase, train_shuffle = 'train', 'valid', 'test', True
    train_dataset = data_loader.DataBowl(args, train_files, phase=train_phase)
    valid_dataset = data_loader.DataBowl(args, valid_files, phase=valid_phase)
    test_dataset = data_loader.DataBowl(args, test_files, phase=test_phase)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=train_shuffle, num_workers=args.workers, pin_memory=True)
    valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True)
    args.vocab_size = (args.output_size + 2) * (1 + args.split_num) + 5

    if args.model == 'tame':
        net = tame.AutoEncoder(args)
    loss = myloss.MSELoss(args)

    net = _cuda(net, 0)
    loss = _cuda(loss, 0)

    best_metric= [0,0]
    start_epoch = 0

    if args.resume:
        p_dict = {'model': net}
        function.load_model(p_dict, args.resume)
        best_metric = p_dict['best_metric']
        start_epoch = p_dict['epoch'] + 1

    parameters_all = []
    for p in net.parameters():
        parameters_all.append(p)

    optimizer = torch.optim.Adam(parameters_all, args.lr)

    if args.phase == 'train':
        for epoch in range(start_epoch, args.epochs):
            print('start epoch :', epoch)
            train_eval(train_loader, net, loss, epoch, optimizer, best_metric)
            best_metric = train_eval(valid_loader, net, loss, epoch, optimizer, best_metric, phase='valid')
        print 'best metric', best_metric

    elif args.phase == 'test':
        folder = os.path.join(args.result_dir, args.dataset, 'imputation_result')
        os.system('rm -r ' + folder)
        os.system('mkdir ' + folder)

        train_eval(train_loader, net, loss, 0, optimizer, best_metric, 'test')
        train_eval(valid_loader, net, loss, 0, optimizer, best_metric, 'test')
        train_eval(test_loader, net, loss, 0, optimizer, best_metric, 'test')
예제 #7
0
def get_data():
    vocab_list = py_op.myreadjson(
        os.path.join(args.data_dir, args.dataset,
                     args.dataset[:-4].lower() + 'vocab.json'))
    aid_year_dict = py_op.myreadjson(
        os.path.join(args.data_dir, args.dataset, 'aid_year_dict.json'))
    pid_aid_did_dict = py_op.myreadjson(
        os.path.join(args.data_dir, args.dataset, 'pid_aid_did_dict.json'))
    pid_demo_dict = py_op.myreadjson(
        os.path.join(args.data_dir, args.dataset, 'pid_demo_dict.json'))
    case_control_data = py_op.myreadjson(
        os.path.join(args.data_dir, args.dataset, 'case_control_data.json'))
    case_test = list(
        set(case_control_data['case_test'] + case_control_data['case_valid']))
    case_control_dict = case_control_data['case_control_dict']
    dataset = data_loader.DataBowl(args, phase='DGVis')
    id_name_dict = py_op.myreadjson(
        os.path.join(args.file_dir, 'id_name_dict.json'))
    graph_dict = {'edge': {}, 'node': {}}
    for line in open(os.path.join(args.file_dir, 'relation2id.txt')):
        data = line.strip().split()
        if len(data) == 2:
            relation, id = data[0], int(data[1])
            graph_dict['edge'][id] = relation
    for line in open(os.path.join(args.file_dir, 'entity2id.txt')):
        data = line.strip().split()
        if len(data) == 2:
            cui, id = data[0], int(data[1])
            if cui in id_name_dict:
                graph_dict['node'][id] = id_name_dict[cui]
            else:
                graph_dict['node'][id] = cui

    aid_second_dict = py_op.myreadjson(
        os.path.join(args.data_dir, args.dataset, 'aid_second_dict.json'))

    for pid, aid_did_dict in pid_aid_did_dict.items():
        n = 0
        aids = sorted(aid_did_dict.keys(),
                      key=lambda aid: int(aid),
                      reverse=True)
        for ia, aid in enumerate(aids):
            n += len(aid_did_dict[aid])
            if n > 120:
                pid_aid_did_dict[pid] = {
                    aid: aid_did_dict[aid]
                    for aid in aids[:ia]
                }
                break

    new_pid_demo_dict = dict()
    pid_list = case_test + [
        c for case in case_test for c in case_control_dict[str(case)]
    ]
    pid_list = [str(pid) for pid in pid_list]
    for pid in pid_list:
        pid = str(pid)
        demo = pid_demo_dict[pid]
        gender = demo[0]
        yob = int(demo[2:])
        if pid not in pid_aid_did_dict:
            continue
        aids = pid_aid_did_dict[pid].keys()
        year = max([aid_year_dict[aid] for aid in aids])
        age = year - yob
        assert age < 100 and age > 0
        new_pid_demo_dict[pid] = [gender, age]

    # return data
    # case_control_dict = { case: [c for c in case_control_dict[case] if c in new_pid_demo_dict] for case in case_test if case in new_pid_demo_dict}

    pid_demo_dict = new_pid_demo_dict
    pid_aid_did_dict = {
        pid: pid_aid_did_dict[pid]
        for pid in new_pid_demo_dict
    }

    # print('case_set', case_control_dict.keys())

    return pid_demo_dict, pid_aid_did_dict, aid_second_dict, dataset, set(
        case_control_dict), vocab_list, graph_dict, id_name_dict