Ejemplo n.º 1
0
def subset_array(input_pkl, cpg_pkl, output_pkl):
    """Only retain certain number of CpGs from methylation array."""
    import numpy as np, pickle
    os.makedirs(output_pkl[:output_pkl.rfind('/')], exist_ok=True)
    cpgs = pickle.load(open(cpg_pkl, 'rb'))
    MethylationArray.from_pickle(input_pkl).subset_cpgs(cpgs).write_pickle(
        output_pkl)
Ejemplo n.º 2
0
def ref_free_cell_deconv(train_pkl, test_pkl, cell_type_columns, n_cell_types,
                         analysis):
    """Reference free cell type deconvolution"""
    import rpy2.robjects as robjects
    from rpy2.robjects.packages import importr
    from rpy2.robjects import pandas2ri, numpy2ri
    pandas2ri.activate()
    train_methyl_array, test_methyl_array = MethylationArray.from_pickle(
        train_pkl), MethylationArray.from_pickle(test_pkl)
    #robjects.r('source("https://raw.githubusercontent.com/rtmag/refactor/master/R/refactor.R")')
    importr('RefFreeEWAS')  # add edec, medecom
    if cell_type_columns[0] != '':
        mean_cell_type = train_methyl_array.pheno[list(
            cell_type_columns)].mean(axis=0)
        n_cell_types = len(mean_cell_type)
    else:
        mean_cell_type = robjects.r('NULL')
    run_reffree_cell_mix = robjects.r("""function (train_beta,test_beta,k) {
                    train_beta = as.matrix(train_beta)
                    return(RefFreeCellMix(train_beta,mu0=RefFreeCellMixInitialize(train_beta, K = k, method = "ward"),K=k,iters=10,Yfinal=test_beta,verbose=TRUE)$mu)
                    }""")
    if analysis == 'reffreecellmix':
        results = run_reffree_cell_mix(train_methyl_array.beta.T,
                                       test_methyl_array.beta.T, n_cell_types)
    # FINISH
    print(results)
Ejemplo n.º 3
0
def create_external_validation_set(train_pkl, query_pkl, output_pkl,
                                   cpg_replace_method):
    """Create external validation set containing same CpGs as training set."""
    import numpy as np, pandas as pd
    os.makedirs(output_pkl[:output_pkl.rfind('/')], exist_ok=True)
    ref_methyl_array = MethylationArray.from_pickle(train_pkl)
    ref_cpgs = np.array(list(ref_methyl_array.beta))
    query_methyl_array = MethylationArray.from_pickle(query_pkl)
    query_cpgs = np.array(list(query_methyl_array.beta))
    cpg_diff = np.setdiff1d(ref_cpgs, query_cpgs)
    if cpg_replace_method == 'mid':
        background = np.ones(
            (query_methyl_array.beta.shape[0], len(cpg_diff))) * 0.5
    elif cpg_replace_method == 'background':
        background = np.ones(
            (query_methyl_array.beta.shape[0],
             len(cpg_diff))) * query_methyl_array.beta.mean().mean()
    concat_df = pd.DataFrame(background,
                             index=query_methyl_array.beta.index,
                             columns=cpg_diff)
    query_methyl_array.beta = pd.concat([
        query_methyl_array.beta.loc[:, np.intersect1d(ref_cpgs, query_cpgs)],
        concat_df
    ],
                                        axis=1).loc[:, ref_cpgs]
    query_methyl_array.write_pickle(output_pkl)
Ejemplo n.º 4
0
def counts(input_pkl, key):
    """Return categorical breakdown of phenotype column."""
    if input_pkl.endswith('.pkl'):
        MethylationArray.from_pickle(input_pkl).categorical_breakdown(key)
    else:
        for input_pkl in glob.glob(join(input_pkl, '*.pkl')):
            print(input_pkl)
            MethylationArray.from_pickle(input_pkl).categorical_breakdown(key)
Ejemplo n.º 5
0
def run_svc(train_pkl, val_pkl, test_pkl, series=False, outcome_col='Disease_State', num_random_search=0):
    train_methyl_array, val_methyl_array, test_methyl_array = MethylationArray.from_pickle(train_pkl), MethylationArray.from_pickle(val_pkl), MethylationArray.from_pickle(test_pkl)
    umap = UMAP(n_components=100)
    umap.fit(train_methyl_array.beta)
    train_methyl_array.beta = pd.DataFrame(umap.transform(train_methyl_array.beta.values),index=train_methyl_array.return_idx())
    val_methyl_array.beta = pd.DataFrame(umap.transform(val_methyl_array.beta),index=val_methyl_array.return_idx())
    test_methyl_array.beta = pd.DataFrame(umap.transform(test_methyl_array.beta),index=test_methyl_array.return_idx())

    model = SVC
    model = MachineLearning(model,options={'penalty':'l2','verbose':3,'n_jobs':35,'class_weight':'balanced'},grid={'C':[1,10,100,1000], 'gamma':[1,0.1,0.001,0.0001], 'kernel':['linear','rbf']},
                            n_eval=num_random_search,
                            series=series,
                            labelencode=True,
                            verbose=True)

    sklearn_model=model.fit(train_methyl_array,val_methyl_array,outcome_col)
    pickle.dump(sklearn_model,open('sklearn_model.p','wb'))

    y_pred = model.predict(test_methyl_array)
    pd.DataFrame(np.hstack((y_pred[:,np.newaxis],test_methyl_array.pheno[outcome_col].values[:,np.newaxis])),index=test_methyl_array.return_idx(),columns=['y_pred','y_true']).to_csv('SklearnPredictions.csv')

    original, std_err, (low_ci,high_ci) = model.return_outcome_metric(test_methyl_array, outcome_col, accuracy_score, run_bootstrap=True)

    results={'score':original,'Standard Error':std_err, '0.95 CI Low':low_ci, '0.95 CI High':high_ci}

    print('\n'.join(['{}:{}'.format(k,v) for k,v in results.items()]))
Ejemplo n.º 6
0
def methy_array_from_csv(pheno_csv, beta_csv, transpose, output_pkl):
    import pandas as pd
    beta_df = pd.read_csv(beta_csv, index_col=0)
    if transpose:
        beta_df = beta_df.T
    MethylationArray(pheno_df=pd.read_csv(pheno_csv, index_col=0),
                     beta_df=beta_df).write_pickle(output_pkl)
Ejemplo n.º 7
0
def write_cpgs(input_pkl, cpg_pkl):
    """Write CpGs in methylation array to file."""
    import numpy as np, pickle
    os.makedirs(cpg_pkl[:cpg_pkl.rfind('/')], exist_ok=True)
    pickle.dump(
        MethylationArray.from_pickle(input_pkl).return_cpgs(),
        open(cpg_pkl, 'wb'))
Ejemplo n.º 8
0
def stratify(input_pkl, key, output_dir):
    """Split methylation array by key and store."""
    for name, methyl_array in MethylationArray.from_pickle(input_pkl).groupby(
            key):
        out_dir = os.path.join(output_dir,
                               name.replace('/', '-').replace(' ', ''))
        os.makedirs(out_dir, exist_ok=True)
        methyl_array.write_pickle(os.path.join(out_dir, 'methyl_array.pkl'))
Ejemplo n.º 9
0
def run_rand_forest(train_pkl,
                    val_pkl,
                    test_pkl,
                    classify=True,
                    outcome_col='Disease_State',
                    num_random_search=0,
                    series=False):
    train_methyl_array, val_methyl_array, test_methyl_array = MethylationArray.from_pickle(
        train_pkl), MethylationArray.from_pickle(
            val_pkl), MethylationArray.from_pickle(test_pkl)
    model = RandomForestClassifier if classify else RandomForestRegressor
    model = MachineLearning(
        model,
        options={},
        grid=dict(n_estimators=[10, 25, 50, 75, 100, 125, 150, 175, 200],
                  criterion=['gini', 'entropy'],
                  max_features=['auto', 'sqrt'],
                  max_depth=[None, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100],
                  min_samples_split=[2, 5, 10],
                  min_samples_leaf=[1, 2, 4],
                  bootstrap=[True, False]),
        labelencode=True,
        n_eval=num_random_search,
        series=series)

    sklearn_model = model.fit(train_methyl_array, val_methyl_array,
                              outcome_col)

    y_pred = model.predict(test_methyl_array)

    original, std_err, (low_ci, high_ci) = model.return_outcome_metric(
        test_methyl_array,
        'Disease_State',
        accuracy_score if classify else r2_score,
        run_bootstrap=True)

    results = {
        'score': original,
        'Standard Error': std_err,
        '0.95 CI Low': low_ci,
        '0.95 CI High': high_ci
    }

    print('\n'.join(['{}:{}'.format(k, v) for k, v in results.items()]))

    pickle.dump(sklearn_model, open('sklearn_model.p', 'wb'))
Ejemplo n.º 10
0
def modify_pheno_data(input_pkl, input_formatted_sample_sheet, output_pkl):
    """Use another spreadsheet to add more descriptive data to methylarray."""
    import pandas as pd
    os.makedirs(output_pkl[:output_pkl.rfind('/')], exist_ok=True)
    methyl_array = MethylationArray.from_pickle(input_pkl)
    methyl_array.merge_preprocess_sheet(
        pd.read_csv(input_formatted_sample_sheet, header=0))
    methyl_array.write_pickle(output_pkl)
Ejemplo n.º 11
0
def fix_key(input_pkl, key, disease_only, subtype_delimiter, output_pkl):
    """Format certain column of phenotype array in MethylationArray."""
    os.makedirs(output_pkl[:output_pkl.rfind('/')], exist_ok=True)
    methyl_array = MethylationArray.from_pickle(input_pkl)
    methyl_array.remove_whitespace(key)
    if disease_only:
        methyl_array.split_key(key, subtype_delimiter)
    methyl_array.write_pickle(output_pkl)
