Beispiel #1
0
def gene_table(
    gene_targets,
    gene_preds,
    gene_iter,
    target_labels,
    target_indexes,
    out_prefix,
    plot_scatter,
):
    """Print a gene-based statistics table and scatter plot for the given target indexes."""

    num_genes = gene_targets.shape[0]

    table_out = open("%s_table.txt" % out_prefix, "w")

    for ti in target_indexes:
        gti = np.log2(gene_targets[:, ti].astype("float32") + 1)
        gpi = np.log2(gene_preds[:, ti].astype("float32") + 1)

        # plot scatter
        if plot_scatter:
            sns.set(font_scale=1.3, style="ticks")
            out_pdf = "%s_scatter%d.pdf" % (out_prefix, ti)
            if num_genes < 2000:
                ri = np.arange(num_genes)
            else:
                ri = np.random.choice(range(num_genes), 2000, replace=False)
            plots.regplot(
                gti[ri],
                gpi[ri],
                out_pdf,
                poly_order=3,
                alpha=0.3,
                x_label="log2 Experiment",
                y_label="log2 Prediction",
            )

        # print table lines
        tx_i = 0
        for gid in gene_iter:
            # print TSS line
            cols = (gid, gti[tx_i], gpi[tx_i], ti, target_labels[ti])
            print("%-20s  %.3f  %.3f  %4d  %20s" % cols, file=table_out)
            tx_i += 1

    table_out.close()

    subprocess.call("gzip -f %s_table.txt" % out_prefix, shell=True)
Beispiel #2
0
def main():
    usage = 'usage: %prog [options] <params_file> <model_file> <data_dir>'
    parser = OptionParser(usage)
    parser.add_option(
        '--ai',
        dest='accuracy_indexes',
        help=
        'Comma-separated list of target indexes to make accuracy scatter plots.'
    )
    parser.add_option('--mc',
                      dest='mc_n',
                      default=0,
                      type='int',
                      help='Monte carlo test iterations [Default: %default]')
    parser.add_option(
        '--peak',
        '--peaks',
        dest='peaks',
        default=False,
        action='store_true',
        help='Compute expensive peak accuracy [Default: %default]')
    parser.add_option(
        '-o',
        dest='out_dir',
        default='test_out',
        help='Output directory for test statistics [Default: %default]')
    parser.add_option(
        '--rc',
        dest='rc',
        default=False,
        action='store_true',
        help='Average the fwd and rc predictions [Default: %default]')
    parser.add_option(
        '--save',
        dest='save',
        default=False,
        action='store_true',
        help='Save targets and predictions numpy arrays [Default: %default]')
    parser.add_option('--shifts',
                      dest='shifts',
                      default='0',
                      help='Ensemble prediction shifts [Default: %default]')
    parser.add_option(
        '-t',
        dest='targets_file',
        default=None,
        type='str',
        help='File specifying target indexes and labels in table format')
    parser.add_option(
        '--tfr',
        dest='tfr_pattern',
        default='test-*.tfr',
        help='TFR pattern string appended to data_dir [Default: %default]')
    (options, args) = parser.parse_args()

    if len(args) != 3:
        parser.error('Must provide parameters, model, and test data HDF5')
    else:
        params_file = args[0]
        model_file = args[1]
        data_dir = args[2]

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    # parse shifts to integers
    options.shifts = [int(shift) for shift in options.shifts.split(',')]

    #######################################################
    # inputs

    # read targets
    if options.targets_file is None:
        options.targets_file = '%s/targets.txt' % data_dir
    targets_df = pd.read_csv(options.targets_file, index_col=0, sep='\t')

    # read model parameters
    with open(params_file) as params_open:
        params = json.load(params_open)
    params_model = params['model']
    params_train = params['train']

    # read data parameters
    data_stats_file = '%s/statistics.json' % data_dir
    with open(data_stats_file) as data_stats_open:
        data_stats = json.load(data_stats_open)

    # construct data ops
    tfr_pattern_path = '%s/tfrecords/%s' % (data_dir, options.tfr_pattern)
    eval_data = dataset.SeqDataset(tfr_pattern_path,
                                   seq_length=data_stats['seq_length'],
                                   target_length=data_stats['target_length'],
                                   batch_size=params_train['batch_size'],
                                   mode=tf.estimator.ModeKeys.EVAL)

    # initialize model
    seqnn_model = seqnn.SeqNN(params_model)
    seqnn_model.restore(model_file)
    seqnn_model.build_ensemble(options.rc, options.shifts)

    #######################################################
    # evaluate

    eval_loss = params_train.get('loss', 'poisson')

    # evaluate
    test_loss, test_pr, test_r2 = seqnn_model.evaluate(eval_data,
                                                       loss=eval_loss)
    print('')

    # print summary statistics
    print('Test Loss:         %7.5f' % test_loss)
    print('Test R2:           %7.5f' % test_r2.mean())
    print('Test PearsonR:     %7.5f' % test_pr.mean())

    # write target-level statistics
    targets_acc_df = pd.DataFrame({
        'index': targets_df.index,
        'r2': test_r2,
        'pearsonr': test_pr,
        'identifier': targets_df.identifier,
        'description': targets_df.description
    })
    targets_acc_df.to_csv('%s/acc.txt' % options.out_dir,
                          sep='\t',
                          index=False,
                          float_format='%.5f')

    #######################################################
    # predict?

    if options.save or options.peaks or options.accuracy_indexes is not None:
        # compute predictions
        test_preds = seqnn_model.predict(eval_data).astype('float16')

        # read targets
        test_targets = eval_data.numpy(return_inputs=False)

    if options.save:
        preds_h5 = h5py.File('%s/preds.h5' % options.out_dir, 'w')
        preds_h5.create_dataset('preds', data=test_preds)
        preds_h5.close()
        targets_h5 = h5py.File('%s/targets.h5' % options.out_dir, 'w')
        targets_h5.create_dataset('targets', data=test_targets)
        targets_h5.close()

    #######################################################
    # peak call accuracy

    if options.peaks:
        peaks_out_file = '%s/peaks.txt' % options.out_dir
        test_peaks(test_preds, test_targets, peaks_out_file)

    #######################################################
    # accuracy plots

    if options.accuracy_indexes is not None:
        accuracy_indexes = [
            int(ti) for ti in options.accuracy_indexes.split(',')
        ]

        if not os.path.isdir('%s/scatter' % options.out_dir):
            os.mkdir('%s/scatter' % options.out_dir)

        if not os.path.isdir('%s/violin' % options.out_dir):
            os.mkdir('%s/violin' % options.out_dir)

        if not os.path.isdir('%s/roc' % options.out_dir):
            os.mkdir('%s/roc' % options.out_dir)

        if not os.path.isdir('%s/pr' % options.out_dir):
            os.mkdir('%s/pr' % options.out_dir)

        for ti in accuracy_indexes:
            test_targets_ti = test_targets[:, :, ti]

            ############################################
            # scatter

            # sample every few bins (adjust to plot the # points I want)
            ds_indexes = np.arange(0, test_preds.shape[1], 8)

            # subset and flatten
            test_targets_ti_flat = test_targets_ti[:, ds_indexes].flatten(
            ).astype('float32')
            test_preds_ti_flat = test_preds[:, ds_indexes,
                                            ti].flatten().astype('float32')

            # take log2
            test_targets_ti_log = np.log2(test_targets_ti_flat + 1)
            test_preds_ti_log = np.log2(test_preds_ti_flat + 1)

            # plot log2
            sns.set(font_scale=1.2, style='ticks')
            out_pdf = '%s/scatter/t%d.pdf' % (options.out_dir, ti)
            plots.regplot(test_targets_ti_log,
                          test_preds_ti_log,
                          out_pdf,
                          poly_order=1,
                          alpha=0.3,
                          sample=500,
                          figsize=(6, 6),
                          x_label='log2 Experiment',
                          y_label='log2 Prediction',
                          table=True)

            ############################################
            # violin

            # call peaks
            test_targets_ti_lambda = np.mean(test_targets_ti_flat)
            test_targets_pvals = 1 - poisson.cdf(
                np.round(test_targets_ti_flat) - 1, mu=test_targets_ti_lambda)
            test_targets_qvals = np.array(ben_hoch(test_targets_pvals))
            test_targets_peaks = test_targets_qvals < 0.01
            test_targets_peaks_str = np.where(test_targets_peaks, 'Peak',
                                              'Background')

            # violin plot
            sns.set(font_scale=1.3, style='ticks')
            plt.figure()
            df = pd.DataFrame({
                'log2 Prediction':
                np.log2(test_preds_ti_flat + 1),
                'Experimental coverage status':
                test_targets_peaks_str
            })
            ax = sns.violinplot(x='Experimental coverage status',
                                y='log2 Prediction',
                                data=df)
            ax.grid(True, linestyle=':')
            plt.savefig('%s/violin/t%d.pdf' % (options.out_dir, ti))
            plt.close()

            # ROC
            plt.figure()
            fpr, tpr, _ = roc_curve(test_targets_peaks, test_preds_ti_flat)
            auroc = roc_auc_score(test_targets_peaks, test_preds_ti_flat)
            plt.plot([0, 1], [0, 1],
                     c='black',
                     linewidth=1,
                     linestyle='--',
                     alpha=0.7)
            plt.plot(fpr, tpr, c='black')
            ax = plt.gca()
            ax.set_xlabel('False positive rate')
            ax.set_ylabel('True positive rate')
            ax.text(0.99,
                    0.02,
                    'AUROC %.3f' % auroc,
                    horizontalalignment='right')  # , fontsize=14)
            ax.grid(True, linestyle=':')
            plt.savefig('%s/roc/t%d.pdf' % (options.out_dir, ti))
            plt.close()

            # PR
            plt.figure()
            prec, recall, _ = precision_recall_curve(test_targets_peaks,
                                                     test_preds_ti_flat)
            auprc = average_precision_score(test_targets_peaks,
                                            test_preds_ti_flat)
            plt.axhline(y=test_targets_peaks.mean(),
                        c='black',
                        linewidth=1,
                        linestyle='--',
                        alpha=0.7)
            plt.plot(recall, prec, c='black')
            ax = plt.gca()
            ax.set_xlabel('Recall')
            ax.set_ylabel('Precision')
            ax.text(0.99,
                    0.95,
                    'AUPRC %.3f' % auprc,
                    horizontalalignment='right')  # , fontsize=14)
            ax.grid(True, linestyle=':')
            plt.savefig('%s/pr/t%d.pdf' % (options.out_dir, ti))
            plt.close()
