Exemplo n.º 1
0
def main():
  usage = 'usage: %prog [options] <params_file> <model_file> <genes_hdf5_file>'
  parser = OptionParser(usage)
  parser.add_option('-g', dest='genome_file',
      default='%s/assembly/human.hg19.genome' % os.environ['HG19'],
      help='Chromosome lengths file [Default: %default]')
  parser.add_option('-l', dest='gene_list',
      help='Process only gene ids in the given file')
  parser.add_option('--mc', dest='mc_n',
      default=0, type='int',
      help='Monte carlo test iterations [Default: %default]')
  parser.add_option('-n', dest='norm',
      default=None, type='int',
      help='Compute saliency norm [Default% default]')
  parser.add_option('-o', dest='out_dir',
      default='grad_map',
      help='Output directory [Default: %default]')
  parser.add_option('--rc', dest='rc',
      default=False, action='store_true',
      help='Average the forward and reverse complement predictions when testing [Default: %default]')
  parser.add_option('--shifts', dest='shifts',
      default='0',
      help='Ensemble prediction shifts [Default: %default]')
  parser.add_option('-t', dest='target_indexes',
      default=None,
      help='Target indexes to plot')
  (options, args) = parser.parse_args()

  if len(args) != 3:
    parser.error('Must provide parameters, model, and genomic position')
  else:
    params_file = args[0]
    model_file = args[1]
    genes_hdf5_file = args[2]

  if not os.path.isdir(options.out_dir):
    os.mkdir(options.out_dir)

  options.shifts = [int(shift) for shift in options.shifts.split(',')]

  #################################################################
  # reads in genes HDF5

  gene_data = genedata.GeneData(genes_hdf5_file)

  # subset gene sequences
  genes_subset = set()
  if options.gene_list:
    for line in open(options.gene_list):
      genes_subset.add(line.rstrip())

    gene_data.subset_genes(genes_subset)
    print('Filtered to %d sequences' % gene_data.num_seqs)

  # extract sequence chrom and start
  seqs_chrom = [gene_data.gene_seqs[si].chrom for si in range(gene_data.num_seqs)]
  seqs_start = [gene_data.gene_seqs[si].start for si in range(gene_data.num_seqs)]


  #######################################################
  # model parameters and placeholders

  job = params.read_job_params(params_file)

  job['seq_length'] = gene_data.seq_length
  job['seq_depth'] = gene_data.seq_depth
  job['target_pool'] = gene_data.pool_width

  if 'num_targets' not in job:
    print(
        "Must specify number of targets (num_targets) in the parameters file.",
        file=sys.stderr)
    exit(1)

  # set target indexes
  if options.target_indexes is not None:
    options.target_indexes = [int(ti) for ti in options.target_indexes.split(',')]
    target_subset = options.target_indexes
  else:
    options.target_indexes = list(range(job['num_targets']))
    target_subset = None

  # build model
  model = seqnn.SeqNN()
  model.build_feed(job, target_subset=target_subset)

  # determine latest pre-dilated layer
  cnn_dilation = np.array([cp.dilation for cp in model.hp.cnn_params])
  dilated_mask = cnn_dilation > 1
  dilated_indexes = np.where(dilated_mask)[0]
  pre_dilated_layer = np.min(dilated_indexes)
  print('Pre-dilated layer: %d' % pre_dilated_layer)

  # build gradients ops
  t0 = time.time()
  print('Building target/position-specific gradient ops.', end='')
  model.build_grads(layers=[pre_dilated_layer])
  print(' Done in %ds' % (time.time()-t0), flush=True)


  #######################################################
  # acquire gradients

  # initialize saver
  saver = tf.train.Saver()

  with tf.Session() as sess:
    # load variables into session
    saver.restore(sess, model_file)

    # score sequences and write bigwigs
    score_write(sess, model, options, gene_data.seqs_1hot, seqs_chrom, seqs_start)
