Ejemplo n.º 1
0
def plot_CellCnn_PR_curves(prec, recall, seq, seq_labels, nclust, plotdir, key):
    sns.set(style="white")
    curr_palette = sns.color_palette("Set1", n_colors=len(seq))

    plt.clf()
    f, ax = plt.subplots()
    for i, nblast in enumerate(seq):
        if (seq_labels == []) or (nclust is None):
            plt.plot(recall[nblast], prec[nblast])
        else:
            if nclust[nblast] == (1,1):
                plt.plot(recall[nblast], prec[nblast], c=curr_palette[i],
                        linestyle = '-',
                        label=seq_labels[i])
            else:
                plt.plot(recall[nblast], prec[nblast], c=curr_palette[i],
                        linestyle = '--',
                        label=seq_labels[i] + ' (%d/%d)' % nclust[nblast])

    
    plt.xlabel('Recall', fontsize=28)
    plt.ylabel('Precision', fontsize=28)
    plt.ylim([0.0, 1.05])
    plt.xlim([0.0, 1.05])
    plt.legend(loc='center left', prop={'size':20})
    
    for item in (ax.get_xticklabels() + ax.get_yticklabels()):
        item.set_fontsize(24)
    plt.tight_layout()
    sns.despine()
    mkdir_p(plotdir)
    plt.savefig(os.path.join(plotdir, key+'_CellCnn_PRcurve.eps'),
                format='eps')
    plt.close()
Ejemplo n.º 2
0
def plot_filters(results, labels, outdir):
    mkdir_p(outdir)
    nmark = len(labels)
    # plot the filter weights of the best network
    w_best = results['w_best_net']
    idx_except_bias = np.array(
        range(nmark) + range(nmark + 1, w_best.shape[1]))
    nc = w_best.shape[1] - (nmark + 1)
    labels_except_bias = labels + ['out %d' % i for i in range(nc)]
    w_best = w_best[:, idx_except_bias]
    fig_path = os.path.join(outdir, 'best_net_weights.pdf')
    plot_nn_weights(w_best, labels_except_bias, fig_path, fig_size=(10, 10))
    # plot the filter clustering
    cl = results['clustering_result']
    cl_w = cl['w'][:, idx_except_bias]
    fig_path = os.path.join(outdir, 'clustered_filter_weights.pdf')
    plot_nn_weights(cl_w,
                    labels_except_bias,
                    fig_path,
                    row_linkage=cl['cluster_linkage'],
                    y_labels=cl['cluster_assignments'],
                    fig_size=(10, 10))
    # plot the selected filters
    if results['selected_filters'] is not None:
        w = results['selected_filters'][:, idx_except_bias]
        fig_path = os.path.join(outdir, 'consensus_filter_weights.pdf')
        plot_nn_weights(w, labels_except_bias, fig_path, fig_size=(10, 10))
        filters = results['selected_filters']
    else:
        sys.exit('Consensus filters were not found.')
Ejemplo n.º 3
0
def plot_benchmark_PR_curves(r_cnn, p_cnn, r_outlier, p_outlier, r_mean, p_mean,
                             r_sc, p_sc, nblast, plotdir, key):
        
    sns.set(style="white")
    curr_palette = sns.color_palette()
    col1 = curr_palette[2]
    col2 = curr_palette[1]
    col3 = curr_palette[0]
    col4 = curr_palette[3]
            
    plt.clf()
    f, ax = plt.subplots()
    plt.plot(r_cnn, p_cnn, c=col1, label='CellCnn')
    plt.plot(r_outlier, p_outlier, c=col2, label='outlier')
    plt.plot(r_mean, p_mean, c=col3, label='mean')
    plt.plot(r_sc, p_sc, c=col4, label='sc')
    plt.xlabel('Recall', fontsize=28)
    plt.ylabel('Precision', fontsize=28)
    plt.ylim([0.0, 1.05])
    plt.xlim([0.0, 1.05])
    plt.legend(loc='center left' , prop={'size':24})
    for item in (ax.get_xticklabels() + ax.get_yticklabels()):
        item.set_fontsize(24)
    plt.tight_layout()
    sns.despine()
    mkdir_p(plotdir)
    plt.savefig(os.path.join(plotdir, str(nblast)+'_PR_curve.eps'), format='eps')
    plt.close()
Ejemplo n.º 4
0
def plot_benchmark_PR_curves(r_cnn, p_cnn, r_outlier, p_outlier, r_mean,
                             p_mean, r_sc, p_sc, nblast, plotdir, key):

    sns.set(style="white")
    curr_palette = sns.color_palette()
    col1 = curr_palette[2]
    col2 = curr_palette[1]
    col3 = curr_palette[0]
    col4 = curr_palette[3]

    plt.clf()
    f, ax = plt.subplots()
    plt.plot(r_cnn, p_cnn, c=col1, label='CellCnn')
    plt.plot(r_outlier, p_outlier, c=col2, label='outlier')
    plt.plot(r_mean, p_mean, c=col3, label='mean')
    plt.plot(r_sc, p_sc, c=col4, label='sc')
    plt.xlabel('Recall', fontsize=28)
    plt.ylabel('Precision', fontsize=28)
    plt.ylim([0.0, 1.05])
    plt.xlim([0.0, 1.05])
    plt.legend(loc='center left', prop={'size': 24})
    for item in (ax.get_xticklabels() + ax.get_yticklabels()):
        item.set_fontsize(24)
    plt.tight_layout()
    sns.despine()
    mkdir_p(plotdir)
    plt.savefig(os.path.join(plotdir,
                             str(nblast) + '_PR_curve.eps'),
                format='eps')
    plt.close()
Ejemplo n.º 5
0
def main():

	PATH = "/Volumes/biol_imsb_claassen_s1/eiriniar/Data/viSNE/mrd_debarcode"
	mrd_file = os.path.join(PATH, 'mrd_debarcoded.csv')
	healthy_file = os.path.join(PATH, 'healthy_debarcoded.csv')
	control_file = os.path.join(PATH, 'visne_marrow1.csv')
	mrd_data = pd.read_csv(mrd_file, sep=',')
	healthy_data = pd.read_csv(healthy_file, sep=',')
	control_data = pd.read_csv(control_file, sep=',')

	# all available channels
	channels = list(control_data.columns)

	# which markers should be kept for further analysis
	full_labels = ['CD19(Nd142)Di','CD22(Nd143)Di', 'CD47(Nd145)Di','CD79b(Nd146)Di',
					'CD20(Sm147)Di', 'CD34(Nd148)Di','CD179a(Sm149)Di','CD72(Eu151)Di',
					'IgM-i(Eu153)Di','CD45(Sm154)Di','CD10(Gd156)Di',
					'CD179b(Gd158)Di','CD11c(Tb159)Di','CD14(Gd160)Di','CD24(Dy161)Di',
					'CD127(Dy162)Di','TdT(Dy163)Di','CD15(Dy164)Di','Pax5(Ho165)Di',
					'CD38(Er168)Di','CD3(Er170)Di','CD117(Yb171)Di',
					'CD49d(Yb172)Di','CD33(Yb173)Di','HLADR(Yb174)Di','IgM-s(Lu175)Di',
					'CD7(Yb176)Di']

	labels = [label.split('(')[0] for label in full_labels]

	# which columns correspond to the interesting markers
	marker_idx = [channels.index(label) for label in full_labels]

	# keep only interesting markers and arcsinh-transform the data
	x_mrd = ftrans(np.asarray(mrd_data)[:,marker_idx], 5)
	x_healthy = ftrans(np.asarray(healthy_data)[:,marker_idx], 5)
	x_control = ftrans(np.asarray(control_data)[:,marker_idx], 5)

	# select CD10+ blasts
	cd10_idx = np.argsort(x_mrd[:,10])
	x_mrd = x_mrd[cd10_idx[-500:]]
	
	# save the pre-processed dataset
	pickle_dir = os.path.join(cellCnn.__path__[0], 'examples', 'data')
	mkdir_p(pickle_dir)
	pickle_file = os.path.join(pickle_dir, 'ALL.pkl')
	
	data_dict = {'control': x_control,
				 'healthy': x_healthy,
				 'ALL': x_mrd,
				 'labels': labels}
	with open(pickle_file, 'wb') as f:
			pickle.dump(data_dict, f, -1)

	return 0
