Exemple #1
0
def main():
    usage = 'usage: %prog [options] <params_file> <model_file> <genes_hdf5_file>'
    parser = OptionParser(usage)
    parser.add_option('-b',
                      dest='batch_size',
                      default=None,
                      type='int',
                      help='Batch size [Default: %default]')
    parser.add_option('-i',
                      dest='ignore_bed',
                      help='Ignore genes overlapping regions in this BED file')
    parser.add_option('-l',
                      dest='load_preds',
                      help='Load transcript_preds from file')
    parser.add_option('--heat',
                      dest='plot_heat',
                      default=False,
                      action='store_true',
                      help='Plot big gene-target heatmaps [Default: %default]')
    parser.add_option(
        '-o',
        dest='out_dir',
        default='genes_out',
        help='Output directory for tables and plots [Default: %default]')
    parser.add_option(
        '--rc',
        dest='rc',
        default=False,
        action='store_true',
        help=
        'Average the forward and reverse complement predictions when testing [Default: %default]'
    )
    parser.add_option(
        '-s',
        dest='plot_scatter',
        default=False,
        action='store_true',
        help='Make time-consuming accuracy scatter plots [Default: %default]')
    parser.add_option(
        '--rep',
        dest='replicate_labels_file',
        help=
        'Compare replicate experiments, aided by the given file with long labels'
    )
    parser.add_option(
        '-t',
        dest='target_indexes',
        default=None,
        help=
        'File or Comma-separated list of target indexes to scatter plot true versus predicted values'
    )
    parser.add_option(
        '--table',
        dest='print_tables',
        default=False,
        action='store_true',
        help='Print big gene/transcript tables [Default: %default]')
    parser.add_option(
        '-v',
        dest='gene_variance',
        default=False,
        action='store_true',
        help=
        'Study accuracy with respect to gene variance across targets [Default: %default]'
    )
    (options, args) = parser.parse_args()

    if len(args) != 3:
        parser.error(
            'Must provide parameters and model files, and genes HDF5 file')
    else:
        params_file = args[0]
        model_file = args[1]
        genes_hdf5_file = args[2]

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    #################################################################
    # read in genes and targets

    gene_data = basenji.genes.GeneData(genes_hdf5_file)

    # all targets
    if options.target_indexes is None:
        options.target_indexes = range(transcript_targets.shape[1])

    # file targets
    elif os.path.isfile(options.target_indexes):
        target_indexes_file = options.target_indexes
        options.target_indexes = []
        for line in open(target_indexes_file):
            options.target_indexes.append(int(line.split()[0]))

    # comma-separated targets
    else:
        options.target_indexes = [
            int(ti) for ti in options.target_indexes.split(',')
        ]

    options.target_indexes = np.array(options.target_indexes)

    # not doing it this way anymore
    # if options.ignore_bed:
    #     seqs_1hot, transcript_map, transcript_targets = ignore_trained_regions(options.ignore_bed, seq_coords, seqs_1hot, transcript_map, transcript_targets)

    #################################################################
    # transcript predictions

    if options.load_preds is not None:
        # load from file
        transcript_preds = np.load(options.load_preds)

    else:

        #######################################################
        # setup model

        t0 = time.time()
        print('Constructing model.', end='', flush=True)

        job = basenji.dna_io.read_job_params(params_file)

        job['batch_length'] = gene_data.seq_length
        job['seq_depth'] = gene_data.seq_depth
        job['target_pool'] = gene_data.pool_width
        job['num_targets'] = gene_data.num_targets

        # build model
        dr = basenji.seqnn.SeqNN()
        dr.build(job)

        if options.batch_size is not None:
            dr.batch_size = options.batch_size

        print(' Done in %ds' % (time.time() - t0), flush=True)

        #######################################################
        # predict transcripts

        print('Computing gene predictions.', end='', flush=True)

        # initialize batcher
        batcher = basenji.batcher.Batcher(gene_data.seqs_1hot,
                                          batch_size=dr.batch_size)

        # initialie saver
        saver = tf.train.Saver()

        with tf.Session() as sess:
            # load variables into session
            saver.restore(sess, model_file)

            # predict
            transcript_preds = dr.predict_genes(sess,
                                                batcher,
                                                gene_data.transcript_map,
                                                rc_avg=options.rc)

        # save to file
        np.save('%s/preds' % options.out_dir, transcript_preds)

        print(' Done in %ds.' % (time.time() - t0), flush=True)

    #################################################################
    # convert to genes

    gene_targets = map_transcripts_genes(gene_data.transcript_targets,
                                         gene_data.transcript_map,
                                         gene_data.transcript_gene_indexes)
    gene_preds = map_transcripts_genes(transcript_preds,
                                       gene_data.transcript_map,
                                       gene_data.transcript_gene_indexes)

    #################################################################
    # correlation statistics

    t0 = time.time()
    print('Computing correlations.', end='', flush=True)

    cor_table(gene_data.transcript_targets, transcript_preds,
              gene_data.target_labels, options.target_indexes,
              '%s/transcript_cors.txt' % options.out_dir)
    cor_table(gene_targets,
              gene_preds,
              gene_data.target_labels,
              options.target_indexes,
              '%s/gene_cors.txt' % options.out_dir,
              plots=True)

    print(' Done in %ds.' % (time.time() - t0), flush=True)

    #################################################################
    # gene statistics

    if options.print_tables:
        t0 = time.time()
        print('Printing predictions.', end='', flush=True)

        gene_table(gene_data.transcript_targets, transcript_preds,
                   gene_data.transcript_map.keys(), gene_data.target_labels,
                   options.target_indexes, '%s/transcript' % options.out_dir,
                   options.plot_scatter)

        gene_table(gene_targets, gene_preds, set(gene_data.genes),
                   gene_data.target_labels, options.target_indexes,
                   '%s/gene' % options.out_dir, options.plot_scatter)

        print(' Done in %ds.' % (time.time() - t0), flush=True)

    #################################################################
    # gene x target heatmaps

    if options.plot_heat or options.gene_variance:
        #########################################
        # normalize predictions across targets

        t0 = time.time()
        print('Normalizing values across targets.', end='', flush=True)

        gene_targets_qn = normalize_targets(
            gene_targets[:, options.target_indexes])
        gene_preds_qn = normalize_targets(gene_preds[:,
                                                     options.target_indexes])

        print(' Done in %ds.' % (time.time() - t0), flush=True)

    if options.plot_heat:
        #########################################
        # plot genes by targets clustermap

        t0 = time.time()
        print('Plotting heat maps.', end='', flush=True)

        sns.set(font_scale=1.3, style='ticks')
        plot_genes = 2000
        plot_targets = 1000

        # choose a set of variable genes
        gene_vars = gene_preds_qn.var(axis=1)
        indexes_var = np.argsort(gene_vars)[::-1][:plot_genes]

        # choose a set of random genes
        indexes_rand = np.random.choice(np.arange(gene_preds_qn.shape[0]),
                                        plot_genes,
                                        replace=False)

        # choose a set of random targets
        if plot_targets < 0.8 * gene_preds_qn.shape[1]:
            indexes_targets = np.random.choice(np.arange(
                gene_preds_qn.shape[1]),
                                               plot_targets,
                                               replace=False)
        else:
            indexes_targets = np.arange(gene_preds_qn.shape[1])

        # variable gene predictions
        clustermap(gene_preds_qn[indexes_var, :][:, indexes_targets],
                   '%s/gene_heat_var.pdf' % options.out_dir)
        clustermap(gene_preds_qn[indexes_var, :][:, indexes_targets],
                   '%s/gene_heat_var_color.pdf' % options.out_dir,
                   color='viridis',
                   table=True)

        # random gene predictions
        clustermap(gene_preds_qn[indexes_rand, :][:, indexes_targets],
                   '%s/gene_heat_rand.pdf' % options.out_dir)

        # variable gene targets
        clustermap(gene_targets_qn[indexes_var, :][:, indexes_targets],
                   '%s/gene_theat_var.pdf' % options.out_dir)
        clustermap(gene_targets_qn[indexes_var, :][:, indexes_targets],
                   '%s/gene_theat_var_color.pdf' % options.out_dir,
                   color='viridis',
                   table=True)

        # random gene targets
        clustermap(gene_targets_qn[indexes_rand, :][:, indexes_targets],
                   '%s/gene_theat_rand.pdf' % options.out_dir)

        print(' Done in %ds.' % (time.time() - t0), flush=True)

    #################################################################
    # analyze replicates

    if options.replicate_labels_file is not None:
        # read long form labels, from which to infer replicates
        target_labels_long = []
        for line in open(options.replicate_labels_file):
            a = line.split('\t')
            a[-1] = a[-1].rstrip()
            target_labels_long.append(a[-1])

        # determine replicates
        replicate_lists = infer_replicates(target_labels_long)

        # compute correlations
        # replicate_correlations(replicate_lists, gene_data.transcript_targets, transcript_preds, options.target_indexes, '%s/transcript_reps' % options.out_dir)
        replicate_correlations(replicate_lists, gene_targets, gene_preds,
                               options.target_indexes, '%s/gene_reps' %
                               options.out_dir)  # , scatter_plots=True)

    #################################################################
    # gene variance

    if options.gene_variance:
        variance_accuracy(gene_targets_qn, gene_preds_qn,
                          '%s/gene' % options.out_dir)
