def main():
  usage = 'usage: %prog [options] <params_file> <in_model_tf> <out_model_tf>'
  parser = OptionParser(usage)
  (options, args) = parser.parse_args()

  if len(args) != 3:
    parser.error('Must provide parameters file and input and out model stems.')
  else:
    params_file = args[0]
    in_model_tf = args[1]
    out_model_tf = args[2]

  # read parameters
  job = params.read_job_params(params_file)
  model = seqnn.SeqNN()
  model.build(job)

  # transform variables names
  restore_dict = {}
  for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
    # names have ":0" suffix that Saver dislikes.
    v_key = v.name.split(':')[0]

    if v_key == 'global_step':
      pass
    elif v_key.startswith('final'):
      # conv1d to dense
      v_key = v_key.replace('dense', 'conv1d')
      restore_dict[v_key] = v
    else:
      restore_dict[v_key] = v

  # initialize savers (reshape is critical for conv1d -> dense)
  saver_read = tf.train.Saver(restore_dict, reshape=True)
  saver_write = tf.train.Saver()

  with tf.Session() as sess:
    # initialize variables
    sess.run(tf.global_variables_initializer())

    # load variables into session
    saver_read.restore(sess, in_model_tf)

    # re-save w/ new names
    saver_write.save(sess, out_model_tf)
Esempio n. 2
0
def main():
    usage = 'usage: %prog [options] <params_file>'
    parser = OptionParser(usage)
    (options, args) = parser.parse_args()

    if len(args) != 1:
        parser.error('Must provide parameters file.')
    else:
        params_file = args[0]

    #######################################################
    # model parameters and placeholders

    job = params.read_job_params(params_file)
    model = seqnn.SeqNN()
    model.build(job)

    for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
        print(v.name, v.shape)
Esempio n. 3
0
def main():
  usage = 'usage: %prog [options] <params_file> <model_file> <genes_hdf5_file>'
  parser = OptionParser(usage)
  parser.add_option(
      '-b',
      dest='batch_size',
      default=None,
      type='int',
      help='Batch size [Default: %default]')
  parser.add_option(
      '-i',
      dest='ignore_bed',
      help='Ignore genes overlapping regions in this BED file')
  parser.add_option(
      '-l', dest='load_preds', help='Load tess_preds from file')
  parser.add_option(
      '--heat',
      dest='plot_heat',
      default=False,
      action='store_true',
      help='Plot big gene-target heatmaps [Default: %default]')
  parser.add_option(
      '-o',
      dest='out_dir',
      default='genes_out',
      help='Output directory for tables and plots [Default: %default]')
  parser.add_option(
      '-r',
      dest='tss_radius',
      default=0,
      type='int',
      help='Radius of bins considered to quantify TSS transcription [Default: %default]')
  parser.add_option(
      '--rc',
      dest='rc',
      default=False,
      action='store_true',
      help=
      'Average the forward and reverse complement predictions when testing [Default: %default]'
  )
  parser.add_option(
      '-s',
      dest='plot_scatter',
      default=False,
      action='store_true',
      help='Make time-consuming accuracy scatter plots [Default: %default]')
  parser.add_option(
      '--shifts',
      dest='shifts',
      default='0',
      help='Ensemble prediction shifts [Default: %default]')
  parser.add_option(
      '--rep',
      dest='replicate_labels_file',
      help=
      'Compare replicate experiments, aided by the given file with long labels')
  parser.add_option(
      '-t',
      dest='target_indexes',
      default=None,
      help=
      'File or Comma-separated list of target indexes to scatter plot true versus predicted values'
  )
  parser.add_option(
      '--table',
      dest='print_tables',
      default=False,
      action='store_true',
      help='Print big gene/TSS tables [Default: %default]')
  parser.add_option(
      '--tss',
      dest='tss_alt',
      default=False,
      action='store_true',
      help='Perform alternative TSS analysis [Default: %default]')
  parser.add_option(
      '-v',
      dest='gene_variance',
      default=False,
      action='store_true',
      help=
      'Study accuracy with respect to gene variance across targets [Default: %default]'
  )
  (options, args) = parser.parse_args()

  if len(args) != 3:
    parser.error('Must provide parameters and model files, and genes HDF5 file')
  else:
    params_file = args[0]
    model_file = args[1]
    genes_hdf5_file = args[2]

  if not os.path.isdir(options.out_dir):
    os.mkdir(options.out_dir)

  options.shifts = [int(shift) for shift in options.shifts.split(',')]

  #################################################################
  # read in genes and targets

  gene_data = genedata.GeneData(genes_hdf5_file)


  #################################################################
  # TSS predictions

  if options.load_preds is not None:
    # load from file
    tss_preds = np.load(options.load_preds)

  else:

    #######################################################
    # setup model

    job = params.read_job_params(params_file)

    job['seq_length'] = gene_data.seq_length
    job['seq_depth'] = gene_data.seq_depth
    job['target_pool'] = gene_data.pool_width
    if not 'num_targets' in job:
      job['num_targets'] = gene_data.num_targets

    # build model
    model = seqnn.SeqNN()
    model.build(job)

    if options.batch_size is not None:
      model.hp.batch_size = options.batch_size


    #######################################################
    # predict TSSs

    t0 = time.time()
    print('Computing gene predictions.', end='')
    sys.stdout.flush()

    # initialize batcher
    gene_batcher = batcher.Batcher(
        gene_data.seqs_1hot, batch_size=model.hp.batch_size)

    # initialie saver
    saver = tf.train.Saver()

    with tf.Session() as sess:
      # load variables into session
      saver.restore(sess, model_file)

      # predict
      tss_preds = model.predict_genes(sess, gene_batcher, gene_data.gene_seqs,
          rc=options.rc, shifts=options.shifts, tss_radius=options.tss_radius)

    # save to file
    np.save('%s/preds' % options.out_dir, tss_preds)

    print(' Done in %ds.' % (time.time() - t0))

  #################################################################
  # convert to genes

  gene_targets, _ = gene.map_tss_genes(gene_data.tss_targets, gene_data.tss,
                                       tss_radius=options.tss_radius)
  gene_preds, _ = gene.map_tss_genes(tss_preds, gene_data.tss,
                                     tss_radius=options.tss_radius)


  #################################################################
  # determine targets

  # all targets
  if options.target_indexes is None:
    if gene_data.num_targets is None:
      print('No targets to test against')
      exit()
    else:
      options.target_indexes = np.arange(gene_data.num_targets)

  # file targets
  elif os.path.isfile(options.target_indexes):
    target_indexes_file = options.target_indexes
    options.target_indexes = []
    for line in open(target_indexes_file):
      options.target_indexes.append(int(line.split()[0]))

  # comma-separated targets
  else:
    options.target_indexes = [
        int(ti) for ti in options.target_indexes.split(',')
    ]

  options.target_indexes = np.array(options.target_indexes)

  #################################################################
  # correlation statistics

  t0 = time.time()
  print('Computing correlations.', end='')
  sys.stdout.flush()

  cor_table(gene_data.tss_targets, tss_preds, gene_data.target_ids,
            gene_data.target_labels, options.target_indexes,
            '%s/tss_cors.txt' % options.out_dir)

  cor_table(gene_targets, gene_preds, gene_data.target_ids,
            gene_data.target_labels, options.target_indexes,
            '%s/gene_cors.txt' % options.out_dir, plots=True)

  print(' Done in %ds.' % (time.time() - t0))

  #################################################################
  # gene statistics

  if options.print_tables:
    t0 = time.time()
    print('Printing predictions.', end='')
    sys.stdout.flush()

    gene_table(gene_data.tss_targets, tss_preds, gene_data.tss_ids(),
               gene_data.target_labels, options.target_indexes,
               '%s/transcript' % options.out_dir, options.plot_scatter)

    gene_table(gene_targets, gene_preds,
               gene_data.gene_ids(), gene_data.target_labels,
               options.target_indexes, '%s/gene' % options.out_dir,
               options.plot_scatter)

    print(' Done in %ds.' % (time.time() - t0))

  #################################################################
  # gene x target heatmaps

  if options.plot_heat or options.gene_variance:
    #########################################
    # normalize predictions across targets

    t0 = time.time()
    print('Normalizing values across targets.', end='')
    sys.stdout.flush()

    gene_targets_qn = normalize_targets(gene_targets[:, options.target_indexes], log_pseudo=1)
    gene_preds_qn = normalize_targets(gene_preds[:, options.target_indexes], log_pseudo=1)

    print(' Done in %ds.' % (time.time() - t0))

  if options.plot_heat:
    #########################################
    # plot genes by targets clustermap

    t0 = time.time()
    print('Plotting heat maps.', end='')
    sys.stdout.flush()

    sns.set(font_scale=1.3, style='ticks')
    plot_genes = 1600
    plot_targets = 800

    # choose a set of variable genes
    gene_vars = gene_preds_qn.var(axis=1)
    indexes_var = np.argsort(gene_vars)[::-1][:plot_genes]

    # choose a set of random genes
    if plot_genes < gene_preds_qn.shape[0]:
      indexes_rand = np.random.choice(
        np.arange(gene_preds_qn.shape[0]), plot_genes, replace=False)
    else:
      indexes_rand = np.arange(gene_preds_qn.shape[0])

    # choose a set of random targets
    if plot_targets < 0.8 * gene_preds_qn.shape[1]:
      indexes_targets = np.random.choice(
          np.arange(gene_preds_qn.shape[1]), plot_targets, replace=False)
    else:
      indexes_targets = np.arange(gene_preds_qn.shape[1])

    # variable gene predictions
    clustermap(gene_preds_qn[indexes_var, :][:, indexes_targets],
               '%s/gene_heat_var.pdf' % options.out_dir)
    clustermap(
        gene_preds_qn[indexes_var, :][:, indexes_targets],
        '%s/gene_heat_var_color.pdf' % options.out_dir,
        color='viridis',
        table=True)

    # random gene predictions
    clustermap(gene_preds_qn[indexes_rand, :][:, indexes_targets],
               '%s/gene_heat_rand.pdf' % options.out_dir)

    # variable gene targets
    clustermap(gene_targets_qn[indexes_var, :][:, indexes_targets],
               '%s/gene_theat_var.pdf' % options.out_dir)
    clustermap(
        gene_targets_qn[indexes_var, :][:, indexes_targets],
        '%s/gene_theat_var_color.pdf' % options.out_dir,
        color='viridis',
        table=True)

    # random gene targets (crashes)
    # clustermap(gene_targets_qn[indexes_rand, :][:, indexes_targets],
    #            '%s/gene_theat_rand.pdf' % options.out_dir)

    print(' Done in %ds.' % (time.time() - t0))

  #################################################################
  # analyze replicates

  if options.replicate_labels_file is not None:
    # read long form labels, from which to infer replicates
    target_labels_long = []
    for line in open(options.replicate_labels_file):
      a = line.split('\t')
      a[-1] = a[-1].rstrip()
      target_labels_long.append(a[-1])

    # determine replicates
    replicate_lists = infer_replicates(target_labels_long)

    # compute correlations
    # replicate_correlations(replicate_lists, gene_data.tss_targets,
        # tss_preds, options.target_indexes, '%s/transcript_reps' % options.out_dir)
    replicate_correlations(
        replicate_lists, gene_targets, gene_preds, options.target_indexes,
        '%s/gene_reps' % options.out_dir)  # , scatter_plots=True)

  #################################################################
  # gene variance

  if options.gene_variance:
    variance_accuracy(gene_targets_qn, gene_preds_qn,
                      '%s/gene' % options.out_dir)

  #################################################################
  # alternative TSS

  if options.tss_alt:
    alternative_tss(gene_data.tss_targets[:,options.target_indexes],
                    tss_preds[:,options.target_indexes], gene_data,
                    options.out_dir, log_pseudo=1)
Esempio n. 4
0
def main():
    usage = 'usage: %prog [options] <params_file> <model_file> <vcf_file>'
    parser = OptionParser(usage)
    parser.add_option(
        '-d',
        dest='mut_down',
        default=0,
        type='int',
        help=
        'Nucleotides downstream of center sequence to mutate [Default: %default]'
    )
    parser.add_option('-f',
                      dest='figure_width',
                      default=20,
                      type='float',
                      help='Figure width [Default: %default]')
    parser.add_option(
        '--f1',
        dest='genome1_fasta',
        default='%s/data/hg38.fa' % os.environ['BASENJIDIR'],
        help='Genome FASTA which which major allele sequences will be drawn')
    parser.add_option(
        '--f2',
        dest='genome2_fasta',
        default=None,
        help='Genome FASTA which which minor allele sequences will be drawn')
    parser.add_option(
        '-l',
        dest='mut_len',
        default=200,
        type='int',
        help='Length of centered sequence to mutate [Default: %default]')
    parser.add_option('-o',
                      dest='out_dir',
                      default='sat_vcf',
                      help='Output directory [Default: %default]')
    parser.add_option(
        '--rc',
        dest='rc',
        default=False,
        action='store_true',
        help=
        'Ensemble forward and reverse complement predictions [Default: %default]'
    )
    parser.add_option('--shifts',
                      dest='shifts',
                      default='0',
                      help='Ensemble prediction shifts [Default: %default]')
    parser.add_option(
        '--stats',
        dest='sad_stats',
        default='sum',
        help='Comma-separated list of stats to save. [Default: %default]')
    parser.add_option(
        '-t',
        dest='targets_file',
        default=None,
        type='str',
        help='File specifying target indexes and labels in table format')
    parser.add_option(
        '-u',
        dest='mut_up',
        default=0,
        type='int',
        help=
        'Nucleotides upstream of center sequence to mutate [Default: %default]'
    )
    (options, args) = parser.parse_args()

    if len(args) != 3:
        parser.error('Must provide parameters and model files and VCF')
    else:
        params_file = args[0]
        model_file = args[1]
        vcf_file = args[2]

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    options.shifts = [int(shift) for shift in options.shifts.split(',')]
    options.sad_stats = [
        sad_stat.lower() for sad_stat in options.sad_stats.split(',')
    ]

    if options.mut_up > 0 or options.mut_down > 0:
        options.mut_len = options.mut_up + options.mut_down
    else:
        assert (options.mut_len > 0)
        options.mut_up = options.mut_len // 2
        options.mut_down = options.mut_len - options.mut_up

    #################################################################
    # read parameters and targets

    # read model parameters
    with open(params_file) as params_open:
        params = json.load(params_open)
    params_model = params['model']
    params_train = params['train']

    # read targets
    if options.targets_file is None:
        target_slice = None
    else:
        targets_df = pd.read_table(options.targets_file, index_col=0)
        target_slice = targets_df.index

    #################################################################
    # setup model

    seqnn_model = seqnn.SeqNN(params_model)
    seqnn_model.restore(model_file)
    seqnn_model.build_slice(target_slice)
    seqnn_model.build_ensemble(options.rc, options.shifts)

    num_targets = seqnn_model.num_targets()

    #################################################################
    # SNP sequence dataset

    # load SNPs
    snps = vcf.vcf_snps(vcf_file)

    # get one hot coded input sequences
    if not options.genome2_fasta:
        seqs_1hot, seq_headers, snps, seqs_dna = vcf.snps_seq1(
            snps,
            params_model['seq_length'],
            options.genome1_fasta,
            return_seqs=True)
    else:
        seqs_1hot, seq_headers, snps, seqs_dna = vcf.snps2_seq1(
            snps,
            params_model['seq_length'],
            options.genome1_fasta,
            options.genome2_fasta,
            return_seqs=True)
    num_seqs = seqs_1hot.shape[0]

    # determine mutation region limits
    seq_mid = params_model['seq_length'] // 2
    mut_start = seq_mid - options.mut_up
    mut_end = mut_start + options.mut_len

    # make sequence generator
    seqs_gen = satmut_gen(seqs_dna, mut_start, mut_end)

    #################################################################
    # setup output

    scores_h5_file = '%s/scores.h5' % options.out_dir
    if os.path.isfile(scores_h5_file):
        os.remove(scores_h5_file)
    scores_h5 = h5py.File(scores_h5_file, 'w')
    scores_h5.create_dataset('label', data=np.array(seq_headers, dtype='S'))
    scores_h5.create_dataset('seqs',
                             dtype='bool',
                             shape=(num_seqs, options.mut_len, 4))
    for sad_stat in options.sad_stats:
        scores_h5.create_dataset(sad_stat,
                                 dtype='float16',
                                 shape=(num_seqs, options.mut_len, 4,
                                        num_targets))

    preds_per_seq = 1 + 3 * options.mut_len

    score_threads = []
    score_queue = Queue()
    for i in range(1):
        sw = ScoreWorker(score_queue, scores_h5, options.sad_stats, mut_start,
                         mut_end)
        sw.start()
        score_threads.append(sw)

    #################################################################
    # predict scores and write output

    # find center
    preds_length = seqnn_model.target_lengths[0]
    center_start = preds_length // 2
    if preds_length % 2 == 0:
        center_end = center_start + 2
    else:
        center_end = center_start + 1

    # initialize predictions stream
    preds_stream = stream.PredStreamGen(seqnn_model, seqs_gen,
                                        params_train['batch_size'])

    # predictions index
    pi = 0

    for si in range(num_seqs):
        print('Predicting %d' % si, flush=True)

        # collect sequence predictions
        seq_preds_sum = []
        seq_preds_center = []
        seq_preds_scd = []
        preds_mut0 = preds_stream[pi]
        for spi in range(preds_per_seq):
            preds_mut = preds_stream[pi]
            preds_sum = preds_mut.sum(axis=0)
            seq_preds_sum.append(preds_sum)
            if 'center' in options.sad_stats:
                preds_center = preds_mut[center_start:center_end, :].sum(
                    axis=0)
                seq_preds_center.append(preds_center)
            if 'scd' in options.sad_stats:
                preds_scd = np.sqrt(((preds_mut - preds_mut0)**2).sum(axis=0))
                seq_preds_scd.append(preds_scd)
            pi += 1
        seq_preds_sum = np.array(seq_preds_sum)
        seq_preds_center = np.array(seq_preds_center)
        seq_preds_scd = np.array(seq_preds_scd)

        # wait for previous to finish
        score_queue.join()

        # queue sequence for scoring
        seq_pred_stats = (seq_preds_sum, seq_preds_center, seq_preds_scd)
        score_queue.put((seqs_dna[si], seq_pred_stats, si))

        gc.collect()

    # finish queue
    print('Waiting for threads to finish.', flush=True)
    score_queue.join()

    # close output HDF5
    scores_h5.close()
Esempio n. 5
0
def run(params_file, train_files, test_files, train_epochs,
        train_epoch_batches, test_epoch_batches):

    # parse shifts
    augment_shifts = [int(shift) for shift in FLAGS.augment_shifts.split(',')]
    ensemble_shifts = [
        int(shift) for shift in FLAGS.ensemble_shifts.split(',')
    ]

    # read parameters
    job = params.read_job_params(params_file)
    job['num_genomes'] = job.get('num_genomes', 1)
    if not isinstance(job['num_targets'], list):
        job['num_targets'] = [job['num_targets']]

    # load data
    data_ops, handle, train_dataseqs, test_dataseqs = make_data_ops(
        job, train_files, test_files)

    # initialize model
    model = seqnn.SeqNN()
    model.build_from_data_ops(job, data_ops, FLAGS.augment_rc, augment_shifts,
                              FLAGS.ensemble_rc, ensemble_shifts)

    # launch accuracy metrics compute thread
    if FLAGS.metrics_thread:
        metrics_queue = Queue()
        metrics_thread = MetricsWorker(metrics_queue)
        metrics_thread.start()

    # checkpoints
    saver = tf.train.Saver()

    # specify CPU parallelism
    session_conf = tf.ConfigProto(intra_op_parallelism_threads=2,
                                  inter_op_parallelism_threads=5)

    # with tf.Session(config=session_conf) as sess:
    with tf.Session() as sess:
        train_writer = tf.summary.FileWriter(
            FLAGS.logdir + '/train', sess.graph) if FLAGS.logdir else None

        # generate handles
        for gi in range(job['num_genomes']):
            train_dataseqs[gi].make_handle(sess)
            test_dataseqs[gi].make_handle(sess)

        if FLAGS.restart:
            # load variables into session
            saver.restore(sess, FLAGS.restart)
        else:
            # initialize variables
            print('Initializing...')
            sess.run(tf.local_variables_initializer())
            sess.run(tf.global_variables_initializer())

        train_loss = None
        best_loss = None
        early_stop_i = 0

        epoch = 0

        while (train_epochs is not None and epoch < train_epochs) or \
              (train_epochs is None and early_stop_i < FLAGS.early_stop):
            t0 = time.time()

            # initialize training data epochs
            for gi in range(job['num_genomes']):
                if train_dataseqs[gi].iterator is not None:
                    sess.run(train_dataseqs[gi].iterator.initializer)

            # train epoch
            train_losses, steps = model.train2_epoch_ops(
                sess, handle, train_dataseqs)

            if FLAGS.metrics_thread:
                # block for previous metrics compute
                metrics_queue.join()

            # test validation
            valid_accs = []
            valid_losses = []
            for gi in range(job['num_genomes']):
                if test_dataseqs[gi].iterator is None:
                    valid_accs.append(None)
                    valid_losses.append(np.nan)

                else:
                    # initialize
                    sess.run(test_dataseqs[gi].iterator.initializer)

                    # compute
                    valid_acc = model.test_tfr(sess, test_dataseqs[gi], handle,
                                               test_epoch_batches,
                                               FLAGS.metrics_sample)

                    # save
                    valid_accs.append(valid_acc)
                    valid_losses.append(valid_acc.loss)

            # summarize
            train_loss = np.nanmean(train_losses)
            valid_loss = np.nanmean(valid_losses)

            # consider as best
            best_str = ''
            if best_loss is None or valid_loss < best_loss:
                best_loss = valid_loss
                best_str = ', best!'
                early_stop_i = 0
                saver.save(sess, '%s/model_best.tf' % FLAGS.logdir)
            else:
                early_stop_i += 1

            # measure time
            et = time.time() - t0
            if et < 600:
                time_str = '%3ds' % et
            elif et < 6000:
                time_str = '%3dm' % (et / 60)
            else:
                time_str = '%3.1fh' % (et / 3600)

            # compute and write accuracy metrics update
            update_args = (epoch, steps, train_losses, valid_losses,
                           valid_accs, time_str, best_str)
            if FLAGS.metrics_thread:
                metrics_queue.put(update_args)
            else:
                metrics_update(*update_args)

            # checkpoint
            saver.save(sess, '%s/model_check.tf' % FLAGS.logdir)

            # update epoch
            epoch += 1

        # block for final metrics compute
        metrics_queue.join()

        if FLAGS.logdir:
            train_writer.close()
