Beispiel #1
0
def main():
    usage = 'usage: %prog [options] <params_file> <model_file> <test_hdf5_file>'
    parser = OptionParser(usage)
    parser.add_option(
        '--ai',
        dest='accuracy_indexes',
        help=
        'Comma-separated list of target indexes to make accuracy plots comparing true versus predicted values'
    )
    parser.add_option(
        '--clip',
        dest='target_clip',
        default=None,
        type='float',
        help=
        'Clip targets and predictions to a maximum value [Default: %default]')
    parser.add_option(
        '-d',
        dest='down_sample',
        default=1,
        type='int',
        help=
        'Down sample test computation by taking uniformly spaced positions [Default: %default]'
    )
    parser.add_option('-g',
                      dest='genome_file',
                      default='%s/assembly/human.hg19.genome' %
                      os.environ['HG19'],
                      help='Chromosome length information [Default: %default]')
    parser.add_option('--mc',
                      dest='mc_n',
                      default=0,
                      type='int',
                      help='Monte carlo test iterations [Default: %default]')
    parser.add_option(
        '--peak',
        '--peaks',
        dest='peaks',
        default=False,
        action='store_true',
        help='Compute expensive peak accuracy [Default: %default]')
    parser.add_option(
        '-o',
        dest='out_dir',
        default='test_out',
        help='Output directory for test statistics [Default: %default]')
    parser.add_option(
        '--rc',
        dest='rc',
        default=False,
        action='store_true',
        help='Average the fwd and rc predictions [Default: %default]')
    parser.add_option('-s',
                      dest='scent_file',
                      help='Dimension reduction model file')
    parser.add_option('--sample',
                      dest='sample_pct',
                      default=1,
                      type='float',
                      help='Sample percentage')
    parser.add_option('--save',
                      dest='save',
                      default=False,
                      action='store_true')
    parser.add_option('--shifts',
                      dest='shifts',
                      default='0',
                      help='Ensemble prediction shifts [Default: %default]')
    parser.add_option(
        '-t',
        dest='track_bed',
        help='BED file describing regions so we can output BigWig tracks')
    parser.add_option(
        '--ti',
        dest='track_indexes',
        help='Comma-separated list of target indexes to output BigWig tracks')
    parser.add_option('--train',
                      dest='train',
                      default=False,
                      action='store_true',
                      help='Process the training set [Default: %default]')
    parser.add_option('-v',
                      dest='valid',
                      default=False,
                      action='store_true',
                      help='Process the validation set [Default: %default]')
    parser.add_option(
        '-w',
        dest='pool_width',
        default=1,
        type='int',
        help=
        'Max pool width for regressing nt predictions to predict peak calls [Default: %default]'
    )
    (options, args) = parser.parse_args()

    if len(args) != 3:
        parser.error('Must provide parameters, model, and test data HDF5')
    else:
        params_file = args[0]
        model_file = args[1]
        test_hdf5_file = args[2]

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    options.shifts = [int(shift) for shift in options.shifts.split(',')]

    #######################################################
    # load data
    #######################################################
    data_open = h5py.File(test_hdf5_file)

    if options.train:
        test_seqs = data_open['train_in']
        test_targets = data_open['train_out']
        if 'train_na' in data_open:
            test_na = data_open['train_na']

    elif options.valid:
        test_seqs = data_open['valid_in']
        test_targets = data_open['valid_out']
        test_na = None
        if 'valid_na' in data_open:
            test_na = data_open['valid_na']

    else:
        test_seqs = data_open['test_in']
        test_targets = data_open['test_out']
        test_na = None
        if 'test_na' in data_open:
            test_na = data_open['test_na']

    if options.sample_pct < 1:
        sample_n = int(test_seqs.shape[0] * options.sample_pct)
        print('Sampling %d sequences' % sample_n)
        sample_indexes = sorted(
            np.random.choice(np.arange(test_seqs.shape[0]),
                             size=sample_n,
                             replace=False))
        test_seqs = test_seqs[sample_indexes]
        test_targets = test_targets[sample_indexes]
        if test_na is not None:
            test_na = test_na[sample_indexes]

    target_labels = [tl.decode('UTF-8') for tl in data_open['target_labels']]

    #######################################################
    # model parameters and placeholders

    job = params.read_job_params(params_file)

    job['seq_length'] = test_seqs.shape[1]
    job['seq_depth'] = test_seqs.shape[2]
    job['num_targets'] = test_targets.shape[2]
    job['target_pool'] = int(np.array(data_open.get('pool_width', 1)))

    t0 = time.time()
    dr = seqnn.SeqNN()
    dr.build(job)
    print('Model building time %ds' % (time.time() - t0))

    # adjust for fourier
    job['fourier'] = 'train_out_imag' in data_open
    if job['fourier']:
        test_targets_imag = data_open['test_out_imag']
        if options.valid:
            test_targets_imag = data_open['valid_out_imag']

    # adjust for factors
    if options.scent_file is not None:
        t0 = time.time()
        test_targets_full = data_open['test_out_full']
        model = joblib.load(options.scent_file)

    #######################################################
    # test

    # initialize batcher
    if job['fourier']:
        batcher_test = batcher.BatcherF(test_seqs, test_targets,
                                        test_targets_imag, test_na,
                                        dr.hp.batch_size, dr.hp.target_pool)
    else:
        batcher_test = batcher.Batcher(test_seqs, test_targets, test_na,
                                       dr.hp.batch_size, dr.hp.target_pool)

    # initialize saver
    saver = tf.train.Saver()

    with tf.Session() as sess:
        # load variables into session
        saver.restore(sess, model_file)

        # test
        t0 = time.time()
        test_acc = dr.test(sess,
                           batcher_test,
                           rc=options.rc,
                           shifts=options.shifts,
                           mc_n=options.mc_n)

        if options.save:
            np.save('%s/preds.npy' % options.out_dir, test_acc.preds)
            np.save('%s/targets.npy' % options.out_dir, test_acc.targets)

        test_preds = test_acc.preds
        print('SeqNN test: %ds' % (time.time() - t0))

        # compute stats
        t0 = time.time()
        test_r2 = test_acc.r2(clip=options.target_clip)
        # test_log_r2 = test_acc.r2(log=True, clip=options.target_clip)
        test_pcor = test_acc.pearsonr(clip=options.target_clip)
        test_log_pcor = test_acc.pearsonr(log=True, clip=options.target_clip)
        #test_scor = test_acc.spearmanr()  # too slow; mostly driven by low values
        print('Compute stats: %ds' % (time.time() - t0))

        # print
        print('Test Loss:         %7.5f' % test_acc.loss)
        print('Test R2:           %7.5f' % test_r2.mean())
        # print('Test log R2:       %7.5f' % test_log_r2.mean())
        print('Test PearsonR:     %7.5f' % test_pcor.mean())
        print('Test log PearsonR: %7.5f' % test_log_pcor.mean())
        # print('Test SpearmanR:    %7.5f' % test_scor.mean())

        acc_out = open('%s/acc.txt' % options.out_dir, 'w')
        for ti in range(len(test_r2)):
            print('%4d  %7.5f  %.5f  %.5f  %.5f  %s' %
                  (ti, test_acc.target_losses[ti], test_r2[ti], test_pcor[ti],
                   test_log_pcor[ti], target_labels[ti]),
                  file=acc_out)
        acc_out.close()

        # print normalization factors
        target_means = test_preds.mean(axis=(0, 1), dtype='float64')
        target_means_median = np.median(target_means)
        target_means /= target_means_median
        norm_out = open('%s/normalization.txt' % options.out_dir, 'w')
        print('\n'.join([str(tu) for tu in target_means]), file=norm_out)
        norm_out.close()

        # clean up
        del test_acc

        # if test targets are reconstructed, measure versus the truth
        if options.scent_file is not None:
            compute_full_accuracy(dr, model, test_preds, test_targets_full,
                                  options.out_dir, options.down_sample)

    #######################################################
    # peak call accuracy

    if options.peaks:
        # sample every few bins to decrease correlations
        ds_indexes_preds = np.arange(0, test_preds.shape[1], 8)
        ds_indexes_targets = ds_indexes_preds + (dr.hp.batch_buffer //
                                                 dr.hp.target_pool)

        aurocs = []
        auprcs = []

        peaks_out = open('%s/peaks.txt' % options.out_dir, 'w')
        for ti in range(test_targets.shape[2]):
            if options.scent_file is not None:
                test_targets_ti = test_targets_full[:, :, ti]
            else:
                test_targets_ti = test_targets[:, :, ti]

            # subset and flatten
            test_targets_ti_flat = test_targets_ti[:,
                                                   ds_indexes_targets].flatten(
                                                   ).astype('float32')
            test_preds_ti_flat = test_preds[:, ds_indexes_preds,
                                            ti].flatten().astype('float32')

            # call peaks
            test_targets_ti_lambda = np.mean(test_targets_ti_flat)
            test_targets_pvals = 1 - poisson.cdf(
                np.round(test_targets_ti_flat) - 1, mu=test_targets_ti_lambda)
            test_targets_qvals = np.array(ben_hoch(test_targets_pvals))
            test_targets_peaks = test_targets_qvals < 0.01

            if test_targets_peaks.sum() == 0:
                aurocs.append(0.5)
                auprcs.append(0)

            else:
                # compute prediction accuracy
                aurocs.append(
                    roc_auc_score(test_targets_peaks, test_preds_ti_flat))
                auprcs.append(
                    average_precision_score(test_targets_peaks,
                                            test_preds_ti_flat))

            print('%4d  %6d  %.5f  %.5f' %
                  (ti, test_targets_peaks.sum(), aurocs[-1], auprcs[-1]),
                  file=peaks_out)

        peaks_out.close()

        print('Test AUROC:     %7.5f' % np.mean(aurocs))
        print('Test AUPRC:     %7.5f' % np.mean(auprcs))

    #######################################################
    # BigWig tracks

    # NOTE: THESE ASSUME THERE WAS NO DOWN-SAMPLING ABOVE

    # print bigwig tracks for visualization
    if options.track_bed:
        if options.genome_file is None:
            parser.error(
                'Must provide genome file in order to print valid BigWigs')

        if not os.path.isdir('%s/tracks' % options.out_dir):
            os.mkdir('%s/tracks' % options.out_dir)

        track_indexes = range(test_preds.shape[2])
        if options.track_indexes:
            track_indexes = [
                int(ti) for ti in options.track_indexes.split(',')
            ]

        bed_set = 'test'
        if options.valid:
            bed_set = 'valid'

        for ti in track_indexes:
            if options.scent_file is not None:
                test_targets_ti = test_targets_full[:, :, ti]
            else:
                test_targets_ti = test_targets[:, :, ti]

            # make true targets bigwig
            bw_file = '%s/tracks/t%d_true.bw' % (options.out_dir, ti)
            bigwig_write(bw_file,
                         test_targets_ti,
                         options.track_bed,
                         options.genome_file,
                         bed_set=bed_set)

            # make predictions bigwig
            bw_file = '%s/tracks/t%d_preds.bw' % (options.out_dir, ti)
            bigwig_write(bw_file,
                         test_preds[:, :, ti],
                         options.track_bed,
                         options.genome_file,
                         dr.hp.batch_buffer,
                         bed_set=bed_set)

        # make NA bigwig
        bw_file = '%s/tracks/na.bw' % options.out_dir
        bigwig_write(bw_file,
                     test_na,
                     options.track_bed,
                     options.genome_file,
                     bed_set=bed_set)

    #######################################################
    # accuracy plots

    if options.accuracy_indexes is not None:
        accuracy_indexes = [
            int(ti) for ti in options.accuracy_indexes.split(',')
        ]

        if not os.path.isdir('%s/scatter' % options.out_dir):
            os.mkdir('%s/scatter' % options.out_dir)

        if not os.path.isdir('%s/violin' % options.out_dir):
            os.mkdir('%s/violin' % options.out_dir)

        if not os.path.isdir('%s/roc' % options.out_dir):
            os.mkdir('%s/roc' % options.out_dir)

        if not os.path.isdir('%s/pr' % options.out_dir):
            os.mkdir('%s/pr' % options.out_dir)

        for ti in accuracy_indexes:
            if options.scent_file is not None:
                test_targets_ti = test_targets_full[:, :, ti]
            else:
                test_targets_ti = test_targets[:, :, ti]

            ############################################
            # scatter

            # sample every few bins (adjust to plot the # points I want)
            ds_indexes_preds = np.arange(0, test_preds.shape[1], 8)
            ds_indexes_targets = ds_indexes_preds + (dr.hp.batch_buffer //
                                                     dr.hp.target_pool)

            # subset and flatten
            test_targets_ti_flat = test_targets_ti[:,
                                                   ds_indexes_targets].flatten(
                                                   ).astype('float32')
            test_preds_ti_flat = test_preds[:, ds_indexes_preds,
                                            ti].flatten().astype('float32')

            # take log2
            test_targets_ti_log = np.log2(test_targets_ti_flat + 1)
            test_preds_ti_log = np.log2(test_preds_ti_flat + 1)

            # plot log2
            sns.set(font_scale=1.2, style='ticks')
            out_pdf = '%s/scatter/t%d.pdf' % (options.out_dir, ti)
            plots.regplot(test_targets_ti_log,
                          test_preds_ti_log,
                          out_pdf,
                          poly_order=1,
                          alpha=0.3,
                          sample=500,
                          figsize=(6, 6),
                          x_label='log2 Experiment',
                          y_label='log2 Prediction',
                          table=True)

            ############################################
            # violin

            # call peaks
            test_targets_ti_lambda = np.mean(test_targets_ti_flat)
            test_targets_pvals = 1 - poisson.cdf(
                np.round(test_targets_ti_flat) - 1, mu=test_targets_ti_lambda)
            test_targets_qvals = np.array(ben_hoch(test_targets_pvals))
            test_targets_peaks = test_targets_qvals < 0.01
            test_targets_peaks_str = np.where(test_targets_peaks, 'Peak',
                                              'Background')

            # violin plot
            sns.set(font_scale=1.3, style='ticks')
            plt.figure()
            df = pd.DataFrame({
                'log2 Prediction':
                np.log2(test_preds_ti_flat + 1),
                'Experimental coverage status':
                test_targets_peaks_str
            })
            ax = sns.violinplot(x='Experimental coverage status',
                                y='log2 Prediction',
                                data=df)
            ax.grid(True, linestyle=':')
            plt.savefig('%s/violin/t%d.pdf' % (options.out_dir, ti))
            plt.close()

            # ROC
            plt.figure()
            fpr, tpr, _ = roc_curve(test_targets_peaks, test_preds_ti_flat)
            auroc = roc_auc_score(test_targets_peaks, test_preds_ti_flat)
            plt.plot([0, 1], [0, 1],
                     c='black',
                     linewidth=1,
                     linestyle='--',
                     alpha=0.7)
            plt.plot(fpr, tpr, c='black')
            ax = plt.gca()
            ax.set_xlabel('False positive rate')
            ax.set_ylabel('True positive rate')
            ax.text(0.99,
                    0.02,
                    'AUROC %.3f' % auroc,
                    horizontalalignment='right')  # , fontsize=14)
            ax.grid(True, linestyle=':')
            plt.savefig('%s/roc/t%d.pdf' % (options.out_dir, ti))
            plt.close()

            # PR
            plt.figure()
            prec, recall, _ = precision_recall_curve(test_targets_peaks,
                                                     test_preds_ti_flat)
            auprc = average_precision_score(test_targets_peaks,
                                            test_preds_ti_flat)
            plt.axhline(y=test_targets_peaks.mean(),
                        c='black',
                        linewidth=1,
                        linestyle='--',
                        alpha=0.7)
            plt.plot(recall, prec, c='black')
            ax = plt.gca()
            ax.set_xlabel('Recall')
            ax.set_ylabel('Precision')
            ax.text(0.99,
                    0.95,
                    'AUPRC %.3f' % auprc,
                    horizontalalignment='right')  # , fontsize=14)
            ax.grid(True, linestyle=':')
            plt.savefig('%s/pr/t%d.pdf' % (options.out_dir, ti))
            plt.close()

    data_open.close()