Exemple #2
0
def main():
  usage = 'usage: %prog [options] <params_file> <model_file> <genes_hdf5_file>'
  parser = OptionParser(usage)
  parser.add_option(
      '-b',
      dest='batch_size',
      default=None,
      type='int',
      help='Batch size [Default: %default]')
  parser.add_option(
      '-i',
      dest='ignore_bed',
      help='Ignore genes overlapping regions in this BED file')
  parser.add_option(
      '-l', dest='load_preds', help='Load tess_preds from file')
  parser.add_option(
      '--heat',
      dest='plot_heat',
      default=False,
      action='store_true',
      help='Plot big gene-target heatmaps [Default: %default]')
  parser.add_option(
      '-o',
      dest='out_dir',
      default='genes_out',
      help='Output directory for tables and plots [Default: %default]')
  parser.add_option(
      '-r',
      dest='tss_radius',
      default=0,
      type='int',
      help='Radius of bins considered to quantify TSS transcription [Default: %default]')
  parser.add_option(
      '--rc',
      dest='rc',
      default=False,
      action='store_true',
      help=
      'Average the forward and reverse complement predictions when testing [Default: %default]'
  )
  parser.add_option(
      '-s',
      dest='plot_scatter',
      default=False,
      action='store_true',
      help='Make time-consuming accuracy scatter plots [Default: %default]')
  parser.add_option(
      '--shifts',
      dest='shifts',
      default='0',
      help='Ensemble prediction shifts [Default: %default]')
  parser.add_option(
      '--rep',
      dest='replicate_labels_file',
      help=
      'Compare replicate experiments, aided by the given file with long labels')
  parser.add_option(
      '-t',
      dest='target_indexes',
      default=None,
      help=
      'File or Comma-separated list of target indexes to scatter plot true versus predicted values'
  )
  parser.add_option(
      '--table',
      dest='print_tables',
      default=False,
      action='store_true',
      help='Print big gene/TSS tables [Default: %default]')
  parser.add_option(
      '--tss',
      dest='tss_alt',
      default=False,
      action='store_true',
      help='Perform alternative TSS analysis [Default: %default]')
  parser.add_option(
      '-v',
      dest='gene_variance',
      default=False,
      action='store_true',
      help=
      'Study accuracy with respect to gene variance across targets [Default: %default]'
  )
  (options, args) = parser.parse_args()

  if len(args) != 3:
    parser.error('Must provide parameters and model files, and genes HDF5 file')
  else:
    params_file = args[0]
    model_file = args[1]
    genes_hdf5_file = args[2]

  if not os.path.isdir(options.out_dir):
    os.mkdir(options.out_dir)

  options.shifts = [int(shift) for shift in options.shifts.split(',')]

  #################################################################
  # read in genes and targets

  gene_data = genedata.GeneData(genes_hdf5_file)


  #################################################################
  # TSS predictions

  if options.load_preds is not None:
    # load from file
    tss_preds = np.load(options.load_preds)

  else:

    #######################################################
    # setup model

    job = params.read_job_params(params_file)

    job['seq_length'] = gene_data.seq_length
    job['seq_depth'] = gene_data.seq_depth
    job['target_pool'] = gene_data.pool_width
    if not 'num_targets' in job:
      job['num_targets'] = gene_data.num_targets

    # build model
    model = seqnn.SeqNN()
    model.build(job)

    if options.batch_size is not None:
      model.hp.batch_size = options.batch_size


    #######################################################
    # predict TSSs

    t0 = time.time()
    print('Computing gene predictions.', end='')
    sys.stdout.flush()

    # initialize batcher
    gene_batcher = batcher.Batcher(
        gene_data.seqs_1hot, batch_size=model.hp.batch_size)

    # initialie saver
    saver = tf.train.Saver()

    with tf.Session() as sess:
      # load variables into session
      saver.restore(sess, model_file)

      # predict
      tss_preds = model.predict_genes(sess, gene_batcher, gene_data.gene_seqs,
          rc=options.rc, shifts=options.shifts, tss_radius=options.tss_radius)

    # save to file
    np.save('%s/preds' % options.out_dir, tss_preds)

    print(' Done in %ds.' % (time.time() - t0))

  #################################################################
  # convert to genes

  gene_targets, _ = gene.map_tss_genes(gene_data.tss_targets, gene_data.tss,
                                       tss_radius=options.tss_radius)
  gene_preds, _ = gene.map_tss_genes(tss_preds, gene_data.tss,
                                     tss_radius=options.tss_radius)


  #################################################################
  # determine targets

  # all targets
  if options.target_indexes is None:
    if gene_data.num_targets is None:
      print('No targets to test against')
      exit()
    else:
      options.target_indexes = np.arange(gene_data.num_targets)

  # file targets
  elif os.path.isfile(options.target_indexes):
    target_indexes_file = options.target_indexes
    options.target_indexes = []
    for line in open(target_indexes_file):
      options.target_indexes.append(int(line.split()[0]))

  # comma-separated targets
  else:
    options.target_indexes = [
        int(ti) for ti in options.target_indexes.split(',')
    ]

  options.target_indexes = np.array(options.target_indexes)

  #################################################################
  # correlation statistics

  t0 = time.time()
  print('Computing correlations.', end='')
  sys.stdout.flush()

  cor_table(gene_data.tss_targets, tss_preds, gene_data.target_ids,
            gene_data.target_labels, options.target_indexes,
            '%s/tss_cors.txt' % options.out_dir)

  cor_table(gene_targets, gene_preds, gene_data.target_ids,
            gene_data.target_labels, options.target_indexes,
            '%s/gene_cors.txt' % options.out_dir, plots=True)

  print(' Done in %ds.' % (time.time() - t0))

  #################################################################
  # gene statistics

  if options.print_tables:
    t0 = time.time()
    print('Printing predictions.', end='')
    sys.stdout.flush()

    gene_table(gene_data.tss_targets, tss_preds, gene_data.tss_ids(),
               gene_data.target_labels, options.target_indexes,
               '%s/transcript' % options.out_dir, options.plot_scatter)

    gene_table(gene_targets, gene_preds,
               gene_data.gene_ids(), gene_data.target_labels,
               options.target_indexes, '%s/gene' % options.out_dir,
               options.plot_scatter)

    print(' Done in %ds.' % (time.time() - t0))

  #################################################################
  # gene x target heatmaps

  if options.plot_heat or options.gene_variance:
    #########################################
    # normalize predictions across targets

    t0 = time.time()
    print('Normalizing values across targets.', end='')
    sys.stdout.flush()

    gene_targets_qn = normalize_targets(gene_targets[:, options.target_indexes], log_pseudo=1)
    gene_preds_qn = normalize_targets(gene_preds[:, options.target_indexes], log_pseudo=1)

    print(' Done in %ds.' % (time.time() - t0))

  if options.plot_heat:
    #########################################
    # plot genes by targets clustermap

    t0 = time.time()
    print('Plotting heat maps.', end='')
    sys.stdout.flush()

    sns.set(font_scale=1.3, style='ticks')
    plot_genes = 1600
    plot_targets = 800

    # choose a set of variable genes
    gene_vars = gene_preds_qn.var(axis=1)
    indexes_var = np.argsort(gene_vars)[::-1][:plot_genes]

    # choose a set of random genes
    if plot_genes < gene_preds_qn.shape[0]:
      indexes_rand = np.random.choice(
        np.arange(gene_preds_qn.shape[0]), plot_genes, replace=False)
    else:
      indexes_rand = np.arange(gene_preds_qn.shape[0])

    # choose a set of random targets
    if plot_targets < 0.8 * gene_preds_qn.shape[1]:
      indexes_targets = np.random.choice(
          np.arange(gene_preds_qn.shape[1]), plot_targets, replace=False)
    else:
      indexes_targets = np.arange(gene_preds_qn.shape[1])

    # variable gene predictions
    clustermap(gene_preds_qn[indexes_var, :][:, indexes_targets],
               '%s/gene_heat_var.pdf' % options.out_dir)
    clustermap(
        gene_preds_qn[indexes_var, :][:, indexes_targets],
        '%s/gene_heat_var_color.pdf' % options.out_dir,
        color='viridis',
        table=True)

    # random gene predictions
    clustermap(gene_preds_qn[indexes_rand, :][:, indexes_targets],
               '%s/gene_heat_rand.pdf' % options.out_dir)

    # variable gene targets
    clustermap(gene_targets_qn[indexes_var, :][:, indexes_targets],
               '%s/gene_theat_var.pdf' % options.out_dir)
    clustermap(
        gene_targets_qn[indexes_var, :][:, indexes_targets],
        '%s/gene_theat_var_color.pdf' % options.out_dir,
        color='viridis',
        table=True)

    # random gene targets (crashes)
    # clustermap(gene_targets_qn[indexes_rand, :][:, indexes_targets],
    #            '%s/gene_theat_rand.pdf' % options.out_dir)

    print(' Done in %ds.' % (time.time() - t0))

  #################################################################
  # analyze replicates

  if options.replicate_labels_file is not None:
    # read long form labels, from which to infer replicates
    target_labels_long = []
    for line in open(options.replicate_labels_file):
      a = line.split('\t')
      a[-1] = a[-1].rstrip()
      target_labels_long.append(a[-1])

    # determine replicates
    replicate_lists = infer_replicates(target_labels_long)

    # compute correlations
    # replicate_correlations(replicate_lists, gene_data.tss_targets,
        # tss_preds, options.target_indexes, '%s/transcript_reps' % options.out_dir)
    replicate_correlations(
        replicate_lists, gene_targets, gene_preds, options.target_indexes,
        '%s/gene_reps' % options.out_dir)  # , scatter_plots=True)

  #################################################################
  # gene variance

  if options.gene_variance:
    variance_accuracy(gene_targets_qn, gene_preds_qn,
                      '%s/gene' % options.out_dir)

  #################################################################
  # alternative TSS

  if options.tss_alt:
    alternative_tss(gene_data.tss_targets[:,options.target_indexes],
                    tss_preds[:,options.target_indexes], gene_data,
                    options.out_dir, log_pseudo=1)