Exemplo n.º 2
0
def main():
  usage = 'usage: %prog [options] <params_file> <model_file> <genes_hdf5_file>'
  parser = OptionParser(usage)
  parser.add_option(
      '-b',
      dest='batch_size',
      default=None,
      type='int',
      help='Batch size [Default: %default]')
  parser.add_option(
      '-i',
      dest='ignore_bed',
      help='Ignore genes overlapping regions in this BED file')
  parser.add_option(
      '-l', dest='load_preds', help='Load tess_preds from file')
  parser.add_option(
      '--heat',
      dest='plot_heat',
      default=False,
      action='store_true',
      help='Plot big gene-target heatmaps [Default: %default]')
  parser.add_option(
      '-o',
      dest='out_dir',
      default='genes_out',
      help='Output directory for tables and plots [Default: %default]')
  parser.add_option(
      '-r',
      dest='tss_radius',
      default=0,
      type='int',
      help='Radius of bins considered to quantify TSS transcription [Default: %default]')
  parser.add_option(
      '--rc',
      dest='rc',
      default=False,
      action='store_true',
      help=
      'Average the forward and reverse complement predictions when testing [Default: %default]'
  )
  parser.add_option(
      '-s',
      dest='plot_scatter',
      default=False,
      action='store_true',
      help='Make time-consuming accuracy scatter plots [Default: %default]')
  parser.add_option(
      '--shifts',
      dest='shifts',
      default='0',
      help='Ensemble prediction shifts [Default: %default]')
  parser.add_option(
      '--rep',
      dest='replicate_labels_file',
      help=
      'Compare replicate experiments, aided by the given file with long labels')
  parser.add_option(
      '-t',
      dest='target_indexes',
      default=None,
      help=
      'File or Comma-separated list of target indexes to scatter plot true versus predicted values'
  )
  parser.add_option(
      '--table',
      dest='print_tables',
      default=False,
      action='store_true',
      help='Print big gene/TSS tables [Default: %default]')
  parser.add_option(
      '--tss',
      dest='tss_alt',
      default=False,
      action='store_true',
      help='Perform alternative TSS analysis [Default: %default]')
  parser.add_option(
      '-v',
      dest='gene_variance',
      default=False,
      action='store_true',
      help=
      'Study accuracy with respect to gene variance across targets [Default: %default]'
  )
  (options, args) = parser.parse_args()

  if len(args) != 3:
    parser.error('Must provide parameters and model files, and genes HDF5 file')
  else:
    params_file = args[0]
    model_file = args[1]
    genes_hdf5_file = args[2]

  if not os.path.isdir(options.out_dir):
    os.mkdir(options.out_dir)

  options.shifts = [int(shift) for shift in options.shifts.split(',')]

  #################################################################
  # read in genes and targets

  gene_data = genedata.GeneData(genes_hdf5_file)


  #################################################################
  # TSS predictions

  if options.load_preds is not None:
    # load from file
    tss_preds = np.load(options.load_preds)

  else:

    #######################################################
    # setup model

    job = params.read_job_params(params_file)

    job['seq_length'] = gene_data.seq_length
    job['seq_depth'] = gene_data.seq_depth
    job['target_pool'] = gene_data.pool_width
    if not 'num_targets' in job:
      job['num_targets'] = gene_data.num_targets

    # build model
    model = seqnn.SeqNN()
    model.build(job)

    if options.batch_size is not None:
      model.hp.batch_size = options.batch_size


    #######################################################
    # predict TSSs

    t0 = time.time()
    print('Computing gene predictions.', end='')
    sys.stdout.flush()

    # initialize batcher
    gene_batcher = batcher.Batcher(
        gene_data.seqs_1hot, batch_size=model.hp.batch_size)

    # initialie saver
    saver = tf.train.Saver()

    with tf.Session() as sess:
      # load variables into session
      saver.restore(sess, model_file)

      # predict
      tss_preds = model.predict_genes(sess, gene_batcher, gene_data.gene_seqs,
          rc=options.rc, shifts=options.shifts, tss_radius=options.tss_radius)

    # save to file
    np.save('%s/preds' % options.out_dir, tss_preds)

    print(' Done in %ds.' % (time.time() - t0))

  #################################################################
  # convert to genes

  gene_targets, _ = gene.map_tss_genes(gene_data.tss_targets, gene_data.tss,
                                       tss_radius=options.tss_radius)
  gene_preds, _ = gene.map_tss_genes(tss_preds, gene_data.tss,
                                     tss_radius=options.tss_radius)


  #################################################################
  # determine targets

  # all targets
  if options.target_indexes is None:
    if gene_data.num_targets is None:
      print('No targets to test against')
      exit()
    else:
      options.target_indexes = np.arange(gene_data.num_targets)

  # file targets
  elif os.path.isfile(options.target_indexes):
    target_indexes_file = options.target_indexes
    options.target_indexes = []
    for line in open(target_indexes_file):
      options.target_indexes.append(int(line.split()[0]))

  # comma-separated targets
  else:
    options.target_indexes = [
        int(ti) for ti in options.target_indexes.split(',')
    ]

  options.target_indexes = np.array(options.target_indexes)

  #################################################################
  # correlation statistics

  t0 = time.time()
  print('Computing correlations.', end='')
  sys.stdout.flush()

  cor_table(gene_data.tss_targets, tss_preds, gene_data.target_ids,
            gene_data.target_labels, options.target_indexes,
            '%s/tss_cors.txt' % options.out_dir)

  cor_table(gene_targets, gene_preds, gene_data.target_ids,
            gene_data.target_labels, options.target_indexes,
            '%s/gene_cors.txt' % options.out_dir, plots=True)

  print(' Done in %ds.' % (time.time() - t0))

  #################################################################
  # gene statistics

  if options.print_tables:
    t0 = time.time()
    print('Printing predictions.', end='')
    sys.stdout.flush()

    gene_table(gene_data.tss_targets, tss_preds, gene_data.tss_ids(),
               gene_data.target_labels, options.target_indexes,
               '%s/transcript' % options.out_dir, options.plot_scatter)

    gene_table(gene_targets, gene_preds,
               gene_data.gene_ids(), gene_data.target_labels,
               options.target_indexes, '%s/gene' % options.out_dir,
               options.plot_scatter)

    print(' Done in %ds.' % (time.time() - t0))

  #################################################################
  # gene x target heatmaps

  if options.plot_heat or options.gene_variance:
    #########################################
    # normalize predictions across targets

    t0 = time.time()
    print('Normalizing values across targets.', end='')
    sys.stdout.flush()

    gene_targets_qn = normalize_targets(gene_targets[:, options.target_indexes], log_pseudo=1)
    gene_preds_qn = normalize_targets(gene_preds[:, options.target_indexes], log_pseudo=1)

    print(' Done in %ds.' % (time.time() - t0))

  if options.plot_heat:
    #########################################
    # plot genes by targets clustermap

    t0 = time.time()
    print('Plotting heat maps.', end='')
    sys.stdout.flush()

    sns.set(font_scale=1.3, style='ticks')
    plot_genes = 1600
    plot_targets = 800

    # choose a set of variable genes
    gene_vars = gene_preds_qn.var(axis=1)
    indexes_var = np.argsort(gene_vars)[::-1][:plot_genes]

    # choose a set of random genes
    if plot_genes < gene_preds_qn.shape[0]:
      indexes_rand = np.random.choice(
        np.arange(gene_preds_qn.shape[0]), plot_genes, replace=False)
    else:
      indexes_rand = np.arange(gene_preds_qn.shape[0])

    # choose a set of random targets
    if plot_targets < 0.8 * gene_preds_qn.shape[1]:
      indexes_targets = np.random.choice(
          np.arange(gene_preds_qn.shape[1]), plot_targets, replace=False)
    else:
      indexes_targets = np.arange(gene_preds_qn.shape[1])

    # variable gene predictions
    clustermap(gene_preds_qn[indexes_var, :][:, indexes_targets],
               '%s/gene_heat_var.pdf' % options.out_dir)
    clustermap(
        gene_preds_qn[indexes_var, :][:, indexes_targets],
        '%s/gene_heat_var_color.pdf' % options.out_dir,
        color='viridis',
        table=True)

    # random gene predictions
    clustermap(gene_preds_qn[indexes_rand, :][:, indexes_targets],
               '%s/gene_heat_rand.pdf' % options.out_dir)

    # variable gene targets
    clustermap(gene_targets_qn[indexes_var, :][:, indexes_targets],
               '%s/gene_theat_var.pdf' % options.out_dir)
    clustermap(
        gene_targets_qn[indexes_var, :][:, indexes_targets],
        '%s/gene_theat_var_color.pdf' % options.out_dir,
        color='viridis',
        table=True)

    # random gene targets (crashes)
    # clustermap(gene_targets_qn[indexes_rand, :][:, indexes_targets],
    #            '%s/gene_theat_rand.pdf' % options.out_dir)

    print(' Done in %ds.' % (time.time() - t0))

  #################################################################
  # analyze replicates

  if options.replicate_labels_file is not None:
    # read long form labels, from which to infer replicates
    target_labels_long = []
    for line in open(options.replicate_labels_file):
      a = line.split('\t')
      a[-1] = a[-1].rstrip()
      target_labels_long.append(a[-1])

    # determine replicates
    replicate_lists = infer_replicates(target_labels_long)

    # compute correlations
    # replicate_correlations(replicate_lists, gene_data.tss_targets,
        # tss_preds, options.target_indexes, '%s/transcript_reps' % options.out_dir)
    replicate_correlations(
        replicate_lists, gene_targets, gene_preds, options.target_indexes,
        '%s/gene_reps' % options.out_dir)  # , scatter_plots=True)

  #################################################################
  # gene variance

  if options.gene_variance:
    variance_accuracy(gene_targets_qn, gene_preds_qn,
                      '%s/gene' % options.out_dir)

  #################################################################
  # alternative TSS

  if options.tss_alt:
    alternative_tss(gene_data.tss_targets[:,options.target_indexes],
                    tss_preds[:,options.target_indexes], gene_data,
                    options.out_dir, log_pseudo=1)