Beispiel #2
0
def main():
    usage = "usage: %prog [options] <params_file> <model_file> <test_hdf5_file>"
    parser = OptionParser(usage)
    parser.add_option(
        "--ai",
        dest="accuracy_indexes",
        help=
        "Comma-separated list of target indexes to make accuracy plots comparing true versus predicted values",
    )
    parser.add_option(
        "--clip",
        dest="target_clip",
        default=None,
        type="float",
        help=
        "Clip targets and predictions to a maximum value [Default: %default]",
    )
    parser.add_option(
        "-d",
        dest="down_sample",
        default=1,
        type="int",
        help=
        "Down sample test computation by taking uniformly spaced positions [Default: %default]",
    )
    parser.add_option(
        "-g",
        dest="genome_file",
        default="%s/data/human.hg19.genome" % os.environ["BASENJIDIR"],
        help="Chromosome length information [Default: %default]",
    )
    parser.add_option(
        "--mc",
        dest="mc_n",
        default=0,
        type="int",
        help="Monte carlo test iterations [Default: %default]",
    )
    parser.add_option(
        "--peak",
        "--peaks",
        dest="peaks",
        default=False,
        action="store_true",
        help="Compute expensive peak accuracy [Default: %default]",
    )
    parser.add_option(
        "-o",
        dest="out_dir",
        default="test_out",
        help="Output directory for test statistics [Default: %default]",
    )
    parser.add_option(
        "--rc",
        dest="rc",
        default=False,
        action="store_true",
        help="Average the fwd and rc predictions [Default: %default]",
    )
    parser.add_option("-s",
                      dest="scent_file",
                      help="Dimension reduction model file")
    parser.add_option("--sample",
                      dest="sample_pct",
                      default=1,
                      type="float",
                      help="Sample percentage")
    parser.add_option("--save",
                      dest="save",
                      default=False,
                      action="store_true")
    parser.add_option(
        "--shifts",
        dest="shifts",
        default="0",
        help="Ensemble prediction shifts [Default: %default]",
    )
    parser.add_option(
        "-t",
        dest="track_bed",
        help="BED file describing regions so we can output BigWig tracks",
    )
    parser.add_option(
        "--ti",
        dest="track_indexes",
        help="Comma-separated list of target indexes to output BigWig tracks",
    )
    parser.add_option(
        "--train",
        dest="train",
        default=False,
        action="store_true",
        help="Process the training set [Default: %default]",
    )
    parser.add_option(
        "-v",
        dest="valid",
        default=False,
        action="store_true",
        help="Process the validation set [Default: %default]",
    )
    parser.add_option(
        "-w",
        dest="pool_width",
        default=1,
        type="int",
        help=
        "Max pool width for regressing nt predictions to predict peak calls [Default: %default]",
    )
    (options, args) = parser.parse_args()

    if len(args) != 3:
        parser.error("Must provide parameters, model, and test data HDF5")
    else:
        params_file = args[0]
        model_file = args[1]
        test_hdf5_file = args[2]

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    options.shifts = [int(shift) for shift in options.shifts.split(",")]

    #######################################################
    # load data
    #######################################################
    data_open = h5py.File(test_hdf5_file)

    if options.train:
        test_seqs = data_open["train_in"]
        test_targets = data_open["train_out"]
        if "train_na" in data_open:
            test_na = data_open["train_na"]

    elif options.valid:
        test_seqs = data_open["valid_in"]
        test_targets = data_open["valid_out"]
        test_na = None
        if "valid_na" in data_open:
            test_na = data_open["valid_na"]

    else:
        test_seqs = data_open["test_in"]
        test_targets = data_open["test_out"]
        test_na = None
        if "test_na" in data_open:
            test_na = data_open["test_na"]

    if options.sample_pct < 1:
        sample_n = int(test_seqs.shape[0] * options.sample_pct)
        print("Sampling %d sequences" % sample_n)
        sample_indexes = sorted(
            np.random.choice(np.arange(test_seqs.shape[0]),
                             size=sample_n,
                             replace=False))
        test_seqs = test_seqs[sample_indexes]
        test_targets = test_targets[sample_indexes]
        if test_na is not None:
            test_na = test_na[sample_indexes]

    target_labels = [tl.decode("UTF-8") for tl in data_open["target_labels"]]

    #######################################################
    # model parameters and placeholders

    job = params.read_job_params(params_file)
    job["seq_length"] = test_seqs.shape[1]
    job["seq_depth"] = test_seqs.shape[2]
    job["num_targets"] = test_targets.shape[2]
    job["target_pool"] = int(np.array(data_open.get("pool_width", 1)))

    t0 = time.time()
    dr = seqnn.SeqNN()
    dr.build_feed(job)
    print("Model building time %ds" % (time.time() - t0))

    # adjust for fourier
    job["fourier"] = "train_out_imag" in data_open
    if job["fourier"]:
        test_targets_imag = data_open["test_out_imag"]
        if options.valid:
            test_targets_imag = data_open["valid_out_imag"]

    # adjust for factors
    if options.scent_file is not None:
        t0 = time.time()
        test_targets_full = data_open["test_out_full"]
        model = joblib.load(options.scent_file)

    #######################################################
    # test

    # initialize batcher
    if job["fourier"]:
        batcher_test = batcher.BatcherF(
            test_seqs,
            test_targets,
            test_targets_imag,
            test_na,
            dr.hp.batch_size,
            dr.hp.target_pool,
        )
    else:
        batcher_test = batcher.Batcher(test_seqs, test_targets, test_na,
                                       dr.hp.batch_size, dr.hp.target_pool)

    # initialize saver
    saver = tf.train.Saver()

    with tf.Session() as sess:
        # load variables into session
        saver.restore(sess, model_file)

        # test
        t0 = time.time()
        test_acc = dr.test_h5_manual(sess,
                                     batcher_test,
                                     rc=options.rc,
                                     shifts=options.shifts,
                                     mc_n=options.mc_n)

        if options.save:
            np.save("%s/preds.npy" % options.out_dir, test_acc.preds)
            np.save("%s/targets.npy" % options.out_dir, test_acc.targets)

        test_preds = test_acc.preds
        print("SeqNN test: %ds" % (time.time() - t0))

        # compute stats
        t0 = time.time()
        test_r2 = test_acc.r2(clip=options.target_clip)
        # test_log_r2 = test_acc.r2(log=True, clip=options.target_clip)
        test_pcor = test_acc.pearsonr(clip=options.target_clip)
        test_log_pcor = test_acc.pearsonr(log=True, clip=options.target_clip)
        # test_scor = test_acc.spearmanr()  # too slow; mostly driven by low values
        print("Compute stats: %ds" % (time.time() - t0))

        # print
        print("Test Loss:         %7.5f" % test_acc.loss)
        print("Test R2:           %7.5f" % test_r2.mean())
        # print('Test log R2:       %7.5f' % test_log_r2.mean())
        print("Test PearsonR:     %7.5f" % test_pcor.mean())
        print("Test log PearsonR: %7.5f" % test_log_pcor.mean())
        # print('Test SpearmanR:    %7.5f' % test_scor.mean())

        acc_out = open("%s/acc.txt" % options.out_dir, "w")
        for ti in range(len(test_r2)):
            print(
                "%4d  %7.5f  %.5f  %.5f  %.5f  %s" % (
                    ti,
                    test_acc.target_losses[ti],
                    test_r2[ti],
                    test_pcor[ti],
                    test_log_pcor[ti],
                    target_labels[ti],
                ),
                file=acc_out,
            )
        acc_out.close()

        # print normalization factors
        target_means = test_preds.mean(axis=(0, 1), dtype="float64")
        target_means_median = np.median(target_means)
        target_means /= target_means_median
        norm_out = open("%s/normalization.txt" % options.out_dir, "w")
        print("\n".join([str(tu) for tu in target_means]), file=norm_out)
        norm_out.close()

        # clean up
        del test_acc

        # if test targets are reconstructed, measure versus the truth
        if options.scent_file is not None:
            compute_full_accuracy(
                dr,
                model,
                test_preds,
                test_targets_full,
                options.out_dir,
                options.down_sample,
            )

    #######################################################
    # peak call accuracy

    if options.peaks:
        # sample every few bins to decrease correlations
        ds_indexes_preds = np.arange(0, test_preds.shape[1], 8)
        ds_indexes_targets = ds_indexes_preds + (dr.hp.batch_buffer //
                                                 dr.hp.target_pool)

        aurocs = []
        auprcs = []

        peaks_out = open("%s/peaks.txt" % options.out_dir, "w")
        for ti in range(test_targets.shape[2]):
            if options.scent_file is not None:
                test_targets_ti = test_targets_full[:, :, ti]
            else:
                test_targets_ti = test_targets[:, :, ti]

            # subset and flatten
            test_targets_ti_flat = (test_targets_ti[:, ds_indexes_targets].
                                    flatten().astype("float32"))
            test_preds_ti_flat = (test_preds[:, ds_indexes_preds,
                                             ti].flatten().astype("float32"))

            # call peaks
            test_targets_ti_lambda = np.mean(test_targets_ti_flat)
            test_targets_pvals = 1 - poisson.cdf(
                np.round(test_targets_ti_flat) - 1, mu=test_targets_ti_lambda)
            test_targets_qvals = np.array(ben_hoch(test_targets_pvals))
            test_targets_peaks = test_targets_qvals < 0.01

            if test_targets_peaks.sum() == 0:
                aurocs.append(0.5)
                auprcs.append(0)

            else:
                # compute prediction accuracy
                aurocs.append(
                    roc_auc_score(test_targets_peaks, test_preds_ti_flat))
                auprcs.append(
                    average_precision_score(test_targets_peaks,
                                            test_preds_ti_flat))

            print(
                "%4d  %6d  %.5f  %.5f" %
                (ti, test_targets_peaks.sum(), aurocs[-1], auprcs[-1]),
                file=peaks_out,
            )

        peaks_out.close()

        print("Test AUROC:     %7.5f" % np.mean(aurocs))
        print("Test AUPRC:     %7.5f" % np.mean(auprcs))

    #######################################################
    # BigWig tracks

    # NOTE: THESE ASSUME THERE WAS NO DOWN-SAMPLING ABOVE

    # print bigwig tracks for visualization
    if options.track_bed:
        if options.genome_file is None:
            parser.error(
                "Must provide genome file in order to print valid BigWigs")

        if not os.path.isdir("%s/tracks" % options.out_dir):
            os.mkdir("%s/tracks" % options.out_dir)

        track_indexes = range(test_preds.shape[2])
        if options.track_indexes:
            track_indexes = [
                int(ti) for ti in options.track_indexes.split(",")
            ]

        bed_set = "test"
        if options.valid:
            bed_set = "valid"

        for ti in track_indexes:
            if options.scent_file is not None:
                test_targets_ti = test_targets_full[:, :, ti]
            else:
                test_targets_ti = test_targets[:, :, ti]

            # make true targets bigwig
            bw_file = "%s/tracks/t%d_true.bw" % (options.out_dir, ti)
            bigwig_write(
                bw_file,
                test_targets_ti,
                options.track_bed,
                options.genome_file,
                bed_set=bed_set,
            )

            # make predictions bigwig
            bw_file = "%s/tracks/t%d_preds.bw" % (options.out_dir, ti)
            bigwig_write(
                bw_file,
                test_preds[:, :, ti],
                options.track_bed,
                options.genome_file,
                dr.hp.batch_buffer,
                bed_set=bed_set,
            )

        # make NA bigwig
        # bw_file = '%s/tracks/na.bw' % options.out_dir
        # bigwig_write(
        #     bw_file,
        #     test_na,
        #     options.track_bed,
        #     options.genome_file,
        #     bed_set=bed_set)

    #######################################################
    # accuracy plots

    if options.accuracy_indexes is not None:
        accuracy_indexes = [
            int(ti) for ti in options.accuracy_indexes.split(",")
        ]

        if not os.path.isdir("%s/scatter" % options.out_dir):
            os.mkdir("%s/scatter" % options.out_dir)

        if not os.path.isdir("%s/violin" % options.out_dir):
            os.mkdir("%s/violin" % options.out_dir)

        if not os.path.isdir("%s/roc" % options.out_dir):
            os.mkdir("%s/roc" % options.out_dir)

        if not os.path.isdir("%s/pr" % options.out_dir):
            os.mkdir("%s/pr" % options.out_dir)

        for ti in accuracy_indexes:
            if options.scent_file is not None:
                test_targets_ti = test_targets_full[:, :, ti]
            else:
                test_targets_ti = test_targets[:, :, ti]

            ############################################
            # scatter

            # sample every few bins (adjust to plot the # points I want)
            ds_indexes_preds = np.arange(0, test_preds.shape[1], 8)
            ds_indexes_targets = ds_indexes_preds + (dr.hp.batch_buffer //
                                                     dr.hp.target_pool)

            # subset and flatten
            test_targets_ti_flat = (test_targets_ti[:, ds_indexes_targets].
                                    flatten().astype("float32"))
            test_preds_ti_flat = (test_preds[:, ds_indexes_preds,
                                             ti].flatten().astype("float32"))

            # take log2
            test_targets_ti_log = np.log2(test_targets_ti_flat + 1)
            test_preds_ti_log = np.log2(test_preds_ti_flat + 1)

            # plot log2
            sns.set(font_scale=1.2, style="ticks")
            out_pdf = "%s/scatter/t%d.pdf" % (options.out_dir, ti)
            plots.regplot(
                test_targets_ti_log,
                test_preds_ti_log,
                out_pdf,
                poly_order=1,
                alpha=0.3,
                sample=500,
                figsize=(6, 6),
                x_label="log2 Experiment",
                y_label="log2 Prediction",
                table=True,
            )

            ############################################
            # violin

            # call peaks
            test_targets_ti_lambda = np.mean(test_targets_ti_flat)
            test_targets_pvals = 1 - poisson.cdf(
                np.round(test_targets_ti_flat) - 1, mu=test_targets_ti_lambda)
            test_targets_qvals = np.array(ben_hoch(test_targets_pvals))
            test_targets_peaks = test_targets_qvals < 0.01
            test_targets_peaks_str = np.where(test_targets_peaks, "Peak",
                                              "Background")

            # violin plot
            sns.set(font_scale=1.3, style="ticks")
            plt.figure()
            df = pd.DataFrame({
                "log2 Prediction":
                np.log2(test_preds_ti_flat + 1),
                "Experimental coverage status":
                test_targets_peaks_str,
            })
            ax = sns.violinplot(x="Experimental coverage status",
                                y="log2 Prediction",
                                data=df)
            ax.grid(True, linestyle=":")
            plt.savefig("%s/violin/t%d.pdf" % (options.out_dir, ti))
            plt.close()

            # ROC
            plt.figure()
            fpr, tpr, _ = roc_curve(test_targets_peaks, test_preds_ti_flat)
            auroc = roc_auc_score(test_targets_peaks, test_preds_ti_flat)
            plt.plot([0, 1], [0, 1],
                     c="black",
                     linewidth=1,
                     linestyle="--",
                     alpha=0.7)
            plt.plot(fpr, tpr, c="black")
            ax = plt.gca()
            ax.set_xlabel("False positive rate")
            ax.set_ylabel("True positive rate")
            ax.text(0.99,
                    0.02,
                    "AUROC %.3f" % auroc,
                    horizontalalignment="right")  # , fontsize=14)
            ax.grid(True, linestyle=":")
            plt.savefig("%s/roc/t%d.pdf" % (options.out_dir, ti))
            plt.close()

            # PR
            plt.figure()
            prec, recall, _ = precision_recall_curve(test_targets_peaks,
                                                     test_preds_ti_flat)
            auprc = average_precision_score(test_targets_peaks,
                                            test_preds_ti_flat)
            plt.axhline(
                y=test_targets_peaks.mean(),
                c="black",
                linewidth=1,
                linestyle="--",
                alpha=0.7,
            )
            plt.plot(recall, prec, c="black")
            ax = plt.gca()
            ax.set_xlabel("Recall")
            ax.set_ylabel("Precision")
            ax.text(0.99,
                    0.95,
                    "AUPRC %.3f" % auprc,
                    horizontalalignment="right")  # , fontsize=14)
            ax.grid(True, linestyle=":")
            plt.savefig("%s/pr/t%d.pdf" % (options.out_dir, ti))
            plt.close()

    data_open.close()