def main():
    usage = "usage: %prog [options] <params_file> <model_file> <genes_hdf5_file>"
    parser = OptionParser(usage)
    parser.add_option(
        "-b",
        dest="batch_size",
        default=None,
        type="int",
        help="Batch size [Default: %default]",
    )
    parser.add_option(
        "-i",
        dest="ignore_bed",
        help="Ignore genes overlapping regions in this BED file",
    )
    parser.add_option("-l", dest="load_preds", help="Load tess_preds from file")
    parser.add_option(
        "--heat",
        dest="plot_heat",
        default=False,
        action="store_true",
        help="Plot big gene-target heatmaps [Default: %default]",
    )
    parser.add_option(
        "-o",
        dest="out_dir",
        default="genes_out",
        help="Output directory for tables and plots [Default: %default]",
    )
    parser.add_option(
        "-r",
        dest="tss_radius",
        default=0,
        type="int",
        help="Radius of bins considered to quantify TSS transcription [Default: %default]",
    )
    parser.add_option(
        "--rc",
        dest="rc",
        default=False,
        action="store_true",
        help="Average the forward and reverse complement predictions when testing [Default: %default]",
    )
    parser.add_option(
        "-s",
        dest="plot_scatter",
        default=False,
        action="store_true",
        help="Make time-consuming accuracy scatter plots [Default: %default]",
    )
    parser.add_option(
        "--shifts",
        dest="shifts",
        default="0",
        help="Ensemble prediction shifts [Default: %default]",
    )
    parser.add_option(
        "--rep",
        dest="replicate_labels_file",
        help="Compare replicate experiments, aided by the given file with long labels",
    )
    parser.add_option(
        "-t",
        dest="targets_file",
        default=None,
        type="str",
        help="File specifying target indexes and labels in table format",
    )
    parser.add_option(
        "--table",
        dest="print_tables",
        default=False,
        action="store_true",
        help="Print big gene/TSS tables [Default: %default]",
    )
    parser.add_option(
        "--tss",
        dest="tss_alt",
        default=False,
        action="store_true",
        help="Perform alternative TSS analysis [Default: %default]",
    )
    parser.add_option(
        "-v",
        dest="gene_variance",
        default=False,
        action="store_true",
        help="Study accuracy with respect to gene variance across targets [Default: %default]",
    )
    (options, args) = parser.parse_args()

    if len(args) != 3:
        parser.error("Must provide parameters and model files, and genes HDF5 file")
    else:
        params_file = args[0]
        model_file = args[1]
        genes_hdf5_file = args[2]

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    options.shifts = [int(shift) for shift in options.shifts.split(",")]

    #################################################################
    # read in genes and targets

    gene_data = genedata.GeneData(genes_hdf5_file)

    #################################################################
    # TSS predictions

    if options.load_preds is not None:
        # load from file
        tss_preds = np.load(options.load_preds)

    else:

        #######################################################
        # setup model

        job = params.read_job_params(params_file)

        job["seq_length"] = gene_data.seq_length
        job["seq_depth"] = gene_data.seq_depth
        job["target_pool"] = gene_data.pool_width
        if not "num_targets" in job:
            job["num_targets"] = gene_data.num_targets

        # build model
        model = seqnn.SeqNN()
        model.build_feed_sad(job)

        if options.batch_size is not None:
            model.hp.batch_size = options.batch_size

        #######################################################
        # predict TSSs

        t0 = time.time()
        print("Computing gene predictions.", end="")
        sys.stdout.flush()

        # initialize batcher
        gene_batcher = batcher.Batcher(
            gene_data.seqs_1hot, batch_size=model.hp.batch_size
        )

        # initialie saver
        saver = tf.train.Saver()

        with tf.Session() as sess:
            # load variables into session
            saver.restore(sess, model_file)

            # predict
            tss_preds = model.predict_genes(
                sess,
                gene_batcher,
                gene_data.gene_seqs,
                rc=options.rc,
                shifts=options.shifts,
                tss_radius=options.tss_radius,
            )

        # save to file
        np.save("%s/preds" % options.out_dir, tss_preds)

        print(" Done in %ds." % (time.time() - t0))

    # up-convert
    tss_preds = tss_preds.astype("float32")

    #################################################################
    # convert to genes

    gene_targets, _ = gene.map_tss_genes(
        gene_data.tss_targets, gene_data.tss, tss_radius=options.tss_radius
    )
    gene_preds, _ = gene.map_tss_genes(
        tss_preds, gene_data.tss, tss_radius=options.tss_radius
    )

    #################################################################
    # determine targets

    # read targets
    if options.targets_file is not None:
        targets_df = pd.read_table(options.targets_file, index_col=0)
        target_indexes = targets_df.index
    else:
        if gene_data.num_targets is None:
            print("No targets to test against.")
            exit(1)
        else:
            target_indexes = np.arange(gene_data.num_targets)

    #################################################################
    # correlation statistics

    t0 = time.time()
    print("Computing correlations.", end="")
    sys.stdout.flush()

    cor_table(
        gene_data.tss_targets,
        tss_preds,
        gene_data.target_ids,
        gene_data.target_labels,
        target_indexes,
        "%s/tss_cors.txt" % options.out_dir,
    )

    cor_table(
        gene_targets,
        gene_preds,
        gene_data.target_ids,
        gene_data.target_labels,
        target_indexes,
        "%s/gene_cors.txt" % options.out_dir,
        draw_plots=True,
    )

    print(" Done in %ds." % (time.time() - t0))

    #################################################################
    # gene statistics

    if options.print_tables:
        t0 = time.time()
        print("Printing predictions.", end="")
        sys.stdout.flush()

        gene_table(
            gene_data.tss_targets,
            tss_preds,
            gene_data.tss_ids(),
            gene_data.target_labels,
            target_indexes,
            "%s/transcript" % options.out_dir,
            options.plot_scatter,
        )

        gene_table(
            gene_targets,
            gene_preds,
            gene_data.gene_ids(),
            gene_data.target_labels,
            target_indexes,
            "%s/gene" % options.out_dir,
            options.plot_scatter,
        )

        print(" Done in %ds." % (time.time() - t0))

    #################################################################
    # gene x target heatmaps

    if options.plot_heat or options.gene_variance:
        #########################################
        # normalize predictions across targets

        t0 = time.time()
        print("Normalizing values across targets.", end="")
        sys.stdout.flush()

        gene_targets_qn = normalize_targets(
            gene_targets[:, target_indexes], log_pseudo=1
        )
        gene_preds_qn = normalize_targets(gene_preds[:, target_indexes], log_pseudo=1)

        print(" Done in %ds." % (time.time() - t0))

    if options.plot_heat:
        #########################################
        # plot genes by targets clustermap

        t0 = time.time()
        print("Plotting heat maps.", end="")
        sys.stdout.flush()

        sns.set(font_scale=1.3, style="ticks")
        plot_genes = 1600
        plot_targets = 800

        # choose a set of variable genes
        gene_vars = gene_preds_qn.var(axis=1)
        indexes_var = np.argsort(gene_vars)[::-1][:plot_genes]

        # choose a set of random genes
        if plot_genes < gene_preds_qn.shape[0]:
            indexes_rand = np.random.choice(
                np.arange(gene_preds_qn.shape[0]), plot_genes, replace=False
            )
        else:
            indexes_rand = np.arange(gene_preds_qn.shape[0])

        # choose a set of random targets
        if plot_targets < 0.8 * gene_preds_qn.shape[1]:
            indexes_targets = np.random.choice(
                np.arange(gene_preds_qn.shape[1]), plot_targets, replace=False
            )
        else:
            indexes_targets = np.arange(gene_preds_qn.shape[1])

        # variable gene predictions
        clustermap(
            gene_preds_qn[indexes_var, :][:, indexes_targets],
            "%s/gene_heat_var.pdf" % options.out_dir,
        )
        clustermap(
            gene_preds_qn[indexes_var, :][:, indexes_targets],
            "%s/gene_heat_var_color.pdf" % options.out_dir,
            color="viridis",
            table=True,
        )

        # random gene predictions
        clustermap(
            gene_preds_qn[indexes_rand, :][:, indexes_targets],
            "%s/gene_heat_rand.pdf" % options.out_dir,
        )

        # variable gene targets
        clustermap(
            gene_targets_qn[indexes_var, :][:, indexes_targets],
            "%s/gene_theat_var.pdf" % options.out_dir,
        )
        clustermap(
            gene_targets_qn[indexes_var, :][:, indexes_targets],
            "%s/gene_theat_var_color.pdf" % options.out_dir,
            color="viridis",
            table=True,
        )

        # random gene targets (crashes)
        # clustermap(gene_targets_qn[indexes_rand, :][:, indexes_targets],
        #            '%s/gene_theat_rand.pdf' % options.out_dir)

        print(" Done in %ds." % (time.time() - t0))

    #################################################################
    # analyze replicates

    if options.replicate_labels_file is not None:
        # read long form labels, from which to infer replicates
        target_labels_long = []
        for line in open(options.replicate_labels_file):
            a = line.split("\t")
            a[-1] = a[-1].rstrip()
            target_labels_long.append(a[-1])

        # determine replicates
        replicate_lists = infer_replicates(target_labels_long)

        # compute correlations
        # replicate_correlations(replicate_lists, gene_data.tss_targets,
        # tss_preds, options.target_indexes, '%s/transcript_reps' % options.out_dir)
        replicate_correlations(
            replicate_lists,
            gene_targets,
            gene_preds,
            target_indexes,
            "%s/gene_reps" % options.out_dir,
        )  # , scatter_plots=True)

    #################################################################
    # gene variance

    if options.gene_variance:
        variance_accuracy(gene_targets_qn, gene_preds_qn, "%s/gene" % options.out_dir)

    #################################################################
    # alternative TSS

    if options.tss_alt:
        alternative_tss(
            gene_data.tss_targets[:, target_indexes],
            tss_preds[:, target_indexes],
            gene_data,
            options.out_dir,
            log_pseudo=1,
        )