Ejemplo n.º 12
0
    def to_methyl_array(self):
        """Convert torch dataset back into methylation array, useful because turning into torch dataset can cause the original MethylationArray beta matrix to turn into numpy array, when needs turn back into pandas dataframe.

        Returns
        -------
        MethylationArray

        """
        return MethylationArray(self.methylation_array.pheno,pd.DataFrame(self.methylation_array.beta,index=self.samples,columns=self.features),'')
Ejemplo n.º 13
0
def overwrite_pheno_data(input_pkl, input_formatted_sample_sheet, output_pkl,
                         index_col):
    """Use another spreadsheet to add more descriptive data to methylarray."""
    import pandas as pd
    os.makedirs(output_pkl[:output_pkl.rfind('/')], exist_ok=True)
    methyl_array = MethylationArray.from_pickle(input_pkl)
    methyl_array.overwrite_pheno_data(
        pd.read_csv(input_formatted_sample_sheet,
                    index_col=(index_col if index_col != -1 else None),
                    header=0))
    methyl_array.write_pickle(output_pkl)
Ejemplo n.º 14
0
def train_test_val_split(input_pkl, output_dir, train_percent, val_percent,
                         categorical, disease_only, key, subtype_delimiter):
    """Split methylation array into train, test, val."""
    os.makedirs(output_dir, exist_ok=True)
    methyl_array = MethylationArray.from_pickle(input_pkl)
    train_arr, test_arr, val_arr = methyl_array.split_train_test(
        train_percent, categorical, disease_only, key, subtype_delimiter,
        val_percent)
    train_arr.write_pickle(join(output_dir, 'train_methyl_array.pkl'))
    test_arr.write_pickle(join(output_dir, 'test_methyl_array.pkl'))
    val_arr.write_pickle(join(output_dir, 'val_methyl_array.pkl'))
Ejemplo n.º 15
0
def set_part_array_background(input_pkl, cpg_pkl, output_pkl):
    """Set subset of CpGs from beta matrix to background values."""
    import numpy as np, pickle
    os.makedirs(output_pkl[:output_pkl.rfind('/')], exist_ok=True)
    cpgs = pickle.load(open(cpg_pkl, 'rb'))
    methyl_array = MethylationArray.from_pickle(input_pkl)
    methyl_array.beta.loc[:,
                          cpgs] = methyl_array.beta.loc[:,
                                                        np.setdiff1d(
                                                            list(methyl_array.
                                                                 beta), cpgs
                                                        )].mean().mean()
    methyl_array.write_pickle(output_pkl)
Ejemplo n.º 16
0
def feature_select_train_val_test(input_pkl_dir,
                                  output_dir,
                                  n_top_cpgs=300000,
                                  feature_selection_method='mad',
                                  metric='correlation',
                                  n_neighbors=10,
                                  mad_top_cpgs=0):
    """Filter CpGs by taking x top CpGs with highest mean absolute deviation scores or via spectral feature selection."""
    os.makedirs(output_dir, exist_ok=True)
    train_pkl, val_pkl, test_pkl = join(input_pkl_dir,
                                        'train_methyl_array.pkl'), join(
                                            input_pkl_dir,
                                            'val_methyl_array.pkl'), join(
                                                input_pkl_dir,
                                                'test_methyl_array.pkl')
    train_methyl_array, val_methyl_array, test_methyl_array = MethylationArray.from_pickle(
        train_pkl), MethylationArray.from_pickle(
            val_pkl), MethylationArray.from_pickle(test_pkl)

    methyl_array = MethylationArrays([train_methyl_array,
                                      val_methyl_array]).combine()

    if mad_top_cpgs and feature_selection_method != 'mad':
        methyl_array.feature_select(mad_top_cpgs, 'mad')

    methyl_array.feature_select(n_top_cpgs,
                                feature_selection_method,
                                metric,
                                nn=n_neighbors)

    cpgs = methyl_array.return_cpgs()

    train_arr.subset_cpgs(cpgs).write_pickle(
        join(output_dir, 'train_methyl_array.pkl'))
    test_arr.subset_cpgs(cpgs).write_pickle(
        join(output_dir, 'test_methyl_array.pkl'))
    val_arr.subset_cpgs(cpgs).write_pickle(
        join(output_dir, 'val_methyl_array.pkl'))
Ejemplo n.º 17
0
def remove_snps(input_pkl, output_pkl, array_type):
    """Remove SNPs from methylation array."""
    import numpy as np
    #from rpy2.robjects import pandas2ri
    from pymethylprocess.meffil_functions import r_snp_cpgs
    #pandas2ri.activate()
    os.makedirs(output_pkl[:output_pkl.rfind('/')], exist_ok=True)
    snp_cpgs = r_snp_cpgs(array_type)  #pandas2ri.ri2py()
    methyl_array = MethylationArray.from_pickle(input_pkl)
    methyl_array.beta = methyl_array.beta.loc[:,
                                              np.setdiff1d(
                                                  list(methyl_array.beta
                                                       ), snp_cpgs)]
    methyl_array.write_pickle(output_pkl)
Ejemplo n.º 18
0
def remove_sex(input_pkl, output_pkl, array_type):
    """Remove non-autosomal CpGs."""
    import numpy as np
    #from rpy2.robjects import pandas2ri
    from pymethylprocess.meffil_functions import r_autosomal_cpgs
    #pandas2ri.activate()
    os.makedirs(output_pkl[:output_pkl.rfind('/')], exist_ok=True)
    autosomal_cpgs = r_autosomal_cpgs(array_type)  #pandas2ri.ri2py()
    methyl_array = MethylationArray.from_pickle(input_pkl)
    methyl_array.beta = methyl_array.beta.loc[:,
                                              np.intersect1d(
                                                  list(methyl_array.beta
                                                       ), autosomal_cpgs)]
    methyl_array.write_pickle(output_pkl)
Ejemplo n.º 19
0
def make_new_predictions(test_pkl, model_pickle, batch_size, n_workers, interest_cols, categorical, cuda, categorical_encoder, output_dir):
	"""Run prediction model again to further assess outcome. Only evaluate prediction model."""
	os.makedirs(output_dir,exist_ok=True)
	test_methyl_array = MethylationArray.from_pickle(test_pkl) # generate results pickle to run through classification/regression report
	if cuda:
		model = torch.load(model_pickle)
		model.vae.cuda_on=True
	else:
		model = torch.load(model_pickle,map_location='cpu')
		model.vae.cuda_on=False
	if not categorical:
		test_methyl_array.remove_na_samples(interest_cols if len(interest_cols)>1 else interest_cols[0])
	if os.path.exists(categorical_encoder):
		categorical_encoder=pickle.load(open(categorical_encoder,'rb'))
	else:
		categorical_encoder=None
	test_methyl_dataset = get_methylation_dataset(test_methyl_array,interest_cols,categorical=categorical, predict=True, categorical_encoder=categorical_encoder)
	test_methyl_dataloader = DataLoader(
		dataset=test_methyl_dataset,
		num_workers=n_workers,
		batch_size=min(batch_size,len(test_methyl_dataset)),
		shuffle=False)
	vae_mlp=MLPFinetuneVAE(mlp_model=model,categorical=categorical,cuda=cuda)

	Y_pred, Y_true, latent_projection, _ = vae_mlp.predict(test_methyl_dataloader)

	results = dict(test={})
	results['test']['y_pred'], results['test']['y_true'] = copy.deepcopy(Y_pred), copy.deepcopy(Y_true)

	if categorical:
		Y_true=Y_true.argmax(axis=1)[:,np.newaxis]
		Y_pred=Y_pred.argmax(axis=1)[:,np.newaxis]
	test_methyl_array = test_methyl_dataset.to_methyl_array()

	Y_pred=pd.DataFrame(Y_pred.flatten() if (np.array(Y_pred.shape)==1).any() else Y_pred,index=test_methyl_array.beta.index,columns=(['y_pred'] if categorical else interest_cols))
	Y_true=pd.DataFrame(Y_true.flatten() if (np.array(Y_true.shape)==1).any() else Y_true,index=test_methyl_array.beta.index,columns=(['y_true'] if categorical else interest_cols))
	results_df = pd.concat([Y_pred,Y_true],axis=1) if categorical else pd.concat([Y_pred.rename(columns={name:name+'_pred' for name in list(Y_pred)}),Y_true.rename(columns={name:name+'_true' for name in list(Y_pred)})],axis=1)  # FIXME
	latent_projection=pd.DataFrame(latent_projection,index=test_methyl_array.beta.index)
	test_methyl_array.beta=latent_projection

	output_file = join(output_dir,'results.csv')
	results_file = join(output_dir,'results.p')
	output_file_latent = join(output_dir,'latent.csv')
	output_pkl = join(output_dir, 'vae_mlp_methyl_arr.pkl')

	test_methyl_array.write_pickle(output_pkl)
	pickle.dump(results,open(results_file,'wb'))
	latent_projection.to_csv(output_file_latent)
	results_df.to_csv(output_file)
Ejemplo n.º 20
0
def print_number_sex_cpgs(input_pkl, array_type):
    """Print number of non-autosomal CpGs."""
    import numpy as np
    #from rpy2.robjects import pandas2ri
    from pymethylprocess.meffil_functions import r_autosomal_cpgs
    #pandas2ri.activate()
    autosomal_cpgs = r_autosomal_cpgs(array_type)  #pandas2ri.ri2py()
    methyl_array = MethylationArray.from_pickle(input_pkl)
    n_autosomal = len(np.intersect1d(list(methyl_array.beta), autosomal_cpgs))
    n_cpgs = len(list(methyl_array.beta))
    n_sex = n_cpgs - n_autosomal
    percent_sex = round(float(n_sex) / n_cpgs, 2)
    print(
        "There are {} autosomal cpgs in your methyl array and {} sex cpgs. Sex CpGs make up {}\% of {} total cpgs."
        .format(n_autosomal, n_sex, percent_sex, n_cpgs))