Beispiel #3
0
def replicate_correlations(replicate_lists,
                           gene_targets,
                           gene_preds,
                           target_indexes,
                           out_prefix,
                           scatter_plots=False):
  """ Study replicate correlations. """

  # for intersections
  target_set = set(target_indexes)

  rep_cors = []
  pred_cors = []

  table_out = open('%s.txt' % out_prefix, 'w')
  sns.set(style='ticks', font_scale=1.3)
  num_genes = gene_targets.shape[0]

  li = 0
  replicate_labels = sorted(replicate_lists.keys())

  for label in replicate_labels:
    if len(replicate_lists[label]) > 1 and target_set & set(
        replicate_lists[label]):
      ti1 = replicate_lists[label][0]
      ti2 = replicate_lists[label][1]

      # retrieve targets
      gene_targets_rep1 = np.log2(gene_targets[:, ti1].astype('float32') + 1)
      gene_targets_rep2 = np.log2(gene_targets[:, ti2].astype('float32') + 1)

      # retrieve predictions
      gene_preds_rep1 = np.log2(gene_preds[:, ti1].astype('float32') + 1)
      gene_preds_rep2 = np.log2(gene_preds[:, ti2].astype('float32') + 1)

      #####################################
      # replicate

      # compute replicate correlation
      rcor, _ = pearsonr(gene_targets_rep1, gene_targets_rep2)
      rep_cors.append(rcor)

      # scatter plot rep vs rep
      if scatter_plots:
        out_pdf = '%s_s%d.pdf' % (out_prefix, li)
        gene_indexes = np.random.choice(range(num_genes), 1000, replace=False)
        plots.regplot(
            gene_targets_rep1[gene_indexes],
            gene_targets_rep2[gene_indexes],
            out_pdf,
            poly_order=3,
            alpha=0.3,
            x_label='log2 Replicate 1',
            y_label='log2 Replicate 2')

      #####################################
      # prediction

      # compute prediction correlation
      pcor1, _ = pearsonr(gene_targets_rep1, gene_preds_rep1)
      pcor2, _ = pearsonr(gene_targets_rep2, gene_preds_rep2)
      pcor = 0.5 * pcor1 + 0.5 * pcor2
      pred_cors.append(pcor)

      # scatter plot vs pred
      if scatter_plots:
        # scatter plot rep vs pred
        out_pdf = '%s_s%d_rep1.pdf' % (out_prefix, li)
        plots.regplot(
            gene_targets_rep1[gene_indexes],
            gene_preds_rep1[gene_indexes],
            out_pdf,
            poly_order=3,
            alpha=0.3,
            x_label='log2 Experiment',
            y_label='log2 Prediction')

        # scatter plot rep vs pred
        out_pdf = '%s_s%d_rep2.pdf' % (out_prefix, li)
        plots.regplot(
            gene_targets_rep2[gene_indexes],
            gene_preds_rep2[gene_indexes],
            out_pdf,
            poly_order=3,
            alpha=0.3,
            x_label='log2 Experiment',
            y_label='log2 Prediction')

      #####################################
      # table

      print(
          '%4d  %4d  %4d  %7.4f  %7.4f  %s' % (li, ti1, ti2, rcor, pcor, label),
          file=table_out)

      # update counter
      li += 1

  table_out.close()

  #######################################################
  # scatter plot replicate versus prediction correlation

  rep_cors = np.array(rep_cors)
  pred_cors = np.array(pred_cors)

  out_pdf = '%s_scatter.pdf' % out_prefix
  plots.jointplot(
      rep_cors,
      pred_cors,
      out_pdf,
      square=True,
      x_label='Replicate R',
      y_label='Prediction R')
Beispiel #4
0
def alternative_tss(tss_targets, tss_preds, gene_data, out_base, log_pseudo=1, tss_var_t=1, scatter_pct=0.02):
  ''' Compare predicted to experimental log2 TSS1 to TSS2 ratio. '''

  sns.set(style='ticks', font_scale=1.2)

  # normalize TSS
  tss_targets_qn = normalize_targets(tss_targets, log_pseudo=log_pseudo)
  tss_preds_qn = normalize_targets(tss_preds, log_pseudo=log_pseudo)

  # compute
  tss_targets_var = tss_targets_qn.var(axis=1, dtype='float64')

  # save genes for later plotting
  gene_tss12_targets = []
  gene_tss12_preds = []
  gene_ids = []

  # output correlations
  table_out = open('%s/tss12_cor.txt' % out_base, 'w')
  gene_tss12_cors = []

  for gene_id in gene_tss:
    # sort TSS by variance
    var_tss_list = [(tss_targets_var[tss_i],tss_i) for tss_i in gene_data.gene_tss[gene_id]]
    var_tss_list.sort(reverse=True)

    # filter for high variance
    tss_list = [tss_i for (tss_var,tss_i) in var_tss_list if tss_var > tss_var_t]

    # filter for sufficient distance from TSS1
    if len(tss_list) > 1:
      tss1_pos = gene_data.tss[tss_list[0]].pos
      tss_list = [tss_list[0]] + [tss_i for tss_i in tss_list[1:] if abs(gene_data.tss[tss_i].pos - tss1_pos) > 500]

    if len(tss_list) > 1:
      tss_i1 = tss_list[0]
      tss_i2 = tss_list[1]

      # compute log2 ratio (already log2)
      tss12_targets = tss_targets_qn[tss_i1,:] - tss_targets_qn[tss_i2,:]
      tss12_preds = tss_preds_qn[tss_i1,:] - tss_preds_qn[tss_i2,:]

      # convert
      tss12_targets = tss12_targets.astype('float32')
      tss12_preds = tss12_preds.astype('float32')

      # save values
      gene_tss12_targets.append(tss12_targets)
      gene_tss12_preds.append(tss12_preds)
      gene_ids.append(gene_id)

      # compute correlation
      pcor, p = pearsonr(tss12_targets, tss12_preds)
      gene_tss12_cors.append(pcor)

      print('%-20s  %7.4f' % (gene_id, pcor), file=table_out)

  table_out.close()

  gene_tss12_cors = np.array(gene_tss12_cors)

  # T-test PearsonR > 0
  _, tp = ttest_1samp(gene_tss12_cors, 0)

  # plot PearsonR distribution
  plt.figure(figsize=(6.5,4))
  sns.distplot(gene_tss12_cors, axlabel='TSS1/TSS2 PearsonR') # , color='black')
  ax = plt.gca()
  ax.axvline(0, linestyle='--', color='black')
  xmin, xmax = ax.get_xlim()
  ymin, ymax = ax.get_ylim()
  ax.text(xmax*0.98, ymax*0.92, 'p-val < %.2e' % p, horizontalalignment='right')
  plt.tight_layout()
  plt.savefig('%s/tss12_cor.pdf' % out_base)
  plt.close()

  # save gene values for later plotting (gene_ids's in the table)
  np.save('%s/gene_tss12_targets.npy' % out_base, np.array(gene_tss12_targets, dtype='float16'))
  np.save('%s/gene_tss12_preds.npy' % out_base, np.array(gene_tss12_preds, dtype='float16'))

  # choose a range of percentiles
  genes_out = open('%s/tss12_qgenes.txt' % out_base, 'w')
  cor_indexes = np.argsort(gene_tss12_cors)
  pct_indexes = np.linspace(0, len(cor_indexes)-1, 10+1).astype('int')
  for qi in range(len(pct_indexes)):
    pct_i = pct_indexes[qi]
    cor_i = cor_indexes[pct_i]

    out_pdf = '%s/tss12_q%d.pdf' % (out_base, qi)
    plots.regplot(gene_tss12_targets[cor_i], gene_tss12_preds[cor_i],
                  out_pdf, poly_order=1, alpha=0.8, point_size=8,
                  square=False, figsize=(4,4),
                  x_label='log2 Experiment TSS1/TSS2',
                  y_label='log2 Prediction TSS1/TSS2',
                  title=gene_ids[cor_i])

    print(qi, gene_ids[cor_i], file=genes_out)

  genes_out.close()
