Ejemplo n.º 1
0
  def __getitem__(self, i):
    # acquire predictions, if needed
    if i >= self.stream_end:
      self.stream_start = self.stream_end
      self.stream_end = min(self.stream_start + self.stream_length,
                            self.seqs_1hot.shape[0])

      # subset sequences
      stream_seqs_1hot = self.seqs_1hot[self.stream_start:self.stream_end]

      # initialize batcher
      batcher = batcher.Batcher(
          stream_seqs_1hot, batch_size=self.model.hp.batch_size)

      # predict
      self.stream_preds = self.model.predict(self.sess, batcher, rc_avg=False)

    return self.stream_preds[i - self.stream_start]
Ejemplo n.º 2
0
  def __getitem__(self, i):
    # acquire predictions, if needed
    if i >= self.stream_end:
      self.stream_start = self.stream_end
      self.stream_end = min(self.stream_start + self.stream_length,
                            self.seqs_1hot.shape[0])

      # subset sequences
      stream_seqs_1hot = self.seqs_1hot[self.stream_start:self.stream_end]

      # initialize batcher
      batcher = batcher.Batcher(
          stream_seqs_1hot, batch_size=self.model.hp.batch_size)

      # predict
      self.stream_grads, self.stream_preds = self.model.gradients(
          self.sess, batcher, layers=[0], return_preds=True)

      # take first layer
      self.stream_grads = self.stream_grads[0]

    return self.stream_preds[i - self.stream_start], self.stream_grads[
        i - self.stream_start]
Ejemplo n.º 3
0
def main():
    usage = "usage: %prog [options] <params_file> <model_file> <genes_hdf5_file>"
    parser = OptionParser(usage)
    parser.add_option(
        "-g",
        dest="genome_file",
        default="%s/data/human.hg19.genome" % os.environ["BASENJIDIR"],
        help="Chromosome lengths file [Default: %default]",
    )
    parser.add_option("-l",
                      dest="gene_list",
                      help="Process only gene ids in the given file")
    parser.add_option(
        "-o",
        dest="out_dir",
        default="grad_mapg",
        help="Output directory [Default: %default]",
    )
    parser.add_option("-t",
                      dest="target_indexes",
                      default=None,
                      help="Target indexes to plot")
    (options, args) = parser.parse_args()

    if len(args) != 3:
        parser.error("Must provide parameters, model, and genomic position")
    else:
        params_file = args[0]
        model_file = args[1]
        genes_hdf5_file = args[2]

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    #################################################################
    # reads in genes HDF5

    gene_data = genedata.GeneData(genes_hdf5_file)

    # subset gene sequences
    genes_subset = set()
    if options.gene_list:
        for line in open(options.gene_list):
            genes_subset.add(line.rstrip())

        gene_data.subset_genes(genes_subset)
        print("Filtered to %d sequences" % gene_data.num_seqs)

    #######################################################
    # model parameters and placeholders

    job = params.read_job_params(params_file)

    job["seq_length"] = gene_data.seq_length
    job["seq_depth"] = gene_data.seq_depth
    job["target_pool"] = gene_data.pool_width

    if "num_targets" not in job:
        print(
            "Must specify number of targets (num_targets) in the parameters file.",
            file=sys.stderr,
        )
        exit(1)

    # set target indexes
    if options.target_indexes is not None:
        options.target_indexes = [
            int(ti) for ti in options.target_indexes.split(",")
        ]
        target_subset = options.target_indexes
    else:
        options.target_indexes = list(range(job["num_targets"]))
        target_subset = None

    # build model
    model = seqnn.SeqNN()
    model.build(job, target_subset=target_subset)

    # determine latest pre-dilated layer
    cnn_dilation = np.array([cp.dilation for cp in model.hp.cnn_params])
    dilated_mask = cnn_dilation > 1
    dilated_indexes = np.where(dilated_mask)[0]
    pre_dilated_layer = np.min(dilated_indexes)
    print("Pre-dilated layer: %d" % pre_dilated_layer)

    # build gradients ops
    t0 = time.time()
    print("Building target/position-specific gradient ops.", end="")
    model.build_grads_genes(gene_data.gene_seqs, layers=[pre_dilated_layer])
    print(" Done in %ds" % (time.time() - t0), flush=True)

    #######################################################
    # acquire gradients

    # initialize saver
    saver = tf.train.Saver()

    with tf.Session() as sess:
        # load variables into session
        saver.restore(sess, model_file)

        for si in range(gene_data.num_seqs):
            # initialize batcher
            batcher_si = batcher.Batcher(
                gene_data.seqs_1hot[si:si + 1],
                batch_size=model.hp.batch_size,
                pool_width=model.hp.target_pool,
            )

            # get layer representations
            t0 = time.time()
            print("Computing gradients.", end="", flush=True)
            batch_grads, batch_reprs = model.gradients_genes(
                sess, batcher_si, gene_data.gene_seqs[si:si + 1])
            print(" Done in %ds." % (time.time() - t0), flush=True)

            # only layer
            batch_reprs = batch_reprs[0]
            batch_grads = batch_grads[0]

            # G (TSSs) x T (targets) x P (seq position) x U (Units layer i)
            print("batch_grads", batch_grads.shape)
            pooled_length = batch_grads.shape[2]

            # S (sequences) x P (seq position) x U (Units layer i)
            print("batch_reprs", batch_reprs.shape)

            # write bigwigs
            t0 = time.time()
            print("Writing BigWigs.", end="", flush=True)

            # for each TSS
            for tss_i in range(batch_grads.shape[0]):
                tss = gene_data.gene_seqs[si].tss_list[tss_i]

                # for each target
                for tii in range(len(options.target_indexes)):
                    ti = options.target_indexes[tii]

                    # dot representation and gradient
                    batch_grads_score = np.multiply(
                        batch_reprs[0], batch_grads[tss_i,
                                                    tii, :, :]).sum(axis=1)

                    # open bigwig
                    bw_file = "%s/%s-%s_t%d.bw" % (
                        options.out_dir,
                        tss.gene_id,
                        tss.identifier,
                        ti,
                    )
                    bw_open = bigwig_open(bw_file, options.genome_file)

                    # access gene sequence information
                    seq_chrom = gene_data.gene_seqs[si].chrom
                    seq_start = gene_data.gene_seqs[si].start

                    # specify bigwig locations and values
                    bw_chroms = [seq_chrom] * pooled_length
                    bw_starts = [
                        int(seq_start + li * model.hp.target_pool)
                        for li in range(pooled_length)
                    ]
                    bw_ends = [
                        int(bws + model.hp.target_pool) for bws in bw_starts
                    ]
                    bw_values = [float(bgs) for bgs in batch_grads_score]

                    # write
                    bw_open.addEntries(bw_chroms,
                                       bw_starts,
                                       ends=bw_ends,
                                       values=bw_values)

                    # close
                    bw_open.close()

            print(" Done in %ds." % (time.time() - t0), flush=True)
            gc.collect()
Ejemplo n.º 4
0
def main():
    usage = (
        'usage: %prog [options] <params_file> <model_file> <genes_hdf5_file>'
        ' <vcf_file>')
    parser = OptionParser(usage)
    parser.add_option(
        '-a',
        dest='all_sed',
        default=False,
        action='store_true',
        help=
        'Print all variant-gene pairs, as opposed to only nonzero [Default: %default]'
    )
    parser.add_option('-b',
                      dest='batch_size',
                      default=None,
                      type='int',
                      help='Batch size [Default: %default]')
    parser.add_option('-c',
                      dest='csv',
                      default=False,
                      action='store_true',
                      help='Print table as CSV [Default: %default]')
    parser.add_option('-g',
                      dest='genome_file',
                      default='%s/assembly/human.hg19.genome' %
                      os.environ['HG19'],
                      help='Chromosome lengths file [Default: %default]')
    parser.add_option(
        '-o',
        dest='out_dir',
        default='sed',
        help='Output directory for tables and plots [Default: %default]')
    parser.add_option('-p',
                      dest='processes',
                      default=None,
                      type='int',
                      help='Number of processes, passed by multi script')
    parser.add_option('--pseudo',
                      dest='log_pseudo',
                      default=0.125,
                      type='float',
                      help='Log2 pseudocount [Default: %default]')
    parser.add_option(
        '-r',
        dest='tss_radius',
        default=0,
        type='int',
        help=
        'Radius of bins considered to quantify TSS transcription [Default: %default]'
    )
    parser.add_option(
        '--rc',
        dest='rc',
        default=False,
        action='store_true',
        help=
        'Average the forward and reverse complement predictions when testing [Default: %default]'
    )
    parser.add_option('--shifts',
                      dest='shifts',
                      default='0',
                      help='Ensemble prediction shifts [Default: %default]')
    parser.add_option(
        '-t',
        dest='targets_file',
        default=None,
        help='File specifying target indexes and labels in table format')
    parser.add_option(
        '-u',
        dest='penultimate',
        default=False,
        action='store_true',
        help='Compute SED in the penultimate layer [Default: %default]')
    parser.add_option(
        '-x',
        dest='tss_table',
        default=False,
        action='store_true',
        help='Print TSS table in addition to gene [Default: %default]')
    (options, args) = parser.parse_args()

    if len(args) == 4:
        # single worker
        params_file = args[0]
        model_file = args[1]
        genes_hdf5_file = args[2]
        vcf_file = args[3]

    elif len(args) == 6:
        # multi worker
        options_pkl_file = args[0]
        params_file = args[1]
        model_file = args[2]
        genes_hdf5_file = args[3]
        vcf_file = args[4]
        worker_index = int(args[5])

        # load options
        options_pkl = open(options_pkl_file, 'rb')
        options = pickle.load(options_pkl)
        options_pkl.close()

        # update output directory
        options.out_dir = '%s/job%d' % (options.out_dir, worker_index)

    else:
        parser.error(
            'Must provide parameters and model files, genes HDF5 file, and QTL VCF'
            ' file')

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    options.shifts = [int(shift) for shift in options.shifts.split(',')]

    #################################################################
    # reads in genes HDF5

    gene_data = genedata.GeneData(genes_hdf5_file)

    # filter for worker sequences
    if options.processes is not None:
        gene_data.worker(worker_index, options.processes)

    #################################################################
    # prep SNPs

    # load SNPs
    snps = bvcf.vcf_snps(vcf_file)

    # intersect w/ segments
    print('Intersecting gene sequences with SNPs...', end='')
    sys.stdout.flush()
    seqs_snps = bvcf.intersect_seqs_snps(vcf_file,
                                         gene_data.gene_seqs,
                                         vision_p=0.5)
    print('%d sequences w/ SNPs' % len(seqs_snps))

    #################################################################
    # setup model

    job = params.read_job_params(params_file)

    job['seq_length'] = gene_data.seq_length
    job['seq_depth'] = gene_data.seq_depth
    job['target_pool'] = gene_data.pool_width

    if 'num_targets' not in job and gene_data.num_targets is not None:
        job['num_targets'] = gene_data.num_targets

    if 'num_targets' not in job:
        print("Must specify number of targets (num_targets) in the \
            parameters file.",
              file=sys.stderr)
        exit(1)

    if options.targets_file is None:
        target_labels = gene_data.target_labels
        target_subset = None
        target_ids = gene_data.target_ids
        if target_ids is None:
            target_ids = ['t%d' % ti for ti in range(job['num_targets'])]
            target_labels = [''] * len(target_ids)

    else:
        # Unfortunately, this target file differs from some others
        # in that it needs to specify the indexes from the original
        # set. In the future, I should standardize to this version.

        target_ids = []
        target_labels = []
        target_subset = []

        for line in open(options.targets_file):
            a = line.strip().split('\t')
            target_subset.append(int(a[0]))
            target_ids.append(a[1])
            target_labels.append(a[3])

            if len(target_subset) == job['num_targets']:
                target_subset = None

    # build model
    model = seqnn.SeqNN()
    model.build(job, target_subset=target_subset)

    if options.penultimate:
        # labels become inappropriate
        target_ids = [''] * model.hp.cnn_filters[-1]
        target_labels = target_ids

    #################################################################
    # compute, collect, and print SEDs

    header_cols = ('rsid', 'ref', 'alt', 'gene', 'tss_dist', 'ref_pred',
                   'alt_pred', 'sed', 'ser', 'target_index', 'target_id',
                   'target_label')
    if options.csv:
        sed_gene_out = open('%s/sed_gene.csv' % options.out_dir, 'w')
        print(','.join(header_cols), file=sed_gene_out)
        if options.tss_table:
            sed_tss_out = open('%s/sed_tss.csv' % options.out_dir, 'w')
            print(','.join(header_cols), file=sed_tss_out)

    else:
        sed_gene_out = open('%s/sed_gene.txt' % options.out_dir, 'w')
        print(' '.join(header_cols), file=sed_gene_out)
        if options.tss_table:
            sed_tss_out = open('%s/sed_tss.txt' % options.out_dir, 'w')
            print(' '.join(header_cols), file=sed_tss_out)

    # helper variables
    pred_buffer = model.hp.batch_buffer // model.hp.target_pool

    # initialize saver
    saver = tf.train.Saver()

    with tf.Session() as sess:
        # load variables into session
        saver.restore(sess, model_file)

        # for each gene sequence
        for seq_i in range(gene_data.num_seqs):
            gene_seq = gene_data.gene_seqs[seq_i]
            print(gene_seq)

            # if it contains SNPs
            seq_snps = [snps[snp_i] for snp_i in seqs_snps[seq_i]]
            if len(seq_snps) > 0:

                # one hot code allele sequences
                aseqs_1hot = alleles_1hot(gene_seq, gene_data.seqs_1hot[seq_i],
                                          seq_snps)

                # initialize batcher
                batcher_gene = batcher.Batcher(aseqs_1hot,
                                               batch_size=model.hp.batch_size)

                # construct allele gene_seq's
                allele_gene_seqs = [gene_seq] * aseqs_1hot.shape[0]

                # predict alleles
                allele_tss_preds = model.predict_genes(
                    sess,
                    batcher_gene,
                    allele_gene_seqs,
                    rc=options.rc,
                    shifts=options.shifts,
                    penultimate=options.penultimate,
                    tss_radius=options.tss_radius)

                # reshape (Alleles x TSSs) x Targets to Alleles x TSSs x Targets
                allele_tss_preds = allele_tss_preds.reshape(
                    (aseqs_1hot.shape[0], gene_seq.num_tss, -1))

                # extract reference and SNP alt predictions
                ref_tss_preds = allele_tss_preds[0]  # TSSs x Targets
                alt_tss_preds = allele_tss_preds[1:]  # SNPs x TSSs x Targets

                # compute TSS SED scores
                snp_tss_sed = alt_tss_preds - ref_tss_preds
                snp_tss_ser = np.log2(alt_tss_preds + options.log_pseudo) \
                                - np.log2(ref_tss_preds + options.log_pseudo)

                # compute gene-level predictions
                ref_gene_preds, gene_ids = gene.map_tss_genes(
                    ref_tss_preds, gene_seq.tss_list, options.tss_radius)
                alt_gene_preds = []
                for snp_i in range(len(seq_snps)):
                    agp, _ = gene.map_tss_genes(alt_tss_preds[snp_i],
                                                gene_seq.tss_list,
                                                options.tss_radius)
                    alt_gene_preds.append(agp)
                alt_gene_preds = np.array(alt_gene_preds)

                # compute gene SED scores
                gene_sed = alt_gene_preds - ref_gene_preds
                gene_ser = np.log2(alt_gene_preds + options.log_pseudo) \
                            - np.log2(ref_gene_preds + options.log_pseudo)

                # for each SNP
                for snp_i in range(len(seq_snps)):
                    snp = seq_snps[snp_i]

                    # initialize gene data structures
                    snp_dist_gene = {}

                    # for each TSS
                    for tss_i in range(gene_seq.num_tss):
                        tss = gene_seq.tss_list[tss_i]

                        # SNP distance to TSS
                        snp_dist = abs(snp.pos - tss.pos)
                        if tss.gene_id in snp_dist_gene:
                            snp_dist_gene[tss.gene_id] = min(
                                snp_dist_gene[tss.gene_id], snp_dist)
                        else:
                            snp_dist_gene[tss.gene_id] = snp_dist

                        # for each target
                        if options.tss_table:
                            for ti in range(ref_tss_preds.shape[1]):

                                # check if nonzero
                                if options.all_sed or not np.isclose(
                                        tss_sed[snp_i, tss_i,
                                                ti], 0, atol=1e-4):

                                    # print
                                    cols = (snp.rsid,
                                            bvcf.cap_allele(snp.ref_allele),
                                            bvcf.cap_allele(
                                                snp.alt_alleles[0]),
                                            tss.identifier, snp_dist,
                                            ref_tss_preds[tss_i, ti],
                                            alt_tss_preds[snp_i, tss_i, ti],
                                            tss_sed[snp_i, tss_i,
                                                    ti], tss_ser[snp_i, tss_i,
                                                                 ti], ti,
                                            target_ids[ti], target_labels[ti])
                                    if options.csv:
                                        print(','.join([str(c) for c in cols]),
                                              file=sed_tss_out)
                                    else:
                                        print(
                                            '%-13s %s %5s %16s %5d %7.4f %7.4f %7.4f %7.4f %4d %12s %s'
                                            % cols,
                                            file=sed_tss_out)

                    # for each gene
                    for gi in range(len(gene_ids)):
                        gene_str = gene_ids[gi]
                        if gene_ids[gi] in gene_data.multi_seq_genes:
                            gene_str = '%s_multi' % gene_ids[gi]

                        # print rows to gene table
                        for ti in range(ref_gene_preds.shape[1]):

                            # check if nonzero
                            if options.all_sed or not np.isclose(
                                    gene_sed[snp_i, gi, ti], 0, atol=1e-4):

                                # print
                                cols = [
                                    snp.rsid,
                                    bvcf.cap_allele(snp.ref_allele),
                                    bvcf.cap_allele(snp.alt_alleles[0]),
                                    gene_str, snp_dist_gene[gene_ids[gi]],
                                    ref_gene_preds[gi,
                                                   ti], alt_gene_preds[snp_i,
                                                                       gi, ti],
                                    gene_sed[snp_i, gi,
                                             ti], gene_ser[snp_i, gi, ti], ti,
                                    target_ids[ti], target_labels[ti]
                                ]
                                if options.csv:
                                    print(','.join([str(c) for c in cols]),
                                          file=sed_gene_out)
                                else:
                                    print(
                                        '%-13s %s %5s %16s %5d %7.4f %7.4f %7.4f %7.4f %4d %12s %s'
                                        % tuple(cols),
                                        file=sed_gene_out)

                # clean up
                gc.collect()

    sed_gene_out.close()
    if options.tss_table:
        sed_tss_out.close()
