Exemplo n.º 1
0
def main():
    usage = (
        'usage: %prog [options] <params_file> <model_file> <genes_hdf5_file>'
        ' <vcf_file>')
    parser = OptionParser(usage)
    parser.add_option(
        '-a',
        dest='all_sed',
        default=False,
        action='store_true',
        help=
        'Print all variant-gene pairs, as opposed to only nonzero [Default: %default]'
    )
    parser.add_option('-b',
                      dest='batch_size',
                      default=None,
                      type='int',
                      help='Batch size [Default: %default]')
    parser.add_option('-c',
                      dest='csv',
                      default=False,
                      action='store_true',
                      help='Print table as CSV [Default: %default]')
    parser.add_option('-g',
                      dest='genome_file',
                      default='%s/assembly/human.hg19.genome' %
                      os.environ['HG19'],
                      help='Chromosome lengths file [Default: %default]')
    parser.add_option(
        '-o',
        dest='out_dir',
        default='sed',
        help='Output directory for tables and plots [Default: %default]')
    parser.add_option('-p',
                      dest='processes',
                      default=None,
                      type='int',
                      help='Number of processes, passed by multi script')
    parser.add_option('--pseudo',
                      dest='log_pseudo',
                      default=0.125,
                      type='float',
                      help='Log2 pseudocount [Default: %default]')
    parser.add_option(
        '-r',
        dest='tss_radius',
        default=0,
        type='int',
        help=
        'Radius of bins considered to quantify TSS transcription [Default: %default]'
    )
    parser.add_option(
        '--rc',
        dest='rc',
        default=False,
        action='store_true',
        help=
        'Average the forward and reverse complement predictions when testing [Default: %default]'
    )
    parser.add_option('--shifts',
                      dest='shifts',
                      default='0',
                      help='Ensemble prediction shifts [Default: %default]')
    parser.add_option(
        '-t',
        dest='targets_file',
        default=None,
        help='File specifying target indexes and labels in table format')
    parser.add_option(
        '-u',
        dest='penultimate',
        default=False,
        action='store_true',
        help='Compute SED in the penultimate layer [Default: %default]')
    parser.add_option(
        '-x',
        dest='tss_table',
        default=False,
        action='store_true',
        help='Print TSS table in addition to gene [Default: %default]')
    (options, args) = parser.parse_args()

    if len(args) == 4:
        # single worker
        params_file = args[0]
        model_file = args[1]
        genes_hdf5_file = args[2]
        vcf_file = args[3]

    elif len(args) == 6:
        # multi worker
        options_pkl_file = args[0]
        params_file = args[1]
        model_file = args[2]
        genes_hdf5_file = args[3]
        vcf_file = args[4]
        worker_index = int(args[5])

        # load options
        options_pkl = open(options_pkl_file, 'rb')
        options = pickle.load(options_pkl)
        options_pkl.close()

        # update output directory
        options.out_dir = '%s/job%d' % (options.out_dir, worker_index)

    else:
        parser.error(
            'Must provide parameters and model files, genes HDF5 file, and QTL VCF'
            ' file')

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    options.shifts = [int(shift) for shift in options.shifts.split(',')]

    #################################################################
    # reads in genes HDF5

    gene_data = genedata.GeneData(genes_hdf5_file)

    # filter for worker sequences
    if options.processes is not None:
        gene_data.worker(worker_index, options.processes)

    #################################################################
    # prep SNPs

    # load SNPs
    snps = bvcf.vcf_snps(vcf_file)

    # intersect w/ segments
    print('Intersecting gene sequences with SNPs...', end='')
    sys.stdout.flush()
    seqs_snps = bvcf.intersect_seqs_snps(vcf_file,
                                         gene_data.gene_seqs,
                                         vision_p=0.5)
    print('%d sequences w/ SNPs' % len(seqs_snps))

    #################################################################
    # setup model

    job = params.read_job_params(params_file)

    job['seq_length'] = gene_data.seq_length
    job['seq_depth'] = gene_data.seq_depth
    job['target_pool'] = gene_data.pool_width

    if 'num_targets' not in job and gene_data.num_targets is not None:
        job['num_targets'] = gene_data.num_targets

    if 'num_targets' not in job:
        print("Must specify number of targets (num_targets) in the \
            parameters file.",
              file=sys.stderr)
        exit(1)

    if options.targets_file is None:
        target_labels = gene_data.target_labels
        target_subset = None
        target_ids = gene_data.target_ids
        if target_ids is None:
            target_ids = ['t%d' % ti for ti in range(job['num_targets'])]
            target_labels = [''] * len(target_ids)

    else:
        # Unfortunately, this target file differs from some others
        # in that it needs to specify the indexes from the original
        # set. In the future, I should standardize to this version.

        target_ids = []
        target_labels = []
        target_subset = []

        for line in open(options.targets_file):
            a = line.strip().split('\t')
            target_subset.append(int(a[0]))
            target_ids.append(a[1])
            target_labels.append(a[3])

            if len(target_subset) == job['num_targets']:
                target_subset = None

    # build model
    model = seqnn.SeqNN()
    model.build(job, target_subset=target_subset)

    if options.penultimate:
        # labels become inappropriate
        target_ids = [''] * model.hp.cnn_filters[-1]
        target_labels = target_ids

    #################################################################
    # compute, collect, and print SEDs

    header_cols = ('rsid', 'ref', 'alt', 'gene', 'tss_dist', 'ref_pred',
                   'alt_pred', 'sed', 'ser', 'target_index', 'target_id',
                   'target_label')
    if options.csv:
        sed_gene_out = open('%s/sed_gene.csv' % options.out_dir, 'w')
        print(','.join(header_cols), file=sed_gene_out)
        if options.tss_table:
            sed_tss_out = open('%s/sed_tss.csv' % options.out_dir, 'w')
            print(','.join(header_cols), file=sed_tss_out)

    else:
        sed_gene_out = open('%s/sed_gene.txt' % options.out_dir, 'w')
        print(' '.join(header_cols), file=sed_gene_out)
        if options.tss_table:
            sed_tss_out = open('%s/sed_tss.txt' % options.out_dir, 'w')
            print(' '.join(header_cols), file=sed_tss_out)

    # helper variables
    pred_buffer = model.hp.batch_buffer // model.hp.target_pool

    # initialize saver
    saver = tf.train.Saver()

    with tf.Session() as sess:
        # load variables into session
        saver.restore(sess, model_file)

        # for each gene sequence
        for seq_i in range(gene_data.num_seqs):
            gene_seq = gene_data.gene_seqs[seq_i]
            print(gene_seq)

            # if it contains SNPs
            seq_snps = [snps[snp_i] for snp_i in seqs_snps[seq_i]]
            if len(seq_snps) > 0:

                # one hot code allele sequences
                aseqs_1hot = alleles_1hot(gene_seq, gene_data.seqs_1hot[seq_i],
                                          seq_snps)

                # initialize batcher
                batcher_gene = batcher.Batcher(aseqs_1hot,
                                               batch_size=model.hp.batch_size)

                # construct allele gene_seq's
                allele_gene_seqs = [gene_seq] * aseqs_1hot.shape[0]

                # predict alleles
                allele_tss_preds = model.predict_genes(
                    sess,
                    batcher_gene,
                    allele_gene_seqs,
                    rc=options.rc,
                    shifts=options.shifts,
                    penultimate=options.penultimate,
                    tss_radius=options.tss_radius)

                # reshape (Alleles x TSSs) x Targets to Alleles x TSSs x Targets
                allele_tss_preds = allele_tss_preds.reshape(
                    (aseqs_1hot.shape[0], gene_seq.num_tss, -1))

                # extract reference and SNP alt predictions
                ref_tss_preds = allele_tss_preds[0]  # TSSs x Targets
                alt_tss_preds = allele_tss_preds[1:]  # SNPs x TSSs x Targets

                # compute TSS SED scores
                snp_tss_sed = alt_tss_preds - ref_tss_preds
                snp_tss_ser = np.log2(alt_tss_preds + options.log_pseudo) \
                                - np.log2(ref_tss_preds + options.log_pseudo)

                # compute gene-level predictions
                ref_gene_preds, gene_ids = gene.map_tss_genes(
                    ref_tss_preds, gene_seq.tss_list, options.tss_radius)
                alt_gene_preds = []
                for snp_i in range(len(seq_snps)):
                    agp, _ = gene.map_tss_genes(alt_tss_preds[snp_i],
                                                gene_seq.tss_list,
                                                options.tss_radius)
                    alt_gene_preds.append(agp)
                alt_gene_preds = np.array(alt_gene_preds)

                # compute gene SED scores
                gene_sed = alt_gene_preds - ref_gene_preds
                gene_ser = np.log2(alt_gene_preds + options.log_pseudo) \
                            - np.log2(ref_gene_preds + options.log_pseudo)

                # for each SNP
                for snp_i in range(len(seq_snps)):
                    snp = seq_snps[snp_i]

                    # initialize gene data structures
                    snp_dist_gene = {}

                    # for each TSS
                    for tss_i in range(gene_seq.num_tss):
                        tss = gene_seq.tss_list[tss_i]

                        # SNP distance to TSS
                        snp_dist = abs(snp.pos - tss.pos)
                        if tss.gene_id in snp_dist_gene:
                            snp_dist_gene[tss.gene_id] = min(
                                snp_dist_gene[tss.gene_id], snp_dist)
                        else:
                            snp_dist_gene[tss.gene_id] = snp_dist

                        # for each target
                        if options.tss_table:
                            for ti in range(ref_tss_preds.shape[1]):

                                # check if nonzero
                                if options.all_sed or not np.isclose(
                                        tss_sed[snp_i, tss_i,
                                                ti], 0, atol=1e-4):

                                    # print
                                    cols = (snp.rsid,
                                            bvcf.cap_allele(snp.ref_allele),
                                            bvcf.cap_allele(
                                                snp.alt_alleles[0]),
                                            tss.identifier, snp_dist,
                                            ref_tss_preds[tss_i, ti],
                                            alt_tss_preds[snp_i, tss_i, ti],
                                            tss_sed[snp_i, tss_i,
                                                    ti], tss_ser[snp_i, tss_i,
                                                                 ti], ti,
                                            target_ids[ti], target_labels[ti])
                                    if options.csv:
                                        print(','.join([str(c) for c in cols]),
                                              file=sed_tss_out)
                                    else:
                                        print(
                                            '%-13s %s %5s %16s %5d %7.4f %7.4f %7.4f %7.4f %4d %12s %s'
                                            % cols,
                                            file=sed_tss_out)

                    # for each gene
                    for gi in range(len(gene_ids)):
                        gene_str = gene_ids[gi]
                        if gene_ids[gi] in gene_data.multi_seq_genes:
                            gene_str = '%s_multi' % gene_ids[gi]

                        # print rows to gene table
                        for ti in range(ref_gene_preds.shape[1]):

                            # check if nonzero
                            if options.all_sed or not np.isclose(
                                    gene_sed[snp_i, gi, ti], 0, atol=1e-4):

                                # print
                                cols = [
                                    snp.rsid,
                                    bvcf.cap_allele(snp.ref_allele),
                                    bvcf.cap_allele(snp.alt_alleles[0]),
                                    gene_str, snp_dist_gene[gene_ids[gi]],
                                    ref_gene_preds[gi,
                                                   ti], alt_gene_preds[snp_i,
                                                                       gi, ti],
                                    gene_sed[snp_i, gi,
                                             ti], gene_ser[snp_i, gi, ti], ti,
                                    target_ids[ti], target_labels[ti]
                                ]
                                if options.csv:
                                    print(','.join([str(c) for c in cols]),
                                          file=sed_gene_out)
                                else:
                                    print(
                                        '%-13s %s %5s %16s %5d %7.4f %7.4f %7.4f %7.4f %4d %12s %s'
                                        % tuple(cols),
                                        file=sed_gene_out)

                # clean up
                gc.collect()

    sed_gene_out.close()
    if options.tss_table:
        sed_tss_out.close()