Esempio n. 6
0
def main():
    usage = 'usage: %prog [options] <params_file> <model_file> <data_dir>'
    parser = OptionParser(usage)
    parser.add_option(
        '--ai',
        dest='accuracy_indexes',
        help=
        'Comma-separated list of target indexes to make accuracy scatter plots.'
    )
    parser.add_option(
        '--clip',
        dest='target_clip',
        default=None,
        type='float',
        help=
        'Clip targets and predictions to a maximum value [Default: %default]')
    parser.add_option(
        '-d',
        dest='down_sample',
        default=1,
        type='int',
        help=
        'Down sample by taking uniformly spaced positions [Default: %default]')
    parser.add_option('-g',
                      dest='genome_file',
                      default='%s/tutorials/data/human.hg19.genome' %
                      os.environ['BASENJIDIR'],
                      help='Chromosome length information [Default: %default]')
    parser.add_option('--mc',
                      dest='mc_n',
                      default=0,
                      type='int',
                      help='Monte carlo test iterations [Default: %default]')
    parser.add_option(
        '--peak',
        '--peaks',
        dest='peaks',
        default=False,
        action='store_true',
        help='Compute expensive peak accuracy [Default: %default]')
    parser.add_option(
        '-o',
        dest='out_dir',
        default='test_out',
        help='Output directory for test statistics [Default: %default]')
    parser.add_option(
        '--rc',
        dest='rc',
        default=False,
        action='store_true',
        help='Average the fwd and rc predictions [Default: %default]')
    parser.add_option(
        '--save',
        dest='save',
        default=False,
        action='store_true',
        help='Save targets and predictions numpy arrays [Default: %default]')
    parser.add_option('--shifts',
                      dest='shifts',
                      default='0',
                      help='Ensemble prediction shifts [Default: %default]')
    parser.add_option(
        '-t',
        dest='track_bed',
        help='BED file describing regions so we can output BigWig tracks')
    parser.add_option(
        '--ti',
        dest='track_indexes',
        help='Comma-separated list of target indexes to output BigWig tracks')
    parser.add_option(
        '--tfr',
        dest='tfr_pattern',
        default='test-*.tfr',
        help='TFR pattern string appended to data_dir [Default: %default]')
    parser.add_option(
        '-w',
        dest='pool_width',
        default=1,
        type='int',
        help=
        'Max pool width for regressing nt preds to peak calls [Default: %default]'
    )
    (options, args) = parser.parse_args()

    if len(args) != 3:
        parser.error('Must provide parameters, model, and test data HDF5')
    else:
        params_file = args[0]
        model_file = args[1]
        data_dir = args[2]

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    # parse shifts to integers
    options.shifts = [int(shift) for shift in options.shifts.split(',')]

    # read targets
    targets_file = '%s/targets.txt' % data_dir
    targets_df = pd.read_table(targets_file)

    # read model parameters
    job = params.read_job_params(params_file)

    # construct data ops
    tfr_pattern_path = '%s/tfrecords/%s' % (data_dir, options.tfr_pattern)
    data_ops, test_init_op = make_data_ops(job, tfr_pattern_path)

    # initialize model
    model = seqnn.SeqNN()
    model.build_from_data_ops(job,
                              data_ops,
                              ensemble_rc=options.rc,
                              ensemble_shifts=options.shifts)

    # initialize saver
    saver = tf.train.Saver()

    with tf.Session() as sess:
        # start queue runners
        coord = tf.train.Coordinator()
        tf.train.start_queue_runners(coord=coord)

        # load variables into session
        saver.restore(sess, model_file)

        # test
        t0 = time.time()
        sess.run(test_init_op)
        test_acc = model.test_tfr(sess)

        test_preds = test_acc.preds
        print('SeqNN test: %ds' % (time.time() - t0))

        if options.save:
            np.save('%s/preds.npy' % options.out_dir, test_acc.preds)
            np.save('%s/targets.npy' % options.out_dir, test_acc.targets)

        # compute stats
        t0 = time.time()
        test_r2 = test_acc.r2(clip=options.target_clip)
        # test_log_r2 = test_acc.r2(log=True, clip=options.target_clip)
        test_pcor = test_acc.pearsonr(clip=options.target_clip)
        test_log_pcor = test_acc.pearsonr(log=True, clip=options.target_clip)
        #test_scor = test_acc.spearmanr()  # too slow; mostly driven by low values
        print('Compute stats: %ds' % (time.time() - t0))

        # print
        print('Test Loss:         %7.5f' % test_acc.loss)
        print('Test R2:           %7.5f' % test_r2.mean())
        # print('Test log R2:       %7.5f' % test_log_r2.mean())
        print('Test PearsonR:     %7.5f' % test_pcor.mean())
        print('Test log PearsonR: %7.5f' % test_log_pcor.mean())
        # print('Test SpearmanR:    %7.5f' % test_scor.mean())

        acc_out = open('%s/acc.txt' % options.out_dir, 'w')
        for ti in range(len(test_r2)):
            print('%4d  %7.5f  %.5f  %.5f  %.5f  %s' %
                  (ti, test_acc.target_losses[ti], test_r2[ti], test_pcor[ti],
                   test_log_pcor[ti], targets_df.description.iloc[ti]),
                  file=acc_out)
        acc_out.close()

        # print normalization factors
        target_means = test_preds.mean(axis=(0, 1), dtype='float64')
        target_means_median = np.median(target_means)
        target_means /= target_means_median
        norm_out = open('%s/normalization.txt' % options.out_dir, 'w')
        print('\n'.join([str(tu) for tu in target_means]), file=norm_out)
        norm_out.close()

        # clean up
        del test_acc

    #######################################################
    # peak call accuracy

    if options.peaks:
        # sample every few bins to decrease correlations
        ds_indexes_preds = np.arange(0, test_preds.shape[1], 8)
        ds_indexes_targets = ds_indexes_preds + (model.hp.batch_buffer //
                                                 model.hp.target_pool)

        aurocs = []
        auprcs = []

        peaks_out = open('%s/peaks.txt' % options.out_dir, 'w')
        for ti in range(test_targets.shape[2]):
            test_targets_ti = test_targets[:, :, ti]

            # subset and flatten
            test_targets_ti_flat = test_targets_ti[:,
                                                   ds_indexes_targets].flatten(
                                                   ).astype('float32')
            test_preds_ti_flat = test_preds[:, ds_indexes_preds,
                                            ti].flatten().astype('float32')

            # call peaks
            test_targets_ti_lambda = np.mean(test_targets_ti_flat)
            test_targets_pvals = 1 - poisson.cdf(
                np.round(test_targets_ti_flat) - 1, mu=test_targets_ti_lambda)
            test_targets_qvals = np.array(ben_hoch(test_targets_pvals))
            test_targets_peaks = test_targets_qvals < 0.01

            if test_targets_peaks.sum() == 0:
                aurocs.append(0.5)
                auprcs.append(0)

            else:
                # compute prediction accuracy
                aurocs.append(
                    roc_auc_score(test_targets_peaks, test_preds_ti_flat))
                auprcs.append(
                    average_precision_score(test_targets_peaks,
                                            test_preds_ti_flat))

            print('%4d  %6d  %.5f  %.5f' %
                  (ti, test_targets_peaks.sum(), aurocs[-1], auprcs[-1]),
                  file=peaks_out)

        peaks_out.close()

        print('Test AUROC:     %7.5f' % np.mean(aurocs))
        print('Test AUPRC:     %7.5f' % np.mean(auprcs))

    #######################################################
    # BigWig tracks

    # NOTE: THESE ASSUME THERE WAS NO DOWN-SAMPLING ABOVE

    # print bigwig tracks for visualization
    if options.track_bed:
        if options.genome_file is None:
            parser.error(
                'Must provide genome file in order to print valid BigWigs')

        if not os.path.isdir('%s/tracks' % options.out_dir):
            os.mkdir('%s/tracks' % options.out_dir)

        track_indexes = range(test_preds.shape[2])
        if options.track_indexes:
            track_indexes = [
                int(ti) for ti in options.track_indexes.split(',')
            ]

        bed_set = 'test'
        if options.valid:
            bed_set = 'valid'

        for ti in track_indexes:
            test_targets_ti = test_targets[:, :, ti]

            # make true targets bigwig
            bw_file = '%s/tracks/t%d_true.bw' % (options.out_dir, ti)
            bigwig_write(bw_file,
                         test_targets_ti,
                         options.track_bed,
                         options.genome_file,
                         bed_set=bed_set)

            # make predictions bigwig
            bw_file = '%s/tracks/t%d_preds.bw' % (options.out_dir, ti)
            bigwig_write(bw_file,
                         test_preds[:, :, ti],
                         options.track_bed,
                         options.genome_file,
                         model.hp.batch_buffer,
                         bed_set=bed_set)

        # make NA bigwig
        # bw_file = '%s/tracks/na.bw' % options.out_dir
        # bigwig_write(
        #     bw_file,
        #     test_na,
        #     options.track_bed,
        #     options.genome_file,
        #     bed_set=bed_set)

    #######################################################
    # accuracy plots

    if options.accuracy_indexes is not None:
        accuracy_indexes = [
            int(ti) for ti in options.accuracy_indexes.split(',')
        ]

        if not os.path.isdir('%s/scatter' % options.out_dir):
            os.mkdir('%s/scatter' % options.out_dir)

        if not os.path.isdir('%s/violin' % options.out_dir):
            os.mkdir('%s/violin' % options.out_dir)

        if not os.path.isdir('%s/roc' % options.out_dir):
            os.mkdir('%s/roc' % options.out_dir)

        if not os.path.isdir('%s/pr' % options.out_dir):
            os.mkdir('%s/pr' % options.out_dir)

        for ti in accuracy_indexes:
            test_targets_ti = test_targets[:, :, ti]

            ############################################
            # scatter

            # sample every few bins (adjust to plot the # points I want)
            ds_indexes_preds = np.arange(0, test_preds.shape[1], 8)
            ds_indexes_targets = ds_indexes_preds + (model.hp.batch_buffer //
                                                     model.hp.target_pool)

            # subset and flatten
            test_targets_ti_flat = test_targets_ti[:,
                                                   ds_indexes_targets].flatten(
                                                   ).astype('float32')
            test_preds_ti_flat = test_preds[:, ds_indexes_preds,
                                            ti].flatten().astype('float32')

            # take log2
            test_targets_ti_log = np.log2(test_targets_ti_flat + 1)
            test_preds_ti_log = np.log2(test_preds_ti_flat + 1)

            # plot log2
            sns.set(font_scale=1.2, style='ticks')
            out_pdf = '%s/scatter/t%d.pdf' % (options.out_dir, ti)
            plots.regplot(test_targets_ti_log,
                          test_preds_ti_log,
                          out_pdf,
                          poly_order=1,
                          alpha=0.3,
                          sample=500,
                          figsize=(6, 6),
                          x_label='log2 Experiment',
                          y_label='log2 Prediction',
                          table=True)

            ############################################
            # violin

            # call peaks
            test_targets_ti_lambda = np.mean(test_targets_ti_flat)
            test_targets_pvals = 1 - poisson.cdf(
                np.round(test_targets_ti_flat) - 1, mu=test_targets_ti_lambda)
            test_targets_qvals = np.array(ben_hoch(test_targets_pvals))
            test_targets_peaks = test_targets_qvals < 0.01
            test_targets_peaks_str = np.where(test_targets_peaks, 'Peak',
                                              'Background')

            # violin plot
            sns.set(font_scale=1.3, style='ticks')
            plt.figure()
            df = pd.DataFrame({
                'log2 Prediction':
                np.log2(test_preds_ti_flat + 1),
                'Experimental coverage status':
                test_targets_peaks_str
            })
            ax = sns.violinplot(x='Experimental coverage status',
                                y='log2 Prediction',
                                data=df)
            ax.grid(True, linestyle=':')
            plt.savefig('%s/violin/t%d.pdf' % (options.out_dir, ti))
            plt.close()

            # ROC
            plt.figure()
            fpr, tpr, _ = roc_curve(test_targets_peaks, test_preds_ti_flat)
            auroc = roc_auc_score(test_targets_peaks, test_preds_ti_flat)
            plt.plot([0, 1], [0, 1],
                     c='black',
                     linewidth=1,
                     linestyle='--',
                     alpha=0.7)
            plt.plot(fpr, tpr, c='black')
            ax = plt.gca()
            ax.set_xlabel('False positive rate')
            ax.set_ylabel('True positive rate')
            ax.text(0.99,
                    0.02,
                    'AUROC %.3f' % auroc,
                    horizontalalignment='right')  # , fontsize=14)
            ax.grid(True, linestyle=':')
            plt.savefig('%s/roc/t%d.pdf' % (options.out_dir, ti))
            plt.close()

            # PR
            plt.figure()
            prec, recall, _ = precision_recall_curve(test_targets_peaks,
                                                     test_preds_ti_flat)
            auprc = average_precision_score(test_targets_peaks,
                                            test_preds_ti_flat)
            plt.axhline(y=test_targets_peaks.mean(),
                        c='black',
                        linewidth=1,
                        linestyle='--',
                        alpha=0.7)
            plt.plot(recall, prec, c='black')
            ax = plt.gca()
            ax.set_xlabel('Recall')
            ax.set_ylabel('Precision')
            ax.text(0.99,
                    0.95,
                    'AUPRC %.3f' % auprc,
                    horizontalalignment='right')  # , fontsize=14)
            ax.grid(True, linestyle=':')
            plt.savefig('%s/pr/t%d.pdf' % (options.out_dir, ti))
            plt.close()
Esempio n. 7
0
def run(params_file, data_file, train_epochs, train_epoch_batches,
        test_epoch_batches):

    #######################################################
    # load data
    #######################################################
    data_open = h5py.File(data_file)

    train_seqs = data_open['train_in']
    train_targets = data_open['train_out']
    train_na = None
    if 'train_na' in data_open:
        train_na = data_open['train_na']

    valid_seqs = data_open['valid_in']
    valid_targets = data_open['valid_out']
    valid_na = None
    if 'valid_na' in data_open:
        valid_na = data_open['valid_na']

    #######################################################
    # model parameters and placeholders
    #######################################################
    job = params.read_job_params(params_file)

    job['seq_length'] = train_seqs.shape[1]
    job['seq_depth'] = train_seqs.shape[2]
    job['num_targets'] = train_targets.shape[2]
    job['target_pool'] = int(np.array(data_open.get('pool_width', 1)))

    t0 = time.time()
    model = seqnn.SeqNN()
    model.build(job)
    print('Model building time %f' % (time.time() - t0))

    # adjust for fourier
    job['fourier'] = 'train_out_imag' in data_open
    if job['fourier']:
        train_targets_imag = data_open['train_out_imag']
        valid_targets_imag = data_open['valid_out_imag']

    #######################################################
    # prepare batcher
    #######################################################
    if job['fourier']:
        batcher_train = batcher.BatcherF(train_seqs,
                                         train_targets,
                                         train_targets_imag,
                                         train_na,
                                         model.hp.batch_size,
                                         model.hp.target_pool,
                                         shuffle=True)
        batcher_valid = batcher.BatcherF(valid_seqs, valid_targets,
                                         valid_targets_imag, valid_na,
                                         model.batch_size, model.target_pool)
    else:
        batcher_train = batcher.Batcher(train_seqs,
                                        train_targets,
                                        train_na,
                                        model.hp.batch_size,
                                        model.hp.target_pool,
                                        shuffle=True)
        batcher_valid = batcher.Batcher(valid_seqs, valid_targets, valid_na,
                                        model.hp.batch_size,
                                        model.hp.target_pool)
    print('Batcher initialized')

    #######################################################
    # train
    #######################################################
    augment_shifts = [int(shift) for shift in FLAGS.augment_shifts.split(',')]
    ensemble_shifts = [
        int(shift) for shift in FLAGS.ensemble_shifts.split(',')
    ]

    # checkpoints
    saver = tf.train.Saver()

    config = tf.ConfigProto()
    if FLAGS.log_device_placement:
        config.log_device_placement = True
    with tf.Session(config=config) as sess:
        t0 = time.time()

        # set seed
        tf.set_random_seed(FLAGS.seed)

        if FLAGS.logdir:
            train_writer = tf.summary.FileWriter(FLAGS.logdir + '/train',
                                                 sess.graph)
        else:
            train_writer = None

        if FLAGS.restart:
            # load variables into session
            saver.restore(sess, FLAGS.restart)
        else:
            # initialize variables
            print('Initializing...')
            sess.run(tf.global_variables_initializer())
            print('Initialization time %f' % (time.time() - t0))

        train_loss = None
        best_loss = None
        early_stop_i = 0

        epoch = 0
        while (train_epochs is None
               or epochs < train_epochs) and early_stop_i < FLAGS.early_stop:
            t0 = time.time()

            # alternate forward and reverse batches
            fwdrc = True
            if FLAGS.augment_rc and epoch % 2 == 1:
                fwdrc = False

            # cycle shifts
            shift_i = epoch % len(augment_shifts)

            # train
            train_loss, steps = model.train_epoch(
                sess,
                batcher_train,
                fwdrc=fwdrc,
                shift=augment_shifts[shift_i],
                sum_writer=train_writer,
                epoch_batches=train_epoch_batches,
                no_steps=FLAGS.no_steps)

            # validate
            valid_acc = model.test(sess,
                                   batcher_valid,
                                   mc_n=FLAGS.ensemble_mc,
                                   rc=FLAGS.ensemble_rc,
                                   shifts=ensemble_shifts,
                                   test_batches=test_epoch_batches)
            valid_loss = valid_acc.loss
            valid_r2 = valid_acc.r2().mean()
            del valid_acc

            best_str = ''
            if best_loss is None or valid_loss < best_loss:
                best_loss = valid_loss
                best_str = ', best!'
                early_stop_i = 0
                saver.save(sess, '%s/model_best.tf' % FLAGS.logdir)
            else:
                early_stop_i += 1

            # measure time
            et = time.time() - t0
            if et < 600:
                time_str = '%3ds' % et
            elif et < 6000:
                time_str = '%3dm' % (et / 60)
            else:
                time_str = '%3.1fh' % (et / 3600)

            # print update
            print(
                'Epoch: %3d,  Steps: %7d,  Train loss: %7.5f,  Valid loss: %7.5f,  Valid R2: %7.5f,  Time: %s%s'
                % (epoch + 1, steps, train_loss, valid_loss, valid_r2,
                   time_str, best_str))
            sys.stdout.flush()

            if FLAGS.check_all:
                saver.save(sess, '%s/model_check%d.tf' % (FLAGS.logdir, epoch))

            # update epoch
            epoch += 1

        if FLAGS.logdir:
            train_writer.close()
Esempio n. 8
0
def main():
    usage = "usage: %prog [options] <params_file> <model_file> <genes_hdf5_file>"
    parser = OptionParser(usage)
    parser.add_option(
        "-g",
        dest="genome_file",
        default="%s/data/human.hg19.genome" % os.environ["BASENJIDIR"],
        help="Chromosome lengths file [Default: %default]",
    )
    parser.add_option("-l",
                      dest="gene_list",
                      help="Process only gene ids in the given file")
    parser.add_option(
        "-o",
        dest="out_dir",
        default="grad_mapg",
        help="Output directory [Default: %default]",
    )
    parser.add_option("-t",
                      dest="target_indexes",
                      default=None,
                      help="Target indexes to plot")
    (options, args) = parser.parse_args()

    if len(args) != 3:
        parser.error("Must provide parameters, model, and genomic position")
    else:
        params_file = args[0]
        model_file = args[1]
        genes_hdf5_file = args[2]

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    #################################################################
    # reads in genes HDF5

    gene_data = genedata.GeneData(genes_hdf5_file)

    # subset gene sequences
    genes_subset = set()
    if options.gene_list:
        for line in open(options.gene_list):
            genes_subset.add(line.rstrip())

        gene_data.subset_genes(genes_subset)
        print("Filtered to %d sequences" % gene_data.num_seqs)

    #######################################################
    # model parameters and placeholders

    job = params.read_job_params(params_file)

    job["seq_length"] = gene_data.seq_length
    job["seq_depth"] = gene_data.seq_depth
    job["target_pool"] = gene_data.pool_width

    if "num_targets" not in job:
        print(
            "Must specify number of targets (num_targets) in the parameters file.",
            file=sys.stderr,
        )
        exit(1)

    # set target indexes
    if options.target_indexes is not None:
        options.target_indexes = [
            int(ti) for ti in options.target_indexes.split(",")
        ]
        target_subset = options.target_indexes
    else:
        options.target_indexes = list(range(job["num_targets"]))
        target_subset = None

    # build model
    model = seqnn.SeqNN()
    model.build(job, target_subset=target_subset)

    # determine latest pre-dilated layer
    cnn_dilation = np.array([cp.dilation for cp in model.hp.cnn_params])
    dilated_mask = cnn_dilation > 1
    dilated_indexes = np.where(dilated_mask)[0]
    pre_dilated_layer = np.min(dilated_indexes)
    print("Pre-dilated layer: %d" % pre_dilated_layer)

    # build gradients ops
    t0 = time.time()
    print("Building target/position-specific gradient ops.", end="")
    model.build_grads_genes(gene_data.gene_seqs, layers=[pre_dilated_layer])
    print(" Done in %ds" % (time.time() - t0), flush=True)

    #######################################################
    # acquire gradients

    # initialize saver
    saver = tf.train.Saver()

    with tf.Session() as sess:
        # load variables into session
        saver.restore(sess, model_file)

        for si in range(gene_data.num_seqs):
            # initialize batcher
            batcher_si = batcher.Batcher(
                gene_data.seqs_1hot[si:si + 1],
                batch_size=model.hp.batch_size,
                pool_width=model.hp.target_pool,
            )

            # get layer representations
            t0 = time.time()
            print("Computing gradients.", end="", flush=True)
            batch_grads, batch_reprs = model.gradients_genes(
                sess, batcher_si, gene_data.gene_seqs[si:si + 1])
            print(" Done in %ds." % (time.time() - t0), flush=True)

            # only layer
            batch_reprs = batch_reprs[0]
            batch_grads = batch_grads[0]

            # G (TSSs) x T (targets) x P (seq position) x U (Units layer i)
            print("batch_grads", batch_grads.shape)
            pooled_length = batch_grads.shape[2]

            # S (sequences) x P (seq position) x U (Units layer i)
            print("batch_reprs", batch_reprs.shape)

            # write bigwigs
            t0 = time.time()
            print("Writing BigWigs.", end="", flush=True)

            # for each TSS
            for tss_i in range(batch_grads.shape[0]):
                tss = gene_data.gene_seqs[si].tss_list[tss_i]

                # for each target
                for tii in range(len(options.target_indexes)):
                    ti = options.target_indexes[tii]

                    # dot representation and gradient
                    batch_grads_score = np.multiply(
                        batch_reprs[0], batch_grads[tss_i,
                                                    tii, :, :]).sum(axis=1)

                    # open bigwig
                    bw_file = "%s/%s-%s_t%d.bw" % (
                        options.out_dir,
                        tss.gene_id,
                        tss.identifier,
                        ti,
                    )
                    bw_open = bigwig_open(bw_file, options.genome_file)

                    # access gene sequence information
                    seq_chrom = gene_data.gene_seqs[si].chrom
                    seq_start = gene_data.gene_seqs[si].start

                    # specify bigwig locations and values
                    bw_chroms = [seq_chrom] * pooled_length
                    bw_starts = [
                        int(seq_start + li * model.hp.target_pool)
                        for li in range(pooled_length)
                    ]
                    bw_ends = [
                        int(bws + model.hp.target_pool) for bws in bw_starts
                    ]
                    bw_values = [float(bgs) for bgs in batch_grads_score]

                    # write
                    bw_open.addEntries(bw_chroms,
                                       bw_starts,
                                       ends=bw_ends,
                                       values=bw_values)

                    # close
                    bw_open.close()

            print(" Done in %ds." % (time.time() - t0), flush=True)
            gc.collect()