def get_correlation_network_(
        train_methyl_array='train_val_test_sets/train_methyl_array.pkl',
        val_methyl_array='train_val_test_sets/val_methyl_array.pkl',
        test_methyl_array='train_val_test_sets/test_methyl_array.pkl',
        min_capsule_len=5,
        capsule_choice=['gene'],
        n_jobs=20,
        output_file='corr_mat.pkl'):

    # if torch.cuda.is_available():
    # 	torch.set_default_tensor_type('torch.cuda.FloatTensor')

    datasets = dict(train=train_methyl_array,
                    val=val_methyl_array,
                    test=test_methyl_array)

    # LogisticRegression = lambda ne, lr: net = NeuralNetClassifier(LogisticRegressionModel,max_epochs=ne,lr=lr,iterator_train__shuffle=True, callbacks=[EpochScoring(LASSO)])

    X = dict()
    for k in ['train', 'val', 'test']:
        X[k] = MethylationArray.from_pickle(datasets[k]).beta

    capsules, _, names, cpg_arr = return_final_capsules(
        datasets['train'],
        capsule_choice,
        min_capsule_len,
        None,
        None,
        0,
        '',
        '',
        return_original_capsule_assignments=True)

    caps = {
        names[i]: X['train'].loc[:, cpgs].apply(np.median, axis=1).values
        for i, cpgs in enumerate(capsules)
    }  # maybe add more values

    df = pd.DataFrame(1., index=names, columns=names)
    for c1, c2 in combinations(names, r=2):
        df.loc[c1, c2] = pearsonr(caps[c1], caps[c2])

    df.to_pickle(output_file)
Ejemplo n.º 22
0
def generate_embed(input_pkl, output_generate_pkl, output_embed_pkl, cuda,
                   input_vae_pkl, stratify_column, n_workers, batch_size):
    import copy
    from methylnet.models import AutoEncoder, TybaltTitusVAE
    from methylnet.datasets import get_methylation_dataset
    import torch
    from torch.utils.data import DataLoader
    os.makedirs(os.path.dirname(output_generate_pkl), exist_ok=True)
    os.makedirs(os.path.dirname(output_embed_pkl), exist_ok=True)

    methyl_array = MethylationArray.from_pickle(
        input_pkl
    )  # generate results pickle to run through classification/regression report
    if cuda:
        model = torch.load(input_vae_pkl)
    else:
        model = torch.load(input_vae_pkl, map_location='cpu')
    test_methyl_dataset = get_methylation_dataset(copy.deepcopy(methyl_array),
                                                  stratify_column)
    test_methyl_dataloader = DataLoader(dataset=test_methyl_dataset,
                                        num_workers=n_workers,
                                        batch_size=min(
                                            batch_size,
                                            len(test_methyl_dataset)),
                                        shuffle=False)

    auto_encoder = AutoEncoder(autoencoder_model=model,
                               n_epochs=0,
                               loss_fn=None,
                               optimizer=None,
                               cuda=cuda,
                               kl_warm_up=None,
                               beta=None,
                               scheduler_opts={})
    Z, _, _ = auto_encoder.transform(test_methyl_dataloader)
    X_hat = auto_encoder.generate(test_methyl_dataloader)
    methyl_array.beta.iloc[:, :] = X_hat
    methyl_array.write_pickle(output_generate_pkl)
    methyl_array.beta = pd.DataFrame(Z, index=methyl_array.beta.index)
    methyl_array.write_pickle(output_embed_pkl)
Ejemplo n.º 23
0
def est_age(input_pkl, age_column, analyses, output_csv):
    """Estimate age using cgAgeR"""
    import pandas as pd
    import rpy2.robjects as robjects
    from rpy2.robjects.packages import importr
    from rpy2.robjects import pandas2ri, numpy2ri
    pandas2ri.activate()
    os.makedirs(output_csv[:output_csv.rfind('/')], exist_ok=True)
    methyl_array = MethylationArray.from_pickle(input_pkl)
    if age_column:
        age_column = pandas2ri.py2ri(methyl_array.pheno[age_column])
    else:
        age_column = robjects.r('NULL')
    run_analyses = dict(epitoc=False, horvath=False, hannum=False)
    for analysis in analyses:
        run_analyses[analysis] = True
    importr('cgageR')
    returned_ages = robjects.r(
        """function (beta, hannum, horvath, epitoc, age) {
                return(getAgeR(beta, epitoc=epitoc, horvath=horvath, hannum=hannum, chrage=age))
                }""")(methyl_array.beta.T, run_analyses['hannum'],
                      run_analyses['horvath'], run_analyses['epitoc'],
                      age_column)
    result_dfs = []
    return_data = lambda data_str: robjects.r(
        """function (results) results{}""".format(data_str))(returned_ages)
    if 'hannum' in analyses:
        result_dfs.append(
            pandas2ri.ri2py(
                return_data('$HannumClock.output$Hannum.Clock.Est')))
    if 'epitoc' in analyses:
        result_dfs.append(
            pandas2ri.ri2py(return_data('$EpiTOC.output$EpiTOC.Est')))
    if 'horvath' in analyses:
        result_dfs.append(
            pandas2ri.ri2py(return_data('$HorvathClock.output$Horvath.Est')))
    df = pd.concat(result_dfs, axis=1)
    df.index = methyl_array.pheno.index
    df.to_csv(output_csv)
    print(df)
Ejemplo n.º 24
0
def return_spw_importances_(train_methyl_array,
                            val_methyl_array,
                            interest_col,
                            select_subtypes,
                            capsules_pickle,
                            include_last,
                            n_bins,
                            spw_config,
                            model_state_dict_pkl,
                            batch_size,
                            by_subtype=False):
    ma = MethylationArray.from_pickle(train_methyl_array)
    ma_v = MethylationArray.from_pickle(val_methyl_array)

    try:
        ma.remove_na_samples(interest_col)
        ma_v.remove_na_samples(interest_col)
    except:
        pass

    if select_subtypes:
        ma.pheno = ma.pheno.loc[ma.pheno[interest_col].isin(select_subtypes)]
        ma.beta = ma.beta.loc[ma.pheno.index]
        ma_v.pheno = ma_v.pheno.loc[ma_v.pheno[interest_col].isin(
            select_subtypes)]
        ma_v.beta = ma_v.beta.loc[ma_v.pheno.index]

    capsules_dict = torch.load(capsules_pickle)

    final_modules, modulecpgs, module_names = capsules_dict[
        'final_modules'], capsules_dict['modulecpgs'], capsules_dict[
            'module_names']

    if not include_last:
        ma.beta = ma.beta.loc[:, modulecpgs]
        ma_v.beta = ma_v.beta.loc[:, modulecpgs]

    original_interest_col = interest_col

    if n_bins:
        new_interest_col = interest_col + '_binned'
        ma.pheno.loc[:,
                     new_interest_col], bins = pd.cut(ma.pheno[interest_col],
                                                      bins=n_bins,
                                                      retbins=True)
        ma_v.pheno.loc[:,
                       new_interest_col], _ = pd.cut(ma_v.pheno[interest_col],
                                                     bins=bins,
                                                     retbins=True)
        interest_col = new_interest_col

    datasets = dict()
    datasets['train'] = MethylationDataset(
        ma,
        interest_col,
        modules=final_modules,
        module_names=module_names,
        original_interest_col=original_interest_col,
        run_spw=True)
    datasets['val'] = MethylationDataset(
        ma_v,
        interest_col,
        modules=final_modules,
        module_names=module_names,
        original_interest_col=original_interest_col,
        run_spw=True)

    y_val = datasets['val'].y_label
    y_val_uniq = np.unique(y_val)

    dataloaders = dict()
    dataloaders['train'] = DataLoader(datasets['train'],
                                      batch_size=batch_size,
                                      shuffle=True,
                                      num_workers=8,
                                      pin_memory=True,
                                      drop_last=True)
    dataloaders['val'] = DataLoader(datasets['val'],
                                    batch_size=batch_size,
                                    shuffle=False,
                                    num_workers=8,
                                    pin_memory=True,
                                    drop_last=False)
    n_primary = len(final_modules)

    spw_config = torch.load(spw_config)
    spw_config.pop('module_names')

    model = MethylSPWNet(**spw_config)
    model.load_state_dict(torch.load(model_state_dict_pkl))

    if torch.cuda.is_available():
        model = model.cuda()

    model.eval()

    pathway_extractor = model.pathways

    #extract_pathways = lambda modules_x:torch.cat([pathway_extractor[i](module_x) for i,module_x in enumerate(modules_x)],dim=1)

    tensor_data = dict(train=dict(X=[], y=[]), val=dict(X=[], y=[]))

    for k in tensor_data:
        for i, (batch) in enumerate(dataloaders[k]):
            x = batch[0]
            y_true = batch[-1].argmax(1)  #[-2]
            modules_x = batch[1:-1]  #2]
            if torch.cuda.is_available():
                x = x.cuda()
                modules_x = modules_x[0].cuda(
                )  #[module.cuda() for module in modules_x]
            tensor_data[k]['X'].append(
                pathway_extractor(x, modules_x).detach().cpu()
            )  #extract_pathways(modules_x).detach().cpu())
            tensor_data[k]['y'].append(y_true.flatten().view(-1, 1))
        tensor_data[k]['X'] = torch.cat(tensor_data[k]['X'], dim=0)
        tensor_data[k]['y'] = torch.cat(tensor_data[k]['y'], dim=0)
        print(tensor_data[k]['X'].size(), tensor_data[k]['y'].size())
        tensor_data[k] = TensorDataset(tensor_data[k]['X'],
                                       tensor_data[k]['y'])
        dataloaders[k] = DataLoader(tensor_data[k],
                                    batch_size=32,
                                    sampler=ImbalancedDatasetSampler(
                                        tensor_data[k]))

    model = model.output_net
    to_cuda = lambda x: x.cuda() if torch.cuda.is_available() else x
    y = np.unique(tensor_data['train'].tensors[1].numpy().flatten())
    gs = GradientShap(model)
    X_train = torch.cat(
        [next(iter(dataloaders['train']))[0] for i in range(2)], dim=0)
    if torch.cuda.is_available():
        X_train = X_train.cuda()

    #val_loader=iter(dataloaders['val'])

    def return_importances(dataloaders, X_train):
        attributions = []
        for i in range(20):
            batch = next(iter(dataloaders['val']))
            X_test = to_cuda(batch[0])
            y_test = to_cuda(batch[1].flatten())
            attributions.append(
                torch.abs(
                    gs.attribute(
                        X_test,
                        stdevs=0.03,
                        n_samples=200,
                        baselines=X_train,
                        target=y_test,
                        return_convergence_delta=False)))  #torch.tensor(y_i)
        attributions = torch.sum(torch.cat(attributions, dim=0), dim=0)
        importances = pd.DataFrame(
            pd.Series(attributions.detach().cpu().numpy(),
                      index=module_names).sort_values(ascending=False),
            columns=['importances'])
        return importances

    if by_subtype:
        importances = []
        for k in y_val_uniq:
            idx = np.where(y_val == k)[0]
            if len(idx) > 2:
                val_dataset = Subset(tensor_data['val'], idx)
                n_concat = int(np.ceil(64. / len(idx)))
                if n_concat > 1:
                    val_dataset = ConcatDataset([val_dataset] * n_concat)
                #sampler=SubsetRandomSampler(idx)
                dataloaders['val'] = DataLoader(val_dataset,
                                                batch_size=32,
                                                shuffle=True)
                df = return_importances(dataloaders, X_train)
                df['subtype'] = k
                importances.append(df)
        importances = pd.concat(importances)
    else:
        importances = return_importances(dataloaders, X_train)

    return importances