Exemplo n.º 2
0
def main():
    usage = "usage: %prog [options] <params_file> <model_file> <vcf_file>"
    parser = OptionParser(usage)
    parser.add_option(
        "-b",
        dest="batch_size",
        default=256,
        type="int",
        help="Batch size [Default: %default]",
    )
    parser.add_option(
        "-c",
        dest="csv",
        default=False,
        action="store_true",
        help="Print table as CSV [Default: %default]",
    )
    parser.add_option(
        "-f",
        dest="genome_fasta",
        default="%s/data/hg19.fa" % os.environ["BASENJIDIR"],
        help="Genome FASTA for sequences [Default: %default]",
    )
    parser.add_option(
        "-g",
        dest="genome_file",
        default="%s/data/human.hg19.genome" % os.environ["BASENJIDIR"],
        help="Chromosome lengths file [Default: %default]",
    )
    parser.add_option(
        "--h5",
        dest="out_h5",
        default=False,
        action="store_true",
        help="Output stats to sad.h5 [Default: %default]",
    )
    parser.add_option(
        "--local",
        dest="local",
        default=1024,
        type="int",
        help="Local SAD score [Default: %default]",
    )
    parser.add_option("-n",
                      dest="norm_file",
                      default=None,
                      help="Normalize SAD scores")
    parser.add_option(
        "-o",
        dest="out_dir",
        default="sad",
        help="Output directory for tables and plots [Default: %default]",
    )
    parser.add_option(
        "-p",
        dest="processes",
        default=None,
        type="int",
        help="Number of processes, passed by multi script",
    )
    parser.add_option(
        "--pseudo",
        dest="log_pseudo",
        default=1,
        type="float",
        help="Log2 pseudocount [Default: %default]",
    )
    parser.add_option(
        "--rc",
        dest="rc",
        default=False,
        action="store_true",
        help=
        "Average forward and reverse complement predictions [Default: %default]",
    )
    parser.add_option(
        "--shifts",
        dest="shifts",
        default="0",
        type="str",
        help="Ensemble prediction shifts [Default: %default]",
    )
    parser.add_option(
        "--stats",
        dest="sad_stats",
        default="SAD,xSAR",
        help="Comma-separated list of stats to save. [Default: %default]",
    )
    parser.add_option(
        "-t",
        dest="targets_file",
        default=None,
        type="str",
        help="File specifying target indexes and labels in table format",
    )
    parser.add_option(
        "--ti",
        dest="track_indexes",
        default=None,
        type="str",
        help="Comma-separated list of target indexes to output BigWig tracks",
    )
    parser.add_option(
        "-u",
        dest="penultimate",
        default=False,
        action="store_true",
        help="Compute SED in the penultimate layer [Default: %default]",
    )
    parser.add_option(
        "-z",
        dest="out_zarr",
        default=False,
        action="store_true",
        help="Output stats to sad.zarr [Default: %default]",
    )
    (options, args) = parser.parse_args()

    if len(args) == 3:
        # single worker
        params_file = args[0]
        model_file = args[1]
        vcf_file = args[2]

    elif len(args) == 5:
        # multi worker
        options_pkl_file = args[0]
        params_file = args[1]
        model_file = args[2]
        vcf_file = args[3]
        worker_index = int(args[4])

        # load options
        options_pkl = open(options_pkl_file, "rb")
        options = pickle.load(options_pkl)
        options_pkl.close()

        # update output directory
        options.out_dir = "%s/job%d" % (options.out_dir, worker_index)

    else:
        parser.error(
            "Must provide parameters and model files and QTL VCF file")

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    if options.track_indexes is None:
        options.track_indexes = []
    else:
        options.track_indexes = [
            int(ti) for ti in options.track_indexes.split(",")
        ]
        if not os.path.isdir("%s/tracks" % options.out_dir):
            os.mkdir("%s/tracks" % options.out_dir)

    options.shifts = [int(shift) for shift in options.shifts.split(",")]
    options.sad_stats = options.sad_stats.split(",")

    #################################################################
    # setup model

    job = params.read_job_params(
        params_file, require=["seq_length", "num_targets", "target_pool"])

    if options.targets_file is None:
        target_ids = ["t%d" % ti for ti in range(job["num_targets"])]
        target_labels = [""] * len(target_ids)
        target_subset = None

    else:
        targets_df = pd.read_table(options.targets_file, index_col=0)
        target_ids = targets_df.identifier
        target_labels = targets_df.description
        target_subset = targets_df.index
        if len(target_subset) == job["num_targets"]:
            target_subset = None

    # build model
    t0 = time.time()
    model = seqnn.SeqNN()
    model.build_feed(
        job,
        ensemble_rc=options.rc,
        ensemble_shifts=options.shifts,
        embed_penultimate=options.penultimate,
        target_subset=target_subset,
    )
    print("Model building time %f" % (time.time() - t0), flush=True)

    if options.penultimate:
        # labels become inappropriate
        target_ids = [""] * model.hp.cnn_filters[-1]
        target_labels = target_ids

    # read target normalization factors
    target_norms = np.ones(len(target_labels))
    if options.norm_file is not None:
        ti = 0
        for line in open(options.norm_file):
            target_norms[ti] = float(line.strip())
            ti += 1

    num_targets = len(target_ids)

    #################################################################
    # load SNPs

    snps = bvcf.vcf_snps(vcf_file)

    # filter for worker SNPs
    if options.processes is not None:
        worker_bounds = np.linspace(0,
                                    len(snps),
                                    options.processes + 1,
                                    dtype="int")
        snps = snps[worker_bounds[worker_index]:worker_bounds[worker_index +
                                                              1]]

    num_snps = len(snps)

    #################################################################
    # setup output

    header_cols = (
        "rsid",
        "ref",
        "alt",
        "ref_pred",
        "alt_pred",
        "sad",
        "sar",
        "geo_sad",
        "ref_lpred",
        "alt_lpred",
        "lsad",
        "lsar",
        "ref_xpred",
        "alt_xpred",
        "xsad",
        "xsar",
        "target_index",
        "target_id",
        "target_label",
    )

    if options.out_h5:
        sad_out = initialize_output_h5(options.out_dir, options.sad_stats,
                                       snps, target_ids, target_labels)

    elif options.out_zarr:
        sad_out = initialize_output_zarr(options.out_dir, options.sad_stats,
                                         snps, target_ids, target_labels)

    else:
        if options.csv:
            sad_out = open("%s/sad_table.csv" % options.out_dir, "w")
            print(",".join(header_cols), file=sad_out)
        else:
            sad_out = open("%s/sad_table.txt" % options.out_dir, "w")
            print(" ".join(header_cols), file=sad_out)

    #################################################################
    # process

    # open genome FASTA
    genome_open = pysam.Fastafile(options.genome_fasta)

    # determine local start and end
    loc_mid = model.target_length // 2
    loc_start = loc_mid - (options.local // 2) // model.hp.target_pool
    loc_end = loc_start + options.local // model.hp.target_pool

    snp_i = 0
    szi = 0

    # initialize saver
    saver = tf.train.Saver()
    with tf.Session() as sess:
        # load variables into session
        saver.restore(sess, model_file)

        # construct first batch
        batch_1hot, batch_snps, snp_i = snps_next_batch(
            snps, snp_i, options.batch_size, job["seq_length"], genome_open)

        while len(batch_snps) > 0:
            ###################################################
            # predict

            # initialize batcher
            batcher = batcher.Batcher(batch_1hot,
                                      batch_size=model.hp.batch_size)

            # predict
            # batch_preds = model.predict(sess, batcher,
            #                 rc=options.rc, shifts=options.shifts,
            #                 penultimate=options.penultimate)
            batch_preds = model.predict_h5(sess, batcher)

            # normalize
            batch_preds /= target_norms

            ###################################################
            # collect and print SADs

            pi = 0
            for snp in batch_snps:
                # get reference prediction (LxT)
                ref_preds = batch_preds[pi]
                pi += 1

                # sum across length
                ref_preds_sum = ref_preds.sum(axis=0, dtype="float64")

                # print tracks
                for ti in options.track_indexes:
                    ref_bw_file = "%s/tracks/%s_t%d_ref.bw" % (
                        options.out_dir,
                        snp.rsid,
                        ti,
                    )
                    bigwig_write(
                        snp,
                        job["seq_length"],
                        ref_preds[:, ti],
                        model,
                        ref_bw_file,
                        options.genome_file,
                    )

                for alt_al in snp.alt_alleles:
                    # get alternate prediction (LxT)
                    alt_preds = batch_preds[pi]
                    pi += 1

                    # sum across length
                    alt_preds_sum = alt_preds.sum(axis=0, dtype="float64")

                    # compare reference to alternative via mean subtraction
                    sad_vec = alt_preds - ref_preds
                    sad = alt_preds_sum - ref_preds_sum

                    # compare reference to alternative via mean log division
                    sar = np.log2(alt_preds_sum +
                                  options.log_pseudo) - np.log2(
                                      ref_preds_sum + options.log_pseudo)

                    # compare geometric means
                    sar_vec = np.log2(
                        alt_preds.astype("float64") +
                        options.log_pseudo) - np.log2(
                            ref_preds.astype("float64") + options.log_pseudo)
                    geo_sad = sar_vec.sum(axis=0)

                    # sum locally
                    ref_preds_loc = ref_preds[loc_start:loc_end, :].sum(
                        axis=0, dtype="float64")
                    alt_preds_loc = alt_preds[loc_start:loc_end, :].sum(
                        axis=0, dtype="float64")

                    # compute SAD locally
                    sad_loc = alt_preds_loc - ref_preds_loc
                    sar_loc = np.log2(alt_preds_loc +
                                      options.log_pseudo) - np.log2(
                                          ref_preds_loc + options.log_pseudo)

                    # compute max difference position
                    max_li = np.argmax(np.abs(sar_vec), axis=0)

                    if options.out_h5 or options.out_zarr:
                        sad_out["SAD"][szi, :] = sad.astype("float16")
                        sad_out["xSAR"][szi, :] = np.array(
                            [
                                sar_vec[max_li[ti], ti]
                                for ti in range(num_targets)
                            ],
                            dtype="float16",
                        )
                        szi += 1

                    else:
                        # print table lines
                        for ti in range(len(sad)):
                            # print line
                            cols = (
                                snp.rsid,
                                bvcf.cap_allele(snp.ref_allele),
                                bvcf.cap_allele(alt_al),
                                ref_preds_sum[ti],
                                alt_preds_sum[ti],
                                sad[ti],
                                sar[ti],
                                geo_sad[ti],
                                ref_preds_loc[ti],
                                alt_preds_loc[ti],
                                sad_loc[ti],
                                sar_loc[ti],
                                ref_preds[max_li[ti], ti],
                                alt_preds[max_li[ti], ti],
                                sad_vec[max_li[ti], ti],
                                sar_vec[max_li[ti], ti],
                                ti,
                                target_ids[ti],
                                target_labels[ti],
                            )
                            if options.csv:
                                print(",".join([str(c) for c in cols]),
                                      file=sad_out)
                            else:
                                print(
                                    "%-13s %6s %6s | %8.2f %8.2f %8.3f %7.4f %7.3f | %7.3f %7.3f %7.3f %7.4f | %7.3f %7.3f %7.3f %7.4f | %4d %12s %s"
                                    % cols,
                                    file=sad_out,
                                )

                    # print tracks
                    for ti in options.track_indexes:
                        alt_bw_file = "%s/tracks/%s_t%d_alt.bw" % (
                            options.out_dir,
                            snp.rsid,
                            ti,
                        )
                        bigwig_write(
                            snp,
                            job["seq_length"],
                            alt_preds[:, ti],
                            model,
                            alt_bw_file,
                            options.genome_file,
                        )

            ###################################################
            # construct next batch

            batch_1hot, batch_snps, snp_i = snps_next_batch(
                snps, snp_i, options.batch_size, job["seq_length"],
                genome_open)

    ###################################################
    # compute SAD distributions across variants

    if options.out_h5 or options.out_zarr:
        # define percentiles
        d_fine = 0.001
        d_coarse = 0.01
        percentiles_neg = np.arange(d_fine, 0.1, d_fine)
        percentiles_base = np.arange(0.1, 0.9, d_coarse)
        percentiles_pos = np.arange(0.9, 1, d_fine)

        percentiles = np.concatenate(
            [percentiles_neg, percentiles_base, percentiles_pos])
        sad_out.create_dataset("percentiles", data=percentiles)
        pct_len = len(percentiles)

        for sad_stat in options.sad_stats:
            sad_stat_pct = "%s_pct" % sad_stat

            # compute
            sad_pct = np.percentile(sad_out[sad_stat],
                                    100 * percentiles,
                                    axis=0).T
            sad_pct = sad_pct.astype("float16")

            # save
            sad_out.create_dataset(sad_stat_pct, data=sad_pct, dtype="float16")

    if not options.out_zarr:
        sad_out.close()
Exemplo n.º 3
0
def main():
  usage = 'usage: %prog [options] <params_file> <model_file> <vcf_file>'
  parser = OptionParser(usage)
  parser.add_option(
      '-b',
      dest='batch_size',
      default=256,
      type='int',
      help='Batch size [Default: %default]')
  parser.add_option(
      '-c',
      dest='csv',
      default=False,
      action='store_true',
      help='Print table as CSV [Default: %default]')
  parser.add_option(
      '-f',
      dest='genome_fasta',
      default='%s/assembly/hg19.fa' % os.environ['HG19'],
      help='Genome FASTA from which sequences will be drawn [Default: %default]'
  )
  parser.add_option(
      '-g',
      dest='genome_file',
      default='%s/assembly/human.hg19.genome' % os.environ['HG19'],
      help='Chromosome lengths file [Default: %default]')
  parser.add_option(
      '-l',
      dest='seq_len',
      type='int',
      default=131072,
      help='Sequence length provided to the model [Default: %default]')
  parser.add_option(
      '--local',
      dest='local',
      default=1024,
      type='int',
      help='Local SAD score [Default: %default]')
  parser.add_option(
      '-n',
      dest='norm_file',
      default=None,
      help='Normalize SAD scores')
  parser.add_option(
      '-o',
      dest='out_dir',
      default='sad',
      help='Output directory for tables and plots [Default: %default]')
  parser.add_option(
      '-p',
      dest='processes',
      default=None,
      type='int',
      help='Number of processes, passed by multi script')
  parser.add_option(
      '--pseudo',
      dest='log_pseudo',
      default=1,
      type='float',
      help='Log2 pseudocount [Default: %default]')
  parser.add_option(
      '--rc',
      dest='rc',
      default=False,
      action='store_true',
      help=
      'Average the forward and reverse complement predictions when testing [Default: %default]'
  )
  parser.add_option(
      '--shifts',
      dest='shifts',
      default='0',
      help='Ensemble prediction shifts [Default: %default]')
  parser.add_option(
      '-t',
      dest='targets_file',
      default=None,
      help='File specifying target indexes and labels in table format')
  parser.add_option(
      '--ti',
      dest='track_indexes',
      default=None,
      help='Comma-separated list of target indexes to output BigWig tracks')
  parser.add_option(
      '-u',
      dest='penultimate',
      default=False,
      action='store_true',
      help='Compute SED in the penultimate layer [Default: %default]')
  (options, args) = parser.parse_args()

  if len(args) == 3:
    # single worker
    params_file = args[0]
    model_file = args[1]
    vcf_file = args[2]

  elif len(args) == 5:
    # multi worker
    options_pkl_file = args[0]
    params_file = args[1]
    model_file = args[2]
    vcf_file = args[3]
    worker_index = int(args[4])

    # load options
    options_pkl = open(options_pkl_file, 'rb')
    options = pickle.load(options_pkl)
    options_pkl.close()

    # update output directory
    options.out_dir = '%s/job%d' % (options.out_dir, worker_index)

  else:
    parser.error('Must provide parameters and model files and QTL VCF file')

  if not os.path.isdir(options.out_dir):
    os.mkdir(options.out_dir)

  if options.track_indexes is None:
    options.track_indexes = []
  else:
    options.track_indexes = [int(ti) for ti in options.track_indexes.split(',')]
    if not os.path.isdir('%s/tracks' % options.out_dir):
      os.mkdir('%s/tracks' % options.out_dir)

  options.shifts = [int(shift) for shift in options.shifts.split(',')]

  #################################################################
  # setup model

  job = basenji.dna_io.read_job_params(params_file)
  job['seq_length'] = options.seq_len

  if 'num_targets' not in job:
    print(
        "Must specify number of targets (num_targets) in the parameters file.",
        file=sys.stderr)
    exit(1)

  if 'target_pool' not in job:
    print(
        "Must specify target pooling (target_pool) in the parameters file.",
        file=sys.stderr)
    exit(1)

  if options.targets_file is None:
    target_ids = ['t%d' % ti for ti in range(job['num_targets'])]
    target_labels = ['']*len(target_ids)
    target_subset = None

  else:
    # Unfortunately, this target file differs from some others
    # in that it needs to specify the indexes from the original
    # set. In the future, I should standardize to this version.

    target_ids = []
    target_labels = []
    target_subset = []

    for line in open(options.targets_file):
      a = line.strip().split('\t')
      target_subset.append(int(a[0]))
      target_ids.append(a[1])
      target_labels.append(a[3])

      if len(target_subset) == job['num_targets']:
        target_subset = None

  # build model
  t0 = time.time()
  model = basenji.seqnn.SeqNN()
  model.build(job, target_subset=target_subset)
  print('Model building time %f' % (time.time() - t0), flush=True)

  if options.penultimate:
    # labels become inappropriate
    target_ids = ['']*model.cnn_filters[-1]
    target_labels = target_ids

  # read target normalization factors
  target_norms = np.ones(len(target_labels))
  if options.norm_file is not None:
    ti = 0
    for line in open(options.norm_file):
      target_norms[ti] = float(line.strip())
      ti += 1


  #################################################################
  # load SNPs

  snps = bvcf.vcf_snps(vcf_file)

  # filter for worker SNPs
  if options.processes is not None:
    snps = [
        snps[si] for si in range(len(snps))
        if si % options.processes == worker_index
    ]

  #################################################################
  # setup output

  header_cols = ('rsid', 'ref', 'alt',
                  'ref_pred', 'alt_pred', 'sad', 'sar', 'geo_sad',
                  'ref_lpred', 'alt_lpred', 'lsad', 'lsar',
                  'ref_xpred', 'alt_xpred', 'xsad', 'xsar',
                  'target_index', 'target_id', 'target_label')

  if options.csv:
    sad_out = open('%s/sad_table.csv' % options.out_dir, 'w')
    print(','.join(header_cols), file=sad_out)
  else:
    sad_out = open('%s/sad_table.txt' % options.out_dir, 'w')
    print(' '.join(header_cols), file=sad_out)


  #################################################################
  # process

  # open genome FASTA
  genome_open = pysam.Fastafile(options.genome_fasta)

  # determine local start and end
  loc_mid = model.target_length // 2
  loc_start = loc_mid - (options.local//2) // model.target_pool
  loc_end = loc_start + options.local // model.target_pool

  snp_i = 0

  # initialize saver
  saver = tf.train.Saver()

  with tf.Session() as sess:
    # load variables into session
    saver.restore(sess, model_file)

    # construct first batch
    batch_1hot, batch_snps, snp_i = snps_next_batch(
        snps, snp_i, options.batch_size, options.seq_len, genome_open)

    while len(batch_snps) > 0:
      ###################################################
      # predict

      # initialize batcher
      batcher = basenji.batcher.Batcher(batch_1hot, batch_size=model.batch_size)

      # predict
      batch_preds = model.predict(sess, batcher,
                      rc=options.rc, shifts=options.shifts,
                      penultimate=options.penultimate)

      # normalize
      batch_preds /= target_norms


      ###################################################
      # collect and print SADs

      pi = 0
      for snp in batch_snps:
        # get reference prediction (LxT)
        ref_preds = batch_preds[pi]
        pi += 1

        # sum across length
        ref_preds_sum = ref_preds.sum(axis=0, dtype='float64')

        # print tracks
        for ti in options.track_indexes:
          ref_bw_file = '%s/tracks/%s_t%d_ref.bw' % (options.out_dir, snp.rsid,
                                                     ti)
          bigwig_write(snp, options.seq_len, ref_preds[:, ti], model,
                       ref_bw_file, options.genome_file)

        for alt_al in snp.alt_alleles:
          # get alternate prediction (LxT)
          alt_preds = batch_preds[pi]
          pi += 1

          # sum across length
          alt_preds_sum = alt_preds.sum(axis=0, dtype='float64')

          # compare reference to alternative via mean subtraction
          sad_vec = alt_preds - ref_preds
          sad = alt_preds_sum - ref_preds_sum

          # compare reference to alternative via mean log division
          sar = np.log2(alt_preds_sum + options.log_pseudo) \
                  - np.log2(ref_preds_sum + options.log_pseudo)

          # compare geometric means
          sar_vec = np.log2(alt_preds.astype('float64') + options.log_pseudo) \
                      - np.log2(ref_preds.astype('float64') + options.log_pseudo)
          geo_sad = sar_vec.sum(axis=0)

          # sum locally
          ref_preds_loc = ref_preds[loc_start:loc_end,:].sum(axis=0, dtype='float64')
          alt_preds_loc = alt_preds[loc_start:loc_end,:].sum(axis=0, dtype='float64')

          # compute SAD locally
          sad_loc = alt_preds_loc - ref_preds_loc
          sar_loc = np.log2(alt_preds_loc + options.log_pseudo) \
                      - np.log2(ref_preds_loc + options.log_pseudo)

          # compute max difference position
          max_li = np.argmax(np.abs(sar_vec), axis=0)

          # print table lines
          for ti in range(len(sad)):
            # print line
            cols = (snp.rsid, bvcf.cap_allele(snp.ref_allele), bvcf.cap_allele(alt_al),
                    ref_preds_sum[ti], alt_preds_sum[ti], sad[ti], sar[ti], geo_sad[ti],
                    ref_preds_loc[ti], alt_preds_loc[ti], sad_loc[ti], sar_loc[ti],
                    ref_preds[max_li[ti], ti], alt_preds[max_li[ti], ti], sad_vec[max_li[ti],ti], sar_vec[max_li[ti],ti],
                    ti, target_ids[ti], target_labels[ti])
            if options.csv:
              print(','.join([str(c) for c in cols]), file=sad_out)
            else:
              print(
                  '%-13s %6s %6s | %8.2f %8.2f %8.3f %7.4f %7.3f | %7.3f %7.3f %7.3f %7.4f | %7.3f %7.3f %7.3f %7.4f | %4d %12s %s'
                  % cols,
                  file=sad_out)

          # print tracks
          for ti in options.track_indexes:
            alt_bw_file = '%s/tracks/%s_t%d_alt.bw' % (options.out_dir,
                                                       snp.rsid, ti)
            bigwig_write(snp, options.seq_len, alt_preds[:, ti], model,
                         alt_bw_file, options.genome_file)

      ###################################################
      # construct next batch

      batch_1hot, batch_snps, snp_i = snps_next_batch(
          snps, snp_i, options.batch_size, options.seq_len, genome_open)

  sad_out.close()
Exemplo n.º 4
0
def main():
    usage = 'usage: %prog [options] <params_file> <model_file> <vcf_file>'
    parser = OptionParser(usage)
    parser.add_option('-b',
                      dest='batch_size',
                      default=256,
                      type='int',
                      help='Batch size [Default: %default]')
    parser.add_option('-c',
                      dest='csv',
                      default=False,
                      action='store_true',
                      help='Print table as CSV [Default: %default]')
    parser.add_option('-f',
                      dest='genome_fasta',
                      default='%s/data/hg19.fa' % os.environ['BASENJIDIR'],
                      help='Genome FASTA for sequences [Default: %default]')
    parser.add_option('-g',
                      dest='genome_file',
                      default='%s/data/human.hg19.genome' %
                      os.environ['BASENJIDIR'],
                      help='Chromosome lengths file [Default: %default]')
    parser.add_option('--h5',
                      dest='out_h5',
                      default=False,
                      action='store_true',
                      help='Output stats to sad.h5 [Default: %default]')
    parser.add_option('--local',
                      dest='local',
                      default=1024,
                      type='int',
                      help='Local SAD score [Default: %default]')
    parser.add_option('-n',
                      dest='norm_file',
                      default=None,
                      help='Normalize SAD scores')
    parser.add_option(
        '-o',
        dest='out_dir',
        default='sad',
        help='Output directory for tables and plots [Default: %default]')
    parser.add_option('-p',
                      dest='processes',
                      default=None,
                      type='int',
                      help='Number of processes, passed by multi script')
    parser.add_option('--pseudo',
                      dest='log_pseudo',
                      default=1,
                      type='float',
                      help='Log2 pseudocount [Default: %default]')
    parser.add_option(
        '--rc',
        dest='rc',
        default=False,
        action='store_true',
        help=
        'Average forward and reverse complement predictions [Default: %default]'
    )
    parser.add_option('--shifts',
                      dest='shifts',
                      default='0',
                      type='str',
                      help='Ensemble prediction shifts [Default: %default]')
    parser.add_option(
        '--stats',
        dest='sad_stats',
        default='SAD,xSAR',
        help='Comma-separated list of stats to save. [Default: %default]')
    parser.add_option(
        '-t',
        dest='targets_file',
        default=None,
        type='str',
        help='File specifying target indexes and labels in table format')
    parser.add_option(
        '--ti',
        dest='track_indexes',
        default=None,
        type='str',
        help='Comma-separated list of target indexes to output BigWig tracks')
    parser.add_option(
        '-u',
        dest='penultimate',
        default=False,
        action='store_true',
        help='Compute SED in the penultimate layer [Default: %default]')
    parser.add_option('-z',
                      dest='out_zarr',
                      default=False,
                      action='store_true',
                      help='Output stats to sad.zarr [Default: %default]')
    (options, args) = parser.parse_args()

    if len(args) == 3:
        # single worker
        params_file = args[0]
        model_file = args[1]
        vcf_file = args[2]

    elif len(args) == 5:
        # multi worker
        options_pkl_file = args[0]
        params_file = args[1]
        model_file = args[2]
        vcf_file = args[3]
        worker_index = int(args[4])

        # load options
        options_pkl = open(options_pkl_file, 'rb')
        options = pickle.load(options_pkl)
        options_pkl.close()

        # update output directory
        options.out_dir = '%s/job%d' % (options.out_dir, worker_index)

    else:
        parser.error(
            'Must provide parameters and model files and QTL VCF file')

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    if options.track_indexes is None:
        options.track_indexes = []
    else:
        options.track_indexes = [
            int(ti) for ti in options.track_indexes.split(',')
        ]
        if not os.path.isdir('%s/tracks' % options.out_dir):
            os.mkdir('%s/tracks' % options.out_dir)

    options.shifts = [int(shift) for shift in options.shifts.split(',')]
    options.sad_stats = options.sad_stats.split(',')

    #################################################################
    # setup model

    job = params.read_job_params(
        params_file, require=['seq_length', 'num_targets', 'target_pool'])

    if options.targets_file is None:
        target_ids = ['t%d' % ti for ti in range(job['num_targets'])]
        target_labels = [''] * len(target_ids)
        target_subset = None

    else:
        targets_df = pd.read_table(options.targets_file, index_col=0)
        target_ids = targets_df.identifier
        target_labels = targets_df.description
        target_subset = targets_df.index
        if len(target_subset) == job['num_targets']:
            target_subset = None

    # build model
    t0 = time.time()
    model = seqnn.SeqNN()
    model.build_feed(job,
                     ensemble_rc=options.rc,
                     ensemble_shifts=options.shifts,
                     embed_penultimate=options.penultimate,
                     target_subset=target_subset)
    print('Model building time %f' % (time.time() - t0), flush=True)

    if options.penultimate:
        # labels become inappropriate
        target_ids = [''] * model.hp.cnn_filters[-1]
        target_labels = target_ids

    # read target normalization factors
    target_norms = np.ones(len(target_labels))
    if options.norm_file is not None:
        ti = 0
        for line in open(options.norm_file):
            target_norms[ti] = float(line.strip())
            ti += 1

    num_targets = len(target_ids)

    #################################################################
    # load SNPs

    snps = bvcf.vcf_snps(vcf_file)

    # filter for worker SNPs
    if options.processes is not None:
        worker_bounds = np.linspace(0,
                                    len(snps),
                                    options.processes + 1,
                                    dtype='int')
        snps = snps[worker_bounds[worker_index]:worker_bounds[worker_index +
                                                              1]]

    num_snps = len(snps)

    #################################################################
    # setup output

    header_cols = ('rsid', 'ref', 'alt', 'ref_pred', 'alt_pred', 'sad', 'sar',
                   'geo_sad', 'ref_lpred', 'alt_lpred', 'lsad', 'lsar',
                   'ref_xpred', 'alt_xpred', 'xsad', 'xsar', 'target_index',
                   'target_id', 'target_label')

    if options.out_h5:
        sad_out = initialize_output_h5(options.out_dir, options.sad_stats,
                                       snps, target_ids, target_labels)

    elif options.out_zarr:
        sad_out = initialize_output_zarr(options.out_dir, options.sad_stats,
                                         snps, target_ids, target_labels)

    else:
        if options.csv:
            sad_out = open('%s/sad_table.csv' % options.out_dir, 'w')
            print(','.join(header_cols), file=sad_out)
        else:
            sad_out = open('%s/sad_table.txt' % options.out_dir, 'w')
            print(' '.join(header_cols), file=sad_out)

    #################################################################
    # process

    # open genome FASTA
    genome_open = pysam.Fastafile(options.genome_fasta)

    # determine local start and end
    loc_mid = model.target_length // 2
    loc_start = loc_mid - (options.local // 2) // model.hp.target_pool
    loc_end = loc_start + options.local // model.hp.target_pool

    snp_i = 0
    szi = 0

    # initialize saver
    saver = tf.train.Saver()
    with tf.Session() as sess:
        # load variables into session
        saver.restore(sess, model_file)

        # construct first batch
        batch_1hot, batch_snps, snp_i = snps_next_batch(
            snps, snp_i, options.batch_size, job['seq_length'], genome_open)

        while len(batch_snps) > 0:
            ###################################################
            # predict

            # initialize batcher
            batcher = batcher.Batcher(batch_1hot,
                                      batch_size=model.hp.batch_size)

            # predict
            # batch_preds = model.predict(sess, batcher,
            #                 rc=options.rc, shifts=options.shifts,
            #                 penultimate=options.penultimate)
            batch_preds = model.predict_h5(sess, batcher)

            # normalize
            batch_preds /= target_norms

            ###################################################
            # collect and print SADs

            pi = 0
            for snp in batch_snps:
                # get reference prediction (LxT)
                ref_preds = batch_preds[pi]
                pi += 1

                # sum across length
                ref_preds_sum = ref_preds.sum(axis=0, dtype='float64')

                # print tracks
                for ti in options.track_indexes:
                    ref_bw_file = '%s/tracks/%s_t%d_ref.bw' % (options.out_dir,
                                                               snp.rsid, ti)
                    bigwig_write(snp, job['seq_length'], ref_preds[:, ti],
                                 model, ref_bw_file, options.genome_file)

                for alt_al in snp.alt_alleles:
                    # get alternate prediction (LxT)
                    alt_preds = batch_preds[pi]
                    pi += 1

                    # sum across length
                    alt_preds_sum = alt_preds.sum(axis=0, dtype='float64')

                    # compare reference to alternative via mean subtraction
                    sad_vec = alt_preds - ref_preds
                    sad = alt_preds_sum - ref_preds_sum

                    # compare reference to alternative via mean log division
                    sar = np.log2(alt_preds_sum + options.log_pseudo) \
                            - np.log2(ref_preds_sum + options.log_pseudo)

                    # compare geometric means
                    sar_vec = np.log2(alt_preds.astype('float64') + options.log_pseudo) \
                                - np.log2(ref_preds.astype('float64') + options.log_pseudo)
                    geo_sad = sar_vec.sum(axis=0)

                    # sum locally
                    ref_preds_loc = ref_preds[loc_start:loc_end, :].sum(
                        axis=0, dtype='float64')
                    alt_preds_loc = alt_preds[loc_start:loc_end, :].sum(
                        axis=0, dtype='float64')

                    # compute SAD locally
                    sad_loc = alt_preds_loc - ref_preds_loc
                    sar_loc = np.log2(alt_preds_loc + options.log_pseudo) \
                                - np.log2(ref_preds_loc + options.log_pseudo)

                    # compute max difference position
                    max_li = np.argmax(np.abs(sar_vec), axis=0)

                    if options.out_h5 or options.out_zarr:
                        sad_out['SAD'][szi, :] = sad.astype('float16')
                        sad_out['xSAR'][szi, :] = np.array([
                            sar_vec[max_li[ti], ti]
                            for ti in range(num_targets)
                        ],
                                                           dtype='float16')
                        szi += 1

                    else:
                        # print table lines
                        for ti in range(len(sad)):
                            # print line
                            cols = (snp.rsid, bvcf.cap_allele(snp.ref_allele),
                                    bvcf.cap_allele(alt_al), ref_preds_sum[ti],
                                    alt_preds_sum[ti], sad[ti], sar[ti],
                                    geo_sad[ti], ref_preds_loc[ti],
                                    alt_preds_loc[ti], sad_loc[ti],
                                    sar_loc[ti], ref_preds[max_li[ti], ti],
                                    alt_preds[max_li[ti],
                                              ti], sad_vec[max_li[ti], ti],
                                    sar_vec[max_li[ti], ti], ti,
                                    target_ids[ti], target_labels[ti])
                            if options.csv:
                                print(','.join([str(c) for c in cols]),
                                      file=sad_out)
                            else:
                                print(
                                    '%-13s %6s %6s | %8.2f %8.2f %8.3f %7.4f %7.3f | %7.3f %7.3f %7.3f %7.4f | %7.3f %7.3f %7.3f %7.4f | %4d %12s %s'
                                    % cols,
                                    file=sad_out)

                    # print tracks
                    for ti in options.track_indexes:
                        alt_bw_file = '%s/tracks/%s_t%d_alt.bw' % (
                            options.out_dir, snp.rsid, ti)
                        bigwig_write(snp, job['seq_length'], alt_preds[:, ti],
                                     model, alt_bw_file, options.genome_file)

            ###################################################
            # construct next batch

            batch_1hot, batch_snps, snp_i = snps_next_batch(
                snps, snp_i, options.batch_size, job['seq_length'],
                genome_open)

    ###################################################
    # compute SAD distributions across variants

    if options.out_h5 or options.out_zarr:
        # define percentiles
        d_fine = 0.001
        d_coarse = 0.01
        percentiles_neg = np.arange(d_fine, 0.1, d_fine)
        percentiles_base = np.arange(0.1, 0.9, d_coarse)
        percentiles_pos = np.arange(0.9, 1, d_fine)

        percentiles = np.concatenate(
            [percentiles_neg, percentiles_base, percentiles_pos])
        sad_out.create_dataset('percentiles', data=percentiles)
        pct_len = len(percentiles)

        for sad_stat in options.sad_stats:
            sad_stat_pct = '%s_pct' % sad_stat

            # compute
            sad_pct = np.percentile(sad_out[sad_stat],
                                    100 * percentiles,
                                    axis=0).T
            sad_pct = sad_pct.astype('float16')

            # save
            sad_out.create_dataset(sad_stat_pct, data=sad_pct, dtype='float16')

    if not options.out_zarr:
        sad_out.close()
Exemplo n.º 5
0
def main():
    usage = (
        "usage: %prog [options] <params_file> <model_file> <genes_hdf5_file>"
        " <vcf_file>"
    )
    parser = OptionParser(usage)
    parser.add_option(
        "-a",
        dest="all_sed",
        default=False,
        action="store_true",
        help="Print all variant-gene pairs, as opposed to only nonzero [Default: %default]",
    )
    parser.add_option(
        "-b",
        dest="batch_size",
        default=None,
        type="int",
        help="Batch size [Default: %default]",
    )
    parser.add_option(
        "-c",
        dest="csv",
        default=False,
        action="store_true",
        help="Print table as CSV [Default: %default]",
    )
    parser.add_option(
        "-g",
        dest="genome_file",
        default="%s/data/human.hg19.genome" % os.environ["BASENJIDIR"],
        help="Chromosome lengths file [Default: %default]",
    )
    parser.add_option(
        "-o",
        dest="out_dir",
        default="sed",
        help="Output directory for tables and plots [Default: %default]",
    )
    parser.add_option(
        "-p",
        dest="processes",
        default=None,
        type="int",
        help="Number of processes, passed by multi script",
    )
    parser.add_option(
        "--pseudo",
        dest="log_pseudo",
        default=0.125,
        type="float",
        help="Log2 pseudocount [Default: %default]",
    )
    parser.add_option(
        "-r",
        dest="tss_radius",
        default=0,
        type="int",
        help="Radius of bins considered to quantify TSS transcription [Default: %default]",
    )
    parser.add_option(
        "--rc",
        dest="rc",
        default=False,
        action="store_true",
        help="Average the forward and reverse complement predictions when testing [Default: %default]",
    )
    parser.add_option(
        "--shifts",
        dest="shifts",
        default="0",
        help="Ensemble prediction shifts [Default: %default]",
    )
    parser.add_option(
        "-t",
        dest="targets_file",
        default=None,
        help="File specifying target indexes and labels in table format",
    )
    parser.add_option(
        "-u",
        dest="penultimate",
        default=False,
        action="store_true",
        help="Compute SED in the penultimate layer [Default: %default]",
    )
    parser.add_option(
        "-x",
        dest="tss_table",
        default=False,
        action="store_true",
        help="Print TSS table in addition to gene [Default: %default]",
    )
    (options, args) = parser.parse_args()

    if len(args) == 4:
        # single worker
        params_file = args[0]
        model_file = args[1]
        genes_hdf5_file = args[2]
        vcf_file = args[3]

    elif len(args) == 6:
        # multi worker
        options_pkl_file = args[0]
        params_file = args[1]
        model_file = args[2]
        genes_hdf5_file = args[3]
        vcf_file = args[4]
        worker_index = int(args[5])

        # load options
        options_pkl = open(options_pkl_file, "rb")
        options = pickle.load(options_pkl)
        options_pkl.close()

        # update output directory
        options.out_dir = "%s/job%d" % (options.out_dir, worker_index)

    else:
        parser.error(
            "Must provide parameters and model files, genes HDF5 file, and QTL VCF"
            " file"
        )

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    options.shifts = [int(shift) for shift in options.shifts.split(",")]

    #################################################################
    # reads in genes HDF5

    gene_data = genedata.GeneData(genes_hdf5_file)

    # filter for worker sequences
    if options.processes is not None:
        gene_data.worker(worker_index, options.processes)

    #################################################################
    # prep SNPs

    # load SNPs
    snps = bvcf.vcf_snps(vcf_file)

    # intersect w/ segments
    print("Intersecting gene sequences with SNPs...", end="")
    sys.stdout.flush()
    seqs_snps = bvcf.intersect_seqs_snps(vcf_file, gene_data.gene_seqs, vision_p=0.5)
    print("%d sequences w/ SNPs" % len(seqs_snps))

    #################################################################
    # setup model

    job = params.read_job_params(params_file)

    job["seq_length"] = gene_data.seq_length
    job["seq_depth"] = gene_data.seq_depth
    job["target_pool"] = gene_data.pool_width

    if "num_targets" not in job and gene_data.num_targets is not None:
        job["num_targets"] = gene_data.num_targets

    if "num_targets" not in job:
        print(
            "Must specify number of targets (num_targets) in the \
            parameters file.",
            file=sys.stderr,
        )
        exit(1)

    if options.targets_file is None:
        target_labels = gene_data.target_labels
        target_subset = None
        target_ids = gene_data.target_ids
        if target_ids is None:
            target_ids = ["t%d" % ti for ti in range(job["num_targets"])]
            target_labels = [""] * len(target_ids)

    else:
        # Unfortunately, this target file differs from some others
        # in that it needs to specify the indexes from the original
        # set. In the future, I should standardize to this version.

        target_ids = []
        target_labels = []
        target_subset = []

        for line in open(options.targets_file):
            a = line.strip().split("\t")
            target_subset.append(int(a[0]))
            target_ids.append(a[1])
            target_labels.append(a[3])

            if len(target_subset) == job["num_targets"]:
                target_subset = None

    # build model
    model = seqnn.SeqNN()
    model.build_feed(job, target_subset=target_subset)

    if options.penultimate:
        # labels become inappropriate
        target_ids = [""] * model.hp.cnn_filters[-1]
        target_labels = target_ids

    #################################################################
    # compute, collect, and print SEDs

    header_cols = (
        "rsid",
        "ref",
        "alt",
        "gene",
        "tss_dist",
        "ref_pred",
        "alt_pred",
        "sed",
        "ser",
        "target_index",
        "target_id",
        "target_label",
    )
    if options.csv:
        sed_gene_out = open("%s/sed_gene.csv" % options.out_dir, "w")
        print(",".join(header_cols), file=sed_gene_out)
        if options.tss_table:
            sed_tss_out = open("%s/sed_tss.csv" % options.out_dir, "w")
            print(",".join(header_cols), file=sed_tss_out)

    else:
        sed_gene_out = open("%s/sed_gene.txt" % options.out_dir, "w")
        print(" ".join(header_cols), file=sed_gene_out)
        if options.tss_table:
            sed_tss_out = open("%s/sed_tss.txt" % options.out_dir, "w")
            print(" ".join(header_cols), file=sed_tss_out)

    # helper variables
    pred_buffer = model.hp.batch_buffer // model.hp.target_pool

    # initialize saver
    saver = tf.train.Saver()

    with tf.Session() as sess:
        # load variables into session
        saver.restore(sess, model_file)

        # for each gene sequence
        for seq_i in range(gene_data.num_seqs):
            gene_seq = gene_data.gene_seqs[seq_i]
            print(gene_seq)

            # if it contains SNPs
            seq_snps = [snps[snp_i] for snp_i in seqs_snps[seq_i]]
            if len(seq_snps) > 0:

                # one hot code allele sequences
                aseqs_1hot = alleles_1hot(
                    gene_seq, gene_data.seqs_1hot[seq_i], seq_snps
                )

                # initialize batcher
                batcher_gene = batcher.Batcher(
                    aseqs_1hot, batch_size=model.hp.batch_size
                )

                # construct allele gene_seq's
                allele_gene_seqs = [gene_seq] * aseqs_1hot.shape[0]

                # predict alleles
                allele_tss_preds = model.predict_genes(
                    sess,
                    batcher_gene,
                    allele_gene_seqs,
                    rc=options.rc,
                    shifts=options.shifts,
                    embed_penultimate=options.penultimate,
                    tss_radius=options.tss_radius,
                )

                # reshape (Alleles x TSSs) x Targets to Alleles x TSSs x Targets
                allele_tss_preds = allele_tss_preds.reshape(
                    (aseqs_1hot.shape[0], gene_seq.num_tss, -1)
                )

                # extract reference and SNP alt predictions
                ref_tss_preds = allele_tss_preds[0]  # TSSs x Targets
                alt_tss_preds = allele_tss_preds[1:]  # SNPs x TSSs x Targets

                # compute TSS SED scores
                snp_tss_sed = alt_tss_preds - ref_tss_preds
                snp_tss_ser = np.log2(alt_tss_preds + options.log_pseudo) - np.log2(
                    ref_tss_preds + options.log_pseudo
                )

                # compute gene-level predictions
                ref_gene_preds, gene_ids = gene.map_tss_genes(
                    ref_tss_preds, gene_seq.tss_list, options.tss_radius
                )
                alt_gene_preds = []
                for snp_i in range(len(seq_snps)):
                    agp, _ = gene.map_tss_genes(
                        alt_tss_preds[snp_i], gene_seq.tss_list, options.tss_radius
                    )
                    alt_gene_preds.append(agp)
                alt_gene_preds = np.array(alt_gene_preds)

                # compute gene SED scores
                gene_sed = alt_gene_preds - ref_gene_preds
                gene_ser = np.log2(alt_gene_preds + options.log_pseudo) - np.log2(
                    ref_gene_preds + options.log_pseudo
                )

                # for each SNP
                for snp_i in range(len(seq_snps)):
                    snp = seq_snps[snp_i]

                    # initialize gene data structures
                    snp_dist_gene = {}

                    # for each TSS
                    for tss_i in range(gene_seq.num_tss):
                        tss = gene_seq.tss_list[tss_i]

                        # SNP distance to TSS
                        snp_dist = abs(snp.pos - tss.pos)
                        if tss.gene_id in snp_dist_gene:
                            snp_dist_gene[tss.gene_id] = min(
                                snp_dist_gene[tss.gene_id], snp_dist
                            )
                        else:
                            snp_dist_gene[tss.gene_id] = snp_dist

                        # for each target
                        if options.tss_table:
                            for ti in range(ref_tss_preds.shape[1]):

                                # check if nonzero
                                if options.all_sed or not np.isclose(
                                    tss_sed[snp_i, tss_i, ti], 0, atol=1e-4
                                ):

                                    # print
                                    cols = (
                                        snp.rsid,
                                        bvcf.cap_allele(snp.ref_allele),
                                        bvcf.cap_allele(snp.alt_alleles[0]),
                                        tss.identifier,
                                        snp_dist,
                                        ref_tss_preds[tss_i, ti],
                                        alt_tss_preds[snp_i, tss_i, ti],
                                        tss_sed[snp_i, tss_i, ti],
                                        tss_ser[snp_i, tss_i, ti],
                                        ti,
                                        target_ids[ti],
                                        target_labels[ti],
                                    )
                                    if options.csv:
                                        print(
                                            ",".join([str(c) for c in cols]),
                                            file=sed_tss_out,
                                        )
                                    else:
                                        print(
                                            "%-13s %s %5s %16s %5d %7.4f %7.4f %7.4f %7.4f %4d %12s %s"
                                            % cols,
                                            file=sed_tss_out,
                                        )

                    # for each gene
                    for gi in range(len(gene_ids)):
                        gene_str = gene_ids[gi]
                        if gene_ids[gi] in gene_data.multi_seq_genes:
                            gene_str = "%s_multi" % gene_ids[gi]

                        # print rows to gene table
                        for ti in range(ref_gene_preds.shape[1]):

                            # check if nonzero
                            if options.all_sed or not np.isclose(
                                gene_sed[snp_i, gi, ti], 0, atol=1e-4
                            ):

                                # print
                                cols = [
                                    snp.rsid,
                                    bvcf.cap_allele(snp.ref_allele),
                                    bvcf.cap_allele(snp.alt_alleles[0]),
                                    gene_str,
                                    snp_dist_gene[gene_ids[gi]],
                                    ref_gene_preds[gi, ti],
                                    alt_gene_preds[snp_i, gi, ti],
                                    gene_sed[snp_i, gi, ti],
                                    gene_ser[snp_i, gi, ti],
                                    ti,
                                    target_ids[ti],
                                    target_labels[ti],
                                ]
                                if options.csv:
                                    print(
                                        ",".join([str(c) for c in cols]),
                                        file=sed_gene_out,
                                    )
                                else:
                                    print(
                                        "%-13s %s %5s %16s %5d %7.4f %7.4f %7.4f %7.4f %4d %12s %s"
                                        % tuple(cols),
                                        file=sed_gene_out,
                                    )

                # clean up
                gc.collect()

    sed_gene_out.close()
    if options.tss_table:
        sed_tss_out.close()