def __getitem__(self, i): # acquire predictions, if needed if i >= self.stream_end: self.stream_start = self.stream_end self.stream_end = min(self.stream_start + self.stream_length, self.seqs_1hot.shape[0]) # subset sequences stream_seqs_1hot = self.seqs_1hot[self.stream_start:self.stream_end] # initialize batcher batcher = batcher.Batcher( stream_seqs_1hot, batch_size=self.model.hp.batch_size) # predict self.stream_preds = self.model.predict(self.sess, batcher, rc_avg=False) return self.stream_preds[i - self.stream_start]
def __getitem__(self, i): # acquire predictions, if needed if i >= self.stream_end: self.stream_start = self.stream_end self.stream_end = min(self.stream_start + self.stream_length, self.seqs_1hot.shape[0]) # subset sequences stream_seqs_1hot = self.seqs_1hot[self.stream_start:self.stream_end] # initialize batcher batcher = batcher.Batcher( stream_seqs_1hot, batch_size=self.model.hp.batch_size) # predict self.stream_grads, self.stream_preds = self.model.gradients( self.sess, batcher, layers=[0], return_preds=True) # take first layer self.stream_grads = self.stream_grads[0] return self.stream_preds[i - self.stream_start], self.stream_grads[ i - self.stream_start]
def main(): usage = "usage: %prog [options] <params_file> <model_file> <genes_hdf5_file>" parser = OptionParser(usage) parser.add_option( "-g", dest="genome_file", default="%s/data/human.hg19.genome" % os.environ["BASENJIDIR"], help="Chromosome lengths file [Default: %default]", ) parser.add_option("-l", dest="gene_list", help="Process only gene ids in the given file") parser.add_option( "-o", dest="out_dir", default="grad_mapg", help="Output directory [Default: %default]", ) parser.add_option("-t", dest="target_indexes", default=None, help="Target indexes to plot") (options, args) = parser.parse_args() if len(args) != 3: parser.error("Must provide parameters, model, and genomic position") else: params_file = args[0] model_file = args[1] genes_hdf5_file = args[2] if not os.path.isdir(options.out_dir): os.mkdir(options.out_dir) ################################################################# # reads in genes HDF5 gene_data = genedata.GeneData(genes_hdf5_file) # subset gene sequences genes_subset = set() if options.gene_list: for line in open(options.gene_list): genes_subset.add(line.rstrip()) gene_data.subset_genes(genes_subset) print("Filtered to %d sequences" % gene_data.num_seqs) ####################################################### # model parameters and placeholders job = params.read_job_params(params_file) job["seq_length"] = gene_data.seq_length job["seq_depth"] = gene_data.seq_depth job["target_pool"] = gene_data.pool_width if "num_targets" not in job: print( "Must specify number of targets (num_targets) in the parameters file.", file=sys.stderr, ) exit(1) # set target indexes if options.target_indexes is not None: options.target_indexes = [ int(ti) for ti in options.target_indexes.split(",") ] target_subset = options.target_indexes else: options.target_indexes = list(range(job["num_targets"])) target_subset = None # build model model = seqnn.SeqNN() model.build(job, target_subset=target_subset) # determine latest pre-dilated layer cnn_dilation = np.array([cp.dilation for cp in model.hp.cnn_params]) dilated_mask = cnn_dilation > 1 dilated_indexes = np.where(dilated_mask)[0] pre_dilated_layer = np.min(dilated_indexes) print("Pre-dilated layer: %d" % pre_dilated_layer) # build gradients ops t0 = time.time() print("Building target/position-specific gradient ops.", end="") model.build_grads_genes(gene_data.gene_seqs, layers=[pre_dilated_layer]) print(" Done in %ds" % (time.time() - t0), flush=True) ####################################################### # acquire gradients # initialize saver saver = tf.train.Saver() with tf.Session() as sess: # load variables into session saver.restore(sess, model_file) for si in range(gene_data.num_seqs): # initialize batcher batcher_si = batcher.Batcher( gene_data.seqs_1hot[si:si + 1], batch_size=model.hp.batch_size, pool_width=model.hp.target_pool, ) # get layer representations t0 = time.time() print("Computing gradients.", end="", flush=True) batch_grads, batch_reprs = model.gradients_genes( sess, batcher_si, gene_data.gene_seqs[si:si + 1]) print(" Done in %ds." % (time.time() - t0), flush=True) # only layer batch_reprs = batch_reprs[0] batch_grads = batch_grads[0] # G (TSSs) x T (targets) x P (seq position) x U (Units layer i) print("batch_grads", batch_grads.shape) pooled_length = batch_grads.shape[2] # S (sequences) x P (seq position) x U (Units layer i) print("batch_reprs", batch_reprs.shape) # write bigwigs t0 = time.time() print("Writing BigWigs.", end="", flush=True) # for each TSS for tss_i in range(batch_grads.shape[0]): tss = gene_data.gene_seqs[si].tss_list[tss_i] # for each target for tii in range(len(options.target_indexes)): ti = options.target_indexes[tii] # dot representation and gradient batch_grads_score = np.multiply( batch_reprs[0], batch_grads[tss_i, tii, :, :]).sum(axis=1) # open bigwig bw_file = "%s/%s-%s_t%d.bw" % ( options.out_dir, tss.gene_id, tss.identifier, ti, ) bw_open = bigwig_open(bw_file, options.genome_file) # access gene sequence information seq_chrom = gene_data.gene_seqs[si].chrom seq_start = gene_data.gene_seqs[si].start # specify bigwig locations and values bw_chroms = [seq_chrom] * pooled_length bw_starts = [ int(seq_start + li * model.hp.target_pool) for li in range(pooled_length) ] bw_ends = [ int(bws + model.hp.target_pool) for bws in bw_starts ] bw_values = [float(bgs) for bgs in batch_grads_score] # write bw_open.addEntries(bw_chroms, bw_starts, ends=bw_ends, values=bw_values) # close bw_open.close() print(" Done in %ds." % (time.time() - t0), flush=True) gc.collect()
def main(): usage = ( 'usage: %prog [options] <params_file> <model_file> <genes_hdf5_file>' ' <vcf_file>') parser = OptionParser(usage) parser.add_option( '-a', dest='all_sed', default=False, action='store_true', help= 'Print all variant-gene pairs, as opposed to only nonzero [Default: %default]' ) parser.add_option('-b', dest='batch_size', default=None, type='int', help='Batch size [Default: %default]') parser.add_option('-c', dest='csv', default=False, action='store_true', help='Print table as CSV [Default: %default]') parser.add_option('-g', dest='genome_file', default='%s/assembly/human.hg19.genome' % os.environ['HG19'], help='Chromosome lengths file [Default: %default]') parser.add_option( '-o', dest='out_dir', default='sed', help='Output directory for tables and plots [Default: %default]') parser.add_option('-p', dest='processes', default=None, type='int', help='Number of processes, passed by multi script') parser.add_option('--pseudo', dest='log_pseudo', default=0.125, type='float', help='Log2 pseudocount [Default: %default]') parser.add_option( '-r', dest='tss_radius', default=0, type='int', help= 'Radius of bins considered to quantify TSS transcription [Default: %default]' ) parser.add_option( '--rc', dest='rc', default=False, action='store_true', help= 'Average the forward and reverse complement predictions when testing [Default: %default]' ) parser.add_option('--shifts', dest='shifts', default='0', help='Ensemble prediction shifts [Default: %default]') parser.add_option( '-t', dest='targets_file', default=None, help='File specifying target indexes and labels in table format') parser.add_option( '-u', dest='penultimate', default=False, action='store_true', help='Compute SED in the penultimate layer [Default: %default]') parser.add_option( '-x', dest='tss_table', default=False, action='store_true', help='Print TSS table in addition to gene [Default: %default]') (options, args) = parser.parse_args() if len(args) == 4: # single worker params_file = args[0] model_file = args[1] genes_hdf5_file = args[2] vcf_file = args[3] elif len(args) == 6: # multi worker options_pkl_file = args[0] params_file = args[1] model_file = args[2] genes_hdf5_file = args[3] vcf_file = args[4] worker_index = int(args[5]) # load options options_pkl = open(options_pkl_file, 'rb') options = pickle.load(options_pkl) options_pkl.close() # update output directory options.out_dir = '%s/job%d' % (options.out_dir, worker_index) else: parser.error( 'Must provide parameters and model files, genes HDF5 file, and QTL VCF' ' file') if not os.path.isdir(options.out_dir): os.mkdir(options.out_dir) options.shifts = [int(shift) for shift in options.shifts.split(',')] ################################################################# # reads in genes HDF5 gene_data = genedata.GeneData(genes_hdf5_file) # filter for worker sequences if options.processes is not None: gene_data.worker(worker_index, options.processes) ################################################################# # prep SNPs # load SNPs snps = bvcf.vcf_snps(vcf_file) # intersect w/ segments print('Intersecting gene sequences with SNPs...', end='') sys.stdout.flush() seqs_snps = bvcf.intersect_seqs_snps(vcf_file, gene_data.gene_seqs, vision_p=0.5) print('%d sequences w/ SNPs' % len(seqs_snps)) ################################################################# # setup model job = params.read_job_params(params_file) job['seq_length'] = gene_data.seq_length job['seq_depth'] = gene_data.seq_depth job['target_pool'] = gene_data.pool_width if 'num_targets' not in job and gene_data.num_targets is not None: job['num_targets'] = gene_data.num_targets if 'num_targets' not in job: print("Must specify number of targets (num_targets) in the \ parameters file.", file=sys.stderr) exit(1) if options.targets_file is None: target_labels = gene_data.target_labels target_subset = None target_ids = gene_data.target_ids if target_ids is None: target_ids = ['t%d' % ti for ti in range(job['num_targets'])] target_labels = [''] * len(target_ids) else: # Unfortunately, this target file differs from some others # in that it needs to specify the indexes from the original # set. In the future, I should standardize to this version. target_ids = [] target_labels = [] target_subset = [] for line in open(options.targets_file): a = line.strip().split('\t') target_subset.append(int(a[0])) target_ids.append(a[1]) target_labels.append(a[3]) if len(target_subset) == job['num_targets']: target_subset = None # build model model = seqnn.SeqNN() model.build(job, target_subset=target_subset) if options.penultimate: # labels become inappropriate target_ids = [''] * model.hp.cnn_filters[-1] target_labels = target_ids ################################################################# # compute, collect, and print SEDs header_cols = ('rsid', 'ref', 'alt', 'gene', 'tss_dist', 'ref_pred', 'alt_pred', 'sed', 'ser', 'target_index', 'target_id', 'target_label') if options.csv: sed_gene_out = open('%s/sed_gene.csv' % options.out_dir, 'w') print(','.join(header_cols), file=sed_gene_out) if options.tss_table: sed_tss_out = open('%s/sed_tss.csv' % options.out_dir, 'w') print(','.join(header_cols), file=sed_tss_out) else: sed_gene_out = open('%s/sed_gene.txt' % options.out_dir, 'w') print(' '.join(header_cols), file=sed_gene_out) if options.tss_table: sed_tss_out = open('%s/sed_tss.txt' % options.out_dir, 'w') print(' '.join(header_cols), file=sed_tss_out) # helper variables pred_buffer = model.hp.batch_buffer // model.hp.target_pool # initialize saver saver = tf.train.Saver() with tf.Session() as sess: # load variables into session saver.restore(sess, model_file) # for each gene sequence for seq_i in range(gene_data.num_seqs): gene_seq = gene_data.gene_seqs[seq_i] print(gene_seq) # if it contains SNPs seq_snps = [snps[snp_i] for snp_i in seqs_snps[seq_i]] if len(seq_snps) > 0: # one hot code allele sequences aseqs_1hot = alleles_1hot(gene_seq, gene_data.seqs_1hot[seq_i], seq_snps) # initialize batcher batcher_gene = batcher.Batcher(aseqs_1hot, batch_size=model.hp.batch_size) # construct allele gene_seq's allele_gene_seqs = [gene_seq] * aseqs_1hot.shape[0] # predict alleles allele_tss_preds = model.predict_genes( sess, batcher_gene, allele_gene_seqs, rc=options.rc, shifts=options.shifts, penultimate=options.penultimate, tss_radius=options.tss_radius) # reshape (Alleles x TSSs) x Targets to Alleles x TSSs x Targets allele_tss_preds = allele_tss_preds.reshape( (aseqs_1hot.shape[0], gene_seq.num_tss, -1)) # extract reference and SNP alt predictions ref_tss_preds = allele_tss_preds[0] # TSSs x Targets alt_tss_preds = allele_tss_preds[1:] # SNPs x TSSs x Targets # compute TSS SED scores snp_tss_sed = alt_tss_preds - ref_tss_preds snp_tss_ser = np.log2(alt_tss_preds + options.log_pseudo) \ - np.log2(ref_tss_preds + options.log_pseudo) # compute gene-level predictions ref_gene_preds, gene_ids = gene.map_tss_genes( ref_tss_preds, gene_seq.tss_list, options.tss_radius) alt_gene_preds = [] for snp_i in range(len(seq_snps)): agp, _ = gene.map_tss_genes(alt_tss_preds[snp_i], gene_seq.tss_list, options.tss_radius) alt_gene_preds.append(agp) alt_gene_preds = np.array(alt_gene_preds) # compute gene SED scores gene_sed = alt_gene_preds - ref_gene_preds gene_ser = np.log2(alt_gene_preds + options.log_pseudo) \ - np.log2(ref_gene_preds + options.log_pseudo) # for each SNP for snp_i in range(len(seq_snps)): snp = seq_snps[snp_i] # initialize gene data structures snp_dist_gene = {} # for each TSS for tss_i in range(gene_seq.num_tss): tss = gene_seq.tss_list[tss_i] # SNP distance to TSS snp_dist = abs(snp.pos - tss.pos) if tss.gene_id in snp_dist_gene: snp_dist_gene[tss.gene_id] = min( snp_dist_gene[tss.gene_id], snp_dist) else: snp_dist_gene[tss.gene_id] = snp_dist # for each target if options.tss_table: for ti in range(ref_tss_preds.shape[1]): # check if nonzero if options.all_sed or not np.isclose( tss_sed[snp_i, tss_i, ti], 0, atol=1e-4): # print cols = (snp.rsid, bvcf.cap_allele(snp.ref_allele), bvcf.cap_allele( snp.alt_alleles[0]), tss.identifier, snp_dist, ref_tss_preds[tss_i, ti], alt_tss_preds[snp_i, tss_i, ti], tss_sed[snp_i, tss_i, ti], tss_ser[snp_i, tss_i, ti], ti, target_ids[ti], target_labels[ti]) if options.csv: print(','.join([str(c) for c in cols]), file=sed_tss_out) else: print( '%-13s %s %5s %16s %5d %7.4f %7.4f %7.4f %7.4f %4d %12s %s' % cols, file=sed_tss_out) # for each gene for gi in range(len(gene_ids)): gene_str = gene_ids[gi] if gene_ids[gi] in gene_data.multi_seq_genes: gene_str = '%s_multi' % gene_ids[gi] # print rows to gene table for ti in range(ref_gene_preds.shape[1]): # check if nonzero if options.all_sed or not np.isclose( gene_sed[snp_i, gi, ti], 0, atol=1e-4): # print cols = [ snp.rsid, bvcf.cap_allele(snp.ref_allele), bvcf.cap_allele(snp.alt_alleles[0]), gene_str, snp_dist_gene[gene_ids[gi]], ref_gene_preds[gi, ti], alt_gene_preds[snp_i, gi, ti], gene_sed[snp_i, gi, ti], gene_ser[snp_i, gi, ti], ti, target_ids[ti], target_labels[ti] ] if options.csv: print(','.join([str(c) for c in cols]), file=sed_gene_out) else: print( '%-13s %s %5s %16s %5d %7.4f %7.4f %7.4f %7.4f %4d %12s %s' % tuple(cols), file=sed_gene_out) # clean up gc.collect() sed_gene_out.close() if options.tss_table: sed_tss_out.close()
def main(): usage = "usage: %prog [options] <params_file> <model_file> <vcf_file>" parser = OptionParser(usage) parser.add_option( "-b", dest="batch_size", default=256, type="int", help="Batch size [Default: %default]", ) parser.add_option( "-c", dest="csv", default=False, action="store_true", help="Print table as CSV [Default: %default]", ) parser.add_option( "-f", dest="genome_fasta", default="%s/data/hg19.fa" % os.environ["BASENJIDIR"], help="Genome FASTA for sequences [Default: %default]", ) parser.add_option( "-g", dest="genome_file", default="%s/data/human.hg19.genome" % os.environ["BASENJIDIR"], help="Chromosome lengths file [Default: %default]", ) parser.add_option( "--h5", dest="out_h5", default=False, action="store_true", help="Output stats to sad.h5 [Default: %default]", ) parser.add_option( "--local", dest="local", default=1024, type="int", help="Local SAD score [Default: %default]", ) parser.add_option("-n", dest="norm_file", default=None, help="Normalize SAD scores") parser.add_option( "-o", dest="out_dir", default="sad", help="Output directory for tables and plots [Default: %default]", ) parser.add_option( "-p", dest="processes", default=None, type="int", help="Number of processes, passed by multi script", ) parser.add_option( "--pseudo", dest="log_pseudo", default=1, type="float", help="Log2 pseudocount [Default: %default]", ) parser.add_option( "--rc", dest="rc", default=False, action="store_true", help= "Average forward and reverse complement predictions [Default: %default]", ) parser.add_option( "--shifts", dest="shifts", default="0", type="str", help="Ensemble prediction shifts [Default: %default]", ) parser.add_option( "--stats", dest="sad_stats", default="SAD,xSAR", help="Comma-separated list of stats to save. [Default: %default]", ) parser.add_option( "-t", dest="targets_file", default=None, type="str", help="File specifying target indexes and labels in table format", ) parser.add_option( "--ti", dest="track_indexes", default=None, type="str", help="Comma-separated list of target indexes to output BigWig tracks", ) parser.add_option( "-u", dest="penultimate", default=False, action="store_true", help="Compute SED in the penultimate layer [Default: %default]", ) parser.add_option( "-z", dest="out_zarr", default=False, action="store_true", help="Output stats to sad.zarr [Default: %default]", ) (options, args) = parser.parse_args() if len(args) == 3: # single worker params_file = args[0] model_file = args[1] vcf_file = args[2] elif len(args) == 5: # multi worker options_pkl_file = args[0] params_file = args[1] model_file = args[2] vcf_file = args[3] worker_index = int(args[4]) # load options options_pkl = open(options_pkl_file, "rb") options = pickle.load(options_pkl) options_pkl.close() # update output directory options.out_dir = "%s/job%d" % (options.out_dir, worker_index) else: parser.error( "Must provide parameters and model files and QTL VCF file") if not os.path.isdir(options.out_dir): os.mkdir(options.out_dir) if options.track_indexes is None: options.track_indexes = [] else: options.track_indexes = [ int(ti) for ti in options.track_indexes.split(",") ] if not os.path.isdir("%s/tracks" % options.out_dir): os.mkdir("%s/tracks" % options.out_dir) options.shifts = [int(shift) for shift in options.shifts.split(",")] options.sad_stats = options.sad_stats.split(",") ################################################################# # setup model job = params.read_job_params( params_file, require=["seq_length", "num_targets", "target_pool"]) if options.targets_file is None: target_ids = ["t%d" % ti for ti in range(job["num_targets"])] target_labels = [""] * len(target_ids) target_subset = None else: targets_df = pd.read_table(options.targets_file, index_col=0) target_ids = targets_df.identifier target_labels = targets_df.description target_subset = targets_df.index if len(target_subset) == job["num_targets"]: target_subset = None # build model t0 = time.time() model = seqnn.SeqNN() model.build_feed( job, ensemble_rc=options.rc, ensemble_shifts=options.shifts, embed_penultimate=options.penultimate, target_subset=target_subset, ) print("Model building time %f" % (time.time() - t0), flush=True) if options.penultimate: # labels become inappropriate target_ids = [""] * model.hp.cnn_filters[-1] target_labels = target_ids # read target normalization factors target_norms = np.ones(len(target_labels)) if options.norm_file is not None: ti = 0 for line in open(options.norm_file): target_norms[ti] = float(line.strip()) ti += 1 num_targets = len(target_ids) ################################################################# # load SNPs snps = bvcf.vcf_snps(vcf_file) # filter for worker SNPs if options.processes is not None: worker_bounds = np.linspace(0, len(snps), options.processes + 1, dtype="int") snps = snps[worker_bounds[worker_index]:worker_bounds[worker_index + 1]] num_snps = len(snps) ################################################################# # setup output header_cols = ( "rsid", "ref", "alt", "ref_pred", "alt_pred", "sad", "sar", "geo_sad", "ref_lpred", "alt_lpred", "lsad", "lsar", "ref_xpred", "alt_xpred", "xsad", "xsar", "target_index", "target_id", "target_label", ) if options.out_h5: sad_out = initialize_output_h5(options.out_dir, options.sad_stats, snps, target_ids, target_labels) elif options.out_zarr: sad_out = initialize_output_zarr(options.out_dir, options.sad_stats, snps, target_ids, target_labels) else: if options.csv: sad_out = open("%s/sad_table.csv" % options.out_dir, "w") print(",".join(header_cols), file=sad_out) else: sad_out = open("%s/sad_table.txt" % options.out_dir, "w") print(" ".join(header_cols), file=sad_out) ################################################################# # process # open genome FASTA genome_open = pysam.Fastafile(options.genome_fasta) # determine local start and end loc_mid = model.target_length // 2 loc_start = loc_mid - (options.local // 2) // model.hp.target_pool loc_end = loc_start + options.local // model.hp.target_pool snp_i = 0 szi = 0 # initialize saver saver = tf.train.Saver() with tf.Session() as sess: # load variables into session saver.restore(sess, model_file) # construct first batch batch_1hot, batch_snps, snp_i = snps_next_batch( snps, snp_i, options.batch_size, job["seq_length"], genome_open) while len(batch_snps) > 0: ################################################### # predict # initialize batcher batcher = batcher.Batcher(batch_1hot, batch_size=model.hp.batch_size) # predict # batch_preds = model.predict(sess, batcher, # rc=options.rc, shifts=options.shifts, # penultimate=options.penultimate) batch_preds = model.predict_h5(sess, batcher) # normalize batch_preds /= target_norms ################################################### # collect and print SADs pi = 0 for snp in batch_snps: # get reference prediction (LxT) ref_preds = batch_preds[pi] pi += 1 # sum across length ref_preds_sum = ref_preds.sum(axis=0, dtype="float64") # print tracks for ti in options.track_indexes: ref_bw_file = "%s/tracks/%s_t%d_ref.bw" % ( options.out_dir, snp.rsid, ti, ) bigwig_write( snp, job["seq_length"], ref_preds[:, ti], model, ref_bw_file, options.genome_file, ) for alt_al in snp.alt_alleles: # get alternate prediction (LxT) alt_preds = batch_preds[pi] pi += 1 # sum across length alt_preds_sum = alt_preds.sum(axis=0, dtype="float64") # compare reference to alternative via mean subtraction sad_vec = alt_preds - ref_preds sad = alt_preds_sum - ref_preds_sum # compare reference to alternative via mean log division sar = np.log2(alt_preds_sum + options.log_pseudo) - np.log2( ref_preds_sum + options.log_pseudo) # compare geometric means sar_vec = np.log2( alt_preds.astype("float64") + options.log_pseudo) - np.log2( ref_preds.astype("float64") + options.log_pseudo) geo_sad = sar_vec.sum(axis=0) # sum locally ref_preds_loc = ref_preds[loc_start:loc_end, :].sum( axis=0, dtype="float64") alt_preds_loc = alt_preds[loc_start:loc_end, :].sum( axis=0, dtype="float64") # compute SAD locally sad_loc = alt_preds_loc - ref_preds_loc sar_loc = np.log2(alt_preds_loc + options.log_pseudo) - np.log2( ref_preds_loc + options.log_pseudo) # compute max difference position max_li = np.argmax(np.abs(sar_vec), axis=0) if options.out_h5 or options.out_zarr: sad_out["SAD"][szi, :] = sad.astype("float16") sad_out["xSAR"][szi, :] = np.array( [ sar_vec[max_li[ti], ti] for ti in range(num_targets) ], dtype="float16", ) szi += 1 else: # print table lines for ti in range(len(sad)): # print line cols = ( snp.rsid, bvcf.cap_allele(snp.ref_allele), bvcf.cap_allele(alt_al), ref_preds_sum[ti], alt_preds_sum[ti], sad[ti], sar[ti], geo_sad[ti], ref_preds_loc[ti], alt_preds_loc[ti], sad_loc[ti], sar_loc[ti], ref_preds[max_li[ti], ti], alt_preds[max_li[ti], ti], sad_vec[max_li[ti], ti], sar_vec[max_li[ti], ti], ti, target_ids[ti], target_labels[ti], ) if options.csv: print(",".join([str(c) for c in cols]), file=sad_out) else: print( "%-13s %6s %6s | %8.2f %8.2f %8.3f %7.4f %7.3f | %7.3f %7.3f %7.3f %7.4f | %7.3f %7.3f %7.3f %7.4f | %4d %12s %s" % cols, file=sad_out, ) # print tracks for ti in options.track_indexes: alt_bw_file = "%s/tracks/%s_t%d_alt.bw" % ( options.out_dir, snp.rsid, ti, ) bigwig_write( snp, job["seq_length"], alt_preds[:, ti], model, alt_bw_file, options.genome_file, ) ################################################### # construct next batch batch_1hot, batch_snps, snp_i = snps_next_batch( snps, snp_i, options.batch_size, job["seq_length"], genome_open) ################################################### # compute SAD distributions across variants if options.out_h5 or options.out_zarr: # define percentiles d_fine = 0.001 d_coarse = 0.01 percentiles_neg = np.arange(d_fine, 0.1, d_fine) percentiles_base = np.arange(0.1, 0.9, d_coarse) percentiles_pos = np.arange(0.9, 1, d_fine) percentiles = np.concatenate( [percentiles_neg, percentiles_base, percentiles_pos]) sad_out.create_dataset("percentiles", data=percentiles) pct_len = len(percentiles) for sad_stat in options.sad_stats: sad_stat_pct = "%s_pct" % sad_stat # compute sad_pct = np.percentile(sad_out[sad_stat], 100 * percentiles, axis=0).T sad_pct = sad_pct.astype("float16") # save sad_out.create_dataset(sad_stat_pct, data=sad_pct, dtype="float16") if not options.out_zarr: sad_out.close()
def main(): usage = 'usage: %prog [options] <params_file> <model_file> <vcf_file>' parser = OptionParser(usage) parser.add_option('-f', dest='figure_width', default=20, type='float', help='Figure width [Default: %default]') parser.add_option('--f1', dest='genome1_fasta', default='%s/data/hg19.fa' % os.environ['BASENJIDIR'], help='Genome FASTA which which major allele sequences will be drawn') parser.add_option('--f2', dest='genome2_fasta', default=None, help='Genome FASTA which which minor allele sequences will be drawn') parser.add_option('-g', dest='gain', default=False, action='store_true', help='Draw a sequence logo for the gain score, too [Default: %default]') # parser.add_option('-k', dest='plot_k', # default=None, type='int', # help='Plot the top k targets at each end.') parser.add_option('-l', dest='satmut_len', default=200, type='int', help='Length of centered sequence to mutate [Default: %default]') parser.add_option('--mean', dest='mean_targets', default=False, action='store_true', help='Take the mean across targets for a single plot [Default: %default]') parser.add_option('-m', dest='mc_n', default=0, type='int', help='Monte carlo iterations [Default: %default]') parser.add_option('--min', dest='min_limit', default=0.01, type='float', help='Minimum heatmap limit [Default: %default]') parser.add_option('-n', dest='load_sat_npy', default=False, action='store_true', help='Load the predictions from .npy files [Default: %default]') parser.add_option('-o', dest='out_dir', default='sat_vcf', help='Output directory [Default: %default]') parser.add_option('--rc', dest='rc', default=False, action='store_true', help='Ensemble forward and reverse complement predictions [Default: %default]') parser.add_option('--shifts', dest='shifts', default='0', help='Ensemble prediction shifts [Default: %default]') parser.add_option('-t', dest='targets_file', default=None, type='str', help='File specifying target indexes and labels in table format') (options, args) = parser.parse_args() if len(args) != 3: parser.error('Must provide parameters and model files and VCF') else: params_file = args[0] model_file = args[1] vcf_file = args[2] if not os.path.isdir(options.out_dir): os.mkdir(options.out_dir) options.shifts = [int(shift) for shift in options.shifts.split(',')] ################################################################# # prep SNP sequences ################################################################# # read parameters job = params.read_job_params(params_file, require=['seq_length', 'num_targets']) # load SNPs snps = vcf.vcf_snps(vcf_file) # get one hot coded input sequences if not options.genome2_fasta: seqs_1hot, seq_headers, snps, seqs = vcf.snps_seq1( snps, job['seq_length'], options.genome1_fasta, return_seqs=True) else: seqs_1hot, seq_headers, snps, seqs = vcf.snps2_seq1( snps, job['seq_length'], options.genome1_fasta, options.genome2_fasta, return_seqs=True) seqs_n = seqs_1hot.shape[0] ################################################################# # setup model ################################################################# if options.targets_file is None: target_ids = ['t%d' % ti for ti in range(job['num_targets'])] target_labels = ['']*len(target_ids) target_subset = None num_targets = job['num_targets'] else: targets_df = pd.read_csv(options.targets_file, sep='\t', index_col=0) target_ids = targets_df.identifier target_labels = targets_df.description target_subset = targets_df.index if len(target_subset) == job['num_targets']: target_subset = None num_targets = len(target_subset) if not options.load_sat_npy: # build model model = seqnn.SeqNN() model.build_feed_sad(job, ensemble_rc=options.rc, ensemble_shifts=options.shifts, target_subset=target_subset) # initialize saver saver = tf.train.Saver() ################################################################# # predict and process ################################################################# with tf.Session() as sess: if not options.load_sat_npy: # load variables into session saver.restore(sess, model_file) for si in range(seqs_n): header = seq_headers[si] header_fs = fs_clean(header) print('Mutating sequence %d / %d' % (si + 1, seqs_n), flush=True) # write sequence fasta_out = open('%s/seq%d.fa' % (options.out_dir, si), 'w') end_len = (len(seqs[si]) - options.satmut_len) // 2 print('>seq%d\n%s' % (si, seqs[si][end_len:-end_len]), file=fasta_out) fasta_out.close() ################################################################# # predict modifications if options.load_sat_npy: sat_preds = np.load('%s/seq%d_preds.npy' % (options.out_dir, si)) else: # supplement with saturated mutagenesis sat_seqs_1hot = satmut_seqs(seqs_1hot[si:si + 1], options.satmut_len) # initialize batcher batcher_sat = batcher.Batcher( sat_seqs_1hot, batch_size=model.hp.batch_size) # predict sat_preds = model.predict_h5(sess, batcher_sat) np.save('%s/seq%d_preds.npy' % (options.out_dir, si), sat_preds) if options.mean_targets: sat_preds = np.mean(sat_preds, axis=-1, keepdims=True) num_targets = 1 ################################################################# # compute delta, loss, and gain matrices # compute the matrix of prediction deltas: (4 x L_sm x T) array sat_delta = delta_matrix(seqs_1hot[si], sat_preds, options.satmut_len) # sat_loss, sat_gain = loss_gain(sat_delta, sat_preds[si], options.satmut_len) sat_loss = sat_delta.min(axis=0) sat_gain = sat_delta.max(axis=0) ############################################## # plot for ti in range(num_targets): # setup plot sns.set(style='white', font_scale=1) spp = subplot_params(sat_delta.shape[1]) if options.gain: plt.figure(figsize=(options.figure_width, 4)) ax_logo_loss = plt.subplot2grid( (4, spp['heat_cols']), (0, spp['logo_start']), colspan=spp['logo_span']) ax_logo_gain = plt.subplot2grid( (4, spp['heat_cols']), (1, spp['logo_start']), colspan=spp['logo_span']) ax_sad = plt.subplot2grid( (4, spp['heat_cols']), (2, spp['sad_start']), colspan=spp['sad_span']) ax_heat = plt.subplot2grid( (4, spp['heat_cols']), (3, 0), colspan=spp['heat_cols']) else: plt.figure(figsize=(options.figure_width, 3)) ax_logo_loss = plt.subplot2grid( (3, spp['heat_cols']), (0, spp['logo_start']), colspan=spp['logo_span']) ax_sad = plt.subplot2grid( (3, spp['heat_cols']), (1, spp['sad_start']), colspan=spp['sad_span']) ax_heat = plt.subplot2grid( (3, spp['heat_cols']), (2, 0), colspan=spp['heat_cols']) # plot sequence logo plot_seqlogo(ax_logo_loss, seqs_1hot[si], -sat_loss[:, ti]) if options.gain: plot_seqlogo(ax_logo_gain, seqs_1hot[si], sat_gain[:, ti]) # plot SAD plot_sad(ax_sad, sat_loss[:, ti], sat_gain[:, ti]) # plot heat map plot_heat(ax_heat, sat_delta[:, :, ti], options.min_limit) plt.tight_layout() plt.savefig('%s/%s_t%d.pdf' % (options.out_dir, header_fs, target_subset[ti]), dpi=600) plt.close()
def main(): usage = "usage: %prog [options] <params_file> <model_file> <genes_hdf5_file>" parser = OptionParser(usage) parser.add_option( "-b", dest="batch_size", default=None, type="int", help="Batch size [Default: %default]", ) parser.add_option( "-i", dest="ignore_bed", help="Ignore genes overlapping regions in this BED file", ) parser.add_option("-l", dest="load_preds", help="Load tess_preds from file") parser.add_option( "--heat", dest="plot_heat", default=False, action="store_true", help="Plot big gene-target heatmaps [Default: %default]", ) parser.add_option( "-o", dest="out_dir", default="genes_out", help="Output directory for tables and plots [Default: %default]", ) parser.add_option( "-r", dest="tss_radius", default=0, type="int", help="Radius of bins considered to quantify TSS transcription [Default: %default]", ) parser.add_option( "--rc", dest="rc", default=False, action="store_true", help="Average the forward and reverse complement predictions when testing [Default: %default]", ) parser.add_option( "-s", dest="plot_scatter", default=False, action="store_true", help="Make time-consuming accuracy scatter plots [Default: %default]", ) parser.add_option( "--shifts", dest="shifts", default="0", help="Ensemble prediction shifts [Default: %default]", ) parser.add_option( "--rep", dest="replicate_labels_file", help="Compare replicate experiments, aided by the given file with long labels", ) parser.add_option( "-t", dest="targets_file", default=None, type="str", help="File specifying target indexes and labels in table format", ) parser.add_option( "--table", dest="print_tables", default=False, action="store_true", help="Print big gene/TSS tables [Default: %default]", ) parser.add_option( "--tss", dest="tss_alt", default=False, action="store_true", help="Perform alternative TSS analysis [Default: %default]", ) parser.add_option( "-v", dest="gene_variance", default=False, action="store_true", help="Study accuracy with respect to gene variance across targets [Default: %default]", ) (options, args) = parser.parse_args() if len(args) != 3: parser.error("Must provide parameters and model files, and genes HDF5 file") else: params_file = args[0] model_file = args[1] genes_hdf5_file = args[2] if not os.path.isdir(options.out_dir): os.mkdir(options.out_dir) options.shifts = [int(shift) for shift in options.shifts.split(",")] ################################################################# # read in genes and targets gene_data = genedata.GeneData(genes_hdf5_file) ################################################################# # TSS predictions if options.load_preds is not None: # load from file tss_preds = np.load(options.load_preds) else: ####################################################### # setup model job = params.read_job_params(params_file) job["seq_length"] = gene_data.seq_length job["seq_depth"] = gene_data.seq_depth job["target_pool"] = gene_data.pool_width if not "num_targets" in job: job["num_targets"] = gene_data.num_targets # build model model = seqnn.SeqNN() model.build_feed_sad(job) if options.batch_size is not None: model.hp.batch_size = options.batch_size ####################################################### # predict TSSs t0 = time.time() print("Computing gene predictions.", end="") sys.stdout.flush() # initialize batcher gene_batcher = batcher.Batcher( gene_data.seqs_1hot, batch_size=model.hp.batch_size ) # initialie saver saver = tf.train.Saver() with tf.Session() as sess: # load variables into session saver.restore(sess, model_file) # predict tss_preds = model.predict_genes( sess, gene_batcher, gene_data.gene_seqs, rc=options.rc, shifts=options.shifts, tss_radius=options.tss_radius, ) # save to file np.save("%s/preds" % options.out_dir, tss_preds) print(" Done in %ds." % (time.time() - t0)) # up-convert tss_preds = tss_preds.astype("float32") ################################################################# # convert to genes gene_targets, _ = gene.map_tss_genes( gene_data.tss_targets, gene_data.tss, tss_radius=options.tss_radius ) gene_preds, _ = gene.map_tss_genes( tss_preds, gene_data.tss, tss_radius=options.tss_radius ) ################################################################# # determine targets # read targets if options.targets_file is not None: targets_df = pd.read_table(options.targets_file, index_col=0) target_indexes = targets_df.index else: if gene_data.num_targets is None: print("No targets to test against.") exit(1) else: target_indexes = np.arange(gene_data.num_targets) ################################################################# # correlation statistics t0 = time.time() print("Computing correlations.", end="") sys.stdout.flush() cor_table( gene_data.tss_targets, tss_preds, gene_data.target_ids, gene_data.target_labels, target_indexes, "%s/tss_cors.txt" % options.out_dir, ) cor_table( gene_targets, gene_preds, gene_data.target_ids, gene_data.target_labels, target_indexes, "%s/gene_cors.txt" % options.out_dir, draw_plots=True, ) print(" Done in %ds." % (time.time() - t0)) ################################################################# # gene statistics if options.print_tables: t0 = time.time() print("Printing predictions.", end="") sys.stdout.flush() gene_table( gene_data.tss_targets, tss_preds, gene_data.tss_ids(), gene_data.target_labels, target_indexes, "%s/transcript" % options.out_dir, options.plot_scatter, ) gene_table( gene_targets, gene_preds, gene_data.gene_ids(), gene_data.target_labels, target_indexes, "%s/gene" % options.out_dir, options.plot_scatter, ) print(" Done in %ds." % (time.time() - t0)) ################################################################# # gene x target heatmaps if options.plot_heat or options.gene_variance: ######################################### # normalize predictions across targets t0 = time.time() print("Normalizing values across targets.", end="") sys.stdout.flush() gene_targets_qn = normalize_targets( gene_targets[:, target_indexes], log_pseudo=1 ) gene_preds_qn = normalize_targets(gene_preds[:, target_indexes], log_pseudo=1) print(" Done in %ds." % (time.time() - t0)) if options.plot_heat: ######################################### # plot genes by targets clustermap t0 = time.time() print("Plotting heat maps.", end="") sys.stdout.flush() sns.set(font_scale=1.3, style="ticks") plot_genes = 1600 plot_targets = 800 # choose a set of variable genes gene_vars = gene_preds_qn.var(axis=1) indexes_var = np.argsort(gene_vars)[::-1][:plot_genes] # choose a set of random genes if plot_genes < gene_preds_qn.shape[0]: indexes_rand = np.random.choice( np.arange(gene_preds_qn.shape[0]), plot_genes, replace=False ) else: indexes_rand = np.arange(gene_preds_qn.shape[0]) # choose a set of random targets if plot_targets < 0.8 * gene_preds_qn.shape[1]: indexes_targets = np.random.choice( np.arange(gene_preds_qn.shape[1]), plot_targets, replace=False ) else: indexes_targets = np.arange(gene_preds_qn.shape[1]) # variable gene predictions clustermap( gene_preds_qn[indexes_var, :][:, indexes_targets], "%s/gene_heat_var.pdf" % options.out_dir, ) clustermap( gene_preds_qn[indexes_var, :][:, indexes_targets], "%s/gene_heat_var_color.pdf" % options.out_dir, color="viridis", table=True, ) # random gene predictions clustermap( gene_preds_qn[indexes_rand, :][:, indexes_targets], "%s/gene_heat_rand.pdf" % options.out_dir, ) # variable gene targets clustermap( gene_targets_qn[indexes_var, :][:, indexes_targets], "%s/gene_theat_var.pdf" % options.out_dir, ) clustermap( gene_targets_qn[indexes_var, :][:, indexes_targets], "%s/gene_theat_var_color.pdf" % options.out_dir, color="viridis", table=True, ) # random gene targets (crashes) # clustermap(gene_targets_qn[indexes_rand, :][:, indexes_targets], # '%s/gene_theat_rand.pdf' % options.out_dir) print(" Done in %ds." % (time.time() - t0)) ################################################################# # analyze replicates if options.replicate_labels_file is not None: # read long form labels, from which to infer replicates target_labels_long = [] for line in open(options.replicate_labels_file): a = line.split("\t") a[-1] = a[-1].rstrip() target_labels_long.append(a[-1]) # determine replicates replicate_lists = infer_replicates(target_labels_long) # compute correlations # replicate_correlations(replicate_lists, gene_data.tss_targets, # tss_preds, options.target_indexes, '%s/transcript_reps' % options.out_dir) replicate_correlations( replicate_lists, gene_targets, gene_preds, target_indexes, "%s/gene_reps" % options.out_dir, ) # , scatter_plots=True) ################################################################# # gene variance if options.gene_variance: variance_accuracy(gene_targets_qn, gene_preds_qn, "%s/gene" % options.out_dir) ################################################################# # alternative TSS if options.tss_alt: alternative_tss( gene_data.tss_targets[:, target_indexes], tss_preds[:, target_indexes], gene_data, options.out_dir, log_pseudo=1, )
def main(): usage = 'usage: %prog [options] <params_file> <model_file> <test_hdf5_file>' parser = OptionParser(usage) parser.add_option( '--ai', dest='accuracy_indexes', help= 'Comma-separated list of target indexes to make accuracy plots comparing true versus predicted values' ) parser.add_option( '--clip', dest='target_clip', default=None, type='float', help= 'Clip targets and predictions to a maximum value [Default: %default]') parser.add_option( '-d', dest='down_sample', default=1, type='int', help= 'Down sample test computation by taking uniformly spaced positions [Default: %default]' ) parser.add_option('-g', dest='genome_file', default='%s/assembly/human.hg19.genome' % os.environ['HG19'], help='Chromosome length information [Default: %default]') parser.add_option('--mc', dest='mc_n', default=0, type='int', help='Monte carlo test iterations [Default: %default]') parser.add_option( '--peak', '--peaks', dest='peaks', default=False, action='store_true', help='Compute expensive peak accuracy [Default: %default]') parser.add_option( '-o', dest='out_dir', default='test_out', help='Output directory for test statistics [Default: %default]') parser.add_option( '--rc', dest='rc', default=False, action='store_true', help='Average the fwd and rc predictions [Default: %default]') parser.add_option('-s', dest='scent_file', help='Dimension reduction model file') parser.add_option('--sample', dest='sample_pct', default=1, type='float', help='Sample percentage') parser.add_option('--save', dest='save', default=False, action='store_true') parser.add_option('--shifts', dest='shifts', default='0', help='Ensemble prediction shifts [Default: %default]') parser.add_option( '-t', dest='track_bed', help='BED file describing regions so we can output BigWig tracks') parser.add_option( '--ti', dest='track_indexes', help='Comma-separated list of target indexes to output BigWig tracks') parser.add_option('--train', dest='train', default=False, action='store_true', help='Process the training set [Default: %default]') parser.add_option('-v', dest='valid', default=False, action='store_true', help='Process the validation set [Default: %default]') parser.add_option( '-w', dest='pool_width', default=1, type='int', help= 'Max pool width for regressing nt predictions to predict peak calls [Default: %default]' ) (options, args) = parser.parse_args() if len(args) != 3: parser.error('Must provide parameters, model, and test data HDF5') else: params_file = args[0] model_file = args[1] test_hdf5_file = args[2] if not os.path.isdir(options.out_dir): os.mkdir(options.out_dir) options.shifts = [int(shift) for shift in options.shifts.split(',')] ####################################################### # load data ####################################################### data_open = h5py.File(test_hdf5_file) if options.train: test_seqs = data_open['train_in'] test_targets = data_open['train_out'] if 'train_na' in data_open: test_na = data_open['train_na'] elif options.valid: test_seqs = data_open['valid_in'] test_targets = data_open['valid_out'] test_na = None if 'valid_na' in data_open: test_na = data_open['valid_na'] else: test_seqs = data_open['test_in'] test_targets = data_open['test_out'] test_na = None if 'test_na' in data_open: test_na = data_open['test_na'] if options.sample_pct < 1: sample_n = int(test_seqs.shape[0] * options.sample_pct) print('Sampling %d sequences' % sample_n) sample_indexes = sorted( np.random.choice(np.arange(test_seqs.shape[0]), size=sample_n, replace=False)) test_seqs = test_seqs[sample_indexes] test_targets = test_targets[sample_indexes] if test_na is not None: test_na = test_na[sample_indexes] target_labels = [tl.decode('UTF-8') for tl in data_open['target_labels']] ####################################################### # model parameters and placeholders job = params.read_job_params(params_file) job['seq_length'] = test_seqs.shape[1] job['seq_depth'] = test_seqs.shape[2] job['num_targets'] = test_targets.shape[2] job['target_pool'] = int(np.array(data_open.get('pool_width', 1))) t0 = time.time() dr = seqnn.SeqNN() dr.build(job) print('Model building time %ds' % (time.time() - t0)) # adjust for fourier job['fourier'] = 'train_out_imag' in data_open if job['fourier']: test_targets_imag = data_open['test_out_imag'] if options.valid: test_targets_imag = data_open['valid_out_imag'] # adjust for factors if options.scent_file is not None: t0 = time.time() test_targets_full = data_open['test_out_full'] model = joblib.load(options.scent_file) ####################################################### # test # initialize batcher if job['fourier']: batcher_test = batcher.BatcherF(test_seqs, test_targets, test_targets_imag, test_na, dr.hp.batch_size, dr.hp.target_pool) else: batcher_test = batcher.Batcher(test_seqs, test_targets, test_na, dr.hp.batch_size, dr.hp.target_pool) # initialize saver saver = tf.train.Saver() with tf.Session() as sess: # load variables into session saver.restore(sess, model_file) # test t0 = time.time() test_acc = dr.test(sess, batcher_test, rc=options.rc, shifts=options.shifts, mc_n=options.mc_n) if options.save: np.save('%s/preds.npy' % options.out_dir, test_acc.preds) np.save('%s/targets.npy' % options.out_dir, test_acc.targets) test_preds = test_acc.preds print('SeqNN test: %ds' % (time.time() - t0)) # compute stats t0 = time.time() test_r2 = test_acc.r2(clip=options.target_clip) # test_log_r2 = test_acc.r2(log=True, clip=options.target_clip) test_pcor = test_acc.pearsonr(clip=options.target_clip) test_log_pcor = test_acc.pearsonr(log=True, clip=options.target_clip) #test_scor = test_acc.spearmanr() # too slow; mostly driven by low values print('Compute stats: %ds' % (time.time() - t0)) # print print('Test Loss: %7.5f' % test_acc.loss) print('Test R2: %7.5f' % test_r2.mean()) # print('Test log R2: %7.5f' % test_log_r2.mean()) print('Test PearsonR: %7.5f' % test_pcor.mean()) print('Test log PearsonR: %7.5f' % test_log_pcor.mean()) # print('Test SpearmanR: %7.5f' % test_scor.mean()) acc_out = open('%s/acc.txt' % options.out_dir, 'w') for ti in range(len(test_r2)): print('%4d %7.5f %.5f %.5f %.5f %s' % (ti, test_acc.target_losses[ti], test_r2[ti], test_pcor[ti], test_log_pcor[ti], target_labels[ti]), file=acc_out) acc_out.close() # print normalization factors target_means = test_preds.mean(axis=(0, 1), dtype='float64') target_means_median = np.median(target_means) target_means /= target_means_median norm_out = open('%s/normalization.txt' % options.out_dir, 'w') print('\n'.join([str(tu) for tu in target_means]), file=norm_out) norm_out.close() # clean up del test_acc # if test targets are reconstructed, measure versus the truth if options.scent_file is not None: compute_full_accuracy(dr, model, test_preds, test_targets_full, options.out_dir, options.down_sample) ####################################################### # peak call accuracy if options.peaks: # sample every few bins to decrease correlations ds_indexes_preds = np.arange(0, test_preds.shape[1], 8) ds_indexes_targets = ds_indexes_preds + (dr.hp.batch_buffer // dr.hp.target_pool) aurocs = [] auprcs = [] peaks_out = open('%s/peaks.txt' % options.out_dir, 'w') for ti in range(test_targets.shape[2]): if options.scent_file is not None: test_targets_ti = test_targets_full[:, :, ti] else: test_targets_ti = test_targets[:, :, ti] # subset and flatten test_targets_ti_flat = test_targets_ti[:, ds_indexes_targets].flatten( ).astype('float32') test_preds_ti_flat = test_preds[:, ds_indexes_preds, ti].flatten().astype('float32') # call peaks test_targets_ti_lambda = np.mean(test_targets_ti_flat) test_targets_pvals = 1 - poisson.cdf( np.round(test_targets_ti_flat) - 1, mu=test_targets_ti_lambda) test_targets_qvals = np.array(ben_hoch(test_targets_pvals)) test_targets_peaks = test_targets_qvals < 0.01 if test_targets_peaks.sum() == 0: aurocs.append(0.5) auprcs.append(0) else: # compute prediction accuracy aurocs.append( roc_auc_score(test_targets_peaks, test_preds_ti_flat)) auprcs.append( average_precision_score(test_targets_peaks, test_preds_ti_flat)) print('%4d %6d %.5f %.5f' % (ti, test_targets_peaks.sum(), aurocs[-1], auprcs[-1]), file=peaks_out) peaks_out.close() print('Test AUROC: %7.5f' % np.mean(aurocs)) print('Test AUPRC: %7.5f' % np.mean(auprcs)) ####################################################### # BigWig tracks # NOTE: THESE ASSUME THERE WAS NO DOWN-SAMPLING ABOVE # print bigwig tracks for visualization if options.track_bed: if options.genome_file is None: parser.error( 'Must provide genome file in order to print valid BigWigs') if not os.path.isdir('%s/tracks' % options.out_dir): os.mkdir('%s/tracks' % options.out_dir) track_indexes = range(test_preds.shape[2]) if options.track_indexes: track_indexes = [ int(ti) for ti in options.track_indexes.split(',') ] bed_set = 'test' if options.valid: bed_set = 'valid' for ti in track_indexes: if options.scent_file is not None: test_targets_ti = test_targets_full[:, :, ti] else: test_targets_ti = test_targets[:, :, ti] # make true targets bigwig bw_file = '%s/tracks/t%d_true.bw' % (options.out_dir, ti) bigwig_write(bw_file, test_targets_ti, options.track_bed, options.genome_file, bed_set=bed_set) # make predictions bigwig bw_file = '%s/tracks/t%d_preds.bw' % (options.out_dir, ti) bigwig_write(bw_file, test_preds[:, :, ti], options.track_bed, options.genome_file, dr.hp.batch_buffer, bed_set=bed_set) # make NA bigwig bw_file = '%s/tracks/na.bw' % options.out_dir bigwig_write(bw_file, test_na, options.track_bed, options.genome_file, bed_set=bed_set) ####################################################### # accuracy plots if options.accuracy_indexes is not None: accuracy_indexes = [ int(ti) for ti in options.accuracy_indexes.split(',') ] if not os.path.isdir('%s/scatter' % options.out_dir): os.mkdir('%s/scatter' % options.out_dir) if not os.path.isdir('%s/violin' % options.out_dir): os.mkdir('%s/violin' % options.out_dir) if not os.path.isdir('%s/roc' % options.out_dir): os.mkdir('%s/roc' % options.out_dir) if not os.path.isdir('%s/pr' % options.out_dir): os.mkdir('%s/pr' % options.out_dir) for ti in accuracy_indexes: if options.scent_file is not None: test_targets_ti = test_targets_full[:, :, ti] else: test_targets_ti = test_targets[:, :, ti] ############################################ # scatter # sample every few bins (adjust to plot the # points I want) ds_indexes_preds = np.arange(0, test_preds.shape[1], 8) ds_indexes_targets = ds_indexes_preds + (dr.hp.batch_buffer // dr.hp.target_pool) # subset and flatten test_targets_ti_flat = test_targets_ti[:, ds_indexes_targets].flatten( ).astype('float32') test_preds_ti_flat = test_preds[:, ds_indexes_preds, ti].flatten().astype('float32') # take log2 test_targets_ti_log = np.log2(test_targets_ti_flat + 1) test_preds_ti_log = np.log2(test_preds_ti_flat + 1) # plot log2 sns.set(font_scale=1.2, style='ticks') out_pdf = '%s/scatter/t%d.pdf' % (options.out_dir, ti) plots.regplot(test_targets_ti_log, test_preds_ti_log, out_pdf, poly_order=1, alpha=0.3, sample=500, figsize=(6, 6), x_label='log2 Experiment', y_label='log2 Prediction', table=True) ############################################ # violin # call peaks test_targets_ti_lambda = np.mean(test_targets_ti_flat) test_targets_pvals = 1 - poisson.cdf( np.round(test_targets_ti_flat) - 1, mu=test_targets_ti_lambda) test_targets_qvals = np.array(ben_hoch(test_targets_pvals)) test_targets_peaks = test_targets_qvals < 0.01 test_targets_peaks_str = np.where(test_targets_peaks, 'Peak', 'Background') # violin plot sns.set(font_scale=1.3, style='ticks') plt.figure() df = pd.DataFrame({ 'log2 Prediction': np.log2(test_preds_ti_flat + 1), 'Experimental coverage status': test_targets_peaks_str }) ax = sns.violinplot(x='Experimental coverage status', y='log2 Prediction', data=df) ax.grid(True, linestyle=':') plt.savefig('%s/violin/t%d.pdf' % (options.out_dir, ti)) plt.close() # ROC plt.figure() fpr, tpr, _ = roc_curve(test_targets_peaks, test_preds_ti_flat) auroc = roc_auc_score(test_targets_peaks, test_preds_ti_flat) plt.plot([0, 1], [0, 1], c='black', linewidth=1, linestyle='--', alpha=0.7) plt.plot(fpr, tpr, c='black') ax = plt.gca() ax.set_xlabel('False positive rate') ax.set_ylabel('True positive rate') ax.text(0.99, 0.02, 'AUROC %.3f' % auroc, horizontalalignment='right') # , fontsize=14) ax.grid(True, linestyle=':') plt.savefig('%s/roc/t%d.pdf' % (options.out_dir, ti)) plt.close() # PR plt.figure() prec, recall, _ = precision_recall_curve(test_targets_peaks, test_preds_ti_flat) auprc = average_precision_score(test_targets_peaks, test_preds_ti_flat) plt.axhline(y=test_targets_peaks.mean(), c='black', linewidth=1, linestyle='--', alpha=0.7) plt.plot(recall, prec, c='black') ax = plt.gca() ax.set_xlabel('Recall') ax.set_ylabel('Precision') ax.text(0.99, 0.95, 'AUPRC %.3f' % auprc, horizontalalignment='right') # , fontsize=14) ax.grid(True, linestyle=':') plt.savefig('%s/pr/t%d.pdf' % (options.out_dir, ti)) plt.close() data_open.close()
def main(): usage = "usage: %prog [options] <params_file> <model_file> <test_hdf5_file>" parser = OptionParser(usage) parser.add_option( "--ai", dest="accuracy_indexes", help= "Comma-separated list of target indexes to make accuracy plots comparing true versus predicted values", ) parser.add_option( "--clip", dest="target_clip", default=None, type="float", help= "Clip targets and predictions to a maximum value [Default: %default]", ) parser.add_option( "-d", dest="down_sample", default=1, type="int", help= "Down sample test computation by taking uniformly spaced positions [Default: %default]", ) parser.add_option( "-g", dest="genome_file", default="%s/data/human.hg19.genome" % os.environ["BASENJIDIR"], help="Chromosome length information [Default: %default]", ) parser.add_option( "--mc", dest="mc_n", default=0, type="int", help="Monte carlo test iterations [Default: %default]", ) parser.add_option( "--peak", "--peaks", dest="peaks", default=False, action="store_true", help="Compute expensive peak accuracy [Default: %default]", ) parser.add_option( "-o", dest="out_dir", default="test_out", help="Output directory for test statistics [Default: %default]", ) parser.add_option( "--rc", dest="rc", default=False, action="store_true", help="Average the fwd and rc predictions [Default: %default]", ) parser.add_option("-s", dest="scent_file", help="Dimension reduction model file") parser.add_option("--sample", dest="sample_pct", default=1, type="float", help="Sample percentage") parser.add_option("--save", dest="save", default=False, action="store_true") parser.add_option( "--shifts", dest="shifts", default="0", help="Ensemble prediction shifts [Default: %default]", ) parser.add_option( "-t", dest="track_bed", help="BED file describing regions so we can output BigWig tracks", ) parser.add_option( "--ti", dest="track_indexes", help="Comma-separated list of target indexes to output BigWig tracks", ) parser.add_option( "--train", dest="train", default=False, action="store_true", help="Process the training set [Default: %default]", ) parser.add_option( "-v", dest="valid", default=False, action="store_true", help="Process the validation set [Default: %default]", ) parser.add_option( "-w", dest="pool_width", default=1, type="int", help= "Max pool width for regressing nt predictions to predict peak calls [Default: %default]", ) (options, args) = parser.parse_args() if len(args) != 3: parser.error("Must provide parameters, model, and test data HDF5") else: params_file = args[0] model_file = args[1] test_hdf5_file = args[2] if not os.path.isdir(options.out_dir): os.mkdir(options.out_dir) options.shifts = [int(shift) for shift in options.shifts.split(",")] ####################################################### # load data ####################################################### data_open = h5py.File(test_hdf5_file) if options.train: test_seqs = data_open["train_in"] test_targets = data_open["train_out"] if "train_na" in data_open: test_na = data_open["train_na"] elif options.valid: test_seqs = data_open["valid_in"] test_targets = data_open["valid_out"] test_na = None if "valid_na" in data_open: test_na = data_open["valid_na"] else: test_seqs = data_open["test_in"] test_targets = data_open["test_out"] test_na = None if "test_na" in data_open: test_na = data_open["test_na"] if options.sample_pct < 1: sample_n = int(test_seqs.shape[0] * options.sample_pct) print("Sampling %d sequences" % sample_n) sample_indexes = sorted( np.random.choice(np.arange(test_seqs.shape[0]), size=sample_n, replace=False)) test_seqs = test_seqs[sample_indexes] test_targets = test_targets[sample_indexes] if test_na is not None: test_na = test_na[sample_indexes] target_labels = [tl.decode("UTF-8") for tl in data_open["target_labels"]] ####################################################### # model parameters and placeholders job = params.read_job_params(params_file) job["seq_length"] = test_seqs.shape[1] job["seq_depth"] = test_seqs.shape[2] job["num_targets"] = test_targets.shape[2] job["target_pool"] = int(np.array(data_open.get("pool_width", 1))) t0 = time.time() dr = seqnn.SeqNN() dr.build_feed(job) print("Model building time %ds" % (time.time() - t0)) # adjust for fourier job["fourier"] = "train_out_imag" in data_open if job["fourier"]: test_targets_imag = data_open["test_out_imag"] if options.valid: test_targets_imag = data_open["valid_out_imag"] # adjust for factors if options.scent_file is not None: t0 = time.time() test_targets_full = data_open["test_out_full"] model = joblib.load(options.scent_file) ####################################################### # test # initialize batcher if job["fourier"]: batcher_test = batcher.BatcherF( test_seqs, test_targets, test_targets_imag, test_na, dr.hp.batch_size, dr.hp.target_pool, ) else: batcher_test = batcher.Batcher(test_seqs, test_targets, test_na, dr.hp.batch_size, dr.hp.target_pool) # initialize saver saver = tf.train.Saver() with tf.Session() as sess: # load variables into session saver.restore(sess, model_file) # test t0 = time.time() test_acc = dr.test_h5_manual(sess, batcher_test, rc=options.rc, shifts=options.shifts, mc_n=options.mc_n) if options.save: np.save("%s/preds.npy" % options.out_dir, test_acc.preds) np.save("%s/targets.npy" % options.out_dir, test_acc.targets) test_preds = test_acc.preds print("SeqNN test: %ds" % (time.time() - t0)) # compute stats t0 = time.time() test_r2 = test_acc.r2(clip=options.target_clip) # test_log_r2 = test_acc.r2(log=True, clip=options.target_clip) test_pcor = test_acc.pearsonr(clip=options.target_clip) test_log_pcor = test_acc.pearsonr(log=True, clip=options.target_clip) # test_scor = test_acc.spearmanr() # too slow; mostly driven by low values print("Compute stats: %ds" % (time.time() - t0)) # print print("Test Loss: %7.5f" % test_acc.loss) print("Test R2: %7.5f" % test_r2.mean()) # print('Test log R2: %7.5f' % test_log_r2.mean()) print("Test PearsonR: %7.5f" % test_pcor.mean()) print("Test log PearsonR: %7.5f" % test_log_pcor.mean()) # print('Test SpearmanR: %7.5f' % test_scor.mean()) acc_out = open("%s/acc.txt" % options.out_dir, "w") for ti in range(len(test_r2)): print( "%4d %7.5f %.5f %.5f %.5f %s" % ( ti, test_acc.target_losses[ti], test_r2[ti], test_pcor[ti], test_log_pcor[ti], target_labels[ti], ), file=acc_out, ) acc_out.close() # print normalization factors target_means = test_preds.mean(axis=(0, 1), dtype="float64") target_means_median = np.median(target_means) target_means /= target_means_median norm_out = open("%s/normalization.txt" % options.out_dir, "w") print("\n".join([str(tu) for tu in target_means]), file=norm_out) norm_out.close() # clean up del test_acc # if test targets are reconstructed, measure versus the truth if options.scent_file is not None: compute_full_accuracy( dr, model, test_preds, test_targets_full, options.out_dir, options.down_sample, ) ####################################################### # peak call accuracy if options.peaks: # sample every few bins to decrease correlations ds_indexes_preds = np.arange(0, test_preds.shape[1], 8) ds_indexes_targets = ds_indexes_preds + (dr.hp.batch_buffer // dr.hp.target_pool) aurocs = [] auprcs = [] peaks_out = open("%s/peaks.txt" % options.out_dir, "w") for ti in range(test_targets.shape[2]): if options.scent_file is not None: test_targets_ti = test_targets_full[:, :, ti] else: test_targets_ti = test_targets[:, :, ti] # subset and flatten test_targets_ti_flat = (test_targets_ti[:, ds_indexes_targets]. flatten().astype("float32")) test_preds_ti_flat = (test_preds[:, ds_indexes_preds, ti].flatten().astype("float32")) # call peaks test_targets_ti_lambda = np.mean(test_targets_ti_flat) test_targets_pvals = 1 - poisson.cdf( np.round(test_targets_ti_flat) - 1, mu=test_targets_ti_lambda) test_targets_qvals = np.array(ben_hoch(test_targets_pvals)) test_targets_peaks = test_targets_qvals < 0.01 if test_targets_peaks.sum() == 0: aurocs.append(0.5) auprcs.append(0) else: # compute prediction accuracy aurocs.append( roc_auc_score(test_targets_peaks, test_preds_ti_flat)) auprcs.append( average_precision_score(test_targets_peaks, test_preds_ti_flat)) print( "%4d %6d %.5f %.5f" % (ti, test_targets_peaks.sum(), aurocs[-1], auprcs[-1]), file=peaks_out, ) peaks_out.close() print("Test AUROC: %7.5f" % np.mean(aurocs)) print("Test AUPRC: %7.5f" % np.mean(auprcs)) ####################################################### # BigWig tracks # NOTE: THESE ASSUME THERE WAS NO DOWN-SAMPLING ABOVE # print bigwig tracks for visualization if options.track_bed: if options.genome_file is None: parser.error( "Must provide genome file in order to print valid BigWigs") if not os.path.isdir("%s/tracks" % options.out_dir): os.mkdir("%s/tracks" % options.out_dir) track_indexes = range(test_preds.shape[2]) if options.track_indexes: track_indexes = [ int(ti) for ti in options.track_indexes.split(",") ] bed_set = "test" if options.valid: bed_set = "valid" for ti in track_indexes: if options.scent_file is not None: test_targets_ti = test_targets_full[:, :, ti] else: test_targets_ti = test_targets[:, :, ti] # make true targets bigwig bw_file = "%s/tracks/t%d_true.bw" % (options.out_dir, ti) bigwig_write( bw_file, test_targets_ti, options.track_bed, options.genome_file, bed_set=bed_set, ) # make predictions bigwig bw_file = "%s/tracks/t%d_preds.bw" % (options.out_dir, ti) bigwig_write( bw_file, test_preds[:, :, ti], options.track_bed, options.genome_file, dr.hp.batch_buffer, bed_set=bed_set, ) # make NA bigwig # bw_file = '%s/tracks/na.bw' % options.out_dir # bigwig_write( # bw_file, # test_na, # options.track_bed, # options.genome_file, # bed_set=bed_set) ####################################################### # accuracy plots if options.accuracy_indexes is not None: accuracy_indexes = [ int(ti) for ti in options.accuracy_indexes.split(",") ] if not os.path.isdir("%s/scatter" % options.out_dir): os.mkdir("%s/scatter" % options.out_dir) if not os.path.isdir("%s/violin" % options.out_dir): os.mkdir("%s/violin" % options.out_dir) if not os.path.isdir("%s/roc" % options.out_dir): os.mkdir("%s/roc" % options.out_dir) if not os.path.isdir("%s/pr" % options.out_dir): os.mkdir("%s/pr" % options.out_dir) for ti in accuracy_indexes: if options.scent_file is not None: test_targets_ti = test_targets_full[:, :, ti] else: test_targets_ti = test_targets[:, :, ti] ############################################ # scatter # sample every few bins (adjust to plot the # points I want) ds_indexes_preds = np.arange(0, test_preds.shape[1], 8) ds_indexes_targets = ds_indexes_preds + (dr.hp.batch_buffer // dr.hp.target_pool) # subset and flatten test_targets_ti_flat = (test_targets_ti[:, ds_indexes_targets]. flatten().astype("float32")) test_preds_ti_flat = (test_preds[:, ds_indexes_preds, ti].flatten().astype("float32")) # take log2 test_targets_ti_log = np.log2(test_targets_ti_flat + 1) test_preds_ti_log = np.log2(test_preds_ti_flat + 1) # plot log2 sns.set(font_scale=1.2, style="ticks") out_pdf = "%s/scatter/t%d.pdf" % (options.out_dir, ti) plots.regplot( test_targets_ti_log, test_preds_ti_log, out_pdf, poly_order=1, alpha=0.3, sample=500, figsize=(6, 6), x_label="log2 Experiment", y_label="log2 Prediction", table=True, ) ############################################ # violin # call peaks test_targets_ti_lambda = np.mean(test_targets_ti_flat) test_targets_pvals = 1 - poisson.cdf( np.round(test_targets_ti_flat) - 1, mu=test_targets_ti_lambda) test_targets_qvals = np.array(ben_hoch(test_targets_pvals)) test_targets_peaks = test_targets_qvals < 0.01 test_targets_peaks_str = np.where(test_targets_peaks, "Peak", "Background") # violin plot sns.set(font_scale=1.3, style="ticks") plt.figure() df = pd.DataFrame({ "log2 Prediction": np.log2(test_preds_ti_flat + 1), "Experimental coverage status": test_targets_peaks_str, }) ax = sns.violinplot(x="Experimental coverage status", y="log2 Prediction", data=df) ax.grid(True, linestyle=":") plt.savefig("%s/violin/t%d.pdf" % (options.out_dir, ti)) plt.close() # ROC plt.figure() fpr, tpr, _ = roc_curve(test_targets_peaks, test_preds_ti_flat) auroc = roc_auc_score(test_targets_peaks, test_preds_ti_flat) plt.plot([0, 1], [0, 1], c="black", linewidth=1, linestyle="--", alpha=0.7) plt.plot(fpr, tpr, c="black") ax = plt.gca() ax.set_xlabel("False positive rate") ax.set_ylabel("True positive rate") ax.text(0.99, 0.02, "AUROC %.3f" % auroc, horizontalalignment="right") # , fontsize=14) ax.grid(True, linestyle=":") plt.savefig("%s/roc/t%d.pdf" % (options.out_dir, ti)) plt.close() # PR plt.figure() prec, recall, _ = precision_recall_curve(test_targets_peaks, test_preds_ti_flat) auprc = average_precision_score(test_targets_peaks, test_preds_ti_flat) plt.axhline( y=test_targets_peaks.mean(), c="black", linewidth=1, linestyle="--", alpha=0.7, ) plt.plot(recall, prec, c="black") ax = plt.gca() ax.set_xlabel("Recall") ax.set_ylabel("Precision") ax.text(0.99, 0.95, "AUPRC %.3f" % auprc, horizontalalignment="right") # , fontsize=14) ax.grid(True, linestyle=":") plt.savefig("%s/pr/t%d.pdf" % (options.out_dir, ti)) plt.close() data_open.close()
def main(): usage = "usage: %prog [options] <params_file> <model_file> <input_file>" parser = OptionParser(usage) parser.add_option( "-a", dest="activity_enrich", default=1, type="float", help= "Enrich for the most active top % of sequences [Default: %default]", ) parser.add_option("-b", dest="batch_size", default=None, type="int", help="Batch size") parser.add_option( "-f", dest="figure_width", default=20, type="float", help="Figure width [Default: %default]", ) parser.add_option( "-g", dest="gain", default=False, action="store_true", help="Draw a sequence logo for the gain score, too [Default: %default]", ) parser.add_option( "-l", dest="satmut_len", default=200, type="int", help="Length of centered sequence to mutate [Default: %default]", ) parser.add_option( "-m", dest="min_limit", default=0.005, type="float", help="Minimum heatmap limit [Default: %default]", ) parser.add_option( "-n", dest="load_sat_npy", default=False, action="store_true", help="Load the predictions from .npy files [Default: %default]", ) parser.add_option( "-o", dest="out_dir", default="heat", help="Output directory [Default: %default]", ) parser.add_option( "-r", dest="rng_seed", default=1, type="float", help="Random number generator seed [Default: %default]", ) parser.add_option( "--rc", dest="rc", default=False, action="store_true", help= "Ensemble forward and reverse complement predictions [Default: %default]", ) parser.add_option( "-s", dest="sample", default=None, type="int", help="Sample sequences from the test set [Default:%default]", ) parser.add_option( "--shifts", dest="shifts", default="0", help="Ensemble prediction shifts [Default: %default]", ) parser.add_option( "-t", dest="targets", default="0", help= "Comma-separated target indexes (or -1 for all) [Default: %default]", ) (options, args) = parser.parse_args() if len(args) != 3: parser.error( "Must provide parameters and model files and input sequences (as a " "FASTA file or test data in an HDF file") else: params_file = args[0] model_file = args[1] input_file = args[2] if not os.path.isdir(options.out_dir): os.mkdir(options.out_dir) options.shifts = [int(shift) for shift in options.shifts.split(",")] random.seed(options.rng_seed) ################################################################# # parse input file ################################################################# seqs, seqs_1hot, targets = parse_input(input_file, options.sample) # decide which targets to obtain if options.targets == "-1": target_indexes = range(targets.shape[2]) target_subset = None else: target_indexes = [int(ti) for ti in options.targets.split(",")] target_subset = target_indexes # enrich for active sequences if targets is not None: seqs, seqs_1hot, targets = enrich_activity(seqs, seqs_1hot, targets, options.activity_enrich, target_indexes) seqs_n = seqs_1hot.shape[0] ################################################################# # setup model ################################################################# job = params.read_job_params(params_file) job["seq_length"] = seqs_1hot.shape[1] job["seq_depth"] = seqs_1hot.shape[2] if targets is None: if "num_targets" not in job or "target_pool" not in job: print( "Must provide num_targets and target_pool in parameters file", file=sys.stderr, ) exit(1) else: job["num_targets"] = targets.shape[2] job["target_pool"] = job["seq_length"] // targets.shape[1] t0 = time.time() model = seqnn.SeqNN() model.build_feed( job, ensemble_rc=options.rc, ensemble_shifts=options.shifts, target_subset=target_subset, ) if options.batch_size is not None: model.hp.batch_size = options.batch_size # initialize saver saver = tf.train.Saver() with tf.Session() as sess: # load variables into session saver.restore(sess, model_file) for si in range(seqs_n): print("Mutating sequence %d / %d" % (si + 1, seqs_n), flush=True) # write sequence fasta_out = open("%s/seq%d.fa" % (options.out_dir, si), "w") end_len = (len(seqs[si]) - options.satmut_len) // 2 print(">seq%d\n%s" % (si, seqs[si][end_len:-end_len]), file=fasta_out) fasta_out.close() ################################################################# # predict modifications if options.load_sat_npy: sat_preds = np.load("%s/seq%d_preds.npy" % (options.out_dir, si)) else: # supplement with saturated mutagenesis sat_seqs_1hot = satmut_seqs(seqs_1hot[si:si + 1], options.satmut_len) # initialize batcher batcher_sat = batcher.Batcher(sat_seqs_1hot, batch_size=model.hp.batch_size) # predict sat_preds = model.predict_h5(sess, batcher_sat) np.save("%s/seq%d_preds.npy" % (options.out_dir, si), sat_preds) ################################################################# # compute delta, loss, and gain matrices # compute the matrix of prediction deltas: (4 x L_sm x T) array sat_delta = delta_matrix(seqs_1hot[si], sat_preds, options.satmut_len) # sat_loss, sat_gain = loss_gain(sat_delta, sat_preds[si], options.satmut_len) sat_loss = sat_delta.min(axis=0) sat_gain = sat_delta.max(axis=0) ############################################## # plot for tii in range(len(target_indexes)): # setup plot sns.set(style="white", font_scale=1) spp = subplot_params(sat_delta.shape[1]) if options.gain: plt.figure(figsize=(options.figure_width, 5)) ax_pred = plt.subplot2grid( (5, spp["heat_cols"]), (0, spp["pred_start"]), colspan=spp["pred_span"], ) ax_logo_loss = plt.subplot2grid( (5, spp["heat_cols"]), (1, spp["logo_start"]), colspan=spp["logo_span"], ) ax_logo_gain = plt.subplot2grid( (5, spp["heat_cols"]), (2, spp["logo_start"]), colspan=spp["logo_span"], ) ax_sad = plt.subplot2grid( (5, spp["heat_cols"]), (3, spp["sad_start"]), colspan=spp["sad_span"], ) ax_heat = plt.subplot2grid((5, spp["heat_cols"]), (4, 0), colspan=spp["heat_cols"]) else: plt.figure(figsize=(options.figure_width, 4)) ax_pred = plt.subplot2grid( (4, spp["heat_cols"]), (0, spp["pred_start"]), colspan=spp["pred_span"], ) ax_logo_loss = plt.subplot2grid( (4, spp["heat_cols"]), (1, spp["logo_start"]), colspan=spp["logo_span"], ) ax_sad = plt.subplot2grid( (4, spp["heat_cols"]), (2, spp["sad_start"]), colspan=spp["sad_span"], ) ax_heat = plt.subplot2grid((4, spp["heat_cols"]), (3, 0), colspan=spp["heat_cols"]) # plot predictions plot_predictions( ax_pred, sat_preds[0, :, tii], options.satmut_len, model.hp.seq_length, model.hp.batch_buffer, ) # plot sequence logo plot_seqlogo(ax_logo_loss, seqs_1hot[si], -sat_loss[:, tii]) if options.gain: plot_seqlogo(ax_logo_gain, seqs_1hot[si], sat_gain[:, tii]) # plot SAD plot_sad(ax_sad, sat_loss[:, tii], sat_gain[:, tii]) # plot heat map plot_heat(ax_heat, sat_delta[:, :, tii], options.min_limit) plt.tight_layout() plt.savefig( "%s/seq%d_t%d.pdf" % (options.out_dir, si, target_indexes[tii]), dpi=600, ) plt.close()
def run(params_file, data_file, num_train_epochs): shifts = [int(shift) for shift in FLAGS.shifts.split(',')] ####################################################### # load data ####################################################### data_open = h5py.File(data_file) train_seqs = data_open['train_in'] train_targets = data_open['train_out'] train_na = None if 'train_na' in data_open: train_na = data_open['train_na'] valid_seqs = data_open['valid_in'] valid_targets = data_open['valid_out'] valid_na = None if 'valid_na' in data_open: valid_na = data_open['valid_na'] ####################################################### # model parameters and placeholders ####################################################### job = dna_io.read_job_params(params_file) job['batch_length'] = train_seqs.shape[1] job['seq_depth'] = train_seqs.shape[2] job['num_targets'] = train_targets.shape[2] job['target_pool'] = int(np.array(data_open.get('pool_width', 1))) job['early_stop'] = job.get('early_stop', 16) job['rate_drop'] = job.get('rate_drop', 3) t0 = time.time() dr = seqnn.SeqNN() dr.build(job) print('Model building time %f' % (time.time() - t0)) # adjust for fourier job['fourier'] = 'train_out_imag' in data_open if job['fourier']: train_targets_imag = data_open['train_out_imag'] valid_targets_imag = data_open['valid_out_imag'] ####################################################### # train ####################################################### # initialize batcher if job['fourier']: batcher_train = batcher.BatcherF(train_seqs, train_targets, train_targets_imag, train_na, dr.batch_size, dr.target_pool, shuffle=True) batcher_valid = batcher.BatcherF(valid_seqs, valid_targets, valid_targets_imag, valid_na, dr.batch_size, dr.target_pool) else: batcher_train = batcher.Batcher(train_seqs, train_targets, train_na, dr.batch_size, dr.target_pool, shuffle=True) batcher_valid = batcher.Batcher(valid_seqs, valid_targets, valid_na, dr.batch_size, dr.target_pool) print('Batcher initialized') # checkpoints saver = tf.train.Saver() config = tf.ConfigProto() if FLAGS.log_device_placement: config.log_device_placement = True with tf.Session(config=config) as sess: t0 = time.time() # set seed tf.set_random_seed(FLAGS.seed) if FLAGS.logdir: train_writer = tf.summary.FileWriter(FLAGS.logdir + '/train', sess.graph) else: train_writer = None if FLAGS.restart: # load variables into session saver.restore(sess, FLAGS.restart) else: # initialize variables print('Initializing...') sess.run(tf.global_variables_initializer()) print('Initialization time %f' % (time.time() - t0)) train_loss = None best_loss = None early_stop_i = 0 undroppable_counter = 3 max_drops = 8 num_drops = 0 for epoch in range(num_train_epochs): if early_stop_i < job['early_stop'] or epoch < FLAGS.min_epochs: t0 = time.time() # save previous train_loss_last = train_loss # alternate forward and reverse batches fwdrc = True if FLAGS.rc and epoch % 2 == 1: fwdrc = False # cycle shifts shift_i = epoch % len(shifts) # train train_loss = dr.train_epoch(sess, batcher_train, fwdrc, shifts[shift_i], train_writer) # validate valid_acc = dr.test(sess, batcher_valid, mc_n=FLAGS.mc_n, rc=FLAGS.rc, shifts=shifts) valid_loss = valid_acc.loss valid_r2 = valid_acc.r2().mean() del valid_acc best_str = '' if best_loss is None or valid_loss < best_loss: best_loss = valid_loss best_str = ', best!' early_stop_i = 0 saver.save( sess, '%s/%s_best.tf' % (FLAGS.logdir, FLAGS.save_prefix)) else: early_stop_i += 1 # measure time et = time.time() - t0 if et < 600: time_str = '%3ds' % et elif et < 6000: time_str = '%3dm' % (et / 60) else: time_str = '%3.1fh' % (et / 3600) # print update print( 'Epoch %3d: Train loss: %7.5f, Valid loss: %7.5f, Valid R2: %7.5f, Time: %s%s' % (epoch + 1, train_loss, valid_loss, valid_r2, time_str, best_str), end='') # if training stagnant if FLAGS.learn_rate_drop and num_drops < max_drops and undroppable_counter == 0 and ( train_loss_last - train_loss) / train_loss_last < 0.0002: print(', rate drop', end='') dr.drop_rate(2 / 3) undroppable_counter = 1 num_drops += 1 else: undroppable_counter = max(0, undroppable_counter - 1) print('') sys.stdout.flush() if FLAGS.logdir: train_writer.close()
def main(): usage = "usage: %prog [options] <params_file> <model_file> <data_file>" parser = OptionParser(usage) parser.add_option("-l", dest="layers", default=None, help="Comma-separated list of layers to plot") parser.add_option( "-n", dest="num_seqs", default=None, type="int", help="Number of sequences to process", ) parser.add_option( "-o", dest="out_dir", default="hidden", help="Output directory [Default: %default]", ) parser.add_option( "-t", dest="target_indexes", default=None, help="Paint 2D plots with these target index values.", ) (options, args) = parser.parse_args() if len(args) != 3: parser.error("Must provide paramters, model, and test data HDF5") else: params_file = args[0] model_file = args[1] data_file = args[2] if not os.path.isdir(options.out_dir): os.mkdir(options.out_dir) if options.layers is not None: options.layers = [int(li) for li in options.layers.split(",")] ####################################################### # load data ####################################################### data_open = h5py.File(data_file) test_seqs = data_open["test_in"] test_targets = data_open["test_out"] if options.num_seqs is not None: test_seqs = test_seqs[:options.num_seqs] test_targets = test_targets[:options.num_seqs] ####################################################### # model parameters and placeholders ####################################################### job = params.read_job_params(params_file) job["seq_length"] = test_seqs.shape[1] job["seq_depth"] = test_seqs.shape[2] job["num_targets"] = test_targets.shape[2] job["target_pool"] = int(np.array(data_open.get("pool_width", 1))) t0 = time.time() model = seqnn.SeqNN() model.build_feed(job) if options.target_indexes is None: options.target_indexes = range(job["num_targets"]) else: options.target_indexes = [ int(ti) for ti in options.target_indexes.split(",") ] ####################################################### # test ####################################################### # initialize batcher batcher_test = batcher.Batcher( test_seqs, test_targets, batch_size=model.hp.batch_size, pool_width=model.hp.target_pool, ) # initialize saver saver = tf.train.Saver() with tf.Session() as sess: # load variables into session saver.restore(sess, model_file) # get layer representations layer_reprs, _ = model.hidden(sess, batcher_test, options.layers) if options.layers is None: options.layers = range(len(layer_reprs)) for li in options.layers: layer_repr = layer_reprs[li] try: print(layer_repr.shape) except: print(layer_repr) # sample one nt per sequence ds_indexes = np.arange(0, layer_repr.shape[1], 256) nt_reprs = layer_repr[:, ds_indexes, :].reshape( (-1, layer_repr.shape[2])) print("nt_reprs", nt_reprs.shape) ######################################################## # plot raw sns.set(style="ticks", font_scale=1.2) plt.figure() g = sns.clustermap(nt_reprs, cmap="RdBu_r", xticklabels=False, yticklabels=False) g.ax_heatmap.set_xlabel("Representation") g.ax_heatmap.set_ylabel("Sequences") plt.savefig("%s/l%d_reprs.pdf" % (options.out_dir, li)) plt.close() ######################################################## # plot variance explained ratios model_full = PCA() model_full.fit_transform(nt_reprs) evr = model_full.explained_variance_ratio_ pca_n = 40 plt.figure() plt.scatter(range(1, pca_n + 1), evr[:pca_n], c="black") ax = plt.gca() ax.set_xlim(0, pca_n + 1) ax.set_xlabel("Principal components") ax.set_ylim(0, evr[:pca_n].max() * 1.05) ax.set_ylabel("Variance explained") ax.grid(True, linestyle=":") plt.savefig("%s/l%d_pca.pdf" % (options.out_dir, li)) plt.close() ######################################################## # visualize in 2D model2 = PCA(n_components=2) nt_2d = model2.fit_transform(nt_reprs) for ti in options.target_indexes: # slice for target test_targets_ti = test_targets[:, :, ti] # repeat to match layer_repr target_repeat = layer_repr.shape[1] // test_targets.shape[1] test_targets_ti = np.repeat(test_targets_ti, target_repeat, axis=1) # downsample indexes nt_targets = test_targets_ti[:, ds_indexes].flatten() # log transform nt_targets = np.log1p(nt_targets) plt.figure() plt.scatter(nt_2d[:, 0], nt_2d[:, 1], alpha=0.5, c=nt_targets, cmap="RdBu_r") plt.colorbar() ax = plt.gca() ax.grid(True, linestyle=":") plt.savefig("%s/l%d_nt2d_t%d.pdf" % (options.out_dir, li, ti)) plt.close() ######################################################## # plot neuron-neuron correlations # compute correlation matrix hidden_cov = np.corrcoef(nt_reprs.T) print("hidden_cov", hidden_cov.shape) plt.figure() g = sns.clustermap(hidden_cov, cmap="RdBu_r", xticklabels=False, yticklabels=False) plt.savefig("%s/l%d_cov.pdf" % (options.out_dir, li)) plt.close() ######################################################## # plot neuron densities neuron_stats_out = open("%s/l%d_stats.txt" % (options.out_dir, li), "w") for ni in range(nt_reprs.shape[1]): # print stats nu = nt_reprs[:, ni].mean() nstd = nt_reprs[:, ni].std() print("%3d %6.3f %6.3f" % (ni, nu, nstd), file=neuron_stats_out) # plot # plt.figure() # sns.distplot(nt_reprs[:,ni]) # plt.savefig('%s/l%d_dist%d.pdf' % (options.out_dir,li,ni)) # plt.close() neuron_stats_out.close() ######################################################## # plot layer norms across length """ layer_repr_norms = np.linalg.norm(layer_repr, axis=2) length_vec = list(range(layer_repr_norms.shape[1]))*layer_repr_norms.shape[0] length_vec = np.array(length_vec) + 0.1*np.random.randn(len(length_vec)) repr_norm_vec = layer_repr_norms.flatten() out_pdf = '%s/l%d_lnorm.pdf' % (options.out_dir,li) regplot(length_vec, repr_norm_vec, out_pdf, x_label='Position', y_label='Repr Norm') """ data_open.close()
def main(): usage = 'usage: %prog [options] <params_file> <model_file> <input_file>' parser = OptionParser(usage) parser.add_option( '-a', dest='activity_enrich', default=1, type='float', help='Enrich for the most active top % of sequences [Default: %default]' ) parser.add_option('-b', dest='batch_size', default=None, type='int', help='Batch size') parser.add_option('-f', dest='figure_width', default=20, type='float', help='Figure width [Default: %default]') parser.add_option( '-g', dest='gain', default=False, action='store_true', help='Draw a sequence logo for the gain score, too [Default: %default]' ) parser.add_option( '-l', dest='satmut_len', default=200, type='int', help='Length of centered sequence to mutate [Default: %default]') parser.add_option('-m', dest='min_limit', default=0.005, type='float', help='Minimum heatmap limit [Default: %default]') parser.add_option( '-n', dest='load_sat_npy', default=False, action='store_true', help='Load the predictions from .npy files [Default: %default]') parser.add_option('-o', dest='out_dir', default='heat', help='Output directory [Default: %default]') parser.add_option('-r', dest='rng_seed', default=1, type='float', help='Random number generator seed [Default: %default]') parser.add_option( '--rc', dest='rc', default=False, action='store_true', help= 'Ensemble forward and reverse complement predictions [Default: %default]' ) parser.add_option( '-s', dest='sample', default=None, type='int', help='Sample sequences from the test set [Default:%default]') parser.add_option('--shifts', dest='shifts', default='0', help='Ensemble prediction shifts [Default: %default]') parser.add_option( '-t', dest='targets', default='0', help= 'Comma-separated target indexes (or -1 for all) [Default: %default]') (options, args) = parser.parse_args() if len(args) != 3: parser.error( 'Must provide parameters and model files and input sequences (as a ' 'FASTA file or test data in an HDF file') else: params_file = args[0] model_file = args[1] input_file = args[2] if not os.path.isdir(options.out_dir): os.mkdir(options.out_dir) options.shifts = [int(shift) for shift in options.shifts.split(',')] random.seed(options.rng_seed) ################################################################# # parse input file ################################################################# seqs, seqs_1hot, targets = parse_input(input_file, options.sample) # decide which targets to obtain if options.targets == '-1': target_indexes = range(targets.shape[2]) target_subset = None else: target_indexes = [int(ti) for ti in options.targets.split(',')] target_subset = target_indexes # enrich for active sequences if targets is not None: seqs, seqs_1hot, targets = enrich_activity(seqs, seqs_1hot, targets, options.activity_enrich, target_indexes) seqs_n = seqs_1hot.shape[0] ################################################################# # setup model ################################################################# job = params.read_job_params(params_file) job['seq_length'] = seqs_1hot.shape[1] job['seq_depth'] = seqs_1hot.shape[2] if targets is None: if 'num_targets' not in job or 'target_pool' not in job: print( 'Must provide num_targets and target_pool in parameters file', file=sys.stderr) exit(1) else: job['num_targets'] = targets.shape[2] job['target_pool'] = job['seq_length'] // targets.shape[1] t0 = time.time() model = seqnn.SeqNN() model.build_feed(job, ensemble_rc=options.rc, ensemble_shifts=options.shifts, target_subset=target_subset) if options.batch_size is not None: model.hp.batch_size = options.batch_size # initialize saver saver = tf.train.Saver() with tf.Session() as sess: # load variables into session saver.restore(sess, model_file) for si in range(seqs_n): print('Mutating sequence %d / %d' % (si + 1, seqs_n), flush=True) # write sequence fasta_out = open('%s/seq%d.fa' % (options.out_dir, si), 'w') end_len = (len(seqs[si]) - options.satmut_len) // 2 print('>seq%d\n%s' % (si, seqs[si][end_len:-end_len]), file=fasta_out) fasta_out.close() ################################################################# # predict modifications if options.load_sat_npy: sat_preds = np.load('%s/seq%d_preds.npy' % (options.out_dir, si)) else: # supplement with saturated mutagenesis sat_seqs_1hot = satmut_seqs(seqs_1hot[si:si + 1], options.satmut_len) # initialize batcher batcher_sat = batcher.Batcher(sat_seqs_1hot, batch_size=model.hp.batch_size) # predict sat_preds = model.predict_h5(sess, batcher_sat) np.save('%s/seq%d_preds.npy' % (options.out_dir, si), sat_preds) ################################################################# # compute delta, loss, and gain matrices # compute the matrix of prediction deltas: (4 x L_sm x T) array sat_delta = delta_matrix(seqs_1hot[si], sat_preds, options.satmut_len) # sat_loss, sat_gain = loss_gain(sat_delta, sat_preds[si], options.satmut_len) sat_loss = sat_delta.min(axis=0) sat_gain = sat_delta.max(axis=0) ############################################## # plot for tii in range(len(target_indexes)): # setup plot sns.set(style='white', font_scale=1) spp = subplot_params(sat_delta.shape[1]) if options.gain: plt.figure(figsize=(options.figure_width, 5)) ax_pred = plt.subplot2grid((5, spp['heat_cols']), (0, spp['pred_start']), colspan=spp['pred_span']) ax_logo_loss = plt.subplot2grid((5, spp['heat_cols']), (1, spp['logo_start']), colspan=spp['logo_span']) ax_logo_gain = plt.subplot2grid((5, spp['heat_cols']), (2, spp['logo_start']), colspan=spp['logo_span']) ax_sad = plt.subplot2grid((5, spp['heat_cols']), (3, spp['sad_start']), colspan=spp['sad_span']) ax_heat = plt.subplot2grid((5, spp['heat_cols']), (4, 0), colspan=spp['heat_cols']) else: plt.figure(figsize=(options.figure_width, 4)) ax_pred = plt.subplot2grid((4, spp['heat_cols']), (0, spp['pred_start']), colspan=spp['pred_span']) ax_logo_loss = plt.subplot2grid((4, spp['heat_cols']), (1, spp['logo_start']), colspan=spp['logo_span']) ax_sad = plt.subplot2grid((4, spp['heat_cols']), (2, spp['sad_start']), colspan=spp['sad_span']) ax_heat = plt.subplot2grid((4, spp['heat_cols']), (3, 0), colspan=spp['heat_cols']) # plot predictions plot_predictions(ax_pred, sat_preds[0, :, tii], options.satmut_len, model.hp.seq_length, model.hp.batch_buffer) # plot sequence logo plot_seqlogo(ax_logo_loss, seqs_1hot[si], -sat_loss[:, tii]) if options.gain: plot_seqlogo(ax_logo_gain, seqs_1hot[si], sat_gain[:, tii]) # plot SAD plot_sad(ax_sad, sat_loss[:, tii], sat_gain[:, tii]) # plot heat map plot_heat(ax_heat, sat_delta[:, :, tii], options.min_limit) plt.tight_layout() plt.savefig('%s/seq%d_t%d.pdf' % (options.out_dir, si, target_indexes[tii]), dpi=600) plt.close()
def main(): usage = ( "usage: %prog [options] <params_file> <model_file> <genes_hdf5_file>" " <vcf_file>" ) parser = OptionParser(usage) parser.add_option( "-a", dest="all_sed", default=False, action="store_true", help="Print all variant-gene pairs, as opposed to only nonzero [Default: %default]", ) parser.add_option( "-b", dest="batch_size", default=None, type="int", help="Batch size [Default: %default]", ) parser.add_option( "-c", dest="csv", default=False, action="store_true", help="Print table as CSV [Default: %default]", ) parser.add_option( "-g", dest="genome_file", default="%s/data/human.hg19.genome" % os.environ["BASENJIDIR"], help="Chromosome lengths file [Default: %default]", ) parser.add_option( "-o", dest="out_dir", default="sed", help="Output directory for tables and plots [Default: %default]", ) parser.add_option( "-p", dest="processes", default=None, type="int", help="Number of processes, passed by multi script", ) parser.add_option( "--pseudo", dest="log_pseudo", default=0.125, type="float", help="Log2 pseudocount [Default: %default]", ) parser.add_option( "-r", dest="tss_radius", default=0, type="int", help="Radius of bins considered to quantify TSS transcription [Default: %default]", ) parser.add_option( "--rc", dest="rc", default=False, action="store_true", help="Average the forward and reverse complement predictions when testing [Default: %default]", ) parser.add_option( "--shifts", dest="shifts", default="0", help="Ensemble prediction shifts [Default: %default]", ) parser.add_option( "-t", dest="targets_file", default=None, help="File specifying target indexes and labels in table format", ) parser.add_option( "-u", dest="penultimate", default=False, action="store_true", help="Compute SED in the penultimate layer [Default: %default]", ) parser.add_option( "-x", dest="tss_table", default=False, action="store_true", help="Print TSS table in addition to gene [Default: %default]", ) (options, args) = parser.parse_args() if len(args) == 4: # single worker params_file = args[0] model_file = args[1] genes_hdf5_file = args[2] vcf_file = args[3] elif len(args) == 6: # multi worker options_pkl_file = args[0] params_file = args[1] model_file = args[2] genes_hdf5_file = args[3] vcf_file = args[4] worker_index = int(args[5]) # load options options_pkl = open(options_pkl_file, "rb") options = pickle.load(options_pkl) options_pkl.close() # update output directory options.out_dir = "%s/job%d" % (options.out_dir, worker_index) else: parser.error( "Must provide parameters and model files, genes HDF5 file, and QTL VCF" " file" ) if not os.path.isdir(options.out_dir): os.mkdir(options.out_dir) options.shifts = [int(shift) for shift in options.shifts.split(",")] ################################################################# # reads in genes HDF5 gene_data = genedata.GeneData(genes_hdf5_file) # filter for worker sequences if options.processes is not None: gene_data.worker(worker_index, options.processes) ################################################################# # prep SNPs # load SNPs snps = bvcf.vcf_snps(vcf_file) # intersect w/ segments print("Intersecting gene sequences with SNPs...", end="") sys.stdout.flush() seqs_snps = bvcf.intersect_seqs_snps(vcf_file, gene_data.gene_seqs, vision_p=0.5) print("%d sequences w/ SNPs" % len(seqs_snps)) ################################################################# # setup model job = params.read_job_params(params_file) job["seq_length"] = gene_data.seq_length job["seq_depth"] = gene_data.seq_depth job["target_pool"] = gene_data.pool_width if "num_targets" not in job and gene_data.num_targets is not None: job["num_targets"] = gene_data.num_targets if "num_targets" not in job: print( "Must specify number of targets (num_targets) in the \ parameters file.", file=sys.stderr, ) exit(1) if options.targets_file is None: target_labels = gene_data.target_labels target_subset = None target_ids = gene_data.target_ids if target_ids is None: target_ids = ["t%d" % ti for ti in range(job["num_targets"])] target_labels = [""] * len(target_ids) else: # Unfortunately, this target file differs from some others # in that it needs to specify the indexes from the original # set. In the future, I should standardize to this version. target_ids = [] target_labels = [] target_subset = [] for line in open(options.targets_file): a = line.strip().split("\t") target_subset.append(int(a[0])) target_ids.append(a[1]) target_labels.append(a[3]) if len(target_subset) == job["num_targets"]: target_subset = None # build model model = seqnn.SeqNN() model.build_feed(job, target_subset=target_subset) if options.penultimate: # labels become inappropriate target_ids = [""] * model.hp.cnn_filters[-1] target_labels = target_ids ################################################################# # compute, collect, and print SEDs header_cols = ( "rsid", "ref", "alt", "gene", "tss_dist", "ref_pred", "alt_pred", "sed", "ser", "target_index", "target_id", "target_label", ) if options.csv: sed_gene_out = open("%s/sed_gene.csv" % options.out_dir, "w") print(",".join(header_cols), file=sed_gene_out) if options.tss_table: sed_tss_out = open("%s/sed_tss.csv" % options.out_dir, "w") print(",".join(header_cols), file=sed_tss_out) else: sed_gene_out = open("%s/sed_gene.txt" % options.out_dir, "w") print(" ".join(header_cols), file=sed_gene_out) if options.tss_table: sed_tss_out = open("%s/sed_tss.txt" % options.out_dir, "w") print(" ".join(header_cols), file=sed_tss_out) # helper variables pred_buffer = model.hp.batch_buffer // model.hp.target_pool # initialize saver saver = tf.train.Saver() with tf.Session() as sess: # load variables into session saver.restore(sess, model_file) # for each gene sequence for seq_i in range(gene_data.num_seqs): gene_seq = gene_data.gene_seqs[seq_i] print(gene_seq) # if it contains SNPs seq_snps = [snps[snp_i] for snp_i in seqs_snps[seq_i]] if len(seq_snps) > 0: # one hot code allele sequences aseqs_1hot = alleles_1hot( gene_seq, gene_data.seqs_1hot[seq_i], seq_snps ) # initialize batcher batcher_gene = batcher.Batcher( aseqs_1hot, batch_size=model.hp.batch_size ) # construct allele gene_seq's allele_gene_seqs = [gene_seq] * aseqs_1hot.shape[0] # predict alleles allele_tss_preds = model.predict_genes( sess, batcher_gene, allele_gene_seqs, rc=options.rc, shifts=options.shifts, embed_penultimate=options.penultimate, tss_radius=options.tss_radius, ) # reshape (Alleles x TSSs) x Targets to Alleles x TSSs x Targets allele_tss_preds = allele_tss_preds.reshape( (aseqs_1hot.shape[0], gene_seq.num_tss, -1) ) # extract reference and SNP alt predictions ref_tss_preds = allele_tss_preds[0] # TSSs x Targets alt_tss_preds = allele_tss_preds[1:] # SNPs x TSSs x Targets # compute TSS SED scores snp_tss_sed = alt_tss_preds - ref_tss_preds snp_tss_ser = np.log2(alt_tss_preds + options.log_pseudo) - np.log2( ref_tss_preds + options.log_pseudo ) # compute gene-level predictions ref_gene_preds, gene_ids = gene.map_tss_genes( ref_tss_preds, gene_seq.tss_list, options.tss_radius ) alt_gene_preds = [] for snp_i in range(len(seq_snps)): agp, _ = gene.map_tss_genes( alt_tss_preds[snp_i], gene_seq.tss_list, options.tss_radius ) alt_gene_preds.append(agp) alt_gene_preds = np.array(alt_gene_preds) # compute gene SED scores gene_sed = alt_gene_preds - ref_gene_preds gene_ser = np.log2(alt_gene_preds + options.log_pseudo) - np.log2( ref_gene_preds + options.log_pseudo ) # for each SNP for snp_i in range(len(seq_snps)): snp = seq_snps[snp_i] # initialize gene data structures snp_dist_gene = {} # for each TSS for tss_i in range(gene_seq.num_tss): tss = gene_seq.tss_list[tss_i] # SNP distance to TSS snp_dist = abs(snp.pos - tss.pos) if tss.gene_id in snp_dist_gene: snp_dist_gene[tss.gene_id] = min( snp_dist_gene[tss.gene_id], snp_dist ) else: snp_dist_gene[tss.gene_id] = snp_dist # for each target if options.tss_table: for ti in range(ref_tss_preds.shape[1]): # check if nonzero if options.all_sed or not np.isclose( tss_sed[snp_i, tss_i, ti], 0, atol=1e-4 ): # print cols = ( snp.rsid, bvcf.cap_allele(snp.ref_allele), bvcf.cap_allele(snp.alt_alleles[0]), tss.identifier, snp_dist, ref_tss_preds[tss_i, ti], alt_tss_preds[snp_i, tss_i, ti], tss_sed[snp_i, tss_i, ti], tss_ser[snp_i, tss_i, ti], ti, target_ids[ti], target_labels[ti], ) if options.csv: print( ",".join([str(c) for c in cols]), file=sed_tss_out, ) else: print( "%-13s %s %5s %16s %5d %7.4f %7.4f %7.4f %7.4f %4d %12s %s" % cols, file=sed_tss_out, ) # for each gene for gi in range(len(gene_ids)): gene_str = gene_ids[gi] if gene_ids[gi] in gene_data.multi_seq_genes: gene_str = "%s_multi" % gene_ids[gi] # print rows to gene table for ti in range(ref_gene_preds.shape[1]): # check if nonzero if options.all_sed or not np.isclose( gene_sed[snp_i, gi, ti], 0, atol=1e-4 ): # print cols = [ snp.rsid, bvcf.cap_allele(snp.ref_allele), bvcf.cap_allele(snp.alt_alleles[0]), gene_str, snp_dist_gene[gene_ids[gi]], ref_gene_preds[gi, ti], alt_gene_preds[snp_i, gi, ti], gene_sed[snp_i, gi, ti], gene_ser[snp_i, gi, ti], ti, target_ids[ti], target_labels[ti], ] if options.csv: print( ",".join([str(c) for c in cols]), file=sed_gene_out, ) else: print( "%-13s %s %5s %16s %5d %7.4f %7.4f %7.4f %7.4f %4d %12s %s" % tuple(cols), file=sed_gene_out, ) # clean up gc.collect() sed_gene_out.close() if options.tss_table: sed_tss_out.close()
def run(params_file, data_file, train_epochs, train_epoch_batches, test_epoch_batches): ####################################################### # load data ####################################################### data_open = h5py.File(data_file) train_seqs = data_open['train_in'] train_targets = data_open['train_out'] train_na = None if 'train_na' in data_open: train_na = data_open['train_na'] valid_seqs = data_open['valid_in'] valid_targets = data_open['valid_out'] valid_na = None if 'valid_na' in data_open: valid_na = data_open['valid_na'] ####################################################### # model parameters and placeholders ####################################################### job = params.read_job_params(params_file) job['seq_length'] = train_seqs.shape[1] job['seq_depth'] = train_seqs.shape[2] job['num_targets'] = train_targets.shape[2] job['target_pool'] = int(np.array(data_open.get('pool_width', 1))) t0 = time.time() model = seqnn.SeqNN() model.build(job) print('Model building time %f' % (time.time() - t0)) # adjust for fourier job['fourier'] = 'train_out_imag' in data_open if job['fourier']: train_targets_imag = data_open['train_out_imag'] valid_targets_imag = data_open['valid_out_imag'] ####################################################### # prepare batcher ####################################################### if job['fourier']: batcher_train = batcher.BatcherF(train_seqs, train_targets, train_targets_imag, train_na, model.hp.batch_size, model.hp.target_pool, shuffle=True) batcher_valid = batcher.BatcherF(valid_seqs, valid_targets, valid_targets_imag, valid_na, model.batch_size, model.target_pool) else: batcher_train = batcher.Batcher(train_seqs, train_targets, train_na, model.hp.batch_size, model.hp.target_pool, shuffle=True) batcher_valid = batcher.Batcher(valid_seqs, valid_targets, valid_na, model.hp.batch_size, model.hp.target_pool) print('Batcher initialized') ####################################################### # train ####################################################### augment_shifts = [int(shift) for shift in FLAGS.augment_shifts.split(',')] ensemble_shifts = [ int(shift) for shift in FLAGS.ensemble_shifts.split(',') ] # checkpoints saver = tf.train.Saver() config = tf.ConfigProto() if FLAGS.log_device_placement: config.log_device_placement = True with tf.Session(config=config) as sess: t0 = time.time() # set seed tf.set_random_seed(FLAGS.seed) if FLAGS.logdir: train_writer = tf.summary.FileWriter(FLAGS.logdir + '/train', sess.graph) else: train_writer = None if FLAGS.restart: # load variables into session saver.restore(sess, FLAGS.restart) else: # initialize variables print('Initializing...') sess.run(tf.global_variables_initializer()) print('Initialization time %f' % (time.time() - t0)) train_loss = None best_loss = None early_stop_i = 0 epoch = 0 while (train_epochs is None or epochs < train_epochs) and early_stop_i < FLAGS.early_stop: t0 = time.time() # alternate forward and reverse batches fwdrc = True if FLAGS.augment_rc and epoch % 2 == 1: fwdrc = False # cycle shifts shift_i = epoch % len(augment_shifts) # train train_loss, steps = model.train_epoch( sess, batcher_train, fwdrc=fwdrc, shift=augment_shifts[shift_i], sum_writer=train_writer, epoch_batches=train_epoch_batches, no_steps=FLAGS.no_steps) # validate valid_acc = model.test(sess, batcher_valid, mc_n=FLAGS.ensemble_mc, rc=FLAGS.ensemble_rc, shifts=ensemble_shifts, test_batches=test_epoch_batches) valid_loss = valid_acc.loss valid_r2 = valid_acc.r2().mean() del valid_acc best_str = '' if best_loss is None or valid_loss < best_loss: best_loss = valid_loss best_str = ', best!' early_stop_i = 0 saver.save(sess, '%s/model_best.tf' % FLAGS.logdir) else: early_stop_i += 1 # measure time et = time.time() - t0 if et < 600: time_str = '%3ds' % et elif et < 6000: time_str = '%3dm' % (et / 60) else: time_str = '%3.1fh' % (et / 3600) # print update print( 'Epoch: %3d, Steps: %7d, Train loss: %7.5f, Valid loss: %7.5f, Valid R2: %7.5f, Time: %s%s' % (epoch + 1, steps, train_loss, valid_loss, valid_r2, time_str, best_str)) sys.stdout.flush() if FLAGS.check_all: saver.save(sess, '%s/model_check%d.tf' % (FLAGS.logdir, epoch)) # update epoch epoch += 1 if FLAGS.logdir: train_writer.close()
def score_write(sess, model, options, seqs_1hot, seqs_chrom, seqs_start): ''' Compute scores and write them as BigWigs for a set of sequences. ''' for si in range(seqs_1hot.shape[0]): # initialize batcher batcher_si = batcher.Batcher(seqs_1hot[si:si+1], batch_size=model.hp.batch_size, pool_width=model.hp.target_pool) # get layer representations t0 = time.time() print('Computing gradients.', end='', flush=True) _, _, _, batch_grads, batch_reprs, _ = model.gradients(sess, batcher_si, rc=options.rc, shifts=options.shifts, mc_n=options.mc_n, return_all=True) print(' Done in %ds.' % (time.time()-t0), flush=True) # only layer batch_reprs = batch_reprs[0] batch_grads = batch_grads[0] # increase resolution batch_reprs = batch_reprs.astype('float32') batch_grads = batch_grads.astype('float32') # S (sequences) x T (targets) x P (seq position) x U (units layer i) x E (ensembles) print('batch_grads', batch_grads.shape) pooled_length = batch_grads.shape[2] # S (sequences) x P (seq position) x U (Units layer i) x E (ensembles) print('batch_reprs', batch_reprs.shape) # write bigwigs t0 = time.time() print('Writing BigWigs.', end='', flush=True) # for each target for tii in range(len(options.target_indexes)): ti = options.target_indexes[tii] # compute scores if options.norm is None: batch_grads_scores = np.multiply(batch_reprs[0], batch_grads[0,tii,:,:,:]).sum(axis=1) else: batch_grads_scores = np.multiply(batch_reprs[0], batch_grads[0,tii,:,:,:]) batch_grads_scores = np.power(np.abs(batch_grads_scores), options.norm) batch_grads_scores = batch_grads_scores.sum(axis=1) batch_grads_scores = np.power(batch_grads_scores, 1./options.norm) # compute score statistics batch_grads_mean = batch_grads_scores.mean(axis=1) if options.norm is None: batch_grads_pval = ttest_1samp(batch_grads_scores, 0, axis=1)[1] else: batch_grads_pval = ttest_1samp(batch_grads_scores, 0, axis=1)[1] # batch_grads_pval = chi2(df=) batch_grads_pval /= 2 # open bigwig bws_file = '%s/s%d_t%d_scores.bw' % (options.out_dir, si, ti) bwp_file = '%s/s%d_t%d_pvals.bw' % (options.out_dir, si, ti) bws_open = bigwig_open(bws_file, options.genome_file) # bwp_open = bigwig_open(bwp_file, options.genome_file) # specify bigwig locations and values bw_chroms = [seqs_chrom[si]]*pooled_length bw_starts = [int(seqs_start[si] + pi*model.hp.target_pool) for pi in range(pooled_length)] bw_ends = [int(bws + model.hp.target_pool) for bws in bw_starts] bws_values = [float(bgs) for bgs in batch_grads_mean] # bwp_values = [float(bgp) for bgp in batch_grads_pval] # write bws_open.addEntries(bw_chroms, bw_starts, ends=bw_ends, values=bws_values) # bwp_open.addEntries(bw_chroms, bw_starts, ends=bw_ends, values=bwp_values) # close bws_open.close() # bwp_open.close() print(' Done in %ds.' % (time.time()-t0), flush=True) gc.collect()
def main(): usage = 'usage: %prog [options] <params_file> <model_file> <vcf_file>' parser = OptionParser(usage) parser.add_option('-b', dest='batch_size', default=256, type='int', help='Batch size [Default: %default]') parser.add_option('-c', dest='csv', default=False, action='store_true', help='Print table as CSV [Default: %default]') parser.add_option('-f', dest='genome_fasta', default='%s/data/hg19.fa' % os.environ['BASENJIDIR'], help='Genome FASTA for sequences [Default: %default]') parser.add_option('-g', dest='genome_file', default='%s/data/human.hg19.genome' % os.environ['BASENJIDIR'], help='Chromosome lengths file [Default: %default]') parser.add_option('--h5', dest='out_h5', default=False, action='store_true', help='Output stats to sad.h5 [Default: %default]') parser.add_option('--local', dest='local', default=1024, type='int', help='Local SAD score [Default: %default]') parser.add_option('-n', dest='norm_file', default=None, help='Normalize SAD scores') parser.add_option( '-o', dest='out_dir', default='sad', help='Output directory for tables and plots [Default: %default]') parser.add_option('-p', dest='processes', default=None, type='int', help='Number of processes, passed by multi script') parser.add_option('--pseudo', dest='log_pseudo', default=1, type='float', help='Log2 pseudocount [Default: %default]') parser.add_option( '--rc', dest='rc', default=False, action='store_true', help= 'Average forward and reverse complement predictions [Default: %default]' ) parser.add_option('--shifts', dest='shifts', default='0', type='str', help='Ensemble prediction shifts [Default: %default]') parser.add_option( '--stats', dest='sad_stats', default='SAD,xSAR', help='Comma-separated list of stats to save. [Default: %default]') parser.add_option( '-t', dest='targets_file', default=None, type='str', help='File specifying target indexes and labels in table format') parser.add_option( '--ti', dest='track_indexes', default=None, type='str', help='Comma-separated list of target indexes to output BigWig tracks') parser.add_option( '-u', dest='penultimate', default=False, action='store_true', help='Compute SED in the penultimate layer [Default: %default]') parser.add_option('-z', dest='out_zarr', default=False, action='store_true', help='Output stats to sad.zarr [Default: %default]') (options, args) = parser.parse_args() if len(args) == 3: # single worker params_file = args[0] model_file = args[1] vcf_file = args[2] elif len(args) == 5: # multi worker options_pkl_file = args[0] params_file = args[1] model_file = args[2] vcf_file = args[3] worker_index = int(args[4]) # load options options_pkl = open(options_pkl_file, 'rb') options = pickle.load(options_pkl) options_pkl.close() # update output directory options.out_dir = '%s/job%d' % (options.out_dir, worker_index) else: parser.error( 'Must provide parameters and model files and QTL VCF file') if not os.path.isdir(options.out_dir): os.mkdir(options.out_dir) if options.track_indexes is None: options.track_indexes = [] else: options.track_indexes = [ int(ti) for ti in options.track_indexes.split(',') ] if not os.path.isdir('%s/tracks' % options.out_dir): os.mkdir('%s/tracks' % options.out_dir) options.shifts = [int(shift) for shift in options.shifts.split(',')] options.sad_stats = options.sad_stats.split(',') ################################################################# # setup model job = params.read_job_params( params_file, require=['seq_length', 'num_targets', 'target_pool']) if options.targets_file is None: target_ids = ['t%d' % ti for ti in range(job['num_targets'])] target_labels = [''] * len(target_ids) target_subset = None else: targets_df = pd.read_table(options.targets_file, index_col=0) target_ids = targets_df.identifier target_labels = targets_df.description target_subset = targets_df.index if len(target_subset) == job['num_targets']: target_subset = None # build model t0 = time.time() model = seqnn.SeqNN() model.build_feed(job, ensemble_rc=options.rc, ensemble_shifts=options.shifts, embed_penultimate=options.penultimate, target_subset=target_subset) print('Model building time %f' % (time.time() - t0), flush=True) if options.penultimate: # labels become inappropriate target_ids = [''] * model.hp.cnn_filters[-1] target_labels = target_ids # read target normalization factors target_norms = np.ones(len(target_labels)) if options.norm_file is not None: ti = 0 for line in open(options.norm_file): target_norms[ti] = float(line.strip()) ti += 1 num_targets = len(target_ids) ################################################################# # load SNPs snps = bvcf.vcf_snps(vcf_file) # filter for worker SNPs if options.processes is not None: worker_bounds = np.linspace(0, len(snps), options.processes + 1, dtype='int') snps = snps[worker_bounds[worker_index]:worker_bounds[worker_index + 1]] num_snps = len(snps) ################################################################# # setup output header_cols = ('rsid', 'ref', 'alt', 'ref_pred', 'alt_pred', 'sad', 'sar', 'geo_sad', 'ref_lpred', 'alt_lpred', 'lsad', 'lsar', 'ref_xpred', 'alt_xpred', 'xsad', 'xsar', 'target_index', 'target_id', 'target_label') if options.out_h5: sad_out = initialize_output_h5(options.out_dir, options.sad_stats, snps, target_ids, target_labels) elif options.out_zarr: sad_out = initialize_output_zarr(options.out_dir, options.sad_stats, snps, target_ids, target_labels) else: if options.csv: sad_out = open('%s/sad_table.csv' % options.out_dir, 'w') print(','.join(header_cols), file=sad_out) else: sad_out = open('%s/sad_table.txt' % options.out_dir, 'w') print(' '.join(header_cols), file=sad_out) ################################################################# # process # open genome FASTA genome_open = pysam.Fastafile(options.genome_fasta) # determine local start and end loc_mid = model.target_length // 2 loc_start = loc_mid - (options.local // 2) // model.hp.target_pool loc_end = loc_start + options.local // model.hp.target_pool snp_i = 0 szi = 0 # initialize saver saver = tf.train.Saver() with tf.Session() as sess: # load variables into session saver.restore(sess, model_file) # construct first batch batch_1hot, batch_snps, snp_i = snps_next_batch( snps, snp_i, options.batch_size, job['seq_length'], genome_open) while len(batch_snps) > 0: ################################################### # predict # initialize batcher batcher = batcher.Batcher(batch_1hot, batch_size=model.hp.batch_size) # predict # batch_preds = model.predict(sess, batcher, # rc=options.rc, shifts=options.shifts, # penultimate=options.penultimate) batch_preds = model.predict_h5(sess, batcher) # normalize batch_preds /= target_norms ################################################### # collect and print SADs pi = 0 for snp in batch_snps: # get reference prediction (LxT) ref_preds = batch_preds[pi] pi += 1 # sum across length ref_preds_sum = ref_preds.sum(axis=0, dtype='float64') # print tracks for ti in options.track_indexes: ref_bw_file = '%s/tracks/%s_t%d_ref.bw' % (options.out_dir, snp.rsid, ti) bigwig_write(snp, job['seq_length'], ref_preds[:, ti], model, ref_bw_file, options.genome_file) for alt_al in snp.alt_alleles: # get alternate prediction (LxT) alt_preds = batch_preds[pi] pi += 1 # sum across length alt_preds_sum = alt_preds.sum(axis=0, dtype='float64') # compare reference to alternative via mean subtraction sad_vec = alt_preds - ref_preds sad = alt_preds_sum - ref_preds_sum # compare reference to alternative via mean log division sar = np.log2(alt_preds_sum + options.log_pseudo) \ - np.log2(ref_preds_sum + options.log_pseudo) # compare geometric means sar_vec = np.log2(alt_preds.astype('float64') + options.log_pseudo) \ - np.log2(ref_preds.astype('float64') + options.log_pseudo) geo_sad = sar_vec.sum(axis=0) # sum locally ref_preds_loc = ref_preds[loc_start:loc_end, :].sum( axis=0, dtype='float64') alt_preds_loc = alt_preds[loc_start:loc_end, :].sum( axis=0, dtype='float64') # compute SAD locally sad_loc = alt_preds_loc - ref_preds_loc sar_loc = np.log2(alt_preds_loc + options.log_pseudo) \ - np.log2(ref_preds_loc + options.log_pseudo) # compute max difference position max_li = np.argmax(np.abs(sar_vec), axis=0) if options.out_h5 or options.out_zarr: sad_out['SAD'][szi, :] = sad.astype('float16') sad_out['xSAR'][szi, :] = np.array([ sar_vec[max_li[ti], ti] for ti in range(num_targets) ], dtype='float16') szi += 1 else: # print table lines for ti in range(len(sad)): # print line cols = (snp.rsid, bvcf.cap_allele(snp.ref_allele), bvcf.cap_allele(alt_al), ref_preds_sum[ti], alt_preds_sum[ti], sad[ti], sar[ti], geo_sad[ti], ref_preds_loc[ti], alt_preds_loc[ti], sad_loc[ti], sar_loc[ti], ref_preds[max_li[ti], ti], alt_preds[max_li[ti], ti], sad_vec[max_li[ti], ti], sar_vec[max_li[ti], ti], ti, target_ids[ti], target_labels[ti]) if options.csv: print(','.join([str(c) for c in cols]), file=sad_out) else: print( '%-13s %6s %6s | %8.2f %8.2f %8.3f %7.4f %7.3f | %7.3f %7.3f %7.3f %7.4f | %7.3f %7.3f %7.3f %7.4f | %4d %12s %s' % cols, file=sad_out) # print tracks for ti in options.track_indexes: alt_bw_file = '%s/tracks/%s_t%d_alt.bw' % ( options.out_dir, snp.rsid, ti) bigwig_write(snp, job['seq_length'], alt_preds[:, ti], model, alt_bw_file, options.genome_file) ################################################### # construct next batch batch_1hot, batch_snps, snp_i = snps_next_batch( snps, snp_i, options.batch_size, job['seq_length'], genome_open) ################################################### # compute SAD distributions across variants if options.out_h5 or options.out_zarr: # define percentiles d_fine = 0.001 d_coarse = 0.01 percentiles_neg = np.arange(d_fine, 0.1, d_fine) percentiles_base = np.arange(0.1, 0.9, d_coarse) percentiles_pos = np.arange(0.9, 1, d_fine) percentiles = np.concatenate( [percentiles_neg, percentiles_base, percentiles_pos]) sad_out.create_dataset('percentiles', data=percentiles) pct_len = len(percentiles) for sad_stat in options.sad_stats: sad_stat_pct = '%s_pct' % sad_stat # compute sad_pct = np.percentile(sad_out[sad_stat], 100 * percentiles, axis=0).T sad_pct = sad_pct.astype('float16') # save sad_out.create_dataset(sad_stat_pct, data=sad_pct, dtype='float16') if not options.out_zarr: sad_out.close()
def main(): usage = 'usage: %prog [options] <params_file> <model_file> <genes_hdf5_file>' parser = OptionParser(usage) parser.add_option( '-b', dest='batch_size', default=None, type='int', help='Batch size [Default: %default]') parser.add_option( '-i', dest='ignore_bed', help='Ignore genes overlapping regions in this BED file') parser.add_option( '-l', dest='load_preds', help='Load tess_preds from file') parser.add_option( '--heat', dest='plot_heat', default=False, action='store_true', help='Plot big gene-target heatmaps [Default: %default]') parser.add_option( '-o', dest='out_dir', default='genes_out', help='Output directory for tables and plots [Default: %default]') parser.add_option( '-r', dest='tss_radius', default=0, type='int', help='Radius of bins considered to quantify TSS transcription [Default: %default]') parser.add_option( '--rc', dest='rc', default=False, action='store_true', help= 'Average the forward and reverse complement predictions when testing [Default: %default]' ) parser.add_option( '-s', dest='plot_scatter', default=False, action='store_true', help='Make time-consuming accuracy scatter plots [Default: %default]') parser.add_option( '--shifts', dest='shifts', default='0', help='Ensemble prediction shifts [Default: %default]') parser.add_option( '--rep', dest='replicate_labels_file', help= 'Compare replicate experiments, aided by the given file with long labels') parser.add_option( '-t', dest='target_indexes', default=None, help= 'File or Comma-separated list of target indexes to scatter plot true versus predicted values' ) parser.add_option( '--table', dest='print_tables', default=False, action='store_true', help='Print big gene/TSS tables [Default: %default]') parser.add_option( '--tss', dest='tss_alt', default=False, action='store_true', help='Perform alternative TSS analysis [Default: %default]') parser.add_option( '-v', dest='gene_variance', default=False, action='store_true', help= 'Study accuracy with respect to gene variance across targets [Default: %default]' ) (options, args) = parser.parse_args() if len(args) != 3: parser.error('Must provide parameters and model files, and genes HDF5 file') else: params_file = args[0] model_file = args[1] genes_hdf5_file = args[2] if not os.path.isdir(options.out_dir): os.mkdir(options.out_dir) options.shifts = [int(shift) for shift in options.shifts.split(',')] ################################################################# # read in genes and targets gene_data = genedata.GeneData(genes_hdf5_file) ################################################################# # TSS predictions if options.load_preds is not None: # load from file tss_preds = np.load(options.load_preds) else: ####################################################### # setup model job = params.read_job_params(params_file) job['seq_length'] = gene_data.seq_length job['seq_depth'] = gene_data.seq_depth job['target_pool'] = gene_data.pool_width if not 'num_targets' in job: job['num_targets'] = gene_data.num_targets # build model model = seqnn.SeqNN() model.build(job) if options.batch_size is not None: model.hp.batch_size = options.batch_size ####################################################### # predict TSSs t0 = time.time() print('Computing gene predictions.', end='') sys.stdout.flush() # initialize batcher gene_batcher = batcher.Batcher( gene_data.seqs_1hot, batch_size=model.hp.batch_size) # initialie saver saver = tf.train.Saver() with tf.Session() as sess: # load variables into session saver.restore(sess, model_file) # predict tss_preds = model.predict_genes(sess, gene_batcher, gene_data.gene_seqs, rc=options.rc, shifts=options.shifts, tss_radius=options.tss_radius) # save to file np.save('%s/preds' % options.out_dir, tss_preds) print(' Done in %ds.' % (time.time() - t0)) ################################################################# # convert to genes gene_targets, _ = gene.map_tss_genes(gene_data.tss_targets, gene_data.tss, tss_radius=options.tss_radius) gene_preds, _ = gene.map_tss_genes(tss_preds, gene_data.tss, tss_radius=options.tss_radius) ################################################################# # determine targets # all targets if options.target_indexes is None: if gene_data.num_targets is None: print('No targets to test against') exit() else: options.target_indexes = np.arange(gene_data.num_targets) # file targets elif os.path.isfile(options.target_indexes): target_indexes_file = options.target_indexes options.target_indexes = [] for line in open(target_indexes_file): options.target_indexes.append(int(line.split()[0])) # comma-separated targets else: options.target_indexes = [ int(ti) for ti in options.target_indexes.split(',') ] options.target_indexes = np.array(options.target_indexes) ################################################################# # correlation statistics t0 = time.time() print('Computing correlations.', end='') sys.stdout.flush() cor_table(gene_data.tss_targets, tss_preds, gene_data.target_ids, gene_data.target_labels, options.target_indexes, '%s/tss_cors.txt' % options.out_dir) cor_table(gene_targets, gene_preds, gene_data.target_ids, gene_data.target_labels, options.target_indexes, '%s/gene_cors.txt' % options.out_dir, plots=True) print(' Done in %ds.' % (time.time() - t0)) ################################################################# # gene statistics if options.print_tables: t0 = time.time() print('Printing predictions.', end='') sys.stdout.flush() gene_table(gene_data.tss_targets, tss_preds, gene_data.tss_ids(), gene_data.target_labels, options.target_indexes, '%s/transcript' % options.out_dir, options.plot_scatter) gene_table(gene_targets, gene_preds, gene_data.gene_ids(), gene_data.target_labels, options.target_indexes, '%s/gene' % options.out_dir, options.plot_scatter) print(' Done in %ds.' % (time.time() - t0)) ################################################################# # gene x target heatmaps if options.plot_heat or options.gene_variance: ######################################### # normalize predictions across targets t0 = time.time() print('Normalizing values across targets.', end='') sys.stdout.flush() gene_targets_qn = normalize_targets(gene_targets[:, options.target_indexes], log_pseudo=1) gene_preds_qn = normalize_targets(gene_preds[:, options.target_indexes], log_pseudo=1) print(' Done in %ds.' % (time.time() - t0)) if options.plot_heat: ######################################### # plot genes by targets clustermap t0 = time.time() print('Plotting heat maps.', end='') sys.stdout.flush() sns.set(font_scale=1.3, style='ticks') plot_genes = 1600 plot_targets = 800 # choose a set of variable genes gene_vars = gene_preds_qn.var(axis=1) indexes_var = np.argsort(gene_vars)[::-1][:plot_genes] # choose a set of random genes if plot_genes < gene_preds_qn.shape[0]: indexes_rand = np.random.choice( np.arange(gene_preds_qn.shape[0]), plot_genes, replace=False) else: indexes_rand = np.arange(gene_preds_qn.shape[0]) # choose a set of random targets if plot_targets < 0.8 * gene_preds_qn.shape[1]: indexes_targets = np.random.choice( np.arange(gene_preds_qn.shape[1]), plot_targets, replace=False) else: indexes_targets = np.arange(gene_preds_qn.shape[1]) # variable gene predictions clustermap(gene_preds_qn[indexes_var, :][:, indexes_targets], '%s/gene_heat_var.pdf' % options.out_dir) clustermap( gene_preds_qn[indexes_var, :][:, indexes_targets], '%s/gene_heat_var_color.pdf' % options.out_dir, color='viridis', table=True) # random gene predictions clustermap(gene_preds_qn[indexes_rand, :][:, indexes_targets], '%s/gene_heat_rand.pdf' % options.out_dir) # variable gene targets clustermap(gene_targets_qn[indexes_var, :][:, indexes_targets], '%s/gene_theat_var.pdf' % options.out_dir) clustermap( gene_targets_qn[indexes_var, :][:, indexes_targets], '%s/gene_theat_var_color.pdf' % options.out_dir, color='viridis', table=True) # random gene targets (crashes) # clustermap(gene_targets_qn[indexes_rand, :][:, indexes_targets], # '%s/gene_theat_rand.pdf' % options.out_dir) print(' Done in %ds.' % (time.time() - t0)) ################################################################# # analyze replicates if options.replicate_labels_file is not None: # read long form labels, from which to infer replicates target_labels_long = [] for line in open(options.replicate_labels_file): a = line.split('\t') a[-1] = a[-1].rstrip() target_labels_long.append(a[-1]) # determine replicates replicate_lists = infer_replicates(target_labels_long) # compute correlations # replicate_correlations(replicate_lists, gene_data.tss_targets, # tss_preds, options.target_indexes, '%s/transcript_reps' % options.out_dir) replicate_correlations( replicate_lists, gene_targets, gene_preds, options.target_indexes, '%s/gene_reps' % options.out_dir) # , scatter_plots=True) ################################################################# # gene variance if options.gene_variance: variance_accuracy(gene_targets_qn, gene_preds_qn, '%s/gene' % options.out_dir) ################################################################# # alternative TSS if options.tss_alt: alternative_tss(gene_data.tss_targets[:,options.target_indexes], tss_preds[:,options.target_indexes], gene_data, options.out_dir, log_pseudo=1)
def score_write(sess, model, options, target_indexes, seqs_1hot, seqs_chrom, seqs_start): ''' Compute scores and write them as BigWigs for a set of sequences. ''' num_seqs = seqs_1hot.shape[0] num_targets = len(target_indexes) # initialize scores HDF5 scores_h5_file = '%s/scores.h5' % options.out_dir scores_h5_out = h5py.File(scores_h5_file, 'w') for si in range(num_seqs): # initialize batcher batcher_si = batcher.Batcher(seqs_1hot[si:si + 1], batch_size=model.hp.batch_size, pool_width=model.hp.target_pool) # get layer representations t0 = time.time() print('Computing gradients.', end='', flush=True) _, _, _, batch_grads, batch_reprs, _ = model.gradients( sess, batcher_si, rc=options.rc, shifts=options.shifts, mc_n=options.mc_n, return_all=True) print(' Done in %ds.' % (time.time() - t0), flush=True) # only layer batch_reprs = batch_reprs[0] batch_grads = batch_grads[0] # increase resolution batch_reprs = batch_reprs.astype('float32') batch_grads = batch_grads.astype('float32') # S (sequences) x T (targets) x P (seq position) x U (units layer i) x E (ensembles) print('batch_grads', batch_grads.shape) # S (sequences) x P (seq position) x U (Units layer i) x E (ensembles) print('batch_reprs', batch_reprs.shape) preds_length = batch_reprs.shape[1] if 'score' not in scores_h5_out: # initialize scores scores_h5_out.create_dataset('score', shape=(num_seqs, preds_length, num_targets), dtype='float16') scores_h5_out.create_dataset('pvalue', shape=(num_seqs, preds_length, num_targets), dtype='float16') # write bigwigs t0 = time.time() print('Computing and writing scores.', end='', flush=True) # for each target for tii in range(len(target_indexes)): ti = target_indexes[tii] # representation x gradient batch_grads_scores = np.multiply(batch_reprs[0], batch_grads[0, tii, :, :, :]) if options.norm is None: # sum across filters batch_grads_scores = batch_grads_scores.sum(axis=1) else: # raise to power batch_grads_scores = np.power(np.abs(batch_grads_scores), options.norm) # sum across filters batch_grads_scores = batch_grads_scores.sum(axis=1) # normalize w/ 1/power batch_grads_scores = np.power(batch_grads_scores, 1. / options.norm) # mean across ensemble batch_grads_mean = batch_grads_scores.mean(axis=1) # compute p-values if options.norm is None: batch_grads_pval = ttest_1samp(batch_grads_scores, 0, axis=1)[1] else: batch_grads_pval = ttest_1samp(batch_grads_scores, 0, axis=1)[1] # batch_grads_pval = chi2(df=) batch_grads_pval /= 2 # write to HDF5 scores_h5_out['score'][si, :, tii] = batch_grads_mean.astype('float16') scores_h5_out['pvalue'][si, :, tii] = batch_grads_pval.astype('float16') if options.bigwig: # open bigwig bws_file = '%s/s%d_t%d_scores.bw' % (options.out_dir, si, ti) bwp_file = '%s/s%d_t%d_pvals.bw' % (options.out_dir, si, ti) bws_open = bigwig_open(bws_file, options.genome_file) # bwp_open = bigwig_open(bwp_file, options.genome_file) # specify bigwig locations and values bw_chroms = [seqs_chrom[si]] * preds_length bw_starts = [ int(seqs_start[si] + pi * model.hp.target_pool) for pi in range(preds_length) ] bw_ends = [ int(bws + model.hp.target_pool) for bws in bw_starts ] bws_values = [float(bgs) for bgs in batch_grads_mean] # bwp_values = [float(bgp) for bgp in batch_grads_pval] # write bws_open.addEntries(bw_chroms, bw_starts, ends=bw_ends, values=bws_values) # bwp_open.addEntries(bw_chroms, bw_starts, ends=bw_ends, values=bwp_values) # close if options.bigwig: bws_open.close() # bwp_open.close() print(' Done in %ds.' % (time.time() - t0), flush=True) gc.collect() scores_h5_out.close()