Ejemplo n.º 5
0
def main():
    usage = "usage: %prog [options] <params_file> <model_file> <vcf_file>"
    parser = OptionParser(usage)
    parser.add_option(
        "-b",
        dest="batch_size",
        default=256,
        type="int",
        help="Batch size [Default: %default]",
    )
    parser.add_option(
        "-c",
        dest="csv",
        default=False,
        action="store_true",
        help="Print table as CSV [Default: %default]",
    )
    parser.add_option(
        "-f",
        dest="genome_fasta",
        default="%s/data/hg19.fa" % os.environ["BASENJIDIR"],
        help="Genome FASTA for sequences [Default: %default]",
    )
    parser.add_option(
        "-g",
        dest="genome_file",
        default="%s/data/human.hg19.genome" % os.environ["BASENJIDIR"],
        help="Chromosome lengths file [Default: %default]",
    )
    parser.add_option(
        "--h5",
        dest="out_h5",
        default=False,
        action="store_true",
        help="Output stats to sad.h5 [Default: %default]",
    )
    parser.add_option(
        "--local",
        dest="local",
        default=1024,
        type="int",
        help="Local SAD score [Default: %default]",
    )
    parser.add_option("-n",
                      dest="norm_file",
                      default=None,
                      help="Normalize SAD scores")
    parser.add_option(
        "-o",
        dest="out_dir",
        default="sad",
        help="Output directory for tables and plots [Default: %default]",
    )
    parser.add_option(
        "-p",
        dest="processes",
        default=None,
        type="int",
        help="Number of processes, passed by multi script",
    )
    parser.add_option(
        "--pseudo",
        dest="log_pseudo",
        default=1,
        type="float",
        help="Log2 pseudocount [Default: %default]",
    )
    parser.add_option(
        "--rc",
        dest="rc",
        default=False,
        action="store_true",
        help=
        "Average forward and reverse complement predictions [Default: %default]",
    )
    parser.add_option(
        "--shifts",
        dest="shifts",
        default="0",
        type="str",
        help="Ensemble prediction shifts [Default: %default]",
    )
    parser.add_option(
        "--stats",
        dest="sad_stats",
        default="SAD,xSAR",
        help="Comma-separated list of stats to save. [Default: %default]",
    )
    parser.add_option(
        "-t",
        dest="targets_file",
        default=None,
        type="str",
        help="File specifying target indexes and labels in table format",
    )
    parser.add_option(
        "--ti",
        dest="track_indexes",
        default=None,
        type="str",
        help="Comma-separated list of target indexes to output BigWig tracks",
    )
    parser.add_option(
        "-u",
        dest="penultimate",
        default=False,
        action="store_true",
        help="Compute SED in the penultimate layer [Default: %default]",
    )
    parser.add_option(
        "-z",
        dest="out_zarr",
        default=False,
        action="store_true",
        help="Output stats to sad.zarr [Default: %default]",
    )
    (options, args) = parser.parse_args()

    if len(args) == 3:
        # single worker
        params_file = args[0]
        model_file = args[1]
        vcf_file = args[2]

    elif len(args) == 5:
        # multi worker
        options_pkl_file = args[0]
        params_file = args[1]
        model_file = args[2]
        vcf_file = args[3]
        worker_index = int(args[4])

        # load options
        options_pkl = open(options_pkl_file, "rb")
        options = pickle.load(options_pkl)
        options_pkl.close()

        # update output directory
        options.out_dir = "%s/job%d" % (options.out_dir, worker_index)

    else:
        parser.error(
            "Must provide parameters and model files and QTL VCF file")

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    if options.track_indexes is None:
        options.track_indexes = []
    else:
        options.track_indexes = [
            int(ti) for ti in options.track_indexes.split(",")
        ]
        if not os.path.isdir("%s/tracks" % options.out_dir):
            os.mkdir("%s/tracks" % options.out_dir)

    options.shifts = [int(shift) for shift in options.shifts.split(",")]
    options.sad_stats = options.sad_stats.split(",")

    #################################################################
    # setup model

    job = params.read_job_params(
        params_file, require=["seq_length", "num_targets", "target_pool"])

    if options.targets_file is None:
        target_ids = ["t%d" % ti for ti in range(job["num_targets"])]
        target_labels = [""] * len(target_ids)
        target_subset = None

    else:
        targets_df = pd.read_table(options.targets_file, index_col=0)
        target_ids = targets_df.identifier
        target_labels = targets_df.description
        target_subset = targets_df.index
        if len(target_subset) == job["num_targets"]:
            target_subset = None

    # build model
    t0 = time.time()
    model = seqnn.SeqNN()
    model.build_feed(
        job,
        ensemble_rc=options.rc,
        ensemble_shifts=options.shifts,
        embed_penultimate=options.penultimate,
        target_subset=target_subset,
    )
    print("Model building time %f" % (time.time() - t0), flush=True)

    if options.penultimate:
        # labels become inappropriate
        target_ids = [""] * model.hp.cnn_filters[-1]
        target_labels = target_ids

    # read target normalization factors
    target_norms = np.ones(len(target_labels))
    if options.norm_file is not None:
        ti = 0
        for line in open(options.norm_file):
            target_norms[ti] = float(line.strip())
            ti += 1

    num_targets = len(target_ids)

    #################################################################
    # load SNPs

    snps = bvcf.vcf_snps(vcf_file)

    # filter for worker SNPs
    if options.processes is not None:
        worker_bounds = np.linspace(0,
                                    len(snps),
                                    options.processes + 1,
                                    dtype="int")
        snps = snps[worker_bounds[worker_index]:worker_bounds[worker_index +
                                                              1]]

    num_snps = len(snps)

    #################################################################
    # setup output

    header_cols = (
        "rsid",
        "ref",
        "alt",
        "ref_pred",
        "alt_pred",
        "sad",
        "sar",
        "geo_sad",
        "ref_lpred",
        "alt_lpred",
        "lsad",
        "lsar",
        "ref_xpred",
        "alt_xpred",
        "xsad",
        "xsar",
        "target_index",
        "target_id",
        "target_label",
    )

    if options.out_h5:
        sad_out = initialize_output_h5(options.out_dir, options.sad_stats,
                                       snps, target_ids, target_labels)

    elif options.out_zarr:
        sad_out = initialize_output_zarr(options.out_dir, options.sad_stats,
                                         snps, target_ids, target_labels)

    else:
        if options.csv:
            sad_out = open("%s/sad_table.csv" % options.out_dir, "w")
            print(",".join(header_cols), file=sad_out)
        else:
            sad_out = open("%s/sad_table.txt" % options.out_dir, "w")
            print(" ".join(header_cols), file=sad_out)

    #################################################################
    # process

    # open genome FASTA
    genome_open = pysam.Fastafile(options.genome_fasta)

    # determine local start and end
    loc_mid = model.target_length // 2
    loc_start = loc_mid - (options.local // 2) // model.hp.target_pool
    loc_end = loc_start + options.local // model.hp.target_pool

    snp_i = 0
    szi = 0

    # initialize saver
    saver = tf.train.Saver()
    with tf.Session() as sess:
        # load variables into session
        saver.restore(sess, model_file)

        # construct first batch
        batch_1hot, batch_snps, snp_i = snps_next_batch(
            snps, snp_i, options.batch_size, job["seq_length"], genome_open)

        while len(batch_snps) > 0:
            ###################################################
            # predict

            # initialize batcher
            batcher = batcher.Batcher(batch_1hot,
                                      batch_size=model.hp.batch_size)

            # predict
            # batch_preds = model.predict(sess, batcher,
            #                 rc=options.rc, shifts=options.shifts,
            #                 penultimate=options.penultimate)
            batch_preds = model.predict_h5(sess, batcher)

            # normalize
            batch_preds /= target_norms

            ###################################################
            # collect and print SADs

            pi = 0
            for snp in batch_snps:
                # get reference prediction (LxT)
                ref_preds = batch_preds[pi]
                pi += 1

                # sum across length
                ref_preds_sum = ref_preds.sum(axis=0, dtype="float64")

                # print tracks
                for ti in options.track_indexes:
                    ref_bw_file = "%s/tracks/%s_t%d_ref.bw" % (
                        options.out_dir,
                        snp.rsid,
                        ti,
                    )
                    bigwig_write(
                        snp,
                        job["seq_length"],
                        ref_preds[:, ti],
                        model,
                        ref_bw_file,
                        options.genome_file,
                    )

                for alt_al in snp.alt_alleles:
                    # get alternate prediction (LxT)
                    alt_preds = batch_preds[pi]
                    pi += 1

                    # sum across length
                    alt_preds_sum = alt_preds.sum(axis=0, dtype="float64")

                    # compare reference to alternative via mean subtraction
                    sad_vec = alt_preds - ref_preds
                    sad = alt_preds_sum - ref_preds_sum

                    # compare reference to alternative via mean log division
                    sar = np.log2(alt_preds_sum +
                                  options.log_pseudo) - np.log2(
                                      ref_preds_sum + options.log_pseudo)

                    # compare geometric means
                    sar_vec = np.log2(
                        alt_preds.astype("float64") +
                        options.log_pseudo) - np.log2(
                            ref_preds.astype("float64") + options.log_pseudo)
                    geo_sad = sar_vec.sum(axis=0)

                    # sum locally
                    ref_preds_loc = ref_preds[loc_start:loc_end, :].sum(
                        axis=0, dtype="float64")
                    alt_preds_loc = alt_preds[loc_start:loc_end, :].sum(
                        axis=0, dtype="float64")

                    # compute SAD locally
                    sad_loc = alt_preds_loc - ref_preds_loc
                    sar_loc = np.log2(alt_preds_loc +
                                      options.log_pseudo) - np.log2(
                                          ref_preds_loc + options.log_pseudo)

                    # compute max difference position
                    max_li = np.argmax(np.abs(sar_vec), axis=0)

                    if options.out_h5 or options.out_zarr:
                        sad_out["SAD"][szi, :] = sad.astype("float16")
                        sad_out["xSAR"][szi, :] = np.array(
                            [
                                sar_vec[max_li[ti], ti]
                                for ti in range(num_targets)
                            ],
                            dtype="float16",
                        )
                        szi += 1

                    else:
                        # print table lines
                        for ti in range(len(sad)):
                            # print line
                            cols = (
                                snp.rsid,
                                bvcf.cap_allele(snp.ref_allele),
                                bvcf.cap_allele(alt_al),
                                ref_preds_sum[ti],
                                alt_preds_sum[ti],
                                sad[ti],
                                sar[ti],
                                geo_sad[ti],
                                ref_preds_loc[ti],
                                alt_preds_loc[ti],
                                sad_loc[ti],
                                sar_loc[ti],
                                ref_preds[max_li[ti], ti],
                                alt_preds[max_li[ti], ti],
                                sad_vec[max_li[ti], ti],
                                sar_vec[max_li[ti], ti],
                                ti,
                                target_ids[ti],
                                target_labels[ti],
                            )
                            if options.csv:
                                print(",".join([str(c) for c in cols]),
                                      file=sad_out)
                            else:
                                print(
                                    "%-13s %6s %6s | %8.2f %8.2f %8.3f %7.4f %7.3f | %7.3f %7.3f %7.3f %7.4f | %7.3f %7.3f %7.3f %7.4f | %4d %12s %s"
                                    % cols,
                                    file=sad_out,
                                )

                    # print tracks
                    for ti in options.track_indexes:
                        alt_bw_file = "%s/tracks/%s_t%d_alt.bw" % (
                            options.out_dir,
                            snp.rsid,
                            ti,
                        )
                        bigwig_write(
                            snp,
                            job["seq_length"],
                            alt_preds[:, ti],
                            model,
                            alt_bw_file,
                            options.genome_file,
                        )

            ###################################################
            # construct next batch

            batch_1hot, batch_snps, snp_i = snps_next_batch(
                snps, snp_i, options.batch_size, job["seq_length"],
                genome_open)

    ###################################################
    # compute SAD distributions across variants

    if options.out_h5 or options.out_zarr:
        # define percentiles
        d_fine = 0.001
        d_coarse = 0.01
        percentiles_neg = np.arange(d_fine, 0.1, d_fine)
        percentiles_base = np.arange(0.1, 0.9, d_coarse)
        percentiles_pos = np.arange(0.9, 1, d_fine)

        percentiles = np.concatenate(
            [percentiles_neg, percentiles_base, percentiles_pos])
        sad_out.create_dataset("percentiles", data=percentiles)
        pct_len = len(percentiles)

        for sad_stat in options.sad_stats:
            sad_stat_pct = "%s_pct" % sad_stat

            # compute
            sad_pct = np.percentile(sad_out[sad_stat],
                                    100 * percentiles,
                                    axis=0).T
            sad_pct = sad_pct.astype("float16")

            # save
            sad_out.create_dataset(sad_stat_pct, data=sad_pct, dtype="float16")

    if not options.out_zarr:
        sad_out.close()
Ejemplo n.º 6
0
def main():
  usage = 'usage: %prog [options] <params_file> <model_file> <vcf_file>'
  parser = OptionParser(usage)
  parser.add_option('-f', dest='figure_width',
      default=20, type='float',
      help='Figure width [Default: %default]')
  parser.add_option('--f1', dest='genome1_fasta',
      default='%s/data/hg19.fa' % os.environ['BASENJIDIR'],
      help='Genome FASTA which which major allele sequences will be drawn')
  parser.add_option('--f2', dest='genome2_fasta',
      default=None,
      help='Genome FASTA which which minor allele sequences will be drawn')
  parser.add_option('-g', dest='gain',
      default=False, action='store_true',
      help='Draw a sequence logo for the gain score, too [Default: %default]')
  # parser.add_option('-k', dest='plot_k',
  #     default=None, type='int',
  #     help='Plot the top k targets at each end.')
  parser.add_option('-l', dest='satmut_len',
      default=200, type='int',
      help='Length of centered sequence to mutate [Default: %default]')
  parser.add_option('--mean', dest='mean_targets',
      default=False, action='store_true',
      help='Take the mean across targets for a single plot [Default: %default]')
  parser.add_option('-m', dest='mc_n',
      default=0, type='int',
      help='Monte carlo iterations [Default: %default]')
  parser.add_option('--min', dest='min_limit',
      default=0.01, type='float',
      help='Minimum heatmap limit [Default: %default]')
  parser.add_option('-n', dest='load_sat_npy',
      default=False, action='store_true',
      help='Load the predictions from .npy files [Default: %default]')
  parser.add_option('-o', dest='out_dir',
      default='sat_vcf',
      help='Output directory [Default: %default]')
  parser.add_option('--rc', dest='rc',
      default=False, action='store_true',
      help='Ensemble forward and reverse complement predictions [Default: %default]')
  parser.add_option('--shifts', dest='shifts',
      default='0',
      help='Ensemble prediction shifts [Default: %default]')
  parser.add_option('-t', dest='targets_file',
      default=None, type='str',
      help='File specifying target indexes and labels in table format')
  (options, args) = parser.parse_args()

  if len(args) != 3:
    parser.error('Must provide parameters and model files and VCF')
  else:
    params_file = args[0]
    model_file = args[1]
    vcf_file = args[2]

  if not os.path.isdir(options.out_dir):
    os.mkdir(options.out_dir)

  options.shifts = [int(shift) for shift in options.shifts.split(',')]

  #################################################################
  # prep SNP sequences
  #################################################################

  # read parameters
  job = params.read_job_params(params_file, require=['seq_length', 'num_targets'])

  # load SNPs
  snps = vcf.vcf_snps(vcf_file)

  # get one hot coded input sequences
  if not options.genome2_fasta:
    seqs_1hot, seq_headers, snps, seqs = vcf.snps_seq1(
        snps, job['seq_length'], options.genome1_fasta, return_seqs=True)
  else:
    seqs_1hot, seq_headers, snps, seqs = vcf.snps2_seq1(
        snps, job['seq_length'], options.genome1_fasta,
        options.genome2_fasta, return_seqs=True)

  seqs_n = seqs_1hot.shape[0]

  #################################################################
  # setup model
  #################################################################

  if options.targets_file is None:
    target_ids = ['t%d' % ti for ti in range(job['num_targets'])]
    target_labels = ['']*len(target_ids)
    target_subset = None
    num_targets = job['num_targets']

  else:
    targets_df = pd.read_csv(options.targets_file, sep='\t', index_col=0)
    target_ids = targets_df.identifier
    target_labels = targets_df.description
    target_subset = targets_df.index
    if len(target_subset) == job['num_targets']:
        target_subset = None
    num_targets = len(target_subset)

  if not options.load_sat_npy:
    # build model
    model = seqnn.SeqNN()
    model.build_feed_sad(job, ensemble_rc=options.rc,
        ensemble_shifts=options.shifts, target_subset=target_subset)

    # initialize saver
    saver = tf.train.Saver()

  #################################################################
  # predict and process
  #################################################################

  with tf.Session() as sess:
    if not options.load_sat_npy:
      # load variables into session
      saver.restore(sess, model_file)

    for si in range(seqs_n):
      header = seq_headers[si]
      header_fs = fs_clean(header)

      print('Mutating sequence %d / %d' % (si + 1, seqs_n), flush=True)

      # write sequence
      fasta_out = open('%s/seq%d.fa' % (options.out_dir, si), 'w')
      end_len = (len(seqs[si]) - options.satmut_len) // 2
      print('>seq%d\n%s' % (si, seqs[si][end_len:-end_len]), file=fasta_out)
      fasta_out.close()

      #################################################################
      # predict modifications

      if options.load_sat_npy:
        sat_preds = np.load('%s/seq%d_preds.npy' % (options.out_dir, si))

      else:
        # supplement with saturated mutagenesis
        sat_seqs_1hot = satmut_seqs(seqs_1hot[si:si + 1], options.satmut_len)

        # initialize batcher
        batcher_sat = batcher.Batcher(
            sat_seqs_1hot, batch_size=model.hp.batch_size)

        # predict
        sat_preds = model.predict_h5(sess, batcher_sat)
        np.save('%s/seq%d_preds.npy' % (options.out_dir, si), sat_preds)

      if options.mean_targets:
        sat_preds = np.mean(sat_preds, axis=-1, keepdims=True)
        num_targets = 1

      #################################################################
      # compute delta, loss, and gain matrices

      # compute the matrix of prediction deltas: (4 x L_sm x T) array
      sat_delta = delta_matrix(seqs_1hot[si], sat_preds, options.satmut_len)

      # sat_loss, sat_gain = loss_gain(sat_delta, sat_preds[si], options.satmut_len)
      sat_loss = sat_delta.min(axis=0)
      sat_gain = sat_delta.max(axis=0)

      ##############################################
      # plot

      for ti in range(num_targets):
        # setup plot
        sns.set(style='white', font_scale=1)
        spp = subplot_params(sat_delta.shape[1])

        if options.gain:
          plt.figure(figsize=(options.figure_width, 4))
          ax_logo_loss = plt.subplot2grid(
              (4, spp['heat_cols']), (0, spp['logo_start']),
              colspan=spp['logo_span'])
          ax_logo_gain = plt.subplot2grid(
              (4, spp['heat_cols']), (1, spp['logo_start']),
              colspan=spp['logo_span'])
          ax_sad = plt.subplot2grid(
              (4, spp['heat_cols']), (2, spp['sad_start']),
              colspan=spp['sad_span'])
          ax_heat = plt.subplot2grid(
              (4, spp['heat_cols']), (3, 0), colspan=spp['heat_cols'])
        else:
          plt.figure(figsize=(options.figure_width, 3))
          ax_logo_loss = plt.subplot2grid(
              (3, spp['heat_cols']), (0, spp['logo_start']),
              colspan=spp['logo_span'])
          ax_sad = plt.subplot2grid(
              (3, spp['heat_cols']), (1, spp['sad_start']),
              colspan=spp['sad_span'])
          ax_heat = plt.subplot2grid(
              (3, spp['heat_cols']), (2, 0), colspan=spp['heat_cols'])

        # plot sequence logo
        plot_seqlogo(ax_logo_loss, seqs_1hot[si], -sat_loss[:, ti])
        if options.gain:
          plot_seqlogo(ax_logo_gain, seqs_1hot[si], sat_gain[:, ti])

        # plot SAD
        plot_sad(ax_sad, sat_loss[:, ti], sat_gain[:, ti])

        # plot heat map
        plot_heat(ax_heat, sat_delta[:, :, ti], options.min_limit)

        plt.tight_layout()
        plt.savefig('%s/%s_t%d.pdf' % (options.out_dir, header_fs, target_subset[ti]), dpi=600)
        plt.close()
Ejemplo n.º 7
0
def main():
    usage = "usage: %prog [options] <params_file> <model_file> <genes_hdf5_file>"
    parser = OptionParser(usage)
    parser.add_option(
        "-b",
        dest="batch_size",
        default=None,
        type="int",
        help="Batch size [Default: %default]",
    )
    parser.add_option(
        "-i",
        dest="ignore_bed",
        help="Ignore genes overlapping regions in this BED file",
    )
    parser.add_option("-l", dest="load_preds", help="Load tess_preds from file")
    parser.add_option(
        "--heat",
        dest="plot_heat",
        default=False,
        action="store_true",
        help="Plot big gene-target heatmaps [Default: %default]",
    )
    parser.add_option(
        "-o",
        dest="out_dir",
        default="genes_out",
        help="Output directory for tables and plots [Default: %default]",
    )
    parser.add_option(
        "-r",
        dest="tss_radius",
        default=0,
        type="int",
        help="Radius of bins considered to quantify TSS transcription [Default: %default]",
    )
    parser.add_option(
        "--rc",
        dest="rc",
        default=False,
        action="store_true",
        help="Average the forward and reverse complement predictions when testing [Default: %default]",
    )
    parser.add_option(
        "-s",
        dest="plot_scatter",
        default=False,
        action="store_true",
        help="Make time-consuming accuracy scatter plots [Default: %default]",
    )
    parser.add_option(
        "--shifts",
        dest="shifts",
        default="0",
        help="Ensemble prediction shifts [Default: %default]",
    )
    parser.add_option(
        "--rep",
        dest="replicate_labels_file",
        help="Compare replicate experiments, aided by the given file with long labels",
    )
    parser.add_option(
        "-t",
        dest="targets_file",
        default=None,
        type="str",
        help="File specifying target indexes and labels in table format",
    )
    parser.add_option(
        "--table",
        dest="print_tables",
        default=False,
        action="store_true",
        help="Print big gene/TSS tables [Default: %default]",
    )
    parser.add_option(
        "--tss",
        dest="tss_alt",
        default=False,
        action="store_true",
        help="Perform alternative TSS analysis [Default: %default]",
    )
    parser.add_option(
        "-v",
        dest="gene_variance",
        default=False,
        action="store_true",
        help="Study accuracy with respect to gene variance across targets [Default: %default]",
    )
    (options, args) = parser.parse_args()

    if len(args) != 3:
        parser.error("Must provide parameters and model files, and genes HDF5 file")
    else:
        params_file = args[0]
        model_file = args[1]
        genes_hdf5_file = args[2]

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    options.shifts = [int(shift) for shift in options.shifts.split(",")]

    #################################################################
    # read in genes and targets

    gene_data = genedata.GeneData(genes_hdf5_file)

    #################################################################
    # TSS predictions

    if options.load_preds is not None:
        # load from file
        tss_preds = np.load(options.load_preds)

    else:

        #######################################################
        # setup model

        job = params.read_job_params(params_file)

        job["seq_length"] = gene_data.seq_length
        job["seq_depth"] = gene_data.seq_depth
        job["target_pool"] = gene_data.pool_width
        if not "num_targets" in job:
            job["num_targets"] = gene_data.num_targets

        # build model
        model = seqnn.SeqNN()
        model.build_feed_sad(job)

        if options.batch_size is not None:
            model.hp.batch_size = options.batch_size

        #######################################################
        # predict TSSs

        t0 = time.time()
        print("Computing gene predictions.", end="")
        sys.stdout.flush()

        # initialize batcher
        gene_batcher = batcher.Batcher(
            gene_data.seqs_1hot, batch_size=model.hp.batch_size
        )

        # initialie saver
        saver = tf.train.Saver()

        with tf.Session() as sess:
            # load variables into session
            saver.restore(sess, model_file)

            # predict
            tss_preds = model.predict_genes(
                sess,
                gene_batcher,
                gene_data.gene_seqs,
                rc=options.rc,
                shifts=options.shifts,
                tss_radius=options.tss_radius,
            )

        # save to file
        np.save("%s/preds" % options.out_dir, tss_preds)

        print(" Done in %ds." % (time.time() - t0))

    # up-convert
    tss_preds = tss_preds.astype("float32")

    #################################################################
    # convert to genes

    gene_targets, _ = gene.map_tss_genes(
        gene_data.tss_targets, gene_data.tss, tss_radius=options.tss_radius
    )
    gene_preds, _ = gene.map_tss_genes(
        tss_preds, gene_data.tss, tss_radius=options.tss_radius
    )

    #################################################################
    # determine targets

    # read targets
    if options.targets_file is not None:
        targets_df = pd.read_table(options.targets_file, index_col=0)
        target_indexes = targets_df.index
    else:
        if gene_data.num_targets is None:
            print("No targets to test against.")
            exit(1)
        else:
            target_indexes = np.arange(gene_data.num_targets)

    #################################################################
    # correlation statistics

    t0 = time.time()
    print("Computing correlations.", end="")
    sys.stdout.flush()

    cor_table(
        gene_data.tss_targets,
        tss_preds,
        gene_data.target_ids,
        gene_data.target_labels,
        target_indexes,
        "%s/tss_cors.txt" % options.out_dir,
    )

    cor_table(
        gene_targets,
        gene_preds,
        gene_data.target_ids,
        gene_data.target_labels,
        target_indexes,
        "%s/gene_cors.txt" % options.out_dir,
        draw_plots=True,
    )

    print(" Done in %ds." % (time.time() - t0))

    #################################################################
    # gene statistics

    if options.print_tables:
        t0 = time.time()
        print("Printing predictions.", end="")
        sys.stdout.flush()

        gene_table(
            gene_data.tss_targets,
            tss_preds,
            gene_data.tss_ids(),
            gene_data.target_labels,
            target_indexes,
            "%s/transcript" % options.out_dir,
            options.plot_scatter,
        )

        gene_table(
            gene_targets,
            gene_preds,
            gene_data.gene_ids(),
            gene_data.target_labels,
            target_indexes,
            "%s/gene" % options.out_dir,
            options.plot_scatter,
        )

        print(" Done in %ds." % (time.time() - t0))

    #################################################################
    # gene x target heatmaps

    if options.plot_heat or options.gene_variance:
        #########################################
        # normalize predictions across targets

        t0 = time.time()
        print("Normalizing values across targets.", end="")
        sys.stdout.flush()

        gene_targets_qn = normalize_targets(
            gene_targets[:, target_indexes], log_pseudo=1
        )
        gene_preds_qn = normalize_targets(gene_preds[:, target_indexes], log_pseudo=1)

        print(" Done in %ds." % (time.time() - t0))

    if options.plot_heat:
        #########################################
        # plot genes by targets clustermap

        t0 = time.time()
        print("Plotting heat maps.", end="")
        sys.stdout.flush()

        sns.set(font_scale=1.3, style="ticks")
        plot_genes = 1600
        plot_targets = 800

        # choose a set of variable genes
        gene_vars = gene_preds_qn.var(axis=1)
        indexes_var = np.argsort(gene_vars)[::-1][:plot_genes]

        # choose a set of random genes
        if plot_genes < gene_preds_qn.shape[0]:
            indexes_rand = np.random.choice(
                np.arange(gene_preds_qn.shape[0]), plot_genes, replace=False
            )
        else:
            indexes_rand = np.arange(gene_preds_qn.shape[0])

        # choose a set of random targets
        if plot_targets < 0.8 * gene_preds_qn.shape[1]:
            indexes_targets = np.random.choice(
                np.arange(gene_preds_qn.shape[1]), plot_targets, replace=False
            )
        else:
            indexes_targets = np.arange(gene_preds_qn.shape[1])

        # variable gene predictions
        clustermap(
            gene_preds_qn[indexes_var, :][:, indexes_targets],
            "%s/gene_heat_var.pdf" % options.out_dir,
        )
        clustermap(
            gene_preds_qn[indexes_var, :][:, indexes_targets],
            "%s/gene_heat_var_color.pdf" % options.out_dir,
            color="viridis",
            table=True,
        )

        # random gene predictions
        clustermap(
            gene_preds_qn[indexes_rand, :][:, indexes_targets],
            "%s/gene_heat_rand.pdf" % options.out_dir,
        )

        # variable gene targets
        clustermap(
            gene_targets_qn[indexes_var, :][:, indexes_targets],
            "%s/gene_theat_var.pdf" % options.out_dir,
        )
        clustermap(
            gene_targets_qn[indexes_var, :][:, indexes_targets],
            "%s/gene_theat_var_color.pdf" % options.out_dir,
            color="viridis",
            table=True,
        )

        # random gene targets (crashes)
        # clustermap(gene_targets_qn[indexes_rand, :][:, indexes_targets],
        #            '%s/gene_theat_rand.pdf' % options.out_dir)

        print(" Done in %ds." % (time.time() - t0))

    #################################################################
    # analyze replicates

    if options.replicate_labels_file is not None:
        # read long form labels, from which to infer replicates
        target_labels_long = []
        for line in open(options.replicate_labels_file):
            a = line.split("\t")
            a[-1] = a[-1].rstrip()
            target_labels_long.append(a[-1])

        # determine replicates
        replicate_lists = infer_replicates(target_labels_long)

        # compute correlations
        # replicate_correlations(replicate_lists, gene_data.tss_targets,
        # tss_preds, options.target_indexes, '%s/transcript_reps' % options.out_dir)
        replicate_correlations(
            replicate_lists,
            gene_targets,
            gene_preds,
            target_indexes,
            "%s/gene_reps" % options.out_dir,
        )  # , scatter_plots=True)

    #################################################################
    # gene variance

    if options.gene_variance:
        variance_accuracy(gene_targets_qn, gene_preds_qn, "%s/gene" % options.out_dir)

    #################################################################
    # alternative TSS

    if options.tss_alt:
        alternative_tss(
            gene_data.tss_targets[:, target_indexes],
            tss_preds[:, target_indexes],
            gene_data,
            options.out_dir,
            log_pseudo=1,
        )
Ejemplo n.º 8
0
def main():
    usage = 'usage: %prog [options] <params_file> <model_file> <test_hdf5_file>'
    parser = OptionParser(usage)
    parser.add_option(
        '--ai',
        dest='accuracy_indexes',
        help=
        'Comma-separated list of target indexes to make accuracy plots comparing true versus predicted values'
    )
    parser.add_option(
        '--clip',
        dest='target_clip',
        default=None,
        type='float',
        help=
        'Clip targets and predictions to a maximum value [Default: %default]')
    parser.add_option(
        '-d',
        dest='down_sample',
        default=1,
        type='int',
        help=
        'Down sample test computation by taking uniformly spaced positions [Default: %default]'
    )
    parser.add_option('-g',
                      dest='genome_file',
                      default='%s/assembly/human.hg19.genome' %
                      os.environ['HG19'],
                      help='Chromosome length information [Default: %default]')
    parser.add_option('--mc',
                      dest='mc_n',
                      default=0,
                      type='int',
                      help='Monte carlo test iterations [Default: %default]')
    parser.add_option(
        '--peak',
        '--peaks',
        dest='peaks',
        default=False,
        action='store_true',
        help='Compute expensive peak accuracy [Default: %default]')
    parser.add_option(
        '-o',
        dest='out_dir',
        default='test_out',
        help='Output directory for test statistics [Default: %default]')
    parser.add_option(
        '--rc',
        dest='rc',
        default=False,
        action='store_true',
        help='Average the fwd and rc predictions [Default: %default]')
    parser.add_option('-s',
                      dest='scent_file',
                      help='Dimension reduction model file')
    parser.add_option('--sample',
                      dest='sample_pct',
                      default=1,
                      type='float',
                      help='Sample percentage')
    parser.add_option('--save',
                      dest='save',
                      default=False,
                      action='store_true')
    parser.add_option('--shifts',
                      dest='shifts',
                      default='0',
                      help='Ensemble prediction shifts [Default: %default]')
    parser.add_option(
        '-t',
        dest='track_bed',
        help='BED file describing regions so we can output BigWig tracks')
    parser.add_option(
        '--ti',
        dest='track_indexes',
        help='Comma-separated list of target indexes to output BigWig tracks')
    parser.add_option('--train',
                      dest='train',
                      default=False,
                      action='store_true',
                      help='Process the training set [Default: %default]')
    parser.add_option('-v',
                      dest='valid',
                      default=False,
                      action='store_true',
                      help='Process the validation set [Default: %default]')
    parser.add_option(
        '-w',
        dest='pool_width',
        default=1,
        type='int',
        help=
        'Max pool width for regressing nt predictions to predict peak calls [Default: %default]'
    )
    (options, args) = parser.parse_args()

    if len(args) != 3:
        parser.error('Must provide parameters, model, and test data HDF5')
    else:
        params_file = args[0]
        model_file = args[1]
        test_hdf5_file = args[2]

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    options.shifts = [int(shift) for shift in options.shifts.split(',')]

    #######################################################
    # load data
    #######################################################
    data_open = h5py.File(test_hdf5_file)

    if options.train:
        test_seqs = data_open['train_in']
        test_targets = data_open['train_out']
        if 'train_na' in data_open:
            test_na = data_open['train_na']

    elif options.valid:
        test_seqs = data_open['valid_in']
        test_targets = data_open['valid_out']
        test_na = None
        if 'valid_na' in data_open:
            test_na = data_open['valid_na']

    else:
        test_seqs = data_open['test_in']
        test_targets = data_open['test_out']
        test_na = None
        if 'test_na' in data_open:
            test_na = data_open['test_na']

    if options.sample_pct < 1:
        sample_n = int(test_seqs.shape[0] * options.sample_pct)
        print('Sampling %d sequences' % sample_n)
        sample_indexes = sorted(
            np.random.choice(np.arange(test_seqs.shape[0]),
                             size=sample_n,
                             replace=False))
        test_seqs = test_seqs[sample_indexes]
        test_targets = test_targets[sample_indexes]
        if test_na is not None:
            test_na = test_na[sample_indexes]

    target_labels = [tl.decode('UTF-8') for tl in data_open['target_labels']]

    #######################################################
    # model parameters and placeholders

    job = params.read_job_params(params_file)

    job['seq_length'] = test_seqs.shape[1]
    job['seq_depth'] = test_seqs.shape[2]
    job['num_targets'] = test_targets.shape[2]
    job['target_pool'] = int(np.array(data_open.get('pool_width', 1)))

    t0 = time.time()
    dr = seqnn.SeqNN()
    dr.build(job)
    print('Model building time %ds' % (time.time() - t0))

    # adjust for fourier
    job['fourier'] = 'train_out_imag' in data_open
    if job['fourier']:
        test_targets_imag = data_open['test_out_imag']
        if options.valid:
            test_targets_imag = data_open['valid_out_imag']

    # adjust for factors
    if options.scent_file is not None:
        t0 = time.time()
        test_targets_full = data_open['test_out_full']
        model = joblib.load(options.scent_file)

    #######################################################
    # test

    # initialize batcher
    if job['fourier']:
        batcher_test = batcher.BatcherF(test_seqs, test_targets,
                                        test_targets_imag, test_na,
                                        dr.hp.batch_size, dr.hp.target_pool)
    else:
        batcher_test = batcher.Batcher(test_seqs, test_targets, test_na,
                                       dr.hp.batch_size, dr.hp.target_pool)

    # initialize saver
    saver = tf.train.Saver()

    with tf.Session() as sess:
        # load variables into session
        saver.restore(sess, model_file)

        # test
        t0 = time.time()
        test_acc = dr.test(sess,
                           batcher_test,
                           rc=options.rc,
                           shifts=options.shifts,
                           mc_n=options.mc_n)

        if options.save:
            np.save('%s/preds.npy' % options.out_dir, test_acc.preds)
            np.save('%s/targets.npy' % options.out_dir, test_acc.targets)

        test_preds = test_acc.preds
        print('SeqNN test: %ds' % (time.time() - t0))

        # compute stats
        t0 = time.time()
        test_r2 = test_acc.r2(clip=options.target_clip)
        # test_log_r2 = test_acc.r2(log=True, clip=options.target_clip)
        test_pcor = test_acc.pearsonr(clip=options.target_clip)
        test_log_pcor = test_acc.pearsonr(log=True, clip=options.target_clip)
        #test_scor = test_acc.spearmanr()  # too slow; mostly driven by low values
        print('Compute stats: %ds' % (time.time() - t0))

        # print
        print('Test Loss:         %7.5f' % test_acc.loss)
        print('Test R2:           %7.5f' % test_r2.mean())
        # print('Test log R2:       %7.5f' % test_log_r2.mean())
        print('Test PearsonR:     %7.5f' % test_pcor.mean())
        print('Test log PearsonR: %7.5f' % test_log_pcor.mean())
        # print('Test SpearmanR:    %7.5f' % test_scor.mean())

        acc_out = open('%s/acc.txt' % options.out_dir, 'w')
        for ti in range(len(test_r2)):
            print('%4d  %7.5f  %.5f  %.5f  %.5f  %s' %
                  (ti, test_acc.target_losses[ti], test_r2[ti], test_pcor[ti],
                   test_log_pcor[ti], target_labels[ti]),
                  file=acc_out)
        acc_out.close()

        # print normalization factors
        target_means = test_preds.mean(axis=(0, 1), dtype='float64')
        target_means_median = np.median(target_means)
        target_means /= target_means_median
        norm_out = open('%s/normalization.txt' % options.out_dir, 'w')
        print('\n'.join([str(tu) for tu in target_means]), file=norm_out)
        norm_out.close()

        # clean up
        del test_acc

        # if test targets are reconstructed, measure versus the truth
        if options.scent_file is not None:
            compute_full_accuracy(dr, model, test_preds, test_targets_full,
                                  options.out_dir, options.down_sample)

    #######################################################
    # peak call accuracy

    if options.peaks:
        # sample every few bins to decrease correlations
        ds_indexes_preds = np.arange(0, test_preds.shape[1], 8)
        ds_indexes_targets = ds_indexes_preds + (dr.hp.batch_buffer //
                                                 dr.hp.target_pool)

        aurocs = []
        auprcs = []

        peaks_out = open('%s/peaks.txt' % options.out_dir, 'w')
        for ti in range(test_targets.shape[2]):
            if options.scent_file is not None:
                test_targets_ti = test_targets_full[:, :, ti]
            else:
                test_targets_ti = test_targets[:, :, ti]

            # subset and flatten
            test_targets_ti_flat = test_targets_ti[:,
                                                   ds_indexes_targets].flatten(
                                                   ).astype('float32')
            test_preds_ti_flat = test_preds[:, ds_indexes_preds,
                                            ti].flatten().astype('float32')

            # call peaks
            test_targets_ti_lambda = np.mean(test_targets_ti_flat)
            test_targets_pvals = 1 - poisson.cdf(
                np.round(test_targets_ti_flat) - 1, mu=test_targets_ti_lambda)
            test_targets_qvals = np.array(ben_hoch(test_targets_pvals))
            test_targets_peaks = test_targets_qvals < 0.01

            if test_targets_peaks.sum() == 0:
                aurocs.append(0.5)
                auprcs.append(0)

            else:
                # compute prediction accuracy
                aurocs.append(
                    roc_auc_score(test_targets_peaks, test_preds_ti_flat))
                auprcs.append(
                    average_precision_score(test_targets_peaks,
                                            test_preds_ti_flat))

            print('%4d  %6d  %.5f  %.5f' %
                  (ti, test_targets_peaks.sum(), aurocs[-1], auprcs[-1]),
                  file=peaks_out)

        peaks_out.close()

        print('Test AUROC:     %7.5f' % np.mean(aurocs))
        print('Test AUPRC:     %7.5f' % np.mean(auprcs))

    #######################################################
    # BigWig tracks

    # NOTE: THESE ASSUME THERE WAS NO DOWN-SAMPLING ABOVE

    # print bigwig tracks for visualization
    if options.track_bed:
        if options.genome_file is None:
            parser.error(
                'Must provide genome file in order to print valid BigWigs')

        if not os.path.isdir('%s/tracks' % options.out_dir):
            os.mkdir('%s/tracks' % options.out_dir)

        track_indexes = range(test_preds.shape[2])
        if options.track_indexes:
            track_indexes = [
                int(ti) for ti in options.track_indexes.split(',')
            ]

        bed_set = 'test'
        if options.valid:
            bed_set = 'valid'

        for ti in track_indexes:
            if options.scent_file is not None:
                test_targets_ti = test_targets_full[:, :, ti]
            else:
                test_targets_ti = test_targets[:, :, ti]

            # make true targets bigwig
            bw_file = '%s/tracks/t%d_true.bw' % (options.out_dir, ti)
            bigwig_write(bw_file,
                         test_targets_ti,
                         options.track_bed,
                         options.genome_file,
                         bed_set=bed_set)

            # make predictions bigwig
            bw_file = '%s/tracks/t%d_preds.bw' % (options.out_dir, ti)
            bigwig_write(bw_file,
                         test_preds[:, :, ti],
                         options.track_bed,
                         options.genome_file,
                         dr.hp.batch_buffer,
                         bed_set=bed_set)

        # make NA bigwig
        bw_file = '%s/tracks/na.bw' % options.out_dir
        bigwig_write(bw_file,
                     test_na,
                     options.track_bed,
                     options.genome_file,
                     bed_set=bed_set)

    #######################################################
    # accuracy plots

    if options.accuracy_indexes is not None:
        accuracy_indexes = [
            int(ti) for ti in options.accuracy_indexes.split(',')
        ]

        if not os.path.isdir('%s/scatter' % options.out_dir):
            os.mkdir('%s/scatter' % options.out_dir)

        if not os.path.isdir('%s/violin' % options.out_dir):
            os.mkdir('%s/violin' % options.out_dir)

        if not os.path.isdir('%s/roc' % options.out_dir):
            os.mkdir('%s/roc' % options.out_dir)

        if not os.path.isdir('%s/pr' % options.out_dir):
            os.mkdir('%s/pr' % options.out_dir)

        for ti in accuracy_indexes:
            if options.scent_file is not None:
                test_targets_ti = test_targets_full[:, :, ti]
            else:
                test_targets_ti = test_targets[:, :, ti]

            ############################################
            # scatter

            # sample every few bins (adjust to plot the # points I want)
            ds_indexes_preds = np.arange(0, test_preds.shape[1], 8)
            ds_indexes_targets = ds_indexes_preds + (dr.hp.batch_buffer //
                                                     dr.hp.target_pool)

            # subset and flatten
            test_targets_ti_flat = test_targets_ti[:,
                                                   ds_indexes_targets].flatten(
                                                   ).astype('float32')
            test_preds_ti_flat = test_preds[:, ds_indexes_preds,
                                            ti].flatten().astype('float32')

            # take log2
            test_targets_ti_log = np.log2(test_targets_ti_flat + 1)
            test_preds_ti_log = np.log2(test_preds_ti_flat + 1)

            # plot log2
            sns.set(font_scale=1.2, style='ticks')
            out_pdf = '%s/scatter/t%d.pdf' % (options.out_dir, ti)
            plots.regplot(test_targets_ti_log,
                          test_preds_ti_log,
                          out_pdf,
                          poly_order=1,
                          alpha=0.3,
                          sample=500,
                          figsize=(6, 6),
                          x_label='log2 Experiment',
                          y_label='log2 Prediction',
                          table=True)

            ############################################
            # violin

            # call peaks
            test_targets_ti_lambda = np.mean(test_targets_ti_flat)
            test_targets_pvals = 1 - poisson.cdf(
                np.round(test_targets_ti_flat) - 1, mu=test_targets_ti_lambda)
            test_targets_qvals = np.array(ben_hoch(test_targets_pvals))
            test_targets_peaks = test_targets_qvals < 0.01
            test_targets_peaks_str = np.where(test_targets_peaks, 'Peak',
                                              'Background')

            # violin plot
            sns.set(font_scale=1.3, style='ticks')
            plt.figure()
            df = pd.DataFrame({
                'log2 Prediction':
                np.log2(test_preds_ti_flat + 1),
                'Experimental coverage status':
                test_targets_peaks_str
            })
            ax = sns.violinplot(x='Experimental coverage status',
                                y='log2 Prediction',
                                data=df)
            ax.grid(True, linestyle=':')
            plt.savefig('%s/violin/t%d.pdf' % (options.out_dir, ti))
            plt.close()

            # ROC
            plt.figure()
            fpr, tpr, _ = roc_curve(test_targets_peaks, test_preds_ti_flat)
            auroc = roc_auc_score(test_targets_peaks, test_preds_ti_flat)
            plt.plot([0, 1], [0, 1],
                     c='black',
                     linewidth=1,
                     linestyle='--',
                     alpha=0.7)
            plt.plot(fpr, tpr, c='black')
            ax = plt.gca()
            ax.set_xlabel('False positive rate')
            ax.set_ylabel('True positive rate')
            ax.text(0.99,
                    0.02,
                    'AUROC %.3f' % auroc,
                    horizontalalignment='right')  # , fontsize=14)
            ax.grid(True, linestyle=':')
            plt.savefig('%s/roc/t%d.pdf' % (options.out_dir, ti))
            plt.close()

            # PR
            plt.figure()
            prec, recall, _ = precision_recall_curve(test_targets_peaks,
                                                     test_preds_ti_flat)
            auprc = average_precision_score(test_targets_peaks,
                                            test_preds_ti_flat)
            plt.axhline(y=test_targets_peaks.mean(),
                        c='black',
                        linewidth=1,
                        linestyle='--',
                        alpha=0.7)
            plt.plot(recall, prec, c='black')
            ax = plt.gca()
            ax.set_xlabel('Recall')
            ax.set_ylabel('Precision')
            ax.text(0.99,
                    0.95,
                    'AUPRC %.3f' % auprc,
                    horizontalalignment='right')  # , fontsize=14)
            ax.grid(True, linestyle=':')
            plt.savefig('%s/pr/t%d.pdf' % (options.out_dir, ti))
            plt.close()

    data_open.close()
Ejemplo n.º 9
0
def main():
    usage = "usage: %prog [options] <params_file> <model_file> <test_hdf5_file>"
    parser = OptionParser(usage)
    parser.add_option(
        "--ai",
        dest="accuracy_indexes",
        help=
        "Comma-separated list of target indexes to make accuracy plots comparing true versus predicted values",
    )
    parser.add_option(
        "--clip",
        dest="target_clip",
        default=None,
        type="float",
        help=
        "Clip targets and predictions to a maximum value [Default: %default]",
    )
    parser.add_option(
        "-d",
        dest="down_sample",
        default=1,
        type="int",
        help=
        "Down sample test computation by taking uniformly spaced positions [Default: %default]",
    )
    parser.add_option(
        "-g",
        dest="genome_file",
        default="%s/data/human.hg19.genome" % os.environ["BASENJIDIR"],
        help="Chromosome length information [Default: %default]",
    )
    parser.add_option(
        "--mc",
        dest="mc_n",
        default=0,
        type="int",
        help="Monte carlo test iterations [Default: %default]",
    )
    parser.add_option(
        "--peak",
        "--peaks",
        dest="peaks",
        default=False,
        action="store_true",
        help="Compute expensive peak accuracy [Default: %default]",
    )
    parser.add_option(
        "-o",
        dest="out_dir",
        default="test_out",
        help="Output directory for test statistics [Default: %default]",
    )
    parser.add_option(
        "--rc",
        dest="rc",
        default=False,
        action="store_true",
        help="Average the fwd and rc predictions [Default: %default]",
    )
    parser.add_option("-s",
                      dest="scent_file",
                      help="Dimension reduction model file")
    parser.add_option("--sample",
                      dest="sample_pct",
                      default=1,
                      type="float",
                      help="Sample percentage")
    parser.add_option("--save",
                      dest="save",
                      default=False,
                      action="store_true")
    parser.add_option(
        "--shifts",
        dest="shifts",
        default="0",
        help="Ensemble prediction shifts [Default: %default]",
    )
    parser.add_option(
        "-t",
        dest="track_bed",
        help="BED file describing regions so we can output BigWig tracks",
    )
    parser.add_option(
        "--ti",
        dest="track_indexes",
        help="Comma-separated list of target indexes to output BigWig tracks",
    )
    parser.add_option(
        "--train",
        dest="train",
        default=False,
        action="store_true",
        help="Process the training set [Default: %default]",
    )
    parser.add_option(
        "-v",
        dest="valid",
        default=False,
        action="store_true",
        help="Process the validation set [Default: %default]",
    )
    parser.add_option(
        "-w",
        dest="pool_width",
        default=1,
        type="int",
        help=
        "Max pool width for regressing nt predictions to predict peak calls [Default: %default]",
    )
    (options, args) = parser.parse_args()

    if len(args) != 3:
        parser.error("Must provide parameters, model, and test data HDF5")
    else:
        params_file = args[0]
        model_file = args[1]
        test_hdf5_file = args[2]

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    options.shifts = [int(shift) for shift in options.shifts.split(",")]

    #######################################################
    # load data
    #######################################################
    data_open = h5py.File(test_hdf5_file)

    if options.train:
        test_seqs = data_open["train_in"]
        test_targets = data_open["train_out"]
        if "train_na" in data_open:
            test_na = data_open["train_na"]

    elif options.valid:
        test_seqs = data_open["valid_in"]
        test_targets = data_open["valid_out"]
        test_na = None
        if "valid_na" in data_open:
            test_na = data_open["valid_na"]

    else:
        test_seqs = data_open["test_in"]
        test_targets = data_open["test_out"]
        test_na = None
        if "test_na" in data_open:
            test_na = data_open["test_na"]

    if options.sample_pct < 1:
        sample_n = int(test_seqs.shape[0] * options.sample_pct)
        print("Sampling %d sequences" % sample_n)
        sample_indexes = sorted(
            np.random.choice(np.arange(test_seqs.shape[0]),
                             size=sample_n,
                             replace=False))
        test_seqs = test_seqs[sample_indexes]
        test_targets = test_targets[sample_indexes]
        if test_na is not None:
            test_na = test_na[sample_indexes]

    target_labels = [tl.decode("UTF-8") for tl in data_open["target_labels"]]

    #######################################################
    # model parameters and placeholders

    job = params.read_job_params(params_file)
    job["seq_length"] = test_seqs.shape[1]
    job["seq_depth"] = test_seqs.shape[2]
    job["num_targets"] = test_targets.shape[2]
    job["target_pool"] = int(np.array(data_open.get("pool_width", 1)))

    t0 = time.time()
    dr = seqnn.SeqNN()
    dr.build_feed(job)
    print("Model building time %ds" % (time.time() - t0))

    # adjust for fourier
    job["fourier"] = "train_out_imag" in data_open
    if job["fourier"]:
        test_targets_imag = data_open["test_out_imag"]
        if options.valid:
            test_targets_imag = data_open["valid_out_imag"]

    # adjust for factors
    if options.scent_file is not None:
        t0 = time.time()
        test_targets_full = data_open["test_out_full"]
        model = joblib.load(options.scent_file)

    #######################################################
    # test

    # initialize batcher
    if job["fourier"]:
        batcher_test = batcher.BatcherF(
            test_seqs,
            test_targets,
            test_targets_imag,
            test_na,
            dr.hp.batch_size,
            dr.hp.target_pool,
        )
    else:
        batcher_test = batcher.Batcher(test_seqs, test_targets, test_na,
                                       dr.hp.batch_size, dr.hp.target_pool)

    # initialize saver
    saver = tf.train.Saver()

    with tf.Session() as sess:
        # load variables into session
        saver.restore(sess, model_file)

        # test
        t0 = time.time()
        test_acc = dr.test_h5_manual(sess,
                                     batcher_test,
                                     rc=options.rc,
                                     shifts=options.shifts,
                                     mc_n=options.mc_n)

        if options.save:
            np.save("%s/preds.npy" % options.out_dir, test_acc.preds)
            np.save("%s/targets.npy" % options.out_dir, test_acc.targets)

        test_preds = test_acc.preds
        print("SeqNN test: %ds" % (time.time() - t0))

        # compute stats
        t0 = time.time()
        test_r2 = test_acc.r2(clip=options.target_clip)
        # test_log_r2 = test_acc.r2(log=True, clip=options.target_clip)
        test_pcor = test_acc.pearsonr(clip=options.target_clip)
        test_log_pcor = test_acc.pearsonr(log=True, clip=options.target_clip)
        # test_scor = test_acc.spearmanr()  # too slow; mostly driven by low values
        print("Compute stats: %ds" % (time.time() - t0))

        # print
        print("Test Loss:         %7.5f" % test_acc.loss)
        print("Test R2:           %7.5f" % test_r2.mean())
        # print('Test log R2:       %7.5f' % test_log_r2.mean())
        print("Test PearsonR:     %7.5f" % test_pcor.mean())
        print("Test log PearsonR: %7.5f" % test_log_pcor.mean())
        # print('Test SpearmanR:    %7.5f' % test_scor.mean())

        acc_out = open("%s/acc.txt" % options.out_dir, "w")
        for ti in range(len(test_r2)):
            print(
                "%4d  %7.5f  %.5f  %.5f  %.5f  %s" % (
                    ti,
                    test_acc.target_losses[ti],
                    test_r2[ti],
                    test_pcor[ti],
                    test_log_pcor[ti],
                    target_labels[ti],
                ),
                file=acc_out,
            )
        acc_out.close()

        # print normalization factors
        target_means = test_preds.mean(axis=(0, 1), dtype="float64")
        target_means_median = np.median(target_means)
        target_means /= target_means_median
        norm_out = open("%s/normalization.txt" % options.out_dir, "w")
        print("\n".join([str(tu) for tu in target_means]), file=norm_out)
        norm_out.close()

        # clean up
        del test_acc

        # if test targets are reconstructed, measure versus the truth
        if options.scent_file is not None:
            compute_full_accuracy(
                dr,
                model,
                test_preds,
                test_targets_full,
                options.out_dir,
                options.down_sample,
            )

    #######################################################
    # peak call accuracy

    if options.peaks:
        # sample every few bins to decrease correlations
        ds_indexes_preds = np.arange(0, test_preds.shape[1], 8)
        ds_indexes_targets = ds_indexes_preds + (dr.hp.batch_buffer //
                                                 dr.hp.target_pool)

        aurocs = []
        auprcs = []

        peaks_out = open("%s/peaks.txt" % options.out_dir, "w")
        for ti in range(test_targets.shape[2]):
            if options.scent_file is not None:
                test_targets_ti = test_targets_full[:, :, ti]
            else:
                test_targets_ti = test_targets[:, :, ti]

            # subset and flatten
            test_targets_ti_flat = (test_targets_ti[:, ds_indexes_targets].
                                    flatten().astype("float32"))
            test_preds_ti_flat = (test_preds[:, ds_indexes_preds,
                                             ti].flatten().astype("float32"))

            # call peaks
            test_targets_ti_lambda = np.mean(test_targets_ti_flat)
            test_targets_pvals = 1 - poisson.cdf(
                np.round(test_targets_ti_flat) - 1, mu=test_targets_ti_lambda)
            test_targets_qvals = np.array(ben_hoch(test_targets_pvals))
            test_targets_peaks = test_targets_qvals < 0.01

            if test_targets_peaks.sum() == 0:
                aurocs.append(0.5)
                auprcs.append(0)

            else:
                # compute prediction accuracy
                aurocs.append(
                    roc_auc_score(test_targets_peaks, test_preds_ti_flat))
                auprcs.append(
                    average_precision_score(test_targets_peaks,
                                            test_preds_ti_flat))

            print(
                "%4d  %6d  %.5f  %.5f" %
                (ti, test_targets_peaks.sum(), aurocs[-1], auprcs[-1]),
                file=peaks_out,
            )

        peaks_out.close()

        print("Test AUROC:     %7.5f" % np.mean(aurocs))
        print("Test AUPRC:     %7.5f" % np.mean(auprcs))

    #######################################################
    # BigWig tracks

    # NOTE: THESE ASSUME THERE WAS NO DOWN-SAMPLING ABOVE

    # print bigwig tracks for visualization
    if options.track_bed:
        if options.genome_file is None:
            parser.error(
                "Must provide genome file in order to print valid BigWigs")

        if not os.path.isdir("%s/tracks" % options.out_dir):
            os.mkdir("%s/tracks" % options.out_dir)

        track_indexes = range(test_preds.shape[2])
        if options.track_indexes:
            track_indexes = [
                int(ti) for ti in options.track_indexes.split(",")
            ]

        bed_set = "test"
        if options.valid:
            bed_set = "valid"

        for ti in track_indexes:
            if options.scent_file is not None:
                test_targets_ti = test_targets_full[:, :, ti]
            else:
                test_targets_ti = test_targets[:, :, ti]

            # make true targets bigwig
            bw_file = "%s/tracks/t%d_true.bw" % (options.out_dir, ti)
            bigwig_write(
                bw_file,
                test_targets_ti,
                options.track_bed,
                options.genome_file,
                bed_set=bed_set,
            )

            # make predictions bigwig
            bw_file = "%s/tracks/t%d_preds.bw" % (options.out_dir, ti)
            bigwig_write(
                bw_file,
                test_preds[:, :, ti],
                options.track_bed,
                options.genome_file,
                dr.hp.batch_buffer,
                bed_set=bed_set,
            )

        # make NA bigwig
        # bw_file = '%s/tracks/na.bw' % options.out_dir
        # bigwig_write(
        #     bw_file,
        #     test_na,
        #     options.track_bed,
        #     options.genome_file,
        #     bed_set=bed_set)

    #######################################################
    # accuracy plots

    if options.accuracy_indexes is not None:
        accuracy_indexes = [
            int(ti) for ti in options.accuracy_indexes.split(",")
        ]

        if not os.path.isdir("%s/scatter" % options.out_dir):
            os.mkdir("%s/scatter" % options.out_dir)

        if not os.path.isdir("%s/violin" % options.out_dir):
            os.mkdir("%s/violin" % options.out_dir)

        if not os.path.isdir("%s/roc" % options.out_dir):
            os.mkdir("%s/roc" % options.out_dir)

        if not os.path.isdir("%s/pr" % options.out_dir):
            os.mkdir("%s/pr" % options.out_dir)

        for ti in accuracy_indexes:
            if options.scent_file is not None:
                test_targets_ti = test_targets_full[:, :, ti]
            else:
                test_targets_ti = test_targets[:, :, ti]

            ############################################
            # scatter

            # sample every few bins (adjust to plot the # points I want)
            ds_indexes_preds = np.arange(0, test_preds.shape[1], 8)
            ds_indexes_targets = ds_indexes_preds + (dr.hp.batch_buffer //
                                                     dr.hp.target_pool)

            # subset and flatten
            test_targets_ti_flat = (test_targets_ti[:, ds_indexes_targets].
                                    flatten().astype("float32"))
            test_preds_ti_flat = (test_preds[:, ds_indexes_preds,
                                             ti].flatten().astype("float32"))

            # take log2
            test_targets_ti_log = np.log2(test_targets_ti_flat + 1)
            test_preds_ti_log = np.log2(test_preds_ti_flat + 1)

            # plot log2
            sns.set(font_scale=1.2, style="ticks")
            out_pdf = "%s/scatter/t%d.pdf" % (options.out_dir, ti)
            plots.regplot(
                test_targets_ti_log,
                test_preds_ti_log,
                out_pdf,
                poly_order=1,
                alpha=0.3,
                sample=500,
                figsize=(6, 6),
                x_label="log2 Experiment",
                y_label="log2 Prediction",
                table=True,
            )

            ############################################
            # violin

            # call peaks
            test_targets_ti_lambda = np.mean(test_targets_ti_flat)
            test_targets_pvals = 1 - poisson.cdf(
                np.round(test_targets_ti_flat) - 1, mu=test_targets_ti_lambda)
            test_targets_qvals = np.array(ben_hoch(test_targets_pvals))
            test_targets_peaks = test_targets_qvals < 0.01
            test_targets_peaks_str = np.where(test_targets_peaks, "Peak",
                                              "Background")

            # violin plot
            sns.set(font_scale=1.3, style="ticks")
            plt.figure()
            df = pd.DataFrame({
                "log2 Prediction":
                np.log2(test_preds_ti_flat + 1),
                "Experimental coverage status":
                test_targets_peaks_str,
            })
            ax = sns.violinplot(x="Experimental coverage status",
                                y="log2 Prediction",
                                data=df)
            ax.grid(True, linestyle=":")
            plt.savefig("%s/violin/t%d.pdf" % (options.out_dir, ti))
            plt.close()

            # ROC
            plt.figure()
            fpr, tpr, _ = roc_curve(test_targets_peaks, test_preds_ti_flat)
            auroc = roc_auc_score(test_targets_peaks, test_preds_ti_flat)
            plt.plot([0, 1], [0, 1],
                     c="black",
                     linewidth=1,
                     linestyle="--",
                     alpha=0.7)
            plt.plot(fpr, tpr, c="black")
            ax = plt.gca()
            ax.set_xlabel("False positive rate")
            ax.set_ylabel("True positive rate")
            ax.text(0.99,
                    0.02,
                    "AUROC %.3f" % auroc,
                    horizontalalignment="right")  # , fontsize=14)
            ax.grid(True, linestyle=":")
            plt.savefig("%s/roc/t%d.pdf" % (options.out_dir, ti))
            plt.close()

            # PR
            plt.figure()
            prec, recall, _ = precision_recall_curve(test_targets_peaks,
                                                     test_preds_ti_flat)
            auprc = average_precision_score(test_targets_peaks,
                                            test_preds_ti_flat)
            plt.axhline(
                y=test_targets_peaks.mean(),
                c="black",
                linewidth=1,
                linestyle="--",
                alpha=0.7,
            )
            plt.plot(recall, prec, c="black")
            ax = plt.gca()
            ax.set_xlabel("Recall")
            ax.set_ylabel("Precision")
            ax.text(0.99,
                    0.95,
                    "AUPRC %.3f" % auprc,
                    horizontalalignment="right")  # , fontsize=14)
            ax.grid(True, linestyle=":")
            plt.savefig("%s/pr/t%d.pdf" % (options.out_dir, ti))
            plt.close()

    data_open.close()
Ejemplo n.º 10
0
def main():
    usage = "usage: %prog [options] <params_file> <model_file> <input_file>"
    parser = OptionParser(usage)
    parser.add_option(
        "-a",
        dest="activity_enrich",
        default=1,
        type="float",
        help=
        "Enrich for the most active top % of sequences [Default: %default]",
    )
    parser.add_option("-b",
                      dest="batch_size",
                      default=None,
                      type="int",
                      help="Batch size")
    parser.add_option(
        "-f",
        dest="figure_width",
        default=20,
        type="float",
        help="Figure width [Default: %default]",
    )
    parser.add_option(
        "-g",
        dest="gain",
        default=False,
        action="store_true",
        help="Draw a sequence logo for the gain score, too [Default: %default]",
    )
    parser.add_option(
        "-l",
        dest="satmut_len",
        default=200,
        type="int",
        help="Length of centered sequence to mutate [Default: %default]",
    )
    parser.add_option(
        "-m",
        dest="min_limit",
        default=0.005,
        type="float",
        help="Minimum heatmap limit [Default: %default]",
    )
    parser.add_option(
        "-n",
        dest="load_sat_npy",
        default=False,
        action="store_true",
        help="Load the predictions from .npy files [Default: %default]",
    )
    parser.add_option(
        "-o",
        dest="out_dir",
        default="heat",
        help="Output directory [Default: %default]",
    )
    parser.add_option(
        "-r",
        dest="rng_seed",
        default=1,
        type="float",
        help="Random number generator seed [Default: %default]",
    )
    parser.add_option(
        "--rc",
        dest="rc",
        default=False,
        action="store_true",
        help=
        "Ensemble forward and reverse complement predictions [Default: %default]",
    )
    parser.add_option(
        "-s",
        dest="sample",
        default=None,
        type="int",
        help="Sample sequences from the test set [Default:%default]",
    )
    parser.add_option(
        "--shifts",
        dest="shifts",
        default="0",
        help="Ensemble prediction shifts [Default: %default]",
    )
    parser.add_option(
        "-t",
        dest="targets",
        default="0",
        help=
        "Comma-separated target indexes (or -1 for all) [Default: %default]",
    )
    (options, args) = parser.parse_args()

    if len(args) != 3:
        parser.error(
            "Must provide parameters and model files and input sequences (as a "
            "FASTA file or test data in an HDF file")
    else:
        params_file = args[0]
        model_file = args[1]
        input_file = args[2]

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    options.shifts = [int(shift) for shift in options.shifts.split(",")]

    random.seed(options.rng_seed)

    #################################################################
    # parse input file
    #################################################################
    seqs, seqs_1hot, targets = parse_input(input_file, options.sample)

    # decide which targets to obtain
    if options.targets == "-1":
        target_indexes = range(targets.shape[2])
        target_subset = None
    else:
        target_indexes = [int(ti) for ti in options.targets.split(",")]
        target_subset = target_indexes

    # enrich for active sequences
    if targets is not None:
        seqs, seqs_1hot, targets = enrich_activity(seqs, seqs_1hot, targets,
                                                   options.activity_enrich,
                                                   target_indexes)

    seqs_n = seqs_1hot.shape[0]

    #################################################################
    # setup model
    #################################################################
    job = params.read_job_params(params_file)

    job["seq_length"] = seqs_1hot.shape[1]
    job["seq_depth"] = seqs_1hot.shape[2]

    if targets is None:
        if "num_targets" not in job or "target_pool" not in job:
            print(
                "Must provide num_targets and target_pool in parameters file",
                file=sys.stderr,
            )
            exit(1)
    else:
        job["num_targets"] = targets.shape[2]
        job["target_pool"] = job["seq_length"] // targets.shape[1]

    t0 = time.time()
    model = seqnn.SeqNN()
    model.build_feed(
        job,
        ensemble_rc=options.rc,
        ensemble_shifts=options.shifts,
        target_subset=target_subset,
    )

    if options.batch_size is not None:
        model.hp.batch_size = options.batch_size

    # initialize saver
    saver = tf.train.Saver()

    with tf.Session() as sess:
        # load variables into session
        saver.restore(sess, model_file)

        for si in range(seqs_n):
            print("Mutating sequence %d / %d" % (si + 1, seqs_n), flush=True)

            # write sequence
            fasta_out = open("%s/seq%d.fa" % (options.out_dir, si), "w")
            end_len = (len(seqs[si]) - options.satmut_len) // 2
            print(">seq%d\n%s" % (si, seqs[si][end_len:-end_len]),
                  file=fasta_out)
            fasta_out.close()

            #################################################################
            # predict modifications

            if options.load_sat_npy:
                sat_preds = np.load("%s/seq%d_preds.npy" %
                                    (options.out_dir, si))

            else:
                # supplement with saturated mutagenesis
                sat_seqs_1hot = satmut_seqs(seqs_1hot[si:si + 1],
                                            options.satmut_len)

                # initialize batcher
                batcher_sat = batcher.Batcher(sat_seqs_1hot,
                                              batch_size=model.hp.batch_size)

                # predict
                sat_preds = model.predict_h5(sess, batcher_sat)
                np.save("%s/seq%d_preds.npy" % (options.out_dir, si),
                        sat_preds)

            #################################################################
            # compute delta, loss, and gain matrices

            # compute the matrix of prediction deltas: (4 x L_sm x T) array
            sat_delta = delta_matrix(seqs_1hot[si], sat_preds,
                                     options.satmut_len)

            # sat_loss, sat_gain = loss_gain(sat_delta, sat_preds[si], options.satmut_len)
            sat_loss = sat_delta.min(axis=0)
            sat_gain = sat_delta.max(axis=0)

            ##############################################
            # plot

            for tii in range(len(target_indexes)):
                # setup plot
                sns.set(style="white", font_scale=1)
                spp = subplot_params(sat_delta.shape[1])

                if options.gain:
                    plt.figure(figsize=(options.figure_width, 5))
                    ax_pred = plt.subplot2grid(
                        (5, spp["heat_cols"]),
                        (0, spp["pred_start"]),
                        colspan=spp["pred_span"],
                    )
                    ax_logo_loss = plt.subplot2grid(
                        (5, spp["heat_cols"]),
                        (1, spp["logo_start"]),
                        colspan=spp["logo_span"],
                    )
                    ax_logo_gain = plt.subplot2grid(
                        (5, spp["heat_cols"]),
                        (2, spp["logo_start"]),
                        colspan=spp["logo_span"],
                    )
                    ax_sad = plt.subplot2grid(
                        (5, spp["heat_cols"]),
                        (3, spp["sad_start"]),
                        colspan=spp["sad_span"],
                    )
                    ax_heat = plt.subplot2grid((5, spp["heat_cols"]), (4, 0),
                                               colspan=spp["heat_cols"])

                else:
                    plt.figure(figsize=(options.figure_width, 4))
                    ax_pred = plt.subplot2grid(
                        (4, spp["heat_cols"]),
                        (0, spp["pred_start"]),
                        colspan=spp["pred_span"],
                    )
                    ax_logo_loss = plt.subplot2grid(
                        (4, spp["heat_cols"]),
                        (1, spp["logo_start"]),
                        colspan=spp["logo_span"],
                    )
                    ax_sad = plt.subplot2grid(
                        (4, spp["heat_cols"]),
                        (2, spp["sad_start"]),
                        colspan=spp["sad_span"],
                    )
                    ax_heat = plt.subplot2grid((4, spp["heat_cols"]), (3, 0),
                                               colspan=spp["heat_cols"])

                # plot predictions
                plot_predictions(
                    ax_pred,
                    sat_preds[0, :, tii],
                    options.satmut_len,
                    model.hp.seq_length,
                    model.hp.batch_buffer,
                )

                # plot sequence logo
                plot_seqlogo(ax_logo_loss, seqs_1hot[si], -sat_loss[:, tii])
                if options.gain:
                    plot_seqlogo(ax_logo_gain, seqs_1hot[si], sat_gain[:, tii])

                # plot SAD
                plot_sad(ax_sad, sat_loss[:, tii], sat_gain[:, tii])

                # plot heat map
                plot_heat(ax_heat, sat_delta[:, :, tii], options.min_limit)

                plt.tight_layout()
                plt.savefig(
                    "%s/seq%d_t%d.pdf" %
                    (options.out_dir, si, target_indexes[tii]),
                    dpi=600,
                )
                plt.close()
Ejemplo n.º 11
0
def run(params_file, data_file, num_train_epochs):
    shifts = [int(shift) for shift in FLAGS.shifts.split(',')]

    #######################################################
    # load data
    #######################################################
    data_open = h5py.File(data_file)

    train_seqs = data_open['train_in']
    train_targets = data_open['train_out']
    train_na = None
    if 'train_na' in data_open:
        train_na = data_open['train_na']

    valid_seqs = data_open['valid_in']
    valid_targets = data_open['valid_out']
    valid_na = None
    if 'valid_na' in data_open:
        valid_na = data_open['valid_na']

    #######################################################
    # model parameters and placeholders
    #######################################################
    job = dna_io.read_job_params(params_file)

    job['batch_length'] = train_seqs.shape[1]
    job['seq_depth'] = train_seqs.shape[2]
    job['num_targets'] = train_targets.shape[2]
    job['target_pool'] = int(np.array(data_open.get('pool_width', 1)))
    job['early_stop'] = job.get('early_stop', 16)
    job['rate_drop'] = job.get('rate_drop', 3)

    t0 = time.time()
    dr = seqnn.SeqNN()
    dr.build(job)
    print('Model building time %f' % (time.time() - t0))

    # adjust for fourier
    job['fourier'] = 'train_out_imag' in data_open
    if job['fourier']:
        train_targets_imag = data_open['train_out_imag']
        valid_targets_imag = data_open['valid_out_imag']

    #######################################################
    # train
    #######################################################
    # initialize batcher
    if job['fourier']:
        batcher_train = batcher.BatcherF(train_seqs,
                                         train_targets,
                                         train_targets_imag,
                                         train_na,
                                         dr.batch_size,
                                         dr.target_pool,
                                         shuffle=True)
        batcher_valid = batcher.BatcherF(valid_seqs, valid_targets,
                                         valid_targets_imag, valid_na,
                                         dr.batch_size, dr.target_pool)
    else:
        batcher_train = batcher.Batcher(train_seqs,
                                        train_targets,
                                        train_na,
                                        dr.batch_size,
                                        dr.target_pool,
                                        shuffle=True)
        batcher_valid = batcher.Batcher(valid_seqs, valid_targets, valid_na,
                                        dr.batch_size, dr.target_pool)
    print('Batcher initialized')

    # checkpoints
    saver = tf.train.Saver()

    config = tf.ConfigProto()
    if FLAGS.log_device_placement:
        config.log_device_placement = True
    with tf.Session(config=config) as sess:
        t0 = time.time()

        # set seed
        tf.set_random_seed(FLAGS.seed)

        if FLAGS.logdir:
            train_writer = tf.summary.FileWriter(FLAGS.logdir + '/train',
                                                 sess.graph)
        else:
            train_writer = None

        if FLAGS.restart:
            # load variables into session
            saver.restore(sess, FLAGS.restart)
        else:
            # initialize variables
            print('Initializing...')
            sess.run(tf.global_variables_initializer())
            print('Initialization time %f' % (time.time() - t0))

        train_loss = None
        best_loss = None
        early_stop_i = 0
        undroppable_counter = 3
        max_drops = 8
        num_drops = 0

        for epoch in range(num_train_epochs):
            if early_stop_i < job['early_stop'] or epoch < FLAGS.min_epochs:
                t0 = time.time()

                # save previous
                train_loss_last = train_loss

                # alternate forward and reverse batches
                fwdrc = True
                if FLAGS.rc and epoch % 2 == 1:
                    fwdrc = False

                # cycle shifts
                shift_i = epoch % len(shifts)

                # train
                train_loss = dr.train_epoch(sess, batcher_train, fwdrc,
                                            shifts[shift_i], train_writer)

                # validate
                valid_acc = dr.test(sess,
                                    batcher_valid,
                                    mc_n=FLAGS.mc_n,
                                    rc=FLAGS.rc,
                                    shifts=shifts)
                valid_loss = valid_acc.loss
                valid_r2 = valid_acc.r2().mean()
                del valid_acc

                best_str = ''
                if best_loss is None or valid_loss < best_loss:
                    best_loss = valid_loss
                    best_str = ', best!'
                    early_stop_i = 0
                    saver.save(
                        sess,
                        '%s/%s_best.tf' % (FLAGS.logdir, FLAGS.save_prefix))
                else:
                    early_stop_i += 1

                # measure time
                et = time.time() - t0
                if et < 600:
                    time_str = '%3ds' % et
                elif et < 6000:
                    time_str = '%3dm' % (et / 60)
                else:
                    time_str = '%3.1fh' % (et / 3600)

                # print update
                print(
                    'Epoch %3d: Train loss: %7.5f, Valid loss: %7.5f, Valid R2: %7.5f, Time: %s%s'
                    % (epoch + 1, train_loss, valid_loss, valid_r2, time_str,
                       best_str),
                    end='')

                # if training stagnant
                if FLAGS.learn_rate_drop and num_drops < max_drops and undroppable_counter == 0 and (
                        train_loss_last -
                        train_loss) / train_loss_last < 0.0002:
                    print(', rate drop', end='')
                    dr.drop_rate(2 / 3)
                    undroppable_counter = 1
                    num_drops += 1
                else:
                    undroppable_counter = max(0, undroppable_counter - 1)

                print('')
                sys.stdout.flush()

        if FLAGS.logdir:
            train_writer.close()
Ejemplo n.º 12
0
def main():
    usage = "usage: %prog [options] <params_file> <model_file> <data_file>"
    parser = OptionParser(usage)
    parser.add_option("-l",
                      dest="layers",
                      default=None,
                      help="Comma-separated list of layers to plot")
    parser.add_option(
        "-n",
        dest="num_seqs",
        default=None,
        type="int",
        help="Number of sequences to process",
    )
    parser.add_option(
        "-o",
        dest="out_dir",
        default="hidden",
        help="Output directory [Default: %default]",
    )
    parser.add_option(
        "-t",
        dest="target_indexes",
        default=None,
        help="Paint 2D plots with these target index values.",
    )
    (options, args) = parser.parse_args()

    if len(args) != 3:
        parser.error("Must provide paramters, model, and test data HDF5")
    else:
        params_file = args[0]
        model_file = args[1]
        data_file = args[2]

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    if options.layers is not None:
        options.layers = [int(li) for li in options.layers.split(",")]

    #######################################################
    # load data
    #######################################################
    data_open = h5py.File(data_file)
    test_seqs = data_open["test_in"]
    test_targets = data_open["test_out"]

    if options.num_seqs is not None:
        test_seqs = test_seqs[:options.num_seqs]
        test_targets = test_targets[:options.num_seqs]

    #######################################################
    # model parameters and placeholders
    #######################################################
    job = params.read_job_params(params_file)

    job["seq_length"] = test_seqs.shape[1]
    job["seq_depth"] = test_seqs.shape[2]
    job["num_targets"] = test_targets.shape[2]
    job["target_pool"] = int(np.array(data_open.get("pool_width", 1)))

    t0 = time.time()
    model = seqnn.SeqNN()
    model.build_feed(job)

    if options.target_indexes is None:
        options.target_indexes = range(job["num_targets"])
    else:
        options.target_indexes = [
            int(ti) for ti in options.target_indexes.split(",")
        ]

    #######################################################
    # test
    #######################################################
    # initialize batcher
    batcher_test = batcher.Batcher(
        test_seqs,
        test_targets,
        batch_size=model.hp.batch_size,
        pool_width=model.hp.target_pool,
    )

    # initialize saver
    saver = tf.train.Saver()

    with tf.Session() as sess:
        # load variables into session
        saver.restore(sess, model_file)

        # get layer representations
        layer_reprs, _ = model.hidden(sess, batcher_test, options.layers)

        if options.layers is None:
            options.layers = range(len(layer_reprs))

        for li in options.layers:
            layer_repr = layer_reprs[li]
            try:
                print(layer_repr.shape)
            except:
                print(layer_repr)

            # sample one nt per sequence
            ds_indexes = np.arange(0, layer_repr.shape[1], 256)
            nt_reprs = layer_repr[:, ds_indexes, :].reshape(
                (-1, layer_repr.shape[2]))
            print("nt_reprs", nt_reprs.shape)

            ########################################################
            # plot raw
            sns.set(style="ticks", font_scale=1.2)
            plt.figure()
            g = sns.clustermap(nt_reprs,
                               cmap="RdBu_r",
                               xticklabels=False,
                               yticklabels=False)
            g.ax_heatmap.set_xlabel("Representation")
            g.ax_heatmap.set_ylabel("Sequences")
            plt.savefig("%s/l%d_reprs.pdf" % (options.out_dir, li))
            plt.close()

            ########################################################
            # plot variance explained ratios

            model_full = PCA()
            model_full.fit_transform(nt_reprs)
            evr = model_full.explained_variance_ratio_

            pca_n = 40

            plt.figure()
            plt.scatter(range(1, pca_n + 1), evr[:pca_n], c="black")
            ax = plt.gca()
            ax.set_xlim(0, pca_n + 1)
            ax.set_xlabel("Principal components")
            ax.set_ylim(0, evr[:pca_n].max() * 1.05)
            ax.set_ylabel("Variance explained")
            ax.grid(True, linestyle=":")
            plt.savefig("%s/l%d_pca.pdf" % (options.out_dir, li))
            plt.close()

            ########################################################
            # visualize in 2D

            model2 = PCA(n_components=2)
            nt_2d = model2.fit_transform(nt_reprs)

            for ti in options.target_indexes:
                # slice for target
                test_targets_ti = test_targets[:, :, ti]

                # repeat to match layer_repr
                target_repeat = layer_repr.shape[1] // test_targets.shape[1]
                test_targets_ti = np.repeat(test_targets_ti,
                                            target_repeat,
                                            axis=1)

                # downsample indexes
                nt_targets = test_targets_ti[:, ds_indexes].flatten()

                # log transform
                nt_targets = np.log1p(nt_targets)

                plt.figure()
                plt.scatter(nt_2d[:, 0],
                            nt_2d[:, 1],
                            alpha=0.5,
                            c=nt_targets,
                            cmap="RdBu_r")
                plt.colorbar()
                ax = plt.gca()
                ax.grid(True, linestyle=":")
                plt.savefig("%s/l%d_nt2d_t%d.pdf" % (options.out_dir, li, ti))
                plt.close()

            ########################################################
            # plot neuron-neuron correlations

            # compute correlation matrix
            hidden_cov = np.corrcoef(nt_reprs.T)
            print("hidden_cov", hidden_cov.shape)

            plt.figure()
            g = sns.clustermap(hidden_cov,
                               cmap="RdBu_r",
                               xticklabels=False,
                               yticklabels=False)
            plt.savefig("%s/l%d_cov.pdf" % (options.out_dir, li))
            plt.close()

            ########################################################
            # plot neuron densities
            neuron_stats_out = open("%s/l%d_stats.txt" % (options.out_dir, li),
                                    "w")

            for ni in range(nt_reprs.shape[1]):
                # print stats
                nu = nt_reprs[:, ni].mean()
                nstd = nt_reprs[:, ni].std()
                print("%3d  %6.3f  %6.3f" % (ni, nu, nstd),
                      file=neuron_stats_out)

                # plot
                # plt.figure()
                # sns.distplot(nt_reprs[:,ni])
                # plt.savefig('%s/l%d_dist%d.pdf' % (options.out_dir,li,ni))
                # plt.close()

            neuron_stats_out.close()

            ########################################################
            # plot layer norms across length
            """
            layer_repr_norms = np.linalg.norm(layer_repr, axis=2)

            length_vec =
            list(range(layer_repr_norms.shape[1]))*layer_repr_norms.shape[0]
            length_vec = np.array(length_vec) +
            0.1*np.random.randn(len(length_vec))
            repr_norm_vec = layer_repr_norms.flatten()

            out_pdf = '%s/l%d_lnorm.pdf' % (options.out_dir,li)
            regplot(length_vec, repr_norm_vec, out_pdf, x_label='Position',
            y_label='Repr Norm')
            """

    data_open.close()
Ejemplo n.º 13
0
def main():
    usage = 'usage: %prog [options] <params_file> <model_file> <input_file>'
    parser = OptionParser(usage)
    parser.add_option(
        '-a',
        dest='activity_enrich',
        default=1,
        type='float',
        help='Enrich for the most active top % of sequences [Default: %default]'
    )
    parser.add_option('-b',
                      dest='batch_size',
                      default=None,
                      type='int',
                      help='Batch size')
    parser.add_option('-f',
                      dest='figure_width',
                      default=20,
                      type='float',
                      help='Figure width [Default: %default]')
    parser.add_option(
        '-g',
        dest='gain',
        default=False,
        action='store_true',
        help='Draw a sequence logo for the gain score, too [Default: %default]'
    )
    parser.add_option(
        '-l',
        dest='satmut_len',
        default=200,
        type='int',
        help='Length of centered sequence to mutate [Default: %default]')
    parser.add_option('-m',
                      dest='min_limit',
                      default=0.005,
                      type='float',
                      help='Minimum heatmap limit [Default: %default]')
    parser.add_option(
        '-n',
        dest='load_sat_npy',
        default=False,
        action='store_true',
        help='Load the predictions from .npy files [Default: %default]')
    parser.add_option('-o',
                      dest='out_dir',
                      default='heat',
                      help='Output directory [Default: %default]')
    parser.add_option('-r',
                      dest='rng_seed',
                      default=1,
                      type='float',
                      help='Random number generator seed [Default: %default]')
    parser.add_option(
        '--rc',
        dest='rc',
        default=False,
        action='store_true',
        help=
        'Ensemble forward and reverse complement predictions [Default: %default]'
    )
    parser.add_option(
        '-s',
        dest='sample',
        default=None,
        type='int',
        help='Sample sequences from the test set [Default:%default]')
    parser.add_option('--shifts',
                      dest='shifts',
                      default='0',
                      help='Ensemble prediction shifts [Default: %default]')
    parser.add_option(
        '-t',
        dest='targets',
        default='0',
        help=
        'Comma-separated target indexes (or -1 for all) [Default: %default]')
    (options, args) = parser.parse_args()

    if len(args) != 3:
        parser.error(
            'Must provide parameters and model files and input sequences (as a '
            'FASTA file or test data in an HDF file')
    else:
        params_file = args[0]
        model_file = args[1]
        input_file = args[2]

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    options.shifts = [int(shift) for shift in options.shifts.split(',')]

    random.seed(options.rng_seed)

    #################################################################
    # parse input file
    #################################################################
    seqs, seqs_1hot, targets = parse_input(input_file, options.sample)

    # decide which targets to obtain
    if options.targets == '-1':
        target_indexes = range(targets.shape[2])
        target_subset = None
    else:
        target_indexes = [int(ti) for ti in options.targets.split(',')]
        target_subset = target_indexes

    # enrich for active sequences
    if targets is not None:
        seqs, seqs_1hot, targets = enrich_activity(seqs, seqs_1hot, targets,
                                                   options.activity_enrich,
                                                   target_indexes)

    seqs_n = seqs_1hot.shape[0]

    #################################################################
    # setup model
    #################################################################
    job = params.read_job_params(params_file)

    job['seq_length'] = seqs_1hot.shape[1]
    job['seq_depth'] = seqs_1hot.shape[2]

    if targets is None:
        if 'num_targets' not in job or 'target_pool' not in job:
            print(
                'Must provide num_targets and target_pool in parameters file',
                file=sys.stderr)
            exit(1)
    else:
        job['num_targets'] = targets.shape[2]
        job['target_pool'] = job['seq_length'] // targets.shape[1]

    t0 = time.time()
    model = seqnn.SeqNN()
    model.build_feed(job,
                     ensemble_rc=options.rc,
                     ensemble_shifts=options.shifts,
                     target_subset=target_subset)

    if options.batch_size is not None:
        model.hp.batch_size = options.batch_size

    # initialize saver
    saver = tf.train.Saver()

    with tf.Session() as sess:
        # load variables into session
        saver.restore(sess, model_file)

        for si in range(seqs_n):
            print('Mutating sequence %d / %d' % (si + 1, seqs_n), flush=True)

            # write sequence
            fasta_out = open('%s/seq%d.fa' % (options.out_dir, si), 'w')
            end_len = (len(seqs[si]) - options.satmut_len) // 2
            print('>seq%d\n%s' % (si, seqs[si][end_len:-end_len]),
                  file=fasta_out)
            fasta_out.close()

            #################################################################
            # predict modifications

            if options.load_sat_npy:
                sat_preds = np.load('%s/seq%d_preds.npy' %
                                    (options.out_dir, si))

            else:
                # supplement with saturated mutagenesis
                sat_seqs_1hot = satmut_seqs(seqs_1hot[si:si + 1],
                                            options.satmut_len)

                # initialize batcher
                batcher_sat = batcher.Batcher(sat_seqs_1hot,
                                              batch_size=model.hp.batch_size)

                # predict
                sat_preds = model.predict_h5(sess, batcher_sat)
                np.save('%s/seq%d_preds.npy' % (options.out_dir, si),
                        sat_preds)

            #################################################################
            # compute delta, loss, and gain matrices

            # compute the matrix of prediction deltas: (4 x L_sm x T) array
            sat_delta = delta_matrix(seqs_1hot[si], sat_preds,
                                     options.satmut_len)

            # sat_loss, sat_gain = loss_gain(sat_delta, sat_preds[si], options.satmut_len)
            sat_loss = sat_delta.min(axis=0)
            sat_gain = sat_delta.max(axis=0)

            ##############################################
            # plot

            for tii in range(len(target_indexes)):
                # setup plot
                sns.set(style='white', font_scale=1)
                spp = subplot_params(sat_delta.shape[1])

                if options.gain:
                    plt.figure(figsize=(options.figure_width, 5))
                    ax_pred = plt.subplot2grid((5, spp['heat_cols']),
                                               (0, spp['pred_start']),
                                               colspan=spp['pred_span'])
                    ax_logo_loss = plt.subplot2grid((5, spp['heat_cols']),
                                                    (1, spp['logo_start']),
                                                    colspan=spp['logo_span'])
                    ax_logo_gain = plt.subplot2grid((5, spp['heat_cols']),
                                                    (2, spp['logo_start']),
                                                    colspan=spp['logo_span'])
                    ax_sad = plt.subplot2grid((5, spp['heat_cols']),
                                              (3, spp['sad_start']),
                                              colspan=spp['sad_span'])
                    ax_heat = plt.subplot2grid((5, spp['heat_cols']), (4, 0),
                                               colspan=spp['heat_cols'])

                else:
                    plt.figure(figsize=(options.figure_width, 4))
                    ax_pred = plt.subplot2grid((4, spp['heat_cols']),
                                               (0, spp['pred_start']),
                                               colspan=spp['pred_span'])
                    ax_logo_loss = plt.subplot2grid((4, spp['heat_cols']),
                                                    (1, spp['logo_start']),
                                                    colspan=spp['logo_span'])
                    ax_sad = plt.subplot2grid((4, spp['heat_cols']),
                                              (2, spp['sad_start']),
                                              colspan=spp['sad_span'])
                    ax_heat = plt.subplot2grid((4, spp['heat_cols']), (3, 0),
                                               colspan=spp['heat_cols'])

                # plot predictions
                plot_predictions(ax_pred, sat_preds[0, :,
                                                    tii], options.satmut_len,
                                 model.hp.seq_length, model.hp.batch_buffer)

                # plot sequence logo
                plot_seqlogo(ax_logo_loss, seqs_1hot[si], -sat_loss[:, tii])
                if options.gain:
                    plot_seqlogo(ax_logo_gain, seqs_1hot[si], sat_gain[:, tii])

                # plot SAD
                plot_sad(ax_sad, sat_loss[:, tii], sat_gain[:, tii])

                # plot heat map
                plot_heat(ax_heat, sat_delta[:, :, tii], options.min_limit)

                plt.tight_layout()
                plt.savefig('%s/seq%d_t%d.pdf' %
                            (options.out_dir, si, target_indexes[tii]),
                            dpi=600)
                plt.close()
Ejemplo n.º 14
0
def main():
    usage = (
        "usage: %prog [options] <params_file> <model_file> <genes_hdf5_file>"
        " <vcf_file>"
    )
    parser = OptionParser(usage)
    parser.add_option(
        "-a",
        dest="all_sed",
        default=False,
        action="store_true",
        help="Print all variant-gene pairs, as opposed to only nonzero [Default: %default]",
    )
    parser.add_option(
        "-b",
        dest="batch_size",
        default=None,
        type="int",
        help="Batch size [Default: %default]",
    )
    parser.add_option(
        "-c",
        dest="csv",
        default=False,
        action="store_true",
        help="Print table as CSV [Default: %default]",
    )
    parser.add_option(
        "-g",
        dest="genome_file",
        default="%s/data/human.hg19.genome" % os.environ["BASENJIDIR"],
        help="Chromosome lengths file [Default: %default]",
    )
    parser.add_option(
        "-o",
        dest="out_dir",
        default="sed",
        help="Output directory for tables and plots [Default: %default]",
    )
    parser.add_option(
        "-p",
        dest="processes",
        default=None,
        type="int",
        help="Number of processes, passed by multi script",
    )
    parser.add_option(
        "--pseudo",
        dest="log_pseudo",
        default=0.125,
        type="float",
        help="Log2 pseudocount [Default: %default]",
    )
    parser.add_option(
        "-r",
        dest="tss_radius",
        default=0,
        type="int",
        help="Radius of bins considered to quantify TSS transcription [Default: %default]",
    )
    parser.add_option(
        "--rc",
        dest="rc",
        default=False,
        action="store_true",
        help="Average the forward and reverse complement predictions when testing [Default: %default]",
    )
    parser.add_option(
        "--shifts",
        dest="shifts",
        default="0",
        help="Ensemble prediction shifts [Default: %default]",
    )
    parser.add_option(
        "-t",
        dest="targets_file",
        default=None,
        help="File specifying target indexes and labels in table format",
    )
    parser.add_option(
        "-u",
        dest="penultimate",
        default=False,
        action="store_true",
        help="Compute SED in the penultimate layer [Default: %default]",
    )
    parser.add_option(
        "-x",
        dest="tss_table",
        default=False,
        action="store_true",
        help="Print TSS table in addition to gene [Default: %default]",
    )
    (options, args) = parser.parse_args()

    if len(args) == 4:
        # single worker
        params_file = args[0]
        model_file = args[1]
        genes_hdf5_file = args[2]
        vcf_file = args[3]

    elif len(args) == 6:
        # multi worker
        options_pkl_file = args[0]
        params_file = args[1]
        model_file = args[2]
        genes_hdf5_file = args[3]
        vcf_file = args[4]
        worker_index = int(args[5])

        # load options
        options_pkl = open(options_pkl_file, "rb")
        options = pickle.load(options_pkl)
        options_pkl.close()

        # update output directory
        options.out_dir = "%s/job%d" % (options.out_dir, worker_index)

    else:
        parser.error(
            "Must provide parameters and model files, genes HDF5 file, and QTL VCF"
            " file"
        )

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    options.shifts = [int(shift) for shift in options.shifts.split(",")]

    #################################################################
    # reads in genes HDF5

    gene_data = genedata.GeneData(genes_hdf5_file)

    # filter for worker sequences
    if options.processes is not None:
        gene_data.worker(worker_index, options.processes)

    #################################################################
    # prep SNPs

    # load SNPs
    snps = bvcf.vcf_snps(vcf_file)

    # intersect w/ segments
    print("Intersecting gene sequences with SNPs...", end="")
    sys.stdout.flush()
    seqs_snps = bvcf.intersect_seqs_snps(vcf_file, gene_data.gene_seqs, vision_p=0.5)
    print("%d sequences w/ SNPs" % len(seqs_snps))

    #################################################################
    # setup model

    job = params.read_job_params(params_file)

    job["seq_length"] = gene_data.seq_length
    job["seq_depth"] = gene_data.seq_depth
    job["target_pool"] = gene_data.pool_width

    if "num_targets" not in job and gene_data.num_targets is not None:
        job["num_targets"] = gene_data.num_targets

    if "num_targets" not in job:
        print(
            "Must specify number of targets (num_targets) in the \
            parameters file.",
            file=sys.stderr,
        )
        exit(1)

    if options.targets_file is None:
        target_labels = gene_data.target_labels
        target_subset = None
        target_ids = gene_data.target_ids
        if target_ids is None:
            target_ids = ["t%d" % ti for ti in range(job["num_targets"])]
            target_labels = [""] * len(target_ids)

    else:
        # Unfortunately, this target file differs from some others
        # in that it needs to specify the indexes from the original
        # set. In the future, I should standardize to this version.

        target_ids = []
        target_labels = []
        target_subset = []

        for line in open(options.targets_file):
            a = line.strip().split("\t")
            target_subset.append(int(a[0]))
            target_ids.append(a[1])
            target_labels.append(a[3])

            if len(target_subset) == job["num_targets"]:
                target_subset = None

    # build model
    model = seqnn.SeqNN()
    model.build_feed(job, target_subset=target_subset)

    if options.penultimate:
        # labels become inappropriate
        target_ids = [""] * model.hp.cnn_filters[-1]
        target_labels = target_ids

    #################################################################
    # compute, collect, and print SEDs

    header_cols = (
        "rsid",
        "ref",
        "alt",
        "gene",
        "tss_dist",
        "ref_pred",
        "alt_pred",
        "sed",
        "ser",
        "target_index",
        "target_id",
        "target_label",
    )
    if options.csv:
        sed_gene_out = open("%s/sed_gene.csv" % options.out_dir, "w")
        print(",".join(header_cols), file=sed_gene_out)
        if options.tss_table:
            sed_tss_out = open("%s/sed_tss.csv" % options.out_dir, "w")
            print(",".join(header_cols), file=sed_tss_out)

    else:
        sed_gene_out = open("%s/sed_gene.txt" % options.out_dir, "w")
        print(" ".join(header_cols), file=sed_gene_out)
        if options.tss_table:
            sed_tss_out = open("%s/sed_tss.txt" % options.out_dir, "w")
            print(" ".join(header_cols), file=sed_tss_out)

    # helper variables
    pred_buffer = model.hp.batch_buffer // model.hp.target_pool

    # initialize saver
    saver = tf.train.Saver()

    with tf.Session() as sess:
        # load variables into session
        saver.restore(sess, model_file)

        # for each gene sequence
        for seq_i in range(gene_data.num_seqs):
            gene_seq = gene_data.gene_seqs[seq_i]
            print(gene_seq)

            # if it contains SNPs
            seq_snps = [snps[snp_i] for snp_i in seqs_snps[seq_i]]
            if len(seq_snps) > 0:

                # one hot code allele sequences
                aseqs_1hot = alleles_1hot(
                    gene_seq, gene_data.seqs_1hot[seq_i], seq_snps
                )

                # initialize batcher
                batcher_gene = batcher.Batcher(
                    aseqs_1hot, batch_size=model.hp.batch_size
                )

                # construct allele gene_seq's
                allele_gene_seqs = [gene_seq] * aseqs_1hot.shape[0]

                # predict alleles
                allele_tss_preds = model.predict_genes(
                    sess,
                    batcher_gene,
                    allele_gene_seqs,
                    rc=options.rc,
                    shifts=options.shifts,
                    embed_penultimate=options.penultimate,
                    tss_radius=options.tss_radius,
                )

                # reshape (Alleles x TSSs) x Targets to Alleles x TSSs x Targets
                allele_tss_preds = allele_tss_preds.reshape(
                    (aseqs_1hot.shape[0], gene_seq.num_tss, -1)
                )

                # extract reference and SNP alt predictions
                ref_tss_preds = allele_tss_preds[0]  # TSSs x Targets
                alt_tss_preds = allele_tss_preds[1:]  # SNPs x TSSs x Targets

                # compute TSS SED scores
                snp_tss_sed = alt_tss_preds - ref_tss_preds
                snp_tss_ser = np.log2(alt_tss_preds + options.log_pseudo) - np.log2(
                    ref_tss_preds + options.log_pseudo
                )

                # compute gene-level predictions
                ref_gene_preds, gene_ids = gene.map_tss_genes(
                    ref_tss_preds, gene_seq.tss_list, options.tss_radius
                )
                alt_gene_preds = []
                for snp_i in range(len(seq_snps)):
                    agp, _ = gene.map_tss_genes(
                        alt_tss_preds[snp_i], gene_seq.tss_list, options.tss_radius
                    )
                    alt_gene_preds.append(agp)
                alt_gene_preds = np.array(alt_gene_preds)

                # compute gene SED scores
                gene_sed = alt_gene_preds - ref_gene_preds
                gene_ser = np.log2(alt_gene_preds + options.log_pseudo) - np.log2(
                    ref_gene_preds + options.log_pseudo
                )

                # for each SNP
                for snp_i in range(len(seq_snps)):
                    snp = seq_snps[snp_i]

                    # initialize gene data structures
                    snp_dist_gene = {}

                    # for each TSS
                    for tss_i in range(gene_seq.num_tss):
                        tss = gene_seq.tss_list[tss_i]

                        # SNP distance to TSS
                        snp_dist = abs(snp.pos - tss.pos)
                        if tss.gene_id in snp_dist_gene:
                            snp_dist_gene[tss.gene_id] = min(
                                snp_dist_gene[tss.gene_id], snp_dist
                            )
                        else:
                            snp_dist_gene[tss.gene_id] = snp_dist

                        # for each target
                        if options.tss_table:
                            for ti in range(ref_tss_preds.shape[1]):

                                # check if nonzero
                                if options.all_sed or not np.isclose(
                                    tss_sed[snp_i, tss_i, ti], 0, atol=1e-4
                                ):

                                    # print
                                    cols = (
                                        snp.rsid,
                                        bvcf.cap_allele(snp.ref_allele),
                                        bvcf.cap_allele(snp.alt_alleles[0]),
                                        tss.identifier,
                                        snp_dist,
                                        ref_tss_preds[tss_i, ti],
                                        alt_tss_preds[snp_i, tss_i, ti],
                                        tss_sed[snp_i, tss_i, ti],
                                        tss_ser[snp_i, tss_i, ti],
                                        ti,
                                        target_ids[ti],
                                        target_labels[ti],
                                    )
                                    if options.csv:
                                        print(
                                            ",".join([str(c) for c in cols]),
                                            file=sed_tss_out,
                                        )
                                    else:
                                        print(
                                            "%-13s %s %5s %16s %5d %7.4f %7.4f %7.4f %7.4f %4d %12s %s"
                                            % cols,
                                            file=sed_tss_out,
                                        )

                    # for each gene
                    for gi in range(len(gene_ids)):
                        gene_str = gene_ids[gi]
                        if gene_ids[gi] in gene_data.multi_seq_genes:
                            gene_str = "%s_multi" % gene_ids[gi]

                        # print rows to gene table
                        for ti in range(ref_gene_preds.shape[1]):

                            # check if nonzero
                            if options.all_sed or not np.isclose(
                                gene_sed[snp_i, gi, ti], 0, atol=1e-4
                            ):

                                # print
                                cols = [
                                    snp.rsid,
                                    bvcf.cap_allele(snp.ref_allele),
                                    bvcf.cap_allele(snp.alt_alleles[0]),
                                    gene_str,
                                    snp_dist_gene[gene_ids[gi]],
                                    ref_gene_preds[gi, ti],
                                    alt_gene_preds[snp_i, gi, ti],
                                    gene_sed[snp_i, gi, ti],
                                    gene_ser[snp_i, gi, ti],
                                    ti,
                                    target_ids[ti],
                                    target_labels[ti],
                                ]
                                if options.csv:
                                    print(
                                        ",".join([str(c) for c in cols]),
                                        file=sed_gene_out,
                                    )
                                else:
                                    print(
                                        "%-13s %s %5s %16s %5d %7.4f %7.4f %7.4f %7.4f %4d %12s %s"
                                        % tuple(cols),
                                        file=sed_gene_out,
                                    )

                # clean up
                gc.collect()

    sed_gene_out.close()
    if options.tss_table:
        sed_tss_out.close()
Ejemplo n.º 15
0
def run(params_file, data_file, train_epochs, train_epoch_batches,
        test_epoch_batches):

    #######################################################
    # load data
    #######################################################
    data_open = h5py.File(data_file)

    train_seqs = data_open['train_in']
    train_targets = data_open['train_out']
    train_na = None
    if 'train_na' in data_open:
        train_na = data_open['train_na']

    valid_seqs = data_open['valid_in']
    valid_targets = data_open['valid_out']
    valid_na = None
    if 'valid_na' in data_open:
        valid_na = data_open['valid_na']

    #######################################################
    # model parameters and placeholders
    #######################################################
    job = params.read_job_params(params_file)

    job['seq_length'] = train_seqs.shape[1]
    job['seq_depth'] = train_seqs.shape[2]
    job['num_targets'] = train_targets.shape[2]
    job['target_pool'] = int(np.array(data_open.get('pool_width', 1)))

    t0 = time.time()
    model = seqnn.SeqNN()
    model.build(job)
    print('Model building time %f' % (time.time() - t0))

    # adjust for fourier
    job['fourier'] = 'train_out_imag' in data_open
    if job['fourier']:
        train_targets_imag = data_open['train_out_imag']
        valid_targets_imag = data_open['valid_out_imag']

    #######################################################
    # prepare batcher
    #######################################################
    if job['fourier']:
        batcher_train = batcher.BatcherF(train_seqs,
                                         train_targets,
                                         train_targets_imag,
                                         train_na,
                                         model.hp.batch_size,
                                         model.hp.target_pool,
                                         shuffle=True)
        batcher_valid = batcher.BatcherF(valid_seqs, valid_targets,
                                         valid_targets_imag, valid_na,
                                         model.batch_size, model.target_pool)
    else:
        batcher_train = batcher.Batcher(train_seqs,
                                        train_targets,
                                        train_na,
                                        model.hp.batch_size,
                                        model.hp.target_pool,
                                        shuffle=True)
        batcher_valid = batcher.Batcher(valid_seqs, valid_targets, valid_na,
                                        model.hp.batch_size,
                                        model.hp.target_pool)
    print('Batcher initialized')

    #######################################################
    # train
    #######################################################
    augment_shifts = [int(shift) for shift in FLAGS.augment_shifts.split(',')]
    ensemble_shifts = [
        int(shift) for shift in FLAGS.ensemble_shifts.split(',')
    ]

    # checkpoints
    saver = tf.train.Saver()

    config = tf.ConfigProto()
    if FLAGS.log_device_placement:
        config.log_device_placement = True
    with tf.Session(config=config) as sess:
        t0 = time.time()

        # set seed
        tf.set_random_seed(FLAGS.seed)

        if FLAGS.logdir:
            train_writer = tf.summary.FileWriter(FLAGS.logdir + '/train',
                                                 sess.graph)
        else:
            train_writer = None

        if FLAGS.restart:
            # load variables into session
            saver.restore(sess, FLAGS.restart)
        else:
            # initialize variables
            print('Initializing...')
            sess.run(tf.global_variables_initializer())
            print('Initialization time %f' % (time.time() - t0))

        train_loss = None
        best_loss = None
        early_stop_i = 0

        epoch = 0
        while (train_epochs is None
               or epochs < train_epochs) and early_stop_i < FLAGS.early_stop:
            t0 = time.time()

            # alternate forward and reverse batches
            fwdrc = True
            if FLAGS.augment_rc and epoch % 2 == 1:
                fwdrc = False

            # cycle shifts
            shift_i = epoch % len(augment_shifts)

            # train
            train_loss, steps = model.train_epoch(
                sess,
                batcher_train,
                fwdrc=fwdrc,
                shift=augment_shifts[shift_i],
                sum_writer=train_writer,
                epoch_batches=train_epoch_batches,
                no_steps=FLAGS.no_steps)

            # validate
            valid_acc = model.test(sess,
                                   batcher_valid,
                                   mc_n=FLAGS.ensemble_mc,
                                   rc=FLAGS.ensemble_rc,
                                   shifts=ensemble_shifts,
                                   test_batches=test_epoch_batches)
            valid_loss = valid_acc.loss
            valid_r2 = valid_acc.r2().mean()
            del valid_acc

            best_str = ''
            if best_loss is None or valid_loss < best_loss:
                best_loss = valid_loss
                best_str = ', best!'
                early_stop_i = 0
                saver.save(sess, '%s/model_best.tf' % FLAGS.logdir)
            else:
                early_stop_i += 1

            # measure time
            et = time.time() - t0
            if et < 600:
                time_str = '%3ds' % et
            elif et < 6000:
                time_str = '%3dm' % (et / 60)
            else:
                time_str = '%3.1fh' % (et / 3600)

            # print update
            print(
                'Epoch: %3d,  Steps: %7d,  Train loss: %7.5f,  Valid loss: %7.5f,  Valid R2: %7.5f,  Time: %s%s'
                % (epoch + 1, steps, train_loss, valid_loss, valid_r2,
                   time_str, best_str))
            sys.stdout.flush()

            if FLAGS.check_all:
                saver.save(sess, '%s/model_check%d.tf' % (FLAGS.logdir, epoch))

            # update epoch
            epoch += 1

        if FLAGS.logdir:
            train_writer.close()
Ejemplo n.º 16
0
def score_write(sess, model, options, seqs_1hot, seqs_chrom, seqs_start):
  ''' Compute scores and write them as BigWigs for a set of sequences. '''

  for si in range(seqs_1hot.shape[0]):
    # initialize batcher
    batcher_si = batcher.Batcher(seqs_1hot[si:si+1],
                                 batch_size=model.hp.batch_size,
                                 pool_width=model.hp.target_pool)

    # get layer representations
    t0 = time.time()
    print('Computing gradients.', end='', flush=True)
    _, _, _, batch_grads, batch_reprs, _ = model.gradients(sess, batcher_si,
      rc=options.rc, shifts=options.shifts, mc_n=options.mc_n, return_all=True)
    print(' Done in %ds.' % (time.time()-t0), flush=True)

    # only layer
    batch_reprs = batch_reprs[0]
    batch_grads = batch_grads[0]

    # increase resolution
    batch_reprs = batch_reprs.astype('float32')
    batch_grads = batch_grads.astype('float32')

    # S (sequences) x T (targets) x P (seq position) x U (units layer i) x E (ensembles)
    print('batch_grads', batch_grads.shape)
    pooled_length = batch_grads.shape[2]

    # S (sequences) x P (seq position) x U (Units layer i) x E (ensembles)
    print('batch_reprs', batch_reprs.shape)

    # write bigwigs
    t0 = time.time()
    print('Writing BigWigs.', end='', flush=True)

    # for each target
    for tii in range(len(options.target_indexes)):
      ti = options.target_indexes[tii]

      # compute scores
      if options.norm is None:
        batch_grads_scores = np.multiply(batch_reprs[0], batch_grads[0,tii,:,:,:]).sum(axis=1)
      else:
        batch_grads_scores = np.multiply(batch_reprs[0], batch_grads[0,tii,:,:,:])
        batch_grads_scores = np.power(np.abs(batch_grads_scores), options.norm)
        batch_grads_scores = batch_grads_scores.sum(axis=1)
        batch_grads_scores = np.power(batch_grads_scores, 1./options.norm)

      # compute score statistics
      batch_grads_mean = batch_grads_scores.mean(axis=1)

      if options.norm is None:
        batch_grads_pval = ttest_1samp(batch_grads_scores, 0, axis=1)[1]
      else:
        batch_grads_pval = ttest_1samp(batch_grads_scores, 0, axis=1)[1]
        # batch_grads_pval = chi2(df=)
        batch_grads_pval /= 2

      # open bigwig
      bws_file = '%s/s%d_t%d_scores.bw' % (options.out_dir, si, ti)
      bwp_file = '%s/s%d_t%d_pvals.bw' % (options.out_dir, si, ti)
      bws_open = bigwig_open(bws_file, options.genome_file)
      # bwp_open = bigwig_open(bwp_file, options.genome_file)

      # specify bigwig locations and values
      bw_chroms = [seqs_chrom[si]]*pooled_length
      bw_starts = [int(seqs_start[si] + pi*model.hp.target_pool) for pi in range(pooled_length)]
      bw_ends = [int(bws + model.hp.target_pool) for bws in bw_starts]
      bws_values = [float(bgs) for bgs in batch_grads_mean]
      # bwp_values = [float(bgp) for bgp in batch_grads_pval]

      # write
      bws_open.addEntries(bw_chroms, bw_starts, ends=bw_ends, values=bws_values)
      # bwp_open.addEntries(bw_chroms, bw_starts, ends=bw_ends, values=bwp_values)

      # close
      bws_open.close()
      # bwp_open.close()

    print(' Done in %ds.' % (time.time()-t0), flush=True)
    gc.collect()
Ejemplo n.º 17
0
def main():
    usage = 'usage: %prog [options] <params_file> <model_file> <vcf_file>'
    parser = OptionParser(usage)
    parser.add_option('-b',
                      dest='batch_size',
                      default=256,
                      type='int',
                      help='Batch size [Default: %default]')
    parser.add_option('-c',
                      dest='csv',
                      default=False,
                      action='store_true',
                      help='Print table as CSV [Default: %default]')
    parser.add_option('-f',
                      dest='genome_fasta',
                      default='%s/data/hg19.fa' % os.environ['BASENJIDIR'],
                      help='Genome FASTA for sequences [Default: %default]')
    parser.add_option('-g',
                      dest='genome_file',
                      default='%s/data/human.hg19.genome' %
                      os.environ['BASENJIDIR'],
                      help='Chromosome lengths file [Default: %default]')
    parser.add_option('--h5',
                      dest='out_h5',
                      default=False,
                      action='store_true',
                      help='Output stats to sad.h5 [Default: %default]')
    parser.add_option('--local',
                      dest='local',
                      default=1024,
                      type='int',
                      help='Local SAD score [Default: %default]')
    parser.add_option('-n',
                      dest='norm_file',
                      default=None,
                      help='Normalize SAD scores')
    parser.add_option(
        '-o',
        dest='out_dir',
        default='sad',
        help='Output directory for tables and plots [Default: %default]')
    parser.add_option('-p',
                      dest='processes',
                      default=None,
                      type='int',
                      help='Number of processes, passed by multi script')
    parser.add_option('--pseudo',
                      dest='log_pseudo',
                      default=1,
                      type='float',
                      help='Log2 pseudocount [Default: %default]')
    parser.add_option(
        '--rc',
        dest='rc',
        default=False,
        action='store_true',
        help=
        'Average forward and reverse complement predictions [Default: %default]'
    )
    parser.add_option('--shifts',
                      dest='shifts',
                      default='0',
                      type='str',
                      help='Ensemble prediction shifts [Default: %default]')
    parser.add_option(
        '--stats',
        dest='sad_stats',
        default='SAD,xSAR',
        help='Comma-separated list of stats to save. [Default: %default]')
    parser.add_option(
        '-t',
        dest='targets_file',
        default=None,
        type='str',
        help='File specifying target indexes and labels in table format')
    parser.add_option(
        '--ti',
        dest='track_indexes',
        default=None,
        type='str',
        help='Comma-separated list of target indexes to output BigWig tracks')
    parser.add_option(
        '-u',
        dest='penultimate',
        default=False,
        action='store_true',
        help='Compute SED in the penultimate layer [Default: %default]')
    parser.add_option('-z',
                      dest='out_zarr',
                      default=False,
                      action='store_true',
                      help='Output stats to sad.zarr [Default: %default]')
    (options, args) = parser.parse_args()

    if len(args) == 3:
        # single worker
        params_file = args[0]
        model_file = args[1]
        vcf_file = args[2]

    elif len(args) == 5:
        # multi worker
        options_pkl_file = args[0]
        params_file = args[1]
        model_file = args[2]
        vcf_file = args[3]
        worker_index = int(args[4])

        # load options
        options_pkl = open(options_pkl_file, 'rb')
        options = pickle.load(options_pkl)
        options_pkl.close()

        # update output directory
        options.out_dir = '%s/job%d' % (options.out_dir, worker_index)

    else:
        parser.error(
            'Must provide parameters and model files and QTL VCF file')

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    if options.track_indexes is None:
        options.track_indexes = []
    else:
        options.track_indexes = [
            int(ti) for ti in options.track_indexes.split(',')
        ]
        if not os.path.isdir('%s/tracks' % options.out_dir):
            os.mkdir('%s/tracks' % options.out_dir)

    options.shifts = [int(shift) for shift in options.shifts.split(',')]
    options.sad_stats = options.sad_stats.split(',')

    #################################################################
    # setup model

    job = params.read_job_params(
        params_file, require=['seq_length', 'num_targets', 'target_pool'])

    if options.targets_file is None:
        target_ids = ['t%d' % ti for ti in range(job['num_targets'])]
        target_labels = [''] * len(target_ids)
        target_subset = None

    else:
        targets_df = pd.read_table(options.targets_file, index_col=0)
        target_ids = targets_df.identifier
        target_labels = targets_df.description
        target_subset = targets_df.index
        if len(target_subset) == job['num_targets']:
            target_subset = None

    # build model
    t0 = time.time()
    model = seqnn.SeqNN()
    model.build_feed(job,
                     ensemble_rc=options.rc,
                     ensemble_shifts=options.shifts,
                     embed_penultimate=options.penultimate,
                     target_subset=target_subset)
    print('Model building time %f' % (time.time() - t0), flush=True)

    if options.penultimate:
        # labels become inappropriate
        target_ids = [''] * model.hp.cnn_filters[-1]
        target_labels = target_ids

    # read target normalization factors
    target_norms = np.ones(len(target_labels))
    if options.norm_file is not None:
        ti = 0
        for line in open(options.norm_file):
            target_norms[ti] = float(line.strip())
            ti += 1

    num_targets = len(target_ids)

    #################################################################
    # load SNPs

    snps = bvcf.vcf_snps(vcf_file)

    # filter for worker SNPs
    if options.processes is not None:
        worker_bounds = np.linspace(0,
                                    len(snps),
                                    options.processes + 1,
                                    dtype='int')
        snps = snps[worker_bounds[worker_index]:worker_bounds[worker_index +
                                                              1]]

    num_snps = len(snps)

    #################################################################
    # setup output

    header_cols = ('rsid', 'ref', 'alt', 'ref_pred', 'alt_pred', 'sad', 'sar',
                   'geo_sad', 'ref_lpred', 'alt_lpred', 'lsad', 'lsar',
                   'ref_xpred', 'alt_xpred', 'xsad', 'xsar', 'target_index',
                   'target_id', 'target_label')

    if options.out_h5:
        sad_out = initialize_output_h5(options.out_dir, options.sad_stats,
                                       snps, target_ids, target_labels)

    elif options.out_zarr:
        sad_out = initialize_output_zarr(options.out_dir, options.sad_stats,
                                         snps, target_ids, target_labels)

    else:
        if options.csv:
            sad_out = open('%s/sad_table.csv' % options.out_dir, 'w')
            print(','.join(header_cols), file=sad_out)
        else:
            sad_out = open('%s/sad_table.txt' % options.out_dir, 'w')
            print(' '.join(header_cols), file=sad_out)

    #################################################################
    # process

    # open genome FASTA
    genome_open = pysam.Fastafile(options.genome_fasta)

    # determine local start and end
    loc_mid = model.target_length // 2
    loc_start = loc_mid - (options.local // 2) // model.hp.target_pool
    loc_end = loc_start + options.local // model.hp.target_pool

    snp_i = 0
    szi = 0

    # initialize saver
    saver = tf.train.Saver()
    with tf.Session() as sess:
        # load variables into session
        saver.restore(sess, model_file)

        # construct first batch
        batch_1hot, batch_snps, snp_i = snps_next_batch(
            snps, snp_i, options.batch_size, job['seq_length'], genome_open)

        while len(batch_snps) > 0:
            ###################################################
            # predict

            # initialize batcher
            batcher = batcher.Batcher(batch_1hot,
                                      batch_size=model.hp.batch_size)

            # predict
            # batch_preds = model.predict(sess, batcher,
            #                 rc=options.rc, shifts=options.shifts,
            #                 penultimate=options.penultimate)
            batch_preds = model.predict_h5(sess, batcher)

            # normalize
            batch_preds /= target_norms

            ###################################################
            # collect and print SADs

            pi = 0
            for snp in batch_snps:
                # get reference prediction (LxT)
                ref_preds = batch_preds[pi]
                pi += 1

                # sum across length
                ref_preds_sum = ref_preds.sum(axis=0, dtype='float64')

                # print tracks
                for ti in options.track_indexes:
                    ref_bw_file = '%s/tracks/%s_t%d_ref.bw' % (options.out_dir,
                                                               snp.rsid, ti)
                    bigwig_write(snp, job['seq_length'], ref_preds[:, ti],
                                 model, ref_bw_file, options.genome_file)

                for alt_al in snp.alt_alleles:
                    # get alternate prediction (LxT)
                    alt_preds = batch_preds[pi]
                    pi += 1

                    # sum across length
                    alt_preds_sum = alt_preds.sum(axis=0, dtype='float64')

                    # compare reference to alternative via mean subtraction
                    sad_vec = alt_preds - ref_preds
                    sad = alt_preds_sum - ref_preds_sum

                    # compare reference to alternative via mean log division
                    sar = np.log2(alt_preds_sum + options.log_pseudo) \
                            - np.log2(ref_preds_sum + options.log_pseudo)

                    # compare geometric means
                    sar_vec = np.log2(alt_preds.astype('float64') + options.log_pseudo) \
                                - np.log2(ref_preds.astype('float64') + options.log_pseudo)
                    geo_sad = sar_vec.sum(axis=0)

                    # sum locally
                    ref_preds_loc = ref_preds[loc_start:loc_end, :].sum(
                        axis=0, dtype='float64')
                    alt_preds_loc = alt_preds[loc_start:loc_end, :].sum(
                        axis=0, dtype='float64')

                    # compute SAD locally
                    sad_loc = alt_preds_loc - ref_preds_loc
                    sar_loc = np.log2(alt_preds_loc + options.log_pseudo) \
                                - np.log2(ref_preds_loc + options.log_pseudo)

                    # compute max difference position
                    max_li = np.argmax(np.abs(sar_vec), axis=0)

                    if options.out_h5 or options.out_zarr:
                        sad_out['SAD'][szi, :] = sad.astype('float16')
                        sad_out['xSAR'][szi, :] = np.array([
                            sar_vec[max_li[ti], ti]
                            for ti in range(num_targets)
                        ],
                                                           dtype='float16')
                        szi += 1

                    else:
                        # print table lines
                        for ti in range(len(sad)):
                            # print line
                            cols = (snp.rsid, bvcf.cap_allele(snp.ref_allele),
                                    bvcf.cap_allele(alt_al), ref_preds_sum[ti],
                                    alt_preds_sum[ti], sad[ti], sar[ti],
                                    geo_sad[ti], ref_preds_loc[ti],
                                    alt_preds_loc[ti], sad_loc[ti],
                                    sar_loc[ti], ref_preds[max_li[ti], ti],
                                    alt_preds[max_li[ti],
                                              ti], sad_vec[max_li[ti], ti],
                                    sar_vec[max_li[ti], ti], ti,
                                    target_ids[ti], target_labels[ti])
                            if options.csv:
                                print(','.join([str(c) for c in cols]),
                                      file=sad_out)
                            else:
                                print(
                                    '%-13s %6s %6s | %8.2f %8.2f %8.3f %7.4f %7.3f | %7.3f %7.3f %7.3f %7.4f | %7.3f %7.3f %7.3f %7.4f | %4d %12s %s'
                                    % cols,
                                    file=sad_out)

                    # print tracks
                    for ti in options.track_indexes:
                        alt_bw_file = '%s/tracks/%s_t%d_alt.bw' % (
                            options.out_dir, snp.rsid, ti)
                        bigwig_write(snp, job['seq_length'], alt_preds[:, ti],
                                     model, alt_bw_file, options.genome_file)

            ###################################################
            # construct next batch

            batch_1hot, batch_snps, snp_i = snps_next_batch(
                snps, snp_i, options.batch_size, job['seq_length'],
                genome_open)

    ###################################################
    # compute SAD distributions across variants

    if options.out_h5 or options.out_zarr:
        # define percentiles
        d_fine = 0.001
        d_coarse = 0.01
        percentiles_neg = np.arange(d_fine, 0.1, d_fine)
        percentiles_base = np.arange(0.1, 0.9, d_coarse)
        percentiles_pos = np.arange(0.9, 1, d_fine)

        percentiles = np.concatenate(
            [percentiles_neg, percentiles_base, percentiles_pos])
        sad_out.create_dataset('percentiles', data=percentiles)
        pct_len = len(percentiles)

        for sad_stat in options.sad_stats:
            sad_stat_pct = '%s_pct' % sad_stat

            # compute
            sad_pct = np.percentile(sad_out[sad_stat],
                                    100 * percentiles,
                                    axis=0).T
            sad_pct = sad_pct.astype('float16')

            # save
            sad_out.create_dataset(sad_stat_pct, data=sad_pct, dtype='float16')

    if not options.out_zarr:
        sad_out.close()
Ejemplo n.º 18
0
def main():
  usage = 'usage: %prog [options] <params_file> <model_file> <genes_hdf5_file>'
  parser = OptionParser(usage)
  parser.add_option(
      '-b',
      dest='batch_size',
      default=None,
      type='int',
      help='Batch size [Default: %default]')
  parser.add_option(
      '-i',
      dest='ignore_bed',
      help='Ignore genes overlapping regions in this BED file')
  parser.add_option(
      '-l', dest='load_preds', help='Load tess_preds from file')
  parser.add_option(
      '--heat',
      dest='plot_heat',
      default=False,
      action='store_true',
      help='Plot big gene-target heatmaps [Default: %default]')
  parser.add_option(
      '-o',
      dest='out_dir',
      default='genes_out',
      help='Output directory for tables and plots [Default: %default]')
  parser.add_option(
      '-r',
      dest='tss_radius',
      default=0,
      type='int',
      help='Radius of bins considered to quantify TSS transcription [Default: %default]')
  parser.add_option(
      '--rc',
      dest='rc',
      default=False,
      action='store_true',
      help=
      'Average the forward and reverse complement predictions when testing [Default: %default]'
  )
  parser.add_option(
      '-s',
      dest='plot_scatter',
      default=False,
      action='store_true',
      help='Make time-consuming accuracy scatter plots [Default: %default]')
  parser.add_option(
      '--shifts',
      dest='shifts',
      default='0',
      help='Ensemble prediction shifts [Default: %default]')
  parser.add_option(
      '--rep',
      dest='replicate_labels_file',
      help=
      'Compare replicate experiments, aided by the given file with long labels')
  parser.add_option(
      '-t',
      dest='target_indexes',
      default=None,
      help=
      'File or Comma-separated list of target indexes to scatter plot true versus predicted values'
  )
  parser.add_option(
      '--table',
      dest='print_tables',
      default=False,
      action='store_true',
      help='Print big gene/TSS tables [Default: %default]')
  parser.add_option(
      '--tss',
      dest='tss_alt',
      default=False,
      action='store_true',
      help='Perform alternative TSS analysis [Default: %default]')
  parser.add_option(
      '-v',
      dest='gene_variance',
      default=False,
      action='store_true',
      help=
      'Study accuracy with respect to gene variance across targets [Default: %default]'
  )
  (options, args) = parser.parse_args()

  if len(args) != 3:
    parser.error('Must provide parameters and model files, and genes HDF5 file')
  else:
    params_file = args[0]
    model_file = args[1]
    genes_hdf5_file = args[2]

  if not os.path.isdir(options.out_dir):
    os.mkdir(options.out_dir)

  options.shifts = [int(shift) for shift in options.shifts.split(',')]

  #################################################################
  # read in genes and targets

  gene_data = genedata.GeneData(genes_hdf5_file)


  #################################################################
  # TSS predictions

  if options.load_preds is not None:
    # load from file
    tss_preds = np.load(options.load_preds)

  else:

    #######################################################
    # setup model

    job = params.read_job_params(params_file)

    job['seq_length'] = gene_data.seq_length
    job['seq_depth'] = gene_data.seq_depth
    job['target_pool'] = gene_data.pool_width
    if not 'num_targets' in job:
      job['num_targets'] = gene_data.num_targets

    # build model
    model = seqnn.SeqNN()
    model.build(job)

    if options.batch_size is not None:
      model.hp.batch_size = options.batch_size


    #######################################################
    # predict TSSs

    t0 = time.time()
    print('Computing gene predictions.', end='')
    sys.stdout.flush()

    # initialize batcher
    gene_batcher = batcher.Batcher(
        gene_data.seqs_1hot, batch_size=model.hp.batch_size)

    # initialie saver
    saver = tf.train.Saver()

    with tf.Session() as sess:
      # load variables into session
      saver.restore(sess, model_file)

      # predict
      tss_preds = model.predict_genes(sess, gene_batcher, gene_data.gene_seqs,
          rc=options.rc, shifts=options.shifts, tss_radius=options.tss_radius)

    # save to file
    np.save('%s/preds' % options.out_dir, tss_preds)

    print(' Done in %ds.' % (time.time() - t0))

  #################################################################
  # convert to genes

  gene_targets, _ = gene.map_tss_genes(gene_data.tss_targets, gene_data.tss,
                                       tss_radius=options.tss_radius)
  gene_preds, _ = gene.map_tss_genes(tss_preds, gene_data.tss,
                                     tss_radius=options.tss_radius)


  #################################################################
  # determine targets

  # all targets
  if options.target_indexes is None:
    if gene_data.num_targets is None:
      print('No targets to test against')
      exit()
    else:
      options.target_indexes = np.arange(gene_data.num_targets)

  # file targets
  elif os.path.isfile(options.target_indexes):
    target_indexes_file = options.target_indexes
    options.target_indexes = []
    for line in open(target_indexes_file):
      options.target_indexes.append(int(line.split()[0]))

  # comma-separated targets
  else:
    options.target_indexes = [
        int(ti) for ti in options.target_indexes.split(',')
    ]

  options.target_indexes = np.array(options.target_indexes)

  #################################################################
  # correlation statistics

  t0 = time.time()
  print('Computing correlations.', end='')
  sys.stdout.flush()

  cor_table(gene_data.tss_targets, tss_preds, gene_data.target_ids,
            gene_data.target_labels, options.target_indexes,
            '%s/tss_cors.txt' % options.out_dir)

  cor_table(gene_targets, gene_preds, gene_data.target_ids,
            gene_data.target_labels, options.target_indexes,
            '%s/gene_cors.txt' % options.out_dir, plots=True)

  print(' Done in %ds.' % (time.time() - t0))

  #################################################################
  # gene statistics

  if options.print_tables:
    t0 = time.time()
    print('Printing predictions.', end='')
    sys.stdout.flush()

    gene_table(gene_data.tss_targets, tss_preds, gene_data.tss_ids(),
               gene_data.target_labels, options.target_indexes,
               '%s/transcript' % options.out_dir, options.plot_scatter)

    gene_table(gene_targets, gene_preds,
               gene_data.gene_ids(), gene_data.target_labels,
               options.target_indexes, '%s/gene' % options.out_dir,
               options.plot_scatter)

    print(' Done in %ds.' % (time.time() - t0))

  #################################################################
  # gene x target heatmaps

  if options.plot_heat or options.gene_variance:
    #########################################
    # normalize predictions across targets

    t0 = time.time()
    print('Normalizing values across targets.', end='')
    sys.stdout.flush()

    gene_targets_qn = normalize_targets(gene_targets[:, options.target_indexes], log_pseudo=1)
    gene_preds_qn = normalize_targets(gene_preds[:, options.target_indexes], log_pseudo=1)

    print(' Done in %ds.' % (time.time() - t0))

  if options.plot_heat:
    #########################################
    # plot genes by targets clustermap

    t0 = time.time()
    print('Plotting heat maps.', end='')
    sys.stdout.flush()

    sns.set(font_scale=1.3, style='ticks')
    plot_genes = 1600
    plot_targets = 800

    # choose a set of variable genes
    gene_vars = gene_preds_qn.var(axis=1)
    indexes_var = np.argsort(gene_vars)[::-1][:plot_genes]

    # choose a set of random genes
    if plot_genes < gene_preds_qn.shape[0]:
      indexes_rand = np.random.choice(
        np.arange(gene_preds_qn.shape[0]), plot_genes, replace=False)
    else:
      indexes_rand = np.arange(gene_preds_qn.shape[0])

    # choose a set of random targets
    if plot_targets < 0.8 * gene_preds_qn.shape[1]:
      indexes_targets = np.random.choice(
          np.arange(gene_preds_qn.shape[1]), plot_targets, replace=False)
    else:
      indexes_targets = np.arange(gene_preds_qn.shape[1])

    # variable gene predictions
    clustermap(gene_preds_qn[indexes_var, :][:, indexes_targets],
               '%s/gene_heat_var.pdf' % options.out_dir)
    clustermap(
        gene_preds_qn[indexes_var, :][:, indexes_targets],
        '%s/gene_heat_var_color.pdf' % options.out_dir,
        color='viridis',
        table=True)

    # random gene predictions
    clustermap(gene_preds_qn[indexes_rand, :][:, indexes_targets],
               '%s/gene_heat_rand.pdf' % options.out_dir)

    # variable gene targets
    clustermap(gene_targets_qn[indexes_var, :][:, indexes_targets],
               '%s/gene_theat_var.pdf' % options.out_dir)
    clustermap(
        gene_targets_qn[indexes_var, :][:, indexes_targets],
        '%s/gene_theat_var_color.pdf' % options.out_dir,
        color='viridis',
        table=True)

    # random gene targets (crashes)
    # clustermap(gene_targets_qn[indexes_rand, :][:, indexes_targets],
    #            '%s/gene_theat_rand.pdf' % options.out_dir)

    print(' Done in %ds.' % (time.time() - t0))

  #################################################################
  # analyze replicates

  if options.replicate_labels_file is not None:
    # read long form labels, from which to infer replicates
    target_labels_long = []
    for line in open(options.replicate_labels_file):
      a = line.split('\t')
      a[-1] = a[-1].rstrip()
      target_labels_long.append(a[-1])

    # determine replicates
    replicate_lists = infer_replicates(target_labels_long)

    # compute correlations
    # replicate_correlations(replicate_lists, gene_data.tss_targets,
        # tss_preds, options.target_indexes, '%s/transcript_reps' % options.out_dir)
    replicate_correlations(
        replicate_lists, gene_targets, gene_preds, options.target_indexes,
        '%s/gene_reps' % options.out_dir)  # , scatter_plots=True)

  #################################################################
  # gene variance

  if options.gene_variance:
    variance_accuracy(gene_targets_qn, gene_preds_qn,
                      '%s/gene' % options.out_dir)

  #################################################################
  # alternative TSS

  if options.tss_alt:
    alternative_tss(gene_data.tss_targets[:,options.target_indexes],
                    tss_preds[:,options.target_indexes], gene_data,
                    options.out_dir, log_pseudo=1)
Ejemplo n.º 19
0
def score_write(sess, model, options, target_indexes, seqs_1hot, seqs_chrom,
                seqs_start):
    ''' Compute scores and write them as BigWigs for a set of sequences. '''

    num_seqs = seqs_1hot.shape[0]
    num_targets = len(target_indexes)

    # initialize scores HDF5
    scores_h5_file = '%s/scores.h5' % options.out_dir
    scores_h5_out = h5py.File(scores_h5_file, 'w')

    for si in range(num_seqs):
        # initialize batcher
        batcher_si = batcher.Batcher(seqs_1hot[si:si + 1],
                                     batch_size=model.hp.batch_size,
                                     pool_width=model.hp.target_pool)

        # get layer representations
        t0 = time.time()
        print('Computing gradients.', end='', flush=True)
        _, _, _, batch_grads, batch_reprs, _ = model.gradients(
            sess,
            batcher_si,
            rc=options.rc,
            shifts=options.shifts,
            mc_n=options.mc_n,
            return_all=True)
        print(' Done in %ds.' % (time.time() - t0), flush=True)

        # only layer
        batch_reprs = batch_reprs[0]
        batch_grads = batch_grads[0]

        # increase resolution
        batch_reprs = batch_reprs.astype('float32')
        batch_grads = batch_grads.astype('float32')

        # S (sequences) x T (targets) x P (seq position) x U (units layer i) x E (ensembles)
        print('batch_grads', batch_grads.shape)

        # S (sequences) x P (seq position) x U (Units layer i) x E (ensembles)
        print('batch_reprs', batch_reprs.shape)

        preds_length = batch_reprs.shape[1]
        if 'score' not in scores_h5_out:
            # initialize scores
            scores_h5_out.create_dataset('score',
                                         shape=(num_seqs, preds_length,
                                                num_targets),
                                         dtype='float16')
            scores_h5_out.create_dataset('pvalue',
                                         shape=(num_seqs, preds_length,
                                                num_targets),
                                         dtype='float16')

        # write bigwigs
        t0 = time.time()
        print('Computing and writing scores.', end='', flush=True)

        # for each target
        for tii in range(len(target_indexes)):
            ti = target_indexes[tii]

            # representation x gradient
            batch_grads_scores = np.multiply(batch_reprs[0],
                                             batch_grads[0, tii, :, :, :])

            if options.norm is None:
                # sum across filters
                batch_grads_scores = batch_grads_scores.sum(axis=1)
            else:
                # raise to power
                batch_grads_scores = np.power(np.abs(batch_grads_scores),
                                              options.norm)
                # sum across filters
                batch_grads_scores = batch_grads_scores.sum(axis=1)
                # normalize w/ 1/power
                batch_grads_scores = np.power(batch_grads_scores,
                                              1. / options.norm)

            # mean across ensemble
            batch_grads_mean = batch_grads_scores.mean(axis=1)

            # compute p-values
            if options.norm is None:
                batch_grads_pval = ttest_1samp(batch_grads_scores, 0,
                                               axis=1)[1]
            else:
                batch_grads_pval = ttest_1samp(batch_grads_scores, 0,
                                               axis=1)[1]
                # batch_grads_pval = chi2(df=)
                batch_grads_pval /= 2

            # write to HDF5
            scores_h5_out['score'][si, :,
                                   tii] = batch_grads_mean.astype('float16')
            scores_h5_out['pvalue'][si, :,
                                    tii] = batch_grads_pval.astype('float16')

            if options.bigwig:
                # open bigwig
                bws_file = '%s/s%d_t%d_scores.bw' % (options.out_dir, si, ti)
                bwp_file = '%s/s%d_t%d_pvals.bw' % (options.out_dir, si, ti)
                bws_open = bigwig_open(bws_file, options.genome_file)
                # bwp_open = bigwig_open(bwp_file, options.genome_file)

                # specify bigwig locations and values
                bw_chroms = [seqs_chrom[si]] * preds_length
                bw_starts = [
                    int(seqs_start[si] + pi * model.hp.target_pool)
                    for pi in range(preds_length)
                ]
                bw_ends = [
                    int(bws + model.hp.target_pool) for bws in bw_starts
                ]
                bws_values = [float(bgs) for bgs in batch_grads_mean]
                # bwp_values = [float(bgp) for bgp in batch_grads_pval]

                # write
                bws_open.addEntries(bw_chroms,
                                    bw_starts,
                                    ends=bw_ends,
                                    values=bws_values)
                # bwp_open.addEntries(bw_chroms, bw_starts, ends=bw_ends, values=bwp_values)

            # close
            if options.bigwig:
                bws_open.close()
                # bwp_open.close()

        print(' Done in %ds.' % (time.time() - t0), flush=True)
        gc.collect()

    scores_h5_out.close()