Esempio n. 1
0
def toplayer(args):
    from dataset import transform
    from data_process import name2index
    name2idx = name2index(config.arrythmia)
    idx2name = {idx: name for name, idx in name2idx.items()}
    utils.mkdirs(config.sub_dir)
    model = models.myecgnet()
    model.load_state_dict(
        torch.load(args.ckpt, map_location='cpu')['state_dict'])
    model = model.to(device)
    model.eval()
    sub_file = '%s.txt' % args.ex
    fout = open(sub_file, 'w', encoding='utf-8')
    with torch.no_grad():
        for line in tqdm(open(config.train_label, encoding='utf-8')):
            fout.write(line.strip('\n'))
            id = line.split('\t')[0]
            file_path = os.path.join(config.train_dir, id)
            df = pd.read_csv(file_path, sep=' ').values
            x = transform(df).reshape((1, 1, 8, 2500)).to(device)
            output = torch.sigmoid(model(x)).squeeze().cpu().numpy()
            for i in output:
                fout.write("\t" + str(i))
            fout.write('\n')
    fout.close()
Esempio n. 2
0
def test_new(args):
    from dataset import transform
    from data_process import name2index, sex2Index, age2Index
    name2idx = name2index(config.arrythmia)
    idx2name = {idx: name for name, idx in name2idx.items()}
    sex2idx = sex2Index(config.test_label)
    age2idx = age2Index(config.test_label)
    utils.mkdirs(config.sub_dir)
    # model
    model = ResMlp(ResMlpParams)
    model.load_state_dict(
        torch.load(args.ckpt, map_location='cpu')['state_dict'])
    model = model.to(device)
    model.eval()
    sub_file = '%s/subA_%s.txt' % (config.sub_dir, time.strftime("%Y%m%d%H%M"))
    fout = open(sub_file, 'w', encoding='utf-8')
    with torch.no_grad():
        for line in open(config.test_label, encoding='utf-8'):
            fout.write(line.strip('\n'))
            row = line.split('\t')
            id = row[0]
            sex, age = sex2idx[id], age2idx[id]
            file_path = os.path.join(config.test_dir, id)
            df = pd.read_csv(file_path, sep=' ').values
            x = transform(df).unsqueeze(0).to(device)  # 因为conv1d需要三维的向量才能读进来
            sex, age = sex.unsqueeze(0).to(device), age.unsqueeze(0).to(device)
            output = torch.sigmoid(model(x, sex, age)).squeeze().cpu().numpy()
            ixs = [i for i, out in enumerate(output) if out > 0.5]
            for i in ixs:
                fout.write("\t" + idx2name[i])
            fout.write('\n')
    fout.close()
Esempio n. 3
0
def test(args):
    from dataset import transform
    from data_process import name2index
    name2idx = name2index(config.arrythmia)
    idx2name = {idx: name for name, idx in name2idx.items()}
    utils.mkdirs(config.sub_dir)
    model = models.myecgnet()
    model.load_state_dict(
        torch.load(args.ckpt, map_location='cpu')['state_dict'])
    model = model.to(device)
    model.eval()
    sub_file = 'result.txt'
    fout = open(sub_file, 'w', encoding='utf-8')
    with torch.no_grad():
        for line in open(config.test_label, encoding='utf-8'):
            fout.write(line.strip('\n'))
            id = line.split('\t')[0]
            file_path = os.path.join(config.test_dir, id)
            df = pd.read_csv(file_path, sep=' ').values
            x = transform(df).unsqueeze(0).to(device)
            output = torch.sigmoid(model(x)).squeeze().cpu().numpy()
            ixs = [i for i, out in enumerate(output) if out > 0.5]
            for i in ixs:
                fout.write("\t" + idx2name[i])
                print(i, end=',')
            fout.write('\n')
    fout.close()
    print('\n', end='')