Esempio n. 9
0
def main():
    usage = 'usage: %prog [options] <params_file> <model_file> <bed_file>'
    parser = OptionParser(usage)
    parser.add_option(
        '-d',
        dest='mut_down',
        default=0,
        type='int',
        help=
        'Nucleotides downstream of center sequence to mutate [Default: %default]'
    )
    parser.add_option('-f',
                      dest='genome_fasta',
                      default=None,
                      help='Genome FASTA for sequences [Default: %default]')
    parser.add_option(
        '-l',
        dest='mut_len',
        default=0,
        type='int',
        help='Length of center sequence to mutate [Default: %default]')
    parser.add_option('-o',
                      dest='out_dir',
                      default='sat_mut',
                      help='Output directory [Default: %default]')
    parser.add_option('--plots',
                      dest='plots',
                      default=False,
                      action='store_true',
                      help='Make heatmap plots [Default: %default]')
    parser.add_option('-p',
                      dest='processes',
                      default=None,
                      type='int',
                      help='Number of processes, passed by multi script')
    parser.add_option(
        '--rc',
        dest='rc',
        default=False,
        action='store_true',
        help=
        'Ensemble forward and reverse complement predictions [Default: %default]'
    )
    parser.add_option('--shifts',
                      dest='shifts',
                      default='0',
                      help='Ensemble prediction shifts [Default: %default]')
    parser.add_option(
        '--stats',
        dest='sad_stats',
        default='sum',
        help='Comma-separated list of stats to save. [Default: %default]')
    parser.add_option(
        '-t',
        dest='targets_file',
        default=None,
        type='str',
        help='File specifying target indexes and labels in table format')
    parser.add_option(
        '-u',
        dest='mut_up',
        default=0,
        type='int',
        help=
        'Nucleotides upstream of center sequence to mutate [Default: %default]'
    )
    (options, args) = parser.parse_args()

    if len(args) == 3:
        # single worker
        params_file = args[0]
        model_file = args[1]
        bed_file = args[2]

    elif len(args) == 4:
        # master script
        options_pkl_file = args[0]
        params_file = args[1]
        model_file = args[2]
        bed_file = args[3]

        # load options
        options_pkl = open(options_pkl_file, 'rb')
        options = pickle.load(options_pkl)
        options_pkl.close()

    elif len(args) == 5:
        # multi worker
        options_pkl_file = args[0]
        params_file = args[1]
        model_file = args[2]
        bed_file = args[3]
        worker_index = int(args[4])

        # load options
        options_pkl = open(options_pkl_file, 'rb')
        options = pickle.load(options_pkl)
        options_pkl.close()

        # update output directory
        options.out_dir = '%s/job%d' % (options.out_dir, worker_index)

    else:
        parser.error('Must provide parameter and model files and BED file')

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    options.shifts = [int(shift) for shift in options.shifts.split(',')]
    options.sad_stats = [
        sad_stat.lower() for sad_stat in options.sad_stats.split(',')
    ]

    if options.mut_up > 0 or options.mut_down > 0:
        options.mut_len = options.mut_up + options.mut_down
    else:
        assert (options.mut_len > 0)
        options.mut_up = options.mut_len // 2
        options.mut_down = options.mut_len - options.mut_up

    #################################################################
    # read parameters and targets

    # read model parameters
    with open(params_file) as params_open:
        params = json.load(params_open)
    params_model = params['model']
    params_train = params['train']

    # read targets
    if options.targets_file is None:
        target_slice = None
    else:
        targets_df = pd.read_table(options.targets_file, index_col=0)
        target_slice = targets_df.index

    #################################################################
    # setup model

    seqnn_model = seqnn.SeqNN(params_model)
    seqnn_model.restore(model_file)
    seqnn_model.build_slice(target_slice)
    seqnn_model.build_ensemble(options.rc, options.shifts)

    num_targets = seqnn_model.num_targets()

    #################################################################
    # sequence dataset

    # read sequences from BED
    seqs_dna, seqs_coords = bed.make_bed_seqs(bed_file,
                                              options.genome_fasta,
                                              params_model['seq_length'],
                                              stranded=True)

    # filter for worker SNPs
    if options.processes is not None:
        worker_bounds = np.linspace(0,
                                    len(seqs_dna),
                                    options.processes + 1,
                                    dtype='int')
        seqs_dna = seqs_dna[
            worker_bounds[worker_index]:worker_bounds[worker_index + 1]]
        seqs_coords = seqs_coords[
            worker_bounds[worker_index]:worker_bounds[worker_index + 1]]

    num_seqs = len(seqs_dna)

    # determine mutation region limits
    seq_mid = params_model['seq_length'] // 2
    mut_start = seq_mid - options.mut_up
    mut_end = mut_start + options.mut_len

    # make sequence generator
    seqs_gen = satmut_gen(seqs_dna, mut_start, mut_end)

    #################################################################
    # setup output

    scores_h5_file = '%s/scores.h5' % options.out_dir
    if os.path.isfile(scores_h5_file):
        os.remove(scores_h5_file)
    scores_h5 = h5py.File('%s/scores.h5' % options.out_dir, 'w')
    scores_h5.create_dataset('seqs',
                             dtype='bool',
                             shape=(num_seqs, options.mut_len, 4))
    for sad_stat in options.sad_stats:
        scores_h5.create_dataset(sad_stat,
                                 dtype='float16',
                                 shape=(num_seqs, options.mut_len, 4,
                                        num_targets))

    # store mutagenesis sequence coordinates
    seqs_chr, seqs_start, _, seqs_strand = zip(*seqs_coords)
    seqs_chr = np.array(seqs_chr, dtype='S')
    seqs_start = np.array(seqs_start) + mut_start
    seqs_end = seqs_start + options.mut_len
    seqs_strand = np.array(seqs_strand, dtype='S')
    scores_h5.create_dataset('chrom', data=seqs_chr)
    scores_h5.create_dataset('start', data=seqs_start)
    scores_h5.create_dataset('end', data=seqs_end)
    scores_h5.create_dataset('strand', data=seqs_strand)

    preds_per_seq = 1 + 3 * options.mut_len

    score_threads = []
    score_queue = Queue()
    for i in range(1):
        sw = ScoreWorker(score_queue, scores_h5, options.sad_stats, mut_start,
                         mut_end)
        sw.start()
        score_threads.append(sw)

    #################################################################
    # predict scores, write output

    # find center
    preds_length = seqnn_model.target_lengths[0]
    center_start = preds_length // 2
    if preds_length % 2 == 0:
        center_end = center_start + 2
    else:
        center_end = center_start + 1

    # initialize predictions stream
    preds_stream = stream.PredStreamGen(seqnn_model, seqs_gen,
                                        params['train']['batch_size'])

    # predictions index
    pi = 0

    for si in range(num_seqs):
        print('Predicting %d' % si, flush=True)

        # collect sequence predictions
        seq_preds_sum = []
        seq_preds_center = []
        seq_preds_scd = []
        preds_mut0 = preds_stream[pi]
        for spi in range(preds_per_seq):
            preds_mut = preds_stream[pi]
            preds_sum = preds_mut.sum(axis=0)
            seq_preds_sum.append(preds_sum)
            if 'center' in options.sad_stats:
                preds_center = preds_mut[center_start:center_end, :].sum(
                    axis=0)
                seq_preds_center.append(preds_center)
            if 'scd' in options.sad_stats:
                preds_scd = np.sqrt(((preds_mut - preds_mut0)**2).sum(axis=0))
                seq_preds_scd.append(preds_scd)
            pi += 1
        seq_preds_sum = np.array(seq_preds_sum)
        seq_preds_center = np.array(seq_preds_center)
        seq_preds_scd = np.array(seq_preds_scd)

        # wait for previous to finish
        score_queue.join()

        # queue sequence for scoring
        seq_pred_stats = (seq_preds_sum, seq_preds_center, seq_preds_scd)
        score_queue.put((seqs_dna[si], seq_pred_stats, si))

        # queue sequence for plotting
        if options.plots:
            plot_queue.put((seqs_dna[si], seq_preds_sum, si))

        gc.collect()

    # finish queue
    print('Waiting for threads to finish.', flush=True)
    score_queue.join()

    # close output HDF5
    scores_h5.close()
Esempio n. 10
0
def main():
    usage = 'usage: %prog [options] <params_file> <model_file> <data_dir>'
    parser = OptionParser(usage)
    parser.add_option(
        '-b',
        dest='track_bed',
        help='BED file describing regions so we can output BigWig tracks')
    parser.add_option('-g',
                      dest='genome_file',
                      default='%s/tutorials/data/human.hg19.genome' %
                      os.environ['BASENJIDIR'],
                      help='Chromosome length information [Default: %default]')
    parser.add_option('--mc',
                      dest='mc_n',
                      default=0,
                      type='int',
                      help='Monte carlo test iterations [Default: %default]')
    parser.add_option(
        '-o',
        dest='out_dir',
        default='test_out',
        help='Output directory for test statistics [Default: %default]')
    parser.add_option(
        '--rc',
        dest='rc',
        default=False,
        action='store_true',
        help='Average the fwd and rc predictions [Default: %default]')
    parser.add_option('--shifts',
                      dest='shifts',
                      default='0',
                      help='Ensemble prediction shifts [Default: %default]')
    parser.add_option(
        '-t',
        dest='targets_file',
        default=None,
        type='str',
        help='File specifying target indexes and labels in table format')
    parser.add_option(
        '--ti',
        dest='track_indexes',
        help='Comma-separated list of target indexes to output BigWig tracks')
    parser.add_option(
        '--tfr',
        dest='tfr_pattern',
        default='test-*.tfr',
        help='TFR pattern string appended to data_dir [Default: %default]')
    (options, args) = parser.parse_args()

    if len(args) != 3:
        parser.error('Must provide parameters, model, and test data HDF5')
    else:
        params_file = args[0]
        model_file = args[1]
        data_dir = args[2]

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    # parse shifts to integers
    options.shifts = [int(shift) for shift in options.shifts.split(',')]

    # read targets
    if options.targets_file is None:
        options.targets_file = '%s/targets.txt' % data_dir
        targets_df = pd.read_csv(options.targets_file, index_col=0, sep='\t')
        target_subset = None
    else:
        targets_df = pd.read_csv(options.targets_file, index_col=0, sep='\t')
        target_subset = targets_df.index

    # read model parameters
    with open(params_file) as params_open:
        params = json.load(params_open)
    params_model = params['model']
    params_train = params['train']

    # read data parameters
    data_stats_file = '%s/statistics.json' % data_dir
    with open(data_stats_file) as data_stats_open:
        data_stats = json.load(data_stats_open)

    # construct data ops
    tfr_pattern_path = '%s/tfrecords/%s' % (data_dir, options.tfr_pattern)
    eval_data = dataset.SeqDataset(tfr_pattern_path,
                                   seq_length=params_model['seq_length'],
                                   target_length=data_stats['target_length'],
                                   batch_size=params_train['batch_size'],
                                   mode=tf.estimator.ModeKeys.EVAL)

    # initialize model
    seqnn_model = seqnn.SeqNN(params_model)
    seqnn_model.restore(model_file)
    seqnn_model.build_ensemble(options.rc, options.shifts)

    # predict
    test_preds = seqnn_model.predict(eval_data, verbose=1).astype('float16')

    # save
    preds_h5 = h5py.File('%s/preds.h5' % options.out_dir, 'w')
    preds_h5.create_dataset('preds', data=test_preds)
    preds_h5.close()

    # print normalization factors
    target_means = test_preds.mean(axis=(0, 1), dtype='float64')
    target_means_median = np.median(target_means)
    # target_means /= target_means_median
    norm_out = open('%s/normalization.txt' % options.out_dir, 'w')
    # print('\n'.join([str(tu) for tu in target_means]), file=norm_out)
    for ti in range(len(target_means)):
        print(ti,
              target_means[ti],
              target_means_median / target_means[ti],
              file=norm_out)
    norm_out.close()

    #######################################################
    # BigWig tracks

    # print bigwig tracks for visualization
    if options.track_bed:
        if options.genome_file is None:
            parser.error(
                'Must provide genome file in order to print valid BigWigs.')

        if not os.path.isdir('%s/tracks' % options.out_dir):
            os.mkdir('%s/tracks' % options.out_dir)

        track_indexes = range(test_preds.shape[2])
        if options.track_indexes:
            track_indexes = [
                int(ti) for ti in options.track_indexes.split(',')
            ]

        for ti in track_indexes:
            # make predictions bigwig
            bw_file = '%s/tracks/t%d_preds.bw' % (options.out_dir, ti)
            bigwig_write(bw_file, test_preds[:, :, ti], options.track_bed,
                         options.genome_file, model.hp.batch_buffer)
def run(dir):
    set_logger(path.join(dir, "experiment.log"))

    # read parameters
    job = params.read_job_params(path.join(dir, "params.txt"))
    tfr_dir = job["data_dir"]
    test_file = None
    train_file = None
    test_epoch_batches = job["test_epoch_batches"]
    train_epoch_batches = job["train_epoch_batches"]
    train_epochs = job["train_epochs"]

    if tfr_dir:
        # load data
        data_ops, training_init_op, test_init_op = make_data_ops(
            job, tfr_dir=tfr_dir)
    elif train_file and test_file:
        data_ops, training_init_op, test_init_op = make_data_ops(
            job, train_file=train_file, test_file=test_file)
    else:
        raise Exception('train and/or test paths missing. Aborting.')

    save_dir = dir if "save_dir" not in job else job["save_dir"]
    model_dir = path.join(save_dir, "model")
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)

    # initialize model
    model = seqnn.SeqNN()
    model.build_from_data_ops(job, data_ops)

    # launch accuracy compute thread
    acc_queue = Queue()
    acc_thread = AccuracyWorker(acc_queue)
    acc_thread.start()

    # checkpoints
    saver = tf.train.Saver()

    with tf.Session() as sess:
        train_writer = tf.summary.FileWriter(dir + '/train',
                                             sess.graph) if dir else None

        coord = tf.train.Coordinator()
        tf.train.start_queue_runners(coord=coord)

        t0 = time.time()
        print('Initializing...')
        sess.run(tf.local_variables_initializer())
        sess.run(tf.global_variables_initializer())
        print('Initialization time %f' % (time.time() - t0))

        if 'restart' in job:
            # only include
            restore_variables = [
                var for var in tf.global_variables()
                if "attention" not in var.name and "cnn_final" not in var.name
                and "decay_factor" not in var.name
            ]
            restore_saver = tf.train.Saver(var_list=restore_variables)
            # load variables into session
            restore_saver.restore(sess, job['restart'])
        #else:
        #  # initialize variables
        #  t0 = time.time()
        #  print('Initializing...')
        #  sess.run(tf.local_variables_initializer())
        #  sess.run(tf.global_variables_initializer())
        #  print('Initialization time %f' % (time.time() - t0))

        train_loss = None
        best_loss = None
        early_stop_i = 0

        epoch = 0
        while (train_epochs is not None and epoch < train_epochs) or \
              (train_epochs is None and early_stop_i < FLAGS.early_stop):
            t0 = time.time()

            # save previous
            train_loss_last = train_loss

            # train epoch
            print("Training – epoch: {}".format(epoch))
            sess.run(training_init_op)
            train_loss, steps = model.train_epoch_tfr(sess, train_writer,
                                                      train_epoch_batches)

            # block for previous accuracy compute
            acc_queue.join()

            # test validation
            print("Validation – epoch: {}".format(epoch))
            sess.run(test_init_op)
            valid_acc = model.test_tfr(sess, test_epoch_batches)

            # consider as best
            best_str = ''
            if best_loss is None or valid_acc.loss < best_loss:
                best_loss = valid_acc.loss
                best_str = ', best!'
                early_stop_i = 0
                saver.save(sess, path.join(model_dir, "model_best.tf"))
            else:
                early_stop_i += 1

            # measure time
            epoch_time = time.time() - t0
            if epoch_time < 600:
                time_str = '%3ds' % epoch_time
            elif epoch_time < 6000:
                time_str = '%3dm' % (epoch_time / 60)
            else:
                time_str = '%3.1fh' % (epoch_time / 3600)

            # compute and write accuracy update
            #accuracy_update(epoch, steps, train_loss, valid_acc, time_str, best_str)
            acc_queue.put((epoch, steps, train_loss, valid_acc, time_str,
                           best_str, train_writer))

            # update epoch
            epoch += 1

        # finish queue
        acc_queue.join()

        if FLAGS.logdir:
            train_writer.close()
