Пример #1
0
def mageckmle_main(pvargs=None, parsedargs=None, returndict=False):
    '''
  Main entry for MAGeCK MLE
  ----
  Parameters:

  pvargs
    Arguments for parsing
  returndict
    If set true, will not try to run the whole prediction process, but will return after mean variance modeling
  '''
    # parsing arguments
    if parsedargs is not None:
        args = parsedargs
    else:
        args = mageckmle_parseargs(pvargs)
    args = mageckmle_postargs(args)
    # Bayes module
    if hasattr(args, 'bayes') and args.bayes:
        from mlemageck_bayes import mageck_bayes_main
        sys.exit(0)  # comment this when you think bayes module is completed
        mageck_bayes_main(parsedargs=args)
        sys.exit(0)  #
    # from mleclassdef import *
    # from mledesignmat import *
    # from mleem import *
    # from mleinstanceio import *
    # from mlemeanvar import *
    import scipy
    from scipy.stats import nbinom
    import numpy as np
    import numpy.linalg as linalg
    from mageck.mleinstanceio import read_gene_from_file, write_gene_to_file, write_sgrna_to_file
    from mageck.mleem import iteratenbem
    from mageck.mlemeanvar import MeanVarModel
    from mageck.mageckCount import normalizeCounts
    from mageck.mlesgeff import read_sgrna_eff, sgrna_eff_initial_guess
    from mageck.dispersion_characterization import sgrna_wide_dispersion_estimation_MAP_v2
    from mageck.mlemultiprocessing import runem_multiproc, iteratenbem_permutation, iteratenbem_permutation_by_nsg
    from mageck.cnv_normalization import read_CNVdata, match_sgrnaCN, betascore_piecewisenorm, betascore_piecewisenorm
    from mageck.cnv_estimation import mageckmleCNVestimation

    # main process
    maxfittinggene = args.genes_varmodeling
    maxgene = np.inf
    # reading sgRNA efficiency
    read_sgrna_eff(args)
    # reading read count table
    allgenedict = read_gene_from_file(args.count_table,
                                      includesamples=args.include_samples)
    #
    sgrna_eff_initial_guess(args, allgenedict)
    #
    #
    #
    # calculate the size factor
    cttab_sel = {}
    for (geneid, gk) in allgenedict.items():
        sgid = gk.sgrnaid
        sgreadmat = gk.nb_count.getT().tolist()
        for i in range(len(sgid)):
            cttab_sel[sgid[i]] = sgreadmat[i]
    if hasattr(args, 'norm_method'):
        if args.norm_method != 'none':
            size_f = normalizeCounts(cttab_sel,
                                     method=args.norm_method,
                                     returnfactor=True,
                                     reversefactor=True,
                                     controlsgfile=args.control_sgrna)
        else:
            size_f = None
    else:
        size_f = normalizeCounts(cttab_sel,
                                 returnfactor=True,
                                 reversefactor=True)
    if size_f != None:
        logging.info('size factor: ' + ','.join([str(x) for x in size_f]))

    # desmat=np.matrix([[1,1,1,1],[0,0,1,0],[0,0,0,1]]).getT()
    desmat = args.design_matrix
    ngene = 0
    for (tgid, tginst) in allgenedict.items():
        tginst.design_mat = desmat
    #
    #
    #
    #
    # perform copy number estimation if option selected
    if args.cnv_est is not None:
        logging.info('Performing copy number estimation...')
        # organize sgRNA-gene pairing into dictionary
        sgrna2genelist = {
            sgrna: gene
            for gene in allgenedict for sgrna in allgenedict[gene].sgrnaid
        }
        # estimate CNV and write results to file
        mageckmleCNVestimation(args.cnv_est, cttab_sel, desmat, sgrna2genelist,
                               args.beta_labels[1:], args.output_prefix)

    #
    #
    #
    #
    # run the EM for a few genes to perform gene fitting process
    meanvardict = {}
    for (tgid, tginst) in allgenedict.items():
        #iteratenbem(tginst,debug=False,alpha_val=0.01,estimateeff=False,size_factor=size_f)
        ##sgrna_wide_dispersion_estimation_MAP_v2(tginst,tginst.design_mat)
        ngene += 1
        tginst.w_estimate = []
        meanvardict[tgid] = tginst
        if ngene > maxfittinggene:
            break
    argsdict = {
        'debug': False,
        'alpha_val': 0.01,
        'estimateeff': False,
        'size_factor': size_f
    }
    runem_multiproc(meanvardict, args, nproc=args.threads, argsdict=argsdict)
    for (tgid, tginst) in meanvardict.items():
        allgenedict[tgid] = tginst
    #
    #
    #
    #
    # model the mean and variance
    logging.info('Modeling the mean and variance ...')
    if maxfittinggene > 0:
        mrm = MeanVarModel()
        # old: linear model
        mrm.get_mean_var_residule(allgenedict)
        mrm.model_mean_var_by_lm()
        # new method: generalized linear model
        #mrm.model_mean_disp_by_glm(allgenedict,args.output_prefix,size_f)
    else:
        mrm = None

    if returndict:
        return (allgenedict, mrm, size_f)
    # run the test again...
    logging.info('Run the algorithm for the second time ...')
    if hasattr(args, 'threads') and args.threads > 0:
        # multipel threads
        argsdict = {
            'debug': False,
            'estimateeff': True,
            'meanvarmodel': mrm,
            'restart': False,
            'removeoutliers': args.remove_outliers,
            'size_factor': size_f,
            'updateeff': args.update_efficiency,
            'logem': False
        }
        runem_multiproc(allgenedict,
                        args,
                        nproc=args.threads,
                        argsdict=argsdict)
    else:
        # only 1 thread
        # the following codes should be merged to the above code section
        ngene = 0
        for (tgid, tginst) in allgenedict.items():
            #try:
            if ngene % 1000 == 1 or args.debug:
                logging.info('Calculating ' + tgid + ' (' + str(ngene) +
                             ') ... ')
            if hasattr(
                    args, 'debug_gene'
            ) and args.debug_gene != None and tginst.prefix != args.debug_gene:
                continue
            iteratenbem(tginst,
                        debug=False,
                        estimateeff=True,
                        meanvarmodel=mrm,
                        restart=False,
                        removeoutliers=args.remove_outliers,
                        size_factor=size_f,
                        updateeff=args.update_efficiency)
            # Tracer()()
            ngene += 1
            if ngene > maxgene:
                break
            #except:
            #  logging.error('Error occurs while calculating beta values of gene '+tgid+'.')
            #  sys.exit(-1)
    # set up the w vector
    for (tgid, tginst) in allgenedict.items():
        if len(tginst.w_estimate) == 0:
            tginst.w_estimate = np.ones(len(tginst.sgrnaid))
    #Tracer()()
    # permutation, either by group or together
    if args.no_permutation_by_group:
        iteratenbem_permutation(allgenedict,
                                args,
                                nround=args.permutation_round,
                                removeoutliers=args.remove_outliers,
                                size_factor=size_f)
    else:
        iteratenbem_permutation_by_nsg(allgenedict, args, size_f=size_f)
    # correct for FDR
    from mageck.mleclassdef import gene_fdr_correction
    gene_fdr_correction(allgenedict, args.adjust_method)
    # correct for CNV
    if args.cnv_norm is not None or args.cnv_est is not None:
        if args.cnv_norm is not None:  # get copy number data from external copy number dataset
            logging.info('Performing copy number normalization.')
            (CN_arr, CN_celldict,
             CN_genedict) = read_CNVdata(args.cnv_norm, args.beta_labels[1:])
            genes2correct = False  # do not select only subset of genes to correct (i.e. correct all genes)
        elif args.cnv_est is not None:  # get copy number data from copy number estimates calculated earlier
            logging.info(
                'Performing copy number normalization using copy number estimates.'
            )
            (CN_arr, CN_celldict, CN_genedict) = read_CNVdata(
                str(args.output_prefix) + 'CNVestimates.txt',
                args.beta_labels[1:])
            genes2correct = highestCNVgenes(CN_arr, CN_genedict, percentile=98)
        for i in range(len(args.beta_labels[1:])):
            if args.beta_labels[1:][i] not in CN_celldict:
                logging.warning(
                    args.beta_labels[1:][i] +
                    ' is not represented in the inputted copy number variation data.'
                )
            else:
                logging.info('Normalizing by copy number with ' +
                             args.beta_labels[1:][i] +
                             ' as the reference cell line.')
        betascore_piecewisenorm(allgenedict,
                                args.beta_labels,
                                CN_arr,
                                CN_celldict,
                                CN_genedict,
                                selectGenes=genes2correct)

    # write to file
    genefile = args.output_prefix + '.gene_summary.txt'
    sgrnafile = args.output_prefix + '.sgrna_summary.txt'
    logging.info('Writing gene results to ' + genefile)
    logging.info('Writing sgRNA results to ' + sgrnafile)
    write_gene_to_file(allgenedict,
                       genefile,
                       args,
                       betalabels=args.beta_labels)
    write_sgrna_to_file(allgenedict, sgrnafile)
    return (allgenedict, mrm)