Ejemplo n.º 6
0
def main():

	# cell types manually gated
	CTYPES = ['cd4+', 'cd8+', 'cd14+hladrmid', 'cd14-hladrmid', 'cd14+surf-', 'cd14-surf-',
			 'dendritic', 'igm+', 'igm-', 'nk']
	
	# channels measured in this experiment
	CH = ['Time', 'Cell_length', 'CD3', 'CD45', 'BC1', 'BC2', 'pNFkB',
			'pp38', 'CD4', 'BC3', 'CD20', 'CD33', 'pStat5', 'CD123',
			'pAkt', 'pStat1', 'pSHP2', 'pZap70', 'pStat3', 'BC4', 'CD14',
			'pSlp76', 'BC5', 'pBtk', 'pPlcg2', 'pErk', 'BC6', 'pLat',
			'IgM', 'pS6', 'HLA-DR', 'BC7', 'CD7', 'DNA-1', 'DNA-2']
	
	# intracelluler makrers
	PH_LABELS = ['pStat1', 'pStat3', 'pStat5', 'pNFkB', 'pp38', 'pAkt', 'pSHP2', 'pZap70',
				'pSlp76', 'pBtk', 'pPlcg2', 'pErk', 'pLat', 'pS6']
	PH_IDX = [CH.index(label) for label in PH_LABELS]

	# cell surface markers
	CD_LABELS = ['CD45', 'CD3', 'CD4', 'CD7', 'CD20', 'IgM', 'CD33',
				 'CD14', 'HLA-DR', 'CD123']
	CD_IDX = [CH.index(label) for label in CD_LABELS]

	# all interesting markers that should be read
	labels = CD_LABELS + PH_LABELS
	marker_idx = CD_IDX + PH_IDX

	# different stimuli considered in this experiemnt
	STIMULI = ['02', '03', '04', '06', '07', '08', '09', '10', '11', '12', '01']
	STIM_NAMES = ['IL-3', 'IL-2', 'IL-12', 'G-CSF', 'GM-CSF',
				  'BCR', 'IFN-g', 'IFN-a', 'LPS', 'PMA', 'Vanadate']


	# store the pre-processed datasets
	pickle_dir = os.path.join(cellCnn.__path__[0], 'examples', 'data')
	mkdir_p(pickle_dir)
	
	for (s_code, s_name) in zip(STIMULI, STIM_NAMES):
		pickle_file = os.path.join(pickle_dir, s_name + '_vs_control.pkl')
		lookup = no_inhibitor_lookup_full(data_path=FCS_DATA_PATH,
										 stimuli=['05', s_code],
										 ctypes=CTYPES,
										 marker_idx=marker_idx)

		with open(pickle_file, 'wb') as f:
			pickle.dump(lookup, f, -1)

	return 0
Ejemplo n.º 7
0
def plot_CellCnn_PR_curves(prec, recall, seq, seq_labels, nclust, plotdir,
                           key):
    sns.set(style="white")
    curr_palette = sns.color_palette("Set1", n_colors=len(seq))

    plt.clf()
    f, ax = plt.subplots()
    for i, nblast in enumerate(seq):
        if (seq_labels == []) or (nclust is None):
            plt.plot(recall[nblast], prec[nblast])
        else:
            if nclust[nblast] == (1, 1):
                plt.plot(recall[nblast],
                         prec[nblast],
                         c=curr_palette[i],
                         linestyle='-',
                         label=seq_labels[i])
            else:
                plt.plot(recall[nblast],
                         prec[nblast],
                         c=curr_palette[i],
                         linestyle='--',
                         label=seq_labels[i] + ' (%d/%d)' % nclust[nblast])

    plt.xlabel('Recall', fontsize=28)
    plt.ylabel('Precision', fontsize=28)
    plt.ylim([0.0, 1.05])
    plt.xlim([0.0, 1.05])
    plt.legend(loc='center left', prop={'size': 20})

    for item in (ax.get_xticklabels() + ax.get_yticklabels()):
        item.set_fontsize(24)
    plt.tight_layout()
    sns.despine()
    mkdir_p(plotdir)
    plt.savefig(os.path.join(plotdir, key + '_CellCnn_PRcurve.eps'),
                format='eps')
    plt.close()
Ejemplo n.º 8
0
def plot_barcharts(y_true, cnn_pred , outlier_pred, mean_pred, sc_pred,
                    nblast, label, plotdir, key, at_recall=0.8, include_citrus=False):
    ntop = int(at_recall * nblast)
    count_h_cnn = return_FP(y_true, cnn_pred, ntop)
    count_h_outlier = return_FP(y_true, outlier_pred, ntop)
    count_h_mean = return_FP(y_true, mean_pred, ntop)
    count_h_sc = return_FP(y_true, sc_pred, ntop)
    
    methods = ['CellCnn', 'outlier', 'mean', 'sc']
    if include_citrus:
        methods.append('Citrus')
        
    n_method = len(methods)
    arr = np.zeros((n_method, 2), dtype=float)
    arr[0] = np.array([count_h_cnn, ntop+count_h_cnn])
    arr[1] = np.array([count_h_outlier, ntop+count_h_outlier])
    arr[2] = np.array([count_h_mean, ntop+count_h_mean])
    arr[3] = np.array([count_h_sc, ntop+count_h_sc])
    
    # only include if we already have results from a Citrus run in R
    if include_citrus:
        if (label == 'AML') and (nblast > 3000):
            arr[4] = np.array([0, 1])
        else:
            arr[4] = np.array([1, 1])

    for i in range(n_method):
        arr[i] /= arr[i,-1]
    
    df_count = pd.DataFrame(arr, columns = ['ctrl', 'total'])
    df_count['names'] = methods
            
    # now plot the results from the dataframe
    sns.set(style="white")
    f, ax = plt.subplots(figsize=(6, 8))
    #cp = sns.color_palette("Set1", n_colors=8, desat=.5)
    curr_palette = sns.color_palette()
    col1 = curr_palette[2]
    col2 = curr_palette[0]
    
    sns.barplot(x="names", y='total', data=df_count, label=label, color=col1)
    bottom_plot = sns.barplot(x="names", y='ctrl', data=df_count,
                             label='healthy', color=col2)
        
    topbar = plt.Rectangle((0,0),1,1, fc=col1, edgecolor = 'none')
    bbar = plt.Rectangle((0,0),1,1, fc=col2,  edgecolor = 'none')
    l = plt.legend([bbar, topbar], ['healthy', label],
                    loc=1, ncol = 2, prop={'size':18})
    l.draw_frame(False)
        
    ax.set(ylim=(0, 1.1), xlabel="",
            ylabel="frequency of selected cell states")
    plt.yticks([.25, 0.5, .75, 1])
    plt.xticks(rotation=45)
    sns.despine(left=True, bottom=True)
        
    #Set fonts to consistent size
    for item in ([bottom_plot.xaxis.label, bottom_plot.yaxis.label] +
                 bottom_plot.get_xticklabels() + bottom_plot.get_yticklabels()):
        item.set_fontsize(22)
    plt.tight_layout()
    mkdir_p(plotdir)
    plt.savefig(os.path.join(plotdir, str(nblast)+'_barcharts_top_cells.eps'), format='eps')
    plt.close()  
Ejemplo n.º 9
0
def plot_barcharts(y_true,
                   cnn_pred,
                   outlier_pred,
                   mean_pred,
                   sc_pred,
                   nblast,
                   label,
                   plotdir,
                   key,
                   at_recall=0.8,
                   include_citrus=False,
                   cutoff=None):

    if cutoff is None:
        ntop = int(at_recall * nblast)
        count_h_cnn = return_FP(y_true, cnn_pred, ntop)
        count_h_outlier = return_FP(y_true, outlier_pred, ntop)
        count_h_mean = return_FP(y_true, mean_pred, ntop)
        count_h_sc = return_FP(y_true, sc_pred, ntop)
        count_tot_cnn = ntop + count_h_cnn
        count_tot_outlier = ntop + count_h_outlier
        count_tot_mean = ntop + count_h_mean
        count_tot_sc = ntop + count_h_sc
    else:
        count_h_cnn, count_tot_cnn = return_FP_cutoff(y_true, cnn_pred, cutoff)
        count_h_outlier, count_tot_outlier = return_FP_cutoff(
            y_true, outlier_pred, cutoff)
        count_h_mean, count_tot_mean = return_FP_cutoff(
            y_true, mean_pred, cutoff)
        count_h_sc, count_tot_sc = return_FP_cutoff(y_true, sc_pred, cutoff)

    methods = ['CellCnn', 'outlier', 'mean', 'sc']
    if include_citrus:
        methods.append('Citrus')

    n_method = len(methods)
    arr = np.zeros((n_method, 2), dtype=float)
    arr[0] = np.array([count_h_cnn, count_tot_cnn])
    arr[1] = np.array([count_h_outlier, count_tot_outlier])
    arr[2] = np.array([count_h_mean, count_tot_mean])
    arr[3] = np.array([count_h_sc, count_tot_sc])

    # only include if we already have results from a Citrus run in R
    if include_citrus:
        if (label == 'AML') and (nblast > 3000):
            arr[4] = np.array([0, 1])
        else:
            arr[4] = np.array([1, 1])

    for i in range(n_method):
        arr[i] /= arr[i, -1]

    df_count = pd.DataFrame(arr, columns=['ctrl', 'total'])
    df_count['names'] = methods

    # now plot the results from the dataframe
    sns.set(style="white")
    f, ax = plt.subplots(figsize=(6, 8))
    #cp = sns.color_palette("Set1", n_colors=8, desat=.5)
    curr_palette = sns.color_palette()
    col1 = curr_palette[2]
    col2 = curr_palette[0]

    sns.barplot(x="names", y='total', data=df_count, label=label, color=col1)
    bottom_plot = sns.barplot(x="names",
                              y='ctrl',
                              data=df_count,
                              label='healthy',
                              color=col2)

    topbar = plt.Rectangle((0, 0), 1, 1, fc=col1, edgecolor='none')
    bbar = plt.Rectangle((0, 0), 1, 1, fc=col2, edgecolor='none')
    l = plt.legend([bbar, topbar], ['healthy', label],
                   loc=1,
                   ncol=2,
                   prop={'size': 18})
    l.draw_frame(False)

    ax.set(ylim=(0, 1.1),
           xlabel="",
           ylabel="frequency of selected cell states")
    plt.yticks([.25, 0.5, .75, 1])
    plt.xticks(rotation=45)
    sns.despine(left=True, bottom=True)

    #Set fonts to consistent size
    for item in ([bottom_plot.xaxis.label, bottom_plot.yaxis.label] +
                 bottom_plot.get_xticklabels() +
                 bottom_plot.get_yticklabels()):
        item.set_fontsize(22)
    plt.tight_layout()
    mkdir_p(plotdir)
    plt.savefig(os.path.join(plotdir,
                             str(nblast) + '_barcharts_top_cells.eps'),
                format='eps')
    plt.close()