Esempio n. 4
0
def test(args):
    from dataset import ECGDataset_test
    from data_process import name2index
    name2idx = name2index(config.arrythmia)
    idx2name = {idx: name for name, idx in name2idx.items()}

    sub_file = 'result.txt'
    fout = open(sub_file, 'w', encoding='utf-8')
    config.test_label = args.test_label
    config.test_dir = args.test_dir
    path_all = []
    for line in open(config.test_label, encoding='utf-8'):
        #            fout.write(line.strip('\n'))
        id = line.split('\t')[0]
        file_path = os.path.join(config.test_dir, id)
        path_all.append(file_path)
    test_dataset = ECGDataset_test(path_all)
    test_load = DataLoader(test_dataset,
                           batch_size=64,
                           num_workers=6,
                           shuffle=False)

    out = np.zeros((len(test_dataset), 34))
    print(len(test_dataset))
    for fold in range(1):
        model_save_dir = '%s/%s' % (config.ckpt,
                                    config.model_name + '_' + str(fold))
        print(model_save_dir)
        model = getattr(models, config.model_name)()
        model.load_state_dict(
            torch.load(os.path.join(model_save_dir, config.best_w),
                       map_location='cpu')['state_dict'])
        model = model.to(device)
        model.eval()
        with torch.no_grad():

            if torch.cuda.is_available():
                pred_all = torch.Tensor().cuda()
            else:
                pred_all = torch.Tensor()

            for inputs, target in tqdm.tqdm(test_load):
                inputs = inputs.to(device)
                output, _ = model(inputs)
                output = torch.sigmoid(output)
                pred_all = torch.cat((pred_all, output), 0)
            out += pred_all.cpu().numpy()

    out = out  #/ 5
    num = 0
    for line in open(config.test_label, encoding='utf-8'):
        fout.write(line.strip('\n'))
        output = out[num]
        ixs = [i for i, out in enumerate(output) if out > 0.5]
        for i in ixs:
            fout.write("\t" + idx2name[i])
        fout.write('\n')
        num = num + 1
    fout.close()
Esempio n. 5
0
def test(args):
    if config.kind == 2 and config.top4_catboost:
        top4_catboost_test(args)  # catboost
    else:
        from dataset import transform
        from data_process import name2index
        name2idx = name2index(config.arrythmia)
        idx2name = {idx: name for name, idx in name2idx.items()}
        utils.mkdirs(config.sub_dir)
        # model
        model = getattr(models,
                        config.model_name)(num_classes=config.num_classes,
                                           channel_size=config.channel_size)
        model.load_state_dict(
            torch.load(args.ckpt, map_location='cpu')['state_dict'])
        model = model.to(device)
        model.eval()
        sub_file = '%s/subA_%s.txt' % (config.sub_dir,
                                       time.strftime("%Y%m%d%H%M"))
        fout = open(sub_file, 'w', encoding='utf-8')
        print(sub_file)
        with torch.no_grad():
            for line in open(config.test_label, encoding='utf-8'):
                fout.write(line.strip('\n'))
                line = line.strip('\n')
                id = line.split('\t')[0]
                age = line.split('\t')[1]
                sex = line.split('\t')[2]
                if len(age) < 1:
                    age = '-999'
                age = int(age)
                sex = {'FEMALE': 0, 'MALE': 1, '': -999}[sex]
                file_path = os.path.join(config.test_dir, id)
                df = utils.read_csv(file_path,
                                    sep=' ',
                                    channel_size=config.channel_size)
                x = transform(df.values).unsqueeze(0).to(device)
                fr = torch.tensor([age, sex],
                                  dtype=torch.float32).unsqueeze(0).to(device)
                if config.kind == 1:
                    output = torch.sigmoid(model(x,
                                                 fr)).squeeze().cpu().numpy()
                elif config.kind == 2:
                    output, out2 = model(x)
                    output = torch.sigmoid(output).squeeze().cpu().numpy()
                    if config.top4_DeepNN:
                        output = output[config.top4_tag_list]
                else:
                    output = torch.sigmoid(model(x)).squeeze().cpu().numpy()

                ixs = [i for i, out in enumerate(output) if out > 0.5]
                for i in ixs:
                    fout.write("\t" + idx2name[i])
                fout.write('\n')
        fout.close()