Пример #2
0
def crispr_test(tab, ctrlg, testg, varg, destfile, sgrna2genelist, args):
    """
  main function of crispr test
  Parameters:
    tab
        Read count table
    ctrlg
        Index for control samples
    testg
        Index for treatment samples
    varg 
        Index for variance estimation samples; if it's empty, use defafult variance estimation samples
    destfile
        Prefix for output file (sgrna_summary.txt)
    sgrna2genelist
        {sgrna:gene} mapping
    args
        Arguments
  Return value:
    (lowp,highp,sgrnalfc)
    lowp
        alpha cutoff for neg. selection
    highp
        alpha cutoff for pos. selection
    lower_gene_lfc
        {gene:lfc} dict. lfc is for neg. selection
    higher_gene_lfc
        {gene:lfc} dict. lfc is for pos. selection
  """
    n = len(tab)

    # control and test matrix
    tabctrl = {
        k: [v[i] for i in range(len(v)) if i in ctrlg]
        for (k, v) in tab.items()
    }
    tabtest = {
        k: [v[i] for i in range(len(v)) if i in testg]
        for (k, v) in tab.items()
    }

    # # write to file
    # f = open('sgrna2genelist.txt','w')
    # import csv
    # writer = csv.writer(f,delimiter='\t')
    # [writer.writerow([k,v]) for (k,v) in sgrna2genelist.items()]
    # f.close()

    # sgrnas = tabctrl.keys()
    # f = open('tabctrl.txt','w')
    # writer = csv.writer(f,delimiter='\t')
    # [writer.writerow([sgrna] + tabctrl[sgrna]) for sgrna in sgrnas]
    # f.close()
    # f = open('tabtest.txt','w')
    # writer = csv.writer(f,delimiter='\t')
    # [writer.writerow([sgrna] + tabtest[sgrna]) for sgrna in sgrnas]
    # f.close()
    #
    #
    #
    #
    # perform copy number estimation if option selected
    if args.cnv_est is not None:
        from mageck.cnv_estimation import mageckCNVestimation
        logging.info('Performing copy number estimation...')
        mageckCNVestimation(args.cnv_est, tabctrl, tabtest, sgrna2genelist,
                            'CNVest', args.output_prefix)
    #
    #
    #
    #
    # control matrix for mean-var estimation
    if len(varg) > 1:
        tabctrlmod = {
            k: [v[i] for i in range(len(v)) if i in varg]
            for (k, v) in tab.items()
        }
    #elif len(ctrlg)>1 and args.variance_from_all_samples==False: # more than 1 controls
    elif len(ctrlg) > 1:  # more than 1 controls
        tabctrlmod = {
            k: [v[i] for i in range(len(v)) if i in ctrlg]
            for (k, v) in tab.items()
        }
    else:  # only 1 control: use all the samples for estimation
        tabctrlmod = {
            k: [v[i] for i in range(len(v)) if i in (ctrlg + testg)]
            for (k, v) in tab.items()
        }
    # treatment matrix
    tabtreatmod = {
        k: [v[i] for i in range(len(v)) if i in (testg)]
        for (k, v) in tab.items()
    }
    # training using control samples
    model1 = modelmeanvar(tabctrlmod, method='linear')
    # 求出线性关系参数 k b
    #model2=modelmeanvar(tabctrl,method='edger')
    model = [x for x in model1]
    #+[model2]
    #if type(model) is types.ListType:
    if isinstance(model, list):
        logging.debug('Adjusted model: ' + '\t'.join([str(x) for x in model]))
    else:
        logging.debug('Adjusted model: k=' + str(model))

    tabctrl_mat = list(tabctrl.values())
    tabctrlmodel_mat = list(tabctrlmod.values())
    tabc_mean = getMeans(tabctrl_mat)
    tabcmodel_mean = getMeans(tabctrlmodel_mat)
    #
    # setup the valid sgRNA flag; note that this step has to be done before setting the zero values of tabc_mean
    validsgrna1 = [1] * n
    if hasattr(args, "remove_zero"):
        validsgrna1 = [(lambda x: 1
                        if x > args.remove_zero_threshold else 0)(t)
                       for t in tabc_mean]
    # if mean of the control samples is 0: set it to greater than 0
    tabc_min = min([x for x in tabc_mean if x > 0])
    tabc_mean = [(lambda x: x if x > tabc_min else tabc_min)(t)
                 for t in tabc_mean]
    #
    #
    #
    #
    # calculate the variance and adjusted variance
    # for consistency, tabc_var would be the final raw variance, and tabc_adjvar would be the final adjusted variance value
    if False:
        # use only control variance
        tabc_var = getVars(tabctrlmodel_mat)
        tabc_adjvar = getadjustvar(model, tabc_mean, method='linear')
    else:
        # consider both control and treatment variances
        # raw variances
        t_var_c = getVars(tabctrlmodel_mat)
        t_var_t = getVars(tabtreatmod.values())
        n_ctrl = len(ctrlg)
        n_test = len(testg)
        if len(varg) > 1 or len(
                ctrlg
        ) > 1:  # more than 1 controls, or users specify that variance should be calculated from certain samples
            # change back to only control variances since 0.5.7
            t_var_mix = t_var_c
        else:
            # just 1 control or users specify that variance should be calculated from all samples
            if n_ctrl == 1 and n_test == 1:
                t_var_mix = t_var_c  # older version
            else:
                frac_c = (n_ctrl - 1) * 1.0 / (n_ctrl - 1 + n_test - 1)
                # frac_c=1.0
                frac_t = 1.0 - frac_c
                t_var_mix = [
                    t_var_c[i] * frac_c + t_var_t[i] * frac_t
                    for i in range(len(t_var_c))
                ]
                logging.info('Raw variance calculation: ' + str(frac_c) +
                             ' for control, ' + str(frac_t) + ' for treatment')
        # adjusted variances
        tabc_var = t_var_mix
        # 用之前算出来的k b来??
        t_var_adj = getadjustvar(model, tabc_mean, method='linear')
        if False:
            # the following code is used only if you want to consider raw variances (without regression)
            # set tabc_var and tabc_adjvar
            # calculate the fraction of raw variances in terms of adjusted variance calculation
            # if n_ctrl+n_test <=2: the variance is completely depend on modeling
            # if n_ctrl+n_test >=8: the variance is completely depend on the data
            frac_raw = (n_ctrl - 1 + n_test - 1) / 6.0
            frac_raw = 0.0
            if frac_raw > 1.0:
                frac_raw = 1.0
            if frac_raw < 0.0:
                frac_raw = 0.0
            logging.info('Adjusted variance calculation: ' + str(frac_raw) +
                         ' for raw variance, ' + str(1 - frac_raw) +
                         ' for modeling')
            tabc_var = t_var_mix
            # increase the raw variance if it's smaller than the model
            tvar_mix_upper = [
                max(t_var_mix[i], t_var_adj[i]) for i in range(len(t_var_mix))
            ]
            #
            # calculate the final adjusted variance, based on either a mixture of raw variance or regressed variance.
            # Currently frac_raw=0, meaning that it's just regressed variance
            tabc_adjvar = [
                tvar_mix_upper[i] * frac_raw + t_var_adj[i] * (1.0 - frac_raw)
                for i in range(len(tabc_var))
            ]
        tabc_adjvar = t_var_adj
    #
    #
    #
    #
    # testing using tebtest
    #nt=tabtest[tabtest.keys()[0]]
    nt = list(tabtest.values())[0]
    ttmat = list(tabtest.values())
    ttmean = getMeans(ttmat)
    # setup the valid sgRNA flag
    validsgrna2 = [1] * n
    validsgrna = [1] * n
    if hasattr(args, "remove_zero"):
        validsgrna2 = [(lambda x: 1
                        if x > args.remove_zero_threshold else 0)(t)
                       for t in ttmean]
        if args.remove_zero == "control":
            validsgrna = validsgrna1
        elif args.remove_zero == "treatment":
            validsgrna = validsgrna2
        elif args.remove_zero == "any":
            validsgrna = [validsgrna1[t] * validsgrna2[t] for t in range(n)]
        elif args.remove_zero == "both":
            validsgrna = [
                1 - (1 - validsgrna1[t]) * (1 - validsgrna2[t])
                for t in range(n)
            ]
        else:
            validsgrna = [1] * n
    else:
        validsgrna = [1] * n
    logging.info("Before RRA, " + str(n - sum(validsgrna)) +
                 " sgRNAs are removed with zero counts in " +
                 args.remove_zero + " group(s).")
    if sum(validsgrna) == 0:
        logging.error(
            'No sgRNAs are left after --remove-zero filtering. Please double check with your data or --remove-zero associated parameters.'
        )
        sys.exit(-1)
    #
    #
    #
    #
    # calculate the p value
    try:
        # for consistency, use normal p values
        tt_p_lower = getNormalPValue(tabc_mean,
                                     tabc_adjvar,
                                     ttmean,
                                     lower=True)
        tt_p_higher = getNormalPValue(tabc_mean,
                                      tabc_adjvar,
                                      ttmean,
                                      lower=False)
        #tt_p_lower=getNBPValue(tabc_mean,tabc_adjvar,ttmean,lower=True)
        #tt_p_higher=getNBPValue(tabc_mean,tabc_adjvar,ttmean,lower=False)
    except:
        logging.error(
            'An error occurs while trying to compute p values. Quit..')
        sys.exit(-1)
    #
    #
    #
    #
    # calculate tt_theta, used for RRA ranking
    if False:
        # use ttmean to calculate the pvalue
        # first, convert to standard normal distribution values
        # old method: directly calculate the theta
        tt_theta = [(ttmean[i] - tabc_mean[i]) / math.sqrt(tabc_adjvar[i])
                    for i in range(n)]
    if False:
        # new method 1: use logP to replace theta
        tt_p_lower_small_nz = min([x for x in tt_p_lower if x > 0])
        tt_p_higher_small_nz = min([x for x in tt_p_lower if x > 0])
        tt_p_lower_rpnz = [
            max(x, tt_p_lower_small_nz * 0.1) for x in tt_p_lower
        ]
        tt_p_higher_rpnz = [
            max(x, tt_p_higher_small_nz * 0.1) for x in tt_p_higher
        ]
        tt_p_lower_log = [math.log(x, 10) for x in tt_p_lower_rpnz]
        tt_p_higher_log = [math.log(x, 10) for x in tt_p_higher_rpnz]
        tt_theta = [
            (lambda i: tt_p_lower_log[i]
             if ttmean[i] < tabc_mean[i] else -1 * tt_p_higher_log[i])(x)
            for x in range(n)
        ]
    if True:
        # new method 2: use qnorm to reversely calculate theta, only on p_lows
        logging.info('Use qnorm to reversely calculate sgRNA scores ...')
        tt_theta_0 = [(ttmean[i] - tabc_mean[i]) / math.sqrt(tabc_adjvar[i])
                      for i in range(n)]
        tt_p_lower_small_nz = min([x * 0.1 for x in tt_p_lower if x * 0.1 > 0])
        #print('nz:'+str(tt_p_lower_small_nz))
        tt_p_lower_rpnz = [max(x, tt_p_lower_small_nz) for x in tt_p_lower]
        tt_p_lower_log = [math.log(x)
                          for x in tt_p_lower_rpnz]  # here it's e based log
        # convert log p to theta score
        T_Q_CV = QNormConverter()
        tt_theta_low_cvt = T_Q_CV.get_qnorm(tt_p_lower_log, islog=True)
        # replace the None value to original theta score
        tt_theta_low = [
            tt_theta_low_cvt[i]
            if tt_theta_low_cvt[i] is not None else tt_theta_0[i]
            for i in range(n)
        ]
        tt_theta = [(lambda i: tt_theta_low[i]
                     if ttmean[i] < tabc_mean[i] else tt_theta_0[i])(x)
                    for x in range(n)]
    #
    tt_abstheta = [math.fabs(tt_theta[i]) for i in range(n)]
    # lower_score and higher_score are used to sort sgRNAs

    tt_p_lower_score = tt_theta
    tt_p_higher_score = [-1 * x for x in tt_theta]
    #
    tt_p_twosided = [(lambda x, y: 2 * x
                      if x < y else 2 * y)(tt_p_lower[i], tt_p_higher[i])
                     for i in range(n)]
    tt_p_fdr = pFDR(tt_p_twosided, method=args.adjust_method)
    #
    #
    #
    #
    # CNV normalization of scores
    # map sgRNA to genes
    gene_list = []
    sgrna_list = list(tabctrl.keys())
    for sgrna in sgrna_list:
        if sgrna2genelist is not None:
            gene_list.append(sgrna2genelist[sgrna])
        else:
            gene_list.append('NA')
    #
    # normalize sgRNA scores from CNV and sort according to score
    CNVnorm = False
    if (args.cnv_norm is not None
            and args.cell_line is not None) or args.cnv_est is not None:
        from mageck.cnv_normalization import read_CNVdata, sgRNAscore_piecewisenorm, highestCNVgenes
        logging.info('Performing copy number normalization.')
        # reading CNV from known CNV profiles
        if args.cnv_norm is not None and args.cell_line is not None:
            (CN_arr, CN_celldict,
             CN_genedict) = read_CNVdata(args.cnv_norm, [args.cell_line])
            genes2correct = False
        # predicting CNV and correcting CNV
        elif args.cnv_est is not None:
            (CN_arr, CN_celldict, CN_genedict) = read_CNVdata(
                args.output_prefix + 'CNVestimates.txt', ['CNVest'])
            genes2correct = highestCNVgenes(CN_arr, CN_genedict, percentile=98)

        # correcting CNV effect
        if args.cell_line in CN_celldict or 'CNVest' in CN_celldict:
            if args.cell_line in CN_celldict:
                logging.info('Normalizing by copy number with ' +
                             args.cell_line + ' as the reference cell line.')
            elif 'CNVest' in CN_celldict:
                logging.info(
                    'Normalizing by copy number using the estimated gene copy numbers.'
                )
            CNVnorm = True
            norm_tt_theta = sgRNAscore_piecewisenorm(tt_theta,
                                                     gene_list,
                                                     CN_arr,
                                                     CN_genedict,
                                                     selectGenes=genes2correct)
            norm_tt_abstheta = [math.fabs(norm_tt_theta[i]) for i in range(n)]
            sort_id = [
                i[0] for i in sorted(enumerate(norm_tt_abstheta),
                                     key=lambda x: x[1],
                                     reverse=True)
            ]
            # replace the original values of tt_theta
            tt_theta = norm_tt_theta
            tt_abstheta = norm_tt_abstheta
        else:
            logging.warning(
                args.cell_line +
                ' is not represented in the inputted copy number variation data.'
            )
            sort_id = [
                i[0] for i in sorted(
                    enumerate(tt_abstheta), key=lambda x: x[1], reverse=True)
            ]
    else:
        sort_id = [
            i[0] for i in sorted(
                enumerate(tt_abstheta), key=lambda x: x[1], reverse=True)
        ]
    #
    #
    #
    #
    # calculating lower_score and higher_score to sort sgRNAs
    tt_p_lower_score = tt_theta
    tt_p_higher_score = [-1 * x for x in tt_theta]
    #
    #
    # calculating sgRNA log fold change
    sgrnalfc = [
        math.log(ttmean[i] + 1.0, 2) - math.log(tabc_mean[i] + 1.0, 2)
        for i in range(n)
    ]
    #
    #
    #
    # write to file
    destfname = destfile + '.sgrna_summary.txt'
    destf = open(destfname, 'w')
    destkeys = list(tabctrl.keys())
    dfmt = "{:.5g}"

    #
    # output to file
    header = [
        'sgrna', 'Gene', 'control_count', 'treatment_count', 'control_mean',
        'treat_mean', 'LFC', 'control_var', 'adj_var', 'score', 'p.low',
        'p.high', 'p.twosided', 'FDR', 'high_in_treatment'
    ]
    #if CNVnorm:
    #  header += ['CNVadj_score']
    print('\t'.join(header), file=destf)
    for i in sort_id:
        # sgRNA mapping to genes?
        if sgrna2genelist is not None:
            destkeygene = sgrna2genelist[destkeys[i]]
        else:
            destkeygene = 'None'
        report = [
            destkeys[i], destkeygene,
            '/'.join([dfmt.format(x) for x in tabctrl_mat[i]]),
            '/'.join([dfmt.format(x) for x in ttmat[i]])
        ]
        t_r = [tabc_mean[i], ttmean[i]]
        t_r += [sgrnalfc[i]]  # log fold change
        t_r += [
            tabc_var[i], tabc_adjvar[i], tt_abstheta[i], tt_p_lower[i],
            tt_p_higher[i], tt_p_twosided[i], tt_p_fdr[i]
        ]
        report += [dfmt.format(x) for x in t_r]
        report += [ttmean[i] > tabc_mean[i]]
        #if CNVnorm:
        #  report+=[dfmt.format(norm_tt_abstheta[i])] # add CNV-adjusted sgRNA scores
        print('\t'.join([str(x) for x in report]), file=destf)
    destf.close()
    #
    #
    #
    #
    # prepare files for gene test
    if sgrna2genelist is not None:
        destfname = destfile + '.plow.txt'
        destkeys = list(tabctrl.keys())
        sort_id = [
            i[0] for i in sorted(
                enumerate(tt_p_lower_score), key=lambda x: x[1], reverse=False)
        ]
        # output to file
        destf = open(destfname, 'w')
        print('\t'.join(['sgrna', 'symbol', 'pool', 'p.low', 'prob',
                         'chosen']),
              file=destf)
        for i in sort_id:
            report = [
                destkeys[i], sgrna2genelist[destkeys[i]], 'list',
                tt_p_lower_score[i], '1', validsgrna[i]
            ]
            # new in 0.5.7: only print valid sgRNAs
            if validsgrna[i] == 1:
                print('\t'.join([str(x) for x in report]), file=destf)
        destf.close()

        tt_p_lower_fdr = pFDR(tt_p_lower, method=args.adjust_method)
        n_lower = sum(
            [1 for x in tt_p_lower if x <= args.gene_test_fdr_threshold])
        n_lower_valid = sum([
            1 for n_i in range(n)
            if (tt_p_lower[n_i] <= args.gene_test_fdr_threshold) and (
                validsgrna[n_i] == 1)
        ])
        #n_lower_p=n_lower*1.0/len(tt_p_lower)
        n_lower_p = n_lower_valid * 1.0 / sum(validsgrna)
        logging.debug('lower test FDR cutoff: ' + str(n_lower_p))
        # calculate gene lfc
        lower_gene_lfc = calculate_gene_lfc(args,
                                            sgrnalfc,
                                            sort_id,
                                            n_lower,
                                            sgrna2genelist,
                                            destkeys,
                                            validsgrna=validsgrna)
        #
        #
        #
        destfname = destfile + '.phigh.txt'
        destf = open(destfname, 'w')
        destkeys = list(tabctrl.keys())
        sort_id = [
            i[0] for i in sorted(enumerate(tt_p_higher_score),
                                 key=lambda x: x[1],
                                 reverse=False)
        ]
        # output to file
        print('\t'.join(
            ['sgrna', 'symbol', 'pool', 'p.high', 'prob', 'chosen']),
              file=destf)
        for i in sort_id:
            report = [
                destkeys[i], sgrna2genelist[destkeys[i]], 'list',
                tt_p_higher_score[i], '1', validsgrna[i]
            ]
            # new in 0.5.7: only print valid sgRNAs
            if validsgrna[i] == 1:
                print('\t'.join([str(x) for x in report]), file=destf)
        destf.close()

        tt_p_higher_fdr = pFDR(tt_p_higher, method=args.adjust_method)
        n_higher = sum(
            [1 for x in tt_p_higher if x <= args.gene_test_fdr_threshold])
        n_higher_valid = sum([
            1 for n_i in range(n)
            if (tt_p_higher[n_i] <= args.gene_test_fdr_threshold) and (
                validsgrna[n_i] == 1)
        ])
        if n_higher > 0:
            #n_higher_p=n_higher*1.0/len(tt_p_higher)
            n_higher_p = n_higher_valid * 1.0 / sum(validsgrna)
        else:
            n_higher_p = 0.01
        logging.debug('higher test FDR cutoff: ' + str(n_higher_p))
        # calculate gene lfc
        higher_gene_lfc = calculate_gene_lfc(args,
                                             sgrnalfc,
                                             sort_id,
                                             n_higher,
                                             sgrna2genelist,
                                             destkeys,
                                             validsgrna=validsgrna,
                                             ispos=True)
        #
        #Tracer()()
        return (n_lower_p, n_higher_p, lower_gene_lfc, higher_gene_lfc)
    else:
        return (None, None, None, None)