Ejemplo n.º 10
0
def train_model(train_samples, train_phenotypes, outdir,
                valid_samples=None, valid_phenotypes=None, generate_valid_set=True,
                scale=True, quant_normed=False, nrun=20, regression=False,
                ncell=200, nsubset=1000, per_sample=False, subset_selection='random',
                maxpool_percentages=[0.01, 1., 5., 20., 100.], nfilter_choice=range(3, 10),
                learning_rate=None, coeff_l1=0, coeff_l2=1e-4, dropout='auto', dropout_p=.5,
                max_epochs=20, patience=5,
                dendrogram_cutoff=0.4, accur_thres=.95, verbose=1):

    """ Performs a CellCnn analysis """
    mkdir_p(outdir)

    if nrun < 3:
        print 'The nrun argument should be >= 3, setting it to 3.'
        nrun = 3

    # copy the list of samples so that they are not modified in place
    train_samples = copy.deepcopy(train_samples)
    if valid_samples is not None:
        valid_samples = copy.deepcopy(valid_samples)

    # normalize extreme values
    # we assume that 0 corresponds to the control class
    if subset_selection == 'outlier':
        ctrl_list = [train_samples[i] for i in np.where(np.array(train_phenotypes) == 0)[0]]
        test_list = [train_samples[i] for i in np.where(np.array(train_phenotypes) != 0)[0]]
        train_samples = normalize_outliers_to_control(ctrl_list, test_list)

        if valid_samples is not None:
            ctrl_list = [valid_samples[i] for i in np.where(np.array(valid_phenotypes) == 0)[0]]
            test_list = [valid_samples[i] for i in np.where(np.array(valid_phenotypes) != 0)[0]]
            valid_samples = normalize_outliers_to_control(ctrl_list, test_list)

    # merge all input samples (X_train, X_valid)
    # and generate an identifier for each of them (train_id, valid_id)
    if (valid_samples is None) and (not generate_valid_set):
        sample_ids = range(len(train_phenotypes))
        X_train, id_train = combine_samples(train_samples, sample_ids)

    elif (valid_samples is None) and generate_valid_set:
        sample_ids = range(len(train_phenotypes))
        X, sample_id = combine_samples(train_samples, sample_ids)
        valid_phenotypes = train_phenotypes

        # split into train-validation partitions
        eval_folds = 5
        #kf = StratifiedKFold(sample_id, eval_folds)
        #train_indices, valid_indices = next(iter(kf))
        kf = StratifiedKFold(n_splits=eval_folds)
        train_indices, valid_indices = next(kf.split(X, sample_id))
        X_train, id_train = X[train_indices], sample_id[train_indices]
        X_valid, id_valid = X[valid_indices], sample_id[valid_indices]

    else:
        sample_ids = range(len(train_phenotypes))
        X_train, id_train = combine_samples(train_samples, sample_ids)
        sample_ids = range(len(valid_phenotypes))
        X_valid, id_valid = combine_samples(valid_samples, sample_ids)

    if quant_normed:
        z_scaler = StandardScaler(with_mean=True, with_std=False)
        z_scaler.fit(0.5 * np.ones((1, X_train.shape[1])))
        X_train = z_scaler.transform(X_train)
    elif scale:
        z_scaler = StandardScaler(with_mean=True, with_std=True)
        z_scaler.fit(X_train)
        X_train = z_scaler.transform(X_train)
    else:
        z_scaler = None

    X_train, id_train = shuffle(X_train, id_train)
    train_phenotypes = np.asarray(train_phenotypes)

    # an array containing the phenotype for each single cell
    y_train = train_phenotypes[id_train]

    if (valid_samples is not None) or generate_valid_set:
        if scale:
            X_valid = z_scaler.transform(X_valid)

        X_valid, id_valid = shuffle(X_valid, id_valid)
        valid_phenotypes = np.asarray(valid_phenotypes)
        y_valid = valid_phenotypes[id_valid]

    # number of measured markers
    nmark = X_train.shape[1]

    # generate multi-cell inputs
    print 'Generating multi-cell inputs...'

    if subset_selection == 'outlier':
        # here we assume that class 0 is always the control class
        x_ctrl_train = X_train[y_train == 0]
        to_keep = int(0.1 * (X_train.shape[0] / len(train_phenotypes)))
        nsubset_ctrl = nsubset / np.sum(train_phenotypes == 0)

        # generate a fixed number of subsets per class
        nsubset_biased = [0]
        for pheno in range(1, len(np.unique(train_phenotypes))):
            nsubset_biased.append(nsubset / np.sum(train_phenotypes == pheno))

        X_tr, y_tr = generate_biased_subsets(X_train, train_phenotypes, id_train, x_ctrl_train,
                                             nsubset_ctrl, nsubset_biased, ncell, to_keep,
                                             id_ctrl=np.where(train_phenotypes == 0)[0],
                                             id_biased=np.where(train_phenotypes != 0)[0])
        # save those because it takes long to generate
        #np.save(os.path.join(outdir, 'X_tr.npy'), X_tr)
        #np.save(os.path.join(outdir, 'y_tr.npy'), y_tr)
        #X_tr = np.load(os.path.join(outdir, 'X_tr.npy'))
        #y_tr = np.load(os.path.join(outdir, 'y_tr.npy'))

        if (valid_samples is not None) or generate_valid_set:
            x_ctrl_valid = X_valid[y_valid == 0]
            nsubset_ctrl = nsubset / np.sum(valid_phenotypes == 0)

            # generate a fixed number of subsets per class
            nsubset_biased = [0]
            for pheno in range(1, len(np.unique(valid_phenotypes))):
                nsubset_biased.append(nsubset / np.sum(valid_phenotypes == pheno))

            to_keep = int(0.1 * (X_valid.shape[0] / len(valid_phenotypes)))
            X_v, y_v = generate_biased_subsets(X_valid, valid_phenotypes, id_valid, x_ctrl_valid,
                                               nsubset_ctrl, nsubset_biased, ncell, to_keep,
                                               id_ctrl=np.where(valid_phenotypes == 0)[0],
                                               id_biased=np.where(valid_phenotypes != 0)[0])
            # save those because it takes long to generate
            #np.save(os.path.join(outdir, 'X_v.npy'), X_v)
            #np.save(os.path.join(outdir, 'y_v.npy'), y_v)
            #X_v = np.load(os.path.join(outdir, 'X_v.npy'))
            #y_v = np.load(os.path.join(outdir, 'y_v.npy'))
        else:
            cut = X_tr.shape[0] / 5
            X_v = X_tr[:cut]
            y_v = y_tr[:cut]
            X_tr = X_tr[cut:]
            y_tr = y_tr[cut:]
    else:
        # generate 'nsubset' multi-cell inputs per input sample
        if per_sample:
            X_tr, y_tr = generate_subsets(X_train, train_phenotypes, id_train,
                                          nsubset, ncell, per_sample)
            if (valid_samples is not None) or generate_valid_set:
                X_v, y_v = generate_subsets(X_valid, valid_phenotypes, id_valid,
                                            nsubset, ncell, per_sample)
        # generate 'nsubset' multi-cell inputs per class
        else:
            nsubset_list = []
            for pheno in range(len(np.unique(train_phenotypes))):
                nsubset_list.append(nsubset / np.sum(train_phenotypes == pheno))
            
            X_tr, y_tr = generate_subsets(X_train, train_phenotypes, id_train,
                                          nsubset_list, ncell, per_sample)

            if (valid_samples is not None) or generate_valid_set:
                nsubset_list = []
                for pheno in range(len(np.unique(valid_phenotypes))):
                    nsubset_list.append(nsubset / np.sum(valid_phenotypes == pheno))
                X_v, y_v = generate_subsets(X_valid, valid_phenotypes, id_valid,
                                            nsubset_list, ncell, per_sample)
    print 'Done.'

    ## neural network configuration ##
    # batch size
    bs = 200

    # keras needs (nbatch, ncell, nmark)
    X_tr = np.swapaxes(X_tr, 2, 1)
    X_v = np.swapaxes(X_v, 2, 1)
    n_classes = 1

    if not regression:
        n_classes = len(np.unique(train_phenotypes))
        y_tr = to_categorical(y_tr, n_classes)
        y_v = to_categorical(y_v, n_classes)

    # train some neural networks with different parameter configurations
    accuracies = np.zeros(nrun)
    w_store = dict()
    config = dict()
    config['nfilter'] = []
    config['learning_rate'] = []
    config['maxpool_percentage'] = []
    lr = learning_rate

    for irun in range(nrun):
        if verbose:
            print 'training network: %d' % (irun + 1)
        if learning_rate is None:
            lr = 10 ** np.random.uniform(-3, -2)
            config['learning_rate'].append(lr)

        # choose number of filters for this run
        nfilter = np.random.choice(nfilter_choice)
        config['nfilter'].append(nfilter)
        print 'Number of filters: %d' % nfilter

        # choose number of cells pooled for this run
        mp = maxpool_percentages[irun % len(maxpool_percentages)]
        config['maxpool_percentage'].append(mp)
        k = max(1, int(mp/100. * ncell))
        print 'Cells pooled: %d' % k

        # build the neural network
        model = build_model(ncell, nmark, nfilter,
                            coeff_l1, coeff_l2, k,
                            dropout, dropout_p, regression, n_classes, lr)

        filepath = os.path.join(outdir, 'nnet_run_%d.hdf5' % irun)
        try:
            if not regression:
                check = ModelCheckpoint(filepath, monitor='val_loss', save_best_only=True,
                                        mode='auto')
                earlyStopping = EarlyStopping(monitor='val_loss', patience=patience, mode='auto')
                model.fit(float32(X_tr), int32(y_tr),
                          nb_epoch=max_epochs, batch_size=bs, callbacks=[check, earlyStopping],
                          validation_data=(float32(X_v), int32(y_v)), verbose=verbose)
            else:
                check = ModelCheckpoint(filepath, monitor='val_loss', save_best_only=True,
                                        mode='auto')
                earlyStopping = EarlyStopping(monitor='val_loss', patience=patience, mode='auto')
                model.fit(float32(X_tr), float32(y_tr),
                          nb_epoch=max_epochs, batch_size=bs, callbacks=[check, earlyStopping],
                          validation_data=(float32(X_v), float32(y_v)), verbose=verbose)

            # load the model from the epoch with highest validation accuracy
            model.load_weights(filepath)

            if not regression:
                valid_metric = model.evaluate(float32(X_v), int32(y_v))[-1]
                print 'Best validation accuracy: %.2f' % valid_metric
                accuracies[irun] = valid_metric

            else:
                train_metric = model.evaluate(float32(X_tr), float32(y_tr), batch_size=bs)
                print 'Best train loss: %.2f' % train_metric
                valid_metric = model.evaluate(float32(X_v), float32(y_v), batch_size=bs)
                print 'Best validation loss: %.2f' % valid_metric
                accuracies[irun] = - valid_metric

            # extract the network parameters
            w_store[irun] = model.get_weights()

        except Exception as e:
            sys.stderr.write('An exception was raised during training the network.\n')
            sys.stderr.write(str(e) + '\n')

    # the top 3 performing networks
    model_sorted_idx = np.argsort(accuracies)[::-1][:3]
    best_3_nets = [w_store[i] for i in model_sorted_idx]
    best_net = best_3_nets[0]
    best_accuracy_idx = model_sorted_idx[0]

    # weights from the best-performing network
    w_best_net = keras_param_vector(best_net)

    # post-process the learned filters
    # cluster weights from all networks that achieved accuracy above the specified thershold
    w_cons, cluster_res = cluster_profiles(w_store, nmark, accuracies, accur_thres,
                                           dendrogram_cutoff=dendrogram_cutoff)

    # output w_store as a .csv file
    w = csv.writer(open("w_store.csv", "w"))
    for key, val in w_store.items():
        w.writerow([key, val])

    results = {
        'clustering_result': cluster_res,
        'selected_filters': w_cons,
        'best_net': best_net,
        'best_3_nets': best_3_nets,
        'w_best_net': w_best_net,
        'accuracies': accuracies,
        'best_model_index': best_accuracy_idx,
        'config': config,
        'scaler': z_scaler,
        'n_classes' : n_classes
    }

    if (valid_samples is not None) and (w_cons is not None):
        maxpool_percentage = config['maxpool_percentage'][best_accuracy_idx]
        if regression:
            tau = get_filters_regression(w_cons, z_scaler, valid_samples, valid_phenotypes,
                                         maxpool_percentage)
            results['filter_tau'] = tau

        else:
            filter_diff = get_filters_classification(w_cons, z_scaler, valid_samples,
                                                     valid_phenotypes, maxpool_percentage)
            results['filter_diff'] = filter_diff
    return results