Esempio n. 12
0
def main():
    usage = "usage: %prog [options] <params_file> <model_file> <vcf_file>"
    parser = OptionParser(usage)
    parser.add_option(
        "-b",
        dest="batch_size",
        default=256,
        type="int",
        help="Batch size [Default: %default]",
    )
    parser.add_option(
        "-c",
        dest="csv",
        default=False,
        action="store_true",
        help="Print table as CSV [Default: %default]",
    )
    parser.add_option(
        "-f",
        dest="genome_fasta",
        default="%s/data/hg19.fa" % os.environ["BASENJIDIR"],
        help="Genome FASTA for sequences [Default: %default]",
    )
    parser.add_option(
        "-g",
        dest="genome_file",
        default="%s/data/human.hg19.genome" % os.environ["BASENJIDIR"],
        help="Chromosome lengths file [Default: %default]",
    )
    parser.add_option(
        "--h5",
        dest="out_h5",
        default=False,
        action="store_true",
        help="Output stats to sad.h5 [Default: %default]",
    )
    parser.add_option(
        "--local",
        dest="local",
        default=1024,
        type="int",
        help="Local SAD score [Default: %default]",
    )
    parser.add_option("-n",
                      dest="norm_file",
                      default=None,
                      help="Normalize SAD scores")
    parser.add_option(
        "-o",
        dest="out_dir",
        default="sad",
        help="Output directory for tables and plots [Default: %default]",
    )
    parser.add_option(
        "-p",
        dest="processes",
        default=None,
        type="int",
        help="Number of processes, passed by multi script",
    )
    parser.add_option(
        "--pseudo",
        dest="log_pseudo",
        default=1,
        type="float",
        help="Log2 pseudocount [Default: %default]",
    )
    parser.add_option(
        "--rc",
        dest="rc",
        default=False,
        action="store_true",
        help=
        "Average forward and reverse complement predictions [Default: %default]",
    )
    parser.add_option(
        "--shifts",
        dest="shifts",
        default="0",
        type="str",
        help="Ensemble prediction shifts [Default: %default]",
    )
    parser.add_option(
        "--stats",
        dest="sad_stats",
        default="SAD,xSAR",
        help="Comma-separated list of stats to save. [Default: %default]",
    )
    parser.add_option(
        "-t",
        dest="targets_file",
        default=None,
        type="str",
        help="File specifying target indexes and labels in table format",
    )
    parser.add_option(
        "--ti",
        dest="track_indexes",
        default=None,
        type="str",
        help="Comma-separated list of target indexes to output BigWig tracks",
    )
    parser.add_option(
        "-u",
        dest="penultimate",
        default=False,
        action="store_true",
        help="Compute SED in the penultimate layer [Default: %default]",
    )
    parser.add_option(
        "-z",
        dest="out_zarr",
        default=False,
        action="store_true",
        help="Output stats to sad.zarr [Default: %default]",
    )
    (options, args) = parser.parse_args()

    if len(args) == 3:
        # single worker
        params_file = args[0]
        model_file = args[1]
        vcf_file = args[2]

    elif len(args) == 5:
        # multi worker
        options_pkl_file = args[0]
        params_file = args[1]
        model_file = args[2]
        vcf_file = args[3]
        worker_index = int(args[4])

        # load options
        options_pkl = open(options_pkl_file, "rb")
        options = pickle.load(options_pkl)
        options_pkl.close()

        # update output directory
        options.out_dir = "%s/job%d" % (options.out_dir, worker_index)

    else:
        parser.error(
            "Must provide parameters and model files and QTL VCF file")

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    if options.track_indexes is None:
        options.track_indexes = []
    else:
        options.track_indexes = [
            int(ti) for ti in options.track_indexes.split(",")
        ]
        if not os.path.isdir("%s/tracks" % options.out_dir):
            os.mkdir("%s/tracks" % options.out_dir)

    options.shifts = [int(shift) for shift in options.shifts.split(",")]
    options.sad_stats = options.sad_stats.split(",")

    #################################################################
    # setup model

    job = params.read_job_params(
        params_file, require=["seq_length", "num_targets", "target_pool"])

    if options.targets_file is None:
        target_ids = ["t%d" % ti for ti in range(job["num_targets"])]
        target_labels = [""] * len(target_ids)
        target_subset = None

    else:
        targets_df = pd.read_table(options.targets_file, index_col=0)
        target_ids = targets_df.identifier
        target_labels = targets_df.description
        target_subset = targets_df.index
        if len(target_subset) == job["num_targets"]:
            target_subset = None

    # build model
    t0 = time.time()
    model = seqnn.SeqNN()
    model.build_feed(
        job,
        ensemble_rc=options.rc,
        ensemble_shifts=options.shifts,
        embed_penultimate=options.penultimate,
        target_subset=target_subset,
    )
    print("Model building time %f" % (time.time() - t0), flush=True)

    if options.penultimate:
        # labels become inappropriate
        target_ids = [""] * model.hp.cnn_filters[-1]
        target_labels = target_ids

    # read target normalization factors
    target_norms = np.ones(len(target_labels))
    if options.norm_file is not None:
        ti = 0
        for line in open(options.norm_file):
            target_norms[ti] = float(line.strip())
            ti += 1

    num_targets = len(target_ids)

    #################################################################
    # load SNPs

    snps = bvcf.vcf_snps(vcf_file)

    # filter for worker SNPs
    if options.processes is not None:
        worker_bounds = np.linspace(0,
                                    len(snps),
                                    options.processes + 1,
                                    dtype="int")
        snps = snps[worker_bounds[worker_index]:worker_bounds[worker_index +
                                                              1]]

    num_snps = len(snps)

    #################################################################
    # setup output

    header_cols = (
        "rsid",
        "ref",
        "alt",
        "ref_pred",
        "alt_pred",
        "sad",
        "sar",
        "geo_sad",
        "ref_lpred",
        "alt_lpred",
        "lsad",
        "lsar",
        "ref_xpred",
        "alt_xpred",
        "xsad",
        "xsar",
        "target_index",
        "target_id",
        "target_label",
    )

    if options.out_h5:
        sad_out = initialize_output_h5(options.out_dir, options.sad_stats,
                                       snps, target_ids, target_labels)

    elif options.out_zarr:
        sad_out = initialize_output_zarr(options.out_dir, options.sad_stats,
                                         snps, target_ids, target_labels)

    else:
        if options.csv:
            sad_out = open("%s/sad_table.csv" % options.out_dir, "w")
            print(",".join(header_cols), file=sad_out)
        else:
            sad_out = open("%s/sad_table.txt" % options.out_dir, "w")
            print(" ".join(header_cols), file=sad_out)

    #################################################################
    # process

    # open genome FASTA
    genome_open = pysam.Fastafile(options.genome_fasta)

    # determine local start and end
    loc_mid = model.target_length // 2
    loc_start = loc_mid - (options.local // 2) // model.hp.target_pool
    loc_end = loc_start + options.local // model.hp.target_pool

    snp_i = 0
    szi = 0

    # initialize saver
    saver = tf.train.Saver()
    with tf.Session() as sess:
        # load variables into session
        saver.restore(sess, model_file)

        # construct first batch
        batch_1hot, batch_snps, snp_i = snps_next_batch(
            snps, snp_i, options.batch_size, job["seq_length"], genome_open)

        while len(batch_snps) > 0:
            ###################################################
            # predict

            # initialize batcher
            batcher = batcher.Batcher(batch_1hot,
                                      batch_size=model.hp.batch_size)

            # predict
            # batch_preds = model.predict(sess, batcher,
            #                 rc=options.rc, shifts=options.shifts,
            #                 penultimate=options.penultimate)
            batch_preds = model.predict_h5(sess, batcher)

            # normalize
            batch_preds /= target_norms

            ###################################################
            # collect and print SADs

            pi = 0
            for snp in batch_snps:
                # get reference prediction (LxT)
                ref_preds = batch_preds[pi]
                pi += 1

                # sum across length
                ref_preds_sum = ref_preds.sum(axis=0, dtype="float64")

                # print tracks
                for ti in options.track_indexes:
                    ref_bw_file = "%s/tracks/%s_t%d_ref.bw" % (
                        options.out_dir,
                        snp.rsid,
                        ti,
                    )
                    bigwig_write(
                        snp,
                        job["seq_length"],
                        ref_preds[:, ti],
                        model,
                        ref_bw_file,
                        options.genome_file,
                    )

                for alt_al in snp.alt_alleles:
                    # get alternate prediction (LxT)
                    alt_preds = batch_preds[pi]
                    pi += 1

                    # sum across length
                    alt_preds_sum = alt_preds.sum(axis=0, dtype="float64")

                    # compare reference to alternative via mean subtraction
                    sad_vec = alt_preds - ref_preds
                    sad = alt_preds_sum - ref_preds_sum

                    # compare reference to alternative via mean log division
                    sar = np.log2(alt_preds_sum +
                                  options.log_pseudo) - np.log2(
                                      ref_preds_sum + options.log_pseudo)

                    # compare geometric means
                    sar_vec = np.log2(
                        alt_preds.astype("float64") +
                        options.log_pseudo) - np.log2(
                            ref_preds.astype("float64") + options.log_pseudo)
                    geo_sad = sar_vec.sum(axis=0)

                    # sum locally
                    ref_preds_loc = ref_preds[loc_start:loc_end, :].sum(
                        axis=0, dtype="float64")
                    alt_preds_loc = alt_preds[loc_start:loc_end, :].sum(
                        axis=0, dtype="float64")

                    # compute SAD locally
                    sad_loc = alt_preds_loc - ref_preds_loc
                    sar_loc = np.log2(alt_preds_loc +
                                      options.log_pseudo) - np.log2(
                                          ref_preds_loc + options.log_pseudo)

                    # compute max difference position
                    max_li = np.argmax(np.abs(sar_vec), axis=0)

                    if options.out_h5 or options.out_zarr:
                        sad_out["SAD"][szi, :] = sad.astype("float16")
                        sad_out["xSAR"][szi, :] = np.array(
                            [
                                sar_vec[max_li[ti], ti]
                                for ti in range(num_targets)
                            ],
                            dtype="float16",
                        )
                        szi += 1

                    else:
                        # print table lines
                        for ti in range(len(sad)):
                            # print line
                            cols = (
                                snp.rsid,
                                bvcf.cap_allele(snp.ref_allele),
                                bvcf.cap_allele(alt_al),
                                ref_preds_sum[ti],
                                alt_preds_sum[ti],
                                sad[ti],
                                sar[ti],
                                geo_sad[ti],
                                ref_preds_loc[ti],
                                alt_preds_loc[ti],
                                sad_loc[ti],
                                sar_loc[ti],
                                ref_preds[max_li[ti], ti],
                                alt_preds[max_li[ti], ti],
                                sad_vec[max_li[ti], ti],
                                sar_vec[max_li[ti], ti],
                                ti,
                                target_ids[ti],
                                target_labels[ti],
                            )
                            if options.csv:
                                print(",".join([str(c) for c in cols]),
                                      file=sad_out)
                            else:
                                print(
                                    "%-13s %6s %6s | %8.2f %8.2f %8.3f %7.4f %7.3f | %7.3f %7.3f %7.3f %7.4f | %7.3f %7.3f %7.3f %7.4f | %4d %12s %s"
                                    % cols,
                                    file=sad_out,
                                )

                    # print tracks
                    for ti in options.track_indexes:
                        alt_bw_file = "%s/tracks/%s_t%d_alt.bw" % (
                            options.out_dir,
                            snp.rsid,
                            ti,
                        )
                        bigwig_write(
                            snp,
                            job["seq_length"],
                            alt_preds[:, ti],
                            model,
                            alt_bw_file,
                            options.genome_file,
                        )

            ###################################################
            # construct next batch

            batch_1hot, batch_snps, snp_i = snps_next_batch(
                snps, snp_i, options.batch_size, job["seq_length"],
                genome_open)

    ###################################################
    # compute SAD distributions across variants

    if options.out_h5 or options.out_zarr:
        # define percentiles
        d_fine = 0.001
        d_coarse = 0.01
        percentiles_neg = np.arange(d_fine, 0.1, d_fine)
        percentiles_base = np.arange(0.1, 0.9, d_coarse)
        percentiles_pos = np.arange(0.9, 1, d_fine)

        percentiles = np.concatenate(
            [percentiles_neg, percentiles_base, percentiles_pos])
        sad_out.create_dataset("percentiles", data=percentiles)
        pct_len = len(percentiles)

        for sad_stat in options.sad_stats:
            sad_stat_pct = "%s_pct" % sad_stat

            # compute
            sad_pct = np.percentile(sad_out[sad_stat],
                                    100 * percentiles,
                                    axis=0).T
            sad_pct = sad_pct.astype("float16")

            # save
            sad_out.create_dataset(sad_stat_pct, data=sad_pct, dtype="float16")

    if not options.out_zarr:
        sad_out.close()
Esempio n. 13
0
def main():
  usage = 'usage: %prog [options] <params_file> <data1_dir> <data2_dir> ...'
  parser = OptionParser(usage)
  parser.add_option('-o', dest='out_dir',
      default='train2_out',
      help='Output directory for test statistics [Default: %default]')
  parser.add_option('--restore', dest='restore',
      help='Restore model and continue training [Default: %default]')
  parser.add_option('--trunk', dest='trunk',
      default=False, action='store_true',
      help='Restore only model trunk [Default: %default]')
  parser.add_option('--tfr_train', dest='tfr_train_pattern',
      default='train-*.tfr',
      help='Training TFRecord pattern string appended to data_dir [Default: %default]')
  parser.add_option('--tfr_eval', dest='tfr_eval_pattern',
      default='valid-*.tfr',
      help='Evaluation TFRecord pattern string appended to data_dir [Default: %default]')
  (options, args) = parser.parse_args()

  if len(args) < 2:
    parser.error('Must provide parameters and data directory.')
  else:
    params_file = args[0]
    data_dirs = args[1:]

  if not os.path.isdir(options.out_dir):
    os.mkdir(options.out_dir)
  if params_file != '%s/params.json' % options.out_dir:
    shutil.copy(params_file, '%s/params.json' % options.out_dir)

  # read model parameters
  with open(params_file) as params_open:
    params = json.load(params_open)
  params_model = params['model']
  params_train = params['train']

  # read datasets
  data_stats = []
  train_data = []
  eval_data = []

  for data_dir in data_dirs:
    # read data parameters
    data_stats_file = '%s/statistics.json' % data_dir
    with open(data_stats_file) as data_stats_open:
      data_stats.append(json.load(data_stats_open))

    # load train data
    tfr_train_full = '%s/tfrecords/%s' % (data_dir, options.tfr_train_pattern)
    train_data.append(dataset.SeqDataset(tfr_train_full,
      seq_length=data_stats[0]['seq_length'],
      target_length=data_stats[0]['target_length'],
      batch_size=params_train['batch_size'],
      mode=tf.estimator.ModeKeys.TRAIN))

    # load eval data
    tfr_eval_full = '%s/tfrecords/%s' % (data_dir, options.tfr_eval_pattern)
    eval_data.append(dataset.SeqDataset(tfr_eval_full,
      seq_length=data_stats[0]['seq_length'],
      target_length=data_stats[0]['target_length'],
      batch_size=params_train['batch_size'],
      mode=tf.estimator.ModeKeys.EVAL))

  if params_train.get('num_gpu', 1) == 1:
    ########################################
    # one GPU

    # initialize model
    seqnn_model = seqnn.SeqNN(params_model)

    # restore
    if options.restore:
      seqnn_model.restore(options.restore, options.trunk)

    # initialize trainer
    seqnn_trainer = trainer.Trainer(params_train, train_data, 
                                    eval_data, options.out_dir)

    # compile model
    seqnn_trainer.compile(seqnn_model)

    # train model
    seqnn_trainer.fit2(seqnn_model)

  else:
    ########################################
    # two GPU

    print('Multiple GPUs untested for joint genome training.', file=sys.stderr)
    exit(1)

    mirrored_strategy = tf.distribute.MirroredStrategy()
    with mirrored_strategy.scope():

      # initialize model
      seqnn_model = seqnn.SeqNN(params_model)

      # restore
      if options.restore:
        seqnn_model.restore(options.restore, options.trunk)

      # initialize trainer
      seqnn_trainer = trainer.Trainer(params_train, train_data,
                                      eval_data, options.out_dir)

      # compile model
      seqnn_trainer.compile(seqnn_model)

    # train model
    seqnn_trainer.fit2(seqnn_model)
Esempio n. 14
0
def main():
    usage = "usage: %prog [options] <params_file> <model_file> <data_dir>"
    parser = OptionParser(usage)
    parser.add_option(
        "-b",
        dest="track_bed",
        help="BED file describing regions so we can output BigWig tracks",
    )
    parser.add_option(
        "-d",
        dest="sample_down",
        default=1.0,
        type="float",
        help="Sample sequence positions down. [Default: %default]",
    )
    parser.add_option(
        "-g",
        dest="genome_file",
        default="%s/tutorials/data/human.hg19.genome" %
        os.environ["BASENJIDIR"],
        help="Chromosome length information [Default: %default]",
    )
    parser.add_option(
        "--mc",
        dest="mc_n",
        default=0,
        type="int",
        help="Monte carlo test iterations [Default: %default]",
    )
    parser.add_option(
        "-o",
        dest="out_dir",
        default="test_out",
        help="Output directory for test statistics [Default: %default]",
    )
    parser.add_option(
        "--rc",
        dest="rc",
        default=False,
        action="store_true",
        help="Average the fwd and rc predictions [Default: %default]",
    )
    parser.add_option(
        "--shifts",
        dest="shifts",
        default="0",
        help="Ensemble prediction shifts [Default: %default]",
    )
    parser.add_option(
        "-t",
        dest="targets_file",
        default=None,
        type="str",
        help="File specifying target indexes and labels in table format",
    )
    parser.add_option(
        "--ti",
        dest="track_indexes",
        help="Comma-separated list of target indexes to output BigWig tracks",
    )
    parser.add_option(
        "--tfr",
        dest="tfr_pattern",
        default="test-*.tfr",
        help="TFR pattern string appended to data_dir [Default: %default]",
    )
    (options, args) = parser.parse_args()

    if len(args) != 3:
        parser.error("Must provide parameters, model, and test data HDF5")
    else:
        params_file = args[0]
        model_file = args[1]
        data_dir = args[2]

    if options.track_bed is not None and options.sample_down < 1:
        parser.error("Cannot down sample and write BigWigs.")

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    # parse shifts to integers
    options.shifts = [int(shift) for shift in options.shifts.split(",")]

    # read targets
    if options.targets_file is None:
        options.targets_file = "%s/targets.txt" % data_dir
        targets_df = pd.read_table(options.targets_file, index_col=0)
        target_subset = None
    else:
        targets_df = pd.read_table(options.targets_file, index_col=0)
        target_subset = targets_df.index

    # read model parameters
    job = params.read_job_params(params_file)

    # construct data ops
    tfr_pattern_path = "%s/tfrecords/%s" % (data_dir, options.tfr_pattern)
    data_ops, test_init_op = make_data_ops(job, tfr_pattern_path)

    # initialize model
    model = seqnn.SeqNN()
    model.build_sad(
        job,
        data_ops,
        ensemble_rc=options.rc,
        ensemble_shifts=options.shifts,
        target_subset=target_subset,
    )

    # initialize saver
    saver = tf.train.Saver()

    with tf.Session() as sess:
        # start queue runners
        coord = tf.train.Coordinator()
        tf.train.start_queue_runners(coord=coord)

        # load variables into session
        saver.restore(sess, model_file)

        # test
        sess.run(test_init_op)
        test_preds = model.predict_tfr(sess, sample=options.sample_down)

        preds_h5 = h5py.File("%s/preds.h5" % options.out_dir, "w")
        preds_h5.create_dataset("preds", data=test_preds)
        preds_h5.close()

        # print normalization factors
        target_means = test_preds.mean(axis=(0, 1), dtype="float64")
        target_means_median = np.median(target_means)
        # target_means /= target_means_median
        norm_out = open("%s/normalization.txt" % options.out_dir, "w")
        # print('\n'.join([str(tu) for tu in target_means]), file=norm_out)
        for ti in range(len(target_means)):
            print(
                ti,
                target_means[ti],
                target_means_median / target_means[ti],
                file=norm_out,
            )
        norm_out.close()

    #######################################################
    # BigWig tracks

    # print bigwig tracks for visualization
    if options.track_bed:
        if options.genome_file is None:
            parser.error(
                "Must provide genome file in order to print valid BigWigs.")

        if not os.path.isdir("%s/tracks" % options.out_dir):
            os.mkdir("%s/tracks" % options.out_dir)

        track_indexes = range(test_preds.shape[2])
        if options.track_indexes:
            track_indexes = [
                int(ti) for ti in options.track_indexes.split(",")
            ]

        for ti in track_indexes:
            # make predictions bigwig
            bw_file = "%s/tracks/t%d_preds.bw" % (options.out_dir, ti)
            bigwig_write(
                bw_file,
                test_preds[:, :, ti],
                options.track_bed,
                options.genome_file,
                model.hp.batch_buffer,
            )
Esempio n. 15
0
def main():
    usage = 'usage: %prog [options] <params_tf1_file> <model_tf1_file> <model_tf2_h5>'
    parser = OptionParser(usage)
    parser.add_option('-f',
                      dest='final_slice',
                      default=None,
                      help='Final dense layer target slice, e.g. "0:1000"')
    (options, args) = parser.parse_args()

    if len(args) != 3:
        parser.error('Must provide TF1 params and .tf files and TF2 .h5 file.')
    else:
        params_tf1_file = args[0]
        model_tf1_file = args[1]
        model_tf2_h5_file = args[2]

    ################################################################
    # params

    job = params.read_job_params(params_tf1_file,
                                 require=['seq_length', 'num_targets'])

    ################################################################
    # dummy data

    def dummy_gen():
        for i in range(16):
            yield i

    data_types = {'sequence': tf.float32}
    data_shapes = {
        'sequence':
        tf.TensorShape([tf.Dimension(job['seq_length']),
                        tf.Dimension(4)])
    }

    dataset = tf.data.Dataset.from_generator(dummy_gen,
                                             output_types=data_types,
                                             output_shapes=data_shapes)
    dataset = dataset.batch(job['batch_size'])
    iterator = dataset.make_one_shot_iterator()
    data_ops = iterator.get_next()

    ################################################################
    # setup model

    model = seqnn.SeqNN()
    model.build_sad(job, data_ops)

    ################################################################
    # restore model and extract weights

    saver = tf.train.Saver()
    with tf.Session() as sess:
        saver.restore(sess, model_tf1_file)
        model1_vars = tf.global_variables()
        model1_weights = sess.run(model1_vars)

    ################################################################
    # write into tf2 hdf5

    model_tf2_h5 = h5py.File(model_tf2_h5_file, 'r+')

    for weight1_name, weight1_val in zip(model1_vars, model1_weights):
        print(weight1_name.name, weight1_val.shape)

        # skip step
        if weight1_name.name == 'global_step:0':
            continue

        weight1_split = weight1_name.name.split('/')
        weight2_split = ['model_weights']

        if weight1_split[0] == 'final':
            weight2_split += [weight1_split[1]] * 2

        else:
            li = int(weight1_split[0].replace('cnn', ''))
            if li == 0:
                weight2_split += [weight1_split[1]] * 2
            else:
                weight2_split += ['%s_%s' % (weight1_split[1], li)] * 2

        weight2_split.append(weight1_split[-1])

        weight2_name = '/'.join(weight2_split)
        print(weight2_name, model_tf2_h5[weight2_name].shape, '\n')

        if weight1_split[0] == 'final' and options.final_slice is not None:
            fs, fe = options.final_slice.split(':')
            fs, fe = int(fs), int(fe)
            model_tf2_h5[weight2_name][...] = weight1_val[..., fs:fe]
        else:
            model_tf2_h5[weight2_name][...] = weight1_val

    model_tf2_h5.close()
Esempio n. 16
0
def main():
  usage = 'usage: %prog [options] <params_file> <model_file> <vcf_file>'
  parser = OptionParser(usage)
  parser.add_option('-c', dest='center_pct',
      default=0.25, type='float',
      help='Require clustered SNPs lie in center region [Default: %default]')
  parser.add_option('-f', dest='genome_fasta',
      default='%s/data/hg19.fa' % os.environ['BASENJIDIR'],
      help='Genome FASTA for sequences [Default: %default]')
  parser.add_option('--flip', dest='flip_ref',
      default=False, action='store_true',
      help='Flip reference/alternate alleles when simple [Default: %default]')
  parser.add_option('--local', dest='local',
      default=1024, type='int',
      help='Local SAD score [Default: %default]')
  parser.add_option('-n', dest='norm_file',
      default=None,
      help='Normalize SAD scores')
  parser.add_option('-o',dest='out_dir',
      default='sad',
      help='Output directory for tables and plots [Default: %default]')
  parser.add_option('-p', dest='processes',
      default=None, type='int',
      help='Number of processes, passed by multi script')
  parser.add_option('--pseudo', dest='log_pseudo',
      default=1, type='float',
      help='Log2 pseudocount [Default: %default]')
  parser.add_option('--rc', dest='rc',
      default=False, action='store_true',
      help='Average forward and reverse complement predictions [Default: %default]')
  parser.add_option('--shifts', dest='shifts',
      default='0', type='str',
      help='Ensemble prediction shifts [Default: %default]')
  parser.add_option('--stats', dest='sad_stats',
      default='SAD',
      help='Comma-separated list of stats to save. [Default: %default]')
  parser.add_option('-t', dest='targets_file',
      default=None, type='str',
      help='File specifying target indexes and labels in table format')
  parser.add_option('--ti', dest='track_indexes',
      default=None, type='str',
      help='Comma-separated list of target indexes to output BigWig tracks')
  parser.add_option('-u', dest='penultimate',
      default=False, action='store_true',
      help='Compute SED in the penultimate layer [Default: %default]')
  (options, args) = parser.parse_args()

  if len(args) == 3:
    # single worker
    params_file = args[0]
    model_file = args[1]
    vcf_file = args[2]

  elif len(args) == 5:
    # multi worker
    options_pkl_file = args[0]
    params_file = args[1]
    model_file = args[2]
    vcf_file = args[3]
    worker_index = int(args[4])

    # load options
    options_pkl = open(options_pkl_file, 'rb')
    options = pickle.load(options_pkl)
    options_pkl.close()

    # update output directory
    options.out_dir = '%s/job%d' % (options.out_dir, worker_index)

  else:
    parser.error('Must provide parameters and model files and QTL VCF file')

  if not os.path.isdir(options.out_dir):
    os.mkdir(options.out_dir)

  if options.track_indexes is None:
    options.track_indexes = []
  else:
    options.track_indexes = [int(ti) for ti in options.track_indexes.split(',')]
    if not os.path.isdir('%s/tracks' % options.out_dir):
      os.mkdir('%s/tracks' % options.out_dir)

  options.shifts = [int(shift) for shift in options.shifts.split(',')]
  options.sad_stats = options.sad_stats.split(',')


  #################################################################
  # read parameters and collet target information

  job = params.read_job_params(params_file, require=['seq_length','num_targets'])

  if options.targets_file is None:
    target_ids = ['t%d' % ti for ti in range(job['num_targets'])]
    target_labels = ['']*len(target_ids)
    target_subset = None

  else:
    targets_df = pd.read_table(options.targets_file, index_col=0)
    target_ids = targets_df.identifier
    target_labels = targets_df.description
    target_subset = targets_df.index
    if len(target_subset) == job['num_targets']:
        target_subset = None


  #################################################################
  # load SNPs

  # read sorted SNPs from VCF
  snps = bvcf.vcf_snps(vcf_file, require_sorted=True, flip_ref=options.flip_ref,
                       validate_ref_fasta=options.genome_fasta)

  # filter for worker SNPs
  if options.processes is not None:
    worker_bounds = np.linspace(0, len(snps), options.processes+1, dtype='int')
    snps = snps[worker_bounds[worker_index]:worker_bounds[worker_index+1]]

  num_snps = len(snps)

  # cluster SNPs by position
  snp_clusters = cluster_snps(snps, job['seq_length'], options.center_pct)

  # delimit sequence boundaries
  [sc.delimit(job['seq_length']) for sc in snp_clusters]

  # open genome FASTA
  genome_open = pysam.Fastafile(options.genome_fasta)

  # make SNP sequence generator
  def snp_gen():
    for sc in snp_clusters:
      snp_1hot_list = sc.get_1hots(genome_open)
      for snp_1hot in snp_1hot_list:
        yield {'sequence':snp_1hot}

  snp_types = {'sequence': tf.float32}
  snp_shapes = {'sequence': tf.TensorShape([tf.Dimension(job['seq_length']),
                                            tf.Dimension(4)])}

  dataset = tf.data.Dataset.from_generator(snp_gen,
                                          output_types=snp_types,
                                          output_shapes=snp_shapes)
  dataset = dataset.batch(job['batch_size'])
  dataset = dataset.prefetch(2*job['batch_size'])
  # dataset = dataset.apply(tf.contrib.data.prefetch_to_device('/device:GPU:0'))

  iterator = dataset.make_one_shot_iterator()
  data_ops = iterator.get_next()


  #################################################################
  # setup model

  # build model
  t0 = time.time()
  model = seqnn.SeqNN()
  model.build_sad(job, data_ops,
                  ensemble_rc=options.rc, ensemble_shifts=options.shifts,
                  embed_penultimate=options.penultimate, target_subset=target_subset)
  print('Model building time %f' % (time.time() - t0), flush=True)

  if options.penultimate:
    # labels become inappropriate
    target_ids = ['']*model.hp.cnn_filters[-1]
    target_labels = target_ids

  # read target normalization factors
  target_norms = np.ones(len(target_labels))
  if options.norm_file is not None:
    ti = 0
    for line in open(options.norm_file):
      target_norms[ti] = float(line.strip())
      ti += 1

  num_targets = len(target_ids)

  #################################################################
  # setup output

  snp_flips = np.array([snp.flipped for snp in snps], dtype='bool')

  sad_out = initialize_output_h5(options.out_dir, options.sad_stats,
                                 snps, target_ids, target_labels)

  snp_threads = []

  snp_queue = Queue()
  for i in range(1):
    sw = SNPWorker(snp_queue, sad_out, options.sad_stats, options.log_pseudo)
    sw.start()
    snp_threads.append(sw)

  #################################################################
  # predict SNP scores, write output

  # initialize saver
  saver = tf.train.Saver()
  with tf.Session() as sess:
    # load variables into session
    saver.restore(sess, model_file)

    # initialize predictions stream
    preds_stream = PredStream(sess, model, 32)

    # predictions index
    pi = 0

    # SNP index
    si = 0

    for snp_cluster in snp_clusters:
      ref_preds = preds_stream[pi]
      pi += 1

      for snp in snp_cluster.snps:
        # print(snp, flush=True)

        alt_preds = preds_stream[pi]
        pi += 1

        # queue SNP
        if snp_flips[si]:
          snp_queue.put((alt_preds, ref_preds, si))
        else:
          snp_queue.put((ref_preds, alt_preds, si))

        # update SNP index
        si += 1

  # finish queue
  print('Waiting for threads to finish.', flush=True)
  snp_queue.join()

  # close genome
  genome_open.close()

  ###################################################
  # compute SAD distributions across variants

  # define percentiles
  d_fine = 0.001
  d_coarse = 0.01
  percentiles_neg = np.arange(d_fine, 0.1, d_fine)
  percentiles_base = np.arange(0.1, 0.9, d_coarse)
  percentiles_pos = np.arange(0.9, 1, d_fine)

  percentiles = np.concatenate([percentiles_neg, percentiles_base, percentiles_pos])
  sad_out.create_dataset('percentiles', data=percentiles)
  pct_len = len(percentiles)

  for sad_stat in options.sad_stats:
    sad_stat_pct = '%s_pct' % sad_stat

    # compute
    sad_pct = np.percentile(sad_out[sad_stat], 100*percentiles, axis=0).T
    sad_pct = sad_pct.astype('float16')

    # save
    sad_out.create_dataset(sad_stat_pct, data=sad_pct, dtype='float16')

  sad_out.close()
Esempio n. 17
0
def main():
    usage = 'usage: %prog [options] <params_file> <data1_dir> <data2_dir> ...'
    parser = OptionParser(usage)
    parser.add_option(
        '-o',
        dest='out_dir',
        default='train2_out',
        help='Output directory for test statistics [Default: %default]')
    parser.add_option(
        '--restore',
        dest='restore',
        help='Restore model and continue training [Default: %default]')
    parser.add_option('--trunk',
                      dest='trunk',
                      default=False,
                      action='store_true',
                      help='Restore only model trunk [Default: %default]')
    parser.add_option(
        '--tfr_train',
        dest='tfr_train_pattern',
        default=None,
        help=
        'Training TFR pattern string appended to data_dir/tfrecords for subsetting [Default: %default]'
    )
    parser.add_option(
        '--tfr_eval',
        dest='tfr_eval_pattern',
        default=None,
        help=
        'Evaluation TFR pattern string appended to data_dir/tfrecords for subsetting [Default: %default]'
    )
    (options, args) = parser.parse_args()

    if len(args) < 2:
        parser.error('Must provide parameters and data directory.')
    else:
        params_file = args[0]
        data_dirs = args[1:]

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)
    if params_file != '%s/params.json' % options.out_dir:
        shutil.copy(params_file, '%s/params.json' % options.out_dir)

    # read model parameters
    with open(params_file) as params_open:
        params = json.load(params_open)
    params_model = params['model']
    params_train = params['train']

    # read datasets
    train_data = []
    eval_data = []

    for data_dir in data_dirs:
        # load train data
        train_data.append(
            dataset.SeqDataset(data_dir,
                               split_label='train',
                               batch_size=params_train['batch_size'],
                               mode=tf.estimator.ModeKeys.TRAIN,
                               tfr_pattern=options.tfr_train_pattern))

        # load eval data
        eval_data.append(
            dataset.SeqDataset(data_dir,
                               split_label='valid',
                               batch_size=params_train['batch_size'],
                               mode=tf.estimator.ModeKeys.EVAL,
                               tfr_pattern=options.tfr_eval_pattern))

    if params_train.get('num_gpu', 1) == 1:
        ########################################
        # one GPU

        # initialize model
        seqnn_model = seqnn.SeqNN(params_model)

        # restore
        if options.restore:
            seqnn_model.restore(options.restore, options.trunk)

        # initialize trainer
        seqnn_trainer = trainer.Trainer(params_train, train_data, eval_data,
                                        options.out_dir)

        # compile model
        seqnn_trainer.compile(seqnn_model)

        # train model
        seqnn_trainer.fit2(seqnn_model)

    else:
        ########################################
        # two GPU

        strategy = tf.distribute.MirroredStrategy()
        with strategy.scope():

            # distribute data
            for di in range(len(data_dirs)):
                train_data[di].distribute(strategy)
                eval_data[di].distribute(strategy)

            # initialize model
            seqnn_model = seqnn.SeqNN(params_model)

            # restore
            if options.restore:
                seqnn_model.restore(options.restore, options.trunk)

            # initialize trainer
            seqnn_trainer = trainer.Trainer(params_train, train_data,
                                            eval_data, options.out_dir,
                                            strategy, params_train['num_gpu'])

            # compile model
            seqnn_trainer.compile(seqnn_model)

        # train model
        seqnn_trainer.fit2(seqnn_model)