Beispiel #3
0
def run(params_file, data_file, train_epochs, train_epoch_batches,
        test_epoch_batches):

    #######################################################
    # load data
    #######################################################
    data_open = h5py.File(data_file)

    train_seqs = data_open['train_in']
    train_targets = data_open['train_out']
    train_na = None
    if 'train_na' in data_open:
        train_na = data_open['train_na']

    valid_seqs = data_open['valid_in']
    valid_targets = data_open['valid_out']
    valid_na = None
    if 'valid_na' in data_open:
        valid_na = data_open['valid_na']

    #######################################################
    # model parameters and placeholders
    #######################################################
    job = params.read_job_params(params_file)

    job['seq_length'] = train_seqs.shape[1]
    job['seq_depth'] = train_seqs.shape[2]
    job['num_targets'] = train_targets.shape[2]
    job['target_pool'] = int(np.array(data_open.get('pool_width', 1)))

    t0 = time.time()
    model = seqnn.SeqNN()
    model.build(job)
    print('Model building time %f' % (time.time() - t0))

    # adjust for fourier
    job['fourier'] = 'train_out_imag' in data_open
    if job['fourier']:
        train_targets_imag = data_open['train_out_imag']
        valid_targets_imag = data_open['valid_out_imag']

    #######################################################
    # prepare batcher
    #######################################################
    if job['fourier']:
        batcher_train = batcher.BatcherF(train_seqs,
                                         train_targets,
                                         train_targets_imag,
                                         train_na,
                                         model.hp.batch_size,
                                         model.hp.target_pool,
                                         shuffle=True)
        batcher_valid = batcher.BatcherF(valid_seqs, valid_targets,
                                         valid_targets_imag, valid_na,
                                         model.batch_size, model.target_pool)
    else:
        batcher_train = batcher.Batcher(train_seqs,
                                        train_targets,
                                        train_na,
                                        model.hp.batch_size,
                                        model.hp.target_pool,
                                        shuffle=True)
        batcher_valid = batcher.Batcher(valid_seqs, valid_targets, valid_na,
                                        model.hp.batch_size,
                                        model.hp.target_pool)
    print('Batcher initialized')

    #######################################################
    # train
    #######################################################
    augment_shifts = [int(shift) for shift in FLAGS.augment_shifts.split(',')]
    ensemble_shifts = [
        int(shift) for shift in FLAGS.ensemble_shifts.split(',')
    ]

    # checkpoints
    saver = tf.train.Saver()

    config = tf.ConfigProto()
    if FLAGS.log_device_placement:
        config.log_device_placement = True
    with tf.Session(config=config) as sess:
        t0 = time.time()

        # set seed
        tf.set_random_seed(FLAGS.seed)

        if FLAGS.logdir:
            train_writer = tf.summary.FileWriter(FLAGS.logdir + '/train',
                                                 sess.graph)
        else:
            train_writer = None

        if FLAGS.restart:
            # load variables into session
            saver.restore(sess, FLAGS.restart)
        else:
            # initialize variables
            print('Initializing...')
            sess.run(tf.global_variables_initializer())
            print('Initialization time %f' % (time.time() - t0))

        train_loss = None
        best_loss = None
        early_stop_i = 0

        epoch = 0
        while (train_epochs is None
               or epochs < train_epochs) and early_stop_i < FLAGS.early_stop:
            t0 = time.time()

            # alternate forward and reverse batches
            fwdrc = True
            if FLAGS.augment_rc and epoch % 2 == 1:
                fwdrc = False

            # cycle shifts
            shift_i = epoch % len(augment_shifts)

            # train
            train_loss, steps = model.train_epoch(
                sess,
                batcher_train,
                fwdrc=fwdrc,
                shift=augment_shifts[shift_i],
                sum_writer=train_writer,
                epoch_batches=train_epoch_batches,
                no_steps=FLAGS.no_steps)

            # validate
            valid_acc = model.test(sess,
                                   batcher_valid,
                                   mc_n=FLAGS.ensemble_mc,
                                   rc=FLAGS.ensemble_rc,
                                   shifts=ensemble_shifts,
                                   test_batches=test_epoch_batches)
            valid_loss = valid_acc.loss
            valid_r2 = valid_acc.r2().mean()
            del valid_acc

            best_str = ''
            if best_loss is None or valid_loss < best_loss:
                best_loss = valid_loss
                best_str = ', best!'
                early_stop_i = 0
                saver.save(sess, '%s/model_best.tf' % FLAGS.logdir)
            else:
                early_stop_i += 1

            # measure time
            et = time.time() - t0
            if et < 600:
                time_str = '%3ds' % et
            elif et < 6000:
                time_str = '%3dm' % (et / 60)
            else:
                time_str = '%3.1fh' % (et / 3600)

            # print update
            print(
                'Epoch: %3d,  Steps: %7d,  Train loss: %7.5f,  Valid loss: %7.5f,  Valid R2: %7.5f,  Time: %s%s'
                % (epoch + 1, steps, train_loss, valid_loss, valid_r2,
                   time_str, best_str))
            sys.stdout.flush()

            if FLAGS.check_all:
                saver.save(sess, '%s/model_check%d.tf' % (FLAGS.logdir, epoch))

            # update epoch
            epoch += 1

        if FLAGS.logdir:
            train_writer.close()