Ejemplo n.º 11
0
def discriminative_filters(results,
                           outdir,
                           filter_diff_thres,
                           positive_filters_only=False,
                           show_filters=True):
    mkdir_p(outdir)
    # select the discriminative filters based on the validation set
    if 'filter_diff' in results:
        filter_diff = results['filter_diff']

        # do we want to consider negative filters?
        if positive_filters_only:
            filter_diff = filter_diff * np.sign(filters[:, -1])
        sorted_idx = np.argsort(filter_diff)[::-1]
        filter_diff = filter_diff[sorted_idx]
        keep_idx = [sorted_idx[0]]
        for i in range(0, len(filter_diff) - 1):
            if (filter_diff[i] -
                    filter_diff[i + 1]) < filter_diff_thres * filter_diff[i]:
                keep_idx.append(sorted_idx[i + 1])
            else:
                break
        if show_filters:
            plt.figure()
            sns.set_style('whitegrid')
            plt.plot(range(len(filter_diff)), filter_diff, '--')
            plt.xticks(range(len(filter_diff)),
                       ['filter %d' % i for i in sorted_idx],
                       rotation='vertical')
            plt.ylabel(
                'average cell filter response difference between classes')
            sns.despine()
            plt.savefig(os.path.join(outdir,
                                     'filter_response_differences.pdf'),
                        format='pdf')
            plt.clf()
            plt.close()

    elif 'filter_tau' in results:
        filter_diff = results['filter_tau']

        # do we want to consider negative filters?
        if positive_filters_only:
            filter_diff = filter_diff * np.sign(filters[:, -1])
        sorted_idx = np.argsort(filter_diff)[::-1]
        filter_diff = filter_diff[sorted_idx]
        keep_idx = [sorted_idx[0]]
        for i in range(0, len(filter_diff) - 1):
            if (filter_diff[i] -
                    filter_diff[i + 1]) < filter_diff_thres * filter_diff[i]:
                keep_idx.append(sorted_idx[i + 1])
            else:
                break
        if show_filters:
            plt.figure()
            sns.set_style('whitegrid')
            plt.plot(range(len(filter_diff)), filter_diff, '--')
            plt.xticks(range(len(filter_diff)),
                       ['filter %d' % i for i in sorted_idx],
                       rotation='vertical')
            plt.ylabel('Kendalls tau')
            sns.despine()
            plt.savefig(os.path.join(outdir,
                                     'filter_response_differences.pdf'),
                        format='pdf')
            plt.clf()
            plt.close()

    # if no validation samples were provided, keep all consensus filters
    else:
        filters = results['selected_filters']
        keep_idx = range(filters.shape[0])
    return keep_idx