Ejemplo n.º 25
0
def bin_column(test_pkl, col, n_bins, output_test_pkl):
    """Convert continuous phenotype column into categorical by binning."""
    os.makedirs(output_test_pkl[:output_test_pkl.rfind('/')], exist_ok=True)
    test_methyl_array = MethylationArray.from_pickle(test_pkl)
    new_col_name = test_methyl_array.bin_column(col, n_bins)
    test_methyl_array.write_pickle(output_test_pkl)
Ejemplo n.º 26
0
def main():
    p = argparse.ArgumentParser()
    p.add_argument('--interest_col', type=str)
    p.add_argument('--n_bins', type=int)
    args = p.parse_args()
    bin_len = 1000000
    min_capsule_len = 350
    interest_col = args.interest_col
    n_bins = args.n_bins

    primary_caps_out_len = 40
    caps_out_len = 20
    n_epochs = 500
    hidden_topology = [30, 80, 50]
    gamma = 1e-2
    decoder_top = [100, 300]
    lr = 1e-3
    routing_iterations = 3

    if not os.path.exists('hg19.{}.bed'.format(bin_len)):
        BedTool('hg19.genome').makewindows(g='hg19.genome', w=bin_len).saveas(
            'hg19.{}.bed'.format(bin_len))  #.to_dataframe().shape

    ma = MethylationArray.from_pickle(
        'train_val_test_sets/train_methyl_array.pkl')
    ma_v = MethylationArray.from_pickle(
        'train_val_test_sets/val_methyl_array.pkl')

    include_last = False

    @pysnooper.snoop('get_mod.log')
    def get_final_modules(ma=ma,
                          a='450kannotations.bed',
                          b='lola_vignette_data/activeDHS_universe.bed',
                          include_last=False,
                          min_capsule_len=2000):
        allcpgs = ma.beta.columns.values
        df = BedTool(a).to_dataframe()
        df.iloc[:, 0] = df.iloc[:, 0].astype(str).map(
            lambda x: 'chr' + x.split('.')[0])
        df = df.set_index('name').loc[list(
            ma.beta)].reset_index().iloc[:, [1, 2, 3, 0]]
        df_bed = pd.read_table(b, header=None)
        df_bed['features'] = np.arange(df_bed.shape[0])
        df_bed = df_bed.iloc[:, [0, 1, 2, -1]]
        b = BedTool.from_dataframe(df)
        a = BedTool.from_dataframe(
            df_bed)  #('lola_vignette_data/activeDHS_universe.bed')
        c = a.intersect(b, wa=True, wb=True).sort()
        d = c.groupby(g=[1, 2, 3, 4], c=(8, 8), o=('count', 'distinct'))
        df2 = d.to_dataframe()
        df3 = df2.loc[df2.iloc[:, -2] > min_capsule_len]
        modules = [cpgs.split(',') for cpgs in df3.iloc[:, -1].values]
        modulecpgs = np.array(
            list(set(list(reduce(lambda x, y: x + y, modules)))))
        if include_last:
            missing_cpgs = np.setdiff1d(allcpgs, modulecpgs).tolist()
        final_modules = modules + ([missing_cpgs] if include_last else [])
        module_names = (df3.iloc[:, 0] + '_' + df3.iloc[:, 1].astype(str) +
                        '_' + df3.iloc[:, 2].astype(str)).tolist()
        return final_modules, modulecpgs, module_names

    final_modules, modulecpgs, module_names = get_final_modules(
        b='hg19.{}.bed'.format(bin_len),
        include_last=include_last,
        min_capsule_len=min_capsule_len)
    print('LEN_MODULES', len(final_modules))

    if not include_last:
        ma.beta = ma.beta.loc[:, modulecpgs]
        ma_v.beta = ma_v.beta.loc[:, modulecpgs]
    # https://github.com/higgsfield/Capsule-Network-Tutorial/blob/master/Capsule%20Network.ipynb

    def softmax(input_tensor, dim=1):
        # transpose input
        transposed_input = input_tensor.transpose(dim,
                                                  len(input_tensor.size()) - 1)
        # calculate softmax
        softmaxed_output = F.softmax(transposed_input.contiguous().view(
            -1, transposed_input.size(-1)),
                                     dim=-1)
        # un-transpose result
        return softmaxed_output.view(*transposed_input.size()).transpose(
            dim,
            len(input_tensor.size()) - 1)

    class MLP(
            nn.Module
    ):  # add latent space extraction, and spits out csv line of SQL as text for UMAP
        def __init__(self,
                     n_input,
                     hidden_topology,
                     dropout_p,
                     n_outputs=1,
                     binary=False,
                     softmax=False):
            super(MLP, self).__init__()
            self.hidden_topology = hidden_topology
            self.topology = [n_input] + hidden_topology + [n_outputs]
            layers = [
                nn.Linear(self.topology[i], self.topology[i + 1])
                for i in range(len(self.topology) - 2)
            ]
            for layer in layers:
                torch.nn.init.xavier_uniform_(layer.weight)
            self.layers = [
                nn.Sequential(layer, nn.ReLU(), nn.Dropout(p=dropout_p))
                for layer in layers
            ]
            self.output_layer = nn.Linear(self.topology[-2], self.topology[-1])
            torch.nn.init.xavier_uniform_(self.output_layer.weight)
            if binary:
                output_transform = nn.Sigmoid()
            elif softmax:
                output_transform = nn.Softmax()
            else:
                output_transform = nn.Dropout(p=0.)
            self.layers.append(
                nn.Sequential(self.output_layer, output_transform))
            self.mlp = nn.Sequential(*self.layers)

        def forward(self, x):
            #print(x.shape)
            return self.mlp(x)

    class MethylationDataset(Dataset):
        def __init__(self,
                     methyl_arr,
                     outcome_col,
                     binarizer=None,
                     modules=[]):
            if binarizer == None:
                binarizer = LabelBinarizer()
                binarizer.fit(methyl_arr.pheno[outcome_col].astype(str).values)
            self.y = binarizer.transform(
                methyl_arr.pheno[outcome_col].astype(str).values)
            self.y_unique = np.unique(np.argmax(self.y, 1))
            self.binarizer = binarizer
            if not modules:
                modules = [list(methyl_arr.beta)]
            self.modules = modules
            self.X = methyl_arr.beta
            self.length = methyl_arr.beta.shape[0]

        def __len__(self):
            return self.length

        def __getitem__(self, i):
            return tuple([torch.FloatTensor(self.X.iloc[i].values)] + [
                torch.FloatTensor(self.X.iloc[i].loc[module].values)
                for module in self.modules
            ] + [torch.FloatTensor(self.y[i])])

    class PrimaryCaps(nn.Module):
        def __init__(self, modules, hidden_topology, n_output):
            super(PrimaryCaps, self).__init__()
            self.capsules = nn.ModuleList([
                MLP(len(module), hidden_topology, 0., n_outputs=n_output)
                for module in modules
            ])

        def forward(self, x):
            #print(self.capsules)
            u = [self.capsules[i](x[i]) for i in range(len(self.capsules))]
            u = torch.stack(u, dim=1)
            #print(u.size())
            return self.squash(u)

        def squash(self, x):
            squared_norm = (x**2).sum(-1, keepdim=True)
            #print('prim_norm',squared_norm.size())
            output_tensor = squared_norm * x / (
                (1. + squared_norm) * torch.sqrt(squared_norm))
            #print('z_init',output_tensor.size())
            return output_tensor

        def get_weights(self):
            return list(
                self.capsules[0].parameters()
            )[0].data  #self.state_dict()#[self.capsules[i].state_dict() for i in range(len(self.capsules))]

    class CapsLayer(nn.Module):
        def __init__(self,
                     n_capsules,
                     n_routes,
                     n_input,
                     n_output,
                     routing_iterations=3):
            super(CapsLayer, self).__init__()
            self.n_capsules = n_capsules
            self.num_routes = n_routes
            self.W = nn.Parameter(
                torch.randn(1, n_routes, n_capsules, n_output, n_input))
            self.routing_iterations = routing_iterations
            self.c_ij = None

        def forward(self, x):
            batch_size = x.size(0)
            x = torch.stack([x] * self.n_capsules, dim=2).unsqueeze(4)

            W = torch.cat([self.W] * batch_size, dim=0)
            #print('affine',W.size(),x.size())
            u_hat = torch.matmul(W, x)
            #print('affine_trans',u_hat.size())

            b_ij = Variable(torch.zeros(1, self.num_routes, self.n_capsules,
                                        1))

            if torch.cuda.is_available():
                b_ij = b_ij.cuda()

            for iteration in range(self.routing_iterations):
                self.c_ij = softmax(b_ij)
                #print(c_ij)
                c_ij = torch.cat([self.c_ij] * batch_size, dim=0).unsqueeze(4)
                #print('coeff',c_ij.size())#[0,:,0,:])#.size())

                s_j = (c_ij * u_hat).sum(dim=1, keepdim=True)
                v_j = self.squash(s_j)
                #print('z',v_j.size())

                if iteration < self.routing_iterations - 1:
                    a_ij = torch.matmul(
                        u_hat.transpose(3, 4),
                        torch.cat([v_j] * self.num_routes, dim=1))
                    b_ij = b_ij + a_ij.squeeze(4).mean(dim=0, keepdim=True)

            return v_j.squeeze(1)

        def return_routing_coef(self):
            return self.c_ij

        def squash(self, x):
            #print(x.size())
            squared_norm = (x**2).sum(-1, keepdim=True)
            #print('norm',squared_norm.size())
            output_tensor = squared_norm * x / (
                (1. + squared_norm) * torch.sqrt(squared_norm))
            return output_tensor

    class Decoder(nn.Module):
        def __init__(self, n_input, n_output, hidden_topology):
            super(Decoder, self).__init__()
            self.decoder = MLP(n_input,
                               hidden_topology,
                               0.,
                               n_outputs=n_output,
                               binary=True)

        def forward(self, x):
            return self.decoder(x)

    class CapsNet(nn.Module):
        def __init__(self,
                     primary_caps,
                     caps_hidden_layers,
                     caps_output_layer,
                     decoder,
                     lr_balance=0.5,
                     gamma=0.005):
            super(CapsNet, self).__init__()
            self.primary_caps = primary_caps
            self.caps_hidden_layers = caps_hidden_layers
            self.caps_output_layer = caps_output_layer
            self.decoder = decoder
            self.recon_loss_fn = nn.BCELoss()
            self.lr_balance = lr_balance
            self.gamma = gamma

        def forward(self, x_orig, modules_input):
            x = self.primary_caps(modules_input)
            primary_caps_out = x  #.view(x.size(0),x.size(1)*x.size(2))
            #print(x.size())
            for layer in self.caps_hidden_layers:
                x = layer(x)

            y_pred = self.caps_output_layer(x)  #.squeeze(-1)
            #print(y_pred.shape)

            classes = torch.sqrt((y_pred**2).sum(2))
            classes = F.softmax(classes)

            max_length_indices = classes.argmax(dim=1)
            masked = torch.sparse.torch.eye(self.caps_output_layer.n_capsules)
            if torch.cuda.is_available():
                masked = masked.cuda()
            masked = masked.index_select(
                dim=0, index=max_length_indices.squeeze(1).data)

            embedding = (y_pred * masked[:, :, None, None]).view(
                y_pred.size(0), -1)

            #print(y_pred.size())
            x_hat = self.decoder(embedding)  #.reshape(y_pred.size(0),-1))
            return x_orig, x_hat, y_pred, embedding, primary_caps_out

        def recon_loss(self, x_orig, x_hat):
            return self.recon_loss_fn(x_hat, x_orig)

        def margin_loss(self, x, labels):
            batch_size = x.size(0)

            v_c = torch.sqrt((x**2).sum(dim=2, keepdim=True))

            #print(v_c)

            left = (F.relu(0.9 - v_c)**2).view(batch_size, -1)
            right = (F.relu(v_c - 0.1)**2).view(batch_size, -1)
            #print(left)
            #print(right)
            #print(labels)

            loss = labels * left + self.lr_balance * (1.0 - labels) * right
            #print(loss.shape)
            loss = loss.sum(dim=1).mean()
            return loss

        def calculate_loss(self, x_orig, x_hat, y_pred, y_true):
            margin_loss = self.margin_loss(y_pred, y_true)
            recon_loss = self.gamma * self.recon_loss(x_orig, x_hat)
            loss = margin_loss + recon_loss
            return loss, margin_loss, recon_loss

    if n_bins:
        ma.pheno.loc[:, interest_col], bins = pd.cut(ma.pheno[interest_col],
                                                     bins=n_bins,
                                                     retbins=True)
        ma_v.pheno.loc[:, interest_col], bins = pd.cut(
            ma_v.pheno[interest_col],
            bins=bins,
            retbins=True,
        )

    dataset = MethylationDataset(ma, interest_col, modules=final_modules)
    dataset_v = MethylationDataset(ma_v, interest_col, modules=final_modules)

    dataloader = DataLoader(dataset,
                            batch_size=16,
                            shuffle=True,
                            num_workers=8,
                            drop_last=True)
    dataloader_v = DataLoader(dataset_v,
                              batch_size=16,
                              shuffle=False,
                              num_workers=8,
                              drop_last=False)

    n_inputs = list(map(len, final_modules))
    n_primary = len(final_modules)

    primary_caps = PrimaryCaps(modules=final_modules,
                               hidden_topology=hidden_topology,
                               n_output=primary_caps_out_len)
    hidden_caps = []
    n_out_caps = len(dataset.y_unique)
    output_caps = CapsLayer(n_out_caps,
                            n_primary,
                            primary_caps_out_len,
                            caps_out_len,
                            routing_iterations=routing_iterations)
    decoder = Decoder(n_out_caps * caps_out_len, len(list(ma.beta)),
                      decoder_top)
    capsnet = CapsNet(primary_caps,
                      hidden_caps,
                      output_caps,
                      decoder,
                      gamma=gamma)

    if torch.cuda.is_available():
        capsnet = capsnet.cuda()

    for d in ['figures/embeddings' + x for x in ['', '2', '3']]:
        os.makedirs(d, exist_ok=True)
    os.makedirs('results/routing_weights', exist_ok=True)
    # extract all c_ij for all layers across all batches, or just last batch
    optimizer = Adam(capsnet.parameters(), lr)
    scheduler = CosineAnnealingLR(optimizer,
                                  T_max=10,
                                  eta_min=0,
                                  last_epoch=-1)
    for epoch in range(n_epochs):
        print(epoch)
        capsnet.train(True)
        running_loss = 0.
        Y = {'true': [], 'pred': []}
        for i, batch in enumerate(dataloader):
            x_orig = batch[0]
            #print(x_orig)
            y_true = batch[-1]
            module_x = batch[1:-1]
            if torch.cuda.is_available():
                x_orig = x_orig.cuda()
                y_true = y_true.cuda()
                module_x = [mod.cuda() for mod in module_x]
            x_orig, x_hat, y_pred, embedding, primary_caps_out = capsnet(
                x_orig, module_x)
            loss, margin_loss, recon_loss = capsnet.calculate_loss(
                x_orig, x_hat, y_pred, y_true)
            Y['true'].extend(y_true.argmax(1).detach().cpu().numpy().tolist())
            Y['pred'].extend(
                F.softmax(torch.sqrt(
                    (y_pred**2
                     ).sum(2))).argmax(1).detach().cpu().numpy().tolist())
            train_loss = margin_loss.item()  #print(loss)
            running_loss += train_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        #print(capsnet.primary_caps.get_weights())
        running_loss /= (i + 1)
        print('Epoch {}: Train Loss {}, Train R2: {}, Train MAE: {}'.format(
            epoch, running_loss, r2_score(Y['true'], Y['pred']),
            mean_absolute_error(Y['true'], Y['pred'])))
        print(classification_report(Y['true'], Y['pred']))
        scheduler.step()
        capsnet.train(False)
        running_loss = np.zeros((3, )).astype(float)
        Y = {
            'true': [],
            'pred': [],
            'embeddings': [],
            'embeddings2': [],
            'embeddings3': [],
            'routing_weights': []
        }
        with torch.no_grad():
            for i, batch in enumerate(dataloader_v):
                x_orig = batch[0]
                y_true = batch[-1]
                module_x = batch[1:-1]
                if torch.cuda.is_available():
                    x_orig = x_orig.cuda()
                    y_true = y_true.cuda()
                    module_x = [mod.cuda() for mod in module_x]
                x_orig, x_hat, y_pred, embedding, primary_caps_out = capsnet(
                    x_orig, module_x)
                #print(primary_caps_out.size())
                routing_coefs = capsnet.caps_output_layer.return_routing_coef(
                ).detach().cpu().numpy()
                if not i:
                    Y['routing_weights'] = pd.DataFrame(
                        routing_coefs[0, ..., 0].T,
                        index=dataset.binarizer.classes_,
                        columns=module_names)
                else:
                    Y['routing_weights'] += pd.DataFrame(
                        routing_coefs[0, ..., 0].T,
                        index=dataset.binarizer.classes_,
                        columns=module_names)
                Y['embeddings3'].append(
                    torch.cat(
                        [primary_caps_out[i] for i in range(x_orig.size(0))],
                        dim=0).detach().cpu().numpy())
                primary_caps_out = primary_caps_out.view(
                    primary_caps_out.size(0),
                    primary_caps_out.size(1) * primary_caps_out.size(2))
                Y['embeddings'].append(embedding.detach().cpu().numpy())
                Y['embeddings2'].append(
                    primary_caps_out.detach().cpu().numpy())
                loss, margin_loss, recon_loss = capsnet.calculate_loss(
                    x_orig, x_hat, y_pred, y_true)
                val_loss = margin_loss.item()  #print(loss)
                running_loss = running_loss + np.array(
                    [loss.item(), margin_loss,
                     recon_loss.item()])
                Y['true'].extend(
                    y_true.argmax(1).detach().cpu().numpy().tolist())
                Y['pred'].extend(
                    (y_pred**2
                     ).sum(2).argmax(1).detach().cpu().numpy().tolist())
            running_loss /= (i + 1)
            Y['routing_weights'].iloc[:, :] = Y['routing_weights'].values / (
                i + 1)

        Y['pred'] = np.array(Y['pred']).astype(str)
        Y['true'] = np.array(Y['true']).astype(str)
        #np.save('results/routing_weights/routing_weights.{}.npy'.format(epoch),Y['routing_weights'])
        pickle.dump(
            Y['routing_weights'],
            open('results/routing_weights/routing_weights.{}.p'.format(epoch),
                 'wb'))
        Y['embeddings'] = pd.DataFrame(PCA(n_components=2).fit_transform(
            np.vstack(Y['embeddings'])),
                                       columns=['x', 'y'])
        Y['embeddings2'] = pd.DataFrame(PCA(n_components=2).fit_transform(
            np.vstack(Y['embeddings2'])),
                                        columns=['x', 'y'])
        #print(list(map(lambda x: x.shape,Y['embeddings3'])))
        Y['embeddings3'] = pd.DataFrame(PCA(n_components=2).fit_transform(
            np.vstack(Y['embeddings3'])),
                                        columns=['x', 'y'])  #'z'
        Y['embeddings']['color'] = Y['true']
        Y['embeddings2']['color'] = Y['true']
        Y['embeddings3']['color'] = module_names * ma_v.beta.shape[
            0]  #Y['true']
        Y['embeddings3']['name'] = list(
            reduce(lambda x, y: x + y, [[i] * n_primary for i in Y['true']]))
        fig = px.scatter(Y['embeddings3'],
                         x="x",
                         y="y",
                         color="color",
                         symbol='name')  #, text='name')
        py.plot(fig,
                filename='figures/embeddings3/embeddings3.{}.pos.html'.format(
                    epoch),
                auto_open=False)
        #Y['embeddings3']['color']=list(reduce(lambda x,y:x+y,[[i]*n_primary for i in Y['true']]))
        fig = px.scatter(Y['embeddings3'], x="x", y="y",
                         color="name")  #, text='color')
        py.plot(fig,
                filename='figures/embeddings3/embeddings3.{}.true.html'.format(
                    epoch),
                auto_open=False)
        fig = px.scatter(Y['embeddings'], x="x", y="y", color="color")
        py.plot(fig,
                filename='figures/embeddings/embeddings.{}.true.html'.format(
                    epoch),
                auto_open=False)
        fig = px.scatter(Y['embeddings2'], x="x", y="y", color="color")
        py.plot(fig,
                filename='figures/embeddings2/embeddings2.{}.true.html'.format(
                    epoch),
                auto_open=False)
        Y['embeddings'].loc[:, 'color'] = Y['pred']
        Y['embeddings2'].loc[:, 'color'] = Y['pred']
        fig = px.scatter(Y['embeddings'], x="x", y="y", color="color")
        py.plot(fig,
                filename='figures/embeddings/embeddings.{}.pred.html'.format(
                    epoch),
                auto_open=False)
        fig = px.scatter(Y['embeddings2'], x="x", y="y", color="color")
        py.plot(fig,
                filename='figures/embeddings2/embeddings2.{}.pred.html'.format(
                    epoch),
                auto_open=False)
        print(
            'Epoch {}: Val Loss {}, Margin Loss {}, Recon Loss {}, Val R2: {}, Val MAE: {}'
            .format(
                epoch, running_loss[0], running_loss[1], running_loss[2],
                r2_score(Y['true'].astype(int), Y['pred'].astype(int)),
                mean_absolute_error(Y['true'].astype(int),
                                    Y['pred'].astype(int))))
        print(classification_report(Y['true'], Y['pred']))
