def main(args):
    print('********** Intervention Experiment **********')
    p_data, train_loader, valid_loader, test_loader = load_data(batch_size=100,
                                                                path='./data',
                                                                cv=args.cv)
    for id in range(len(intervention_list)):
        summary_stat(intervention_ID=id, patient_data=p_data, cv=args.cv)
        plot_summary_stat(id)
def main(data, generator_type, output_path, predictor_model):
    print('********** Running Generator Baseline Experiment **********')
    with open('config.json') as config_file:
        configs = json.load(config_file)[data]['feature_generator_explainer']

    experiment = 'feature_generator_explainer'
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    if data == 'mimic':
        p_data, train_loader, valid_loader, test_loader = load_data(
            batch_size=configs['batch_size'], path='./data')
        feature_size = p_data.feature_size
    elif data == 'ghg':
        p_data, train_loader, valid_loader, test_loader = load_ghg_data(
            configs['batch_size'])
        feature_size = p_data.feature_size
    elif data == 'simulation_spike':
        p_data, train_loader, valid_loader, test_loader = load_simulated_data(
            batch_size=configs['batch_size'],
            path='./data_generator/data/simulated_data',
            data_type='spike')
        feature_size = p_data.shape[1]

    elif data == 'simulation':
        p_data, train_loader, valid_loader, test_loader = load_simulated_data(
            batch_size=configs['batch_size'], path='./data/simulated_data')
        feature_size = p_data.shape[1]

    testset = list(exp.test_loader.dataset)
    test_signals = torch.stack(([x[0] for x in testset])).to(device)

    true_generator = TrueFeatureGenerator()

    S = 100
    for s in range(S):
        print('generating sample: ', s)
        signal = test_signals[s]
        ffc_sample = np.zeros(
            (test_signals.shape[1], test_signals.shape[-1] * S))
        true_sample = np.zeros(
            (test_signals.shape[1], test_signals.shape[-1] * S))
        for t in range(1, test_signals.shape[-1], 3):
            if t % 3 == 0:
                print('t: ', t)
            ffc_sample_t = exp.generator.forward_joint(
                signal[:, 0:t].unsqueeze(0))
            ffc_sample[:, s * test_signals.shape[-1] +
                       t] = ffc_sample_t.cpu().detach().numpy()[0]
            true_sample[:, s * test_signals.shape[-1] +
                        t] = true_generator.sample(signal[:, 0:t], t)

    for f in range(test_signals.shape[1]):
        ks_stat_f, p_value = stats.ks_2samp(ffc_sample[f, :],
                                            true_sample[f, :])
        print('feature: ', f, 'KS_stat: ', ks_stat_f, 'p_value: ', p_value)