Ejemplo n.º 12
0
def plot_results(results,
                 samples,
                 phenotypes,
                 labels,
                 outdir,
                 filter_diff_thres=.2,
                 filter_response_thres=0,
                 response_grad_cutoff=None,
                 stat_test=None,
                 positive_filters_only=False,
                 log_yscale=False,
                 group_a='group A',
                 group_b='group B',
                 group_names=None,
                 tsne_ncell=10000,
                 regression=False,
                 clustering=None,
                 add_filter_response=False,
                 percentage_drop_cluster=.1,
                 min_cluster_freq=0.2,
                 show_filters=True):
    """ Plots the results of a CellCnn analysis.

    Args:
        - results :
            Dictionary containing the results of a CellCnn analysis.
        - samples :
            Samples from which to visualize the selected cell populations.
        - phenotypes :
            List of phenotypes corresponding to the provided `samples`.
        - labels :
            Names of measured markers.
        - outdir :
            Output directory where the generated plots will be stored.
        - filter_diff_thres :
            Threshold that defines which filters are most discriminative. Given an array
            ``filter_diff`` of average cell filter response differences between classes,
            sorted in decreasing order, keep a filter ``i, i > 0`` if it holds that
            ``filter_diff[i-1] - filter_diff[i] < filter_diff_thres * filter_diff[i-1]``.
            For regression problems, the array ``filter_diff`` contains Kendall's tau
            values for each filter.
        - filter_response_thres :
            Threshold for choosing a responding cell population. Default is 0.
        - response_grad_cutoff :
            Threshold on the gradient of the cell filter response CDF, might be useful for defining
            the selected cell population.
        - stat_test: None | 'ttest' | 'mannwhitneyu'
            Optionally, perform a statistical test on selected cell population frequencies between
            two groups and report the corresponding p-value on the boxplot figure
            (see plots description below). Default is None. Currently only used for binary
            classification problems.
        - group_a :
            Name of the first class.
        - group_b :
            Name of the second class.
        - group_names :
            List of names for the different phenotype classes.
        - positive_filters_only :
            If True, only consider filters associated with higher cell population frequency in the
            positive class.
        - log_yscale :
            If True, display the y-axis of the boxplot figure (see plots description below) in
            logarithmic scale.
        - clustering: None | 'dbscan' | 'louvain'
            Post-processing option for selected cell populations. Default is None.
        - tsne_ncell :
            Number of cells to include in t-SNE calculations and plots.
        - regression :
            Whether it is a regression problem.
        - show_filters :
            Whether to plot learned filter weights.

    Returns:
        A list with the indices and corresponding cell filter response thresholds of selected
        discriminative filters. \
        This function also produces a collection of plots for model interpretation.
        These plots are stored in `outdir`. They comprise the following:

        - clustered_filter_weights.pdf :
            Filter weight vectors from all trained networks that pass a validation accuracy
            threshold, grouped in clusters via hierarchical clustering. Each row corresponds to
            a filter. The last column(s) indicate the weight(s) connecting each filter to the output
            class(es). Indices on the y-axis indicate the filter cluster memberships, as a
            result of the hierarchical clustering procedure.
        - consensus_filter_weights.pdf :
            One representative filter per cluster is chosen (the filter with minimum distance to all
            other memebers of the cluster). We call these selected filters "consensus filters".
        - best_net_weights.pdf :
            Filter weight vectors of the network that achieved the highest validation accuracy.
        - filter_response_differences.pdf :
            Difference in cell filter response between classes for each consensus filter.
            To compute this difference for a filter, we first choose a filter-specific class, that's
            the class with highest output weight connection to the filter. Then we compute the
            average cell filter response (value after the pooling layer) for validation samples
            belonging to the filter-specific class (``v1``) and the average cell filter response
            for validation samples not belonging to the filter-specific class (``v0``).
            The difference is computed as ``v1 - v0``. For regression problems, we cannot compute
            a difference between classes. Instead we compute Kendall's rank correlation coefficient
            between the predictions of each individual filter (value after the pooling layer) and
            the true response values.
            This plot helps decide on a cutoff (``filter_diff_thres`` parameter) for selecting
            discriminative filters.
        - tsne_all_cells.png :
            Marker distribution overlaid on t-SNE map. 

        In addition, the following plots are produced for each selected filter (e.g. filter ``i``):

        - cdf_filter_i.pdf :
            Cumulative distribution function of cell filter response for filter ``i``. This plot
            helps decide on a cutoff (``filter_response_thres`` parameter) for selecting the
            responding cell population.

        - selected_population_distribution_filter_i.pdf :
            Histograms of univariate marker expression profiles for the cell population selected by
            filter ``i`` vs all cells.

        - selected_population_frequencies_filter_i.pdf :
            Boxplot of selected cell population frequencies in samples of the different classes,
            if running a classification problem. For regression settings, a scatter plot of selected
            cell population frequencies vs response variable is generated.

        - tsne_cell_response_filter_i.png :
            Cell filter response overlaid on t-SNE map.

        - tsne_selected_cells_filter_i.png :
            Marker distribution of selected cell population overlaid on t-SNE map.
    """

    # create the output directory
    mkdir_p(outdir)

    # number of measured markers
    nmark = samples[0].shape[1]

    if results['selected_filters'] is not None:
        print 'Loading the weights of consensus filters.'
        filters = results['selected_filters']
    else:
        sys.exit('Consensus filters were not found.')

    if show_filters:
        plot_filters(results, labels, outdir)
    # get discriminative filter indices in consensus matrix
    keep_idx = discriminative_filters(
        results,
        outdir,
        filter_diff_thres,
        positive_filters_only=positive_filters_only,
        show_filters=show_filters)

    # encode the sample and sample-phenotype for each cell
    sample_sizes = []
    per_cell_ids = []
    for i, x in enumerate(samples):
        sample_sizes.append(x.shape[0])
        per_cell_ids.append(i * np.ones(x.shape[0]))
    # for each selected filter, plot the selected cell population
    x = np.vstack(samples)
    z = np.hstack(per_cell_ids)

    if results['scaler'] is not None:
        x = results['scaler'].transform(x)

    print 'Computing t-SNE projection...'
    tsne_idx = np.random.choice(x.shape[0], tsne_ncell)
    x_for_tsne = x[tsne_idx].copy()
    x_tsne = TSNE(n_components=2).fit_transform(x_for_tsne)
    vmin, vmax = np.zeros(x.shape[1]), np.zeros(x.shape[1])
    for seq_index in range(x.shape[1]):
        vmin[seq_index] = np.percentile(x[:, seq_index], 1)
        vmax[seq_index] = np.percentile(x[:, seq_index], 99)
    fig_path = os.path.join(outdir, 'tsne_all_cells')
    plot_tsne_grid(x_tsne,
                   x_for_tsne,
                   fig_path,
                   labels=labels,
                   fig_size=(20, 20),
                   point_size=5)

    return_filters = []
    for i_filter in keep_idx:
        w = filters[i_filter, :nmark]
        b = filters[i_filter, nmark]
        g = np.sum(w.reshape(1, -1) * x, axis=1) + b
        g = g * (g > 0)

        # skip a filter if it does not select any cell
        if np.max(g) <= 0:
            continue

        ecdf = sm.distributions.ECDF(g)
        gx = np.linspace(np.min(g), np.max(g))
        gy = ecdf(gx)
        plt.figure()
        sns.set_style('whitegrid')
        a = plt.step(gx, gy)
        t = filter_response_thres
        # set a threshold to the CDF gradient?
        if response_grad_cutoff is not None:
            by = np.array(a[0].get_ydata())[::-1]
            bx = np.array(a[0].get_xdata())[::-1]
            b_diff_idx = np.where(by[:-1] - by[1:] >= response_grad_cutoff)[0]
            if len(b_diff_idx) > 0:
                t = bx[b_diff_idx[0] + 1]
        plt.plot((t, t), (np.min(gy), 1.), 'r--')
        plt.xlabel('Cell filter response')
        plt.ylabel('Cumulative distribution function (CDF)')
        sns.despine()
        plt.savefig(os.path.join(outdir, 'cdf_filter_%d.pdf' % i_filter),
                    format='pdf')
        plt.clf()
        plt.close()

        condition = g > t
        x1 = x[condition]
        z1 = z[condition]
        g1 = g[condition]

        # skip a filter if it does not select any cell with the new cutoff threshold
        if x1.shape[0] == 0:
            continue

        # else add the filters to selected filters
        return_filters.append((i_filter, t))
        # t-SNE plots for characterizing the selected cell population
        fig_path = os.path.join(outdir,
                                'tsne_cell_response_filter_%d.png' % i_filter)
        plot_2D_map(x_tsne, g[tsne_idx], fig_path, s=5)
        # overlay marker values on TSNE map for selected cells
        fig_path = os.path.join(outdir,
                                'tsne_selected_cells_filter_%d' % i_filter)
        g_tsne = g[tsne_idx]
        x_pos = x_for_tsne[g_tsne > t]
        x_tsne_pos = x_tsne[g_tsne > t]
        plot_tsne_selection_grid(x_tsne_pos,
                                 x_pos,
                                 x_tsne,
                                 vmin,
                                 vmax,
                                 fig_path=fig_path,
                                 labels=labels,
                                 fig_size=(20, 20),
                                 s=5,
                                 suffix='png')

        if clustering is None:
            suffix = 'filter_%d' % i_filter
            plot_selected_subset(x1, z1, x, labels, sample_sizes, phenotypes,
                                 outdir, suffix, stat_test, log_yscale,
                                 group_a, group_b, group_names, regression)
        else:
            if clustering == 'louvain':
                print 'Creating a k-NN graph with %d/%d cells...' % (
                    x1.shape[0], x.shape[0])
                k = 10
                G = create_graph(x1, k, g1, add_filter_response)
                print 'Identifying cell communities...'
                cl = G.community_fastgreedy()
                clusters = np.array(cl.as_clustering().membership)
            else:
                print 'Clustering using the dbscan algorithm...'
                eps = set_dbscan_eps(x1,
                                     os.path.join(outdir, 'kNN_distances.png'))
                cl = DBSCAN(eps=eps, min_samples=5, metric='l1')
                clusters = cl.fit_predict(x1)

            # discard outliers, i.e. clusters with very few cells
            c = Counter(clusters)
            cluster_ids = []
            min_cells = int(min_cluster_freq * x1.shape[0])
            for key, val in c.items():
                if (key != -1) and (val > min_cells):
                    cluster_ids.append(key)

            num_clusters = len(cluster_ids)
            scores = np.zeros(num_clusters)
            for j in range(num_clusters):
                cl_id = cluster_ids[j]
                scores[j] = np.mean(g1[clusters == cl_id])

            # keep the communities with high cell filter response
            sorted_idx = np.argsort(scores)[::-1]
            scores = scores[sorted_idx]
            keep_idx_comm = [sorted_idx[0]]
            for i in range(1, num_clusters):
                if (scores[i - 1] -
                        scores[i]) < percentage_drop_cluster * scores[i - 1]:
                    keep_idx_comm.append(sorted_idx[i])
                else:
                    break

            for j in keep_idx_comm:
                cl_id = cluster_ids[j]
                xc = x1[clusters == cl_id]
                zc = z1[clusters == cl_id]
                suffix = 'filter_%d_cluster_%d' % (i_filter, cl_id)
                plot_selected_subset(xc, zc, x, labels, sample_sizes,
                                     phenotypes, outdir, suffix, stat_test,
                                     log_yscale, group_a, group_b, group_names,
                                     regression)
    print 'Done.\n'
    return return_filters