Exemplo n.º 3
0
def main():
    usage = "usage: %prog [options] <params_file> <model_file> <genes_hdf5_file>"
    parser = OptionParser(usage)
    parser.add_option(
        "-g",
        dest="genome_file",
        default="%s/data/human.hg19.genome" % os.environ["BASENJIDIR"],
        help="Chromosome lengths file [Default: %default]",
    )
    parser.add_option("-l",
                      dest="gene_list",
                      help="Process only gene ids in the given file")
    parser.add_option(
        "-o",
        dest="out_dir",
        default="grad_mapg",
        help="Output directory [Default: %default]",
    )
    parser.add_option("-t",
                      dest="target_indexes",
                      default=None,
                      help="Target indexes to plot")
    (options, args) = parser.parse_args()

    if len(args) != 3:
        parser.error("Must provide parameters, model, and genomic position")
    else:
        params_file = args[0]
        model_file = args[1]
        genes_hdf5_file = args[2]

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    #################################################################
    # reads in genes HDF5

    gene_data = genedata.GeneData(genes_hdf5_file)

    # subset gene sequences
    genes_subset = set()
    if options.gene_list:
        for line in open(options.gene_list):
            genes_subset.add(line.rstrip())

        gene_data.subset_genes(genes_subset)
        print("Filtered to %d sequences" % gene_data.num_seqs)

    #######################################################
    # model parameters and placeholders

    job = params.read_job_params(params_file)

    job["seq_length"] = gene_data.seq_length
    job["seq_depth"] = gene_data.seq_depth
    job["target_pool"] = gene_data.pool_width

    if "num_targets" not in job:
        print(
            "Must specify number of targets (num_targets) in the parameters file.",
            file=sys.stderr,
        )
        exit(1)

    # set target indexes
    if options.target_indexes is not None:
        options.target_indexes = [
            int(ti) for ti in options.target_indexes.split(",")
        ]
        target_subset = options.target_indexes
    else:
        options.target_indexes = list(range(job["num_targets"]))
        target_subset = None

    # build model
    model = seqnn.SeqNN()
    model.build(job, target_subset=target_subset)

    # determine latest pre-dilated layer
    cnn_dilation = np.array([cp.dilation for cp in model.hp.cnn_params])
    dilated_mask = cnn_dilation > 1
    dilated_indexes = np.where(dilated_mask)[0]
    pre_dilated_layer = np.min(dilated_indexes)
    print("Pre-dilated layer: %d" % pre_dilated_layer)

    # build gradients ops
    t0 = time.time()
    print("Building target/position-specific gradient ops.", end="")
    model.build_grads_genes(gene_data.gene_seqs, layers=[pre_dilated_layer])
    print(" Done in %ds" % (time.time() - t0), flush=True)

    #######################################################
    # acquire gradients

    # initialize saver
    saver = tf.train.Saver()

    with tf.Session() as sess:
        # load variables into session
        saver.restore(sess, model_file)

        for si in range(gene_data.num_seqs):
            # initialize batcher
            batcher_si = batcher.Batcher(
                gene_data.seqs_1hot[si:si + 1],
                batch_size=model.hp.batch_size,
                pool_width=model.hp.target_pool,
            )

            # get layer representations
            t0 = time.time()
            print("Computing gradients.", end="", flush=True)
            batch_grads, batch_reprs = model.gradients_genes(
                sess, batcher_si, gene_data.gene_seqs[si:si + 1])
            print(" Done in %ds." % (time.time() - t0), flush=True)

            # only layer
            batch_reprs = batch_reprs[0]
            batch_grads = batch_grads[0]

            # G (TSSs) x T (targets) x P (seq position) x U (Units layer i)
            print("batch_grads", batch_grads.shape)
            pooled_length = batch_grads.shape[2]

            # S (sequences) x P (seq position) x U (Units layer i)
            print("batch_reprs", batch_reprs.shape)

            # write bigwigs
            t0 = time.time()
            print("Writing BigWigs.", end="", flush=True)

            # for each TSS
            for tss_i in range(batch_grads.shape[0]):
                tss = gene_data.gene_seqs[si].tss_list[tss_i]

                # for each target
                for tii in range(len(options.target_indexes)):
                    ti = options.target_indexes[tii]

                    # dot representation and gradient
                    batch_grads_score = np.multiply(
                        batch_reprs[0], batch_grads[tss_i,
                                                    tii, :, :]).sum(axis=1)

                    # open bigwig
                    bw_file = "%s/%s-%s_t%d.bw" % (
                        options.out_dir,
                        tss.gene_id,
                        tss.identifier,
                        ti,
                    )
                    bw_open = bigwig_open(bw_file, options.genome_file)

                    # access gene sequence information
                    seq_chrom = gene_data.gene_seqs[si].chrom
                    seq_start = gene_data.gene_seqs[si].start

                    # specify bigwig locations and values
                    bw_chroms = [seq_chrom] * pooled_length
                    bw_starts = [
                        int(seq_start + li * model.hp.target_pool)
                        for li in range(pooled_length)
                    ]
                    bw_ends = [
                        int(bws + model.hp.target_pool) for bws in bw_starts
                    ]
                    bw_values = [float(bgs) for bgs in batch_grads_score]

                    # write
                    bw_open.addEntries(bw_chroms,
                                       bw_starts,
                                       ends=bw_ends,
                                       values=bw_values)

                    # close
                    bw_open.close()

            print(" Done in %ds." % (time.time() - t0), flush=True)
            gc.collect()