Пример #3
0
def crispr_test(tab, ctrlg, testg, destfile, sgrna2genelist, args):
    """
  main function of crispr test
  Parameters:
    tab
        Read count table
    ctrlg
        Index for control samples
    testg
        Index for treatment samples
    destfile
        Prefix for output file (sgrna_summary.txt)
    sgrna2genelist
        {sgrna:gene} mapping
    args
        Arguments
  Return value:
    (lowp,highp,sgrnalfc)
    lowp
        alpha cutoff for neg. selection
    highp
        alpha cutoff for pos. selection
    lower_gene_lfc
        {gene:lfc} dict. lfc is for neg. selection
    higher_gene_lfc
        {gene:lfc} dict. lfc is for pos. selection
  """
    n = len(tab)
    # control and test matrix
    tabctrl = {
        k: [v[i] for i in range(len(v)) if i in ctrlg]
        for (k, v) in tab.iteritems()
    }
    tabtest = {
        k: [v[i] for i in range(len(v)) if i in testg]
        for (k, v) in tab.iteritems()
    }
    # control matrix for mean-var estimation
    if len(
            ctrlg
    ) > 1 and args.variance_from_all_samples == False:  # more than 1 controls
        tabctrlmod = {
            k: [v[i] for i in range(len(v)) if i in ctrlg]
            for (k, v) in tab.iteritems()
        }
    else:  # only 1 control: use all the samples for estimation
        tabctrlmod = {
            k: [v[i] for i in range(len(v)) if i in (ctrlg + testg)]
            for (k, v) in tab.iteritems()
        }
    # training using control samples
    model1 = modelmeanvar(tabctrlmod, method='linear')
    #model2=modelmeanvar(tabctrl,method='edger')
    model = [x for x in model1]
    #+[model2]
    if type(model) is types.ListType:
        logging.debug('Adjusted model: ' + '\t'.join([str(x) for x in model]))
    else:
        logging.debug('Adjusted model: k=' + str(model))

    tabctrl_mat = tabctrl.values()
    tabctrlmodel_mat = tabctrlmod.values()
    tabc_mean = getMeans(tabctrl_mat)
    tabcmodel_mean = getMeans(tabctrlmodel_mat)
    #
    # setup the valid sgRNA flag
    validsgrna = [1] * n
    if hasattr(args, "remove_zero") and (args.remove_zero == "control"
                                         or args.remove_zero == "both"):
        validsgrna = [(lambda x: 1 if x > 0 else 0)(t) for t in tabc_mean]
    # if mean of the control samples is 0: set it to greater than 0
    tabc_min = min([x for x in tabc_mean if x > 0])
    tabc_mean = [(lambda x: x if x > tabc_min else tabc_min)(t)
                 for t in tabc_mean]
    tabc_var = getVars(tabctrlmodel_mat)
    tabc_adjvar = getadjustvar(model, tabc_mean, method='linear')

    # testing using tebtest
    nt = tabtest[tabtest.keys()[0]]
    ttmat = tabtest.values()
    ttmean = getMeans(ttmat)
    # set up the valid sgRNA flag
    if hasattr(args, "remove_zero") and (args.remove_zero == "treatment"
                                         or args.remove_zero == "both"):
        validsgrna2 = [(lambda x: 1 if x > 0 else 0)(t) for t in ttmean]
        validsgrna = [validsgrna[t] * validsgrna2[t] for t in range(n)]
    # use ttmean to calculate the pvalue
    # first, convert to standard normal distribution values
    tt_theta = [(ttmean[i] - tabc_mean[i]) / math.sqrt(tabc_adjvar[i])
                for i in range(n)]
    tt_abstheta = [math.fabs(tt_theta[i]) for i in range(n)]
    #
    try:
        # for consistency, use normal p values
        tt_p_lower = getNormalPValue(tabc_mean,
                                     tabc_adjvar,
                                     ttmean,
                                     lower=True)
        tt_p_higher = getNormalPValue(tabc_mean,
                                      tabc_adjvar,
                                      ttmean,
                                      lower=False)
        #tt_p_lower=getNBPValue(tabc_mean,tabc_adjvar,ttmean,lower=True)
        #tt_p_higher=getNBPValue(tabc_mean,tabc_adjvar,ttmean,lower=False)

        # tt_p_lower_score=getNBPValue(tabc_mean,tabc_adjvar,ttmean,lower=True,log=True)
        # tt_p_higher_score=getNBPValue(tabc_mean,tabc_adjvar,ttmean,lower=False,log=True)
    #except ImportError:
    #  #logging.warning('An error occurs while trying to compute p values using scipy. Will use normal model instead of Negative Binomial model, but please check with your scipy installation.')
    #  #tt_p_lower=getNormalPValue(tabc_mean,tabc_adjvar,ttmean,lower=True)
    #  #tt_p_higher=getNormalPValue(tabc_mean,tabc_adjvar,ttmean,lower=False)
    except:
        logging.error(
            'An error occurs while trying to compute p values. Quit..')
        sys.exit(-1)
    #
    #
    tt_p_twosided = [(lambda x, y: 2 * x
                      if x < y else 2 * y)(tt_p_lower[i], tt_p_higher[i])
                     for i in range(n)]
    tt_p_fdr = pFDR(tt_p_twosided, method=args.adjust_method)
    #
    # map sgRNA to genes
    gene_list = []
    sgrna_list = tabctrl.keys()
    for sgrna in sgrna_list:
        if sgrna2genelist is not None:
            gene_list.append(sgrna2genelist[sgrna])
        else:
            gene_list.append('NA')
    # normalize sgRNA scores and sort according to score
    CNVnorm = False
    if args.cnv_norm is not None and args.cell_line is not None:
        from mageck.cnv_normalization import read_CNVdata, sgRNAscore_piecewisenorm
        logging.info('Performing copy number normalization.')
        (CN_arr, CN_celldict,
         CN_genedict) = read_CNVdata(args.cnv_norm, [args.cell_line])
        if args.cell_line in CN_celldict:
            logging.info('Normalizing by copy number with' + args.cell_line +
                         'as the reference cell line.')
            CNVnorm = True
            norm_tt_theta = sgRNAscore_piecewisenorm(tt_theta, gene_list,
                                                     CN_arr, CN_genedict)
            norm_tt_abstheta = [math.fabs(norm_tt_theta[i]) for i in range(n)]
            sort_id = [
                i[0] for i in sorted(enumerate(norm_tt_abstheta),
                                     key=lambda x: x[1],
                                     reverse=True)
            ]
            # replace the original values of tt_theta
            tt_theta = norm_tt_theta
            tt_abstheta = norm_tt_abstheta
        else:
            logging.warning(
                args.cell_line +
                ' is not represented in the inputted copy number variation data.'
            )
            sort_id = [
                i[0] for i in sorted(
                    enumerate(tt_abstheta), key=lambda x: x[1], reverse=True)
            ]
    else:
        sort_id = [
            i[0] for i in sorted(
                enumerate(tt_abstheta), key=lambda x: x[1], reverse=True)
        ]
    #
    # lower_score and higher_score are used to sort sgRNAs
    tt_p_lower_score = tt_theta
    tt_p_higher_score = [-1 * x for x in tt_theta]
    # write to file
    destfname = destfile + '.sgrna_summary.txt'
    destf = open(destfname, 'w')
    destkeys = tabctrl.keys()
    dfmt = "{:.5g}"

    # sgRNA log fold change
    sgrnalfc = [0.0] * n
    # output to file
    header = [
        'sgrna', 'Gene', 'control_count', 'treatment_count', 'control_mean',
        'treat_mean', 'LFC', 'control_var', 'adj_var', 'score', 'p.low',
        'p.high', 'p.twosided', 'FDR', 'high_in_treatment'
    ]
    #if CNVnorm:
    #  header += ['CNVadj_score']
    print('\t'.join(header), file=destf)
    for i in sort_id:
        # sgRNA mapping to genes?
        if sgrna2genelist is not None:
            destkeygene = sgrna2genelist[destkeys[i]]
        else:
            destkeygene = 'None'
        report = [
            destkeys[i], destkeygene,
            '/'.join([dfmt.format(x) for x in tabctrl_mat[i]]),
            '/'.join([dfmt.format(x) for x in ttmat[i]])
        ]
        t_r = [tabc_mean[i], ttmean[i]]
        lfcval = math.log(ttmean[i] + 1.0, 2) - math.log(tabc_mean[i] + 1.0, 2)
        t_r += [lfcval]  # log fold change
        sgrnalfc[i] = lfcval  # save log fold change
        t_r += [
            tabc_var[i], tabc_adjvar[i], tt_abstheta[i], tt_p_lower[i],
            tt_p_higher[i], tt_p_twosided[i], tt_p_fdr[i]
        ]
        report += [dfmt.format(x) for x in t_r]
        report += [ttmean[i] > tabc_mean[i]]
        #if CNVnorm:
        #  report+=[dfmt.format(norm_tt_abstheta[i])] # add CNV-adjusted sgRNA scores
        print('\t'.join([str(x) for x in report]), file=destf)
    destf.close()
    #
    # prepare files for gene test
    if sgrna2genelist is not None:
        destfname = destfile + '.plow.txt'
        destkeys = tabctrl.keys()
        sort_id = [
            i[0] for i in sorted(
                enumerate(tt_p_lower_score), key=lambda x: x[1], reverse=False)
        ]
        # output to file
        destf = open(destfname, 'w')
        print('\t'.join(['sgrna', 'symbol', 'pool', 'p.low', 'prob',
                         'chosen']),
              file=destf)
        for i in sort_id:
            report = [
                destkeys[i], sgrna2genelist[destkeys[i]], 'list',
                tt_p_lower_score[i], '1', validsgrna[i]
            ]
            print('\t'.join([str(x) for x in report]), file=destf)
        destf.close()
        tt_p_lower_fdr = pFDR(tt_p_lower, method=args.adjust_method)
        n_lower = sum(
            [1 for x in tt_p_lower if x <= args.gene_test_fdr_threshold])
        n_lower_p = n_lower * 1.0 / len(tt_p_lower)
        logging.debug('lower test FDR cutoff: ' + str(n_lower_p))
        # calculate gene lfc
        lower_gene_lfc = calculate_gene_lfc(args, sgrnalfc, sort_id, n_lower,
                                            sgrna2genelist, destkeys)
        #
        destfname = destfile + '.phigh.txt'
        destf = open(destfname, 'w')
        destkeys = tabctrl.keys()
        sort_id = [
            i[0] for i in sorted(enumerate(tt_p_higher_score),
                                 key=lambda x: x[1],
                                 reverse=False)
        ]
        # output to file
        print('\t'.join(
            ['sgrna', 'symbol', 'pool', 'p.high', 'prob', 'chosen']),
              file=destf)
        for i in sort_id:
            report = [
                destkeys[i], sgrna2genelist[destkeys[i]], 'list',
                tt_p_higher_score[i], '1', validsgrna[i]
            ]
            print('\t'.join([str(x) for x in report]), file=destf)
        destf.close()
        tt_p_higher_fdr = pFDR(tt_p_higher, method=args.adjust_method)
        n_higher = sum(
            [1 for x in tt_p_higher if x <= args.gene_test_fdr_threshold])
        if n_higher > 0:
            n_higher_p = n_higher * 1.0 / len(tt_p_higher)
        else:
            n_higher_p = 0.01
        logging.debug('higher test FDR cutoff: ' + str(n_higher_p))
        # calculate gene lfc
        higher_gene_lfc = calculate_gene_lfc(args,
                                             sgrnalfc,
                                             sort_id,
                                             n_higher,
                                             sgrna2genelist,
                                             destkeys,
                                             ispos=True)
        #
        return (n_lower_p, n_higher_p, lower_gene_lfc, higher_gene_lfc)
    else:
        return (None, None, None, None)