Ejemplo n.º 13
0
def plot_results_2class(results,
                        samples,
                        phenotypes,
                        labels,
                        outdir,
                        percentage_drop_filter=.2,
                        filter_response_thres=0,
                        response_grad_cutoff=None,
                        group_a='group a',
                        group_b='group b',
                        stat_test=None,
                        positive_filters_only=False,
                        log_yscale=False,
                        clustering=None,
                        add_filter_response=False,
                        percentage_drop_cluster=.1,
                        min_cluster_freq=0.2,
                        plot_tsne=False,
                        tsne_ncell=3000):
    """ Plots the results of a CellCnn analysis for a 2-class classification problem.

    Args:
        - results :
            Dictionary containing the results of a CellCnn analysis.
        - samples :
            Samples from which to visualize the selected cell populations.
        - phenotypes :
            List of phenotypes corresponding to the provided `samples`.
        - labels :
            Names of measured markers.
        - outdir :
            Output directory where the generated plots will be stored.
        - percentage_drop_filter :
            Threshold that defines which filters are most discriminative. Given an array ``diff``
            of cell filter response differences sorted in decreasing order, keep a filter
            ``i, i >= 0`` if it holds that
            ``diff[i-1] - diff[i] < percentage_drop_filter * diff[i-1]``.
        - filter_response_thres :
            Threshold for choosing a responding cell population. Default is 0.
        - response_grad_cutoff :
            Threshold on the gradient of the cell filter response CDF, might be useful for defining
            the selected cell population.
        - group_a :
            Name of the first class.
        - group_b :
            Name of the second class.
        - stat_test: None | 'ttest' | 'mannwhitneyu'
            Optionally, perform a statistical test on selected cell population frequencies between
            the two groups and report the corresponding p-value on the boxplot figure
            (see plots description below). Default is None.
        - positive_filters_only :
            If True, only consider filters associated with higher cell population frequency in the
            positive class.
        - log_yscale :
            If True, display the y-axis of the boxplot figure (see plots description below) in
            logarithmic scale.
        - clustering: None | 'dbscan' | 'louvain'
            Post-processing option for selected cell populations. Default is None.

    Returns:
        A list with the indices and corresponding cell filter response thresholds of selected
        discriminative filters. \
        This function also produces a collection of plots for model interpretation.
        These plots are stored in `outdir`. They comprise the following:

        - clustered_filter_weights.pdf :
            Filter weight vectors from all trained networks that pass a validation accuracy
            threshold, grouped in clusters via hierarchical clustering. Each row corresponds to
            a filter. The last column indicates the weight connecting each filter to the output
            positive class. Indices on the y-axis indicate the filter cluster memberships, as a
            result of the hierarchical clustering procedure.
        - consensus_filter_weights.pdf :
            One representative filter per cluster is chosen (the filter with minimum distance to all
            other memebers of the cluster). We call these selected filters "consensus filters".
        - best_net_weights.pdf :
            Filter weight vectors of the network that achieved the highest validation accuracy.
        - filter_response_differences.pdf :
            Difference in cell filter response between the two classes for each consensus filter.
            This plot helps decide on a cutoff (``percentage_drop_filter`` parameter) for selecting
            discriminative filters.

        In addition, the following plots are produced for each selected filter (e.g. filter ``i``):

        - cdf_filter_i.pdf :
            Cumulative distribution function of cell filter response for filter ``i``. This plot
            helps decide on a cutoff (``filter_response_thres`` parameter) for selecting the
            responding cell population.

        - selected_population_distribution_filter_i.pdf :
            Histograms of univariate marker expression profiles for the cell population selected by
            filter ``i`` vs all cells.

        - selected_population_boxplot_filter_i.pdf :
            Boxplot of selected cell population frequencies in samples of the two classes.
    """

    # create the output directory
    mkdir_p(outdir)

    # number of measured markers
    nmark = samples[0].shape[1]

    # plot the filter weights of the best network
    w_best = results['w_best_net'][:, range(nmark) + [-1]]
    fig_path = os.path.join(outdir, 'best_net_weights.pdf')
    plot_nn_weights(w_best, labels + ['output'], fig_path, fig_size=(10, 10))

    # plot the selected filters
    if results['selected_filters'] is not None:
        print 'Loading the weights of consensus filters.'
        w = results['selected_filters'][:, range(nmark) + [-1]]
        fig_path = os.path.join(outdir, 'consensus_filter_weights.pdf')
        plot_nn_weights(w, labels + ['output'], fig_path, fig_size=(10, 10))
        filters = results['selected_filters']
    else:
        print 'Consensus filters were not found, using the weights of the best network instead.'
        filters = results['w_best_net']

    # plot the filter clustering
    cl = results['clustering_result']
    cl_w = cl['w'][:, range(nmark) + [-1]]
    fig_path = os.path.join(outdir, 'clustered_filter_weights.pdf')
    plot_nn_weights(cl_w,
                    labels + ['output'],
                    fig_path,
                    row_linkage=cl['cluster_linkage'],
                    y_labels=cl['cluster_assignments'],
                    fig_size=(10, 10))

    # select the discriminative filters based on the validation set
    if 'dist' in results:
        dist = results['dist']
        dist = np.max(dist, axis=1)
    # if no validation set was provided,
    # select filters based on the magnitude of their output weight
    else:
        dist = abs(filters[:, -1])
    # do we want to consider negative filters?
    if positive_filters_only:
        dist = dist * np.sign(filters[:, -1])
    sorted_idx = np.argsort(dist)[::-1]
    dist = dist[sorted_idx]
    keep_idx = [sorted_idx[0]]
    for i in range(1, dist.shape[0]):
        if (dist[i - 1] - dist[i]) < percentage_drop_filter * dist[i - 1]:
            keep_idx.append(sorted_idx[i])
        else:
            break
    plt.figure()
    sns.set_style('whitegrid')
    plt.plot(range(len(dist)), dist, '--')
    plt.xticks(range(len(dist)), ['filter %d' % i for i in sorted_idx],
               rotation='vertical')
    sns.despine()
    plt.savefig(os.path.join(outdir, 'filter_response_differences.pdf'),
                format='pdf')
    plt.clf()
    plt.close()

    # encode the sample and sample-phenotype for each cell
    sample_sizes = []
    per_cell_ids = []
    for i, x in enumerate(samples):
        sample_sizes.append(x.shape[0])
        per_cell_ids.append(i * np.ones(x.shape[0]))
    # for each selected filter, plot the selected cell population
    x = np.vstack(samples)
    z = np.hstack(per_cell_ids)

    if results['scaler'] is not None:
        x = results['scaler'].transform(x)

    return_filters = []
    for i_filter in keep_idx:
        w = filters[i_filter, :nmark]
        b = filters[i_filter, nmark]
        g = np.sum(w.reshape(1, -1) * x, axis=1) + b
        g = g * (g > 0)

        ecdf = sm.distributions.ECDF(g)
        gx = np.linspace(np.min(g), np.max(g))
        gy = ecdf(gx)
        plt.figure()
        sns.set_style('whitegrid')
        a = plt.step(gx, gy)
        t = filter_response_thres
        # set a threshold to the CDF gradient?
        if response_grad_cutoff is not None:
            by = np.array(a[0].get_ydata())[::-1]
            bx = np.array(a[0].get_xdata())[::-1]
            b_diff_idx = np.where(by[:-1] - by[1:] >= response_grad_cutoff)[0]
            if len(b_diff_idx) > 0:
                t = bx[b_diff_idx[0] + 1]
        plt.plot((t, t), (np.min(gy), 1.), 'r--')
        sns.despine()
        plt.savefig(os.path.join(outdir, 'cdf_filter_%d.pdf' % i_filter),
                    format='pdf')
        plt.clf()
        plt.close()

        condition = g > t
        x1 = x[condition]
        z1 = z[condition]
        g1 = g[condition]

        # skip a filter if it does not select any cell
        if x1.shape[0] == 0:
            continue
        else:
            return_filters.append((i_filter, t))
            # plot a cell filter response map for the filter
            # do it on a subset of the cells, so that it is relatively fast
            if plot_tsne:
                proj = TSNE(n_components=2, random_state=0)
                x_2D = proj.fit_transform(x[:tsne_ncell])
                fig_path = os.path.join(
                    outdir, 'cell_filter_response_%d.png' % i_filter)
                plot_2D_map(
                    x_2D,
                    MinMaxScaler().fit_transform(g.reshape(-1,
                                                           1))[:tsne_ncell],
                    fig_path)

        if clustering is None:
            suffix = 'filter_%d' % i_filter
            plot_selected_subset(x1, z1, x, labels, sample_sizes, phenotypes,
                                 outdir, suffix, stat_test, group_a, group_b,
                                 log_yscale)
        else:
            if clustering == 'louvain':
                print 'Creating a k-NN graph with %d/%d cells...' % (
                    x1.shape[0], x.shape[0])
                k = 10
                G = create_graph(x1, k, g1, add_filter_response)
                print 'Identifying cell communities...'
                cl = G.community_fastgreedy()
                clusters = np.array(cl.as_clustering().membership)
            else:
                print 'Clustering using the dbscan algorithm...'
                eps = set_dbscan_eps(x1,
                                     os.path.join(outdir, 'kNN_distances.png'))
                cl = DBSCAN(eps=eps, min_samples=5, metric='l1')
                clusters = cl.fit_predict(x1)

            # discard outliers, i.e. clusters with very few cells
            c = Counter(clusters)
            cluster_ids = []
            min_cells = int(min_cluster_freq * x1.shape[0])
            for key, val in c.items():
                if (key != -1) and (val > min_cells):
                    cluster_ids.append(key)

            num_clusters = len(cluster_ids)
            scores = np.zeros(num_clusters)
            for j in range(num_clusters):
                cl_id = cluster_ids[j]
                scores[j] = np.mean(g1[clusters == cl_id])

            # keep the communities with high cell filter response
            sorted_idx = np.argsort(scores)[::-1]
            scores = scores[sorted_idx]
            keep_idx_comm = [sorted_idx[0]]
            for i in range(1, num_clusters):
                if (scores[i - 1] -
                        scores[i]) < percentage_drop_cluster * scores[i - 1]:
                    keep_idx_comm.append(sorted_idx[i])
                else:
                    break

            for j in keep_idx_comm:
                cl_id = cluster_ids[j]
                xc = x1[clusters == cl_id]
                zc = z1[clusters == cl_id]
                suffix = 'filter_%d_cluster_%d' % (i_filter, cl_id)
                plot_selected_subset(xc, zc, x, labels, sample_sizes,
                                     phenotypes, outdir, suffix, stat_test,
                                     group_a, group_b, log_yscale)
    print 'Found %d discriminative filter(s): ' % len(return_filters), zip(
        *return_filters)[0]
    return return_filters