Exemplo n.º 4
0
def main():
    usage = (
        'usage: %prog [options] <params_file> <model_file> <genes_hdf5_file>'
        ' <vcf_file>')
    parser = OptionParser(usage)
    parser.add_option(
        '-a',
        dest='all_sed',
        default=False,
        action='store_true',
        help=
        'Print all variant-gene pairs, as opposed to only nonzero [Default: %default]'
    )
    parser.add_option('-b',
                      dest='batch_size',
                      default=None,
                      type='int',
                      help='Batch size [Default: %default]')
    parser.add_option('-c',
                      dest='csv',
                      default=False,
                      action='store_true',
                      help='Print table as CSV [Default: %default]')
    parser.add_option('-g',
                      dest='genome_file',
                      default='%s/assembly/human.hg19.genome' %
                      os.environ['HG19'],
                      help='Chromosome lengths file [Default: %default]')
    parser.add_option(
        '-o',
        dest='out_dir',
        default='sed',
        help='Output directory for tables and plots [Default: %default]')
    parser.add_option('-p',
                      dest='processes',
                      default=None,
                      type='int',
                      help='Number of processes, passed by multi script')
    parser.add_option('--pseudo',
                      dest='log_pseudo',
                      default=0.125,
                      type='float',
                      help='Log2 pseudocount [Default: %default]')
    parser.add_option(
        '-r',
        dest='tss_radius',
        default=0,
        type='int',
        help=
        'Radius of bins considered to quantify TSS transcription [Default: %default]'
    )
    parser.add_option(
        '--rc',
        dest='rc',
        default=False,
        action='store_true',
        help=
        'Average the forward and reverse complement predictions when testing [Default: %default]'
    )
    parser.add_option('--shifts',
                      dest='shifts',
                      default='0',
                      help='Ensemble prediction shifts [Default: %default]')
    parser.add_option(
        '-t',
        dest='targets_file',
        default=None,
        help='File specifying target indexes and labels in table format')
    parser.add_option(
        '-u',
        dest='penultimate',
        default=False,
        action='store_true',
        help='Compute SED in the penultimate layer [Default: %default]')
    parser.add_option(
        '-x',
        dest='tss_table',
        default=False,
        action='store_true',
        help='Print TSS table in addition to gene [Default: %default]')
    (options, args) = parser.parse_args()

    if len(args) == 4:
        # single worker
        params_file = args[0]
        model_file = args[1]
        genes_hdf5_file = args[2]
        vcf_file = args[3]

    elif len(args) == 6:
        # multi worker
        options_pkl_file = args[0]
        params_file = args[1]
        model_file = args[2]
        genes_hdf5_file = args[3]
        vcf_file = args[4]
        worker_index = int(args[5])

        # load options
        options_pkl = open(options_pkl_file, 'rb')
        options = pickle.load(options_pkl)
        options_pkl.close()

        # update output directory
        options.out_dir = '%s/job%d' % (options.out_dir, worker_index)

    else:
        parser.error(
            'Must provide parameters and model files, genes HDF5 file, and QTL VCF'
            ' file')

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    options.shifts = [int(shift) for shift in options.shifts.split(',')]

    #################################################################
    # reads in genes HDF5

    gene_data = genedata.GeneData(genes_hdf5_file)

    # filter for worker sequences
    if options.processes is not None:
        gene_data.worker(worker_index, options.processes)

    #################################################################
    # prep SNPs

    # load SNPs
    snps = bvcf.vcf_snps(vcf_file)

    # intersect w/ segments
    print('Intersecting gene sequences with SNPs...', end='')
    sys.stdout.flush()
    seqs_snps = bvcf.intersect_seqs_snps(vcf_file,
                                         gene_data.gene_seqs,
                                         vision_p=0.5)
    print('%d sequences w/ SNPs' % len(seqs_snps))

    #################################################################
    # setup model

    job = params.read_job_params(params_file)

    job['seq_length'] = gene_data.seq_length
    job['seq_depth'] = gene_data.seq_depth
    job['target_pool'] = gene_data.pool_width

    if 'num_targets' not in job and gene_data.num_targets is not None:
        job['num_targets'] = gene_data.num_targets

    if 'num_targets' not in job:
        print("Must specify number of targets (num_targets) in the \
            parameters file.",
              file=sys.stderr)
        exit(1)

    if options.targets_file is None:
        target_labels = gene_data.target_labels
        target_subset = None
        target_ids = gene_data.target_ids
        if target_ids is None:
            target_ids = ['t%d' % ti for ti in range(job['num_targets'])]
            target_labels = [''] * len(target_ids)

    else:
        # Unfortunately, this target file differs from some others
        # in that it needs to specify the indexes from the original
        # set. In the future, I should standardize to this version.

        target_ids = []
        target_labels = []
        target_subset = []

        for line in open(options.targets_file):
            a = line.strip().split('\t')
            target_subset.append(int(a[0]))
            target_ids.append(a[1])
            target_labels.append(a[3])

            if len(target_subset) == job['num_targets']:
                target_subset = None

    # build model
    model = seqnn.SeqNN()
    model.build(job, target_subset=target_subset)

    if options.penultimate:
        # labels become inappropriate
        target_ids = [''] * model.hp.cnn_filters[-1]
        target_labels = target_ids

    #################################################################
    # compute, collect, and print SEDs

    header_cols = ('rsid', 'ref', 'alt', 'gene', 'tss_dist', 'ref_pred',
                   'alt_pred', 'sed', 'ser', 'target_index', 'target_id',
                   'target_label')
    if options.csv:
        sed_gene_out = open('%s/sed_gene.csv' % options.out_dir, 'w')
        print(','.join(header_cols), file=sed_gene_out)
        if options.tss_table:
            sed_tss_out = open('%s/sed_tss.csv' % options.out_dir, 'w')
            print(','.join(header_cols), file=sed_tss_out)

    else:
        sed_gene_out = open('%s/sed_gene.txt' % options.out_dir, 'w')
        print(' '.join(header_cols), file=sed_gene_out)
        if options.tss_table:
            sed_tss_out = open('%s/sed_tss.txt' % options.out_dir, 'w')
            print(' '.join(header_cols), file=sed_tss_out)

    # helper variables
    pred_buffer = model.hp.batch_buffer // model.hp.target_pool

    # initialize saver
    saver = tf.train.Saver()

    with tf.Session() as sess:
        # load variables into session
        saver.restore(sess, model_file)

        # for each gene sequence
        for seq_i in range(gene_data.num_seqs):
            gene_seq = gene_data.gene_seqs[seq_i]
            print(gene_seq)

            # if it contains SNPs
            seq_snps = [snps[snp_i] for snp_i in seqs_snps[seq_i]]
            if len(seq_snps) > 0:

                # one hot code allele sequences
                aseqs_1hot = alleles_1hot(gene_seq, gene_data.seqs_1hot[seq_i],
                                          seq_snps)

                # initialize batcher
                batcher_gene = batcher.Batcher(aseqs_1hot,
                                               batch_size=model.hp.batch_size)

                # construct allele gene_seq's
                allele_gene_seqs = [gene_seq] * aseqs_1hot.shape[0]

                # predict alleles
                allele_tss_preds = model.predict_genes(
                    sess,
                    batcher_gene,
                    allele_gene_seqs,
                    rc=options.rc,
                    shifts=options.shifts,
                    penultimate=options.penultimate,
                    tss_radius=options.tss_radius)

                # reshape (Alleles x TSSs) x Targets to Alleles x TSSs x Targets
                allele_tss_preds = allele_tss_preds.reshape(
                    (aseqs_1hot.shape[0], gene_seq.num_tss, -1))

                # extract reference and SNP alt predictions
                ref_tss_preds = allele_tss_preds[0]  # TSSs x Targets
                alt_tss_preds = allele_tss_preds[1:]  # SNPs x TSSs x Targets

                # compute TSS SED scores
                snp_tss_sed = alt_tss_preds - ref_tss_preds
                snp_tss_ser = np.log2(alt_tss_preds + options.log_pseudo) \
                                - np.log2(ref_tss_preds + options.log_pseudo)

                # compute gene-level predictions
                ref_gene_preds, gene_ids = gene.map_tss_genes(
                    ref_tss_preds, gene_seq.tss_list, options.tss_radius)
                alt_gene_preds = []
                for snp_i in range(len(seq_snps)):
                    agp, _ = gene.map_tss_genes(alt_tss_preds[snp_i],
                                                gene_seq.tss_list,
                                                options.tss_radius)
                    alt_gene_preds.append(agp)
                alt_gene_preds = np.array(alt_gene_preds)

                # compute gene SED scores
                gene_sed = alt_gene_preds - ref_gene_preds
                gene_ser = np.log2(alt_gene_preds + options.log_pseudo) \
                            - np.log2(ref_gene_preds + options.log_pseudo)

                # for each SNP
                for snp_i in range(len(seq_snps)):
                    snp = seq_snps[snp_i]

                    # initialize gene data structures
                    snp_dist_gene = {}

                    # for each TSS
                    for tss_i in range(gene_seq.num_tss):
                        tss = gene_seq.tss_list[tss_i]

                        # SNP distance to TSS
                        snp_dist = abs(snp.pos - tss.pos)
                        if tss.gene_id in snp_dist_gene:
                            snp_dist_gene[tss.gene_id] = min(
                                snp_dist_gene[tss.gene_id], snp_dist)
                        else:
                            snp_dist_gene[tss.gene_id] = snp_dist

                        # for each target
                        if options.tss_table:
                            for ti in range(ref_tss_preds.shape[1]):

                                # check if nonzero
                                if options.all_sed or not np.isclose(
                                        tss_sed[snp_i, tss_i,
                                                ti], 0, atol=1e-4):

                                    # print
                                    cols = (snp.rsid,
                                            bvcf.cap_allele(snp.ref_allele),
                                            bvcf.cap_allele(
                                                snp.alt_alleles[0]),
                                            tss.identifier, snp_dist,
                                            ref_tss_preds[tss_i, ti],
                                            alt_tss_preds[snp_i, tss_i, ti],
                                            tss_sed[snp_i, tss_i,
                                                    ti], tss_ser[snp_i, tss_i,
                                                                 ti], ti,
                                            target_ids[ti], target_labels[ti])
                                    if options.csv:
                                        print(','.join([str(c) for c in cols]),
                                              file=sed_tss_out)
                                    else:
                                        print(
                                            '%-13s %s %5s %16s %5d %7.4f %7.4f %7.4f %7.4f %4d %12s %s'
                                            % cols,
                                            file=sed_tss_out)

                    # for each gene
                    for gi in range(len(gene_ids)):
                        gene_str = gene_ids[gi]
                        if gene_ids[gi] in gene_data.multi_seq_genes:
                            gene_str = '%s_multi' % gene_ids[gi]

                        # print rows to gene table
                        for ti in range(ref_gene_preds.shape[1]):

                            # check if nonzero
                            if options.all_sed or not np.isclose(
                                    gene_sed[snp_i, gi, ti], 0, atol=1e-4):

                                # print
                                cols = [
                                    snp.rsid,
                                    bvcf.cap_allele(snp.ref_allele),
                                    bvcf.cap_allele(snp.alt_alleles[0]),
                                    gene_str, snp_dist_gene[gene_ids[gi]],
                                    ref_gene_preds[gi,
                                                   ti], alt_gene_preds[snp_i,
                                                                       gi, ti],
                                    gene_sed[snp_i, gi,
                                             ti], gene_ser[snp_i, gi, ti], ti,
                                    target_ids[ti], target_labels[ti]
                                ]
                                if options.csv:
                                    print(','.join([str(c) for c in cols]),
                                          file=sed_gene_out)
                                else:
                                    print(
                                        '%-13s %s %5s %16s %5d %7.4f %7.4f %7.4f %7.4f %4d %12s %s'
                                        % tuple(cols),
                                        file=sed_gene_out)

                # clean up
                gc.collect()

    sed_gene_out.close()
    if options.tss_table:
        sed_tss_out.close()