Esempio n. 6
0
def top4_make_dateset(model, dataloader, file_path, save=True):
    print('make dataset')
    if os.path.exists(file_path):
        val_df = pd.read_csv(file_path, index_col=None)
        return val_df
    values = []
    if config.top4_DeepNN:
        model.eval()
        with torch.no_grad():
            n = 0
            for inputs, fr, target, other_f in dataloader:
                inputs = inputs.to(device)
                output, out1 = model(inputs)
                output = torch.sigmoid(output).cpu()
                out1 = out1.cpu()
                if config.top4_DeepNN_tag:
                    output = output[:, config.top4_tag_list]
                # output, out1 = torch.zeros(64, 10), torch.ones(64, 20)
                vi = torch.cat([output, out1, other_f, fr, target], dim=1)
                values.append(vi)
                # n += 1
                # if n == 100:
                #     break
    else:
        n = 0
        for inputs, fr, target, other_f in dataloader:
            output, out1 = torch.zeros(inputs.shape[0],
                                       0), torch.ones(inputs.shape[0], 0)
            vi = torch.cat([output, out1, other_f, fr, target], dim=1)
            values.append(vi)
            # n += 1
            # if n == 100:
            #     break
    values = torch.cat(values, dim=0)
    columnslist = []
    columnslist += ['dnn1_%d' % i for i in range(output.size(1))]
    columnslist += ['dnn2_%d' % i for i in range(out1.size(1))]
    print('len_dnn_feature', len(columnslist))
    columnslist += ['other_f_%d' % i for i in range(other_f.size(1))]
    columnslist += ['sex', 'age']
    name2idx = name2index(config.arrythmia)
    idx2name = {idx: name for name, idx in name2idx.items()}
    columnslist += [idx2name[i] for i in range(target.size(1))]
    df = pd.DataFrame(values.numpy(), columns=columnslist)
    df[columnslist[-2 - 55:]] = df[columnslist[-2 - 55:]].astype(int)
    if save:
        df.to_csv(file_path, index=None)
    return df
Esempio n. 7
0
def test_ensemble(args):
    from dataset import transform
    from data_process import name2index, get_arrythmias, get_dict
    # arrythmias = get_arrythmias(config.arrythmia)
    # name2idx,idx2name = get_dict(arrythmias)
    name2idx = name2index(config.arrythmia)
    idx2name = {idx: name for name, idx in name2idx.items()}
    #utils.mkdirs(config.sub_dir)
    num_clases = 34
    kfold = len(config.model_names)
    # model
    model = []
    for fold in range(kfold):
        model.append(getattr(models, config.model_names)())
    for fold in range(kfold):
        model[fold].load_state_dict(
            torch.load(os.path.join(args.ckpt, config.model_ckpts[fold],
                                    "best_weight.pth"),
                       map_location='cpu')['state_dict'])
        model[fold] = model[fold].to(device)
        model[fold].eval()

    #sub_file = '%s/subB_%s.txt' % (config.sub_dir, time.strftime("%Y%m%d%H%M"))
    sub_file = './result.txt'
    fout = open(sub_file, 'w', encoding='utf-8')
    with torch.no_grad():
        for line in tqdm(open(config.test_label, encoding='utf-8')):
            fout.write(line.strip('\n'))
            id = line.split('\t')[0]
            file_path = os.path.join(config.test_dir, id)
            df = pd.read_csv(file_path, sep=' ')
            df['III'] = df['II'] - df['I']
            df['aVR'] = -(df['I'] + df['II']) / 2
            df['aVL'] = df['I'] - df['II'] / 2
            df['aVF'] = df['II'] - df['I'] / 2
            x = transform(df.values).unsqueeze(0).to(device)
            output = 0  #np.zeros(num_clases)
            for fold in range(kfold):
                output += torch.sigmoid(model[fold](x)).squeeze().cpu().numpy()
            output = output / kfold
            ixs = [i for i, out in enumerate(output) if out > 0.5]
            for i in ixs:
                fout.write("\t" + idx2name[i])
            fout.write('\n')
    fout.close()
Esempio n. 8
0
def test(args):
    from dataset import transform
    from data_process import name2index
    import pickle
    from tqdm import tqdm

    test_age_sex = pickle.load(open(config.test_age_sex, 'rb'))

    name2idx = name2index(config.arrythmia)
    idx2name = {idx: name for name, idx in name2idx.items()}
    utils.mkdirs(config.sub_dir)
    # model
    model = getattr(models, config.model_name)()
    model.load_state_dict(torch.load(args.ckpt, map_location='cpu')['state_dict'])
    model = model.to(device)
    model.eval()
    sub_file = '%s/subA_%s.txt' % (config.sub_dir, time.strftime("%Y%m%d%H%M"))
    fout = open(sub_file, 'w', encoding='utf-8')
    with torch.no_grad():
        for line in tqdm(open(config.test_label, encoding='utf-8')):
            fout.write(line.strip('\n'))
            id = line.split('\t')[0]
            file_path = os.path.join(config.test_dir, id)
            df = pd.read_csv(file_path, sep=' ').values

            df = add_4(df)
            age = test_age_sex[id.split('.')[0]]['age']
            sex = test_age_sex[id.split('.')[0]]['sex']

            age = torch.tensor(age.copy(), dtype=torch.float).unsqueeze(0).to(device)
            sex = torch.tensor(sex.copy(), dtype=torch.float).unsqueeze(0).to(device)

            x = transform(df).unsqueeze(0).to(device)
            output = torch.sigmoid(model(x, age, sex)).squeeze().cpu().numpy()
            ixs = [i for i, out in enumerate(output) if out > 0.5]
            for i in ixs:
                fout.write("\t" + idx2name[i])
            fout.write('\n')
    fout.close()