Ejemplo n.º 14
0
def main():

    # stimulation conditions in this experiment
    STIM = [
        'Basal1', 'Basal2', 'AICAR', 'Flt3L', 'G-CSF', 'GM-CSF', 'IFNa',
        'IFNg', 'IL-10', 'IL-27', 'IL-3', 'IL-6', 'PMAiono', 'PVO4', 'SCF',
        'TNFa', 'TPO'
    ]
    full_stim_names = ['_'.join(['NoDrug', stim])
                       for stim in STIM] + ['BEZ-235_Basal1']

    # all available channels in this experiment
    channels = [
        'Time', 'Cell_length', 'DNA1', 'DNA2', 'BC1', 'BC2', 'BC3', 'BC4',
        'BC5', 'BC6', 'pPLCg2', 'CD19', 'p4EBP1', 'CD11b', 'pAMPK', 'pSTAT3',
        'CD34', 'pSTAT5', 'pS6', 'pCREB', 'pc-Cbl', 'CD45', 'CD123', 'pSTAT1',
        'pZap70-Syk', 'CD33', 'CD47', 'pAKT', 'CD7', 'CD15', 'pRb', 'CD44',
        'CD38', 'pErk1-2', 'CD3', 'pP38', 'CD117', 'cCaspase3', 'HLA-DR',
        'CD64', 'CD41', 'Viability', 'PhenoGraph'
    ]

    # which markers should be kept for further analysis
    labels = [
        'CD19', 'CD11b', 'CD34', 'CD45', 'CD123', 'CD33', 'CD47', 'CD7',
        'CD15', 'CD44', 'CD38', 'CD3', 'CD117', 'HLA-DR', 'CD64', 'CD41'
    ]

    # which columns correspond to the interesting markers
    marker_idx = [channels.index(label) for label in labels]

    # data directory
    FCS_DATA_PATH = '/Volumes/biol_imsb_claassen_s1/eiriniar/Data/phenograph_data'

    # read the data from healthy samples
    healthy_keys = ['H' + str(i) for i in range(1, 6)]
    D = read_healthy_data(FCS_DATA_PATH, healthy_keys, full_stim_names,
                          marker_idx)
    aml_dict = {
        'healthy_BM': [(key, D[key]) for key in ['H1', 'H2', 'H3', 'H5', 'H4']]
    }

    # map .txt files back to patient identifiers
    mapping = {
        0: 'SJ10',
        2: 'SJ12',
        3: 'SJ13',
        4: 'SJ14',
        5: 'SJ15',
        6: 'SJ16',
        8: 'SJ1',
        9: 'SJ1',
        10: 'SJ2',
        11: 'SJ2',
        12: 'SJ3',
        13: 'SJ3',
        14: 'SJ4',
        15: 'SJ5',
        17: 'SJ7'
    }

    # read the data from AML samples
    # gated blast populations were downloaded as .txt files from Cytobank
    # CAREFUL when reading these .txt files:
    # they are tab-separated and include an extra first column (cell index)
    AML_files = glob.glob(os.path.join(FCS_DATA_PATH, 'AML_blasts', '*.txt'))

    # only include patients with sufficiently high blast counts
    # (>10% of total cell counts)
    for sj in [0, 2, 3, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14, 15, 17]:
        fname = AML_files[sj]
        t = pd.read_csv(fname, skiprows=1, sep='\t', index_col=0)
        print[list(t.columns)[ii] for ii in marker_idx]
        data_blasts = ftrans(np.asarray(t)[:, marker_idx], 5)
        if mapping[sj] not in aml_dict:
            aml_dict[mapping[sj]] = data_blasts

    # save the pre-processed dataset
    pickle_dir = os.path.join(cellCnn.__path__[0], 'examples', 'data')
    mkdir_p(pickle_dir)
    pickle_file = os.path.join(pickle_dir, 'AML.pkl')
    aml_dict['labels'] = labels
    with open(pickle_file, 'wb') as f:
        pickle.dump(aml_dict, f, -1)

    return 0
Ejemplo n.º 15
0
from cellCnn.run_CellCnn import train_model
from cellCnn.plotting import visualize_results
from numpy.random import RandomState
from lasagne.random import set_rng as set_lasagne_rng


''' 
    AML.pkl can be downloaded from 
    http://www.imsb.ethz.ch/research/claassen/Software/cellcnn.html
'''

WDIR = os.path.join(cellCnn.__path__[0], 'examples')
LOOKUP_PATH = os.path.join(WDIR, 'data', 'AML.pkl')

OUTDIR = os.path.join(WDIR, 'output', 'AML')
mkdir_p(OUTDIR)

def main():
    
    # set random seed for reproducible results
    seed = 12345
    np.random.seed(seed)
    set_lasagne_rng(RandomState(seed))

    lookup =  pickle.load(open(LOOKUP_PATH, 'rb'))
    labels = lookup['labels']
    healthy_BM = lookup['healthy_BM']
    control_list = [x for (key, x) in healthy_BM[:-1]]
    x_healthy = healthy_BM[-1][1]
    
    # how many AML blast cells to spike-in
Ejemplo n.º 16
0
def main():
    parser = argparse.ArgumentParser()
    # IO-specific
    parser.add_argument('-f', '--fcs', required=True,
                        help='file specifying the FCS file names and corresponding labels')
    parser.add_argument('-m', '--markers', required=True,
                        help='file specifying the names of markers to be used for analysis')
    parser.add_argument('-i', '--indir', default='./',
                        help='directory where input FCS files are located')
    parser.add_argument('-o', '--outdir', default='output',
                        help='directory where output will be generated')
    parser.add_argument('-p', '--plot', action='store_true', default=True,
                        help='whether to plot results ')
    parser.add_argument('--export_selected_cells', action='store_true', default=False,
                        help='whether to export selected cell populations')
    parser.add_argument('--export_csv', action='store_true', default=False,
                        help='whether to export network weights as csv files')
    parser.add_argument('-l', '--load_results', action='store_true', default=False,
                        help='whether to load precomputed results')

    # data preprocessing
    parser.add_argument('--train_perc', type=float, default=0.75,
                        help='percentage of samples to be used for training')
    parser.add_argument('--arcsinh', dest='arcsinh', action='store_true',
                        help='preprocess the data with arcsinh')
    parser.add_argument('--no_arcsinh', dest='arcsinh', action='store_false',
                        help='do not preprocess the data with arcsinh')
    parser.set_defaults(arcsinh=True)
    parser.add_argument('--cofactor', type=int, default=5,
                        help='cofactor for the arcsinh transform')
    parser.add_argument('--scale', dest='scale', action='store_true',
                        help='z-transform features (mean=0, std=1) prior to training')
    parser.add_argument('--no_scale', dest='scale', action='store_false',
                        help='do not z-transform features (mean=0, std=1) prior to training')
    parser.set_defaults(scale=True)
    parser.add_argument('--quant_normed', action='store_true', default=False,
                        help='input data has been pre-processed via quantile normalization')

    # multi-cell input specific
    parser.add_argument('--ncell', type=int, help='number of cells per multi-cell input',
                        default=200)
    parser.add_argument('--nsubset', type=int, help='number of multi-cell inputs',
                        default=1000)
    parser.add_argument('--per_sample', action='store_true', default=False,
                        help='whether nsubset refers to each class or each sample')
    parser.add_argument('--subset_selection', choices=['random', 'outlier'], default='random',
                        help='generate random or outlier-enriched multi-cell inputs')

    # neural network specific
    parser.add_argument('--maxpool_percentages', nargs='+', type=float,
                        help='list of choices (percentage of multi-cell input) for top-k max pooling',
                        default=[0.01, 1, 5, 20, 100])
    parser.add_argument('--nfilter_choice', nargs='+', type=int,
                        help='list of choices for number of filters', default=range(3, 10))
    parser.add_argument('--learning_rate', type=float, default=0.005,
                        help='learning rate for the Adam optimization algorithm')
    parser.add_argument('--coeff_l1', type=float, default=0,
                        help='coefficient for L1 weight regularization')
    parser.add_argument('--coeff_l2', type=float, default=0.0001,
                        help='coefficient for L2 weight regularization')
    parser.add_argument('--coeff_activity', type=float, default=0,
                        help='coefficient for regularizing the activity at each filter')
    parser.add_argument('--max_epochs', type=int, default=20,
                        help='maximum number of iterations through the data')
    parser.add_argument('--patience', type=int, default=5,
                        help='number of epochs before early stopping')

    # analysis specific
    parser.add_argument('--seed', type=int, default=1234,
                        help='random seed')
    parser.add_argument('--nrun', type=int, default=15,
                        help='number of neural network configurations to try (should be >= 3)')
    parser.add_argument('--regression', action='store_true', default=False,
                        help='whether it is a regression problem (default is classification)')
    parser.add_argument('--dendrogram_cutoff', type=float, default=.4,
                        help='cutoff for hierarchical clustering of filter weights')
    parser.add_argument('--accur_thres', type=float, default=.9,
                        help='keep filters from models achieving at least this accuracy ' \
                             ' (or at least from the best 3 models)')
    parser.add_argument('-v', '--verbose', type=int, choices=[0, 1], default=1,
                        help='output verbosity')

    # plot specific
    parser.add_argument('--filter_diff_thres', type=float, default=0.2,
                        help='threshold that defines which filters are discriminative')
    parser.add_argument('--filter_response_thres', type=float, default=0,
                        help='threshold that defines the selected cell population per filter')
    parser.add_argument('--positive_filters_only', action='store_true', default=False,
                        help='whether to only consider filters associated with higher cell ' \
                             'population frequencies in the positive class')
    parser.add_argument('--stat_test', choices=[None, 'ttest', 'mannwhitneyu'],
                        help='statistical test for comparing cell population frequencies of two ' \
                             'groups of samples')
    parser.add_argument('--group_a', default='group A',
                        help='name of the first class')
    parser.add_argument('--group_b', default='group B',
                        help='name of the second class')
    args = parser.parse_args()

    # read in the data
    fcs_info = np.array(pd.read_csv(args.fcs, sep=','))
    marker_names = list(pd.read_csv(args.markers, sep=',').columns)
    # if the samples have already been pre-processed via quantile normalization
    # we should not perform arcsinh transformation
    if args.quant_normed:
        args.arcsinh = False
    samples, phenotypes = get_data(args.indir, fcs_info, marker_names,
                                   args.arcsinh, args.cofactor)

    if not args.load_results:
        # generate training/validation sets
        np.random.seed(args.seed)
        val_perc = 1 - args.train_perc
        n_splits = int(1. / val_perc)
        skf = StratifiedKFold(n_splits=n_splits, shuffle=True)
        train, val = next(skf.split(np.zeros((len(phenotypes), 1)), phenotypes))
        train_samples = [samples[i] for i in train]
        valid_samples = [samples[i] for i in val]
        train_phenotypes = [phenotypes[i] for i in train]
        valid_phenotypes = [phenotypes[i] for i in val]

        # run CellCnn
        model = CellCnn(ncell=args.ncell,
                        nsubset=args.nsubset,
                        per_sample=args.per_sample,
                        subset_selection=args.subset_selection,
                        scale=args.scale,
                        quant_normed=args.quant_normed,
                        maxpool_percentages=args.maxpool_percentages,
                        nfilter_choice=args.nfilter_choice,
                        nrun=args.nrun,
                        regression=args.regression,
                        learning_rate=args.learning_rate,
                        coeff_l1=args.coeff_l1,
                        coeff_l2=args.coeff_l2,
                        coeff_activity=args.coeff_activity,
                        max_epochs=args.max_epochs,
                        patience=args.patience,
                        dendrogram_cutoff=args.dendrogram_cutoff,
                        accur_thres=args.accur_thres,
                        verbose=args.verbose)
        model.fit(train_samples=train_samples, train_phenotypes=train_phenotypes,
                  valid_samples=valid_samples, valid_phenotypes=valid_phenotypes,
                  outdir=args.outdir)
        # save results for subsequent analysis
        results = model.results
        pickle.dump(results, open(os.path.join(args.outdir, 'results.pkl'), 'w'))
    else:
        results = pickle.load(open(os.path.join(args.outdir, 'results.pkl'), 'r'))

    if args.export_csv:
        save_results(results, args.outdir, marker_names)

    # plot results
    if args.plot or args.export_selected_cells:
        mkdir_p(os.path.join(args.outdir, 'plots'))
        filter_info = plot_results(results, samples, phenotypes,
                                   marker_names, os.path.join(args.outdir, 'plots'),
                                   filter_diff_thres=args.filter_diff_thres,
                                   filter_response_thres=args.filter_response_thres,
                                   positive_filters_only=args.positive_filters_only,
                                   stat_test=args.stat_test,
                                   group_a=args.group_a, group_b=args.group_b)
        if args.export_selected_cells:
            csv_dir = os.path.join(args.outdir, 'selected_cells')
            mkdir_p(csv_dir)
            nfilter = len(filter_info)
            sample_names = [name.split('.fcs')[0] for name in list(fcs_info[:, 0])]
            # for each sample
            for x, x_name in zip(samples, sample_names):
                flags = np.zeros((x.shape[0], 2*nfilter))
                columns = []
                # for each filter
                for i, (filter_idx, thres) in enumerate(filter_info):
                    flags[:, 2*i:2*(i+1)] = get_selected_cells(
                        results['selected_filters'][filter_idx], x, results['scaler'], thres, True)
                    columns += ['filter_%d_continuous' % filter_idx, 'filter_%d_binary' % filter_idx]
                df = pd.DataFrame(flags, columns=columns)
                df.to_csv(os.path.join(csv_dir, x_name+'_selected_cells.csv'), index=False)