Beispiel #5
0
def main():
    usage = 'usage: %prog [options] <params_file> <model_file> <test_hdf5_file>'
    parser = OptionParser(usage)
    parser.add_option(
        '--ai',
        dest='accuracy_indexes',
        help=
        'Comma-separated list of target indexes to make accuracy plots comparing true versus predicted values'
    )
    parser.add_option(
        '--clip',
        dest='target_clip',
        default=None,
        type='float',
        help=
        'Clip targets and predictions to a maximum value [Default: %default]')
    parser.add_option(
        '-d',
        dest='down_sample',
        default=1,
        type='int',
        help=
        'Down sample test computation by taking uniformly spaced positions [Default: %default]'
    )
    parser.add_option('-g',
                      dest='genome_file',
                      default='%s/assembly/human.hg19.genome' %
                      os.environ['HG19'],
                      help='Chromosome length information [Default: %default]')
    parser.add_option('--mc',
                      dest='mc_n',
                      default=0,
                      type='int',
                      help='Monte carlo test iterations [Default: %default]')
    parser.add_option(
        '--peak',
        '--peaks',
        dest='peaks',
        default=False,
        action='store_true',
        help='Compute expensive peak accuracy [Default: %default]')
    parser.add_option(
        '-o',
        dest='out_dir',
        default='test_out',
        help='Output directory for test statistics [Default: %default]')
    parser.add_option(
        '--rc',
        dest='rc',
        default=False,
        action='store_true',
        help='Average the fwd and rc predictions [Default: %default]')
    parser.add_option('-s',
                      dest='scent_file',
                      help='Dimension reduction model file')
    parser.add_option('--sample',
                      dest='sample_pct',
                      default=1,
                      type='float',
                      help='Sample percentage')
    parser.add_option('--save',
                      dest='save',
                      default=False,
                      action='store_true')
    parser.add_option('--shifts',
                      dest='shifts',
                      default='0',
                      help='Ensemble prediction shifts [Default: %default]')
    parser.add_option(
        '-t',
        dest='track_bed',
        help='BED file describing regions so we can output BigWig tracks')
    parser.add_option(
        '--ti',
        dest='track_indexes',
        help='Comma-separated list of target indexes to output BigWig tracks')
    parser.add_option('--train',
                      dest='train',
                      default=False,
                      action='store_true',
                      help='Process the training set [Default: %default]')
    parser.add_option('-v',
                      dest='valid',
                      default=False,
                      action='store_true',
                      help='Process the validation set [Default: %default]')
    parser.add_option(
        '-w',
        dest='pool_width',
        default=1,
        type='int',
        help=
        'Max pool width for regressing nt predictions to predict peak calls [Default: %default]'
    )
    (options, args) = parser.parse_args()

    if len(args) != 3:
        parser.error('Must provide parameters, model, and test data HDF5')
    else:
        params_file = args[0]
        model_file = args[1]
        test_hdf5_file = args[2]

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    options.shifts = [int(shift) for shift in options.shifts.split(',')]

    #######################################################
    # load data
    #######################################################
    data_open = h5py.File(test_hdf5_file)

    if options.train:
        test_seqs = data_open['train_in']
        test_targets = data_open['train_out']
        if 'train_na' in data_open:
            test_na = data_open['train_na']

    elif options.valid:
        test_seqs = data_open['valid_in']
        test_targets = data_open['valid_out']
        test_na = None
        if 'valid_na' in data_open:
            test_na = data_open['valid_na']

    else:
        test_seqs = data_open['test_in']
        test_targets = data_open['test_out']
        test_na = None
        if 'test_na' in data_open:
            test_na = data_open['test_na']

    if options.sample_pct < 1:
        sample_n = int(test_seqs.shape[0] * options.sample_pct)
        print('Sampling %d sequences' % sample_n)
        sample_indexes = sorted(
            np.random.choice(np.arange(test_seqs.shape[0]),
                             size=sample_n,
                             replace=False))
        test_seqs = test_seqs[sample_indexes]
        test_targets = test_targets[sample_indexes]
        if test_na is not None:
            test_na = test_na[sample_indexes]

    target_labels = [tl.decode('UTF-8') for tl in data_open['target_labels']]

    #######################################################
    # model parameters and placeholders

    job = params.read_job_params(params_file)

    job['seq_length'] = test_seqs.shape[1]
    job['seq_depth'] = test_seqs.shape[2]
    job['num_targets'] = test_targets.shape[2]
    job['target_pool'] = int(np.array(data_open.get('pool_width', 1)))

    t0 = time.time()
    dr = seqnn.SeqNN()
    dr.build(job)
    print('Model building time %ds' % (time.time() - t0))

    # adjust for fourier
    job['fourier'] = 'train_out_imag' in data_open
    if job['fourier']:
        test_targets_imag = data_open['test_out_imag']
        if options.valid:
            test_targets_imag = data_open['valid_out_imag']

    # adjust for factors
    if options.scent_file is not None:
        t0 = time.time()
        test_targets_full = data_open['test_out_full']
        model = joblib.load(options.scent_file)

    #######################################################
    # test

    # initialize batcher
    if job['fourier']:
        batcher_test = batcher.BatcherF(test_seqs, test_targets,
                                        test_targets_imag, test_na,
                                        dr.hp.batch_size, dr.hp.target_pool)
    else:
        batcher_test = batcher.Batcher(test_seqs, test_targets, test_na,
                                       dr.hp.batch_size, dr.hp.target_pool)

    # initialize saver
    saver = tf.train.Saver()

    with tf.Session() as sess:
        # load variables into session
        saver.restore(sess, model_file)

        # test
        t0 = time.time()
        test_acc = dr.test(sess,
                           batcher_test,
                           rc=options.rc,
                           shifts=options.shifts,
                           mc_n=options.mc_n)

        if options.save:
            np.save('%s/preds.npy' % options.out_dir, test_acc.preds)
            np.save('%s/targets.npy' % options.out_dir, test_acc.targets)

        test_preds = test_acc.preds
        print('SeqNN test: %ds' % (time.time() - t0))

        # compute stats
        t0 = time.time()
        test_r2 = test_acc.r2(clip=options.target_clip)
        # test_log_r2 = test_acc.r2(log=True, clip=options.target_clip)
        test_pcor = test_acc.pearsonr(clip=options.target_clip)
        test_log_pcor = test_acc.pearsonr(log=True, clip=options.target_clip)
        #test_scor = test_acc.spearmanr()  # too slow; mostly driven by low values
        print('Compute stats: %ds' % (time.time() - t0))

        # print
        print('Test Loss:         %7.5f' % test_acc.loss)
        print('Test R2:           %7.5f' % test_r2.mean())
        # print('Test log R2:       %7.5f' % test_log_r2.mean())
        print('Test PearsonR:     %7.5f' % test_pcor.mean())
        print('Test log PearsonR: %7.5f' % test_log_pcor.mean())
        # print('Test SpearmanR:    %7.5f' % test_scor.mean())

        acc_out = open('%s/acc.txt' % options.out_dir, 'w')
        for ti in range(len(test_r2)):
            print('%4d  %7.5f  %.5f  %.5f  %.5f  %s' %
                  (ti, test_acc.target_losses[ti], test_r2[ti], test_pcor[ti],
                   test_log_pcor[ti], target_labels[ti]),
                  file=acc_out)
        acc_out.close()

        # print normalization factors
        target_means = test_preds.mean(axis=(0, 1), dtype='float64')
        target_means_median = np.median(target_means)
        target_means /= target_means_median
        norm_out = open('%s/normalization.txt' % options.out_dir, 'w')
        print('\n'.join([str(tu) for tu in target_means]), file=norm_out)
        norm_out.close()

        # clean up
        del test_acc

        # if test targets are reconstructed, measure versus the truth
        if options.scent_file is not None:
            compute_full_accuracy(dr, model, test_preds, test_targets_full,
                                  options.out_dir, options.down_sample)

    #######################################################
    # peak call accuracy

    if options.peaks:
        # sample every few bins to decrease correlations
        ds_indexes_preds = np.arange(0, test_preds.shape[1], 8)
        ds_indexes_targets = ds_indexes_preds + (dr.hp.batch_buffer //
                                                 dr.hp.target_pool)

        aurocs = []
        auprcs = []

        peaks_out = open('%s/peaks.txt' % options.out_dir, 'w')
        for ti in range(test_targets.shape[2]):
            if options.scent_file is not None:
                test_targets_ti = test_targets_full[:, :, ti]
            else:
                test_targets_ti = test_targets[:, :, ti]

            # subset and flatten
            test_targets_ti_flat = test_targets_ti[:,
                                                   ds_indexes_targets].flatten(
                                                   ).astype('float32')
            test_preds_ti_flat = test_preds[:, ds_indexes_preds,
                                            ti].flatten().astype('float32')

            # call peaks
            test_targets_ti_lambda = np.mean(test_targets_ti_flat)
            test_targets_pvals = 1 - poisson.cdf(
                np.round(test_targets_ti_flat) - 1, mu=test_targets_ti_lambda)
            test_targets_qvals = np.array(ben_hoch(test_targets_pvals))
            test_targets_peaks = test_targets_qvals < 0.01

            if test_targets_peaks.sum() == 0:
                aurocs.append(0.5)
                auprcs.append(0)

            else:
                # compute prediction accuracy
                aurocs.append(
                    roc_auc_score(test_targets_peaks, test_preds_ti_flat))
                auprcs.append(
                    average_precision_score(test_targets_peaks,
                                            test_preds_ti_flat))

            print('%4d  %6d  %.5f  %.5f' %
                  (ti, test_targets_peaks.sum(), aurocs[-1], auprcs[-1]),
                  file=peaks_out)

        peaks_out.close()

        print('Test AUROC:     %7.5f' % np.mean(aurocs))
        print('Test AUPRC:     %7.5f' % np.mean(auprcs))

    #######################################################
    # BigWig tracks

    # NOTE: THESE ASSUME THERE WAS NO DOWN-SAMPLING ABOVE

    # print bigwig tracks for visualization
    if options.track_bed:
        if options.genome_file is None:
            parser.error(
                'Must provide genome file in order to print valid BigWigs')

        if not os.path.isdir('%s/tracks' % options.out_dir):
            os.mkdir('%s/tracks' % options.out_dir)

        track_indexes = range(test_preds.shape[2])
        if options.track_indexes:
            track_indexes = [
                int(ti) for ti in options.track_indexes.split(',')
            ]

        bed_set = 'test'
        if options.valid:
            bed_set = 'valid'

        for ti in track_indexes:
            if options.scent_file is not None:
                test_targets_ti = test_targets_full[:, :, ti]
            else:
                test_targets_ti = test_targets[:, :, ti]

            # make true targets bigwig
            bw_file = '%s/tracks/t%d_true.bw' % (options.out_dir, ti)
            bigwig_write(bw_file,
                         test_targets_ti,
                         options.track_bed,
                         options.genome_file,
                         bed_set=bed_set)

            # make predictions bigwig
            bw_file = '%s/tracks/t%d_preds.bw' % (options.out_dir, ti)
            bigwig_write(bw_file,
                         test_preds[:, :, ti],
                         options.track_bed,
                         options.genome_file,
                         dr.hp.batch_buffer,
                         bed_set=bed_set)

        # make NA bigwig
        bw_file = '%s/tracks/na.bw' % options.out_dir
        bigwig_write(bw_file,
                     test_na,
                     options.track_bed,
                     options.genome_file,
                     bed_set=bed_set)

    #######################################################
    # accuracy plots

    if options.accuracy_indexes is not None:
        accuracy_indexes = [
            int(ti) for ti in options.accuracy_indexes.split(',')
        ]

        if not os.path.isdir('%s/scatter' % options.out_dir):
            os.mkdir('%s/scatter' % options.out_dir)

        if not os.path.isdir('%s/violin' % options.out_dir):
            os.mkdir('%s/violin' % options.out_dir)

        if not os.path.isdir('%s/roc' % options.out_dir):
            os.mkdir('%s/roc' % options.out_dir)

        if not os.path.isdir('%s/pr' % options.out_dir):
            os.mkdir('%s/pr' % options.out_dir)

        for ti in accuracy_indexes:
            if options.scent_file is not None:
                test_targets_ti = test_targets_full[:, :, ti]
            else:
                test_targets_ti = test_targets[:, :, ti]

            ############################################
            # scatter

            # sample every few bins (adjust to plot the # points I want)
            ds_indexes_preds = np.arange(0, test_preds.shape[1], 8)
            ds_indexes_targets = ds_indexes_preds + (dr.hp.batch_buffer //
                                                     dr.hp.target_pool)

            # subset and flatten
            test_targets_ti_flat = test_targets_ti[:,
                                                   ds_indexes_targets].flatten(
                                                   ).astype('float32')
            test_preds_ti_flat = test_preds[:, ds_indexes_preds,
                                            ti].flatten().astype('float32')

            # take log2
            test_targets_ti_log = np.log2(test_targets_ti_flat + 1)
            test_preds_ti_log = np.log2(test_preds_ti_flat + 1)

            # plot log2
            sns.set(font_scale=1.2, style='ticks')
            out_pdf = '%s/scatter/t%d.pdf' % (options.out_dir, ti)
            plots.regplot(test_targets_ti_log,
                          test_preds_ti_log,
                          out_pdf,
                          poly_order=1,
                          alpha=0.3,
                          sample=500,
                          figsize=(6, 6),
                          x_label='log2 Experiment',
                          y_label='log2 Prediction',
                          table=True)

            ############################################
            # violin

            # call peaks
            test_targets_ti_lambda = np.mean(test_targets_ti_flat)
            test_targets_pvals = 1 - poisson.cdf(
                np.round(test_targets_ti_flat) - 1, mu=test_targets_ti_lambda)
            test_targets_qvals = np.array(ben_hoch(test_targets_pvals))
            test_targets_peaks = test_targets_qvals < 0.01
            test_targets_peaks_str = np.where(test_targets_peaks, 'Peak',
                                              'Background')

            # violin plot
            sns.set(font_scale=1.3, style='ticks')
            plt.figure()
            df = pd.DataFrame({
                'log2 Prediction':
                np.log2(test_preds_ti_flat + 1),
                'Experimental coverage status':
                test_targets_peaks_str
            })
            ax = sns.violinplot(x='Experimental coverage status',
                                y='log2 Prediction',
                                data=df)
            ax.grid(True, linestyle=':')
            plt.savefig('%s/violin/t%d.pdf' % (options.out_dir, ti))
            plt.close()

            # ROC
            plt.figure()
            fpr, tpr, _ = roc_curve(test_targets_peaks, test_preds_ti_flat)
            auroc = roc_auc_score(test_targets_peaks, test_preds_ti_flat)
            plt.plot([0, 1], [0, 1],
                     c='black',
                     linewidth=1,
                     linestyle='--',
                     alpha=0.7)
            plt.plot(fpr, tpr, c='black')
            ax = plt.gca()
            ax.set_xlabel('False positive rate')
            ax.set_ylabel('True positive rate')
            ax.text(0.99,
                    0.02,
                    'AUROC %.3f' % auroc,
                    horizontalalignment='right')  # , fontsize=14)
            ax.grid(True, linestyle=':')
            plt.savefig('%s/roc/t%d.pdf' % (options.out_dir, ti))
            plt.close()

            # PR
            plt.figure()
            prec, recall, _ = precision_recall_curve(test_targets_peaks,
                                                     test_preds_ti_flat)
            auprc = average_precision_score(test_targets_peaks,
                                            test_preds_ti_flat)
            plt.axhline(y=test_targets_peaks.mean(),
                        c='black',
                        linewidth=1,
                        linestyle='--',
                        alpha=0.7)
            plt.plot(recall, prec, c='black')
            ax = plt.gca()
            ax.set_xlabel('Recall')
            ax.set_ylabel('Precision')
            ax.text(0.99,
                    0.95,
                    'AUPRC %.3f' % auprc,
                    horizontalalignment='right')  # , fontsize=14)
            ax.grid(True, linestyle=':')
            plt.savefig('%s/pr/t%d.pdf' % (options.out_dir, ti))
            plt.close()

    data_open.close()