Esempio n. 9
0
def deep_predict():
    from dataset import transform
    from data_process import name2index
    name2idx = name2index(config.arrythmia)
    idx2name = {idx: name for name, idx in name2idx.items()}
    # model
    model = getattr(models, config.model_name)()
    model.load_state_dict(torch.load(os.path.join(r'ckpt/resnet34', config.best_w), map_location='cpu')['state_dict'])
    # best_w = torch.load(os.path.join(r'ckpt/resnet34', config.best_w))
    # model.load_state_dict(best_w['state_dict'])
    model = model.to(device)
    model.eval()

    test_data = []
    for line in open(config.test_label, encoding='utf-8'):
        id = line.split('\t')[0]
        file_path = os.path.join(config.test_dir, id)
        test_data.append(file_path)

    print(len(test_data))
    from dataset import ECGDataTest
    test_dataset = ECGDataTest(test_data)

    from multiprocessing import cpu_count
    print(cpu_count())
    processes = cpu_count() - 1
    if processes < 0:
        processes = 1

    test_dataloader = DataLoader(test_dataset, batch_size=64, num_workers=processes)

    output_list = np.empty(shape=[0, config.num_classes])
    with torch.no_grad():
        for inputs in test_dataloader:
            inputs = inputs.to(device)
            output = torch.sigmoid(model(inputs)).cpu().numpy()
            output_list = np.vstack((output_list, output))
    return np.array(output_list)
Esempio n. 10
0
 def test(self):
     name2idx = name2index('./hf_round2_arrythmia.txt')
     idx2name = {idx: name for name, idx in name2idx.items()}
     print(idx2name)
     # utils.mkdirs(self.opt.sub_dir)
     # model
     # model_save_dir = self.opt.ckpt
     # best_w = torch.load(os.path.join(model_save_dir, self.opt.best_w))
     # best_w = torch.load('./best_w.pth')
     best_w = './best_w.pth'
     model = getattr(models, self.opt.model_name)()
     model.load_state_dict(
         torch.load(best_w, map_location='cpu')['state_dict'])
     model = model.to(self.device)
     model.eval()
     # sub_file = '%s/result.txt' % (self.opt.sub_dir)
     sub_file = './result.txt'
     fout = open(sub_file, 'w', encoding='utf-8')
     with torch.no_grad():
         for line in open(self.opt.test_label, encoding='utf-8'):
             fout.write(line.strip('\n'))
             id = line.split('\t')[0]
             file_path = os.path.join(self.opt.test_dir, id)
             # df = pd.read_csv(file_path, sep=' ').values
             df = pd.read_csv(file_path, sep=' ')
             df = add_feature(df).values
             x = transform(df).unsqueeze(0).to(self.device)
             output = torch.sigmoid(model(x)).squeeze().cpu().numpy()
             ixs = [i for i, out in enumerate(output) if out > 0.5]
             print(id, ixs)
             for i in ixs:
                 # print(i, idx2name[i])
                 if i < len(idx2name):
                     fout.write("\t" + idx2name[i])
             fout.write('\n')
     print("writing finish.")
     fout.close()
Esempio n. 11
0
    hh[:, i] = hh[:, i] > lim[i]
ecoc = np.array(pd.read_csv('ecoc.csv', header=None))
ecout = np.zeros_like(hh)
minhmd = np.zeros((h.shape[0], ))
h = np.concatenate((h, hh), axis=1)
for i in range(h.shape[0]):
    minhmd[i] = ecoc.shape[1] + 1
    for j in range(ecoc.shape[0]):
        minhmd[i] = min(np.sum(h[i] != ecoc[j]), minhmd[i])
    for j in range(ecoc.shape[0]):
        if (np.sum(h[i] != ecoc[j]) == minhmd[i]):
            ecout[i, j] = 1
    if (np.sum(ecout[i]) > 5):
        ecout[i] = hh[i]