Esempio n. 18
0
def main():
    usage = "usage: %prog [options] <params_file> <model_file> <genes_hdf5_file>"
    parser = OptionParser(usage)
    parser.add_option(
        "-g",
        dest="genome_file",
        default="%s/data/human.hg38.genome" % os.environ["BASENJIDIR"],
        help="Chromosome lengths file [Default: %default]",
    )
    parser.add_option(
        "-l", dest="gene_list", help="Process only gene ids in the given file"
    )
    parser.add_option(
        "--mc",
        dest="mc_n",
        default=0,
        type="int",
        help="Monte carlo test iterations [Default: %default]",
    )
    parser.add_option(
        "-n",
        dest="norm",
        default=None,
        type="int",
        help="Compute saliency norm [Default% default]",
    )
    parser.add_option(
        "-o",
        dest="out_dir",
        default="grad_map",
        help="Output directory [Default: %default]",
    )
    parser.add_option(
        "--rc",
        dest="rc",
        default=False,
        action="store_true",
        help="Average the forward and reverse complement predictions when testing [Default: %default]",
    )
    parser.add_option(
        "--shifts",
        dest="shifts",
        default="0",
        help="Ensemble prediction shifts [Default: %default]",
    )
    parser.add_option(
        "-t", dest="target_indexes", default=None, help="Target indexes to plot"
    )
    (options, args) = parser.parse_args()

    if len(args) != 3:
        parser.error("Must provide parameters, model, and genomic position")
    else:
        params_file = args[0]
        model_file = args[1]
        genes_hdf5_file = args[2]

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    options.shifts = [int(shift) for shift in options.shifts.split(",")]

    #################################################################
    # reads in genes HDF5

    gene_data = genedata.GeneData(genes_hdf5_file)

    # subset gene sequences
    genes_subset = set()
    if options.gene_list:
        for line in open(options.gene_list):
            genes_subset.add(line.rstrip())

        gene_data.subset_genes(genes_subset)
        print("Filtered to %d sequences" % gene_data.num_seqs)

    # extract sequence chrom and start
    seqs_chrom = [gene_data.gene_seqs[si].chrom for si in range(gene_data.num_seqs)]
    seqs_start = [gene_data.gene_seqs[si].start for si in range(gene_data.num_seqs)]

    #######################################################
    # model parameters and placeholders

    job = params.read_job_params(params_file)

    job["seq_length"] = gene_data.seq_length
    job["seq_depth"] = gene_data.seq_depth
    job["target_pool"] = gene_data.pool_width

    if "num_targets" not in job:
        print(
            "Must specify number of targets (num_targets) in the parameters file.",
            file=sys.stderr,
        )
        exit(1)

    # set target indexes
    if options.target_indexes is not None:
        options.target_indexes = [int(ti) for ti in options.target_indexes.split(",")]
        target_subset = options.target_indexes
    else:
        options.target_indexes = list(range(job["num_targets"]))
        target_subset = None

    # build model
    model = seqnn.SeqNN()
    model.build_feed(job, target_subset=target_subset)

    # determine latest pre-dilated layer
    cnn_dilation = np.array([cp.dilation for cp in model.hp.cnn_params])
    dilated_mask = cnn_dilation > 1
    dilated_indexes = np.where(dilated_mask)[0]
    pre_dilated_layer = np.min(dilated_indexes)
    print("Pre-dilated layer: %d" % pre_dilated_layer)

    # build gradients ops
    t0 = time.time()
    print("Building target/position-specific gradient ops.", end="")
    model.build_grads(layers=[pre_dilated_layer])
    print(" Done in %ds" % (time.time() - t0), flush=True)

    #######################################################
    # acquire gradients

    # initialize saver
    saver = tf.train.Saver()

    with tf.Session() as sess:
        # load variables into session
        saver.restore(sess, model_file)

        # score sequences and write bigwigs
        score_write(sess, model, options, gene_data.seqs_1hot, seqs_chrom, seqs_start)
Esempio n. 19
0
def main():
    usage = 'usage: %prog [options] <params_file> <model_file> <bed_file>'
    parser = OptionParser(usage)
    parser.add_option('-f',
                      dest='genome_fasta',
                      default=None,
                      help='Genome FASTA for sequences [Default: %default]')
    parser.add_option(
        '-l',
        dest='mut_len',
        default=200,
        type='int',
        help='Length of center sequence to mutate [Default: %default]')
    parser.add_option('-o',
                      dest='out_dir',
                      default='sat_mut',
                      help='Output directory [Default: %default]')
    parser.add_option('--plots',
                      dest='plots',
                      default=False,
                      action='store_true',
                      help='Make heatmap plots [Default: %default]')
    parser.add_option('-p',
                      dest='processes',
                      default=None,
                      type='int',
                      help='Number of processes, passed by multi script')
    parser.add_option(
        '--rc',
        dest='rc',
        default=False,
        action='store_true',
        help=
        'Ensemble forward and reverse complement predictions [Default: %default]'
    )
    parser.add_option('--shifts',
                      dest='shifts',
                      default='0',
                      help='Ensemble prediction shifts [Default: %default]')
    parser.add_option(
        '-t',
        dest='targets_file',
        default=None,
        type='str',
        help='File specifying target indexes and labels in table format')
    (options, args) = parser.parse_args()

    if len(args) == 3:
        # single worker
        params_file = args[0]
        model_file = args[1]
        bed_file = args[2]

    elif len(args) == 5:
        # multi worker
        options_pkl_file = args[0]
        params_file = args[1]
        model_file = args[2]
        bed_file = args[3]
        worker_index = int(args[4])

        # load options
        options_pkl = open(options_pkl_file, 'rb')
        options = pickle.load(options_pkl)
        options_pkl.close()

        # update output directory
        options.out_dir = '%s/job%d' % (options.out_dir, worker_index)

    else:
        parser.error(
            'Must provide parameters and model files and QTL BED file')

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    options.shifts = [int(shift) for shift in options.shifts.split(',')]

    #################################################################
    # read parameters and collet target information

    job = params.read_job_params(params_file)

    if options.targets_file is None:
        target_ids = ['t%d' % ti for ti in range(job['num_targets'])]
        target_labels = [''] * len(target_ids)
        target_subset = None

    else:
        targets_df = pd.read_table(options.targets_file, index_col=0)
        target_ids = targets_df.identifier
        target_labels = targets_df.description
        target_subset = targets_df.index
        if len(target_subset) == job['num_targets']:
            target_subset = None

    num_targets = len(target_ids)

    #################################################################
    # sequence dataset

    # read sequences from BED
    seqs_dna, seqs_coords = bed_seqs(bed_file, options.genome_fasta,
                                     job['seq_length'])

    # filter for worker SNPs
    if options.processes is not None:
        worker_bounds = np.linspace(0,
                                    len(seqs_dna),
                                    options.processes + 1,
                                    dtype='int')
        seqs_dna = seqs_dna[
            worker_bounds[worker_index]:worker_bounds[worker_index + 1]]
        seqs_coords = seqs_coords[
            worker_bounds[worker_index]:worker_bounds[worker_index + 1]]

    num_seqs = len(seqs_dna)

    # determine mutation region limits
    seq_mid = job['seq_length'] // 2
    mut_start = seq_mid - options.mut_len // 2
    mut_end = mut_start + options.mut_len

    # make data ops
    data_ops = satmut_data_ops(seqs_dna, mut_start, mut_end, job['batch_size'])

    #################################################################
    # setup model

    # build model
    model = seqnn.SeqNN()
    model.build_sad(job,
                    data_ops,
                    target_subset=target_subset,
                    ensemble_rc=options.rc,
                    ensemble_shifts=options.shifts)

    #################################################################
    # setup output

    scores_h5_file = '%s/scores.h5' % options.out_dir
    if os.path.isfile(scores_h5_file):
        os.remove(scores_h5_file)
    scores_h5 = h5py.File('%s/scores.h5' % options.out_dir)
    scores_h5.create_dataset('scores',
                             dtype='float16',
                             shape=(num_seqs, options.mut_len, 4, num_targets))
    scores_h5.create_dataset('seqs',
                             dtype='bool',
                             shape=(num_seqs, options.mut_len, 4))

    # store mutagenesis sequence coordinates
    seqs_chr, seqs_start, _ = zip(*seqs_coords)
    seqs_chr = np.array(seqs_chr, dtype='S')
    seqs_start = np.array(seqs_start) + mut_start
    seqs_end = seqs_start + options.mut_len
    scores_h5.create_dataset('chrom', data=seqs_chr)
    scores_h5.create_dataset('start', data=seqs_start)
    scores_h5.create_dataset('end', data=seqs_end)

    preds_per_seq = 1 + 3 * options.mut_len

    score_threads = []
    score_queue = Queue()
    for i in range(1):
        sw = ScoreWorker(score_queue, scores_h5)
        sw.start()
        score_threads.append(sw)

    #################################################################
    # predict scores, write output

    # initialize saver
    saver = tf.train.Saver()

    with tf.Session() as sess:
        # coordinator
        coord = tf.train.Coordinator()
        tf.train.start_queue_runners(coord=coord)

        # load variables into session
        saver.restore(sess, model_file)

        # initialize predictions stream
        preds_stream = PredStream(sess, model, 32)

        # predictions index
        pi = 0

        for si in range(num_seqs):
            print('Predicting %d' % si, flush=True)

            # collect sequence predictions
            seq_preds = []
            for spi in range(preds_per_seq):
                seq_preds.append(preds_stream[pi])
                pi += 1

            # wait for previous to finish
            score_queue.join()

            # queue sequence for scoring
            score_queue.put((seqs_dna[si], seq_preds, si))

            # queue sequence for plotting
            if options.plots:
                plot_queue.put((seqs_dna[si], seq_preds, si))

    # finish queue
    print('Waiting for threads to finish.', flush=True)
    score_queue.join()

    # close output HDF5
    scores_h5.close()
Esempio n. 20
0
def main():
    usage = 'usage: %prog [options] <params_file> <model_file> <bed_file>'
    parser = OptionParser(usage)
    parser.add_option(
        '-b',
        dest='bigwig_indexes',
        default=None,
        help='Comma-separated list of target indexes to write BigWigs')
    parser.add_option('-f',
                      dest='genome_fasta',
                      default=None,
                      help='Genome FASTA for sequences [Default: %default]')
    parser.add_option('-g',
                      dest='genome_file',
                      default=None,
                      help='Chromosome length information [Default: %default]')
    # parser.add_option('-l', dest='mid_len',
    #     default=256, type='int',
    #     help='Length of center sequence to sum predictions for [Default: %default]')
    parser.add_option('-o',
                      dest='out_dir',
                      default='pred_out',
                      help='Output directory [Default: %default]')
    # parser.add_option('--plots', dest='plots',
    #     default=False, action='store_true',
    #     help='Make heatmap plots [Default: %default]')
    parser.add_option('-p',
                      dest='processes',
                      default=None,
                      type='int',
                      help='Number of processes, passed by multi script')
    parser.add_option(
        '--rc',
        dest='rc',
        default=False,
        action='store_true',
        help=
        'Ensemble forward and reverse complement predictions [Default: %default]'
    )
    parser.add_option('--shifts',
                      dest='shifts',
                      default='0',
                      help='Ensemble prediction shifts [Default: %default]')
    parser.add_option(
        '-t',
        dest='targets_file',
        default=None,
        type='str',
        help='File specifying target indexes and labels in table format')
    (options, args) = parser.parse_args()

    if len(args) == 3:
        params_file = args[0]
        model_file = args[1]
        bed_file = args[2]

    elif len(args) == 5:
        # multi worker
        options_pkl_file = args[0]
        params_file = args[1]
        model_file = args[2]
        bed_file = args[3]
        worker_index = int(args[4])

        # load options
        options_pkl = open(options_pkl_file, 'rb')
        options = pickle.load(options_pkl)
        options_pkl.close()

        # update output directory
        options.out_dir = '%s/job%d' % (options.out_dir, worker_index)
    else:
        parser.error('Must provide parameter and model files and BED file')

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    options.shifts = [int(shift) for shift in options.shifts.split(',')]

    if options.bigwig_indexes is not None:
        options.bigwig_indexes = [
            int(bi) for bi in options.bigwig_indexes.split(',')
        ]
    else:
        options.bigwig_indexes = []

    if len(options.bigwig_indexes) > 0:
        bigwig_dir = '%s/bigwig' % options.out_dir
        if not os.path.isdir(bigwig_dir):
            os.mkdir(bigwig_dir)

    #################################################################
    # read parameters and collet target information

    job = params.read_job_params(params_file,
                                 require=['num_targets', 'seq_length'])

    num_targets = np.sum(job['num_targets'])
    if options.targets_file is None:
        target_subset = None
    else:
        targets_df = pd.read_table(options.targets_file, index_col=0)
        target_subset = targets_df.index
        if len(target_subset) == num_targets:
            target_subset = None
        else:
            num_targets = len(target_subset)

    #################################################################
    # sequence dataset

    # read sequences from BED
    seqs_dna, seqs_coords = bed_seqs(bed_file, options.genome_fasta,
                                     job['seq_length'])

    # filter for worker SNPs
    if options.processes is not None:
        worker_bounds = np.linspace(0,
                                    len(seqs_dna),
                                    options.processes + 1,
                                    dtype='int')
        seqs_dna = seqs_dna[
            worker_bounds[worker_index]:worker_bounds[worker_index + 1]]
        seqs_coords = seqs_coords[
            worker_bounds[worker_index]:worker_bounds[worker_index + 1]]

    num_seqs = len(seqs_dna)

    # make data ops
    data_ops = seq_data_ops(seqs_dna, job['batch_size'])

    #################################################################
    # setup model

    # build model
    model = seqnn.SeqNN()
    model.build_sad(job,
                    data_ops,
                    ensemble_rc=options.rc,
                    ensemble_shifts=options.shifts,
                    target_subset=target_subset)

    #################################################################
    # setup output

    out_h5_file = '%s/predict.h5' % options.out_dir
    if os.path.isfile(out_h5_file):
        os.remove(out_h5_file)
    out_h5 = h5py.File(out_h5_file, 'w')
    out_h5.create_dataset('preds',
                          shape=(num_seqs, num_targets),
                          dtype='float16')

    # store sequence coordinates
    seqs_chr, seqs_start, _ = zip(*seqs_coords)
    seqs_chr = np.array(seqs_chr, dtype='S')
    seqs_start = np.array(seqs_start)
    seqs_end = seqs_start + job['seq_length']
    out_h5.create_dataset('chrom', data=seqs_chr)
    out_h5.create_dataset('start', data=seqs_start)
    out_h5.create_dataset('end', data=seqs_end)

    if model.preds_length % 2 == 0:
        # sum center two
        mid_start = model.preds_length // 2 - 1
        mid_end = mid_start + 2
    else:
        # take center one
        mid_start = model.preds_length // 2
        mid_end = mid_start + 1

    #################################################################
    # predict scores, write output

    # initialize saver
    saver = tf.train.Saver()

    with tf.Session() as sess:
        # coordinator
        coord = tf.train.Coordinator()
        tf.train.start_queue_runners(coord=coord)

        # load variables into session
        saver.restore(sess, model_file)

        # initialize predictions stream
        preds_stream = PredStream(sess, model, 64)

        for si in range(num_seqs):
            print('Predicting %d' % si, flush=True)

            # predict
            preds_full = preds_stream[si]

            # slice middle and summarize
            preds = preds_full[mid_start:mid_end, :].sum(axis=0)

            # write
            out_h5['preds'][si] = preds

            # write bigwig
            for ti in options.bigwig_indexes:
                bw_file = '%s/s%d_t%d.bw' % (bigwig_dir, si, ti)
                bigwig_write(preds_full[:, ti], seqs_coords[si], bw_file,
                             options.genome_file, model.hp.batch_buffer)

    # close output HDF5
    out_h5.close()
Esempio n. 21
0
def main():
    usage = 'usage: %prog [options] <params_file> <model_file> <bed_file>'
    parser = OptionParser(usage)
    parser.add_option(
        '-a',
        dest='align_seqlets_shift',
        default=0,
        type='int',
        help='Align seqlets, expecting the specified shift [Default: %default]'
    )
    parser.add_option('-b',
                      dest='background_fasta',
                      default=None,
                      help='Homer background FASTA.')
    parser.add_option('-d',
                      dest='meme_db',
                      default='%s/data/motifs/Homo_sapiens.meme' %
                      os.environ['BASSETDIR'],
                      help='MEME database used to annotate motifs')
    parser.add_option('-e',
                      dest='embed_layer',
                      default=None,
                      type='int',
                      help='Embed sequences using the specified layer index.')
    parser.add_option('-f',
                      dest='genome_fasta',
                      default=None,
                      help='Genome FASTA for sequences [Default: %default]')
    parser.add_option(
        '-l',
        dest='site_length',
        default=None,
        type='int',
        help='Prediction site length. [Default: params.seq_length]')
    parser.add_option('-o',
                      dest='out_dir',
                      default='motifs_out',
                      help='Output directory [Default: %default]')
    parser.add_option('-p',
                      dest='predict_h5_file',
                      default=None,
                      help='basenji_predict output HDF5.')
    parser.add_option(
        '-r',
        dest='range_step',
        default=1,
        type='int',
        help='Range step for using activation values [Default: %default]')
    parser.add_option(
        '-s',
        dest='seqlet_length',
        default=20,
        type='int',
        help='Seqlet length to extract for motif analysis [Default: %default]')
    parser.add_option('-t',
                      dest='threads',
                      default=1,
                      type='int',
                      help='Number of threads [Default: %default]')
    (options, args) = parser.parse_args()

    if len(args) == 3:
        params_file = args[0]
        model_file = args[1]
        bed_file = args[2]
    else:
        parser.error('Must provide parameter and model files and BED file')

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    ################################################################
    # predict

    if options.predict_h5_file is None:
        cmd = 'basenji_predict_bed.py'
        cmd += ' -e %d' % options.embed_layer
        cmd += ' -f %s' % options.genome_fasta
        cmd += ' -l %d' % options.site_length
        cmd += ' -o %s' % options.out_dir
        subprocess.call(cmd, shell=True)
        options.predict_h5_file = '%s/predict.h5' % options.out_dir

    ################################################################
    # read model, sequences, and predictions

    # read params
    with open(params_file) as params_open:
        params = json.load(params_open)

    # build model
    seqnn_model = seqnn.SeqNN(params['model'])
    seqnn_model.restore(model_file)

    # read predictions
    seqlet_acts, seqlet_intervals = read_preds(options.predict_h5_file,
                                               range_step=options.range_step)
    seqlet_acts = np.clip(seqlet_acts, 0, np.inf)

    # transform seqlets w/ options.seqlet_length
    seqlet_intervals = extend_intervals(seqlet_intervals,
                                        options.seqlet_length)

    # remove seqlets beyond
    seqlet_acts, seqlet_intervals = filter_seqlets(seqlet_acts,
                                                   seqlet_intervals,
                                                   options.genome_fasta)

    # extract seqlet DNA
    fasta_open = pysam.Fastafile(options.genome_fasta)
    seqlet_dna = [
        fasta_open.fetch(sint.chr, sint.start, sint.end)
        for sint in seqlet_intervals
    ]
    fasta_open.close()

    # construct negative seqlets for motif analysis
    negatives_fasta_file = '%s/seqlets_neg.fa' % options.out_dir
    make_negative_fasta(seqlet_intervals, seqlet_dna, negatives_fasta_file)

    # remove uninformative filters
    seqlet_acts, feature_mask = filter_features(seqlet_acts, return_mask=True)

    ################################################################
    # features

    features_out_dir = '%s/features' % options.out_dir
    if not os.path.isdir(features_out_dir):
        os.mkdir(features_out_dir)

    # read weights
    kernel_weights = seqnn_model.get_conv_weights(options.embed_layer)

    feature_args = []

    sfi = 0
    for fi in range(len(feature_mask)):
        # print('feature %d' % fi)
        if feature_mask[fi]:
            fa = (seqlet_acts[:, sfi], seqlet_dna, kernel_weights[fi],
                  '%s/f%d' % (features_out_dir, fi), negatives_fasta_file,
                  options.align_seqlets_shift, options.meme_db)
            feature_args.append(fa)
            sfi += 1

    if options.threads == 1:
        for fa in feature_args:
            process_feature(*fa)
    else:
        mp = multiprocessing.Pool(options.threads)
        mp.starmap(process_feature, feature_args)

    annotate_motifs(features_out_dir, negatives_fasta_file, options.meme_db)

    ################################################################
    # factorized

    factors_out_dir = '%s/factors' % options.out_dir
    if not os.path.isdir(factors_out_dir):
        os.mkdir(factors_out_dir)

    num_factors = seqlet_acts.shape[1] // 2
    t0 = time.time()
    print('Computing NMF...', end='')
    # seqlet_nmf = nimfa.Nmf(seqlet_acts, rank=num_factors)()
    seqlet_W, seqlet_H = nmf.compute_rnmf(seqlet_acts, rank=num_factors)
    print('done in %ds' % (time.time() - t0))

    factor_args = []
    for fi in range(num_factors):
        fa = (seqlet_H[fi, :], seqlet_W[:, fi], seqlet_dna, feature_mask,
              '%s/f%d' % (factors_out_dir, fi), negatives_fasta_file,
              options.align_seqlets_shift, options.meme_db)
        factor_args.append(fa)

    if options.threads == 1:
        for fa in factor_args:
            process_factor(*fa)
    else:
        mp.starmap(process_factor, factor_args)

    annotate_motifs(factors_out_dir, negatives_fasta_file, options.meme_db)