Beispiel #6
0
def main():
    usage = 'usage: %prog [options] <params_file> <model_file> <data_dir>'
    parser = OptionParser(usage)
    parser.add_option(
        '--ai',
        dest='accuracy_indexes',
        help=
        'Comma-separated list of target indexes to make accuracy scatter plots.'
    )
    parser.add_option(
        '--clip',
        dest='target_clip',
        default=None,
        type='float',
        help=
        'Clip targets and predictions to a maximum value [Default: %default]')
    parser.add_option(
        '-d',
        dest='down_sample',
        default=1,
        type='int',
        help=
        'Down sample by taking uniformly spaced positions [Default: %default]')
    parser.add_option('-g',
                      dest='genome_file',
                      default='%s/tutorials/data/human.hg19.genome' %
                      os.environ['BASENJIDIR'],
                      help='Chromosome length information [Default: %default]')
    parser.add_option('--mc',
                      dest='mc_n',
                      default=0,
                      type='int',
                      help='Monte carlo test iterations [Default: %default]')
    parser.add_option(
        '--peak',
        '--peaks',
        dest='peaks',
        default=False,
        action='store_true',
        help='Compute expensive peak accuracy [Default: %default]')
    parser.add_option(
        '-o',
        dest='out_dir',
        default='test_out',
        help='Output directory for test statistics [Default: %default]')
    parser.add_option(
        '--rc',
        dest='rc',
        default=False,
        action='store_true',
        help='Average the fwd and rc predictions [Default: %default]')
    parser.add_option(
        '--save',
        dest='save',
        default=False,
        action='store_true',
        help='Save targets and predictions numpy arrays [Default: %default]')
    parser.add_option('--shifts',
                      dest='shifts',
                      default='0',
                      help='Ensemble prediction shifts [Default: %default]')
    parser.add_option(
        '-t',
        dest='track_bed',
        help='BED file describing regions so we can output BigWig tracks')
    parser.add_option(
        '--ti',
        dest='track_indexes',
        help='Comma-separated list of target indexes to output BigWig tracks')
    parser.add_option(
        '--tfr',
        dest='tfr_pattern',
        default='test-*.tfr',
        help='TFR pattern string appended to data_dir [Default: %default]')
    parser.add_option(
        '-w',
        dest='pool_width',
        default=1,
        type='int',
        help=
        'Max pool width for regressing nt preds to peak calls [Default: %default]'
    )
    (options, args) = parser.parse_args()

    if len(args) != 3:
        parser.error('Must provide parameters, model, and test data HDF5')
    else:
        params_file = args[0]
        model_file = args[1]
        data_dir = args[2]

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    # parse shifts to integers
    options.shifts = [int(shift) for shift in options.shifts.split(',')]

    # read targets
    targets_file = '%s/targets.txt' % data_dir
    targets_df = pd.read_table(targets_file)

    # read model parameters
    job = params.read_job_params(params_file)

    # construct data ops
    tfr_pattern_path = '%s/tfrecords/%s' % (data_dir, options.tfr_pattern)
    data_ops, test_init_op = make_data_ops(job, tfr_pattern_path)

    # initialize model
    model = seqnn.SeqNN()
    model.build_from_data_ops(job,
                              data_ops,
                              ensemble_rc=options.rc,
                              ensemble_shifts=options.shifts)

    # initialize saver
    saver = tf.train.Saver()

    with tf.Session() as sess:
        # start queue runners
        coord = tf.train.Coordinator()
        tf.train.start_queue_runners(coord=coord)

        # load variables into session
        saver.restore(sess, model_file)

        # test
        t0 = time.time()
        sess.run(test_init_op)
        test_acc = model.test_tfr(sess)

        test_preds = test_acc.preds
        print('SeqNN test: %ds' % (time.time() - t0))

        if options.save:
            np.save('%s/preds.npy' % options.out_dir, test_acc.preds)
            np.save('%s/targets.npy' % options.out_dir, test_acc.targets)

        # compute stats
        t0 = time.time()
        test_r2 = test_acc.r2(clip=options.target_clip)
        # test_log_r2 = test_acc.r2(log=True, clip=options.target_clip)
        test_pcor = test_acc.pearsonr(clip=options.target_clip)
        test_log_pcor = test_acc.pearsonr(log=True, clip=options.target_clip)
        #test_scor = test_acc.spearmanr()  # too slow; mostly driven by low values
        print('Compute stats: %ds' % (time.time() - t0))

        # print
        print('Test Loss:         %7.5f' % test_acc.loss)
        print('Test R2:           %7.5f' % test_r2.mean())
        # print('Test log R2:       %7.5f' % test_log_r2.mean())
        print('Test PearsonR:     %7.5f' % test_pcor.mean())
        print('Test log PearsonR: %7.5f' % test_log_pcor.mean())
        # print('Test SpearmanR:    %7.5f' % test_scor.mean())

        acc_out = open('%s/acc.txt' % options.out_dir, 'w')
        for ti in range(len(test_r2)):
            print('%4d  %7.5f  %.5f  %.5f  %.5f  %s' %
                  (ti, test_acc.target_losses[ti], test_r2[ti], test_pcor[ti],
                   test_log_pcor[ti], targets_df.description.iloc[ti]),
                  file=acc_out)
        acc_out.close()

        # print normalization factors
        target_means = test_preds.mean(axis=(0, 1), dtype='float64')
        target_means_median = np.median(target_means)
        target_means /= target_means_median
        norm_out = open('%s/normalization.txt' % options.out_dir, 'w')
        print('\n'.join([str(tu) for tu in target_means]), file=norm_out)
        norm_out.close()

        # clean up
        del test_acc

    #######################################################
    # peak call accuracy

    if options.peaks:
        # sample every few bins to decrease correlations
        ds_indexes_preds = np.arange(0, test_preds.shape[1], 8)
        ds_indexes_targets = ds_indexes_preds + (model.hp.batch_buffer //
                                                 model.hp.target_pool)

        aurocs = []
        auprcs = []

        peaks_out = open('%s/peaks.txt' % options.out_dir, 'w')
        for ti in range(test_targets.shape[2]):
            test_targets_ti = test_targets[:, :, ti]

            # subset and flatten
            test_targets_ti_flat = test_targets_ti[:,
                                                   ds_indexes_targets].flatten(
                                                   ).astype('float32')
            test_preds_ti_flat = test_preds[:, ds_indexes_preds,
                                            ti].flatten().astype('float32')

            # call peaks
            test_targets_ti_lambda = np.mean(test_targets_ti_flat)
            test_targets_pvals = 1 - poisson.cdf(
                np.round(test_targets_ti_flat) - 1, mu=test_targets_ti_lambda)
            test_targets_qvals = np.array(ben_hoch(test_targets_pvals))
            test_targets_peaks = test_targets_qvals < 0.01

            if test_targets_peaks.sum() == 0:
                aurocs.append(0.5)
                auprcs.append(0)

            else:
                # compute prediction accuracy
                aurocs.append(
                    roc_auc_score(test_targets_peaks, test_preds_ti_flat))
                auprcs.append(
                    average_precision_score(test_targets_peaks,
                                            test_preds_ti_flat))

            print('%4d  %6d  %.5f  %.5f' %
                  (ti, test_targets_peaks.sum(), aurocs[-1], auprcs[-1]),
                  file=peaks_out)

        peaks_out.close()

        print('Test AUROC:     %7.5f' % np.mean(aurocs))
        print('Test AUPRC:     %7.5f' % np.mean(auprcs))

    #######################################################
    # BigWig tracks

    # NOTE: THESE ASSUME THERE WAS NO DOWN-SAMPLING ABOVE

    # print bigwig tracks for visualization
    if options.track_bed:
        if options.genome_file is None:
            parser.error(
                'Must provide genome file in order to print valid BigWigs')

        if not os.path.isdir('%s/tracks' % options.out_dir):
            os.mkdir('%s/tracks' % options.out_dir)

        track_indexes = range(test_preds.shape[2])
        if options.track_indexes:
            track_indexes = [
                int(ti) for ti in options.track_indexes.split(',')
            ]

        bed_set = 'test'
        if options.valid:
            bed_set = 'valid'

        for ti in track_indexes:
            test_targets_ti = test_targets[:, :, ti]

            # make true targets bigwig
            bw_file = '%s/tracks/t%d_true.bw' % (options.out_dir, ti)
            bigwig_write(bw_file,
                         test_targets_ti,
                         options.track_bed,
                         options.genome_file,
                         bed_set=bed_set)

            # make predictions bigwig
            bw_file = '%s/tracks/t%d_preds.bw' % (options.out_dir, ti)
            bigwig_write(bw_file,
                         test_preds[:, :, ti],
                         options.track_bed,
                         options.genome_file,
                         model.hp.batch_buffer,
                         bed_set=bed_set)

        # make NA bigwig
        # bw_file = '%s/tracks/na.bw' % options.out_dir
        # bigwig_write(
        #     bw_file,
        #     test_na,
        #     options.track_bed,
        #     options.genome_file,
        #     bed_set=bed_set)

    #######################################################
    # accuracy plots

    if options.accuracy_indexes is not None:
        accuracy_indexes = [
            int(ti) for ti in options.accuracy_indexes.split(',')
        ]

        if not os.path.isdir('%s/scatter' % options.out_dir):
            os.mkdir('%s/scatter' % options.out_dir)

        if not os.path.isdir('%s/violin' % options.out_dir):
            os.mkdir('%s/violin' % options.out_dir)

        if not os.path.isdir('%s/roc' % options.out_dir):
            os.mkdir('%s/roc' % options.out_dir)

        if not os.path.isdir('%s/pr' % options.out_dir):
            os.mkdir('%s/pr' % options.out_dir)

        for ti in accuracy_indexes:
            test_targets_ti = test_targets[:, :, ti]

            ############################################
            # scatter

            # sample every few bins (adjust to plot the # points I want)
            ds_indexes_preds = np.arange(0, test_preds.shape[1], 8)
            ds_indexes_targets = ds_indexes_preds + (model.hp.batch_buffer //
                                                     model.hp.target_pool)

            # subset and flatten
            test_targets_ti_flat = test_targets_ti[:,
                                                   ds_indexes_targets].flatten(
                                                   ).astype('float32')
            test_preds_ti_flat = test_preds[:, ds_indexes_preds,
                                            ti].flatten().astype('float32')

            # take log2
            test_targets_ti_log = np.log2(test_targets_ti_flat + 1)
            test_preds_ti_log = np.log2(test_preds_ti_flat + 1)

            # plot log2
            sns.set(font_scale=1.2, style='ticks')
            out_pdf = '%s/scatter/t%d.pdf' % (options.out_dir, ti)
            plots.regplot(test_targets_ti_log,
                          test_preds_ti_log,
                          out_pdf,
                          poly_order=1,
                          alpha=0.3,
                          sample=500,
                          figsize=(6, 6),
                          x_label='log2 Experiment',
                          y_label='log2 Prediction',
                          table=True)

            ############################################
            # violin

            # call peaks
            test_targets_ti_lambda = np.mean(test_targets_ti_flat)
            test_targets_pvals = 1 - poisson.cdf(
                np.round(test_targets_ti_flat) - 1, mu=test_targets_ti_lambda)
            test_targets_qvals = np.array(ben_hoch(test_targets_pvals))
            test_targets_peaks = test_targets_qvals < 0.01
            test_targets_peaks_str = np.where(test_targets_peaks, 'Peak',
                                              'Background')

            # violin plot
            sns.set(font_scale=1.3, style='ticks')
            plt.figure()
            df = pd.DataFrame({
                'log2 Prediction':
                np.log2(test_preds_ti_flat + 1),
                'Experimental coverage status':
                test_targets_peaks_str
            })
            ax = sns.violinplot(x='Experimental coverage status',
                                y='log2 Prediction',
                                data=df)
            ax.grid(True, linestyle=':')
            plt.savefig('%s/violin/t%d.pdf' % (options.out_dir, ti))
            plt.close()

            # ROC
            plt.figure()
            fpr, tpr, _ = roc_curve(test_targets_peaks, test_preds_ti_flat)
            auroc = roc_auc_score(test_targets_peaks, test_preds_ti_flat)
            plt.plot([0, 1], [0, 1],
                     c='black',
                     linewidth=1,
                     linestyle='--',
                     alpha=0.7)
            plt.plot(fpr, tpr, c='black')
            ax = plt.gca()
            ax.set_xlabel('False positive rate')
            ax.set_ylabel('True positive rate')
            ax.text(0.99,
                    0.02,
                    'AUROC %.3f' % auroc,
                    horizontalalignment='right')  # , fontsize=14)
            ax.grid(True, linestyle=':')
            plt.savefig('%s/roc/t%d.pdf' % (options.out_dir, ti))
            plt.close()

            # PR
            plt.figure()
            prec, recall, _ = precision_recall_curve(test_targets_peaks,
                                                     test_preds_ti_flat)
            auprc = average_precision_score(test_targets_peaks,
                                            test_preds_ti_flat)
            plt.axhline(y=test_targets_peaks.mean(),
                        c='black',
                        linewidth=1,
                        linestyle='--',
                        alpha=0.7)
            plt.plot(recall, prec, c='black')
            ax = plt.gca()
            ax.set_xlabel('Recall')
            ax.set_ylabel('Precision')
            ax.text(0.99,
                    0.95,
                    'AUPRC %.3f' % auprc,
                    horizontalalignment='right')  # , fontsize=14)
            ax.grid(True, linestyle=':')
            plt.savefig('%s/pr/t%d.pdf' % (options.out_dir, ti))
            plt.close()