print(np.sum(ecout, axis=0).astype('int'))
np.savetxt('ecout.txt', ecout, '%d', '\t')
cur = 0
name2idx = name2index(config.arrythmia)
idx2name = {idx: name for name, idx in name2idx.items()}
sub_file = './result.txt'
fout = open(sub_file, 'w', encoding='utf-8')
for line in open(config.test_label, encoding='utf-8'):
    fout.write(line.strip('\n'))
    output = ecout[cur]
    ixs = [i for i, out in enumerate(output) if out > 0.5]
    for i in ixs:
        fout.write("\t" + idx2name[i])
    fout.write('\n')
    cur = cur + 1
fout.close()
Esempio n. 12
0
from options import Options
from dataset import load_data
from model import Ecg
from data_process import train, name2index
##
# def main():
""" Training
"""
# torch.manual_seed(config.seed)
# torch.cuda.manual_seed(config.seed)
##

# ARGUMENTS
opt = Options().parse()
print(opt)
name2idx = name2index(opt.arrythmia)

idx2name = {idx: name for name, idx in name2idx.items()}

train(name2idx, idx2name, config=opt)
##
# LOAD DATA
train_dataset, train_dataloader, val_dataset, val_dataloader = load_data(opt)

##
# LOAD MODEL
model = Ecg(opt, train_dataset, train_dataloader, val_dataset, val_dataloader)

##
# TRAIN MODEL
model.train()
Esempio n. 13
0
def top4_catboost_test(args):
    print('top4_catboost_test')
    from dataset import transform
    from data_process import name2index
    name2idx = name2index(config.arrythmia)
    idx2name = {idx: name for name, idx in name2idx.items()}
    utils.mkdirs(config.sub_dir)
    if config.top4_DeepNN:
        # model
        model = getattr(models,
                        config.model_name)(num_classes=config.num_classes,
                                           channel_size=config.channel_size)
        model.load_state_dict(
            torch.load(args.ckpt, map_location='cpu')['state_dict'])
        model = model.to(device)
        model.eval()
        print(config.model_name, args.ckpt)
    else:
        model = None
        print('no', config.model_name)
    sub_file = '%s/subA_%s.txt' % (config.sub_dir, time.strftime("%Y%m%d%H%M"))
    print(sub_file)
    fout = open(sub_file, 'w', encoding='utf-8')
    if config.kind == 2:
        model_list = load_model_list(
            os.path.join(config.ckpt, config.top4_catboost_model))
    with torch.no_grad():
        for line in open(config.test_label, encoding='utf-8'):
            fout.write(line.strip('\n'))
            line = line.strip('\n')
            id = line.split('\t')[0]
            age = line.split('\t')[1]
            sex = line.split('\t')[2]
            if len(age) < 1:
                age = '-999'
            age = int(age)
            sex = {'FEMALE': 0, 'MALE': 1, '': -999}[sex]
            file_path = os.path.join(config.test_dir, id)
            df = utils.read_csv(file_path,
                                sep=' ',
                                channel_size=config.channel_size)
            fr = torch.tensor([age, sex], dtype=torch.float32)
            if config.top4_DeepNN:
                x = transform(df.values).unsqueeze(0).to(device)
                output, out1 = model(x)
                output = torch.sigmoid(output).squeeze().cpu().numpy()
                out1 = out1.squeeze().cpu().numpy()
                if config.top4_DeepNN_tag:
                    output = output[config.top4_tag_list]
            else:
                output, out1 = torch.zeros(0).numpy(), torch.ones(0).numpy()
            r_features_file = os.path.join(config.r_test_dir,
                                           id.replace('.txt', '.fea'))
            other_f = get_other_features(df, r_features_file)
            df_values = np.concatenate((output, out1, other_f, fr))
            columnslist = []
            columnslist += ['dnn1_%d' % i for i in range(len(output))]
            columnslist += ['dnn2_%d' % i for i in range(len(out1))]
            # print('len_dnn_feature', len(columnslist))
            columnslist += ['other_f_%d' % i for i in range(len(other_f))]
            columnslist += ['sex', 'age']
            df = pd.DataFrame(df_values.reshape(1, -1), columns=columnslist)
            df[df.columns[config.top4_cat_features]] = df[df.columns[
                config.top4_cat_features]].astype(int)
            # for cindex in config.top4_cat_features:
            #     df[df.columns[cindex]] = df[df.columns[cindex]].astype(int)
            output = model_list_predict(model_list, df).squeeze()
            # print(output)

            ixs = [i for i, out in enumerate(output) if out > 0.5]
            for i in ixs:
                fout.write("\t" + idx2name[i])
            fout.write('\n')
    fout.close()