Esempio n. 22
0
def main():
    usage = 'usage: %prog [options] <params_file> <model_file> <vcf_file>'
    parser = OptionParser(usage)
    parser.add_option(
        '-c',
        dest='center_pct',
        default=0.25,
        type='float',
        help='Require clustered SNPs lie in center region [Default: %default]')
    parser.add_option('-f',
                      dest='genome_fasta',
                      default='%s/data/hg19.fa' % os.environ['BASENJIDIR'],
                      help='Genome FASTA for sequences [Default: %default]')
    parser.add_option(
        '--flip',
        dest='flip_ref',
        default=False,
        action='store_true',
        help='Flip reference/alternate alleles when simple [Default: %default]'
    )
    parser.add_option('--local',
                      dest='local',
                      default=1024,
                      type='int',
                      help='Local SAD score [Default: %default]')
    parser.add_option('-n',
                      dest='norm_file',
                      default=None,
                      help='Normalize SAD scores')
    parser.add_option(
        '-o',
        dest='out_dir',
        default='sad',
        help='Output directory for tables and plots [Default: %default]')
    parser.add_option('-p',
                      dest='processes',
                      default=None,
                      type='int',
                      help='Number of processes, passed by multi script')
    parser.add_option('--pseudo',
                      dest='log_pseudo',
                      default=1,
                      type='float',
                      help='Log2 pseudocount [Default: %default]')
    parser.add_option(
        '--rc',
        dest='rc',
        default=False,
        action='store_true',
        help=
        'Average forward and reverse complement predictions [Default: %default]'
    )
    parser.add_option('--shifts',
                      dest='shifts',
                      default='0',
                      type='str',
                      help='Ensemble prediction shifts [Default: %default]')
    parser.add_option(
        '--stats',
        dest='sad_stats',
        default='SAD',
        help='Comma-separated list of stats to save. [Default: %default]')
    parser.add_option(
        '-t',
        dest='targets_file',
        default=None,
        type='str',
        help='File specifying target indexes and labels in table format')
    parser.add_option(
        '--ti',
        dest='track_indexes',
        default=None,
        type='str',
        help='Comma-separated list of target indexes to output BigWig tracks')
    parser.add_option(
        '--threads',
        dest='threads',
        default=False,
        action='store_true',
        help='Run CPU math and output in a separate thread [Default: %default]'
    )
    parser.add_option(
        '-u',
        dest='penultimate',
        default=False,
        action='store_true',
        help='Compute SED in the penultimate layer [Default: %default]')
    (options, args) = parser.parse_args()

    if len(args) == 3:
        # single worker
        params_file = args[0]
        model_file = args[1]
        vcf_file = args[2]

    elif len(args) == 5:
        # multi worker
        options_pkl_file = args[0]
        params_file = args[1]
        model_file = args[2]
        vcf_file = args[3]
        worker_index = int(args[4])

        # load options
        options_pkl = open(options_pkl_file, 'rb')
        options = pickle.load(options_pkl)
        options_pkl.close()

        # update output directory
        options.out_dir = '%s/job%d' % (options.out_dir, worker_index)

    else:
        parser.error(
            'Must provide parameters and model files and QTL VCF file')

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    if options.track_indexes is None:
        options.track_indexes = []
    else:
        options.track_indexes = [
            int(ti) for ti in options.track_indexes.split(',')
        ]
        if not os.path.isdir('%s/tracks' % options.out_dir):
            os.mkdir('%s/tracks' % options.out_dir)

    options.shifts = [int(shift) for shift in options.shifts.split(',')]
    options.sad_stats = options.sad_stats.split(',')

    #################################################################
    # read parameters and targets

    # read model parameters
    with open(params_file) as params_open:
        params = json.load(params_open)
    params_model = params['model']
    params_train = params['train']

    if options.targets_file is None:
        target_slice = None
    else:
        targets_df = pd.read_csv(options.targets_file, sep='\t', index_col=0)
        target_ids = targets_df.identifier
        target_labels = targets_df.description
        target_slice = targets_df.index

    if options.penultimate:
        parser.error('Not implemented for TF2')

    #################################################################
    # setup model

    seqnn_model = seqnn.SeqNN(params_model)
    seqnn_model.restore(model_file)
    seqnn_model.build_slice(target_slice)
    seqnn_model.build_ensemble(options.rc, options.shifts)

    num_targets = seqnn_model.num_targets()
    if options.targets_file is None:
        target_ids = ['t%d' % ti for ti in range(num_targets)]
        target_labels = [''] * len(target_ids)

    #################################################################
    # load SNPs

    # filter for worker SNPs
    if options.processes is not None:
        # determine boundaries
        num_snps = bvcf.vcf_count(vcf_file)
        worker_bounds = np.linspace(0,
                                    num_snps,
                                    options.processes + 1,
                                    dtype='int')

        # read sorted SNPs from VCF
        snps = bvcf.vcf_snps(vcf_file,
                             require_sorted=True,
                             flip_ref=options.flip_ref,
                             validate_ref_fasta=options.genome_fasta,
                             start_i=worker_bounds[worker_index],
                             end_i=worker_bounds[worker_index + 1])
    else:
        # read sorted SNPs from VCF
        snps = bvcf.vcf_snps(vcf_file,
                             require_sorted=True,
                             flip_ref=options.flip_ref,
                             validate_ref_fasta=options.genome_fasta)

    # cluster SNPs by position
    snp_clusters = cluster_snps(snps, params_model['seq_length'],
                                options.center_pct)

    # delimit sequence boundaries
    [sc.delimit(params_model['seq_length']) for sc in snp_clusters]

    # open genome FASTA
    genome_open = pysam.Fastafile(options.genome_fasta)

    # make SNP sequence generator
    def snp_gen():
        for sc in snp_clusters:
            snp_1hot_list = sc.get_1hots(genome_open)
            for snp_1hot in snp_1hot_list:
                yield snp_1hot

    #################################################################
    # setup output

    snp_flips = np.array([snp.flipped for snp in snps], dtype='bool')

    sad_out = initialize_output_h5(options.out_dir, options.sad_stats, snps,
                                   target_ids, target_labels)

    if options.threads:
        snp_threads = []
        snp_queue = Queue()
        for i in range(1):
            sw = SNPWorker(snp_queue, sad_out, options.sad_stats,
                           options.log_pseudo)
            sw.start()
            snp_threads.append(sw)

    #################################################################
    # predict SNP scores, write output

    # initialize predictions stream
    preds_stream = stream.PredStreamGen(seqnn_model, snp_gen(),
                                        params['train']['batch_size'])

    # predictions index
    pi = 0

    # SNP index
    si = 0

    for snp_cluster in snp_clusters:
        ref_preds = preds_stream[pi]
        pi += 1

        for snp in snp_cluster.snps:
            # print(snp, flush=True)

            alt_preds = preds_stream[pi]
            pi += 1

            if snp_flips[si]:
                ref_preds, alt_preds = alt_preds, ref_preds

            if options.threads:
                # queue SNP
                snp_queue.put((ref_preds, alt_preds, si))
            else:
                # process SNP
                write_snp(ref_preds, alt_preds, sad_out, si, options.sad_stats,
                          options.log_pseudo)

            # update SNP index
            si += 1

    # finish queue
    if options.threads:
        print('Waiting for threads to finish.', flush=True)
        snp_queue.join()

    # close genome
    genome_open.close()

    ###################################################
    # compute SAD distributions across variants

    write_pct(sad_out, options.sad_stats)
    sad_out.close()
Esempio n. 23
0
def main():
  usage = 'usage: %prog [options] <params_file> <model_file> <genes_hdf5_file>'
  parser = OptionParser(usage)
  parser.add_option('-g', dest='genome_file',
      default='%s/assembly/human.hg19.genome' % os.environ['HG19'],
      help='Chromosome lengths file [Default: %default]')
  parser.add_option('-l', dest='gene_list',
      help='Process only gene ids in the given file')
  parser.add_option('--mc', dest='mc_n',
      default=0, type='int',
      help='Monte carlo test iterations [Default: %default]')
  parser.add_option('-n', dest='norm',
      default=None, type='int',
      help='Compute saliency norm [Default% default]')
  parser.add_option('-o', dest='out_dir',
      default='grad_map',
      help='Output directory [Default: %default]')
  parser.add_option('--rc', dest='rc',
      default=False, action='store_true',
      help='Average the forward and reverse complement predictions when testing [Default: %default]')
  parser.add_option('--shifts', dest='shifts',
      default='0',
      help='Ensemble prediction shifts [Default: %default]')
  parser.add_option('-t', dest='target_indexes',
      default=None,
      help='Target indexes to plot')
  (options, args) = parser.parse_args()

  if len(args) != 3:
    parser.error('Must provide parameters, model, and genomic position')
  else:
    params_file = args[0]
    model_file = args[1]
    genes_hdf5_file = args[2]

  if not os.path.isdir(options.out_dir):
    os.mkdir(options.out_dir)

  options.shifts = [int(shift) for shift in options.shifts.split(',')]

  #################################################################
  # reads in genes HDF5

  gene_data = genedata.GeneData(genes_hdf5_file)

  # subset gene sequences
  genes_subset = set()
  if options.gene_list:
    for line in open(options.gene_list):
      genes_subset.add(line.rstrip())

    gene_data.subset_genes(genes_subset)
    print('Filtered to %d sequences' % gene_data.num_seqs)

  # extract sequence chrom and start
  seqs_chrom = [gene_data.gene_seqs[si].chrom for si in range(gene_data.num_seqs)]
  seqs_start = [gene_data.gene_seqs[si].start for si in range(gene_data.num_seqs)]


  #######################################################
  # model parameters and placeholders

  job = params.read_job_params(params_file)

  job['seq_length'] = gene_data.seq_length
  job['seq_depth'] = gene_data.seq_depth
  job['target_pool'] = gene_data.pool_width

  if 'num_targets' not in job:
    print(
        "Must specify number of targets (num_targets) in the parameters file.",
        file=sys.stderr)
    exit(1)

  # set target indexes
  if options.target_indexes is not None:
    options.target_indexes = [int(ti) for ti in options.target_indexes.split(',')]
    target_subset = options.target_indexes
  else:
    options.target_indexes = list(range(job['num_targets']))
    target_subset = None

  # build model
  model = seqnn.SeqNN()
  model.build_feed(job, target_subset=target_subset)

  # determine latest pre-dilated layer
  cnn_dilation = np.array([cp.dilation for cp in model.hp.cnn_params])
  dilated_mask = cnn_dilation > 1
  dilated_indexes = np.where(dilated_mask)[0]
  pre_dilated_layer = np.min(dilated_indexes)
  print('Pre-dilated layer: %d' % pre_dilated_layer)

  # build gradients ops
  t0 = time.time()
  print('Building target/position-specific gradient ops.', end='')
  model.build_grads(layers=[pre_dilated_layer])
  print(' Done in %ds' % (time.time()-t0), flush=True)


  #######################################################
  # acquire gradients

  # initialize saver
  saver = tf.train.Saver()

  with tf.Session() as sess:
    # load variables into session
    saver.restore(sess, model_file)

    # score sequences and write bigwigs
    score_write(sess, model, options, gene_data.seqs_1hot, seqs_chrom, seqs_start)
Esempio n. 24
0
def main():
    usage = 'usage: %prog [options] <params_file> <model_file> <bed_file>'
    parser = OptionParser(usage)
    parser.add_option(
        '-b',
        dest='bigwig_indexes',
        default=None,
        help='Comma-separated list of target indexes to write BigWigs')
    parser.add_option('-e',
                      dest='embed_layer',
                      default=None,
                      type='int',
                      help='Embed sequences using the specified layer index.')
    parser.add_option('-f',
                      dest='genome_fasta',
                      default=None,
                      help='Genome FASTA for sequences [Default: %default]')
    parser.add_option('-g',
                      dest='genome_file',
                      default=None,
                      help='Chromosome length information [Default: %default]')
    parser.add_option(
        '-l',
        dest='site_length',
        default=None,
        type='int',
        help='Prediction site length. [Default: params.seq_length]')
    parser.add_option('-o',
                      dest='out_dir',
                      default='pred_out',
                      help='Output directory [Default: %default]')
    # parser.add_option('--plots', dest='plots',
    #     default=False, action='store_true',
    #     help='Make heatmap plots [Default: %default]')
    parser.add_option('-p',
                      dest='processes',
                      default=None,
                      type='int',
                      help='Number of processes, passed by multi script')
    parser.add_option(
        '--rc',
        dest='rc',
        default=False,
        action='store_true',
        help=
        'Ensemble forward and reverse complement predictions [Default: %default]'
    )
    parser.add_option('-s',
                      dest='sum',
                      default=False,
                      action='store_true',
                      help='Sum site predictions [Default: %default]')
    parser.add_option('--shifts',
                      dest='shifts',
                      default='0',
                      help='Ensemble prediction shifts [Default: %default]')
    parser.add_option(
        '-t',
        dest='targets_file',
        default=None,
        type='str',
        help='File specifying target indexes and labels in table format')
    (options, args) = parser.parse_args()

    if len(args) == 3:
        params_file = args[0]
        model_file = args[1]
        bed_file = args[2]

    elif len(args) == 5:
        # multi worker
        options_pkl_file = args[0]
        params_file = args[1]
        model_file = args[2]
        bed_file = args[3]
        worker_index = int(args[4])

        # load options
        options_pkl = open(options_pkl_file, 'rb')
        options = pickle.load(options_pkl)
        options_pkl.close()

        # update output directory
        options.out_dir = '%s/job%d' % (options.out_dir, worker_index)
    else:
        parser.error('Must provide parameter and model files and BED file')

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    options.shifts = [int(shift) for shift in options.shifts.split(',')]

    if options.bigwig_indexes is not None:
        options.bigwig_indexes = [
            int(bi) for bi in options.bigwig_indexes.split(',')
        ]
    else:
        options.bigwig_indexes = []

    if len(options.bigwig_indexes) > 0:
        bigwig_dir = '%s/bigwig' % options.out_dir
        if not os.path.isdir(bigwig_dir):
            os.mkdir(bigwig_dir)

    #################################################################
    # read parameters and collet target information

    job = params.read_job_params(params_file,
                                 require=['num_targets', 'seq_length'])

    if job.get('batch_buffer', 0) > 0:
        print('Turn off batch_buffer.', file=sys.stderr)
        exit(1)

    num_targets = np.sum(job['num_targets'])
    if options.targets_file is None:
        target_subset = None
    else:
        targets_df = pd.read_table(options.targets_file, index_col=0)
        target_subset = targets_df.index
        if len(target_subset) == num_targets:
            target_subset = None
        else:
            num_targets = len(target_subset)

    if options.site_length is None:
        options.site_length = params['seq_length']

    #################################################################
    # sequence dataset

    # construct model sequences
    model_seqs_dna, model_seqs_coords = make_bed_data(bed_file,
                                                      options.genome_fasta,
                                                      job['seq_length'])

    # construct site coordinates
    site_seqs_coords = read_bed(bed_file, options.site_length)

    # filter for worker SNPs
    if options.processes is not None:
        worker_bounds = np.linspace(0,
                                    len(model_seqs_dna),
                                    options.processes + 1,
                                    dtype='int')
        model_seqs_dna = model_seqs_dna[
            worker_bounds[worker_index]:worker_bounds[worker_index + 1]]
        model_seqs_coords = model_seqs_coords[
            worker_bounds[worker_index]:worker_bounds[worker_index + 1]]
        site_seqs_coords = site_seqs_coords[
            worker_bounds[worker_index]:worker_bounds[worker_index + 1]]

    num_seqs = len(model_seqs_dna)

    # make data ops
    data_ops = seq_data_ops(model_seqs_dna, job['batch_size'])

    #################################################################
    # setup model

    # build model
    model = seqnn.SeqNN()
    model.build_sad(job,
                    data_ops,
                    ensemble_rc=options.rc,
                    ensemble_shifts=options.shifts,
                    embed_layer=options.embed_layer,
                    target_subset=target_subset)

    #################################################################
    # setup output

    # determine site boundaries in predictions space
    assert (job['seq_length'] % model.preds_length == 0)
    preds_window = job['seq_length'] // model.preds_length

    assert (model.preds_length % 2 == 0)
    preds_mid = model.preds_length // 2

    assert (options.site_length % preds_window == 0)
    site_preds_length = options.site_length // preds_window

    assert (site_preds_length % 2 == 0)
    site_preds_start = preds_mid - site_preds_length // 2
    site_preds_end = site_preds_start + site_preds_length

    # initialize HDF5
    out_h5_file = '%s/predict.h5' % options.out_dir
    if os.path.isfile(out_h5_file):
        os.remove(out_h5_file)
    out_h5 = h5py.File(out_h5_file, 'w')

    # create predictions
    if options.sum:
        out_h5.create_dataset('preds',
                              shape=(num_seqs, model.preds_depth),
                              dtype='float16')
    else:
        out_h5.create_dataset('preds',
                              shape=(num_seqs, site_preds_length,
                                     model.preds_depth),
                              dtype='float16')

    # store site coordinates
    site_seqs_chr, site_seqs_start, site_seqs_end = zip(*site_seqs_coords)
    site_seqs_chr = np.array(site_seqs_chr, dtype='S')
    site_seqs_start = np.array(site_seqs_start)
    site_seqs_end = np.array(site_seqs_end)
    out_h5.create_dataset('chrom', data=site_seqs_chr)
    out_h5.create_dataset('start', data=site_seqs_start)
    out_h5.create_dataset('end', data=site_seqs_end)

    #################################################################
    # predict scores, write output

    # initialize saver
    saver = tf.train.Saver()

    with tf.Session() as sess:
        # load variables into session
        saver.restore(sess, model_file)

        # initialize predictions stream
        preds_stream = PredStream(sess, model, 64)

        for si in range(num_seqs):
            print('Predicting %d' % si, flush=True)

            # predict
            preds_full = preds_stream[si]

            # slice site
            preds_site = preds_full[site_preds_start:site_preds_end, :]

            # write
            if options.sum:
                out_h5['preds'][si] = preds_site.sum(axis=0)
            else:
                out_h5['preds'][si] = preds_site

            # write bigwig
            for ti in options.bigwig_indexes:
                bw_file = '%s/s%d_t%d.bw' % (bigwig_dir, si, ti)
                bigwig_write(preds_full[:, ti], model_seqs_coords[si], bw_file,
                             options.genome_file, model.hp.batch_buffer)

    # close output HDF5
    out_h5.close()
Esempio n. 25
0
def run(params_file, train_file, test_file, train_epochs, train_epoch_batches,
        test_epoch_batches):

    # parse shifts
    augment_shifts = [int(shift) for shift in FLAGS.augment_shifts.split(',')]
    ensemble_shifts = [
        int(shift) for shift in FLAGS.ensemble_shifts.split(',')
    ]

    # read parameters
    job = params.read_job_params(params_file)

    # load data
    data_ops, training_init_op, test_init_op = make_data_ops(
        job, train_file, test_file)

    # initialize model
    model = seqnn.SeqNN()
    model.build_from_data_ops(job, data_ops, FLAGS.augment_rc, augment_shifts,
                              FLAGS.ensemble_rc, ensemble_shifts)

    # launch accuracy compute thread
    acc_queue = Queue()
    acc_thread = AccuracyWorker(acc_queue)
    acc_thread.start()

    # checkpoints
    saver = tf.train.Saver()

    with tf.Session() as sess:
        train_writer = tf.summary.FileWriter(
            FLAGS.logdir + '/train', sess.graph) if FLAGS.logdir else None

        # coord = tf.train.Coordinator()
        # tf.train.start_queue_runners(coord=coord)

        if FLAGS.restart:
            # load variables into session
            saver.restore(sess, FLAGS.restart)
        else:
            # initialize variables
            t0 = time.time()
            print('Initializing...')
            sess.run(tf.local_variables_initializer())
            sess.run(tf.global_variables_initializer())
            print('Initialization time %f' % (time.time() - t0))

        train_loss = None
        best_loss = None
        early_stop_i = 0

        epoch = 0
        while (train_epochs is not None and epoch < train_epochs) or \
              (train_epochs is None and early_stop_i < FLAGS.early_stop):
            t0 = time.time()

            # save previous
            train_loss_last = train_loss

            # train epoch
            sess.run(training_init_op)
            train_loss, steps = model.train_epoch_tfr(sess, train_writer,
                                                      train_epoch_batches)

            # block for previous accuracy compute
            acc_queue.join()

            # test validation
            sess.run(test_init_op)
            valid_acc = model.test_tfr(sess, test_epoch_batches)

            # consider as best
            best_str = ''
            if best_loss is None or valid_acc.loss < best_loss:
                best_loss = valid_acc.loss
                best_str = ', best!'
                early_stop_i = 0
                saver.save(sess, '%s/model_best.tf' % FLAGS.logdir)
            else:
                early_stop_i += 1

            # measure time
            epoch_time = time.time() - t0
            if epoch_time < 600:
                time_str = '%3ds' % epoch_time
            elif epoch_time < 6000:
                time_str = '%3dm' % (epoch_time / 60)
            else:
                time_str = '%3.1fh' % (epoch_time / 3600)

            # compute and write accuracy update
            # accuracy_update(epoch, steps, train_loss, valid_acc, time_str, best_str)
            acc_queue.put(
                (epoch, steps, train_loss, valid_acc, time_str, best_str))

            # checkpoint latest
            saver.save(sess, '%s/model_check.tf' % FLAGS.logdir)

            # update epoch
            epoch += 1

        # finish queue
        acc_queue.join()

        if FLAGS.logdir:
            train_writer.close()
Esempio n. 26
0
def main():
    usage = 'usage: %prog [options] <params_file> <model_file> <vcf_file>'
    parser = OptionParser(usage)
    parser.add_option('-f',
                      dest='genome_fasta',
                      default='%s/data/hg19.fa' % os.environ['BASENJIDIR'],
                      help='Genome FASTA for sequences [Default: %default]')
    parser.add_option(
        '-g',
        dest='gain',
        default=False,
        action='store_true',
        help='Draw a sequence logo for the gain score, too [Default: %default]'
    )
    parser.add_option(
        '-l',
        dest='satmut_len',
        default=200,
        type='int',
        help='Length of centered sequence to mutate [Default: %default]')
    parser.add_option('-m',
                      dest='min_limit',
                      default=0.1,
                      type='float',
                      help='Minimum heatmap limit [Default: %default]')
    parser.add_option(
        '-n',
        dest='load_sat_npy',
        default=False,
        action='store_true',
        help='Load the predictions from .npy files [Default: %default]')
    parser.add_option('-o',
                      dest='out_dir',
                      default='sat_vcf',
                      help='Output directory [Default: %default]')
    parser.add_option(
        '--rc',
        dest='rc',
        default=False,
        action='store_true',
        help=
        'Ensemble forward and reverse complement predictions [Default: %default]'
    )
    parser.add_option('--shifts',
                      dest='shifts',
                      default='0',
                      help='Ensemble prediction shifts [Default: %default]')
    parser.add_option(
        '-t',
        dest='targets_file',
        default=None,
        type='str',
        help='File specifying target indexes and labels in table format')
    parser.add_option('-w',
                      dest='figure_width',
                      default=20,
                      type='float',
                      help='Figure width [Default: %default]')
    (options, args) = parser.parse_args()

    if len(args) != 3:
        parser.error('Must provide parameters and model files and VCF')
    else:
        params_file = args[0]
        model_file = args[1]
        vcf_file = args[2]

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    options.shifts = [int(shift) for shift in options.shifts.split(',')]

    #################################################################
    # read parameters

    # read model parameters
    with open(params_file) as params_open:
        params = json.load(params_open)
    params_train = params['train']
    params_model = params['model']
    """ unused
  if options.targets_file is None:
    target_ids = ['t%d' % ti for ti in range(params_model['num_targets'])]
    target_labels = ['']*len(target_ids)

  else:
    targets_df = pd.read_csv(options.targets_file, sep='\t', index_col=0)
    target_ids = targets_df.identifier
    target_labels = targets_df.description
  """

    #################################################################
    # prep SNP sequences

    # load SNPs
    snps = bvcf.vcf_snps(vcf_file)

    # get one hot coded input sequences
    seqs_1hot, seq_headers, snps, seqs = bvcf.snps_seq1(
        snps,
        params_model['seq_length'],
        options.genome_fasta,
        return_seqs=True)

    seqs_n = seqs_1hot.shape[0]

    #################################################################
    # setup model

    if not options.load_sat_npy:
        seqnn_model = seqnn.SeqNN(params_model)
        seqnn_model.restore(model_file)
        seqnn_model.build_ensemble(options.rc, options.shifts)

    #################################################################
    # predict and process

    for si in range(seqs_n):
        header = seq_headers[si]
        header_fs = fs_clean(header)

        print('Mutating sequence %d / %d' % (si + 1, seqs_n), flush=True)

        #################################################################
        # predict modifications

        if options.load_sat_npy:
            sat_preds = np.load('%s/seq%d_preds.npy' % (options.out_dir, si))

        else:
            # supplement with saturated mutagenesis
            sat_seqs_1hot = satmut_seqs(seqs_1hot[si:si + 1],
                                        options.satmut_len)

            # predict
            sat_preds = seqnn_model.predict(sat_seqs_1hot, batch_size=2)
            sat_preds = sat_preds.mean(axis=-1, dtype='float32', keepdims=True)
            np.save('%s/seq%d_preds.npy' % (options.out_dir, si), sat_preds)

        #################################################################
        # score matrices

        # compute the matrix of prediction deltas: (L_sm x 4 x T) array
        sat_scores = score_matrix(seqs_1hot[si], sat_preds)

        # plot max per position
        sat_max = sat_scores.max(axis=1)

        ##############################################
        # plot

        for ti in range(sat_scores.shape[-1]):
            # setup plot
            sns.set(style='white', font_scale=1)
            spp = subplot_params(sat_scores.shape[0])

            plt.figure(figsize=(options.figure_width, 5))
            ax_logo = plt.subplot2grid((3, spp['heat_cols']),
                                       (0, spp['logo_start']),
                                       colspan=spp['logo_span'])
            ax_sad = plt.subplot2grid((3, spp['heat_cols']),
                                      (1, spp['sad_start']),
                                      colspan=spp['sad_span'])
            ax_heat = plt.subplot2grid((3, spp['heat_cols']), (2, 0),
                                       colspan=spp['heat_cols'])

            # plot sequence logo
            plot_seqlogo(ax_logo, seqs_1hot[si], sat_max[:, ti])

            # plot SCD
            plot_scd(ax_sad, sat_max[:, ti])

            # plot heat map
            plot_heat(ax_heat, sat_scores[:, :, ti], options.min_limit)

            plt.savefig('%s/%s_t%d.pdf' % (options.out_dir, header_fs, ti),
                        dpi=600)
            plt.close()