Beispiel #7
0
def main():
    usage = "usage: %prog [options] <params_file> <model_file> <test_hdf5_file>"
    parser = OptionParser(usage)
    parser.add_option(
        "--ai",
        dest="accuracy_indexes",
        help=
        "Comma-separated list of target indexes to make accuracy plots comparing true versus predicted values",
    )
    parser.add_option(
        "--clip",
        dest="target_clip",
        default=None,
        type="float",
        help=
        "Clip targets and predictions to a maximum value [Default: %default]",
    )
    parser.add_option(
        "-d",
        dest="down_sample",
        default=1,
        type="int",
        help=
        "Down sample test computation by taking uniformly spaced positions [Default: %default]",
    )
    parser.add_option(
        "-g",
        dest="genome_file",
        default="%s/data/human.hg19.genome" % os.environ["BASENJIDIR"],
        help="Chromosome length information [Default: %default]",
    )
    parser.add_option(
        "--mc",
        dest="mc_n",
        default=0,
        type="int",
        help="Monte carlo test iterations [Default: %default]",
    )
    parser.add_option(
        "--peak",
        "--peaks",
        dest="peaks",
        default=False,
        action="store_true",
        help="Compute expensive peak accuracy [Default: %default]",
    )
    parser.add_option(
        "-o",
        dest="out_dir",
        default="test_out",
        help="Output directory for test statistics [Default: %default]",
    )
    parser.add_option(
        "--rc",
        dest="rc",
        default=False,
        action="store_true",
        help="Average the fwd and rc predictions [Default: %default]",
    )
    parser.add_option("-s",
                      dest="scent_file",
                      help="Dimension reduction model file")
    parser.add_option("--sample",
                      dest="sample_pct",
                      default=1,
                      type="float",
                      help="Sample percentage")
    parser.add_option("--save",
                      dest="save",
                      default=False,
                      action="store_true")
    parser.add_option(
        "--shifts",
        dest="shifts",
        default="0",
        help="Ensemble prediction shifts [Default: %default]",
    )
    parser.add_option(
        "-t",
        dest="track_bed",
        help="BED file describing regions so we can output BigWig tracks",
    )
    parser.add_option(
        "--ti",
        dest="track_indexes",
        help="Comma-separated list of target indexes to output BigWig tracks",
    )
    parser.add_option(
        "--train",
        dest="train",
        default=False,
        action="store_true",
        help="Process the training set [Default: %default]",
    )
    parser.add_option(
        "-v",
        dest="valid",
        default=False,
        action="store_true",
        help="Process the validation set [Default: %default]",
    )
    parser.add_option(
        "-w",
        dest="pool_width",
        default=1,
        type="int",
        help=
        "Max pool width for regressing nt predictions to predict peak calls [Default: %default]",
    )
    (options, args) = parser.parse_args()

    if len(args) != 3:
        parser.error("Must provide parameters, model, and test data HDF5")
    else:
        params_file = args[0]
        model_file = args[1]
        test_hdf5_file = args[2]

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    options.shifts = [int(shift) for shift in options.shifts.split(",")]

    #######################################################
    # load data
    #######################################################
    data_open = h5py.File(test_hdf5_file)

    if options.train:
        test_seqs = data_open["train_in"]
        test_targets = data_open["train_out"]
        if "train_na" in data_open:
            test_na = data_open["train_na"]

    elif options.valid:
        test_seqs = data_open["valid_in"]
        test_targets = data_open["valid_out"]
        test_na = None
        if "valid_na" in data_open:
            test_na = data_open["valid_na"]

    else:
        test_seqs = data_open["test_in"]
        test_targets = data_open["test_out"]
        test_na = None
        if "test_na" in data_open:
            test_na = data_open["test_na"]

    if options.sample_pct < 1:
        sample_n = int(test_seqs.shape[0] * options.sample_pct)
        print("Sampling %d sequences" % sample_n)
        sample_indexes = sorted(
            np.random.choice(np.arange(test_seqs.shape[0]),
                             size=sample_n,
                             replace=False))
        test_seqs = test_seqs[sample_indexes]
        test_targets = test_targets[sample_indexes]
        if test_na is not None:
            test_na = test_na[sample_indexes]

    target_labels = [tl.decode("UTF-8") for tl in data_open["target_labels"]]

    #######################################################
    # model parameters and placeholders

    job = params.read_job_params(params_file)
    job["seq_length"] = test_seqs.shape[1]
    job["seq_depth"] = test_seqs.shape[2]
    job["num_targets"] = test_targets.shape[2]
    job["target_pool"] = int(np.array(data_open.get("pool_width", 1)))

    t0 = time.time()
    dr = seqnn.SeqNN()
    dr.build_feed(job)
    print("Model building time %ds" % (time.time() - t0))

    # adjust for fourier
    job["fourier"] = "train_out_imag" in data_open
    if job["fourier"]:
        test_targets_imag = data_open["test_out_imag"]
        if options.valid:
            test_targets_imag = data_open["valid_out_imag"]

    # adjust for factors
    if options.scent_file is not None:
        t0 = time.time()
        test_targets_full = data_open["test_out_full"]
        model = joblib.load(options.scent_file)

    #######################################################
    # test

    # initialize batcher
    if job["fourier"]:
        batcher_test = batcher.BatcherF(
            test_seqs,
            test_targets,
            test_targets_imag,
            test_na,
            dr.hp.batch_size,
            dr.hp.target_pool,
        )
    else:
        batcher_test = batcher.Batcher(test_seqs, test_targets, test_na,
                                       dr.hp.batch_size, dr.hp.target_pool)

    # initialize saver
    saver = tf.train.Saver()

    with tf.Session() as sess:
        # load variables into session
        saver.restore(sess, model_file)

        # test
        t0 = time.time()
        test_acc = dr.test_h5_manual(sess,
                                     batcher_test,
                                     rc=options.rc,
                                     shifts=options.shifts,
                                     mc_n=options.mc_n)

        if options.save:
            np.save("%s/preds.npy" % options.out_dir, test_acc.preds)
            np.save("%s/targets.npy" % options.out_dir, test_acc.targets)

        test_preds = test_acc.preds
        print("SeqNN test: %ds" % (time.time() - t0))

        # compute stats
        t0 = time.time()
        test_r2 = test_acc.r2(clip=options.target_clip)
        # test_log_r2 = test_acc.r2(log=True, clip=options.target_clip)
        test_pcor = test_acc.pearsonr(clip=options.target_clip)
        test_log_pcor = test_acc.pearsonr(log=True, clip=options.target_clip)
        # test_scor = test_acc.spearmanr()  # too slow; mostly driven by low values
        print("Compute stats: %ds" % (time.time() - t0))

        # print
        print("Test Loss:         %7.5f" % test_acc.loss)
        print("Test R2:           %7.5f" % test_r2.mean())
        # print('Test log R2:       %7.5f' % test_log_r2.mean())
        print("Test PearsonR:     %7.5f" % test_pcor.mean())
        print("Test log PearsonR: %7.5f" % test_log_pcor.mean())
        # print('Test SpearmanR:    %7.5f' % test_scor.mean())

        acc_out = open("%s/acc.txt" % options.out_dir, "w")
        for ti in range(len(test_r2)):
            print(
                "%4d  %7.5f  %.5f  %.5f  %.5f  %s" % (
                    ti,
                    test_acc.target_losses[ti],
                    test_r2[ti],
                    test_pcor[ti],
                    test_log_pcor[ti],
                    target_labels[ti],
                ),
                file=acc_out,
            )
        acc_out.close()

        # print normalization factors
        target_means = test_preds.mean(axis=(0, 1), dtype="float64")
        target_means_median = np.median(target_means)
        target_means /= target_means_median
        norm_out = open("%s/normalization.txt" % options.out_dir, "w")
        print("\n".join([str(tu) for tu in target_means]), file=norm_out)
        norm_out.close()

        # clean up
        del test_acc

        # if test targets are reconstructed, measure versus the truth
        if options.scent_file is not None:
            compute_full_accuracy(
                dr,
                model,
                test_preds,
                test_targets_full,
                options.out_dir,
                options.down_sample,
            )

    #######################################################
    # peak call accuracy

    if options.peaks:
        # sample every few bins to decrease correlations
        ds_indexes_preds = np.arange(0, test_preds.shape[1], 8)
        ds_indexes_targets = ds_indexes_preds + (dr.hp.batch_buffer //
                                                 dr.hp.target_pool)

        aurocs = []
        auprcs = []

        peaks_out = open("%s/peaks.txt" % options.out_dir, "w")
        for ti in range(test_targets.shape[2]):
            if options.scent_file is not None:
                test_targets_ti = test_targets_full[:, :, ti]
            else:
                test_targets_ti = test_targets[:, :, ti]

            # subset and flatten
            test_targets_ti_flat = (test_targets_ti[:, ds_indexes_targets].
                                    flatten().astype("float32"))
            test_preds_ti_flat = (test_preds[:, ds_indexes_preds,
                                             ti].flatten().astype("float32"))

            # call peaks
            test_targets_ti_lambda = np.mean(test_targets_ti_flat)
            test_targets_pvals = 1 - poisson.cdf(
                np.round(test_targets_ti_flat) - 1, mu=test_targets_ti_lambda)
            test_targets_qvals = np.array(ben_hoch(test_targets_pvals))
            test_targets_peaks = test_targets_qvals < 0.01

            if test_targets_peaks.sum() == 0:
                aurocs.append(0.5)
                auprcs.append(0)

            else:
                # compute prediction accuracy
                aurocs.append(
                    roc_auc_score(test_targets_peaks, test_preds_ti_flat))
                auprcs.append(
                    average_precision_score(test_targets_peaks,
                                            test_preds_ti_flat))

            print(
                "%4d  %6d  %.5f  %.5f" %
                (ti, test_targets_peaks.sum(), aurocs[-1], auprcs[-1]),
                file=peaks_out,
            )

        peaks_out.close()

        print("Test AUROC:     %7.5f" % np.mean(aurocs))
        print("Test AUPRC:     %7.5f" % np.mean(auprcs))

    #######################################################
    # BigWig tracks

    # NOTE: THESE ASSUME THERE WAS NO DOWN-SAMPLING ABOVE

    # print bigwig tracks for visualization
    if options.track_bed:
        if options.genome_file is None:
            parser.error(
                "Must provide genome file in order to print valid BigWigs")

        if not os.path.isdir("%s/tracks" % options.out_dir):
            os.mkdir("%s/tracks" % options.out_dir)

        track_indexes = range(test_preds.shape[2])
        if options.track_indexes:
            track_indexes = [
                int(ti) for ti in options.track_indexes.split(",")
            ]

        bed_set = "test"
        if options.valid:
            bed_set = "valid"

        for ti in track_indexes:
            if options.scent_file is not None:
                test_targets_ti = test_targets_full[:, :, ti]
            else:
                test_targets_ti = test_targets[:, :, ti]

            # make true targets bigwig
            bw_file = "%s/tracks/t%d_true.bw" % (options.out_dir, ti)
            bigwig_write(
                bw_file,
                test_targets_ti,
                options.track_bed,
                options.genome_file,
                bed_set=bed_set,
            )

            # make predictions bigwig
            bw_file = "%s/tracks/t%d_preds.bw" % (options.out_dir, ti)
            bigwig_write(
                bw_file,
                test_preds[:, :, ti],
                options.track_bed,
                options.genome_file,
                dr.hp.batch_buffer,
                bed_set=bed_set,
            )

        # make NA bigwig
        # bw_file = '%s/tracks/na.bw' % options.out_dir
        # bigwig_write(
        #     bw_file,
        #     test_na,
        #     options.track_bed,
        #     options.genome_file,
        #     bed_set=bed_set)

    #######################################################
    # accuracy plots

    if options.accuracy_indexes is not None:
        accuracy_indexes = [
            int(ti) for ti in options.accuracy_indexes.split(",")
        ]

        if not os.path.isdir("%s/scatter" % options.out_dir):
            os.mkdir("%s/scatter" % options.out_dir)

        if not os.path.isdir("%s/violin" % options.out_dir):
            os.mkdir("%s/violin" % options.out_dir)

        if not os.path.isdir("%s/roc" % options.out_dir):
            os.mkdir("%s/roc" % options.out_dir)

        if not os.path.isdir("%s/pr" % options.out_dir):
            os.mkdir("%s/pr" % options.out_dir)

        for ti in accuracy_indexes:
            if options.scent_file is not None:
                test_targets_ti = test_targets_full[:, :, ti]
            else:
                test_targets_ti = test_targets[:, :, ti]

            ############################################
            # scatter

            # sample every few bins (adjust to plot the # points I want)
            ds_indexes_preds = np.arange(0, test_preds.shape[1], 8)
            ds_indexes_targets = ds_indexes_preds + (dr.hp.batch_buffer //
                                                     dr.hp.target_pool)

            # subset and flatten
            test_targets_ti_flat = (test_targets_ti[:, ds_indexes_targets].
                                    flatten().astype("float32"))
            test_preds_ti_flat = (test_preds[:, ds_indexes_preds,
                                             ti].flatten().astype("float32"))

            # take log2
            test_targets_ti_log = np.log2(test_targets_ti_flat + 1)
            test_preds_ti_log = np.log2(test_preds_ti_flat + 1)

            # plot log2
            sns.set(font_scale=1.2, style="ticks")
            out_pdf = "%s/scatter/t%d.pdf" % (options.out_dir, ti)
            plots.regplot(
                test_targets_ti_log,
                test_preds_ti_log,
                out_pdf,
                poly_order=1,
                alpha=0.3,
                sample=500,
                figsize=(6, 6),
                x_label="log2 Experiment",
                y_label="log2 Prediction",
                table=True,
            )

            ############################################
            # violin

            # call peaks
            test_targets_ti_lambda = np.mean(test_targets_ti_flat)
            test_targets_pvals = 1 - poisson.cdf(
                np.round(test_targets_ti_flat) - 1, mu=test_targets_ti_lambda)
            test_targets_qvals = np.array(ben_hoch(test_targets_pvals))
            test_targets_peaks = test_targets_qvals < 0.01
            test_targets_peaks_str = np.where(test_targets_peaks, "Peak",
                                              "Background")

            # violin plot
            sns.set(font_scale=1.3, style="ticks")
            plt.figure()
            df = pd.DataFrame({
                "log2 Prediction":
                np.log2(test_preds_ti_flat + 1),
                "Experimental coverage status":
                test_targets_peaks_str,
            })
            ax = sns.violinplot(x="Experimental coverage status",
                                y="log2 Prediction",
                                data=df)
            ax.grid(True, linestyle=":")
            plt.savefig("%s/violin/t%d.pdf" % (options.out_dir, ti))
            plt.close()

            # ROC
            plt.figure()
            fpr, tpr, _ = roc_curve(test_targets_peaks, test_preds_ti_flat)
            auroc = roc_auc_score(test_targets_peaks, test_preds_ti_flat)
            plt.plot([0, 1], [0, 1],
                     c="black",
                     linewidth=1,
                     linestyle="--",
                     alpha=0.7)
            plt.plot(fpr, tpr, c="black")
            ax = plt.gca()
            ax.set_xlabel("False positive rate")
            ax.set_ylabel("True positive rate")
            ax.text(0.99,
                    0.02,
                    "AUROC %.3f" % auroc,
                    horizontalalignment="right")  # , fontsize=14)
            ax.grid(True, linestyle=":")
            plt.savefig("%s/roc/t%d.pdf" % (options.out_dir, ti))
            plt.close()

            # PR
            plt.figure()
            prec, recall, _ = precision_recall_curve(test_targets_peaks,
                                                     test_preds_ti_flat)
            auprc = average_precision_score(test_targets_peaks,
                                            test_preds_ti_flat)
            plt.axhline(
                y=test_targets_peaks.mean(),
                c="black",
                linewidth=1,
                linestyle="--",
                alpha=0.7,
            )
            plt.plot(recall, prec, c="black")
            ax = plt.gca()
            ax.set_xlabel("Recall")
            ax.set_ylabel("Precision")
            ax.text(0.99,
                    0.95,
                    "AUPRC %.3f" % auprc,
                    horizontalalignment="right")  # , fontsize=14)
            ax.grid(True, linestyle=":")
            plt.savefig("%s/pr/t%d.pdf" % (options.out_dir, ti))
            plt.close()

    data_open.close()