Ejemplo n.º 27
0
def train_predict(train_pkl,
                  test_pkl,
                  input_vae_pkl,
                  output_dir,
                  cuda,
                  interest_cols,
                  categorical,
                  disease_only,
                  hidden_layer_topology,
                  learning_rate_vae,
                  learning_rate_mlp,
                  weight_decay,
                  dropout_p,
                  n_epochs,
                  scheduler='null',
                  decay=0.5,
                  t_max=10,
                  eta_min=1e-6,
                  t_mult=2,
                  batch_size=50,
                  val_pkl='val_methyl_array.pkl',
                  n_workers=8,
                  add_validation_set=False,
                  loss_reduction='sum',
                  add_softmax=False):
    os.makedirs(output_dir, exist_ok=True)

    output_file = join(output_dir, 'results.csv')
    training_curve_file = join(output_dir, 'training_val_curve.p')
    results_file = join(output_dir, 'results.p')
    output_file_latent = join(output_dir, 'latent.csv')
    output_model = join(output_dir, 'output_model.p')
    output_pkl = join(output_dir, 'vae_mlp_methyl_arr.pkl')
    output_onehot_encoder = join(output_dir, 'one_hot_encoder.p')

    #input_dict = pickle.load(open(input_pkl,'rb'))
    if cuda:
        vae_model = torch.load(input_vae_pkl)
        vae_model.cuda_on = True
    else:
        vae_model = torch.load(input_vae_pkl, map_location='cpu')
        vae_model.cuda_on = False

    train_methyl_array, val_methyl_array, test_methyl_array = MethylationArray.from_pickle(
        train_pkl
    ), MethylationArray.from_pickle(val_pkl), MethylationArray.from_pickle(
        test_pkl
    )  #methyl_array.split_train_test(train_p=train_percent, stratified=(True if categorical else False), disease_only=disease_only, key=interest_cols[0], subtype_delimiter=',')

    if not categorical:
        train_methyl_array.remove_na_samples(interest_cols)
        val_methyl_array.remove_na_samples(interest_cols)
        test_methyl_array.remove_na_samples(interest_cols)

    print(train_methyl_array.beta.shape)
    print(val_methyl_array.beta.shape)
    print(test_methyl_array.beta.shape)

    if len(interest_cols) == 1 and disease_only and interest_cols[0].endswith(
            '_only') == False:
        print(interest_cols)
        interest_cols[0] += '_only'
        print(train_methyl_array.pheno[interest_cols[0]].unique())
        print(test_methyl_array.pheno[interest_cols[0]].unique())

    train_methyl_dataset = get_methylation_dataset(
        train_methyl_array,
        interest_cols,
        categorical=categorical,
        predict=True)  # train, test split? Add val set?
    #print(list(train_methyl_dataset.encoder.get_feature_names()))
    val_methyl_dataset = get_methylation_dataset(
        val_methyl_array,
        interest_cols,
        categorical=categorical,
        predict=True,
        categorical_encoder=train_methyl_dataset.encoder)
    test_methyl_dataset = get_methylation_dataset(
        test_methyl_array,
        interest_cols,
        categorical=categorical,
        predict=True,
        categorical_encoder=train_methyl_dataset.encoder)

    if not batch_size:
        batch_size = len(train_methyl_dataset)
    train_batch_size = min(batch_size, len(train_methyl_dataset))
    val_batch_size = min(batch_size, len(val_methyl_dataset))

    train_methyl_dataloader = DataLoader(dataset=train_methyl_dataset,
                                         num_workers=n_workers,
                                         batch_size=train_batch_size,
                                         shuffle=True)

    val_methyl_dataloader = DataLoader(dataset=val_methyl_dataset,
                                       num_workers=n_workers,
                                       batch_size=val_batch_size,
                                       shuffle=True)  # False

    test_methyl_dataloader = DataLoader(dataset=test_methyl_dataset,
                                        num_workers=n_workers,
                                        batch_size=min(
                                            batch_size,
                                            len(test_methyl_dataset)),
                                        shuffle=False)

    scaling_factors = dict(
        val=float(len(val_methyl_dataset)) /
        ((len(val_methyl_dataset) // val_batch_size) * val_batch_size),
        train_batch_size=train_batch_size,
        val_batch_size=val_batch_size)

    model = VAE_MLP(vae_model=vae_model,
                    categorical=categorical,
                    hidden_layer_topology=hidden_layer_topology,
                    n_output=train_methyl_dataset.outcome_col.shape[1],
                    dropout_p=dropout_p,
                    add_softmax=add_softmax)

    class_weights = []
    if categorical:
        out_weight = Counter(
            np.argmax(train_methyl_dataset.outcome_col, axis=1))
        #total_samples=sum(out_weight.values())
        for k in sorted(list(out_weight.keys())):
            class_weights.append(1. / float(out_weight[k]))  # total_samples
        class_weights = np.array(class_weights)
        class_weights = (class_weights / class_weights.sum()).tolist()
        print(class_weights)

    if class_weights:
        class_weights = torch.FloatTensor(class_weights)
        if cuda:
            class_weights = class_weights.cuda()
    else:
        class_weights = None

    optimizer_vae = torch.optim.Adam(model.vae.parameters(),
                                     lr=learning_rate_vae,
                                     weight_decay=weight_decay)
    optimizer_mlp = torch.optim.Adam(model.mlp.parameters(),
                                     lr=learning_rate_mlp,
                                     weight_decay=weight_decay)
    loss_fn = CrossEntropyLoss(
        reduction=loss_reduction,
        weight=class_weights) if categorical else MSELoss(
            reduction=loss_reduction)  # 'sum'
    scheduler_opts = dict(scheduler=scheduler,
                          lr_scheduler_decay=decay,
                          T_max=t_max,
                          eta_min=eta_min,
                          T_mult=t_mult)
    vae_mlp = MLPFinetuneVAE(mlp_model=model,
                             n_epochs=n_epochs,
                             categorical=categorical,
                             loss_fn=loss_fn,
                             optimizer_vae=optimizer_vae,
                             optimizer_mlp=optimizer_mlp,
                             cuda=cuda,
                             scheduler_opts=scheduler_opts)
    if add_validation_set:
        vae_mlp.add_validation_set(val_methyl_dataloader)
    vae_mlp = vae_mlp.fit(train_methyl_dataloader)
    if 'encoder' in dir(train_methyl_dataset):
        pickle.dump(train_methyl_dataset.encoder,
                    open(output_onehot_encoder, 'wb'))
    results = dict(test={}, train={}, val={})
    results['train']['y_pred'], results['train'][
        'y_true'], _, _ = vae_mlp.predict(train_methyl_dataloader)
    results['val']['y_pred'], results['val']['y_true'], _, _ = vae_mlp.predict(
        val_methyl_dataloader)
    del train_methyl_dataloader, train_methyl_dataset
    """methyl_dataset=get_methylation_dataset(methyl_array,interest_cols,predict=True)
    methyl_dataset_loader = DataLoader(
        dataset=methyl_dataset,
        num_workers=9,
        batch_size=1,
        shuffle=False)"""
    Y_pred, Y_true, latent_projection, _ = vae_mlp.predict(
        test_methyl_dataloader
    )  # FIXME change to include predictions for all classes for AUC
    results['test']['y_pred'], results['test']['y_true'] = copy.deepcopy(
        Y_pred), copy.deepcopy(Y_true)
    if categorical:
        Y_true = Y_true.argmax(axis=1)[:, np.newaxis]
        Y_pred = Y_pred.argmax(axis=1)[:, np.newaxis]
    test_methyl_array = test_methyl_dataset.to_methyl_array()
    """if categorical:
        Y_true=test_methyl_dataset.encoder.inverse_transform(Y_true)[:,np.newaxis]
        Y_pred=test_methyl_dataset.encoder.inverse_transform(Y_pred)[:,np.newaxis]"""
    #sample_names = np.array(list(test_methyl_array.beta.index)) # FIXME
    #outcomes = np.array([outcome[0] for outcome in outcomes]) # FIXME
    Y_pred = pd.DataFrame(
        Y_pred.flatten() if (np.array(Y_pred.shape) == 1).any() else Y_pred,
        index=test_methyl_array.beta.index,
        columns=(['y_pred'] if categorical else
                 interest_cols))  #dict(zip(sample_names,outcomes))
    Y_true = pd.DataFrame(
        Y_true.flatten() if (np.array(Y_true.shape) == 1).any() else Y_true,
        index=test_methyl_array.beta.index,
        columns=(['y_true'] if categorical else interest_cols))
    results_df = pd.concat([
        Y_pred, Y_true
    ], axis=1) if categorical else pd.concat([
        Y_pred.rename(columns={name: name + '_pred'
                               for name in list(Y_pred)}),
        Y_true.rename(columns={name: name + '_true'
                               for name in list(Y_pred)})
    ],
                                             axis=1)  # FIXME
    latent_projection = pd.DataFrame(latent_projection,
                                     index=test_methyl_array.beta.index)
    test_methyl_array.beta = latent_projection
    test_methyl_array.write_pickle(output_pkl)
    pickle.dump(results, open(results_file, 'wb'))
    pickle.dump(vae_mlp.training_plot_data, open(training_curve_file, 'wb'))
    latent_projection.to_csv(output_file_latent)
    torch.save(vae_mlp.model, output_model)
    results_df.to_csv(
        output_file)  #pickle.dump(outcome_dict, open(outcome_dict_file,'wb'))
    return latent_projection, Y_pred, Y_true, vae_mlp, scaling_factors
Ejemplo n.º 28
0
def embed_vae(train_pkl,
              output_dir,
              cuda,
              n_latent,
              lr,
              weight_decay,
              n_epochs,
              hidden_layer_encoder_topology,
              kl_warm_up=0,
              beta=1.,
              scheduler='null',
              decay=0.5,
              t_max=10,
              eta_min=1e-6,
              t_mult=2,
              bce_loss=False,
              batch_size=50,
              val_pkl='val_methyl_array.pkl',
              n_workers=9,
              convolutional=False,
              height_kernel_sizes=[],
              width_kernel_sizes=[],
              add_validation_set=False,
              loss_reduction='sum',
              stratify_column='disease'):
    from methylnet.models import AutoEncoder, TybaltTitusVAE
    from methylnet.datasets import get_methylation_dataset
    import torch
    from torch.utils.data import DataLoader
    from torch.nn import MSELoss, BCELoss
    os.makedirs(output_dir, exist_ok=True)

    output_file = join(output_dir, 'output_latent.csv')
    output_model = join(output_dir, 'output_model.p')
    training_curve_file = join(output_dir, 'training_val_curve.p')
    outcome_dict_file = join(output_dir, 'output_outcomes.p')
    output_pkl = join(output_dir, 'vae_methyl_arr.pkl')

    #input_dict = pickle.load(open(input_pkl,'rb'))
    #methyl_array=MethylationArray(*extract_pheno_beta_df_from_pickle_dict(input_dict))
    #print(methyl_array.beta)
    train_methyl_array, val_methyl_array = MethylationArray.from_pickle(
        train_pkl
    ), MethylationArray.from_pickle(
        val_pkl
    )  #methyl_array.split_train_test(train_p=train_percent, stratified=True, disease_only=True, key='disease', subtype_delimiter=',')

    train_methyl_dataset = get_methylation_dataset(
        train_methyl_array, stratify_column)  # train, test split? Add val set?

    val_methyl_dataset = get_methylation_dataset(val_methyl_array,
                                                 stratify_column)

    if not batch_size:
        batch_size = len(methyl_dataset)

    train_batch_size = min(batch_size, len(train_methyl_dataset))
    val_batch_size = min(batch_size, len(val_methyl_dataset))

    train_methyl_dataloader = DataLoader(
        dataset=train_methyl_dataset,
        num_workers=n_workers,  #n_workers
        batch_size=train_batch_size,
        shuffle=True,
        pin_memory=False)

    val_methyl_dataloader = DataLoader(dataset=val_methyl_dataset,
                                       num_workers=n_workers,
                                       batch_size=val_batch_size,
                                       shuffle=True,
                                       pin_memory=False)

    scaling_factors = dict(
        train=float(len(train_methyl_dataset)) /
        ((len(train_methyl_dataset) // train_batch_size) * train_batch_size),
        val=float(len(val_methyl_dataset)) /
        ((len(val_methyl_dataset) // val_batch_size) * val_batch_size),
        train_batch_size=train_batch_size,
        val_batch_size=val_batch_size)
    print('SCALE', len(train_methyl_dataset), len(val_methyl_dataset),
          train_batch_size, val_batch_size, scaling_factors)
    n_input = train_methyl_array.return_shape()[1]
    if not convolutional:
        model = TybaltTitusVAE(
            n_input=n_input,
            n_latent=n_latent,
            hidden_layer_encoder_topology=hidden_layer_encoder_topology,
            cuda=cuda)
    else:
        model = CVAE(n_latent=n_latent,
                     in_shape=methyl_dataset.new_shape,
                     kernel_heights=height_kernel_sizes,
                     kernel_widths=width_kernel_sizes,
                     n_pre_latent=n_latent * 2)  # change soon

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
    loss_fn = BCELoss(reduction=loss_reduction) if bce_loss else MSELoss(
        reduction=loss_reduction)  # 'sum'
    scheduler_opts = dict(scheduler=scheduler,
                          lr_scheduler_decay=decay,
                          T_max=t_max,
                          eta_min=eta_min,
                          T_mult=t_mult)
    auto_encoder = AutoEncoder(autoencoder_model=model,
                               n_epochs=n_epochs,
                               loss_fn=loss_fn,
                               optimizer=optimizer,
                               cuda=cuda,
                               kl_warm_up=kl_warm_up,
                               beta=beta,
                               scheduler_opts=scheduler_opts)
    if add_validation_set:
        auto_encoder.add_validation_set(val_methyl_dataloader)
    auto_encoder = auto_encoder.fit(train_methyl_dataloader)
    train_methyl_array = train_methyl_dataset.to_methyl_array()
    val_methyl_array = val_methyl_dataset.to_methyl_array()
    del val_methyl_dataloader, train_methyl_dataloader, val_methyl_dataset, train_methyl_dataset

    methyl_dataset = get_methylation_dataset(
        MethylationArrays([train_methyl_array, val_methyl_array]).combine(),
        stratify_column)
    methyl_dataset_loader = DataLoader(dataset=methyl_dataset,
                                       num_workers=n_workers,
                                       batch_size=1,
                                       shuffle=False)
    latent_projection, _, _ = auto_encoder.transform(methyl_dataset_loader)
    #print(latent_projection.shape)
    methyl_array = methyl_dataset.to_methyl_array()
    #sample_names = np.array([sample_name[0] for sample_name in sample_names]) # FIXME
    #outcomes = np.array([outcome[0] for outcome in outcomes]) # FIXME
    #outcome_dict=dict(zip(sample_names,outcomes))
    #print(methyl_array.beta)
    latent_projection = pd.DataFrame(latent_projection,
                                     index=methyl_array.beta.index)
    methyl_array.beta = latent_projection
    methyl_array.write_pickle(output_pkl)
    latent_projection.to_csv(output_file)
    pickle.dump(auto_encoder.training_plot_data, open(training_curve_file,
                                                      'wb'))
    torch.save(auto_encoder.model, output_model)
    #pickle.dump(outcome_dict, open(outcome_dict_file,'wb'))
    return latent_projection, None, scaling_factors, n_input, auto_encoder
Ejemplo n.º 29
0
def model_capsnet_(
        train_methyl_array='train_val_test_sets/train_methyl_array.pkl',
        val_methyl_array='train_val_test_sets/val_methyl_array.pkl',
        interest_col='disease',
        n_epochs=10,
        n_bins=0,
        bin_len=1000000,
        min_capsule_len=300,
        primary_caps_out_len=45,
        caps_out_len=45,
        hidden_topology='30,80,50',
        gamma=1e-2,
        decoder_topology='100,300',
        learning_rate=1e-2,
        routing_iterations=3,
        overlap=0.,
        custom_loss='none',
        gamma2=1e-2,
        job=0,
        capsule_choice=['genomic_binned'],
        custom_capsule_file='',
        test_methyl_array='',
        predict=False,
        batch_size=16,
        limited_capsule_names_file='',
        gsea_superset='',
        tissue='',
        number_sets=25,
        use_set=False,
        gene_context=False,
        select_subtypes=[],
        fit_spw=False,
        l1_l2='',
        custom_capsule_file2='',
        min_capsules=5):

    capsule_choice = list(capsule_choice)
    #custom_capsule_file=list(custom_capsule_file)
    hlt_list = filter(None, hidden_topology.split(','))
    if hlt_list:
        hidden_topology = list(map(int, hlt_list))
    else:
        hidden_topology = []
    hlt_list = filter(None, decoder_topology.split(','))
    if hlt_list:
        decoder_topology = list(map(int, hlt_list))
    else:
        decoder_topology = []

    hidden_caps_layers = []
    include_last = False

    ma = MethylationArray.from_pickle(train_methyl_array)
    ma_v = MethylationArray.from_pickle(val_methyl_array)
    if test_methyl_array and predict:
        ma_t = MethylationArray.from_pickle(test_methyl_array)

    try:
        ma.remove_na_samples(interest_col)
        ma_v.remove_na_samples(interest_col)
        if test_methyl_array and predict:
            ma_t.remove_na_samples(interest_col)
    except:
        pass

    if select_subtypes:
        print(ma.pheno[interest_col].unique())
        ma.pheno = ma.pheno.loc[ma.pheno[interest_col].isin(select_subtypes)]
        ma.beta = ma.beta.loc[ma.pheno.index]
        ma_v.pheno = ma_v.pheno.loc[ma_v.pheno[interest_col].isin(
            select_subtypes)]
        ma_v.beta = ma_v.beta.loc[ma_v.pheno.index]
        print(ma.pheno[interest_col].unique())

        if test_methyl_array and predict:
            ma_t.pheno = ma_t.pheno.loc[ma_t.pheno[interest_col].isin(
                select_subtypes)]
            ma_t.beta = ma_t.beta.loc[ma_t.pheno.index]

    if custom_capsule_file2 and os.path.exists(custom_capsule_file2):
        capsules_dict = torch.load(custom_capsule_file2)
        final_modules, modulecpgs, module_names = capsules_dict[
            'final_modules'], capsules_dict['modulecpgs'], capsules_dict[
                'module_names']
        if min_capsule_len > 1:
            include_capsules = [
                len(x) > min_capsule_len for x in final_modules
            ]
            final_modules = [
                final_modules[i] for i in range(len(final_modules))
                if include_capsules[i]
            ]
            module_names = [
                module_names[i] for i in range(len(module_names))
                if include_capsules[i]
            ]
            modulecpgs = (reduce(np.union1d, final_modules)).tolist()

    else:
        final_modules, modulecpgs, module_names = build_capsules(
            capsule_choice, overlap, bin_len, ma, include_last,
            min_capsule_len, custom_capsule_file, gsea_superset, tissue,
            gene_context, use_set, number_sets, limited_capsule_names_file)
        if custom_capsule_file2:
            torch.save(
                dict(final_modules=final_modules,
                     modulecpgs=modulecpgs,
                     module_names=module_names), custom_capsule_file2)

    assert len(
        final_modules) >= min_capsules, "Below the number of allowed capsules."

    if fit_spw:
        modulecpgs = list(reduce(lambda x, y: np.hstack((x, y)),
                                 final_modules))

    if not include_last:  # ERROR HAPPENS HERE!
        ma.beta = ma.beta.loc[:, modulecpgs]
        ma_v.beta = ma_v.beta.loc[:, modulecpgs]
        if test_methyl_array and predict:
            ma_t.beta = ma_t.beta.loc[:, modulecpgs]
    # https://github.com/higgsfield/Capsule-Network-Tutorial/blob/master/Capsule%20Network.ipynb
    original_interest_col = interest_col
    if n_bins:
        new_interest_col = interest_col + '_binned'
        ma.pheno.loc[:,
                     new_interest_col], bins = pd.cut(ma.pheno[interest_col],
                                                      bins=n_bins,
                                                      retbins=True)
        ma_v.pheno.loc[:,
                       new_interest_col], _ = pd.cut(ma_v.pheno[interest_col],
                                                     bins=bins,
                                                     retbins=True)
        if test_methyl_array and predict:
            ma_t.pheno.loc[:, new_interest_col], _ = pd.cut(
                ma_t.pheno[interest_col], bins=bins, retbins=True)
        interest_col = new_interest_col

    datasets = dict()

    datasets['train'] = MethylationDataset(
        ma,
        interest_col,
        modules=final_modules,
        module_names=module_names,
        original_interest_col=original_interest_col,
        run_spw=fit_spw)
    print(datasets['train'].X.isnull().sum().sum())
    datasets['val'] = MethylationDataset(
        ma_v,
        interest_col,
        modules=final_modules,
        module_names=module_names,
        original_interest_col=original_interest_col,
        run_spw=fit_spw)
    if test_methyl_array and predict:
        datasets['test'] = MethylationDataset(
            ma_t,
            interest_col,
            modules=final_modules,
            module_names=module_names,
            original_interest_col=original_interest_col,
            run_spw=fit_spw)

    dataloaders = dict()

    dataloaders['train'] = DataLoader(datasets['train'],
                                      batch_size=batch_size,
                                      shuffle=True,
                                      num_workers=8,
                                      pin_memory=True,
                                      drop_last=True)
    dataloaders['val'] = DataLoader(datasets['val'],
                                    batch_size=batch_size,
                                    shuffle=False,
                                    num_workers=8,
                                    pin_memory=True,
                                    drop_last=False)
    n_primary = len(final_modules)
    if test_methyl_array and predict:
        dataloaders['test'] = DataLoader(datasets['test'],
                                         batch_size=batch_size,
                                         shuffle=False,
                                         num_workers=8,
                                         pin_memory=True,
                                         drop_last=False)

    n_inputs = list(map(len, final_modules))

    n_out_caps = len(datasets['train'].y_unique)

    if not fit_spw:
        print("Not fitting MethylSPWNet")
        primary_caps = PrimaryCaps(modules=final_modules,
                                   hidden_topology=hidden_topology,
                                   n_output=primary_caps_out_len)
        hidden_caps = []
        output_caps = CapsLayer(n_out_caps,
                                n_primary,
                                primary_caps_out_len,
                                caps_out_len,
                                routing_iterations=routing_iterations)
        decoder = Decoder(n_out_caps * caps_out_len, len(list(ma.beta)),
                          decoder_topology)
        model = CapsNet(primary_caps,
                        hidden_caps,
                        output_caps,
                        decoder,
                        gamma=gamma)

        if test_methyl_array and predict:
            model.load_state_dict(torch.load('capsnet_model.pkl'))

    else:
        print("Fitting MethylSPWNet")
        module_lens = [len(x) for x in final_modules]
        model = MethylSPWNet(module_lens,
                             hidden_topology,
                             dropout_p=0.2,
                             n_output=n_out_caps)
        if test_methyl_array and predict:
            model.load_state_dict(torch.load('spwnet_model.pkl'))

    if torch.cuda.is_available():
        model = model.cuda()

    # extract all c_ij for all layers across all batches, or just last batch

    if l1_l2 and fit_spw:
        l1, l2 = list(map(float, l1_l2.split(',')))
    elif fit_spw:
        l1, l2 = 0., 0.

    trainer = Trainer(model=model,
                      validation_dataloader=dataloaders['val'],
                      n_epochs=n_epochs,
                      lr=learning_rate,
                      n_primary=n_primary,
                      custom_loss=custom_loss,
                      gamma2=gamma2,
                      spw_mode=fit_spw,
                      l1=l1 if fit_spw else 0.,
                      l2=l2 if fit_spw else 0.)

    if not predict:
        try:
            #assert 1==2
            trainer.fit(dataloader=dataloaders['train'])
            val_loss = min(trainer.val_losses)
            torch.save(
                trainer.model.state_dict(),
                'capsnet_model.pkl' if not fit_spw else 'spwnet_model.pkl')
            if fit_spw:
                torch.save(
                    dict(final_modules=final_modules,
                         modulecpgs=modulecpgs,
                         module_names=module_names), 'spwnet_capsules.pkl')
                torch.save(
                    dict(module_names=module_names,
                         module_lens=module_lens,
                         dropout_p=0.2,
                         hidden_topology=hidden_topology,
                         n_output=n_out_caps), 'spwnet_config.pkl')
        except Exception as e:
            print(e)
            val_loss = -2

        with sqlite3.connect('jobs.db', check_same_thread=False) as conn:
            pd.DataFrame([job, val_loss],
                         index=['job', 'val_loss'],
                         columns=[0]).T.to_sql('val_loss',
                                               conn,
                                               if_exists='append')
    else:
        if test_methyl_array:
            trainer.weights = 1.
            Y = trainer.predict(dataloaders['test'])
            pickle.dump(Y, open('predictions.pkl', 'wb'))
            val_loss = -1
    #print(val_loss)
    # print([min(trainer.val_losses),n_epochs,
    # 		n_bins,
    # 		bin_len,
    # 		min_capsule_len,
    # 		primary_caps_out_len,
    # 		caps_out_len,
    # 		hidden_topology,
    # 		gamma,
    # 		decoder_topology,
    # 		learning_rate,
    # 		routing_iterations])

    return val_loss
Ejemplo n.º 30
0
def print_shape(input_pkl):
    """Print dimensions of beta matrix."""
    print(MethylationArray.from_pickle(input_pkl).beta.shape)