Esempio n. 27
0
def main():
    usage = 'usage: %prog [options] <params_file> <model_file> <test_hdf5_file>'
    parser = OptionParser(usage)
    parser.add_option(
        '--ai',
        dest='accuracy_indexes',
        help=
        'Comma-separated list of target indexes to make accuracy plots comparing true versus predicted values'
    )
    parser.add_option(
        '--clip',
        dest='target_clip',
        default=None,
        type='float',
        help=
        'Clip targets and predictions to a maximum value [Default: %default]')
    parser.add_option(
        '-d',
        dest='down_sample',
        default=1,
        type='int',
        help=
        'Down sample test computation by taking uniformly spaced positions [Default: %default]'
    )
    parser.add_option('-g',
                      dest='genome_file',
                      default='%s/assembly/human.hg19.genome' %
                      os.environ['HG19'],
                      help='Chromosome length information [Default: %default]')
    parser.add_option('--mc',
                      dest='mc_n',
                      default=0,
                      type='int',
                      help='Monte carlo test iterations [Default: %default]')
    parser.add_option(
        '--peak',
        '--peaks',
        dest='peaks',
        default=False,
        action='store_true',
        help='Compute expensive peak accuracy [Default: %default]')
    parser.add_option(
        '-o',
        dest='out_dir',
        default='test_out',
        help='Output directory for test statistics [Default: %default]')
    parser.add_option(
        '--rc',
        dest='rc',
        default=False,
        action='store_true',
        help='Average the fwd and rc predictions [Default: %default]')
    parser.add_option('-s',
                      dest='scent_file',
                      help='Dimension reduction model file')
    parser.add_option('--sample',
                      dest='sample_pct',
                      default=1,
                      type='float',
                      help='Sample percentage')
    parser.add_option('--save',
                      dest='save',
                      default=False,
                      action='store_true')
    parser.add_option('--shifts',
                      dest='shifts',
                      default='0',
                      help='Ensemble prediction shifts [Default: %default]')
    parser.add_option(
        '-t',
        dest='track_bed',
        help='BED file describing regions so we can output BigWig tracks')
    parser.add_option(
        '--ti',
        dest='track_indexes',
        help='Comma-separated list of target indexes to output BigWig tracks')
    parser.add_option('--train',
                      dest='train',
                      default=False,
                      action='store_true',
                      help='Process the training set [Default: %default]')
    parser.add_option('-v',
                      dest='valid',
                      default=False,
                      action='store_true',
                      help='Process the validation set [Default: %default]')
    parser.add_option(
        '-w',
        dest='pool_width',
        default=1,
        type='int',
        help=
        'Max pool width for regressing nt predictions to predict peak calls [Default: %default]'
    )
    (options, args) = parser.parse_args()

    if len(args) != 3:
        parser.error('Must provide parameters, model, and test data HDF5')
    else:
        params_file = args[0]
        model_file = args[1]
        test_hdf5_file = args[2]

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    options.shifts = [int(shift) for shift in options.shifts.split(',')]

    #######################################################
    # load data
    #######################################################
    data_open = h5py.File(test_hdf5_file)

    if options.train:
        test_seqs = data_open['train_in']
        test_targets = data_open['train_out']
        if 'train_na' in data_open:
            test_na = data_open['train_na']

    elif options.valid:
        test_seqs = data_open['valid_in']
        test_targets = data_open['valid_out']
        test_na = None
        if 'valid_na' in data_open:
            test_na = data_open['valid_na']

    else:
        test_seqs = data_open['test_in']
        test_targets = data_open['test_out']
        test_na = None
        if 'test_na' in data_open:
            test_na = data_open['test_na']

    if options.sample_pct < 1:
        sample_n = int(test_seqs.shape[0] * options.sample_pct)
        print('Sampling %d sequences' % sample_n)
        sample_indexes = sorted(
            np.random.choice(np.arange(test_seqs.shape[0]),
                             size=sample_n,
                             replace=False))
        test_seqs = test_seqs[sample_indexes]
        test_targets = test_targets[sample_indexes]
        if test_na is not None:
            test_na = test_na[sample_indexes]

    target_labels = [tl.decode('UTF-8') for tl in data_open['target_labels']]

    #######################################################
    # model parameters and placeholders

    job = params.read_job_params(params_file)

    job['seq_length'] = test_seqs.shape[1]
    job['seq_depth'] = test_seqs.shape[2]
    job['num_targets'] = test_targets.shape[2]
    job['target_pool'] = int(np.array(data_open.get('pool_width', 1)))

    t0 = time.time()
    dr = seqnn.SeqNN()
    dr.build(job)
    print('Model building time %ds' % (time.time() - t0))

    # adjust for fourier
    job['fourier'] = 'train_out_imag' in data_open
    if job['fourier']:
        test_targets_imag = data_open['test_out_imag']
        if options.valid:
            test_targets_imag = data_open['valid_out_imag']

    # adjust for factors
    if options.scent_file is not None:
        t0 = time.time()
        test_targets_full = data_open['test_out_full']
        model = joblib.load(options.scent_file)

    #######################################################
    # test

    # initialize batcher
    if job['fourier']:
        batcher_test = batcher.BatcherF(test_seqs, test_targets,
                                        test_targets_imag, test_na,
                                        dr.hp.batch_size, dr.hp.target_pool)
    else:
        batcher_test = batcher.Batcher(test_seqs, test_targets, test_na,
                                       dr.hp.batch_size, dr.hp.target_pool)

    # initialize saver
    saver = tf.train.Saver()

    with tf.Session() as sess:
        # load variables into session
        saver.restore(sess, model_file)

        # test
        t0 = time.time()
        test_acc = dr.test(sess,
                           batcher_test,
                           rc=options.rc,
                           shifts=options.shifts,
                           mc_n=options.mc_n)

        if options.save:
            np.save('%s/preds.npy' % options.out_dir, test_acc.preds)
            np.save('%s/targets.npy' % options.out_dir, test_acc.targets)

        test_preds = test_acc.preds
        print('SeqNN test: %ds' % (time.time() - t0))

        # compute stats
        t0 = time.time()
        test_r2 = test_acc.r2(clip=options.target_clip)
        # test_log_r2 = test_acc.r2(log=True, clip=options.target_clip)
        test_pcor = test_acc.pearsonr(clip=options.target_clip)
        test_log_pcor = test_acc.pearsonr(log=True, clip=options.target_clip)
        #test_scor = test_acc.spearmanr()  # too slow; mostly driven by low values
        print('Compute stats: %ds' % (time.time() - t0))

        # print
        print('Test Loss:         %7.5f' % test_acc.loss)
        print('Test R2:           %7.5f' % test_r2.mean())
        # print('Test log R2:       %7.5f' % test_log_r2.mean())
        print('Test PearsonR:     %7.5f' % test_pcor.mean())
        print('Test log PearsonR: %7.5f' % test_log_pcor.mean())
        # print('Test SpearmanR:    %7.5f' % test_scor.mean())

        acc_out = open('%s/acc.txt' % options.out_dir, 'w')
        for ti in range(len(test_r2)):
            print('%4d  %7.5f  %.5f  %.5f  %.5f  %s' %
                  (ti, test_acc.target_losses[ti], test_r2[ti], test_pcor[ti],
                   test_log_pcor[ti], target_labels[ti]),
                  file=acc_out)
        acc_out.close()

        # print normalization factors
        target_means = test_preds.mean(axis=(0, 1), dtype='float64')
        target_means_median = np.median(target_means)
        target_means /= target_means_median
        norm_out = open('%s/normalization.txt' % options.out_dir, 'w')
        print('\n'.join([str(tu) for tu in target_means]), file=norm_out)
        norm_out.close()

        # clean up
        del test_acc

        # if test targets are reconstructed, measure versus the truth
        if options.scent_file is not None:
            compute_full_accuracy(dr, model, test_preds, test_targets_full,
                                  options.out_dir, options.down_sample)

    #######################################################
    # peak call accuracy

    if options.peaks:
        # sample every few bins to decrease correlations
        ds_indexes_preds = np.arange(0, test_preds.shape[1], 8)
        ds_indexes_targets = ds_indexes_preds + (dr.hp.batch_buffer //
                                                 dr.hp.target_pool)

        aurocs = []
        auprcs = []

        peaks_out = open('%s/peaks.txt' % options.out_dir, 'w')
        for ti in range(test_targets.shape[2]):
            if options.scent_file is not None:
                test_targets_ti = test_targets_full[:, :, ti]
            else:
                test_targets_ti = test_targets[:, :, ti]

            # subset and flatten
            test_targets_ti_flat = test_targets_ti[:,
                                                   ds_indexes_targets].flatten(
                                                   ).astype('float32')
            test_preds_ti_flat = test_preds[:, ds_indexes_preds,
                                            ti].flatten().astype('float32')

            # call peaks
            test_targets_ti_lambda = np.mean(test_targets_ti_flat)
            test_targets_pvals = 1 - poisson.cdf(
                np.round(test_targets_ti_flat) - 1, mu=test_targets_ti_lambda)
            test_targets_qvals = np.array(ben_hoch(test_targets_pvals))
            test_targets_peaks = test_targets_qvals < 0.01

            if test_targets_peaks.sum() == 0:
                aurocs.append(0.5)
                auprcs.append(0)

            else:
                # compute prediction accuracy
                aurocs.append(
                    roc_auc_score(test_targets_peaks, test_preds_ti_flat))
                auprcs.append(
                    average_precision_score(test_targets_peaks,
                                            test_preds_ti_flat))

            print('%4d  %6d  %.5f  %.5f' %
                  (ti, test_targets_peaks.sum(), aurocs[-1], auprcs[-1]),
                  file=peaks_out)

        peaks_out.close()

        print('Test AUROC:     %7.5f' % np.mean(aurocs))
        print('Test AUPRC:     %7.5f' % np.mean(auprcs))

    #######################################################
    # BigWig tracks

    # NOTE: THESE ASSUME THERE WAS NO DOWN-SAMPLING ABOVE

    # print bigwig tracks for visualization
    if options.track_bed:
        if options.genome_file is None:
            parser.error(
                'Must provide genome file in order to print valid BigWigs')

        if not os.path.isdir('%s/tracks' % options.out_dir):
            os.mkdir('%s/tracks' % options.out_dir)

        track_indexes = range(test_preds.shape[2])
        if options.track_indexes:
            track_indexes = [
                int(ti) for ti in options.track_indexes.split(',')
            ]

        bed_set = 'test'
        if options.valid:
            bed_set = 'valid'

        for ti in track_indexes:
            if options.scent_file is not None:
                test_targets_ti = test_targets_full[:, :, ti]
            else:
                test_targets_ti = test_targets[:, :, ti]

            # make true targets bigwig
            bw_file = '%s/tracks/t%d_true.bw' % (options.out_dir, ti)
            bigwig_write(bw_file,
                         test_targets_ti,
                         options.track_bed,
                         options.genome_file,
                         bed_set=bed_set)

            # make predictions bigwig
            bw_file = '%s/tracks/t%d_preds.bw' % (options.out_dir, ti)
            bigwig_write(bw_file,
                         test_preds[:, :, ti],
                         options.track_bed,
                         options.genome_file,
                         dr.hp.batch_buffer,
                         bed_set=bed_set)

        # make NA bigwig
        bw_file = '%s/tracks/na.bw' % options.out_dir
        bigwig_write(bw_file,
                     test_na,
                     options.track_bed,
                     options.genome_file,
                     bed_set=bed_set)

    #######################################################
    # accuracy plots

    if options.accuracy_indexes is not None:
        accuracy_indexes = [
            int(ti) for ti in options.accuracy_indexes.split(',')
        ]

        if not os.path.isdir('%s/scatter' % options.out_dir):
            os.mkdir('%s/scatter' % options.out_dir)

        if not os.path.isdir('%s/violin' % options.out_dir):
            os.mkdir('%s/violin' % options.out_dir)

        if not os.path.isdir('%s/roc' % options.out_dir):
            os.mkdir('%s/roc' % options.out_dir)

        if not os.path.isdir('%s/pr' % options.out_dir):
            os.mkdir('%s/pr' % options.out_dir)

        for ti in accuracy_indexes:
            if options.scent_file is not None:
                test_targets_ti = test_targets_full[:, :, ti]
            else:
                test_targets_ti = test_targets[:, :, ti]

            ############################################
            # scatter

            # sample every few bins (adjust to plot the # points I want)
            ds_indexes_preds = np.arange(0, test_preds.shape[1], 8)
            ds_indexes_targets = ds_indexes_preds + (dr.hp.batch_buffer //
                                                     dr.hp.target_pool)

            # subset and flatten
            test_targets_ti_flat = test_targets_ti[:,
                                                   ds_indexes_targets].flatten(
                                                   ).astype('float32')
            test_preds_ti_flat = test_preds[:, ds_indexes_preds,
                                            ti].flatten().astype('float32')

            # take log2
            test_targets_ti_log = np.log2(test_targets_ti_flat + 1)
            test_preds_ti_log = np.log2(test_preds_ti_flat + 1)

            # plot log2
            sns.set(font_scale=1.2, style='ticks')
            out_pdf = '%s/scatter/t%d.pdf' % (options.out_dir, ti)
            plots.regplot(test_targets_ti_log,
                          test_preds_ti_log,
                          out_pdf,
                          poly_order=1,
                          alpha=0.3,
                          sample=500,
                          figsize=(6, 6),
                          x_label='log2 Experiment',
                          y_label='log2 Prediction',
                          table=True)

            ############################################
            # violin

            # call peaks
            test_targets_ti_lambda = np.mean(test_targets_ti_flat)
            test_targets_pvals = 1 - poisson.cdf(
                np.round(test_targets_ti_flat) - 1, mu=test_targets_ti_lambda)
            test_targets_qvals = np.array(ben_hoch(test_targets_pvals))
            test_targets_peaks = test_targets_qvals < 0.01
            test_targets_peaks_str = np.where(test_targets_peaks, 'Peak',
                                              'Background')

            # violin plot
            sns.set(font_scale=1.3, style='ticks')
            plt.figure()
            df = pd.DataFrame({
                'log2 Prediction':
                np.log2(test_preds_ti_flat + 1),
                'Experimental coverage status':
                test_targets_peaks_str
            })
            ax = sns.violinplot(x='Experimental coverage status',
                                y='log2 Prediction',
                                data=df)
            ax.grid(True, linestyle=':')
            plt.savefig('%s/violin/t%d.pdf' % (options.out_dir, ti))
            plt.close()

            # ROC
            plt.figure()
            fpr, tpr, _ = roc_curve(test_targets_peaks, test_preds_ti_flat)
            auroc = roc_auc_score(test_targets_peaks, test_preds_ti_flat)
            plt.plot([0, 1], [0, 1],
                     c='black',
                     linewidth=1,
                     linestyle='--',
                     alpha=0.7)
            plt.plot(fpr, tpr, c='black')
            ax = plt.gca()
            ax.set_xlabel('False positive rate')
            ax.set_ylabel('True positive rate')
            ax.text(0.99,
                    0.02,
                    'AUROC %.3f' % auroc,
                    horizontalalignment='right')  # , fontsize=14)
            ax.grid(True, linestyle=':')
            plt.savefig('%s/roc/t%d.pdf' % (options.out_dir, ti))
            plt.close()

            # PR
            plt.figure()
            prec, recall, _ = precision_recall_curve(test_targets_peaks,
                                                     test_preds_ti_flat)
            auprc = average_precision_score(test_targets_peaks,
                                            test_preds_ti_flat)
            plt.axhline(y=test_targets_peaks.mean(),
                        c='black',
                        linewidth=1,
                        linestyle='--',
                        alpha=0.7)
            plt.plot(recall, prec, c='black')
            ax = plt.gca()
            ax.set_xlabel('Recall')
            ax.set_ylabel('Precision')
            ax.text(0.99,
                    0.95,
                    'AUPRC %.3f' % auprc,
                    horizontalalignment='right')  # , fontsize=14)
            ax.grid(True, linestyle=':')
            plt.savefig('%s/pr/t%d.pdf' % (options.out_dir, ti))
            plt.close()

    data_open.close()
Esempio n. 28
0
def main():
    usage = 'usage: %prog [options] <params_file> <model_file> <vcf_file>'
    parser = OptionParser(usage)
    parser.add_option('-c',
                      dest='csv',
                      default=False,
                      action='store_true',
                      help='Print table as CSV [Default: %default]')
    parser.add_option('--cpu',
                      dest='cpu',
                      default=False,
                      action='store_true',
                      help='Run without a GPU [Default: %default]')
    parser.add_option('-f',
                      dest='genome_fasta',
                      default='%s/data/hg19.fa' % os.environ['BASENJIDIR'],
                      help='Genome FASTA for sequences [Default: %default]')
    parser.add_option('--local',
                      dest='local',
                      default=1024,
                      type='int',
                      help='Local SAD score [Default: %default]')
    parser.add_option('-n',
                      dest='norm_file',
                      default=None,
                      help='Normalize SAD scores')
    parser.add_option(
        '-o',
        dest='out_dir',
        default='sad',
        help='Output directory for tables and plots [Default: %default]')
    parser.add_option('-p',
                      dest='processes',
                      default=None,
                      type='int',
                      help='Number of processes, passed by multi script')
    parser.add_option('--pseudo',
                      dest='log_pseudo',
                      default=1,
                      type='float',
                      help='Log2 pseudocount [Default: %default]')
    parser.add_option(
        '--rc',
        dest='rc',
        default=False,
        action='store_true',
        help=
        'Average forward and reverse complement predictions [Default: %default]'
    )
    parser.add_option('--shifts',
                      dest='shifts',
                      default='0',
                      type='str',
                      help='Ensemble prediction shifts [Default: %default]')
    parser.add_option(
        '--stats',
        dest='sad_stats',
        default='SAD',
        help='Comma-separated list of stats to save. [Default: %default]')
    parser.add_option(
        '-t',
        dest='targets_file',
        default=None,
        type='str',
        help='File specifying target indexes and labels in table format')
    parser.add_option(
        '--ti',
        dest='track_indexes',
        default=None,
        type='str',
        help='Comma-separated list of target indexes to output BigWig tracks')
    parser.add_option('--txt',
                      dest='out_txt',
                      default=False,
                      action='store_true',
                      help='Output stats to text table [Default: %default]')
    parser.add_option(
        '-u',
        dest='penultimate',
        default=False,
        action='store_true',
        help='Compute SED in the penultimate layer [Default: %default]')
    parser.add_option('-z',
                      dest='out_zarr',
                      default=False,
                      action='store_true',
                      help='Output stats to sad.zarr [Default: %default]')
    (options, args) = parser.parse_args()

    if len(args) == 3:
        # single worker
        params_file = args[0]
        model_file = args[1]
        vcf_file = args[2]

    elif len(args) == 5:
        # multi worker
        options_pkl_file = args[0]
        params_file = args[1]
        model_file = args[2]
        vcf_file = args[3]
        worker_index = int(args[4])

        # load options
        options_pkl = open(options_pkl_file, 'rb')
        options = pickle.load(options_pkl)
        options_pkl.close()

        # update output directory
        options.out_dir = '%s/job%d' % (options.out_dir, worker_index)

    else:
        parser.error(
            'Must provide parameters and model files and QTL VCF file')

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    if options.track_indexes is None:
        options.track_indexes = []
    else:
        options.track_indexes = [
            int(ti) for ti in options.track_indexes.split(',')
        ]
        if not os.path.isdir('%s/tracks' % options.out_dir):
            os.mkdir('%s/tracks' % options.out_dir)

    options.shifts = [int(shift) for shift in options.shifts.split(',')]
    options.sad_stats = options.sad_stats.split(',')

    #################################################################
    # read parameters

    job = params.read_job_params(params_file,
                                 require=['seq_length', 'num_targets'])

    if options.targets_file is None:
        target_ids = ['t%d' % ti for ti in range(job['num_targets'])]
        target_labels = [''] * len(target_ids)
        target_subset = None

    else:
        targets_df = pd.read_table(options.targets_file, index_col=0)
        target_ids = targets_df.identifier
        target_labels = targets_df.description
        target_subset = targets_df.index
        if len(target_subset) == job['num_targets']:
            target_subset = None

    #################################################################
    # load SNPs

    snps = bvcf.vcf_snps(vcf_file)

    # filter for worker SNPs
    if options.processes is not None:
        worker_bounds = np.linspace(0,
                                    len(snps),
                                    options.processes + 1,
                                    dtype='int')
        snps = snps[worker_bounds[worker_index]:worker_bounds[worker_index +
                                                              1]]

    num_snps = len(snps)

    # open genome FASTA
    genome_open = pysam.Fastafile(options.genome_fasta)

    def snp_gen():
        for snp in snps:
            # get SNP sequences
            snp_1hot_list = bvcf.snp_seq1(snp, job['seq_length'], genome_open)

            for snp_1hot in snp_1hot_list:
                yield {'sequence': snp_1hot}

    snp_types = {'sequence': tf.float32}
    snp_shapes = {
        'sequence':
        tf.TensorShape([tf.Dimension(job['seq_length']),
                        tf.Dimension(4)])
    }

    dataset = tf.data.Dataset.from_generator(snp_gen,
                                             output_types=snp_types,
                                             output_shapes=snp_shapes)
    dataset = dataset.batch(job['batch_size'])
    dataset = dataset.prefetch(2 * job['batch_size'])
    if not options.cpu:
        dataset = dataset.apply(
            tf.contrib.data.prefetch_to_device('/device:GPU:0'))

    iterator = dataset.make_one_shot_iterator()
    data_ops = iterator.get_next()

    #################################################################
    # setup model

    # build model
    t0 = time.time()
    model = seqnn.SeqNN()
    model.build_sad(job,
                    data_ops,
                    ensemble_rc=options.rc,
                    ensemble_shifts=options.shifts,
                    embed_penultimate=options.penultimate,
                    target_subset=target_subset)
    print('Model building time %f' % (time.time() - t0), flush=True)

    if options.penultimate:
        # labels become inappropriate
        target_ids = [''] * model.hp.cnn_filters[-1]
        target_labels = target_ids

    # read target normalization factors
    target_norms = np.ones(len(target_labels))
    if options.norm_file is not None:
        ti = 0
        for line in open(options.norm_file):
            target_norms[ti] = float(line.strip())
            ti += 1

    num_targets = len(target_ids)

    #################################################################
    # setup output

    if options.out_zarr:
        sad_out = initialize_output_zarr(options.out_dir, options.sad_stats,
                                         snps, target_ids, target_labels)

    elif options.out_txt:
        header_cols = ('rsid', 'ref', 'alt', 'ref_pred', 'alt_pred', 'sad',
                       'sar', 'geo_sad', 'ref_lpred', 'alt_lpred', 'lsad',
                       'lsar', 'ref_xpred', 'alt_xpred', 'xsad', 'xsar',
                       'target_index', 'target_id', 'target_label')

        if options.csv:
            sad_out = open('%s/sad_table.csv' % options.out_dir, 'w')
            print(','.join(header_cols), file=sad_out)
        else:
            sad_out = open('%s/sad_table.txt' % options.out_dir, 'w')
            print(' '.join(header_cols), file=sad_out)

    else:
        sad_out = initialize_output_h5(options.out_dir, options.sad_stats,
                                       snps, target_ids, target_labels)

    #################################################################
    # process

    szi = 0
    sum_write_thread = None
    sw_batch_size = 32 // job['batch_size']

    # initialize saver
    saver = tf.train.Saver()
    with tf.Session() as sess:
        # load variables into session
        saver.restore(sess, model_file)

        # predict first
        batch_preds = model.predict_tfr(sess, test_batches=sw_batch_size)

        while batch_preds.shape[0] > 0:
            # count predicted SNPs
            num_snps = batch_preds.shape[0] // 2

            # normalize
            batch_preds /= target_norms

            # block for last thread
            if sum_write_thread is not None:
                sum_write_thread.join()

            # summarize and write
            sum_write_thread = threading.Thread(target=summarize_write,
                                                args=(batch_preds, sad_out,
                                                      szi, options.sad_stats,
                                                      options.log_pseudo))
            sum_write_thread.start()

            # update SNP index
            szi += num_snps

            # predict next
            batch_preds = model.predict_tfr(sess, test_batches=sw_batch_size)

    sum_write_thread.join()

    ###################################################
    # compute SAD distributions across variants

    if not options.out_txt:
        # define percentiles
        d_fine = 0.001
        d_coarse = 0.01
        percentiles_neg = np.arange(d_fine, 0.1, d_fine)
        percentiles_base = np.arange(0.1, 0.9, d_coarse)
        percentiles_pos = np.arange(0.9, 1, d_fine)

        percentiles = np.concatenate(
            [percentiles_neg, percentiles_base, percentiles_pos])
        sad_out.create_dataset('percentiles', data=percentiles)
        pct_len = len(percentiles)

        for sad_stat in options.sad_stats:
            sad_stat_pct = '%s_pct' % sad_stat

            # compute
            sad_pct = np.percentile(sad_out[sad_stat],
                                    100 * percentiles,
                                    axis=0).T
            sad_pct = sad_pct.astype('float16')

            # save
            sad_out.create_dataset(sad_stat_pct, data=sad_pct, dtype='float16')

    if not options.out_zarr:
        sad_out.close()