Ejemplo n.º 17
0
import cellCnn
from cellCnn.utils import mkdir_p
from cellCnn.run_CellCnn import train_model
from cellCnn.plotting import visualize_results
from numpy.random import RandomState
from lasagne.random import set_rng as set_lasagne_rng
''' 
    ALL.pkl can be downloaded from 
    http://www.imsb.ethz.ch/research/claassen/Software/cellcnn.html
'''

WDIR = os.path.join(cellCnn.__path__[0], 'examples')
LOOKUP_PATH = os.path.join(WDIR, 'data', 'ALL.pkl')

OUTDIR = os.path.join(WDIR, 'output', 'ALL')
mkdir_p(OUTDIR)


def main():

    # set random seed for reproducible results
    seed = 12345
    np.random.seed(seed)
    set_lasagne_rng(RandomState(seed))

    lookup = pickle.load(open(LOOKUP_PATH, 'rb'))
    labels = lookup['labels']
    x_control = lookup['control']
    x_healthy = lookup['healthy']
    x_ALL = shuffle(lookup['ALL'])
Ejemplo n.º 18
0
def main(args):
    # define input and output directories
    WDIR = os.path.join('.')
    FCS_DATA_PATH = args.data_root

    # define output directory
    OUTDIR = os.path.join(WDIR,
                          f'output_{args.data_name}_{args.ncell}_{args.seed}')
    mkdir_p(OUTDIR)

    train_csv_file = os.path.join(FCS_DATA_PATH, 'train', 'train_labels.csv')
    test_csv_file = os.path.join(FCS_DATA_PATH, 'test', 'test_labels.csv')
    marker_file = os.path.join(FCS_DATA_PATH, 'marker.csv')

    # set random seed for reproducible results
    co_factor = 5.0
    np.random.seed(args.seed)

    if args.pkl:
        with open(os.path.join(FCS_DATA_PATH, 'train_HIV.pkl'), 'rb') as f:
            _data = pickle.load(f)
            train_samples, train_phenotypes = _data['sample'], _data[
                'phenotype']
        with open(os.path.join(FCS_DATA_PATH, 'test_HIV.pkl'), 'rb') as f:
            _data = pickle.load(f)
            test_samples, test_phenotypes = _data['sample'], _data['phenotype']
    else:
        train_samples, train_phenotypes = load_fcs_dataset(
            train_csv_file, marker_file, co_factor)
        test_samples, test_phenotypes = load_fcs_dataset(
            test_csv_file, marker_file, co_factor)

    print("data io finished")

    # run a CellCnn analysis
    cellcnn = CellCnn(ncell=args.ncell,
                      nsubset=args.nsubset,
                      max_epochs=args.max_epochs,
                      nrun=3,
                      verbose=0)
    cellcnn.fit(train_samples=train_samples,
                train_phenotypes=train_phenotypes,
                outdir=OUTDIR)

    # make predictions on the test cohort
    test_pred_cellcnn = cellcnn.predict(test_samples)
    test_pred_label_cellcnn = [
        1 if p > 0.5 else 0 for p in test_pred_cellcnn[:, 1]
    ]

    # look at the test set predictions
    # print('\nModel predictions:\n', test_pred_cellcnn)

    # and the true phenotypes of the test samples
    print('\nPred phenotypes:\n', test_pred_label_cellcnn)
    print('\nTrue phenotypes:\n', test_phenotypes)

    # calculate area under the ROC curve for the test set
    test_acc_cellcnn = sum(
        np.array(test_pred_label_cellcnn) == np.array(test_phenotypes)) / len(
            test_phenotypes)
    test_fpr, test_tpr, _ = roc_curve(test_phenotypes,
                                      test_pred_cellcnn[:, 1],
                                      pos_label=1)
    test_auc_cellcnn = roc_auc_score(test_phenotypes, test_pred_cellcnn[:, 1])
    print("test acc of cellcnn: ", test_acc_cellcnn)
    print("test auc of cellcnn: ", test_auc_cellcnn)

    test_stat = {
        'test_acc': test_acc_cellcnn,
        'test_auc': test_auc_cellcnn,
        'fpr': test_fpr,
        'tpr': test_tpr
    }

    with open(os.path.join(OUTDIR, 'test_result.pkl'), 'wb') as f:
        pickle.dump(test_stat, f)
Ejemplo n.º 19
0
	AML_files = glob.glob(os.path.join(FCS_DATA_PATH, 'AML_blasts', '*.txt'))
   
	# only include patients with sufficiently high blast counts 
	# (>10% of total cell counts)
	for sj in [0,2,3,4,5,6,8,9,10,11,12,13,14,15,17]:
		fname = AML_files[sj]
		t = pd.read_csv(fname, skiprows=1, sep='\t', index_col=0)
		print [list(t.columns)[ii] for ii in marker_idx]
		data_blasts = ftrans(np.asarray(t)[:, marker_idx], 5)
		if mapping[sj] not in aml_dict:
			aml_dict[mapping[sj]] = data_blasts

	
	# save the pre-processed dataset
	pickle_dir = os.path.join(cellCnn.__path__[0], 'examples', 'data')
	mkdir_p(pickle_dir)
	pickle_file = os.path.join(pickle_dir, 'AML.pkl')
	aml_dict['labels'] = labels
	with open(pickle_file, 'wb') as f:
		pickle.dump(aml_dict, f, -1)

	return 0


if __name__ == '__main__':
	try:
		main()
	except KeyboardInterrupt:
		sys.stderr.write("User interrupt!\n")
		sys.exit(-1)