Beispiel #4
0
def run(params_file, data_file, num_train_epochs):
    shifts = [int(shift) for shift in FLAGS.shifts.split(',')]

    #######################################################
    # load data
    #######################################################
    data_open = h5py.File(data_file)

    train_seqs = data_open['train_in']
    train_targets = data_open['train_out']
    train_na = None
    if 'train_na' in data_open:
        train_na = data_open['train_na']

    valid_seqs = data_open['valid_in']
    valid_targets = data_open['valid_out']
    valid_na = None
    if 'valid_na' in data_open:
        valid_na = data_open['valid_na']

    #######################################################
    # model parameters and placeholders
    #######################################################
    job = dna_io.read_job_params(params_file)

    job['batch_length'] = train_seqs.shape[1]
    job['seq_depth'] = train_seqs.shape[2]
    job['num_targets'] = train_targets.shape[2]
    job['target_pool'] = int(np.array(data_open.get('pool_width', 1)))
    job['early_stop'] = job.get('early_stop', 16)
    job['rate_drop'] = job.get('rate_drop', 3)

    t0 = time.time()
    dr = seqnn.SeqNN()
    dr.build(job)
    print('Model building time %f' % (time.time() - t0))

    # adjust for fourier
    job['fourier'] = 'train_out_imag' in data_open
    if job['fourier']:
        train_targets_imag = data_open['train_out_imag']
        valid_targets_imag = data_open['valid_out_imag']

    #######################################################
    # train
    #######################################################
    # initialize batcher
    if job['fourier']:
        batcher_train = batcher.BatcherF(train_seqs,
                                         train_targets,
                                         train_targets_imag,
                                         train_na,
                                         dr.batch_size,
                                         dr.target_pool,
                                         shuffle=True)
        batcher_valid = batcher.BatcherF(valid_seqs, valid_targets,
                                         valid_targets_imag, valid_na,
                                         dr.batch_size, dr.target_pool)
    else:
        batcher_train = batcher.Batcher(train_seqs,
                                        train_targets,
                                        train_na,
                                        dr.batch_size,
                                        dr.target_pool,
                                        shuffle=True)
        batcher_valid = batcher.Batcher(valid_seqs, valid_targets, valid_na,
                                        dr.batch_size, dr.target_pool)
    print('Batcher initialized')

    # checkpoints
    saver = tf.train.Saver()

    config = tf.ConfigProto()
    if FLAGS.log_device_placement:
        config.log_device_placement = True
    with tf.Session(config=config) as sess:
        t0 = time.time()

        # set seed
        tf.set_random_seed(FLAGS.seed)

        if FLAGS.logdir:
            train_writer = tf.summary.FileWriter(FLAGS.logdir + '/train',
                                                 sess.graph)
        else:
            train_writer = None

        if FLAGS.restart:
            # load variables into session
            saver.restore(sess, FLAGS.restart)
        else:
            # initialize variables
            print('Initializing...')
            sess.run(tf.global_variables_initializer())
            print('Initialization time %f' % (time.time() - t0))

        train_loss = None
        best_loss = None
        early_stop_i = 0
        undroppable_counter = 3
        max_drops = 8
        num_drops = 0

        for epoch in range(num_train_epochs):
            if early_stop_i < job['early_stop'] or epoch < FLAGS.min_epochs:
                t0 = time.time()

                # save previous
                train_loss_last = train_loss

                # alternate forward and reverse batches
                fwdrc = True
                if FLAGS.rc and epoch % 2 == 1:
                    fwdrc = False

                # cycle shifts
                shift_i = epoch % len(shifts)

                # train
                train_loss = dr.train_epoch(sess, batcher_train, fwdrc,
                                            shifts[shift_i], train_writer)

                # validate
                valid_acc = dr.test(sess,
                                    batcher_valid,
                                    mc_n=FLAGS.mc_n,
                                    rc=FLAGS.rc,
                                    shifts=shifts)
                valid_loss = valid_acc.loss
                valid_r2 = valid_acc.r2().mean()
                del valid_acc

                best_str = ''
                if best_loss is None or valid_loss < best_loss:
                    best_loss = valid_loss
                    best_str = ', best!'
                    early_stop_i = 0
                    saver.save(
                        sess,
                        '%s/%s_best.tf' % (FLAGS.logdir, FLAGS.save_prefix))
                else:
                    early_stop_i += 1

                # measure time
                et = time.time() - t0
                if et < 600:
                    time_str = '%3ds' % et
                elif et < 6000:
                    time_str = '%3dm' % (et / 60)
                else:
                    time_str = '%3.1fh' % (et / 3600)

                # print update
                print(
                    'Epoch %3d: Train loss: %7.5f, Valid loss: %7.5f, Valid R2: %7.5f, Time: %s%s'
                    % (epoch + 1, train_loss, valid_loss, valid_r2, time_str,
                       best_str),
                    end='')

                # if training stagnant
                if FLAGS.learn_rate_drop and num_drops < max_drops and undroppable_counter == 0 and (
                        train_loss_last -
                        train_loss) / train_loss_last < 0.0002:
                    print(', rate drop', end='')
                    dr.drop_rate(2 / 3)
                    undroppable_counter = 1
                    num_drops += 1
                else:
                    undroppable_counter = max(0, undroppable_counter - 1)

                print('')
                sys.stdout.flush()

        if FLAGS.logdir:
            train_writer.close()