Beispiel #3
0
def main(data, generator_type, all_samples, cv=0):
    print('********** Experiment with the %s data **********' % ("feature_generator_explainer"))
    with open('config.json') as config_file:
        configs = json.load(config_file)[data]["feature_generator_explainer"]

    if data == 'mimic':
        p_data, train_loader, valid_loader, test_loader = load_data(batch_size=configs['batch_size'],
                                                                    path='./data', cv=cv)
        feature_size = p_data.feature_size
        # samples_to_analyze = {'mimic':MIMIC_TEST_SAMPLES, 'simulation':SIMULATION_SAMPLES, 'ghg':[], 'simulation_spike':[]}
    elif data == 'ghg':
        p_data, train_loader, valid_loader, test_loader = load_ghg_data(configs['batch_size'], cv=cv)
        feature_size = p_data.feature_size
    elif data == 'simulation_spike':
        p_data, train_loader, valid_loader, test_loader = load_simulated_data(batch_size=configs['batch_size'],
                                                                              path='./data/simulated_spike_data',
                                                                              data_type='spike', cv=cv)
        feature_size = p_data.shape[1]

    elif data == 'simulation':
        percentage = 100.
        p_data, train_loader, valid_loader, test_loader = load_simulated_data(batch_size=configs['batch_size'],
                                                                              path='./data/simulated_data',
                                                                              percentage=percentage / 100, cv=cv)
        # generator_type = generator_type+'_%d'%percentage
        feature_size = p_data.shape[1]

    exp = FeatureGeneratorExplainer(train_loader, valid_loader, test_loader, feature_size, patient_data=p_data,
                                    generator_hidden_size=configs['encoding_size'], prediction_size=1,
                                    historical=(configs['historical'] == 1),
                                    generator_type=generator_type, data=data,
                                    experiment='feature_generator_explainer_' + generator_type)

    if all_samples:
        print('Experiment on all test data')
        print('Number of test samples: ', len(exp.test_loader.dataset))
        exp.select_top_features(samples_to_analyze = range(0, len(exp.test_loader.dataset) // 2), sub_features=[[0], [1], [2], [0,1], [0,2], [1,2], [0,1,2]])
    else:
        imp = exp.select_top_features(samples_to_analyze[data], sub_features=[[0], [1], [2], [0,1], [0,2], [1,2], [0,1,2]])
        print(imp[1])
Beispiel #4
0
def main(args):
    if args.data == 'simulation':
        feature_size = 3
        data_path = './data/simulated_data'
        n_classes=2
    elif args.data == 'simulation_spike':
        feature_size = 3
        data_path = './data/simulated_spike_data'
        n_classes = 2 # use with state-classifier
        if args.explainer=='retain':
            args.multiclass=True
    elif args.data == 'mimic_int':
        timeseries_feature_size = len(feature_map_mimic)
        n_classes = 4
        task = 'intervention'
        data_path = '/scratch/gobi2/projects/tsx'
        args.multiclass=True
    elif args.data == 'simulation_l2x':
        feature_size = 3
        data_path = './data/simulated_data_l2x'
        n_classes = 2
    elif args.data == 'mimic':
        timeseries_feature_size = len(feature_map_mimic)
        n_classes = 2
        task = 'mortality'
        data_path = '/scratch/gobi2/projects/tsx'
        args.multiclass=True

    if args.data == 'mimic' or args.data=='mimic_int':
        p_data, train_loader, valid_loader, test_loader = load_data(batch_size=100, path=data_path,task=task,cv=0,train_pc=1.)
        feature_size = p_data.feature_size
        x_test = torch.stack(([x[0] for x in list(test_loader.dataset)])).cpu().numpy()
        y_test = torch.stack(([x[1] for x in list(test_loader.dataset)])).cpu().numpy()
        if args.explainer=='lime':
            x_test = x_test[:300]
            y_test = y_test[:300]
    else:
        if args.data=='simulation_l2x' or args.data=='simulation':
            file_name = 'state_dataset_'
        else:
            file_name = ''
        with open(os.path.join(data_path, file_name + 'x_test.pkl'), 'rb') as f:
            x_test = pkl.load(f)
        with open(os.path.join(data_path, file_name +'y_test.pkl'), 'rb') as f:
            y_test = pkl.load(f)

        ## Select patients with varying state
        # span = []
        # testset = list(test_loader.dataset)
        # model = StateClassifier(feature_size=feature_size, n_state=2, hidden_size=200)
        # model.load_state_dict(torch.load(os.path.join('./ckpt/%s/%s.pt' % (args.data, 'model'))))
        # for i,(signal,label) in enumerate(testset):
        #    model.to(device)
        #    model.eval()
        #    risk=[]
        #    for t in range(1,48):
        #         pred = torch.nn.Softmax(-1)(model(torch.Tensor(signal[:, 0:t]).unsqueeze(0).to(device)))[:, 1]
        #         risk.append(pred.item())
        #    span.append((i,max(risk) - min(risk)))
        # span.sort(key= lambda pair:pair[1], reverse=True)
        # print([xx[0] for xx in span[0:300]])
        # print([xx[1] for xx in span[0:300]])
        # top_patients = [xx[0] for xx in span[0:300]]

    testset = list(test_loader.dataset)
    if args.percentile:
        top_patients = list(range(len(testset)))
    else:
        top_patients = TOP_PATIENTS

    x_test = torch.stack(([x[0] for x_ind, x in enumerate(testset) if x_ind in top_patients])).cpu().numpy()
    y_test = torch.stack(([x[1] for x_ind, x in enumerate(testset) if x_ind in top_patients])).cpu().numpy()


    # importance_path = '/scratch/gobi2/projects/tsx/new_results/%s' % args.data
    importance_path = os.path.join(args.path, args.data)
    #importance_path = '/scratch/gobi2/projects/tsx/new_results/%s' % args.data
    #importance_path = '/scratch/gobi1/shalmali/TSX_results/new_results/%s' % args.data

    if args.data=='simulation_spike':
        activation = torch.nn.Sigmoid()
        if args.explainer=='retain':
            activation = torch.nn.Softmax(-1)
    elif args.data=='mimic_int':
        activation = torch.nn.Sigmoid()
        if args.explainer=='retain':
            raise ValueError('%s explainer not defined for mimic-int!' % args.explainer)
    else:
        activation = torch.nn.Softmax(-1)
        ## Select patients with varying state
        # span = []
        # testset = list(test_loader.dataset)
        # model = StateClassifier(feature_size=feature_size, n_state=2, hidden_size=200)
        # model.load_state_dict(torch.load(os.path.join('./ckpt/%s/%s.pt' % (args.data, 'model'))))
        # for i,(signal,label) in enumerate(testset):
        #    model.to(device)
        #    model.eval()
        #    risk=[]
        #    for t in range(1,48):
        #         pred = torch.nn.Softmax(-1)(model(torch.Tensor(signal[:, 0:t]).unsqueeze(0).to(device)))[:, 1]
        #         risk.append(pred.item())
        #    span.append((i,max(risk) - min(risk)))
        # span.sort(key= lambda pair:pair[1], reverse=True)
        # print([xx[0] for xx in span[0:300]])
        # print([xx[1] for xx in span[0:300]])
        # top_patients = [xx[0] for xx in span[0:300]]

        testset = list(test_loader.dataset)
        if args.percentile:
            top_patients = list(range(len(testset)))
        x_test = torch.stack(([x[0] for x_ind, x in enumerate(testset) if x_ind in top_patients])).cpu().numpy()
        y_test = torch.stack(([x[1] for x_ind, x in enumerate(testset) if x_ind in top_patients])).cpu().numpy()


    # importance_path = '/scratch/gobi2/projects/tsx/new_results/%s' % args.data
    importance_path = os.path.join(args.path, args.data)

    if args.data=='simulation_spike':
        activation = torch.nn.Sigmoid()
        if args.explainer=='retain':
            activation = torch.nn.Softmax(-1)
    elif args.data=='mimic_int':
        activation = torch.nn.Sigmoid()
        if args.explainer=='retain':
            raise ValueError('%s explainer not defined for mimic-int!' % args.explainer)
    else:
        activation = torch.nn.Softmax(-1)

    auc_drop, aupr_drop = [], []
    for cv in [0, 1, 2]:
    #for cv in [0]:
        with open(os.path.join(importance_path, '%s_test_importance_scores_%s.pkl' % (args.explainer, str(cv))),
                  'rb') as f:
            importance_scores = pkl.load(f)

        if args.data=='simulation_spike':
            model = EncoderRNN(feature_size=feature_size, hidden_size=50, regres=True, return_all=False, data=args.data, rnn="GRU")
        elif args.data=='mimic_int':
            model = StateClassifierMIMIC(feature_size=feature_size, n_state=n_classes, hidden_size=128,rnn='LSTM')
        elif args.data=='mimic' or args.data=='simulation' or args.data=='simulation_l2x':
            model = StateClassifier(feature_size=feature_size, n_state=n_classes, hidden_size=200,rnn='GRU')
        model.load_state_dict(torch.load(os.path.join('./ckpt/%s/%s_%d.pt' % (args.data, 'model',cv))))

        #### Plotting
        plot_id = 10
        pred_batch_vec = []
        plot_path = './plots/mimic/'
        t_len = importance_scores[plot_id].shape[-1]
        t = np.arange(1, t_len)
        #model = StateClassifier(feature_size=feature_size, n_state=2, hidden_size=200)
        #model.load_state_dict(torch.load(os.path.join('./ckpt/%s/%s.pt' % (args.data, 'model'))))
        model.eval()
        for tt in t:
            pred_tt = model(torch.Tensor(x_test[plot_id, :, :tt + 1]).unsqueeze(0)).detach().cpu().numpy()
            pred_tt = pred_tt[:,-1]
            pred_batch_vec.append(pred_tt)
        f, axs = plt.subplots(2)

        for i, ref_ind in enumerate(range(x_test[plot_id].shape[0])):
            axs[0].plot(t, x_test[plot_id, ref_ind, 1:], linewidth=3, label='feature %d' % (i))
            axs[1].plot(t, importance_scores[plot_id, ref_ind, 1:], linewidth=3, label='importance %d' % (i))

        axs[0].plot(t, pred_batch_vec, '--', linewidth=3, c='black')
        # axs[0].plot(t, y[plot_id, 1:].cpu().numpy(), '--', linewidth=3, c='red')
        axs[0].tick_params(axis='both', labelsize=36)
        axs[1].tick_params(axis='both', labelsize=36)
        axs[0].margins(0.03)
        axs[1].margins(0.03)

        # axs[0].grid()
        f.set_figheight(80)
        f.set_figwidth(120)
        plt.subplots_adjust(hspace=.5)
        name = args.explainer + '_' + args.generator_type if args.explainer == 'fit' else args.explainer
        plt.savefig(os.path.join(plot_path, '%s_example.pdf' % name), dpi=300, orientation='landscape')
        fig_legend = plt.figure(figsize=(13, 1.2))
        handles, labels = axs[0].get_legend_handles_labels()
        plt.figlegend(handles, labels, loc='upper left', ncol=4, fancybox=True, handlelength=6, fontsize='xx-large')
        fig_legend.savefig(os.path.join(plot_path, '%s_example_legend.pdf' % name), dpi=300, bbox_inches='tight')

        if args.explainer == 'retain':
            if args.data=='mimic_int':
                model = RETAIN(dim_input=feature_size, dim_emb=32, dropout_emb=0.4, dim_alpha=16, dim_beta=16,
                       dropout_context=0.4, dim_output=n_classes)
            elif args.data=='simulation_spike':
                model = RETAIN(dim_input=feature_size, dim_emb=4, dropout_emb=0.4, dim_alpha=16, dim_beta=16,
                       dropout_context=0.4, dim_output=n_classes)
            else:
                model = RETAIN(dim_input=feature_size, dim_emb=128, dropout_emb=0.4, dim_alpha=8, dim_beta=8,
                           dropout_context=0.4, dim_output=2)
            model.load_state_dict(torch.load(os.path.join('./ckpt/%s/%s_%d.pt' % (args.data, 'retain', args.cv))))
        
        model.to(device)
        model.eval()

        if args.subpop:
            span = []
            testset = list(test_loader.dataset)
            for i,(signal,label) in enumerate(testset):
               model.to(device)
               model.eval()
               risk=[]
               if args.data=='mimic':
                    for t in range(1,signal.shape[-1]):
                        pred = activation(model(torch.Tensor(signal[:, 0:t]).unsqueeze(0).to(device)))[:, 1]
                        risk.append(pred.item())
                    span.append((i,max(risk) - min(risk)))
               elif args.data=='mimic_int':
                    for t in range(1,signal.shape[-1]):
                        pred = activation(model(torch.Tensor(signal[:, 0:t]).unsqueeze(0).to(device)))
                        risk.append(pred.detach().cpu().numpy().flatten())
                    span.append((i,np.mean(np.max(risk,0) - np.min(risk,0))))

            span.sort(key= lambda pair:pair[1], reverse=True)
            with open('shift_subsets.pkl','wb') as f:
                pkl.dump(span,f)

            with open('shift_subsets.pkl','rb') as f:
                span = pkl.load(f)

            if args.data=='mimic_int':
                top_patients = [xx[0] for xx in span if xx[1]>0.20]
            elif args.data=='mimic':
                top_patients = [xx[0] for xx in span if xx[1]>0.87]
                top_patients = [xx[0] for xx in span[0:300]]
            #print([xx[1] for xx in span[0:300]])

            if args.explainer =='lime' and args.data=='mimic_int':
                top_patients = [tt for tt in top_patients if tt<300]

            x_test = x_test[top_patients]
            y_test = y_test[top_patients]
            importance_scores = importance_scores[top_patients]

        min_t = 10#25
        max_t = 40
        n_drops = args.n_drops

        y1, y2, label = [], [], []
        q = np.percentile(importance_scores[:, :, min_t:], 95)

        for i,x in enumerate(x_test):
            if (args.explainer=='deep_lift' or args.explainer=='integrated_gradient' or args.explainer=='gradient_shap'):
                importance_scores = np.abs(importance_scores)
            if args.data=='mimic':
                x_cf = x.copy()
                if args.time_imp:
                    for _ in range(n_drops):
                        imp = np.unravel_index(importance_scores[i, :, min_t:max_t].argmax(), importance_scores[i, :, min_t:max_t].shape)
                        importance_scores[i, :, imp[1] + min_t:] = -1
                        x_cf = x_cf[:,:imp[1] + min_t]
                else:
                    if args.percentile:
                        min_t_feat = [np.min(np.where(importance_scores[i, f, min_t:] >= q)[0]) if
                                      len(np.where(importance_scores[i, f, min_t:] >= q)[0]) > 0 else
                                      x.shape[-1] - min_t - 1 for f in range(p_data.feature_size)]
                        for f in range(importance_scores[i].shape[0]):
                            x_cf[f, min_t_feat[f] + min_t:] = x_cf[f, min_t_feat[f] + min_t - 1]
                    else:
                        for _ in range(n_drops):
                            imp = np.unravel_index(importance_scores[i, :, min_t:].argmax(), importance_scores[i, :, min_t:].shape)
                            importance_scores[i, imp[0], imp[1] + min_t:] = -1
                            x_cf[imp[0], imp[1] + min_t:] = x_cf[imp[0], imp[1] + min_t-1]
                label.append(y_test[i])
                if args.explainer=='retain':
                    x_t = torch.Tensor(x).unsqueeze(0).permute(0, 2, 1)
                    x_cf_t = torch.Tensor(x_cf).unsqueeze(0).permute(0, 2, 1)
                    y, _, _ = (model(x_t.to(device), torch.ones((1,)) * x_t.shape[1]))
                    y1.append(torch.nn.Softmax(-1)(y)[0,1].detach().cpu().numpy())
                    y, _, _ = (model(x_cf_t.to(device), torch.ones((1,)) * x_cf_t.shape[1]))
                    y2.append(torch.nn.Softmax(-1)(y)[0,1].detach().cpu().numpy())
                else:
                    y = torch.nn.Softmax(-1)(model(torch.Tensor(x).unsqueeze(0)))[:, 1]  # Be careful! This is fixed for class 1
                    y1.append(y.detach().cpu().numpy())
                    y = torch.nn.Softmax(-1)(model(torch.Tensor(x_cf).unsqueeze(0)))[:, 1]
                    y2.append(y.detach().cpu().numpy())

            elif args.data=='simulation_l2x' or args.data=='simulation':
                imp = np.unravel_index(importance_scores[i,:,min_t:].argmax(), importance_scores[i,:,min_t:].shape)
                if importance_scores[i,imp[0], imp[1]+ min_t]<0:
                    continue
                else:
                    sample = x[:, :imp[1] + min_t + 1]
                    x_cf = sample.copy()
                    x_cf[imp[0], -1] = x_cf[imp[0], -2]
                    label.append(y_test[i,imp[1]+min_t])
                    lengths = (torch.ones((1,)) * x_cf.shape[1])
                    if args.explainer == 'retain':
                        x_t = torch.Tensor(sample).unsqueeze(0).permute(0, 2, 1)
                        x_cf_t = torch.Tensor(x_cf).unsqueeze(0).permute(0, 2, 1)
                        y, _, _ = (model(x_t.to(device), lengths))
                        y1.append(torch.nn.Softmax(-1)(y)[0, 1].detach().cpu().numpy())
                        y, _, _ = (model(x_cf_t.to(device), lengths))
                        y2.append(torch.nn.Softmax(-1)(y)[0, 1].detach().cpu().numpy())
                    else:
                        y = torch.nn.Softmax(-1)(model(torch.Tensor(sample).unsqueeze(0)))[:, 1] # Be careful! This is fixed for class 1
                        y1.append(y.detach().cpu().numpy())

                        y = torch.nn.Softmax(-1)(model(torch.Tensor(x_cf).unsqueeze(0)))[:, 1]
                        y2.append(y.detach().cpu().numpy())
            elif args.data=='simulation_spike':
                imp = np.unravel_index(importance_scores[i,:,min_t:].argmax(), importance_scores[i,:,min_t:].shape)
                sample = x[:, :imp[1]+min_t+1]
                label.append(y_test[i,imp[1]+min_t])

                if args.explainer=='retain':
                    x_t = torch.Tensor(sample).unsqueeze(0).permute(0, 2, 1)                           
                    logit,_,_ = model(torch.Tensor(x_t).to(device), (torch.ones((1,)) * sample.shape[-1]))
                    y = activation(logit)[:,1]
                    y1.append(y.detach().cpu().numpy())
                    x_cf = sample.copy()
                    x_cf[imp[0],-1] = x_cf[imp[0],-2]
                    x_cf_t = torch.Tensor(x_cf).unsqueeze(0).permute(0, 2, 1)                           
                    logit,_,_ = model(torch.Tensor(x_cf_t).to(device), (torch.ones((1,)) * x_cf.shape[-1]))
                    y = activation(logit)[:,1]
                    y2.append(y.detach().cpu().numpy())
                else:
                    y = activation(model(torch.Tensor(sample).unsqueeze(0)))[0,0]
                    #print(y.shape)
                    y1.append(y.detach().cpu().numpy())
                    x_cf = sample.copy()
                    x_cf[imp[0],-1] = x_cf[imp[0],-2]
                    y = activation(model(torch.Tensor(x_cf).unsqueeze(0)))[0,0]
                    y2.append(y.detach().cpu().numpy())
            elif args.data=='mimic_int':
                x_cf = x.copy()
                if args.time_imp:
                    for _ in range(n_drops):
                        imp = np.unravel_index(importance_scores[i, :, min_t:max_t].argmax(), importance_scores[i, :, min_t:max_t].shape)
                        importance_scores[i, :, imp[1] + min_t:] = -1
                        x_cf = x_cf[:,:imp[1] + min_t]
                    lengths = (torch.ones((1,)) * x_cf.shape[1])
                else:
                    for _ in range(n_drops):
                        imp = np.unravel_index(importance_scores[i, :, min_t:].argmax(), importance_scores[i, :, min_t:].shape)
                        if importance_scores[i,imp[0], imp[1]+ min_t]<0:
                            continue
                        else:
                            importance_scores[i, imp[0], imp[1] + min_t:] = -1
                            x_cf[imp[0], imp[1] + min_t:] = x_cf[imp[0], imp[1] + min_t-1]
                label.append(y_test[i,:,x.shape[-1]-1])
                if args.explainer=='retain':
                    x_t = torch.Tensor(x).unsqueeze(0).permute(0, 2, 1)
                    x_cf_t = torch.Tensor(x_cf).unsqueeze(0).permute(0, 2, 1)
                    y, _, _ = (model(x_t.to(device), torch.ones((1,)) * x_t.shape[1]))
                    y1.append(activation(y).detach().cpu().numpy()[0,:])
                    y, _, _ = (model(x_cf_t.to(device), torch.ones((1,)) * x_cf_t.shape[1]))
                    y2.append(activation(y).detach().cpu().numpy()[0,:])
                else:
                    y = activation(model(torch.Tensor(x).unsqueeze(0)))[0,:]
                    y1.append(y.detach().cpu().numpy())
                    y = activation(model(torch.Tensor(x_cf).unsqueeze(0)))[0,:]
                    y2.append(y.detach().cpu().numpy())

        y1 = np.array(y1)#[:,0,:]
        y2 = np.array(y2)#[:,0,:]
        label = np.array(label)
        #print(y1.shape, y2.shape, label.shape)


        original_auc = metrics.roc_auc_score(label, y1,average='macro')
        modified_auc = metrics.roc_auc_score(label, y2,average='macro')

        original_aupr = metrics.average_precision_score(label, np.array(y1))
        modified_aupr = metrics.average_precision_score(label, np.array(y2))

        auc_drop.append(original_auc-modified_auc)
        aupr_drop.append(original_aupr-modified_aupr)

    print('obs drop' if not args.time_imp else 'time_drop')
    print(args.explainer, ' auc: %.3f$\\pm$%.3f'%(np.mean(auc_drop), np.std(auc_drop)), ' aupr: %.3f$\\pm$%.3f'%(np.mean(aupr_drop), np.std(aupr_drop)))
def main(experiment, train, data, generator_type, predictor_model, all_samples,
         cv, output_path):
    print('********** Experiment with the %s data **********' % experiment)
    with open('config.json') as config_file:
        configs = json.load(config_file)[data][experiment]

    if not os.path.exists('./data'):
        os.mkdir('./data')
    ## Load the data
    if data == 'mimic':
        p_data, train_loader, valid_loader, test_loader = load_data(
            batch_size=configs['batch_size'], path='./data', cv=cv)
        feature_size = p_data.feature_size
    elif data == 'ghg':
        p_data, train_loader, valid_loader, test_loader = load_ghg_data(
            configs['batch_size'], cv=cv)
        feature_size = p_data.feature_size
    elif data == 'simulation_spike':
        p_data, train_loader, valid_loader, test_loader = load_simulated_data(
            batch_size=configs['batch_size'],
            path='./data/simulated_spike_data',
            data_type='spike',
            cv=cv)
        feature_size = p_data.shape[1]

    elif data == 'simulation':
        percentage = 100.
        p_data, train_loader, valid_loader, test_loader = load_simulated_data(
            batch_size=configs['batch_size'],
            path='./data/simulated_data',
            percentage=percentage / 100,
            cv=cv)
        feature_size = p_data.shape[1]

    ## Create the experiment class
    if experiment == 'baseline':
        exp = Baseline(train_loader, valid_loader, test_loader,
                       p_data.feature_size)
    elif experiment == 'risk_predictor':
        exp = EncoderPredictor(train_loader,
                               valid_loader,
                               test_loader,
                               feature_size,
                               configs['encoding_size'],
                               rnn_type=configs['rnn_type'],
                               data=data,
                               model=predictor_model)
    elif experiment == 'feature_generator_explainer':
        exp = FeatureGeneratorExplainer(
            train_loader,
            valid_loader,
            test_loader,
            feature_size,
            patient_data=p_data,
            output_path=output_path,
            predictor_model=predictor_model,
            generator_hidden_size=configs['encoding_size'],
            prediction_size=1,
            generator_type=generator_type,
            data=data,
            experiment=experiment + '_' + generator_type)
    elif experiment == 'lime_explainer':
        exp = BaselineExplainer(train_loader,
                                valid_loader,
                                test_loader,
                                feature_size,
                                data_class=p_data,
                                data=data,
                                baseline_method='lime')

    if all_samples:
        print('Experiment on all test data')
        print('Number of test samples: ', len(exp.test_loader.dataset))
        exp.run(train=False,
                n_epochs=configs['n_epochs'],
                samples_to_analyze=list(range(0,
                                              len(exp.test_loader.dataset))),
                plot=False,
                cv=cv)
    else:
        exp.run(train=train,
                n_epochs=configs['n_epochs'],
                samples_to_analyze=samples_to_analyze[data])
def main(experiment, train, user, data, n_features_to_use=3):
    #sys.stdout = open('/scratch/gobi1/shalmali/global_importance_'+data+'.txt', 'w')
    filelist = glob.glob(
        os.path.join('/scratch/gobi1/%s/TSX_results' % user, data,
                     'results_*.pkl'))

    N = len(filelist)
    with open(filelist[0], 'rb') as f:
        arr = pkl.load(f)

    n_features = arr['FFC']['imp'].shape[0]
    Tt = arr['FFC']['imp'].shape[1]

    y_ffc = np.zeros((N, n_features))
    y_afo = np.zeros((N, n_features))
    y_suresh = np.zeros((N, n_features))
    y_sens = np.zeros((N, n_features))
    y_lime = np.zeros((N, n_features))

    for n, file in enumerate(filelist):
        with open(file, 'rb') as f:
            arr = pkl.load(f)

        y_ffc[n, :] = arr['FFC']['imp'].sum(1)
        y_afo[n, :] = arr['AFO']['imp'].sum(1)
        y_suresh[n, :] = arr['Suresh_et_al']['imp'].sum(1)
        y_sens[n, :] = arr['Sens']['imp'][:len(arr['FFC']['imp']), 1:].sum(1)
        y_lime[n, :] = parse_lime_results(arr, Tt, n_features,
                                          data=data).sum(1)

    y_rank_ffc = np.flip(np.argsort(
        y_ffc.sum(0)).flatten())  # sorted in order of relevance
    y_rank_afo = np.flip(np.argsort(
        y_afo.sum(0)).flatten())  # sorted in order of relevance
    y_rank_suresh = np.flip(np.argsort(
        y_suresh.sum(0)).flatten())  # sorted in order of relevance
    y_rank_sens = np.flip(np.argsort(
        y_sens.sum(0)).flatten())  # sorted in order of relevance
    y_rank_lime = np.flip(np.argsort(
        y_lime.sum(0)).flatten())  # sorted in order of relevance
    ranked_features = {
        'ffc': y_rank_ffc,
        'afo': y_rank_afo,
        'suresh': y_rank_suresh,
        'sens': y_rank_sens,
        'lime': y_rank_lime
    }

    with open('config.json') as config_file:
        configs = json.load(config_file)[data][experiment]

    methods = ranked_features.keys()

    for m in methods:
        print('Experiment with 5 most relevant features: ', m)
        feature_rank = ranked_features[m]

        for ff in [n_features_to_use]:
            features = feature_rank[:ff]
            print('using features', features)

            if data == 'mimic':
                p_data, train_loader, valid_loader, test_loader = load_data(
                    batch_size=configs['batch_size'],
                    path='./data',
                    features=features)
                feature_size = p_data.feature_size
            elif data == 'ghg':
                p_data, train_loader, valid_loader, test_loader = load_ghg_data(
                    configs['batch_size'], features=features)
                feature_size = p_data.feature_size
                print(feature_size)
            elif data == 'simulation_spike':
                p_data, train_loader, valid_loader, test_loader = load_simulated_data(
                    batch_size=configs['batch_size'],
                    path='./data_generator/data/simulated_data',
                    data_type='spike',
                    features=features)
                feature_size = p_data.shape[1]

            elif data == 'simulation':
                p_data, train_loader, valid_loader, test_loader = load_simulated_data(
                    batch_size=configs['batch_size'],
                    path='./data/simulated_data',
                    features=features)
                feature_size = p_data.shape[1]

            if data == 'simulation_spike':
                data = 'simulation'
                spike_data = True
            else:
                spike_data = False

            print('training on ', feature_size, ' features!')

            exp = EncoderPredictor(train_loader,
                                   valid_loader,
                                   test_loader,
                                   feature_size,
                                   configs['encoding_size'],
                                   rnn_type=configs['rnn_type'],
                                   data=data)
            exp.run(train=train, n_epochs=configs['n_epochs'])

    n_features_to_remove = 10  #add/remove same number for now
    #Exp 1 remove and evaluate
    for m in methods:
        print('Experiment for removing features using method: ', m)
        feature_rank = ranked_features[m]

        #for ff in range(min(n_features-1,n_features_to_remove)):
        for ff in [n_features_to_remove]:
            features = [
                elem for elem in list(range(n_features))
                if elem not in feature_rank[:ff]
            ]
            #print('using features:', features)

            if data == 'mimic':
                p_data, train_loader, valid_loader, test_loader = load_data(
                    batch_size=configs['batch_size'],
                    path='./data',
                    features=features)
                feature_size = p_data.feature_size
            elif data == 'ghg':
                p_data, train_loader, valid_loader, test_loader = load_ghg_data(
                    configs['batch_size'], features=features)
                feature_size = p_data.feature_size
                print(feature_size)
            elif data == 'simulation_spike':
                p_data, train_loader, valid_loader, test_loader = load_simulated_data(
                    batch_size=configs['batch_size'],
                    path='./data_generator/data/simulated_data',
                    data_type='spike',
                    features=features)
                feature_size = p_data.shape[1]

            elif data == 'simulation':
                p_data, train_loader, valid_loader, test_loader = load_simulated_data(
                    batch_size=configs['batch_size'],
                    path='./data/simulated_data',
                    features=features)
                feature_size = p_data.shape[1]

            if data == 'simulation_spike':
                data = 'simulation'
                spike_data = True
            else:
                spike_data = False

            print('training on ', feature_size, ' features!')

            exp = EncoderPredictor(train_loader,
                                   valid_loader,
                                   test_loader,
                                   feature_size,
                                   configs['encoding_size'],
                                   rnn_type=configs['rnn_type'],
                                   data=data)
            exp.run(train=train, n_epochs=configs['n_epochs'])
Beispiel #7
0
def main(experiment, train, uncertainty_score, data, generator_type):
    print('********** Experiment with the %s data **********' % (experiment))
    with open('config.json') as config_file:
        configs = json.load(config_file)[data][experiment]

    if data == 'mimic':
        p_data, train_loader, valid_loader, test_loader = load_data(
            batch_size=configs['batch_size'], path='./data')
        feature_size = p_data.feature_size
    elif data == 'ghg':
        p_data, train_loader, valid_loader, test_loader = load_ghg_data(
            configs['batch_size'])
        feature_size = p_data.feature_size
    elif data == 'simulation_spike':
        p_data, train_loader, valid_loader, test_loader = load_simulated_data(
            batch_size=configs['batch_size'],
            path='./data_generator/data/simulated_data',
            data_type='spike')
        feature_size = p_data.shape[1]

    elif data == 'simulation':
        p_data, train_loader, valid_loader, test_loader = load_simulated_data(
            batch_size=configs['batch_size'], path='./data/simulated_data')
        feature_size = p_data.shape[1]

    if data == 'simulation_spike':
        data = 'simulation'
        spike_data = True
    else:
        spike_data = False

    if experiment == 'baseline':
        exp = Baseline(train_loader, valid_loader, test_loader,
                       p_data.feature_size)
    elif experiment == 'risk_predictor':
        exp = EncoderPredictor(train_loader,
                               valid_loader,
                               test_loader,
                               feature_size,
                               configs['encoding_size'],
                               rnn_type=configs['rnn_type'],
                               data=data)
    elif experiment == 'feature_generator_explainer':
        #print(spike_data)
        exp = FeatureGeneratorExplainer(
            train_loader,
            valid_loader,
            test_loader,
            feature_size,
            patient_data=p_data,
            generator_hidden_size=configs['encoding_size'],
            prediction_size=1,
            historical=(configs['historical'] == 1),
            generator_type=generator_type,
            data=data,
            experiment=experiment + '_' + generator_type,
            spike_data=spike_data)
    elif experiment == 'lime_explainer':
        exp = BaselineExplainer(train_loader,
                                valid_loader,
                                test_loader,
                                feature_size,
                                data_class=p_data,
                                data=data,
                                baseline_method='lime')

    exp.run(train=train,
            n_epochs=configs['n_epochs'],
            samples_to_analyze=samples_to_analyze[data])
    #exp.final_reported_plots(samples_to_analyze=samples_to_analyze[data])

    # For MIMIC experiment, extract population level importance for interventions
    # print('********** Extracting population level intervention statistics **********')
    # if data == 'mimic' and experiment == 'feature_generator_explainer':
    #     for id in range(len(intervention_list)):
    #         if not os.path.exists("./interventions/int_%d.pkl" % (id)):
    #             exp.summary_stat(id)
    #         exp.plot_summary_stat(id)

    if uncertainty_score:
        # Evaluate output uncertainty using deep KNN method
        print('\n********** Uncertainty Evaluation: **********')
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        sample_ind = 1
        n_nearest_neighbors = 10
        dknn = DeepKnn(exp.model,
                       p_data.train_data[0:int(0.8 * p_data.n_train), :, :],
                       p_data.train_label[0:int(0.8 * p_data.n_train)], device)
        knn_labels = dknn.evaluate_confidence(
            sample=p_data.test_data[sample_ind, :, :].reshape((1, -1, 48)),
            sample_label=p_data.test_label[sample_ind],
            _nearest_neighbors=n_nearest_neighbors,
            verbose=True)
Beispiel #8
0
def main(args):
    if args.data == 'simulation':
        feature_size = 3
        data_path = './data/simulated_data'
        data_type = 'state'
    elif args.data == 'simulation_l2x':
        feature_size = 3
        data_path = './data/simulated_data_l2x'
        data_type = 'state'
    elif args.data == 'simulation_spike':
        feature_size = 3
        data_path = './data/simulated_spike_data'
        data_type = 'spike'
    elif args.data == 'mimic':
        data_type = 'mimic'
        timeseries_feature_size = len(feature_map_mimic)

    # Load data
    if args.data == 'mimic':
        p_data, train_loader, valid_loader, test_loader = load_data(
            batch_size=100, path='./data', cv=args.cv)
        feature_size = p_data.feature_size
    else:
        _, train_loader, valid_loader, test_loader = load_simulated_data(
            batch_size=100,
            datapath=data_path,
            percentage=0.8,
            data_type=data_type)

    model = StateClassifier(feature_size=feature_size,
                            n_state=2,
                            hidden_size=200)

    if args.explainer == 'fit':
        generator = JointFeatureGenerator(feature_size,
                                          hidden_size=feature_size * 3,
                                          data=args.data)
        generator.load_state_dict(
            torch.load(
                os.path.join('./ckpt/%s/%s.pt' %
                             (args.data, 'joint_generator'))))

    testset = [smpl[0] for smpl in test_loader.dataset]
    samples = torch.stack(
        [testset[sample] for sample in samples_to_analyze[args.data]])

    model.load_state_dict(
        torch.load(os.path.join('./ckpt/%s/%s.pt' % (args.data, 'model'))))
    if args.explainer == 'fit':
        explainer = FITExplainer(model, generator)
    elif args.explainer == 'integrated_gradient':
        explainer = IGExplainer(model)
    elif args.explainer == 'deep_lift':
        explainer = DeepLiftExplainer(model)
    elif args.explainer == 'fo':
        explainer = FOExplainer(model)
    elif args.explainer == 'afo':
        explainer = AFOExplainer(model, train_loader)
    elif args.explainer == 'gradient_shap':
        explainer = GradientShapExplainer(model)
    elif args.explainer == 'retain':
        model = RETAIN(dim_input=feature_size,
                       dim_emb=128,
                       dropout_emb=0.4,
                       dim_alpha=8,
                       dim_beta=8,
                       dropout_context=0.4,
                       dim_output=2)
        explainer = RETAINexplainer(model, args.data)
        model.load_state_dict(
            torch.load(os.path.join('./ckpt/%s/%s.pt' %
                                    (args.data, 'retain'))))
    gt_importance = explainer.attribute(samples, torch.zeros(samples.shape))

    for r_ind, ratio in enumerate([.2, .4, .6, .8, 1.]):
        for param in model.parameters():
            params = param.data.cpu().numpy().reshape(-1)
            params[int(r_ind * 0.2):int(ratio * len(params))] = torch.randn(
                int(ratio * len(params)))
            param.data = torch.Tensor(params.reshape(param.data.shape))
        if args.explainer == 'fit':
            explainer = FITExplainer(model, generator)
        elif args.explainer == 'integrated_gradient':
            explainer = IGExplainer(model)
        elif args.explainer == 'deep_lift':
            explainer = DeepLiftExplainer(model)
        elif args.explainer == 'fo':
            explainer = FOExplainer(model)
        elif args.explainer == 'afo':
            explainer = AFOExplainer(model, train_loader)
        elif args.explainer == 'gradient_shap':
            explainer = GradientShapExplainer(model)
        elif args.explainer == 'retain':
            model = RETAIN(dim_input=feature_size,
                           dim_emb=128,
                           dropout_emb=0.4,
                           dim_alpha=8,
                           dim_beta=8,
                           dropout_context=0.4,
                           dim_output=2)
            explainer = RETAINexplainer(model, args.data)
            model.load_state_dict(
                torch.load(
                    os.path.join('./ckpt/%s/%s.pt' % (args.data, 'retain'))))

        score = explainer.attribute(samples, torch.zeros(samples.shape))
        corr = []
        for sig in range(len(score)):
            corr.append(
                abs(
                    spearmanr(score[sig].reshape(-1, ),
                              gt_importance[sig].reshape(-1, ),
                              nan_policy='omit')[0]))
        print("correlation for %d percent randomization: %.3f +- %.3f" %
              (100 * ratio, np.mean(corr), np.std(corr)))
Beispiel #9
0
def main(generator_type, data_path):
    print(
        '********** Finding Accordance score for baseline methods **********')

    experiment = 'feature_generator_explainer'
    data = 'mimic'
    with open('config.json') as config_file:
        configs = json.load(config_file)[data]['feature_generator_explainer']
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    p_data, train_loader, valid_loader, test_loader = load_data(
        batch_size=configs['batch_size'], path='./data')
    feature_size = p_data.feature_size

    exp = FeatureGeneratorExplainer(
        train_loader,
        valid_loader,
        test_loader,
        feature_size,
        patient_data=p_data,
        generator_hidden_size=configs['encoding_size'],
        prediction_size=1,
        historical=(configs['historical'] == 1),
        generator_type=generator_type,
        data=data,
        experiment=experiment + '_' + generator_type)

    exp.generator.load_state_dict(
        torch.load(os.path.join('./ckpt/%s/%s.pt' %
                                ('mimic', generator_type))))
    baselines = ['FFC', 'AFO', 'FO']
    testset = list(exp.test_loader.dataset)
    test_signals = torch.stack(([x[0] for x in testset])).to(device)

    sum_matrix = np.zeros((3, 3))
    top_n = 6
    count = 0

    for sample_ID in range(len(testset)):
        signal = test_signals[sample_ID]
        matrix = np.zeros((3, 3))
        top_signals = np.zeros((3, signal.shape[0]))
        result_path = data_path + str(
            data) + '/results_%s.pkl' % str(sample_ID)
        if not os.path.exists(result_path):
            continue
        with open(result_path, 'rb') as f:
            arr = pkl.load(f)

        count += 1
        # Read importance scores generated by each methods and pick the top n
        ffc_importance = arr['FFC']['imp'].max(axis=1)
        afo_importance = arr['AFO']['imp'].max(axis=1)
        fo_importance = arr['Suresh_et_al']['imp'].max(axis=1)
        top = ffc_importance.argsort()[-1 * top_n:][::-1]
        top_signals[0, top] = 1
        top = afo_importance.argsort()[-1 * top_n:][::-1]
        top_signals[1, top] = 1
        top = fo_importance.argsort()[-1 * top_n:][::-1]
        top_signals[2, top] = 1

        for i in range(len(baselines)):
            for j in range(len(baselines)):
                matrix[i, j] = np.matmul(top_signals[i],
                                         top_signals[j]) / float(top_n)
        sum_matrix += matrix

    sum_matrix = sum_matrix / count

    fig = plt.figure()
    ax = fig.add_subplot(111)
    cax = ax.matshow(sum_matrix, interpolation='nearest')
    cb = fig.colorbar(cax, norm=mpl.colors.Normalize(vmin=0.0, vmax=1.))

    ax.set_xticklabels([''] + baselines, fontsize=24)
    ax.set_yticklabels([''] + baselines, fontsize=24)
    fig.savefig('/scratch/gobi1/sana/accordance.pdf')
        #change this to softmax for suresh et al
        activation = torch.nn.Sigmoid()
        #activation = torch.nn.Softmax(-1)

    if not os.path.exists(output_path):
        os.mkdir(output_path)
    plot_path = os.path.join('./plots/%s' % args.data)
    if not os.path.exists(plot_path):
        os.mkdir(plot_path)

    # Load data
    if args.data == 'mimic' or args.data == 'mimic_int':
        if args.mimic_path is None:
            raise ValueError(
                'Specify the data directory containing processed mimic data')
        p_data, train_loader, valid_loader, test_loader = load_data(batch_size=batch_size, \
            path=args.mimic_path,task=task,cv=args.cv)
        feature_size = p_data.feature_size
        class_weight = p_data.pos_weight
    else:
        _, train_loader, valid_loader, test_loader = load_simulated_data(
            batch_size=batch_size,
            datapath=data_path,
            percentage=0.8,
            data_type=data_type,
            cv=args.cv)

    # Prepare model to explain
    if args.explainer == 'retain':
        if args.data == 'mimic' or args.data == 'simulation' or args.data == 'simulation_l2x':
            model = RETAIN(dim_input=feature_size,
                           dim_emb=128,
Beispiel #11
0
        data_path = '/scratch/gobi2/projects/tsx/'
        #change this to softmax for suresh et al
        activation = torch.nn.Sigmoid()
        #activation = torch.nn.Softmax(-1)

    output_path = os.path.join(args.path, 'mimic_submission')  #args.data)
    with open(
            os.path.join(
                output_path, '%s_test_importance_scores_%d.pkl' %
                (args.explainer, args.cv)), 'rb') as f:
        importance_scores = pkl.load(f)

    if args.data == 'mimic' or args.data == 'mimic_int':
        p_data, _, _, test_loader = load_data(batch_size=100,
                                              path=data_path,
                                              task=task,
                                              cv=0,
                                              train_pc=1.)
        feature_size = p_data.feature_size
        x_test = torch.stack(
            ([x[0] for x in list(test_loader.dataset)])).cpu().numpy()
        y_test = torch.stack(
            ([x[1] for x in list(test_loader.dataset)])).cpu().numpy()
    else:
        if args.data == 'simulation_l2x' or args.data == 'simulation':
            file_name = 'state_dataset_'
        else:
            file_name = ''
        with open(os.path.join(data_path, file_name + 'x_test.pkl'),
                  'rb') as f:
            x_test = pkl.load(f)