Exemplo n.º 5
0
def main():
    usage = "usage: %prog [options] <params_file> <model_file> <genes_hdf5_file>"
    parser = OptionParser(usage)
    parser.add_option(
        "-g",
        dest="genome_file",
        default="%s/data/human.hg38.genome" % os.environ["BASENJIDIR"],
        help="Chromosome lengths file [Default: %default]",
    )
    parser.add_option(
        "-l", dest="gene_list", help="Process only gene ids in the given file"
    )
    parser.add_option(
        "--mc",
        dest="mc_n",
        default=0,
        type="int",
        help="Monte carlo test iterations [Default: %default]",
    )
    parser.add_option(
        "-n",
        dest="norm",
        default=None,
        type="int",
        help="Compute saliency norm [Default% default]",
    )
    parser.add_option(
        "-o",
        dest="out_dir",
        default="grad_map",
        help="Output directory [Default: %default]",
    )
    parser.add_option(
        "--rc",
        dest="rc",
        default=False,
        action="store_true",
        help="Average the forward and reverse complement predictions when testing [Default: %default]",
    )
    parser.add_option(
        "--shifts",
        dest="shifts",
        default="0",
        help="Ensemble prediction shifts [Default: %default]",
    )
    parser.add_option(
        "-t", dest="target_indexes", default=None, help="Target indexes to plot"
    )
    (options, args) = parser.parse_args()

    if len(args) != 3:
        parser.error("Must provide parameters, model, and genomic position")
    else:
        params_file = args[0]
        model_file = args[1]
        genes_hdf5_file = args[2]

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    options.shifts = [int(shift) for shift in options.shifts.split(",")]

    #################################################################
    # reads in genes HDF5

    gene_data = genedata.GeneData(genes_hdf5_file)

    # subset gene sequences
    genes_subset = set()
    if options.gene_list:
        for line in open(options.gene_list):
            genes_subset.add(line.rstrip())

        gene_data.subset_genes(genes_subset)
        print("Filtered to %d sequences" % gene_data.num_seqs)

    # extract sequence chrom and start
    seqs_chrom = [gene_data.gene_seqs[si].chrom for si in range(gene_data.num_seqs)]
    seqs_start = [gene_data.gene_seqs[si].start for si in range(gene_data.num_seqs)]

    #######################################################
    # model parameters and placeholders

    job = params.read_job_params(params_file)

    job["seq_length"] = gene_data.seq_length
    job["seq_depth"] = gene_data.seq_depth
    job["target_pool"] = gene_data.pool_width

    if "num_targets" not in job:
        print(
            "Must specify number of targets (num_targets) in the parameters file.",
            file=sys.stderr,
        )
        exit(1)

    # set target indexes
    if options.target_indexes is not None:
        options.target_indexes = [int(ti) for ti in options.target_indexes.split(",")]
        target_subset = options.target_indexes
    else:
        options.target_indexes = list(range(job["num_targets"]))
        target_subset = None

    # build model
    model = seqnn.SeqNN()
    model.build_feed(job, target_subset=target_subset)

    # determine latest pre-dilated layer
    cnn_dilation = np.array([cp.dilation for cp in model.hp.cnn_params])
    dilated_mask = cnn_dilation > 1
    dilated_indexes = np.where(dilated_mask)[0]
    pre_dilated_layer = np.min(dilated_indexes)
    print("Pre-dilated layer: %d" % pre_dilated_layer)

    # build gradients ops
    t0 = time.time()
    print("Building target/position-specific gradient ops.", end="")
    model.build_grads(layers=[pre_dilated_layer])
    print(" Done in %ds" % (time.time() - t0), flush=True)

    #######################################################
    # acquire gradients

    # initialize saver
    saver = tf.train.Saver()

    with tf.Session() as sess:
        # load variables into session
        saver.restore(sess, model_file)

        # score sequences and write bigwigs
        score_write(sess, model, options, gene_data.seqs_1hot, seqs_chrom, seqs_start)
