def main(): usage = 'usage: %prog [options] <params_file> <model_file> <test_hdf5_file>' parser = OptionParser(usage) parser.add_option( '--ai', dest='accuracy_indexes', help= 'Comma-separated list of target indexes to make accuracy plots comparing true versus predicted values' ) parser.add_option( '--clip', dest='target_clip', default=None, type='float', help= 'Clip targets and predictions to a maximum value [Default: %default]') parser.add_option( '-d', dest='down_sample', default=1, type='int', help= 'Down sample test computation by taking uniformly spaced positions [Default: %default]' ) parser.add_option('-g', dest='genome_file', default='%s/assembly/human.hg19.genome' % os.environ['HG19'], help='Chromosome length information [Default: %default]') parser.add_option('--mc', dest='mc_n', default=0, type='int', help='Monte carlo test iterations [Default: %default]') parser.add_option( '--peak', '--peaks', dest='peaks', default=False, action='store_true', help='Compute expensive peak accuracy [Default: %default]') parser.add_option( '-o', dest='out_dir', default='test_out', help='Output directory for test statistics [Default: %default]') parser.add_option( '--rc', dest='rc', default=False, action='store_true', help='Average the fwd and rc predictions [Default: %default]') parser.add_option('-s', dest='scent_file', help='Dimension reduction model file') parser.add_option('--sample', dest='sample_pct', default=1, type='float', help='Sample percentage') parser.add_option('--save', dest='save', default=False, action='store_true') parser.add_option('--shifts', dest='shifts', default='0', help='Ensemble prediction shifts [Default: %default]') parser.add_option( '-t', dest='track_bed', help='BED file describing regions so we can output BigWig tracks') parser.add_option( '--ti', dest='track_indexes', help='Comma-separated list of target indexes to output BigWig tracks') parser.add_option('--train', dest='train', default=False, action='store_true', help='Process the training set [Default: %default]') parser.add_option('-v', dest='valid', default=False, action='store_true', help='Process the validation set [Default: %default]') parser.add_option( '-w', dest='pool_width', default=1, type='int', help= 'Max pool width for regressing nt predictions to predict peak calls [Default: %default]' ) (options, args) = parser.parse_args() if len(args) != 3: parser.error('Must provide parameters, model, and test data HDF5') else: params_file = args[0] model_file = args[1] test_hdf5_file = args[2] if not os.path.isdir(options.out_dir): os.mkdir(options.out_dir) options.shifts = [int(shift) for shift in options.shifts.split(',')] ####################################################### # load data ####################################################### data_open = h5py.File(test_hdf5_file) if options.train: test_seqs = data_open['train_in'] test_targets = data_open['train_out'] if 'train_na' in data_open: test_na = data_open['train_na'] elif options.valid: test_seqs = data_open['valid_in'] test_targets = data_open['valid_out'] test_na = None if 'valid_na' in data_open: test_na = data_open['valid_na'] else: test_seqs = data_open['test_in'] test_targets = data_open['test_out'] test_na = None if 'test_na' in data_open: test_na = data_open['test_na'] if options.sample_pct < 1: sample_n = int(test_seqs.shape[0] * options.sample_pct) print('Sampling %d sequences' % sample_n) sample_indexes = sorted( np.random.choice(np.arange(test_seqs.shape[0]), size=sample_n, replace=False)) test_seqs = test_seqs[sample_indexes] test_targets = test_targets[sample_indexes] if test_na is not None: test_na = test_na[sample_indexes] target_labels = [tl.decode('UTF-8') for tl in data_open['target_labels']] ####################################################### # model parameters and placeholders job = params.read_job_params(params_file) job['seq_length'] = test_seqs.shape[1] job['seq_depth'] = test_seqs.shape[2] job['num_targets'] = test_targets.shape[2] job['target_pool'] = int(np.array(data_open.get('pool_width', 1))) t0 = time.time() dr = seqnn.SeqNN() dr.build(job) print('Model building time %ds' % (time.time() - t0)) # adjust for fourier job['fourier'] = 'train_out_imag' in data_open if job['fourier']: test_targets_imag = data_open['test_out_imag'] if options.valid: test_targets_imag = data_open['valid_out_imag'] # adjust for factors if options.scent_file is not None: t0 = time.time() test_targets_full = data_open['test_out_full'] model = joblib.load(options.scent_file) ####################################################### # test # initialize batcher if job['fourier']: batcher_test = batcher.BatcherF(test_seqs, test_targets, test_targets_imag, test_na, dr.hp.batch_size, dr.hp.target_pool) else: batcher_test = batcher.Batcher(test_seqs, test_targets, test_na, dr.hp.batch_size, dr.hp.target_pool) # initialize saver saver = tf.train.Saver() with tf.Session() as sess: # load variables into session saver.restore(sess, model_file) # test t0 = time.time() test_acc = dr.test(sess, batcher_test, rc=options.rc, shifts=options.shifts, mc_n=options.mc_n) if options.save: np.save('%s/preds.npy' % options.out_dir, test_acc.preds) np.save('%s/targets.npy' % options.out_dir, test_acc.targets) test_preds = test_acc.preds print('SeqNN test: %ds' % (time.time() - t0)) # compute stats t0 = time.time() test_r2 = test_acc.r2(clip=options.target_clip) # test_log_r2 = test_acc.r2(log=True, clip=options.target_clip) test_pcor = test_acc.pearsonr(clip=options.target_clip) test_log_pcor = test_acc.pearsonr(log=True, clip=options.target_clip) #test_scor = test_acc.spearmanr() # too slow; mostly driven by low values print('Compute stats: %ds' % (time.time() - t0)) # print print('Test Loss: %7.5f' % test_acc.loss) print('Test R2: %7.5f' % test_r2.mean()) # print('Test log R2: %7.5f' % test_log_r2.mean()) print('Test PearsonR: %7.5f' % test_pcor.mean()) print('Test log PearsonR: %7.5f' % test_log_pcor.mean()) # print('Test SpearmanR: %7.5f' % test_scor.mean()) acc_out = open('%s/acc.txt' % options.out_dir, 'w') for ti in range(len(test_r2)): print('%4d %7.5f %.5f %.5f %.5f %s' % (ti, test_acc.target_losses[ti], test_r2[ti], test_pcor[ti], test_log_pcor[ti], target_labels[ti]), file=acc_out) acc_out.close() # print normalization factors target_means = test_preds.mean(axis=(0, 1), dtype='float64') target_means_median = np.median(target_means) target_means /= target_means_median norm_out = open('%s/normalization.txt' % options.out_dir, 'w') print('\n'.join([str(tu) for tu in target_means]), file=norm_out) norm_out.close() # clean up del test_acc # if test targets are reconstructed, measure versus the truth if options.scent_file is not None: compute_full_accuracy(dr, model, test_preds, test_targets_full, options.out_dir, options.down_sample) ####################################################### # peak call accuracy if options.peaks: # sample every few bins to decrease correlations ds_indexes_preds = np.arange(0, test_preds.shape[1], 8) ds_indexes_targets = ds_indexes_preds + (dr.hp.batch_buffer // dr.hp.target_pool) aurocs = [] auprcs = [] peaks_out = open('%s/peaks.txt' % options.out_dir, 'w') for ti in range(test_targets.shape[2]): if options.scent_file is not None: test_targets_ti = test_targets_full[:, :, ti] else: test_targets_ti = test_targets[:, :, ti] # subset and flatten test_targets_ti_flat = test_targets_ti[:, ds_indexes_targets].flatten( ).astype('float32') test_preds_ti_flat = test_preds[:, ds_indexes_preds, ti].flatten().astype('float32') # call peaks test_targets_ti_lambda = np.mean(test_targets_ti_flat) test_targets_pvals = 1 - poisson.cdf( np.round(test_targets_ti_flat) - 1, mu=test_targets_ti_lambda) test_targets_qvals = np.array(ben_hoch(test_targets_pvals)) test_targets_peaks = test_targets_qvals < 0.01 if test_targets_peaks.sum() == 0: aurocs.append(0.5) auprcs.append(0) else: # compute prediction accuracy aurocs.append( roc_auc_score(test_targets_peaks, test_preds_ti_flat)) auprcs.append( average_precision_score(test_targets_peaks, test_preds_ti_flat)) print('%4d %6d %.5f %.5f' % (ti, test_targets_peaks.sum(), aurocs[-1], auprcs[-1]), file=peaks_out) peaks_out.close() print('Test AUROC: %7.5f' % np.mean(aurocs)) print('Test AUPRC: %7.5f' % np.mean(auprcs)) ####################################################### # BigWig tracks # NOTE: THESE ASSUME THERE WAS NO DOWN-SAMPLING ABOVE # print bigwig tracks for visualization if options.track_bed: if options.genome_file is None: parser.error( 'Must provide genome file in order to print valid BigWigs') if not os.path.isdir('%s/tracks' % options.out_dir): os.mkdir('%s/tracks' % options.out_dir) track_indexes = range(test_preds.shape[2]) if options.track_indexes: track_indexes = [ int(ti) for ti in options.track_indexes.split(',') ] bed_set = 'test' if options.valid: bed_set = 'valid' for ti in track_indexes: if options.scent_file is not None: test_targets_ti = test_targets_full[:, :, ti] else: test_targets_ti = test_targets[:, :, ti] # make true targets bigwig bw_file = '%s/tracks/t%d_true.bw' % (options.out_dir, ti) bigwig_write(bw_file, test_targets_ti, options.track_bed, options.genome_file, bed_set=bed_set) # make predictions bigwig bw_file = '%s/tracks/t%d_preds.bw' % (options.out_dir, ti) bigwig_write(bw_file, test_preds[:, :, ti], options.track_bed, options.genome_file, dr.hp.batch_buffer, bed_set=bed_set) # make NA bigwig bw_file = '%s/tracks/na.bw' % options.out_dir bigwig_write(bw_file, test_na, options.track_bed, options.genome_file, bed_set=bed_set) ####################################################### # accuracy plots if options.accuracy_indexes is not None: accuracy_indexes = [ int(ti) for ti in options.accuracy_indexes.split(',') ] if not os.path.isdir('%s/scatter' % options.out_dir): os.mkdir('%s/scatter' % options.out_dir) if not os.path.isdir('%s/violin' % options.out_dir): os.mkdir('%s/violin' % options.out_dir) if not os.path.isdir('%s/roc' % options.out_dir): os.mkdir('%s/roc' % options.out_dir) if not os.path.isdir('%s/pr' % options.out_dir): os.mkdir('%s/pr' % options.out_dir) for ti in accuracy_indexes: if options.scent_file is not None: test_targets_ti = test_targets_full[:, :, ti] else: test_targets_ti = test_targets[:, :, ti] ############################################ # scatter # sample every few bins (adjust to plot the # points I want) ds_indexes_preds = np.arange(0, test_preds.shape[1], 8) ds_indexes_targets = ds_indexes_preds + (dr.hp.batch_buffer // dr.hp.target_pool) # subset and flatten test_targets_ti_flat = test_targets_ti[:, ds_indexes_targets].flatten( ).astype('float32') test_preds_ti_flat = test_preds[:, ds_indexes_preds, ti].flatten().astype('float32') # take log2 test_targets_ti_log = np.log2(test_targets_ti_flat + 1) test_preds_ti_log = np.log2(test_preds_ti_flat + 1) # plot log2 sns.set(font_scale=1.2, style='ticks') out_pdf = '%s/scatter/t%d.pdf' % (options.out_dir, ti) plots.regplot(test_targets_ti_log, test_preds_ti_log, out_pdf, poly_order=1, alpha=0.3, sample=500, figsize=(6, 6), x_label='log2 Experiment', y_label='log2 Prediction', table=True) ############################################ # violin # call peaks test_targets_ti_lambda = np.mean(test_targets_ti_flat) test_targets_pvals = 1 - poisson.cdf( np.round(test_targets_ti_flat) - 1, mu=test_targets_ti_lambda) test_targets_qvals = np.array(ben_hoch(test_targets_pvals)) test_targets_peaks = test_targets_qvals < 0.01 test_targets_peaks_str = np.where(test_targets_peaks, 'Peak', 'Background') # violin plot sns.set(font_scale=1.3, style='ticks') plt.figure() df = pd.DataFrame({ 'log2 Prediction': np.log2(test_preds_ti_flat + 1), 'Experimental coverage status': test_targets_peaks_str }) ax = sns.violinplot(x='Experimental coverage status', y='log2 Prediction', data=df) ax.grid(True, linestyle=':') plt.savefig('%s/violin/t%d.pdf' % (options.out_dir, ti)) plt.close() # ROC plt.figure() fpr, tpr, _ = roc_curve(test_targets_peaks, test_preds_ti_flat) auroc = roc_auc_score(test_targets_peaks, test_preds_ti_flat) plt.plot([0, 1], [0, 1], c='black', linewidth=1, linestyle='--', alpha=0.7) plt.plot(fpr, tpr, c='black') ax = plt.gca() ax.set_xlabel('False positive rate') ax.set_ylabel('True positive rate') ax.text(0.99, 0.02, 'AUROC %.3f' % auroc, horizontalalignment='right') # , fontsize=14) ax.grid(True, linestyle=':') plt.savefig('%s/roc/t%d.pdf' % (options.out_dir, ti)) plt.close() # PR plt.figure() prec, recall, _ = precision_recall_curve(test_targets_peaks, test_preds_ti_flat) auprc = average_precision_score(test_targets_peaks, test_preds_ti_flat) plt.axhline(y=test_targets_peaks.mean(), c='black', linewidth=1, linestyle='--', alpha=0.7) plt.plot(recall, prec, c='black') ax = plt.gca() ax.set_xlabel('Recall') ax.set_ylabel('Precision') ax.text(0.99, 0.95, 'AUPRC %.3f' % auprc, horizontalalignment='right') # , fontsize=14) ax.grid(True, linestyle=':') plt.savefig('%s/pr/t%d.pdf' % (options.out_dir, ti)) plt.close() data_open.close()
def main(): usage = "usage: %prog [options] <params_file> <model_file> <test_hdf5_file>" parser = OptionParser(usage) parser.add_option( "--ai", dest="accuracy_indexes", help= "Comma-separated list of target indexes to make accuracy plots comparing true versus predicted values", ) parser.add_option( "--clip", dest="target_clip", default=None, type="float", help= "Clip targets and predictions to a maximum value [Default: %default]", ) parser.add_option( "-d", dest="down_sample", default=1, type="int", help= "Down sample test computation by taking uniformly spaced positions [Default: %default]", ) parser.add_option( "-g", dest="genome_file", default="%s/data/human.hg19.genome" % os.environ["BASENJIDIR"], help="Chromosome length information [Default: %default]", ) parser.add_option( "--mc", dest="mc_n", default=0, type="int", help="Monte carlo test iterations [Default: %default]", ) parser.add_option( "--peak", "--peaks", dest="peaks", default=False, action="store_true", help="Compute expensive peak accuracy [Default: %default]", ) parser.add_option( "-o", dest="out_dir", default="test_out", help="Output directory for test statistics [Default: %default]", ) parser.add_option( "--rc", dest="rc", default=False, action="store_true", help="Average the fwd and rc predictions [Default: %default]", ) parser.add_option("-s", dest="scent_file", help="Dimension reduction model file") parser.add_option("--sample", dest="sample_pct", default=1, type="float", help="Sample percentage") parser.add_option("--save", dest="save", default=False, action="store_true") parser.add_option( "--shifts", dest="shifts", default="0", help="Ensemble prediction shifts [Default: %default]", ) parser.add_option( "-t", dest="track_bed", help="BED file describing regions so we can output BigWig tracks", ) parser.add_option( "--ti", dest="track_indexes", help="Comma-separated list of target indexes to output BigWig tracks", ) parser.add_option( "--train", dest="train", default=False, action="store_true", help="Process the training set [Default: %default]", ) parser.add_option( "-v", dest="valid", default=False, action="store_true", help="Process the validation set [Default: %default]", ) parser.add_option( "-w", dest="pool_width", default=1, type="int", help= "Max pool width for regressing nt predictions to predict peak calls [Default: %default]", ) (options, args) = parser.parse_args() if len(args) != 3: parser.error("Must provide parameters, model, and test data HDF5") else: params_file = args[0] model_file = args[1] test_hdf5_file = args[2] if not os.path.isdir(options.out_dir): os.mkdir(options.out_dir) options.shifts = [int(shift) for shift in options.shifts.split(",")] ####################################################### # load data ####################################################### data_open = h5py.File(test_hdf5_file) if options.train: test_seqs = data_open["train_in"] test_targets = data_open["train_out"] if "train_na" in data_open: test_na = data_open["train_na"] elif options.valid: test_seqs = data_open["valid_in"] test_targets = data_open["valid_out"] test_na = None if "valid_na" in data_open: test_na = data_open["valid_na"] else: test_seqs = data_open["test_in"] test_targets = data_open["test_out"] test_na = None if "test_na" in data_open: test_na = data_open["test_na"] if options.sample_pct < 1: sample_n = int(test_seqs.shape[0] * options.sample_pct) print("Sampling %d sequences" % sample_n) sample_indexes = sorted( np.random.choice(np.arange(test_seqs.shape[0]), size=sample_n, replace=False)) test_seqs = test_seqs[sample_indexes] test_targets = test_targets[sample_indexes] if test_na is not None: test_na = test_na[sample_indexes] target_labels = [tl.decode("UTF-8") for tl in data_open["target_labels"]] ####################################################### # model parameters and placeholders job = params.read_job_params(params_file) job["seq_length"] = test_seqs.shape[1] job["seq_depth"] = test_seqs.shape[2] job["num_targets"] = test_targets.shape[2] job["target_pool"] = int(np.array(data_open.get("pool_width", 1))) t0 = time.time() dr = seqnn.SeqNN() dr.build_feed(job) print("Model building time %ds" % (time.time() - t0)) # adjust for fourier job["fourier"] = "train_out_imag" in data_open if job["fourier"]: test_targets_imag = data_open["test_out_imag"] if options.valid: test_targets_imag = data_open["valid_out_imag"] # adjust for factors if options.scent_file is not None: t0 = time.time() test_targets_full = data_open["test_out_full"] model = joblib.load(options.scent_file) ####################################################### # test # initialize batcher if job["fourier"]: batcher_test = batcher.BatcherF( test_seqs, test_targets, test_targets_imag, test_na, dr.hp.batch_size, dr.hp.target_pool, ) else: batcher_test = batcher.Batcher(test_seqs, test_targets, test_na, dr.hp.batch_size, dr.hp.target_pool) # initialize saver saver = tf.train.Saver() with tf.Session() as sess: # load variables into session saver.restore(sess, model_file) # test t0 = time.time() test_acc = dr.test_h5_manual(sess, batcher_test, rc=options.rc, shifts=options.shifts, mc_n=options.mc_n) if options.save: np.save("%s/preds.npy" % options.out_dir, test_acc.preds) np.save("%s/targets.npy" % options.out_dir, test_acc.targets) test_preds = test_acc.preds print("SeqNN test: %ds" % (time.time() - t0)) # compute stats t0 = time.time() test_r2 = test_acc.r2(clip=options.target_clip) # test_log_r2 = test_acc.r2(log=True, clip=options.target_clip) test_pcor = test_acc.pearsonr(clip=options.target_clip) test_log_pcor = test_acc.pearsonr(log=True, clip=options.target_clip) # test_scor = test_acc.spearmanr() # too slow; mostly driven by low values print("Compute stats: %ds" % (time.time() - t0)) # print print("Test Loss: %7.5f" % test_acc.loss) print("Test R2: %7.5f" % test_r2.mean()) # print('Test log R2: %7.5f' % test_log_r2.mean()) print("Test PearsonR: %7.5f" % test_pcor.mean()) print("Test log PearsonR: %7.5f" % test_log_pcor.mean()) # print('Test SpearmanR: %7.5f' % test_scor.mean()) acc_out = open("%s/acc.txt" % options.out_dir, "w") for ti in range(len(test_r2)): print( "%4d %7.5f %.5f %.5f %.5f %s" % ( ti, test_acc.target_losses[ti], test_r2[ti], test_pcor[ti], test_log_pcor[ti], target_labels[ti], ), file=acc_out, ) acc_out.close() # print normalization factors target_means = test_preds.mean(axis=(0, 1), dtype="float64") target_means_median = np.median(target_means) target_means /= target_means_median norm_out = open("%s/normalization.txt" % options.out_dir, "w") print("\n".join([str(tu) for tu in target_means]), file=norm_out) norm_out.close() # clean up del test_acc # if test targets are reconstructed, measure versus the truth if options.scent_file is not None: compute_full_accuracy( dr, model, test_preds, test_targets_full, options.out_dir, options.down_sample, ) ####################################################### # peak call accuracy if options.peaks: # sample every few bins to decrease correlations ds_indexes_preds = np.arange(0, test_preds.shape[1], 8) ds_indexes_targets = ds_indexes_preds + (dr.hp.batch_buffer // dr.hp.target_pool) aurocs = [] auprcs = [] peaks_out = open("%s/peaks.txt" % options.out_dir, "w") for ti in range(test_targets.shape[2]): if options.scent_file is not None: test_targets_ti = test_targets_full[:, :, ti] else: test_targets_ti = test_targets[:, :, ti] # subset and flatten test_targets_ti_flat = (test_targets_ti[:, ds_indexes_targets]. flatten().astype("float32")) test_preds_ti_flat = (test_preds[:, ds_indexes_preds, ti].flatten().astype("float32")) # call peaks test_targets_ti_lambda = np.mean(test_targets_ti_flat) test_targets_pvals = 1 - poisson.cdf( np.round(test_targets_ti_flat) - 1, mu=test_targets_ti_lambda) test_targets_qvals = np.array(ben_hoch(test_targets_pvals)) test_targets_peaks = test_targets_qvals < 0.01 if test_targets_peaks.sum() == 0: aurocs.append(0.5) auprcs.append(0) else: # compute prediction accuracy aurocs.append( roc_auc_score(test_targets_peaks, test_preds_ti_flat)) auprcs.append( average_precision_score(test_targets_peaks, test_preds_ti_flat)) print( "%4d %6d %.5f %.5f" % (ti, test_targets_peaks.sum(), aurocs[-1], auprcs[-1]), file=peaks_out, ) peaks_out.close() print("Test AUROC: %7.5f" % np.mean(aurocs)) print("Test AUPRC: %7.5f" % np.mean(auprcs)) ####################################################### # BigWig tracks # NOTE: THESE ASSUME THERE WAS NO DOWN-SAMPLING ABOVE # print bigwig tracks for visualization if options.track_bed: if options.genome_file is None: parser.error( "Must provide genome file in order to print valid BigWigs") if not os.path.isdir("%s/tracks" % options.out_dir): os.mkdir("%s/tracks" % options.out_dir) track_indexes = range(test_preds.shape[2]) if options.track_indexes: track_indexes = [ int(ti) for ti in options.track_indexes.split(",") ] bed_set = "test" if options.valid: bed_set = "valid" for ti in track_indexes: if options.scent_file is not None: test_targets_ti = test_targets_full[:, :, ti] else: test_targets_ti = test_targets[:, :, ti] # make true targets bigwig bw_file = "%s/tracks/t%d_true.bw" % (options.out_dir, ti) bigwig_write( bw_file, test_targets_ti, options.track_bed, options.genome_file, bed_set=bed_set, ) # make predictions bigwig bw_file = "%s/tracks/t%d_preds.bw" % (options.out_dir, ti) bigwig_write( bw_file, test_preds[:, :, ti], options.track_bed, options.genome_file, dr.hp.batch_buffer, bed_set=bed_set, ) # make NA bigwig # bw_file = '%s/tracks/na.bw' % options.out_dir # bigwig_write( # bw_file, # test_na, # options.track_bed, # options.genome_file, # bed_set=bed_set) ####################################################### # accuracy plots if options.accuracy_indexes is not None: accuracy_indexes = [ int(ti) for ti in options.accuracy_indexes.split(",") ] if not os.path.isdir("%s/scatter" % options.out_dir): os.mkdir("%s/scatter" % options.out_dir) if not os.path.isdir("%s/violin" % options.out_dir): os.mkdir("%s/violin" % options.out_dir) if not os.path.isdir("%s/roc" % options.out_dir): os.mkdir("%s/roc" % options.out_dir) if not os.path.isdir("%s/pr" % options.out_dir): os.mkdir("%s/pr" % options.out_dir) for ti in accuracy_indexes: if options.scent_file is not None: test_targets_ti = test_targets_full[:, :, ti] else: test_targets_ti = test_targets[:, :, ti] ############################################ # scatter # sample every few bins (adjust to plot the # points I want) ds_indexes_preds = np.arange(0, test_preds.shape[1], 8) ds_indexes_targets = ds_indexes_preds + (dr.hp.batch_buffer // dr.hp.target_pool) # subset and flatten test_targets_ti_flat = (test_targets_ti[:, ds_indexes_targets]. flatten().astype("float32")) test_preds_ti_flat = (test_preds[:, ds_indexes_preds, ti].flatten().astype("float32")) # take log2 test_targets_ti_log = np.log2(test_targets_ti_flat + 1) test_preds_ti_log = np.log2(test_preds_ti_flat + 1) # plot log2 sns.set(font_scale=1.2, style="ticks") out_pdf = "%s/scatter/t%d.pdf" % (options.out_dir, ti) plots.regplot( test_targets_ti_log, test_preds_ti_log, out_pdf, poly_order=1, alpha=0.3, sample=500, figsize=(6, 6), x_label="log2 Experiment", y_label="log2 Prediction", table=True, ) ############################################ # violin # call peaks test_targets_ti_lambda = np.mean(test_targets_ti_flat) test_targets_pvals = 1 - poisson.cdf( np.round(test_targets_ti_flat) - 1, mu=test_targets_ti_lambda) test_targets_qvals = np.array(ben_hoch(test_targets_pvals)) test_targets_peaks = test_targets_qvals < 0.01 test_targets_peaks_str = np.where(test_targets_peaks, "Peak", "Background") # violin plot sns.set(font_scale=1.3, style="ticks") plt.figure() df = pd.DataFrame({ "log2 Prediction": np.log2(test_preds_ti_flat + 1), "Experimental coverage status": test_targets_peaks_str, }) ax = sns.violinplot(x="Experimental coverage status", y="log2 Prediction", data=df) ax.grid(True, linestyle=":") plt.savefig("%s/violin/t%d.pdf" % (options.out_dir, ti)) plt.close() # ROC plt.figure() fpr, tpr, _ = roc_curve(test_targets_peaks, test_preds_ti_flat) auroc = roc_auc_score(test_targets_peaks, test_preds_ti_flat) plt.plot([0, 1], [0, 1], c="black", linewidth=1, linestyle="--", alpha=0.7) plt.plot(fpr, tpr, c="black") ax = plt.gca() ax.set_xlabel("False positive rate") ax.set_ylabel("True positive rate") ax.text(0.99, 0.02, "AUROC %.3f" % auroc, horizontalalignment="right") # , fontsize=14) ax.grid(True, linestyle=":") plt.savefig("%s/roc/t%d.pdf" % (options.out_dir, ti)) plt.close() # PR plt.figure() prec, recall, _ = precision_recall_curve(test_targets_peaks, test_preds_ti_flat) auprc = average_precision_score(test_targets_peaks, test_preds_ti_flat) plt.axhline( y=test_targets_peaks.mean(), c="black", linewidth=1, linestyle="--", alpha=0.7, ) plt.plot(recall, prec, c="black") ax = plt.gca() ax.set_xlabel("Recall") ax.set_ylabel("Precision") ax.text(0.99, 0.95, "AUPRC %.3f" % auprc, horizontalalignment="right") # , fontsize=14) ax.grid(True, linestyle=":") plt.savefig("%s/pr/t%d.pdf" % (options.out_dir, ti)) plt.close() data_open.close()
def run(params_file, data_file, train_epochs, train_epoch_batches, test_epoch_batches): ####################################################### # load data ####################################################### data_open = h5py.File(data_file) train_seqs = data_open['train_in'] train_targets = data_open['train_out'] train_na = None if 'train_na' in data_open: train_na = data_open['train_na'] valid_seqs = data_open['valid_in'] valid_targets = data_open['valid_out'] valid_na = None if 'valid_na' in data_open: valid_na = data_open['valid_na'] ####################################################### # model parameters and placeholders ####################################################### job = params.read_job_params(params_file) job['seq_length'] = train_seqs.shape[1] job['seq_depth'] = train_seqs.shape[2] job['num_targets'] = train_targets.shape[2] job['target_pool'] = int(np.array(data_open.get('pool_width', 1))) t0 = time.time() model = seqnn.SeqNN() model.build(job) print('Model building time %f' % (time.time() - t0)) # adjust for fourier job['fourier'] = 'train_out_imag' in data_open if job['fourier']: train_targets_imag = data_open['train_out_imag'] valid_targets_imag = data_open['valid_out_imag'] ####################################################### # prepare batcher ####################################################### if job['fourier']: batcher_train = batcher.BatcherF(train_seqs, train_targets, train_targets_imag, train_na, model.hp.batch_size, model.hp.target_pool, shuffle=True) batcher_valid = batcher.BatcherF(valid_seqs, valid_targets, valid_targets_imag, valid_na, model.batch_size, model.target_pool) else: batcher_train = batcher.Batcher(train_seqs, train_targets, train_na, model.hp.batch_size, model.hp.target_pool, shuffle=True) batcher_valid = batcher.Batcher(valid_seqs, valid_targets, valid_na, model.hp.batch_size, model.hp.target_pool) print('Batcher initialized') ####################################################### # train ####################################################### augment_shifts = [int(shift) for shift in FLAGS.augment_shifts.split(',')] ensemble_shifts = [ int(shift) for shift in FLAGS.ensemble_shifts.split(',') ] # checkpoints saver = tf.train.Saver() config = tf.ConfigProto() if FLAGS.log_device_placement: config.log_device_placement = True with tf.Session(config=config) as sess: t0 = time.time() # set seed tf.set_random_seed(FLAGS.seed) if FLAGS.logdir: train_writer = tf.summary.FileWriter(FLAGS.logdir + '/train', sess.graph) else: train_writer = None if FLAGS.restart: # load variables into session saver.restore(sess, FLAGS.restart) else: # initialize variables print('Initializing...') sess.run(tf.global_variables_initializer()) print('Initialization time %f' % (time.time() - t0)) train_loss = None best_loss = None early_stop_i = 0 epoch = 0 while (train_epochs is None or epochs < train_epochs) and early_stop_i < FLAGS.early_stop: t0 = time.time() # alternate forward and reverse batches fwdrc = True if FLAGS.augment_rc and epoch % 2 == 1: fwdrc = False # cycle shifts shift_i = epoch % len(augment_shifts) # train train_loss, steps = model.train_epoch( sess, batcher_train, fwdrc=fwdrc, shift=augment_shifts[shift_i], sum_writer=train_writer, epoch_batches=train_epoch_batches, no_steps=FLAGS.no_steps) # validate valid_acc = model.test(sess, batcher_valid, mc_n=FLAGS.ensemble_mc, rc=FLAGS.ensemble_rc, shifts=ensemble_shifts, test_batches=test_epoch_batches) valid_loss = valid_acc.loss valid_r2 = valid_acc.r2().mean() del valid_acc best_str = '' if best_loss is None or valid_loss < best_loss: best_loss = valid_loss best_str = ', best!' early_stop_i = 0 saver.save(sess, '%s/model_best.tf' % FLAGS.logdir) else: early_stop_i += 1 # measure time et = time.time() - t0 if et < 600: time_str = '%3ds' % et elif et < 6000: time_str = '%3dm' % (et / 60) else: time_str = '%3.1fh' % (et / 3600) # print update print( 'Epoch: %3d, Steps: %7d, Train loss: %7.5f, Valid loss: %7.5f, Valid R2: %7.5f, Time: %s%s' % (epoch + 1, steps, train_loss, valid_loss, valid_r2, time_str, best_str)) sys.stdout.flush() if FLAGS.check_all: saver.save(sess, '%s/model_check%d.tf' % (FLAGS.logdir, epoch)) # update epoch epoch += 1 if FLAGS.logdir: train_writer.close()
def run(params_file, data_file, num_train_epochs): shifts = [int(shift) for shift in FLAGS.shifts.split(',')] ####################################################### # load data ####################################################### data_open = h5py.File(data_file) train_seqs = data_open['train_in'] train_targets = data_open['train_out'] train_na = None if 'train_na' in data_open: train_na = data_open['train_na'] valid_seqs = data_open['valid_in'] valid_targets = data_open['valid_out'] valid_na = None if 'valid_na' in data_open: valid_na = data_open['valid_na'] ####################################################### # model parameters and placeholders ####################################################### job = dna_io.read_job_params(params_file) job['batch_length'] = train_seqs.shape[1] job['seq_depth'] = train_seqs.shape[2] job['num_targets'] = train_targets.shape[2] job['target_pool'] = int(np.array(data_open.get('pool_width', 1))) job['early_stop'] = job.get('early_stop', 16) job['rate_drop'] = job.get('rate_drop', 3) t0 = time.time() dr = seqnn.SeqNN() dr.build(job) print('Model building time %f' % (time.time() - t0)) # adjust for fourier job['fourier'] = 'train_out_imag' in data_open if job['fourier']: train_targets_imag = data_open['train_out_imag'] valid_targets_imag = data_open['valid_out_imag'] ####################################################### # train ####################################################### # initialize batcher if job['fourier']: batcher_train = batcher.BatcherF(train_seqs, train_targets, train_targets_imag, train_na, dr.batch_size, dr.target_pool, shuffle=True) batcher_valid = batcher.BatcherF(valid_seqs, valid_targets, valid_targets_imag, valid_na, dr.batch_size, dr.target_pool) else: batcher_train = batcher.Batcher(train_seqs, train_targets, train_na, dr.batch_size, dr.target_pool, shuffle=True) batcher_valid = batcher.Batcher(valid_seqs, valid_targets, valid_na, dr.batch_size, dr.target_pool) print('Batcher initialized') # checkpoints saver = tf.train.Saver() config = tf.ConfigProto() if FLAGS.log_device_placement: config.log_device_placement = True with tf.Session(config=config) as sess: t0 = time.time() # set seed tf.set_random_seed(FLAGS.seed) if FLAGS.logdir: train_writer = tf.summary.FileWriter(FLAGS.logdir + '/train', sess.graph) else: train_writer = None if FLAGS.restart: # load variables into session saver.restore(sess, FLAGS.restart) else: # initialize variables print('Initializing...') sess.run(tf.global_variables_initializer()) print('Initialization time %f' % (time.time() - t0)) train_loss = None best_loss = None early_stop_i = 0 undroppable_counter = 3 max_drops = 8 num_drops = 0 for epoch in range(num_train_epochs): if early_stop_i < job['early_stop'] or epoch < FLAGS.min_epochs: t0 = time.time() # save previous train_loss_last = train_loss # alternate forward and reverse batches fwdrc = True if FLAGS.rc and epoch % 2 == 1: fwdrc = False # cycle shifts shift_i = epoch % len(shifts) # train train_loss = dr.train_epoch(sess, batcher_train, fwdrc, shifts[shift_i], train_writer) # validate valid_acc = dr.test(sess, batcher_valid, mc_n=FLAGS.mc_n, rc=FLAGS.rc, shifts=shifts) valid_loss = valid_acc.loss valid_r2 = valid_acc.r2().mean() del valid_acc best_str = '' if best_loss is None or valid_loss < best_loss: best_loss = valid_loss best_str = ', best!' early_stop_i = 0 saver.save( sess, '%s/%s_best.tf' % (FLAGS.logdir, FLAGS.save_prefix)) else: early_stop_i += 1 # measure time et = time.time() - t0 if et < 600: time_str = '%3ds' % et elif et < 6000: time_str = '%3dm' % (et / 60) else: time_str = '%3.1fh' % (et / 3600) # print update print( 'Epoch %3d: Train loss: %7.5f, Valid loss: %7.5f, Valid R2: %7.5f, Time: %s%s' % (epoch + 1, train_loss, valid_loss, valid_r2, time_str, best_str), end='') # if training stagnant if FLAGS.learn_rate_drop and num_drops < max_drops and undroppable_counter == 0 and ( train_loss_last - train_loss) / train_loss_last < 0.0002: print(', rate drop', end='') dr.drop_rate(2 / 3) undroppable_counter = 1 num_drops += 1 else: undroppable_counter = max(0, undroppable_counter - 1) print('') sys.stdout.flush() if FLAGS.logdir: train_writer.close()