Beispiel #8
0
def main():
    usage = "usage: %prog [options] <params_file> <model_file> <data_dir>"
    parser = OptionParser(usage)
    parser.add_option(
        "--ai",
        dest="accuracy_indexes",
        help=
        "Comma-separated list of target indexes to make accuracy scatter plots.",
    )
    parser.add_option(
        "-b",
        dest="track_bed",
        help="BED file describing regions so we can output BigWig tracks",
    )
    parser.add_option(
        "--clip",
        dest="target_clip",
        default=None,
        type="float",
        help=
        "Clip targets and predictions to a maximum value [Default: %default]",
    )
    parser.add_option(
        "-d",
        dest="down_sample",
        default=1,
        type="int",
        help=
        "Down sample by taking uniformly spaced positions [Default: %default]",
    )
    parser.add_option(
        "-g",
        dest="genome_file",
        default=None,
        help="Chromosome length information [Default: %default]",
    )
    parser.add_option(
        "--mc",
        dest="mc_n",
        default=0,
        type="int",
        help="Monte carlo test iterations [Default: %default]",
    )
    parser.add_option(
        "--peak",
        "--peaks",
        dest="peaks",
        default=False,
        action="store_true",
        help="Compute expensive peak accuracy [Default: %default]",
    )
    parser.add_option(
        "-o",
        dest="out_dir",
        default="test_out",
        help="Output directory for test statistics [Default: %default]",
    )
    parser.add_option(
        "--rc",
        dest="rc",
        default=False,
        action="store_true",
        help="Average the fwd and rc predictions [Default: %default]",
    )
    parser.add_option(
        "--save",
        dest="save",
        default=False,
        action="store_true",
        help="Save targets and predictions numpy arrays [Default: %default]",
    )
    parser.add_option(
        "--shifts",
        dest="shifts",
        default="0",
        help="Ensemble prediction shifts [Default: %default]",
    )
    parser.add_option(
        "-t",
        dest="targets_file",
        default=None,
        type="str",
        help="File specifying target indexes and labels in table format",
    )
    parser.add_option(
        "--ti",
        dest="track_indexes",
        help="Comma-separated list of target indexes to output BigWig tracks",
    )
    parser.add_option(
        "--tfr",
        dest="tfr_pattern",
        default="test-*.tfr",
        help="TFR pattern string appended to data_dir [Default: %default]",
    )
    parser.add_option(
        "-w",
        dest="pool_width",
        default=1,
        type="int",
        help=
        "Max pool width for regressing nt preds to peak calls [Default: %default]",
    )
    (options, args) = parser.parse_args()

    if len(args) != 3:
        parser.error("Must provide parameters, model, and test data HDF5")
    else:
        params_file = args[0]
        model_file = args[1]
        data_dir = args[2]

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    # parse shifts to integers
    options.shifts = [int(shift) for shift in options.shifts.split(",")]

    # read targets
    if options.targets_file is None:
        options.targets_file = "%s/targets.txt" % data_dir
    targets_df = pd.read_table(options.targets_file, index_col=0)

    # read model parameters
    job = params.read_job_params(params_file)

    # construct data ops
    tfr_pattern_path = "%s/tfrecords/%s" % (data_dir, options.tfr_pattern)
    data_ops, test_init_op, test_dataseq = make_data_ops(job, tfr_pattern_path)

    # initialize model
    model = seqnn.SeqNN()
    model.build_from_data_ops(job,
                              data_ops,
                              ensemble_rc=options.rc,
                              ensemble_shifts=options.shifts)

    # initialize saver
    saver = tf.train.Saver()

    with tf.Session() as sess:
        # start queue runners
        coord = tf.train.Coordinator()
        tf.train.start_queue_runners(coord=coord)

        # load variables into session
        saver.restore(sess, model_file)

        # test
        t0 = time.time()
        sess.run(test_init_op)
        test_acc = model.test_tfr(sess, test_dataseq)

        test_preds = test_acc.preds
        test_targets = test_acc.targets
        print("SeqNN test: %ds" % (time.time() - t0))

        if options.save:
            preds_h5 = h5py.File("%s/preds.h5" % options.out_dir, "w")
            preds_h5.create_dataset("preds", data=test_preds)
            preds_h5.close()
            targets_h5 = h5py.File("%s/targets.h5" % options.out_dir, "w")
            targets_h5.create_dataset("targets", data=test_targets)
            targets_h5.close()

        # compute stats
        t0 = time.time()
        test_r2 = test_acc.r2(clip=options.target_clip)
        # test_log_r2 = test_acc.r2(log=True, clip=options.target_clip)
        test_pcor = test_acc.pearsonr(clip=options.target_clip)
        test_log_pcor = test_acc.pearsonr(log=True, clip=options.target_clip)
        # test_scor = test_acc.spearmanr()  # too slow; mostly driven by low values
        print("Compute stats: %ds" % (time.time() - t0))

        # print
        print("Test Loss:         %7.5f" % test_acc.loss)
        print("Test R2:           %7.5f" % test_r2.mean())
        # print('Test log R2:       %7.5f' % test_log_r2.mean())
        print("Test PearsonR:     %7.5f" % test_pcor.mean())
        print("Test log PearsonR: %7.5f" % test_log_pcor.mean())
        # print('Test SpearmanR:    %7.5f' % test_scor.mean())

        acc_out = open("%s/acc.txt" % options.out_dir, "w")
        for ti in range(len(test_r2)):
            print(
                "%4d  %7.5f  %.5f  %.5f  %.5f  %10s  %s" % (
                    ti,
                    test_acc.target_losses[ti],
                    test_r2[ti],
                    test_pcor[ti],
                    test_log_pcor[ti],
                    targets_df.identifier.iloc[ti],
                    targets_df.description.iloc[ti],
                ),
                file=acc_out,
            )
        acc_out.close()

        # print normalization factors
        target_means = test_preds.mean(axis=(0, 1), dtype="float64")
        target_means_median = np.median(target_means)
        # target_means /= target_means_median
        norm_out = open("%s/normalization.txt" % options.out_dir, "w")
        # print('\n'.join([str(tu) for tu in target_means]), file=norm_out)
        for ti in range(len(target_means)):
            print(
                ti,
                target_means[ti],
                target_means_median / target_means[ti],
                file=norm_out,
            )
        norm_out.close()

        # clean up
        del test_acc

    #######################################################
    # peak call accuracy

    if options.peaks:
        # sample every few bins to decrease correlations
        ds_indexes = np.arange(0, test_preds.shape[1], 8)
        # ds_indexes_preds = np.arange(0, test_preds.shape[1], 8)
        # ds_indexes_targets = ds_indexes_preds + (model.hp.batch_buffer // model.hp.target_pool)

        aurocs = []
        auprcs = []

        peaks_out = open("%s/peaks.txt" % options.out_dir, "w")
        for ti in range(test_targets.shape[2]):
            test_targets_ti = test_targets[:, :, ti]

            # subset and flatten
            test_targets_ti_flat = (
                test_targets_ti[:, ds_indexes].flatten().astype("float32"))
            test_preds_ti_flat = (test_preds[:, ds_indexes,
                                             ti].flatten().astype("float32"))

            # call peaks
            test_targets_ti_lambda = np.mean(test_targets_ti_flat)
            test_targets_pvals = 1 - poisson.cdf(
                np.round(test_targets_ti_flat) - 1, mu=test_targets_ti_lambda)
            test_targets_qvals = np.array(ben_hoch(test_targets_pvals))
            test_targets_peaks = test_targets_qvals < 0.01

            if test_targets_peaks.sum() == 0:
                aurocs.append(0.5)
                auprcs.append(0)

            else:
                # compute prediction accuracy
                aurocs.append(
                    roc_auc_score(test_targets_peaks, test_preds_ti_flat))
                auprcs.append(
                    average_precision_score(test_targets_peaks,
                                            test_preds_ti_flat))

            print(
                "%4d  %6d  %.5f  %.5f" %
                (ti, test_targets_peaks.sum(), aurocs[-1], auprcs[-1]),
                file=peaks_out,
            )

        peaks_out.close()

        print("Test AUROC:     %7.5f" % np.mean(aurocs))
        print("Test AUPRC:     %7.5f" % np.mean(auprcs))

    #######################################################
    # BigWig tracks

    # NOTE: THESE ASSUME THERE WAS NO DOWN-SAMPLING ABOVE

    # print bigwig tracks for visualization
    if options.track_bed:
        if options.genome_file is None:
            parser.error(
                "Must provide genome file in order to print valid BigWigs")

        if not os.path.isdir("%s/tracks" % options.out_dir):
            os.mkdir("%s/tracks" % options.out_dir)

        track_indexes = range(test_preds.shape[2])
        if options.track_indexes:
            track_indexes = [
                int(ti) for ti in options.track_indexes.split(",")
            ]

        for ti in track_indexes:
            test_targets_ti = test_targets[:, :, ti]

            # make true targets bigwig
            bw_file = "%s/tracks/t%d_true.bw" % (options.out_dir, ti)
            bigwig_write(
                bw_file,
                test_targets_ti,
                options.track_bed,
                options.genome_file,
                model.hp.batch_buffer,
            )
            # buffer unnecessary, but there are overlaps without it

            # make predictions bigwig
            bw_file = "%s/tracks/t%d_preds.bw" % (options.out_dir, ti)
            bigwig_write(
                bw_file,
                test_preds[:, :, ti],
                options.track_bed,
                options.genome_file,
                model.hp.batch_buffer,
            )

    #######################################################
    # accuracy plots

    if options.accuracy_indexes is not None:
        accuracy_indexes = [
            int(ti) for ti in options.accuracy_indexes.split(",")
        ]

        if not os.path.isdir("%s/scatter" % options.out_dir):
            os.mkdir("%s/scatter" % options.out_dir)

        if not os.path.isdir("%s/violin" % options.out_dir):
            os.mkdir("%s/violin" % options.out_dir)

        if not os.path.isdir("%s/roc" % options.out_dir):
            os.mkdir("%s/roc" % options.out_dir)

        if not os.path.isdir("%s/pr" % options.out_dir):
            os.mkdir("%s/pr" % options.out_dir)

        for ti in accuracy_indexes:
            test_targets_ti = test_targets[:, :, ti]

            ############################################
            # scatter

            # sample every few bins (adjust to plot the # points I want)
            ds_indexes = np.arange(0, test_preds.shape[1], 8)

            # subset and flatten
            test_targets_ti_flat = (
                test_targets_ti[:, ds_indexes].flatten().astype("float32"))
            test_preds_ti_flat = (test_preds[:, ds_indexes,
                                             ti].flatten().astype("float32"))

            # take log2
            test_targets_ti_log = np.log2(test_targets_ti_flat + 1)
            test_preds_ti_log = np.log2(test_preds_ti_flat + 1)

            # plot log2
            sns.set(font_scale=1.2, style="ticks")
            out_pdf = "%s/scatter/t%d.pdf" % (options.out_dir, ti)
            plots.regplot(
                test_targets_ti_log,
                test_preds_ti_log,
                out_pdf,
                poly_order=1,
                alpha=0.3,
                sample=500,
                figsize=(6, 6),
                x_label="log2 Experiment",
                y_label="log2 Prediction",
                table=True,
            )

            ############################################
            # violin

            # call peaks
            test_targets_ti_lambda = np.mean(test_targets_ti_flat)
            test_targets_pvals = 1 - poisson.cdf(
                np.round(test_targets_ti_flat) - 1, mu=test_targets_ti_lambda)
            test_targets_qvals = np.array(ben_hoch(test_targets_pvals))
            test_targets_peaks = test_targets_qvals < 0.01
            test_targets_peaks_str = np.where(test_targets_peaks, "Peak",
                                              "Background")

            # violin plot
            sns.set(font_scale=1.3, style="ticks")
            plt.figure()
            df = pd.DataFrame({
                "log2 Prediction":
                np.log2(test_preds_ti_flat + 1),
                "Experimental coverage status":
                test_targets_peaks_str,
            })
            ax = sns.violinplot(x="Experimental coverage status",
                                y="log2 Prediction",
                                data=df)
            ax.grid(True, linestyle=":")
            plt.savefig("%s/violin/t%d.pdf" % (options.out_dir, ti))
            plt.close()

            # ROC
            plt.figure()
            fpr, tpr, _ = roc_curve(test_targets_peaks, test_preds_ti_flat)
            auroc = roc_auc_score(test_targets_peaks, test_preds_ti_flat)
            plt.plot([0, 1], [0, 1],
                     c="black",
                     linewidth=1,
                     linestyle="--",
                     alpha=0.7)
            plt.plot(fpr, tpr, c="black")
            ax = plt.gca()
            ax.set_xlabel("False positive rate")
            ax.set_ylabel("True positive rate")
            ax.text(0.99,
                    0.02,
                    "AUROC %.3f" % auroc,
                    horizontalalignment="right")  # , fontsize=14)
            ax.grid(True, linestyle=":")
            plt.savefig("%s/roc/t%d.pdf" % (options.out_dir, ti))
            plt.close()

            # PR
            plt.figure()
            prec, recall, _ = precision_recall_curve(test_targets_peaks,
                                                     test_preds_ti_flat)
            auprc = average_precision_score(test_targets_peaks,
                                            test_preds_ti_flat)
            plt.axhline(
                y=test_targets_peaks.mean(),
                c="black",
                linewidth=1,
                linestyle="--",
                alpha=0.7,
            )
            plt.plot(recall, prec, c="black")
            ax = plt.gca()
            ax.set_xlabel("Recall")
            ax.set_ylabel("Precision")
            ax.text(0.99,
                    0.95,
                    "AUPRC %.3f" % auprc,
                    horizontalalignment="right")  # , fontsize=14)
            ax.grid(True, linestyle=":")
            plt.savefig("%s/pr/t%d.pdf" % (options.out_dir, ti))
            plt.close()