Exemplo n.º 6
0
def main():
    usage = (
        "usage: %prog [options] <params_file> <model_file> <genes_hdf5_file>"
        " <vcf_file>"
    )
    parser = OptionParser(usage)
    parser.add_option(
        "-a",
        dest="all_sed",
        default=False,
        action="store_true",
        help="Print all variant-gene pairs, as opposed to only nonzero [Default: %default]",
    )
    parser.add_option(
        "-b",
        dest="batch_size",
        default=None,
        type="int",
        help="Batch size [Default: %default]",
    )
    parser.add_option(
        "-c",
        dest="csv",
        default=False,
        action="store_true",
        help="Print table as CSV [Default: %default]",
    )
    parser.add_option(
        "-g",
        dest="genome_file",
        default="%s/data/human.hg19.genome" % os.environ["BASENJIDIR"],
        help="Chromosome lengths file [Default: %default]",
    )
    parser.add_option(
        "-o",
        dest="out_dir",
        default="sed",
        help="Output directory for tables and plots [Default: %default]",
    )
    parser.add_option(
        "-p",
        dest="processes",
        default=None,
        type="int",
        help="Number of processes, passed by multi script",
    )
    parser.add_option(
        "--pseudo",
        dest="log_pseudo",
        default=0.125,
        type="float",
        help="Log2 pseudocount [Default: %default]",
    )
    parser.add_option(
        "-r",
        dest="tss_radius",
        default=0,
        type="int",
        help="Radius of bins considered to quantify TSS transcription [Default: %default]",
    )
    parser.add_option(
        "--rc",
        dest="rc",
        default=False,
        action="store_true",
        help="Average the forward and reverse complement predictions when testing [Default: %default]",
    )
    parser.add_option(
        "--shifts",
        dest="shifts",
        default="0",
        help="Ensemble prediction shifts [Default: %default]",
    )
    parser.add_option(
        "-t",
        dest="targets_file",
        default=None,
        help="File specifying target indexes and labels in table format",
    )
    parser.add_option(
        "-u",
        dest="penultimate",
        default=False,
        action="store_true",
        help="Compute SED in the penultimate layer [Default: %default]",
    )
    parser.add_option(
        "-x",
        dest="tss_table",
        default=False,
        action="store_true",
        help="Print TSS table in addition to gene [Default: %default]",
    )
    (options, args) = parser.parse_args()

    if len(args) == 4:
        # single worker
        params_file = args[0]
        model_file = args[1]
        genes_hdf5_file = args[2]
        vcf_file = args[3]

    elif len(args) == 6:
        # multi worker
        options_pkl_file = args[0]
        params_file = args[1]
        model_file = args[2]
        genes_hdf5_file = args[3]
        vcf_file = args[4]
        worker_index = int(args[5])

        # load options
        options_pkl = open(options_pkl_file, "rb")
        options = pickle.load(options_pkl)
        options_pkl.close()

        # update output directory
        options.out_dir = "%s/job%d" % (options.out_dir, worker_index)

    else:
        parser.error(
            "Must provide parameters and model files, genes HDF5 file, and QTL VCF"
            " file"
        )

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    options.shifts = [int(shift) for shift in options.shifts.split(",")]

    #################################################################
    # reads in genes HDF5

    gene_data = genedata.GeneData(genes_hdf5_file)

    # filter for worker sequences
    if options.processes is not None:
        gene_data.worker(worker_index, options.processes)

    #################################################################
    # prep SNPs

    # load SNPs
    snps = bvcf.vcf_snps(vcf_file)

    # intersect w/ segments
    print("Intersecting gene sequences with SNPs...", end="")
    sys.stdout.flush()
    seqs_snps = bvcf.intersect_seqs_snps(vcf_file, gene_data.gene_seqs, vision_p=0.5)
    print("%d sequences w/ SNPs" % len(seqs_snps))

    #################################################################
    # setup model

    job = params.read_job_params(params_file)

    job["seq_length"] = gene_data.seq_length
    job["seq_depth"] = gene_data.seq_depth
    job["target_pool"] = gene_data.pool_width

    if "num_targets" not in job and gene_data.num_targets is not None:
        job["num_targets"] = gene_data.num_targets

    if "num_targets" not in job:
        print(
            "Must specify number of targets (num_targets) in the \
            parameters file.",
            file=sys.stderr,
        )
        exit(1)

    if options.targets_file is None:
        target_labels = gene_data.target_labels
        target_subset = None
        target_ids = gene_data.target_ids
        if target_ids is None:
            target_ids = ["t%d" % ti for ti in range(job["num_targets"])]
            target_labels = [""] * len(target_ids)

    else:
        # Unfortunately, this target file differs from some others
        # in that it needs to specify the indexes from the original
        # set. In the future, I should standardize to this version.

        target_ids = []
        target_labels = []
        target_subset = []

        for line in open(options.targets_file):
            a = line.strip().split("\t")
            target_subset.append(int(a[0]))
            target_ids.append(a[1])
            target_labels.append(a[3])

            if len(target_subset) == job["num_targets"]:
                target_subset = None

    # build model
    model = seqnn.SeqNN()
    model.build_feed(job, target_subset=target_subset)

    if options.penultimate:
        # labels become inappropriate
        target_ids = [""] * model.hp.cnn_filters[-1]
        target_labels = target_ids

    #################################################################
    # compute, collect, and print SEDs

    header_cols = (
        "rsid",
        "ref",
        "alt",
        "gene",
        "tss_dist",
        "ref_pred",
        "alt_pred",
        "sed",
        "ser",
        "target_index",
        "target_id",
        "target_label",
    )
    if options.csv:
        sed_gene_out = open("%s/sed_gene.csv" % options.out_dir, "w")
        print(",".join(header_cols), file=sed_gene_out)
        if options.tss_table:
            sed_tss_out = open("%s/sed_tss.csv" % options.out_dir, "w")
            print(",".join(header_cols), file=sed_tss_out)

    else:
        sed_gene_out = open("%s/sed_gene.txt" % options.out_dir, "w")
        print(" ".join(header_cols), file=sed_gene_out)
        if options.tss_table:
            sed_tss_out = open("%s/sed_tss.txt" % options.out_dir, "w")
            print(" ".join(header_cols), file=sed_tss_out)

    # helper variables
    pred_buffer = model.hp.batch_buffer // model.hp.target_pool

    # initialize saver
    saver = tf.train.Saver()

    with tf.Session() as sess:
        # load variables into session
        saver.restore(sess, model_file)

        # for each gene sequence
        for seq_i in range(gene_data.num_seqs):
            gene_seq = gene_data.gene_seqs[seq_i]
            print(gene_seq)

            # if it contains SNPs
            seq_snps = [snps[snp_i] for snp_i in seqs_snps[seq_i]]
            if len(seq_snps) > 0:

                # one hot code allele sequences
                aseqs_1hot = alleles_1hot(
                    gene_seq, gene_data.seqs_1hot[seq_i], seq_snps
                )

                # initialize batcher
                batcher_gene = batcher.Batcher(
                    aseqs_1hot, batch_size=model.hp.batch_size
                )

                # construct allele gene_seq's
                allele_gene_seqs = [gene_seq] * aseqs_1hot.shape[0]

                # predict alleles
                allele_tss_preds = model.predict_genes(
                    sess,
                    batcher_gene,
                    allele_gene_seqs,
                    rc=options.rc,
                    shifts=options.shifts,
                    embed_penultimate=options.penultimate,
                    tss_radius=options.tss_radius,
                )

                # reshape (Alleles x TSSs) x Targets to Alleles x TSSs x Targets
                allele_tss_preds = allele_tss_preds.reshape(
                    (aseqs_1hot.shape[0], gene_seq.num_tss, -1)
                )

                # extract reference and SNP alt predictions
                ref_tss_preds = allele_tss_preds[0]  # TSSs x Targets
                alt_tss_preds = allele_tss_preds[1:]  # SNPs x TSSs x Targets

                # compute TSS SED scores
                snp_tss_sed = alt_tss_preds - ref_tss_preds
                snp_tss_ser = np.log2(alt_tss_preds + options.log_pseudo) - np.log2(
                    ref_tss_preds + options.log_pseudo
                )

                # compute gene-level predictions
                ref_gene_preds, gene_ids = gene.map_tss_genes(
                    ref_tss_preds, gene_seq.tss_list, options.tss_radius
                )
                alt_gene_preds = []
                for snp_i in range(len(seq_snps)):
                    agp, _ = gene.map_tss_genes(
                        alt_tss_preds[snp_i], gene_seq.tss_list, options.tss_radius
                    )
                    alt_gene_preds.append(agp)
                alt_gene_preds = np.array(alt_gene_preds)

                # compute gene SED scores
                gene_sed = alt_gene_preds - ref_gene_preds
                gene_ser = np.log2(alt_gene_preds + options.log_pseudo) - np.log2(
                    ref_gene_preds + options.log_pseudo
                )

                # for each SNP
                for snp_i in range(len(seq_snps)):
                    snp = seq_snps[snp_i]

                    # initialize gene data structures
                    snp_dist_gene = {}

                    # for each TSS
                    for tss_i in range(gene_seq.num_tss):
                        tss = gene_seq.tss_list[tss_i]

                        # SNP distance to TSS
                        snp_dist = abs(snp.pos - tss.pos)
                        if tss.gene_id in snp_dist_gene:
                            snp_dist_gene[tss.gene_id] = min(
                                snp_dist_gene[tss.gene_id], snp_dist
                            )
                        else:
                            snp_dist_gene[tss.gene_id] = snp_dist

                        # for each target
                        if options.tss_table:
                            for ti in range(ref_tss_preds.shape[1]):

                                # check if nonzero
                                if options.all_sed or not np.isclose(
                                    tss_sed[snp_i, tss_i, ti], 0, atol=1e-4
                                ):

                                    # print
                                    cols = (
                                        snp.rsid,
                                        bvcf.cap_allele(snp.ref_allele),
                                        bvcf.cap_allele(snp.alt_alleles[0]),
                                        tss.identifier,
                                        snp_dist,
                                        ref_tss_preds[tss_i, ti],
                                        alt_tss_preds[snp_i, tss_i, ti],
                                        tss_sed[snp_i, tss_i, ti],
                                        tss_ser[snp_i, tss_i, ti],
                                        ti,
                                        target_ids[ti],
                                        target_labels[ti],
                                    )
                                    if options.csv:
                                        print(
                                            ",".join([str(c) for c in cols]),
                                            file=sed_tss_out,
                                        )
                                    else:
                                        print(
                                            "%-13s %s %5s %16s %5d %7.4f %7.4f %7.4f %7.4f %4d %12s %s"
                                            % cols,
                                            file=sed_tss_out,
                                        )

                    # for each gene
                    for gi in range(len(gene_ids)):
                        gene_str = gene_ids[gi]
                        if gene_ids[gi] in gene_data.multi_seq_genes:
                            gene_str = "%s_multi" % gene_ids[gi]

                        # print rows to gene table
                        for ti in range(ref_gene_preds.shape[1]):

                            # check if nonzero
                            if options.all_sed or not np.isclose(
                                gene_sed[snp_i, gi, ti], 0, atol=1e-4
                            ):

                                # print
                                cols = [
                                    snp.rsid,
                                    bvcf.cap_allele(snp.ref_allele),
                                    bvcf.cap_allele(snp.alt_alleles[0]),
                                    gene_str,
                                    snp_dist_gene[gene_ids[gi]],
                                    ref_gene_preds[gi, ti],
                                    alt_gene_preds[snp_i, gi, ti],
                                    gene_sed[snp_i, gi, ti],
                                    gene_ser[snp_i, gi, ti],
                                    ti,
                                    target_ids[ti],
                                    target_labels[ti],
                                ]
                                if options.csv:
                                    print(
                                        ",".join([str(c) for c in cols]),
                                        file=sed_gene_out,
                                    )
                                else:
                                    print(
                                        "%-13s %s %5s %16s %5d %7.4f %7.4f %7.4f %7.4f %4d %12s %s"
                                        % tuple(cols),
                                        file=sed_gene_out,
                                    )

                # clean up
                gc.collect()

    sed_gene_out.close()
    if options.tss_table:
        sed_tss_out.close()