Esempio n. 29
0
def main():
    usage = 'usage: %prog [options] <params_file> <model_file> <data_dir>'
    parser = OptionParser(usage)
    parser.add_option(
        '--ai',
        dest='accuracy_indexes',
        help=
        'Comma-separated list of target indexes to make accuracy scatter plots.'
    )
    parser.add_option('--mc',
                      dest='mc_n',
                      default=0,
                      type='int',
                      help='Monte carlo test iterations [Default: %default]')
    parser.add_option(
        '--peak',
        '--peaks',
        dest='peaks',
        default=False,
        action='store_true',
        help='Compute expensive peak accuracy [Default: %default]')
    parser.add_option(
        '-o',
        dest='out_dir',
        default='test_out',
        help='Output directory for test statistics [Default: %default]')
    parser.add_option(
        '--rc',
        dest='rc',
        default=False,
        action='store_true',
        help='Average the fwd and rc predictions [Default: %default]')
    parser.add_option(
        '--save',
        dest='save',
        default=False,
        action='store_true',
        help='Save targets and predictions numpy arrays [Default: %default]')
    parser.add_option('--shifts',
                      dest='shifts',
                      default='0',
                      help='Ensemble prediction shifts [Default: %default]')
    parser.add_option(
        '-t',
        dest='targets_file',
        default=None,
        type='str',
        help='File specifying target indexes and labels in table format')
    parser.add_option(
        '--tfr',
        dest='tfr_pattern',
        default='test-*.tfr',
        help='TFR pattern string appended to data_dir [Default: %default]')
    (options, args) = parser.parse_args()

    if len(args) != 3:
        parser.error('Must provide parameters, model, and test data HDF5')
    else:
        params_file = args[0]
        model_file = args[1]
        data_dir = args[2]

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    # parse shifts to integers
    options.shifts = [int(shift) for shift in options.shifts.split(',')]

    #######################################################
    # inputs

    # read targets
    if options.targets_file is None:
        options.targets_file = '%s/targets.txt' % data_dir
    targets_df = pd.read_csv(options.targets_file, index_col=0, sep='\t')

    # read model parameters
    with open(params_file) as params_open:
        params = json.load(params_open)
    params_model = params['model']
    params_train = params['train']

    # read data parameters
    data_stats_file = '%s/statistics.json' % data_dir
    with open(data_stats_file) as data_stats_open:
        data_stats = json.load(data_stats_open)

    # construct data ops
    tfr_pattern_path = '%s/tfrecords/%s' % (data_dir, options.tfr_pattern)
    eval_data = dataset.SeqDataset(tfr_pattern_path,
                                   seq_length=data_stats['seq_length'],
                                   target_length=data_stats['target_length'],
                                   batch_size=params_train['batch_size'],
                                   mode=tf.estimator.ModeKeys.EVAL)

    # initialize model
    seqnn_model = seqnn.SeqNN(params_model)
    seqnn_model.restore(model_file)
    seqnn_model.build_ensemble(options.rc, options.shifts)

    #######################################################
    # evaluate

    eval_loss = params_train.get('loss', 'poisson')

    # evaluate
    test_loss, test_pr, test_r2 = seqnn_model.evaluate(eval_data,
                                                       loss=eval_loss)
    print('')

    # print summary statistics
    print('Test Loss:         %7.5f' % test_loss)
    print('Test R2:           %7.5f' % test_r2.mean())
    print('Test PearsonR:     %7.5f' % test_pr.mean())

    # write target-level statistics
    targets_acc_df = pd.DataFrame({
        'index': targets_df.index,
        'r2': test_r2,
        'pearsonr': test_pr,
        'identifier': targets_df.identifier,
        'description': targets_df.description
    })
    targets_acc_df.to_csv('%s/acc.txt' % options.out_dir,
                          sep='\t',
                          index=False,
                          float_format='%.5f')

    #######################################################
    # predict?

    if options.save or options.peaks or options.accuracy_indexes is not None:
        # compute predictions
        test_preds = seqnn_model.predict(eval_data).astype('float16')

        # read targets
        test_targets = eval_data.numpy(return_inputs=False)

    if options.save:
        preds_h5 = h5py.File('%s/preds.h5' % options.out_dir, 'w')
        preds_h5.create_dataset('preds', data=test_preds)
        preds_h5.close()
        targets_h5 = h5py.File('%s/targets.h5' % options.out_dir, 'w')
        targets_h5.create_dataset('targets', data=test_targets)
        targets_h5.close()

    #######################################################
    # peak call accuracy

    if options.peaks:
        peaks_out_file = '%s/peaks.txt' % options.out_dir
        test_peaks(test_preds, test_targets, peaks_out_file)

    #######################################################
    # accuracy plots

    if options.accuracy_indexes is not None:
        accuracy_indexes = [
            int(ti) for ti in options.accuracy_indexes.split(',')
        ]

        if not os.path.isdir('%s/scatter' % options.out_dir):
            os.mkdir('%s/scatter' % options.out_dir)

        if not os.path.isdir('%s/violin' % options.out_dir):
            os.mkdir('%s/violin' % options.out_dir)

        if not os.path.isdir('%s/roc' % options.out_dir):
            os.mkdir('%s/roc' % options.out_dir)

        if not os.path.isdir('%s/pr' % options.out_dir):
            os.mkdir('%s/pr' % options.out_dir)

        for ti in accuracy_indexes:
            test_targets_ti = test_targets[:, :, ti]

            ############################################
            # scatter

            # sample every few bins (adjust to plot the # points I want)
            ds_indexes = np.arange(0, test_preds.shape[1], 8)

            # subset and flatten
            test_targets_ti_flat = test_targets_ti[:, ds_indexes].flatten(
            ).astype('float32')
            test_preds_ti_flat = test_preds[:, ds_indexes,
                                            ti].flatten().astype('float32')

            # take log2
            test_targets_ti_log = np.log2(test_targets_ti_flat + 1)
            test_preds_ti_log = np.log2(test_preds_ti_flat + 1)

            # plot log2
            sns.set(font_scale=1.2, style='ticks')
            out_pdf = '%s/scatter/t%d.pdf' % (options.out_dir, ti)
            plots.regplot(test_targets_ti_log,
                          test_preds_ti_log,
                          out_pdf,
                          poly_order=1,
                          alpha=0.3,
                          sample=500,
                          figsize=(6, 6),
                          x_label='log2 Experiment',
                          y_label='log2 Prediction',
                          table=True)

            ############################################
            # violin

            # call peaks
            test_targets_ti_lambda = np.mean(test_targets_ti_flat)
            test_targets_pvals = 1 - poisson.cdf(
                np.round(test_targets_ti_flat) - 1, mu=test_targets_ti_lambda)
            test_targets_qvals = np.array(ben_hoch(test_targets_pvals))
            test_targets_peaks = test_targets_qvals < 0.01
            test_targets_peaks_str = np.where(test_targets_peaks, 'Peak',
                                              'Background')

            # violin plot
            sns.set(font_scale=1.3, style='ticks')
            plt.figure()
            df = pd.DataFrame({
                'log2 Prediction':
                np.log2(test_preds_ti_flat + 1),
                'Experimental coverage status':
                test_targets_peaks_str
            })
            ax = sns.violinplot(x='Experimental coverage status',
                                y='log2 Prediction',
                                data=df)
            ax.grid(True, linestyle=':')
            plt.savefig('%s/violin/t%d.pdf' % (options.out_dir, ti))
            plt.close()

            # ROC
            plt.figure()
            fpr, tpr, _ = roc_curve(test_targets_peaks, test_preds_ti_flat)
            auroc = roc_auc_score(test_targets_peaks, test_preds_ti_flat)
            plt.plot([0, 1], [0, 1],
                     c='black',
                     linewidth=1,
                     linestyle='--',
                     alpha=0.7)
            plt.plot(fpr, tpr, c='black')
            ax = plt.gca()
            ax.set_xlabel('False positive rate')
            ax.set_ylabel('True positive rate')
            ax.text(0.99,
                    0.02,
                    'AUROC %.3f' % auroc,
                    horizontalalignment='right')  # , fontsize=14)
            ax.grid(True, linestyle=':')
            plt.savefig('%s/roc/t%d.pdf' % (options.out_dir, ti))
            plt.close()

            # PR
            plt.figure()
            prec, recall, _ = precision_recall_curve(test_targets_peaks,
                                                     test_preds_ti_flat)
            auprc = average_precision_score(test_targets_peaks,
                                            test_preds_ti_flat)
            plt.axhline(y=test_targets_peaks.mean(),
                        c='black',
                        linewidth=1,
                        linestyle='--',
                        alpha=0.7)
            plt.plot(recall, prec, c='black')
            ax = plt.gca()
            ax.set_xlabel('Recall')
            ax.set_ylabel('Precision')
            ax.text(0.99,
                    0.95,
                    'AUPRC %.3f' % auprc,
                    horizontalalignment='right')  # , fontsize=14)
            ax.grid(True, linestyle=':')
            plt.savefig('%s/pr/t%d.pdf' % (options.out_dir, ti))
            plt.close()
Esempio n. 30
0
def main():
    usage = (
        'usage: %prog [options] <params_file> <model_file> <genes_hdf5_file>'
        ' <vcf_file>')
    parser = OptionParser(usage)
    parser.add_option(
        '-a',
        dest='all_sed',
        default=False,
        action='store_true',
        help=
        'Print all variant-gene pairs, as opposed to only nonzero [Default: %default]'
    )
    parser.add_option('-b',
                      dest='batch_size',
                      default=None,
                      type='int',
                      help='Batch size [Default: %default]')
    parser.add_option('-c',
                      dest='csv',
                      default=False,
                      action='store_true',
                      help='Print table as CSV [Default: %default]')
    parser.add_option('-g',
                      dest='genome_file',
                      default='%s/assembly/human.hg19.genome' %
                      os.environ['HG19'],
                      help='Chromosome lengths file [Default: %default]')
    parser.add_option(
        '-o',
        dest='out_dir',
        default='sed',
        help='Output directory for tables and plots [Default: %default]')
    parser.add_option('-p',
                      dest='processes',
                      default=None,
                      type='int',
                      help='Number of processes, passed by multi script')
    parser.add_option('--pseudo',
                      dest='log_pseudo',
                      default=0.125,
                      type='float',
                      help='Log2 pseudocount [Default: %default]')
    parser.add_option(
        '-r',
        dest='tss_radius',
        default=0,
        type='int',
        help=
        'Radius of bins considered to quantify TSS transcription [Default: %default]'
    )
    parser.add_option(
        '--rc',
        dest='rc',
        default=False,
        action='store_true',
        help=
        'Average the forward and reverse complement predictions when testing [Default: %default]'
    )
    parser.add_option('--shifts',
                      dest='shifts',
                      default='0',
                      help='Ensemble prediction shifts [Default: %default]')
    parser.add_option(
        '-t',
        dest='targets_file',
        default=None,
        help='File specifying target indexes and labels in table format')
    parser.add_option(
        '-u',
        dest='penultimate',
        default=False,
        action='store_true',
        help='Compute SED in the penultimate layer [Default: %default]')
    parser.add_option(
        '-x',
        dest='tss_table',
        default=False,
        action='store_true',
        help='Print TSS table in addition to gene [Default: %default]')
    (options, args) = parser.parse_args()

    if len(args) == 4:
        # single worker
        params_file = args[0]
        model_file = args[1]
        genes_hdf5_file = args[2]
        vcf_file = args[3]

    elif len(args) == 6:
        # multi worker
        options_pkl_file = args[0]
        params_file = args[1]
        model_file = args[2]
        genes_hdf5_file = args[3]
        vcf_file = args[4]
        worker_index = int(args[5])

        # load options
        options_pkl = open(options_pkl_file, 'rb')
        options = pickle.load(options_pkl)
        options_pkl.close()

        # update output directory
        options.out_dir = '%s/job%d' % (options.out_dir, worker_index)

    else:
        parser.error(
            'Must provide parameters and model files, genes HDF5 file, and QTL VCF'
            ' file')

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    options.shifts = [int(shift) for shift in options.shifts.split(',')]

    #################################################################
    # reads in genes HDF5

    gene_data = genedata.GeneData(genes_hdf5_file)

    # filter for worker sequences
    if options.processes is not None:
        gene_data.worker(worker_index, options.processes)

    #################################################################
    # prep SNPs

    # load SNPs
    snps = bvcf.vcf_snps(vcf_file)

    # intersect w/ segments
    print('Intersecting gene sequences with SNPs...', end='')
    sys.stdout.flush()
    seqs_snps = bvcf.intersect_seqs_snps(vcf_file,
                                         gene_data.gene_seqs,
                                         vision_p=0.5)
    print('%d sequences w/ SNPs' % len(seqs_snps))

    #################################################################
    # setup model

    job = params.read_job_params(params_file)

    job['seq_length'] = gene_data.seq_length
    job['seq_depth'] = gene_data.seq_depth
    job['target_pool'] = gene_data.pool_width

    if 'num_targets' not in job and gene_data.num_targets is not None:
        job['num_targets'] = gene_data.num_targets

    if 'num_targets' not in job:
        print("Must specify number of targets (num_targets) in the \
            parameters file.",
              file=sys.stderr)
        exit(1)

    if options.targets_file is None:
        target_labels = gene_data.target_labels
        target_subset = None
        target_ids = gene_data.target_ids
        if target_ids is None:
            target_ids = ['t%d' % ti for ti in range(job['num_targets'])]
            target_labels = [''] * len(target_ids)

    else:
        # Unfortunately, this target file differs from some others
        # in that it needs to specify the indexes from the original
        # set. In the future, I should standardize to this version.

        target_ids = []
        target_labels = []
        target_subset = []

        for line in open(options.targets_file):
            a = line.strip().split('\t')
            target_subset.append(int(a[0]))
            target_ids.append(a[1])
            target_labels.append(a[3])

            if len(target_subset) == job['num_targets']:
                target_subset = None

    # build model
    model = seqnn.SeqNN()
    model.build(job, target_subset=target_subset)

    if options.penultimate:
        # labels become inappropriate
        target_ids = [''] * model.hp.cnn_filters[-1]
        target_labels = target_ids

    #################################################################
    # compute, collect, and print SEDs

    header_cols = ('rsid', 'ref', 'alt', 'gene', 'tss_dist', 'ref_pred',
                   'alt_pred', 'sed', 'ser', 'target_index', 'target_id',
                   'target_label')
    if options.csv:
        sed_gene_out = open('%s/sed_gene.csv' % options.out_dir, 'w')
        print(','.join(header_cols), file=sed_gene_out)
        if options.tss_table:
            sed_tss_out = open('%s/sed_tss.csv' % options.out_dir, 'w')
            print(','.join(header_cols), file=sed_tss_out)

    else:
        sed_gene_out = open('%s/sed_gene.txt' % options.out_dir, 'w')
        print(' '.join(header_cols), file=sed_gene_out)
        if options.tss_table:
            sed_tss_out = open('%s/sed_tss.txt' % options.out_dir, 'w')
            print(' '.join(header_cols), file=sed_tss_out)

    # helper variables
    pred_buffer = model.hp.batch_buffer // model.hp.target_pool

    # initialize saver
    saver = tf.train.Saver()

    with tf.Session() as sess:
        # load variables into session
        saver.restore(sess, model_file)

        # for each gene sequence
        for seq_i in range(gene_data.num_seqs):
            gene_seq = gene_data.gene_seqs[seq_i]
            print(gene_seq)

            # if it contains SNPs
            seq_snps = [snps[snp_i] for snp_i in seqs_snps[seq_i]]
            if len(seq_snps) > 0:

                # one hot code allele sequences
                aseqs_1hot = alleles_1hot(gene_seq, gene_data.seqs_1hot[seq_i],
                                          seq_snps)

                # initialize batcher
                batcher_gene = batcher.Batcher(aseqs_1hot,
                                               batch_size=model.hp.batch_size)

                # construct allele gene_seq's
                allele_gene_seqs = [gene_seq] * aseqs_1hot.shape[0]

                # predict alleles
                allele_tss_preds = model.predict_genes(
                    sess,
                    batcher_gene,
                    allele_gene_seqs,
                    rc=options.rc,
                    shifts=options.shifts,
                    penultimate=options.penultimate,
                    tss_radius=options.tss_radius)

                # reshape (Alleles x TSSs) x Targets to Alleles x TSSs x Targets
                allele_tss_preds = allele_tss_preds.reshape(
                    (aseqs_1hot.shape[0], gene_seq.num_tss, -1))

                # extract reference and SNP alt predictions
                ref_tss_preds = allele_tss_preds[0]  # TSSs x Targets
                alt_tss_preds = allele_tss_preds[1:]  # SNPs x TSSs x Targets

                # compute TSS SED scores
                snp_tss_sed = alt_tss_preds - ref_tss_preds
                snp_tss_ser = np.log2(alt_tss_preds + options.log_pseudo) \
                                - np.log2(ref_tss_preds + options.log_pseudo)

                # compute gene-level predictions
                ref_gene_preds, gene_ids = gene.map_tss_genes(
                    ref_tss_preds, gene_seq.tss_list, options.tss_radius)
                alt_gene_preds = []
                for snp_i in range(len(seq_snps)):
                    agp, _ = gene.map_tss_genes(alt_tss_preds[snp_i],
                                                gene_seq.tss_list,
                                                options.tss_radius)
                    alt_gene_preds.append(agp)
                alt_gene_preds = np.array(alt_gene_preds)

                # compute gene SED scores
                gene_sed = alt_gene_preds - ref_gene_preds
                gene_ser = np.log2(alt_gene_preds + options.log_pseudo) \
                            - np.log2(ref_gene_preds + options.log_pseudo)

                # for each SNP
                for snp_i in range(len(seq_snps)):
                    snp = seq_snps[snp_i]

                    # initialize gene data structures
                    snp_dist_gene = {}

                    # for each TSS
                    for tss_i in range(gene_seq.num_tss):
                        tss = gene_seq.tss_list[tss_i]

                        # SNP distance to TSS
                        snp_dist = abs(snp.pos - tss.pos)
                        if tss.gene_id in snp_dist_gene:
                            snp_dist_gene[tss.gene_id] = min(
                                snp_dist_gene[tss.gene_id], snp_dist)
                        else:
                            snp_dist_gene[tss.gene_id] = snp_dist

                        # for each target
                        if options.tss_table:
                            for ti in range(ref_tss_preds.shape[1]):

                                # check if nonzero
                                if options.all_sed or not np.isclose(
                                        tss_sed[snp_i, tss_i,
                                                ti], 0, atol=1e-4):

                                    # print
                                    cols = (snp.rsid,
                                            bvcf.cap_allele(snp.ref_allele),
                                            bvcf.cap_allele(
                                                snp.alt_alleles[0]),
                                            tss.identifier, snp_dist,
                                            ref_tss_preds[tss_i, ti],
                                            alt_tss_preds[snp_i, tss_i, ti],
                                            tss_sed[snp_i, tss_i,
                                                    ti], tss_ser[snp_i, tss_i,
                                                                 ti], ti,
                                            target_ids[ti], target_labels[ti])
                                    if options.csv:
                                        print(','.join([str(c) for c in cols]),
                                              file=sed_tss_out)
                                    else:
                                        print(
                                            '%-13s %s %5s %16s %5d %7.4f %7.4f %7.4f %7.4f %4d %12s %s'
                                            % cols,
                                            file=sed_tss_out)

                    # for each gene
                    for gi in range(len(gene_ids)):
                        gene_str = gene_ids[gi]
                        if gene_ids[gi] in gene_data.multi_seq_genes:
                            gene_str = '%s_multi' % gene_ids[gi]

                        # print rows to gene table
                        for ti in range(ref_gene_preds.shape[1]):

                            # check if nonzero
                            if options.all_sed or not np.isclose(
                                    gene_sed[snp_i, gi, ti], 0, atol=1e-4):

                                # print
                                cols = [
                                    snp.rsid,
                                    bvcf.cap_allele(snp.ref_allele),
                                    bvcf.cap_allele(snp.alt_alleles[0]),
                                    gene_str, snp_dist_gene[gene_ids[gi]],
                                    ref_gene_preds[gi,
                                                   ti], alt_gene_preds[snp_i,
                                                                       gi, ti],
                                    gene_sed[snp_i, gi,
                                             ti], gene_ser[snp_i, gi, ti], ti,
                                    target_ids[ti], target_labels[ti]
                                ]
                                if options.csv:
                                    print(','.join([str(c) for c in cols]),
                                          file=sed_gene_out)
                                else:
                                    print(
                                        '%-13s %s %5s %16s %5d %7.4f %7.4f %7.4f %7.4f %4d %12s %s'
                                        % tuple(cols),
                                        file=sed_gene_out)

                # clean up
                gc.collect()

    sed_gene_out.close()
    if options.tss_table:
        sed_tss_out.close()