Exemplo n.º 7
0
def main():
    usage = "usage: %prog [options] <params_file> <model_file> <genes_hdf5_file>"
    parser = OptionParser(usage)
    parser.add_option(
        "-b",
        dest="batch_size",
        default=None,
        type="int",
        help="Batch size [Default: %default]",
    )
    parser.add_option(
        "-i",
        dest="ignore_bed",
        help="Ignore genes overlapping regions in this BED file",
    )
    parser.add_option("-l", dest="load_preds", help="Load tess_preds from file")
    parser.add_option(
        "--heat",
        dest="plot_heat",
        default=False,
        action="store_true",
        help="Plot big gene-target heatmaps [Default: %default]",
    )
    parser.add_option(
        "-o",
        dest="out_dir",
        default="genes_out",
        help="Output directory for tables and plots [Default: %default]",
    )
    parser.add_option(
        "-r",
        dest="tss_radius",
        default=0,
        type="int",
        help="Radius of bins considered to quantify TSS transcription [Default: %default]",
    )
    parser.add_option(
        "--rc",
        dest="rc",
        default=False,
        action="store_true",
        help="Average the forward and reverse complement predictions when testing [Default: %default]",
    )
    parser.add_option(
        "-s",
        dest="plot_scatter",
        default=False,
        action="store_true",
        help="Make time-consuming accuracy scatter plots [Default: %default]",
    )
    parser.add_option(
        "--shifts",
        dest="shifts",
        default="0",
        help="Ensemble prediction shifts [Default: %default]",
    )
    parser.add_option(
        "--rep",
        dest="replicate_labels_file",
        help="Compare replicate experiments, aided by the given file with long labels",
    )
    parser.add_option(
        "-t",
        dest="targets_file",
        default=None,
        type="str",
        help="File specifying target indexes and labels in table format",
    )
    parser.add_option(
        "--table",
        dest="print_tables",
        default=False,
        action="store_true",
        help="Print big gene/TSS tables [Default: %default]",
    )
    parser.add_option(
        "--tss",
        dest="tss_alt",
        default=False,
        action="store_true",
        help="Perform alternative TSS analysis [Default: %default]",
    )
    parser.add_option(
        "-v",
        dest="gene_variance",
        default=False,
        action="store_true",
        help="Study accuracy with respect to gene variance across targets [Default: %default]",
    )
    (options, args) = parser.parse_args()

    if len(args) != 3:
        parser.error("Must provide parameters and model files, and genes HDF5 file")
    else:
        params_file = args[0]
        model_file = args[1]
        genes_hdf5_file = args[2]

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    options.shifts = [int(shift) for shift in options.shifts.split(",")]

    #################################################################
    # read in genes and targets

    gene_data = genedata.GeneData(genes_hdf5_file)

    #################################################################
    # TSS predictions

    if options.load_preds is not None:
        # load from file
        tss_preds = np.load(options.load_preds)

    else:

        #######################################################
        # setup model

        job = params.read_job_params(params_file)

        job["seq_length"] = gene_data.seq_length
        job["seq_depth"] = gene_data.seq_depth
        job["target_pool"] = gene_data.pool_width
        if not "num_targets" in job:
            job["num_targets"] = gene_data.num_targets

        # build model
        model = seqnn.SeqNN()
        model.build_feed_sad(job)

        if options.batch_size is not None:
            model.hp.batch_size = options.batch_size

        #######################################################
        # predict TSSs

        t0 = time.time()
        print("Computing gene predictions.", end="")
        sys.stdout.flush()

        # initialize batcher
        gene_batcher = batcher.Batcher(
            gene_data.seqs_1hot, batch_size=model.hp.batch_size
        )

        # initialie saver
        saver = tf.train.Saver()

        with tf.Session() as sess:
            # load variables into session
            saver.restore(sess, model_file)

            # predict
            tss_preds = model.predict_genes(
                sess,
                gene_batcher,
                gene_data.gene_seqs,
                rc=options.rc,
                shifts=options.shifts,
                tss_radius=options.tss_radius,
            )

        # save to file
        np.save("%s/preds" % options.out_dir, tss_preds)

        print(" Done in %ds." % (time.time() - t0))

    # up-convert
    tss_preds = tss_preds.astype("float32")

    #################################################################
    # convert to genes

    gene_targets, _ = gene.map_tss_genes(
        gene_data.tss_targets, gene_data.tss, tss_radius=options.tss_radius
    )
    gene_preds, _ = gene.map_tss_genes(
        tss_preds, gene_data.tss, tss_radius=options.tss_radius
    )

    #################################################################
    # determine targets

    # read targets
    if options.targets_file is not None:
        targets_df = pd.read_table(options.targets_file, index_col=0)
        target_indexes = targets_df.index
    else:
        if gene_data.num_targets is None:
            print("No targets to test against.")
            exit(1)
        else:
            target_indexes = np.arange(gene_data.num_targets)

    #################################################################
    # correlation statistics

    t0 = time.time()
    print("Computing correlations.", end="")
    sys.stdout.flush()

    cor_table(
        gene_data.tss_targets,
        tss_preds,
        gene_data.target_ids,
        gene_data.target_labels,
        target_indexes,
        "%s/tss_cors.txt" % options.out_dir,
    )

    cor_table(
        gene_targets,
        gene_preds,
        gene_data.target_ids,
        gene_data.target_labels,
        target_indexes,
        "%s/gene_cors.txt" % options.out_dir,
        draw_plots=True,
    )

    print(" Done in %ds." % (time.time() - t0))

    #################################################################
    # gene statistics

    if options.print_tables:
        t0 = time.time()
        print("Printing predictions.", end="")
        sys.stdout.flush()

        gene_table(
            gene_data.tss_targets,
            tss_preds,
            gene_data.tss_ids(),
            gene_data.target_labels,
            target_indexes,
            "%s/transcript" % options.out_dir,
            options.plot_scatter,
        )

        gene_table(
            gene_targets,
            gene_preds,
            gene_data.gene_ids(),
            gene_data.target_labels,
            target_indexes,
            "%s/gene" % options.out_dir,
            options.plot_scatter,
        )

        print(" Done in %ds." % (time.time() - t0))

    #################################################################
    # gene x target heatmaps

    if options.plot_heat or options.gene_variance:
        #########################################
        # normalize predictions across targets

        t0 = time.time()
        print("Normalizing values across targets.", end="")
        sys.stdout.flush()

        gene_targets_qn = normalize_targets(
            gene_targets[:, target_indexes], log_pseudo=1
        )
        gene_preds_qn = normalize_targets(gene_preds[:, target_indexes], log_pseudo=1)

        print(" Done in %ds." % (time.time() - t0))

    if options.plot_heat:
        #########################################
        # plot genes by targets clustermap

        t0 = time.time()
        print("Plotting heat maps.", end="")
        sys.stdout.flush()

        sns.set(font_scale=1.3, style="ticks")
        plot_genes = 1600
        plot_targets = 800

        # choose a set of variable genes
        gene_vars = gene_preds_qn.var(axis=1)
        indexes_var = np.argsort(gene_vars)[::-1][:plot_genes]

        # choose a set of random genes
        if plot_genes < gene_preds_qn.shape[0]:
            indexes_rand = np.random.choice(
                np.arange(gene_preds_qn.shape[0]), plot_genes, replace=False
            )
        else:
            indexes_rand = np.arange(gene_preds_qn.shape[0])

        # choose a set of random targets
        if plot_targets < 0.8 * gene_preds_qn.shape[1]:
            indexes_targets = np.random.choice(
                np.arange(gene_preds_qn.shape[1]), plot_targets, replace=False
            )
        else:
            indexes_targets = np.arange(gene_preds_qn.shape[1])

        # variable gene predictions
        clustermap(
            gene_preds_qn[indexes_var, :][:, indexes_targets],
            "%s/gene_heat_var.pdf" % options.out_dir,
        )
        clustermap(
            gene_preds_qn[indexes_var, :][:, indexes_targets],
            "%s/gene_heat_var_color.pdf" % options.out_dir,
            color="viridis",
            table=True,
        )

        # random gene predictions
        clustermap(
            gene_preds_qn[indexes_rand, :][:, indexes_targets],
            "%s/gene_heat_rand.pdf" % options.out_dir,
        )

        # variable gene targets
        clustermap(
            gene_targets_qn[indexes_var, :][:, indexes_targets],
            "%s/gene_theat_var.pdf" % options.out_dir,
        )
        clustermap(
            gene_targets_qn[indexes_var, :][:, indexes_targets],
            "%s/gene_theat_var_color.pdf" % options.out_dir,
            color="viridis",
            table=True,
        )

        # random gene targets (crashes)
        # clustermap(gene_targets_qn[indexes_rand, :][:, indexes_targets],
        #            '%s/gene_theat_rand.pdf' % options.out_dir)

        print(" Done in %ds." % (time.time() - t0))

    #################################################################
    # analyze replicates

    if options.replicate_labels_file is not None:
        # read long form labels, from which to infer replicates
        target_labels_long = []
        for line in open(options.replicate_labels_file):
            a = line.split("\t")
            a[-1] = a[-1].rstrip()
            target_labels_long.append(a[-1])

        # determine replicates
        replicate_lists = infer_replicates(target_labels_long)

        # compute correlations
        # replicate_correlations(replicate_lists, gene_data.tss_targets,
        # tss_preds, options.target_indexes, '%s/transcript_reps' % options.out_dir)
        replicate_correlations(
            replicate_lists,
            gene_targets,
            gene_preds,
            target_indexes,
            "%s/gene_reps" % options.out_dir,
        )  # , scatter_plots=True)

    #################################################################
    # gene variance

    if options.gene_variance:
        variance_accuracy(gene_targets_qn, gene_preds_qn, "%s/gene" % options.out_dir)

    #################################################################
    # alternative TSS

    if options.tss_alt:
        alternative_tss(
            gene_data.tss_targets[:, target_indexes],
            tss_preds[:, target_indexes],
            gene_data,
            options.out_dir,
            log_pseudo=1,
        )