Ejemplo n.º 1
0
def main():
  usage = 'usage: %prog [options] <params_file> <model_file> <vcf_file>'
  parser = OptionParser(usage)
  parser.add_option('-c', dest='center_pct',
      default=0.25, type='float',
      help='Require clustered SNPs lie in center region [Default: %default]')
  parser.add_option('-f', dest='genome_fasta',
      default='%s/data/hg19.fa' % os.environ['BASENJIDIR'],
      help='Genome FASTA for sequences [Default: %default]')
  parser.add_option('--flip', dest='flip_ref',
      default=False, action='store_true',
      help='Flip reference/alternate alleles when simple [Default: %default]')
  parser.add_option('-n', dest='norm_file',
      default=None,
      help='Normalize SAD scores')
  parser.add_option('-o',dest='out_dir',
      default='sad',
      help='Output directory for tables and plots [Default: %default]')
  parser.add_option('-p', dest='processes',
      default=None, type='int',
      help='Number of processes, passed by multi script')
  parser.add_option('--pseudo', dest='log_pseudo',
      default=1, type='float',
      help='Log2 pseudocount [Default: %default]')
  parser.add_option('--rc', dest='rc',
      default=False, action='store_true',
      help='Average forward and reverse complement predictions [Default: %default]')
  parser.add_option('--shifts', dest='shifts',
      default='0', type='str',
      help='Ensemble prediction shifts [Default: %default]')
  parser.add_option('--stats', dest='sad_stats',
      default='SAD',
      help='Comma-separated list of stats to save. [Default: %default]')
  parser.add_option('-t', dest='targets_file',
      default=None, type='str',
      help='File specifying target indexes and labels in table format')
  parser.add_option('--ti', dest='track_indexes',
      default=None, type='str',
      help='Comma-separated list of target indexes to output BigWig tracks')
  parser.add_option('--threads', dest='threads',
      default=False, action='store_true',
      help='Run CPU math and output in a separate thread [Default: %default]')
  parser.add_option('-u', dest='penultimate',
      default=False, action='store_true',
      help='Compute SED in the penultimate layer [Default: %default]')
  (options, args) = parser.parse_args()

  if len(args) == 3:
    # single worker
    params_file = args[0]
    model_file = args[1]
    vcf_file = args[2]

  elif len(args) == 5:
    # multi worker
    options_pkl_file = args[0]
    params_file = args[1]
    model_file = args[2]
    vcf_file = args[3]
    worker_index = int(args[4])

    # load options
    options_pkl = open(options_pkl_file, 'rb')
    options = pickle.load(options_pkl)
    options_pkl.close()

    # update output directory
    options.out_dir = '%s/job%d' % (options.out_dir, worker_index)

  else:
    parser.error('Must provide parameters and model files and QTL VCF file')

  if not os.path.isdir(options.out_dir):
    os.mkdir(options.out_dir)

  if options.track_indexes is None:
    options.track_indexes = []
  else:
    options.track_indexes = [int(ti) for ti in options.track_indexes.split(',')]
    if not os.path.isdir('%s/tracks' % options.out_dir):
      os.mkdir('%s/tracks' % options.out_dir)

  options.shifts = [int(shift) for shift in options.shifts.split(',')]
  options.sad_stats = options.sad_stats.split(',')


  #################################################################
  # read parameters and targets

  # read model parameters
  with open(params_file) as params_open:
    params = json.load(params_open)
  params_model = params['model']
  params_train = params['train']

  if options.targets_file is None:
    target_slice = None
  else:
    targets_df = pd.read_csv(options.targets_file, sep='\t', index_col=0)
    target_ids = targets_df.identifier
    target_labels = targets_df.description
    target_slice = targets_df.index

  if options.penultimate:
    parser.error('Not implemented for TF2')

  #################################################################
  # setup model

  seqnn_model = seqnn.SeqNN(params_model)
  seqnn_model.restore(model_file)
  seqnn_model.build_slice(target_slice)
  seqnn_model.build_ensemble(options.rc, options.shifts)

  num_targets = seqnn_model.num_targets()
  if options.targets_file is None:
    target_ids = ['t%d' % ti for ti in range(num_targets)]
    target_labels = ['']*len(target_ids)

  #################################################################
  # load SNPs

  # filter for worker SNPs
  if options.processes is not None:
    # determine boundaries
    num_snps = bvcf.vcf_count(vcf_file)
    worker_bounds = np.linspace(0, num_snps, options.processes+1, dtype='int')

    # read sorted SNPs from VCF
    snps = bvcf.vcf_snps(vcf_file, require_sorted=True, flip_ref=options.flip_ref,
                         validate_ref_fasta=options.genome_fasta,
                         start_i=worker_bounds[worker_index],
                         end_i=worker_bounds[worker_index+1])
  else:
    # read sorted SNPs from VCF
    snps = bvcf.vcf_snps(vcf_file, require_sorted=True, flip_ref=options.flip_ref,
                         validate_ref_fasta=options.genome_fasta)

  # cluster SNPs by position
  snp_clusters = cluster_snps(snps, params_model['seq_length'], options.center_pct)

  # delimit sequence boundaries
  [sc.delimit(params_model['seq_length']) for sc in snp_clusters]

  # open genome FASTA
  genome_open = pysam.Fastafile(options.genome_fasta)

  # make SNP sequence generator
  def snp_gen():
    for sc in snp_clusters:
      snp_1hot_list = sc.get_1hots(genome_open)
      for snp_1hot in snp_1hot_list:
        yield snp_1hot


  #################################################################
  # setup output

  snp_flips = np.array([snp.flipped for snp in snps], dtype='bool')

  sad_out = initialize_output_h5(options.out_dir, options.sad_stats,
                                 snps, target_ids, target_labels)

  if options.threads:
    snp_threads = []
    snp_queue = Queue()
    for i in range(1):
      sw = SNPWorker(snp_queue, sad_out, options.sad_stats, options.log_pseudo)
      sw.start()
      snp_threads.append(sw)


  #################################################################
  # predict SNP scores, write output

  # initialize predictions stream
  preds_stream = stream.PredStreamGen(seqnn_model, snp_gen(), params['train']['batch_size'])

  # predictions index
  pi = 0

  # SNP index
  si = 0

  for snp_cluster in snp_clusters:
    ref_preds = preds_stream[pi]
    pi += 1

    for snp in snp_cluster.snps:
      # print(snp, flush=True)

      alt_preds = preds_stream[pi]
      pi += 1

      if snp_flips[si]:
        ref_preds, alt_preds = alt_preds, ref_preds

      if options.threads:
        # queue SNP
          snp_queue.put((ref_preds, alt_preds, si))
      else:
        # process SNP
        write_snp(ref_preds, alt_preds, sad_out, si,
                  options.sad_stats, options.log_pseudo)

      # update SNP index
      si += 1

  # finish queue
  if options.threads:
    print('Waiting for threads to finish.', flush=True)
    snp_queue.join()

  # close genome
  genome_open.close()

  ###################################################
  # compute SAD distributions across variants

  write_pct(sad_out, options.sad_stats)
  sad_out.close()
Ejemplo n.º 2
0
def main():
  usage = 'usage: %prog [options] <params_file> <model_file> <vcf_file>'
  parser = OptionParser(usage)
  parser.add_option('-c', dest='csv',
      default=False, action='store_true',
      help='Print table as CSV [Default: %default]')
  parser.add_option('--cpu', dest='cpu',
      default=False, action='store_true',
      help='Run without a GPU [Default: %default]')
  parser.add_option('-f', dest='genome_fasta',
      default='%s/data/hg19.fa' % os.environ['BASENJIDIR'],
      help='Genome FASTA for sequences [Default: %default]')
  parser.add_option('-g', dest='genome_file',
      default='%s/data/human.hg19.genome' % os.environ['BASENJIDIR'],
      help='Chromosome lengths file [Default: %default]')
  parser.add_option('--h5', dest='out_h5',
      default=False, action='store_true',
      help='Output stats to sad.h5 [Default: %default]')
  parser.add_option('--local', dest='local',
      default=1024, type='int',
      help='Local SAD score [Default: %default]')
  parser.add_option('-n', dest='norm_file',
      default=None,
      help='Normalize SAD scores')
  parser.add_option('-o',dest='out_dir',
      default='sad',
      help='Output directory for tables and plots [Default: %default]')
  parser.add_option('-p', dest='processes',
      default=None, type='int',
      help='Number of processes, passed by multi script')
  parser.add_option('--pseudo', dest='log_pseudo',
      default=1, type='float',
      help='Log2 pseudocount [Default: %default]')
  parser.add_option('--rc', dest='rc',
      default=False, action='store_true',
      help='Average forward and reverse complement predictions [Default: %default]')
  parser.add_option('--shifts', dest='shifts',
      default='0', type='str',
      help='Ensemble prediction shifts [Default: %default]')
  parser.add_option('--stats', dest='sad_stats',
      default='SAD',
      help='Comma-separated list of stats to save. [Default: %default]')
  parser.add_option('-t', dest='targets_file',
      default=None, type='str',
      help='File specifying target indexes and labels in table format')
  parser.add_option('--ti', dest='track_indexes',
      default=None, type='str',
      help='Comma-separated list of target indexes to output BigWig tracks')
  parser.add_option('-u', dest='penultimate',
      default=False, action='store_true',
      help='Compute SED in the penultimate layer [Default: %default]')
  parser.add_option('-z', dest='out_zarr',
      default=False, action='store_true',
      help='Output stats to sad.zarr [Default: %default]')
  (options, args) = parser.parse_args()

  if len(args) == 3:
    # single worker
    params_file = args[0]
    model_file = args[1]
    vcf_file = args[2]

  elif len(args) == 5:
    # multi worker
    options_pkl_file = args[0]
    params_file = args[1]
    model_file = args[2]
    vcf_file = args[3]
    worker_index = int(args[4])

    # load options
    options_pkl = open(options_pkl_file, 'rb')
    options = pickle.load(options_pkl)
    options_pkl.close()

    # update output directory
    options.out_dir = '%s/job%d' % (options.out_dir, worker_index)

  else:
    parser.error('Must provide parameters and model files and QTL VCF file')

  if not os.path.isdir(options.out_dir):
    os.mkdir(options.out_dir)

  if options.track_indexes is None:
    options.track_indexes = []
  else:
    options.track_indexes = [int(ti) for ti in options.track_indexes.split(',')]
    if not os.path.isdir('%s/tracks' % options.out_dir):
      os.mkdir('%s/tracks' % options.out_dir)

  options.shifts = [int(shift) for shift in options.shifts.split(',')]
  options.sad_stats = options.sad_stats.split(',')


  #################################################################
  # read parameters

  job = params.read_job_params(params_file, require=['seq_length','num_targets'])

  if options.targets_file is None:
    target_ids = ['t%d' % ti for ti in range(job['num_targets'])]
    target_labels = ['']*len(target_ids)
    target_subset = None

  else:
    targets_df = pd.read_table(options.targets_file, index_col=0)
    target_ids = targets_df.identifier
    target_labels = targets_df.description
    target_subset = targets_df.index
    if len(target_subset) == job['num_targets']:
        target_subset = None


  #################################################################
  # load SNPs

  snps = bvcf.vcf_snps(vcf_file)

  # filter for worker SNPs
  if options.processes is not None:
    worker_bounds = np.linspace(0, len(snps), options.processes+1, dtype='int')
    snps = snps[worker_bounds[worker_index]:worker_bounds[worker_index+1]]

  num_snps = len(snps)

  # open genome FASTA
  genome_open = pysam.Fastafile(options.genome_fasta)

  def snp_gen():
    for snp in snps:
      # get SNP sequences
      snp_1hot_list = bvcf.snp_seq1(snp, job['seq_length'], genome_open)

      for snp_1hot in snp_1hot_list:
        yield {'sequence':snp_1hot}

  snp_types = {'sequence': tf.float32}
  snp_shapes = {'sequence': tf.TensorShape([tf.Dimension(job['seq_length']),
                                            tf.Dimension(4)])}

  dataset = tf.data.Dataset().from_generator(snp_gen,
                                             output_types=snp_types,
                                             output_shapes=snp_shapes)
  dataset = dataset.batch(job['batch_size'])
  dataset = dataset.prefetch(2*job['batch_size'])
  if not options.cpu:
    dataset = dataset.apply(tf.contrib.data.prefetch_to_device('/device:GPU:0'))

  iterator = dataset.make_one_shot_iterator()
  data_ops = iterator.get_next()


  #################################################################
  # setup model

  # build model
  t0 = time.time()
  model = seqnn.SeqNN()
  model.build_sad(job, data_ops,
                  ensemble_rc=options.rc, ensemble_shifts=options.shifts,
                  embed_penultimate=options.penultimate, target_subset=target_subset)
  print('Model building time %f' % (time.time() - t0), flush=True)

  if options.penultimate:
    # labels become inappropriate
    target_ids = ['']*model.hp.cnn_filters[-1]
    target_labels = target_ids

  # read target normalization factors
  target_norms = np.ones(len(target_labels))
  if options.norm_file is not None:
    ti = 0
    for line in open(options.norm_file):
      target_norms[ti] = float(line.strip())
      ti += 1

  num_targets = len(target_ids)

  #################################################################
  # setup output

  if options.out_h5:
    sad_out = initialize_output_h5(options.out_dir, options.sad_stats,
                                   snps, target_ids, target_labels)

  elif options.out_zarr:
    sad_out = initialize_output_zarr(options.out_dir, options.sad_stats,
                                     snps, target_ids, target_labels)

  else:
    header_cols = ('rsid', 'ref', 'alt',
                  'ref_pred', 'alt_pred', 'sad', 'sar', 'geo_sad',
                  'ref_lpred', 'alt_lpred', 'lsad', 'lsar',
                  'ref_xpred', 'alt_xpred', 'xsad', 'xsar',
                  'target_index', 'target_id', 'target_label')

    if options.csv:
      sad_out = open('%s/sad_table.csv' % options.out_dir, 'w')
      print(','.join(header_cols), file=sad_out)
    else:
      sad_out = open('%s/sad_table.txt' % options.out_dir, 'w')
      print(' '.join(header_cols), file=sad_out)


  #################################################################
  # process

  szi = 0
  sum_write_thread = None
  sw_batch_size = 32 // job['batch_size']

  # initialize saver
  saver = tf.train.Saver()
  with tf.Session() as sess:
    # coordinator
    coord = tf.train.Coordinator()
    tf.train.start_queue_runners(coord=coord)

    # load variables into session
    saver.restore(sess, model_file)

    # predict first
    batch_preds = model.predict_tfr(sess, test_batches=sw_batch_size)

    while batch_preds.shape[0] > 0:
      # count predicted SNPs
      num_snps = batch_preds.shape[0] // 2

      # normalize
      batch_preds /= target_norms

      # block for last thread
      if sum_write_thread is not None:
        sum_write_thread.join()

      # summarize and write
      sum_write_thread = threading.Thread(target=summarize_write,
            args=(batch_preds, sad_out, szi, options.sad_stats, options.log_pseudo))
      sum_write_thread.start()

      # update SNP index
      szi += num_snps

      # predict next
      batch_preds = model.predict_tfr(sess, test_batches=sw_batch_size)

  sum_write_thread.join()

  ###################################################
  # compute SAD distributions across variants

  if options.out_h5 or options.out_zarr:
    # define percentiles
    d_fine = 0.001
    d_coarse = 0.01
    percentiles_neg = np.arange(d_fine, 0.1, d_fine)
    percentiles_base = np.arange(0.1, 0.9, d_coarse)
    percentiles_pos = np.arange(0.9, 1, d_fine)

    percentiles = np.concatenate([percentiles_neg, percentiles_base, percentiles_pos])
    sad_out.create_dataset('percentiles', data=percentiles)
    pct_len = len(percentiles)

    for sad_stat in options.sad_stats:
      sad_stat_pct = '%s_pct' % sad_stat

      # compute
      sad_pct = np.percentile(sad_out[sad_stat], 100*percentiles, axis=0).T
      sad_pct = sad_pct.astype('float16')

      # save
      sad_out.create_dataset(sad_stat_pct, data=sad_pct, dtype='float16')

  if not options.out_zarr:
    sad_out.close()
Ejemplo n.º 3
0
def main():
    usage = 'usage: %prog [options] <model> <vcf_file>'
    parser = OptionParser(usage)
    parser.add_option('-f',
                      dest='genome_fasta',
                      default='%s/data/hg19.fa' % os.environ['BASENJIDIR'],
                      help='Genome FASTA for sequences [Default: %default]')
    parser.add_option('-n',
                      dest='norm_file',
                      default=None,
                      help='Normalize SAD scores')
    parser.add_option(
        '-o',
        dest='out_dir',
        default='sad',
        help='Output directory for tables and plots [Default: %default]')
    parser.add_option('-p',
                      dest='processes',
                      default=None,
                      type='int',
                      help='Number of processes, passed by multi script')
    parser.add_option('--pseudo',
                      dest='log_pseudo',
                      default=1,
                      type='float',
                      help='Log2 pseudocount [Default: %default]')
    parser.add_option(
        '--rc',
        dest='rc',
        default=False,
        action='store_true',
        help=
        'Average forward and reverse complement predictions [Default: %default]'
    )
    parser.add_option('--shifts',
                      dest='shifts',
                      default='0',
                      type='str',
                      help='Ensemble prediction shifts [Default: %default]')
    parser.add_option('--species', dest='species', default='human')
    parser.add_option(
        '--stats',
        dest='sad_stats',
        default='SAD',
        help='Comma-separated list of stats to save. [Default: %default]')
    parser.add_option(
        '-t',
        dest='targets_file',
        default=None,
        type='str',
        help='File specifying target indexes and labels in table format')
    parser.add_option(
        '--ti',
        dest='track_indexes',
        default=None,
        type='str',
        help='Comma-separated list of target indexes to output BigWig tracks')
    (options, args) = parser.parse_args()

    if len(args) == 2:
        # single worker
        model_file = args[0]
        vcf_file = args[1]

    elif len(args) == 3:
        # multi separate
        options_pkl_file = args[0]
        model_file = args[1]
        vcf_file = args[2]

        # save out dir
        out_dir = options.out_dir

        # load options
        options_pkl = open(options_pkl_file, 'rb')
        options = pickle.load(options_pkl)
        options_pkl.close()

        # update output directory
        options.out_dir = out_dir

    elif len(args) == 4:
        # multi worker
        options_pkl_file = args[0]
        model_file = args[1]
        vcf_file = args[2]
        worker_index = int(args[3])

        # load options
        options_pkl = open(options_pkl_file, 'rb')
        options = pickle.load(options_pkl)
        options_pkl.close()

        # update output directory
        options.out_dir = '%s/job%d' % (options.out_dir, worker_index)

    else:
        parser.error('Must provide model and VCF file')

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    if options.track_indexes is None:
        options.track_indexes = []
    else:
        options.track_indexes = [
            int(ti) for ti in options.track_indexes.split(',')
        ]
        if not os.path.isdir('%s/tracks' % options.out_dir):
            os.mkdir('%s/tracks' % options.out_dir)

    options.shifts = [int(shift) for shift in options.shifts.split(',')]
    options.sad_stats = options.sad_stats.split(',')

    #################################################################
    # read parameters and targets

    if options.targets_file is None:
        target_slice = None
    else:
        targets_df = pd.read_csv(options.targets_file, sep='\t', index_col=0)
        target_ids = targets_df.identifier
        target_labels = targets_df.description
        target_slice = targets_df.index

    if options.penultimate:
        parser.error('Not implemented for TF2')

    #################################################################
    # setup model

    seqnn_model = tf.saved_model.load(model_file).model

    if options.targets_file is None:
        num_targets = 5313
        target_ids = ['t%d' % ti for ti in range(num_targets)]
        target_labels = [''] * len(target_ids)

    #################################################################
    # load SNPs

    # filter for worker SNPs
    if options.processes is not None:
        # determine boundaries
        num_snps = bvcf.vcf_count(vcf_file)
        worker_bounds = np.linspace(0,
                                    num_snps,
                                    options.processes + 1,
                                    dtype='int')

        # read SNPs form VCF
        snps = bvcf.vcf_snps(vcf_file,
                             start_i=worker_bounds[worker_index],
                             end_i=worker_bounds[worker_index + 1])

    else:
        # read SNPs form VCF
        snps = bvcf.vcf_snps(vcf_file)

    num_snps = len(snps)

    # open genome FASTA
    genome_open = pysam.Fastafile(options.genome_fasta)

    seq_length = seqnn_model.predict_on_batch.input_signature[0].shape[1]

    def snp_gen():
        for snp in snps:
            # get SNP sequences
            snp_1hot_list = bvcf.snp_seq1(snp, seq_length, genome_open)
            for snp_1hot in snp_1hot_list:
                yield snp_1hot

    #################################################################
    # setup output

    sad_out = initialize_output_h5(options.out_dir, options.sad_stats, snps,
                                   target_ids, target_labels)

    #################################################################
    # predict SNP scores, write output

    # initialize predictions stream
    preds_stream = PredStreamGen(seqnn_model,
                                 snp_gen(),
                                 rc=options.rc,
                                 shifts=options.shifts,
                                 species=options.species)

    # predictions index
    pi = 0

    for si in range(num_snps):
        # get predictions
        ref_preds = preds_stream[pi]
        pi += 1
        alt_preds = preds_stream[pi]
        pi += 1

        # process SNP
        write_snp(ref_preds, alt_preds, sad_out, si, options.sad_stats,
                  options.log_pseudo)

    # close genome
    genome_open.close()

    ###################################################
    # compute SAD distributions across variants

    write_pct(sad_out, options.sad_stats)
    sad_out.close()
Ejemplo n.º 4
0
def main():
    usage = 'usage: %prog [options] <model> <vcf_file>'
    parser = OptionParser(usage)
    parser.add_option(
        '-d',
        dest='mut_down',
        default=0,
        type='int',
        help=
        'Nucleotides downstream of center sequence to mutate [Default: %default]'
    )
    parser.add_option('-f',
                      dest='figure_width',
                      default=20,
                      type='float',
                      help='Figure width [Default: %default]')
    parser.add_option(
        '--f1',
        dest='genome1_fasta',
        default='%s/data/hg38.fa' % os.environ['BASENJIDIR'],
        help='Genome FASTA which which major allele sequences will be drawn')
    parser.add_option(
        '--f2',
        dest='genome2_fasta',
        default=None,
        help='Genome FASTA which which minor allele sequences will be drawn')
    parser.add_option(
        '-l',
        dest='mut_len',
        default=200,
        type='int',
        help='Length of centered sequence to mutate [Default: %default]')
    parser.add_option('-o',
                      dest='out_dir',
                      default='sat_vcf',
                      help='Output directory [Default: %default]')
    parser.add_option(
        '--rc',
        dest='rc',
        default=False,
        action='store_true',
        help=
        'Ensemble forward and reverse complement predictions [Default: %default]'
    )
    parser.add_option('--shifts',
                      dest='shifts',
                      default='0',
                      help='Ensemble prediction shifts [Default: %default]')
    parser.add_option('--species', dest='species', default='human')
    parser.add_option(
        '--stats',
        dest='sad_stats',
        default='sum',
        help='Comma-separated list of stats to save. [Default: %default]')
    parser.add_option(
        '-t',
        dest='targets_file',
        default=None,
        type='str',
        help='File specifying target indexes and labels in table format')
    parser.add_option(
        '-u',
        dest='mut_up',
        default=0,
        type='int',
        help=
        'Nucleotides upstream of center sequence to mutate [Default: %default]'
    )
    (options, args) = parser.parse_args()

    if len(args) != 2:
        parser.error('Must provide model and VCF')
    else:
        model_file = args[0]
        vcf_file = args[1]

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    options.shifts = [int(shift) for shift in options.shifts.split(',')]
    options.sad_stats = [
        sad_stat.lower() for sad_stat in options.sad_stats.split(',')
    ]

    if options.mut_up > 0 or options.mut_down > 0:
        options.mut_len = options.mut_up + options.mut_down
    else:
        assert (options.mut_len > 0)
        options.mut_up = options.mut_len // 2
        options.mut_down = options.mut_len - options.mut_up

    #################################################################
    # read parameters and targets

    # read targets
    if options.targets_file is None:
        target_slice = None
    else:
        targets_df = pd.read_table(options.targets_file, index_col=0)
        target_slice = targets_df.index

    #################################################################
    # setup model

    seqnn_model = tf.saved_model.load(model_file).model

    # query num model targets
    seq_length = seqnn_model.predict_on_batch.input_signature[0].shape[1]
    null_1hot = np.zeros((1, seq_length, 4))
    null_preds = seqnn_model.predict_on_batch(null_1hot)
    null_preds = null_preds[options.species].numpy()
    _, preds_length, num_targets = null_preds.shape

    #################################################################
    # SNP sequence dataset

    # load SNPs
    snps = vcf.vcf_snps(vcf_file)

    # get one hot coded input sequences
    if not options.genome2_fasta:
        seqs_1hot, seq_headers, snps, seqs_dna = vcf.snps_seq1(
            snps, seq_length, options.genome1_fasta, return_seqs=True)
    else:
        seqs_1hot, seq_headers, snps, seqs_dna = vcf.snps2_seq1(
            snps,
            seq_length,
            options.genome1_fasta,
            options.genome2_fasta,
            return_seqs=True)
    num_seqs = seqs_1hot.shape[0]

    # determine mutation region limits
    seq_mid = seq_length // 2
    mut_start = seq_mid - options.mut_up
    mut_end = mut_start + options.mut_len

    # make sequence generator
    seqs_gen = satmut_gen(seqs_dna, mut_start, mut_end)

    #################################################################
    # setup output

    scores_h5_file = '%s/scores.h5' % options.out_dir
    if os.path.isfile(scores_h5_file):
        os.remove(scores_h5_file)
    scores_h5 = h5py.File(scores_h5_file, 'w')
    scores_h5.create_dataset('label', data=np.array(seq_headers, dtype='S'))
    scores_h5.create_dataset('seqs',
                             dtype='bool',
                             shape=(num_seqs, options.mut_len, 4))
    for sad_stat in options.sad_stats:
        scores_h5.create_dataset(sad_stat,
                                 dtype='float16',
                                 shape=(num_seqs, options.mut_len, 4,
                                        num_targets))

    preds_per_seq = 1 + 3 * options.mut_len

    score_threads = []
    score_queue = Queue()
    for i in range(1):
        sw = ScoreWorker(score_queue, scores_h5, options.sad_stats, mut_start,
                         mut_end)
        sw.start()
        score_threads.append(sw)

    #################################################################
    # predict scores and write output

    # find center
    center_start = preds_length // 2
    if preds_length % 2 == 0:
        center_end = center_start + 2
    else:
        center_end = center_start + 1

    # initialize predictions stream
    preds_stream = PredStreamGen(seqnn_model,
                                 seqs_gen,
                                 rc=options.rc,
                                 shifts=options.shifts,
                                 species=options.species)

    # predictions index
    pi = 0

    for si in range(num_seqs):
        print('Predicting %d' % si, flush=True)

        # collect sequence predictions
        seq_preds_sum = []
        seq_preds_center = []
        seq_preds_scd = []
        preds_mut0 = preds_stream[pi]
        for spi in range(preds_per_seq):
            preds_mut = preds_stream[pi]
            preds_sum = preds_mut.sum(axis=0)
            seq_preds_sum.append(preds_sum)
            if 'center' in options.sad_stats:
                preds_center = preds_mut[center_start:center_end, :].sum(
                    axis=0)
                seq_preds_center.append(preds_center)
            if 'scd' in options.sad_stats:
                preds_scd = np.sqrt(((preds_mut - preds_mut0)**2).sum(axis=0))
                seq_preds_scd.append(preds_scd)
            pi += 1
        seq_preds_sum = np.array(seq_preds_sum)
        seq_preds_center = np.array(seq_preds_center)
        seq_preds_scd = np.array(seq_preds_scd)

        # wait for previous to finish
        score_queue.join()

        # queue sequence for scoring
        seq_pred_stats = (seq_preds_sum, seq_preds_center, seq_preds_scd)
        score_queue.put((seqs_dna[si], seq_pred_stats, si))

        gc.collect()

    # finish queue
    print('Waiting for threads to finish.', flush=True)
    score_queue.join()

    # close output HDF5
    scores_h5.close()
Ejemplo n.º 5
0
    def test_get_1hots(self):
        # read sorted SNPs from VCF
        snps = bvcf.vcf_snps(self.vcf_file, require_sorted=True)

        # cluster SNPs by position
        snp_clusters = basenji_sad_ref.cluster_snps(snps,
                                                    self.params["seq_length"],
                                                    0.25)

        # delimit sequence boundaries
        [sc.delimit(self.params["seq_length"]) for sc in snp_clusters]

        # open genome FASTA
        genome_open = pysam.Fastafile(self.genome_fasta)

        ########################################
        # verify single SNP

        # get 1 hot coded sequences
        snp_1hot_list = snp_clusters[0].get_1hots(genome_open)

        self.assertEqual(len(snp_1hot_list), 2)
        self.assertEqual(snp_1hot_list[1].shape,
                         (self.params["seq_length"], 4))

        mid_i = self.params["seq_length"] // 2 - 1
        self.assertEqual(mid_i, snps[0].seq_pos)

        ref_nt = dna_io.hot1_get(snp_1hot_list[0], mid_i)
        self.assertEqual(ref_nt, snps[0].ref_allele)

        alt_nt = dna_io.hot1_get(snp_1hot_list[1], mid_i)
        self.assertEqual(alt_nt, snps[0].alt_alleles[0])

        ########################################
        # verify multiple SNPs

        # get 1 hot coded sequences
        snp_1hot_list = snp_clusters[6].get_1hots(genome_open)

        self.assertEqual(len(snp_1hot_list), 3)

        snp1, snp2 = snps[6:8]

        # verify position 1 changes between 0 and 1
        nt = dna_io.hot1_get(snp_1hot_list[0], snp1.seq_pos)
        self.assertEqual(nt, snp1.ref_allele)

        nt = dna_io.hot1_get(snp_1hot_list[1], snp1.seq_pos)
        self.assertEqual(nt, snp1.alt_alleles[0])

        # verify position 2 is unchanged between 0 and 1
        nt = dna_io.hot1_get(snp_1hot_list[0], snp2.seq_pos)
        self.assertEqual(nt, snp2.ref_allele)

        nt = dna_io.hot1_get(snp_1hot_list[1], snp2.seq_pos)
        self.assertEqual(nt, snp2.ref_allele)

        # verify position 2 is unchanged between 0 and 2
        nt = dna_io.hot1_get(snp_1hot_list[0], snp1.seq_pos)
        self.assertEqual(nt, snp1.ref_allele)

        nt = dna_io.hot1_get(snp_1hot_list[2], snp1.seq_pos)
        self.assertEqual(nt, snp1.ref_allele)

        # verify position 2 changes between 0 and 2
        nt = dna_io.hot1_get(snp_1hot_list[0], snp2.seq_pos)
        self.assertEqual(nt, snp2.ref_allele)

        nt = dna_io.hot1_get(snp_1hot_list[2], snp2.seq_pos)
        self.assertEqual(nt, snp2.alt_alleles[0])
Ejemplo n.º 6
0
def main():
    usage = 'usage: %prog [options] <params_file> <model_file> <vcf_file>'
    parser = OptionParser(usage)
    parser.add_option(
        '-c',
        dest='center_pct',
        default=0.25,
        type='float',
        help='Require clustered SNPs lie in center region [Default: %default]')
    parser.add_option('-f',
                      dest='genome_fasta',
                      default='%s/data/hg19.fa' % os.environ['BASENJIDIR'],
                      help='Genome FASTA for sequences [Default: %default]')
    parser.add_option('-g',
                      dest='genome_file',
                      default='%s/data/human.hg19.genome' %
                      os.environ['BASENJIDIR'],
                      help='Chromosome lengths file [Default: %default]')
    parser.add_option('--h5',
                      dest='out_h5',
                      default=False,
                      action='store_true',
                      help='Output stats to sad.h5 [Default: %default]')
    parser.add_option('--local',
                      dest='local',
                      default=1024,
                      type='int',
                      help='Local SAD score [Default: %default]')
    parser.add_option('-n',
                      dest='norm_file',
                      default=None,
                      help='Normalize SAD scores')
    parser.add_option(
        '-o',
        dest='out_dir',
        default='sad',
        help='Output directory for tables and plots [Default: %default]')
    parser.add_option('-p',
                      dest='processes',
                      default=None,
                      type='int',
                      help='Number of processes, passed by multi script')
    parser.add_option('--pseudo',
                      dest='log_pseudo',
                      default=1,
                      type='float',
                      help='Log2 pseudocount [Default: %default]')
    parser.add_option(
        '--rc',
        dest='rc',
        default=False,
        action='store_true',
        help=
        'Average forward and reverse complement predictions [Default: %default]'
    )
    parser.add_option('--shifts',
                      dest='shifts',
                      default='0',
                      type='str',
                      help='Ensemble prediction shifts [Default: %default]')
    parser.add_option(
        '--stats',
        dest='sad_stats',
        default='SAD',
        help='Comma-separated list of stats to save. [Default: %default]')
    parser.add_option(
        '-t',
        dest='targets_file',
        default=None,
        type='str',
        help='File specifying target indexes and labels in table format')
    parser.add_option(
        '--ti',
        dest='track_indexes',
        default=None,
        type='str',
        help='Comma-separated list of target indexes to output BigWig tracks')
    parser.add_option(
        '-u',
        dest='penultimate',
        default=False,
        action='store_true',
        help='Compute SED in the penultimate layer [Default: %default]')
    # parser.add_option('-z', dest='out_zarr',
    #     default=False, action='store_true',
    #     help='Output stats to sad.zarr [Default: %default]')
    (options, args) = parser.parse_args()

    if len(args) == 3:
        # single worker
        params_file = args[0]
        model_file = args[1]
        vcf_file = args[2]

    elif len(args) == 5:
        # multi worker
        options_pkl_file = args[0]
        params_file = args[1]
        model_file = args[2]
        vcf_file = args[3]
        worker_index = int(args[4])

        # load options
        options_pkl = open(options_pkl_file, 'rb')
        options = pickle.load(options_pkl)
        options_pkl.close()

        # update output directory
        options.out_dir = '%s/job%d' % (options.out_dir, worker_index)

    else:
        parser.error(
            'Must provide parameters and model files and QTL VCF file')

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    if options.track_indexes is None:
        options.track_indexes = []
    else:
        options.track_indexes = [
            int(ti) for ti in options.track_indexes.split(',')
        ]
        if not os.path.isdir('%s/tracks' % options.out_dir):
            os.mkdir('%s/tracks' % options.out_dir)

    options.shifts = [int(shift) for shift in options.shifts.split(',')]
    options.sad_stats = options.sad_stats.split(',')

    #################################################################
    # read parameters and collet target information

    job = params.read_job_params(params_file,
                                 require=['seq_length', 'num_targets'])

    if options.targets_file is None:
        target_ids = ['t%d' % ti for ti in range(job['num_targets'])]
        target_labels = [''] * len(target_ids)
        target_subset = None

    else:
        targets_df = pd.read_table(options.targets_file, index_col=0)
        target_ids = targets_df.identifier
        target_labels = targets_df.description
        target_subset = targets_df.index
        if len(target_subset) == job['num_targets']:
            target_subset = None

    #################################################################
    # load SNPs

    # read sorted SNPs from VCF
    snps = bvcf.vcf_snps(vcf_file,
                         require_sorted=True,
                         flip_ref=True,
                         validate_ref_fasta=options.genome_fasta)

    # filter for worker SNPs
    if options.processes is not None:
        worker_bounds = np.linspace(0,
                                    len(snps),
                                    options.processes + 1,
                                    dtype='int')
        snps = snps[worker_bounds[worker_index]:worker_bounds[worker_index +
                                                              1]]

    num_snps = len(snps)

    # cluster SNPs by position
    snp_clusters = cluster_snps(snps, job['seq_length'], options.center_pct)

    # delimit sequence boundaries
    [sc.delimit(job['seq_length']) for sc in snp_clusters]

    # open genome FASTA
    genome_open = pysam.Fastafile(options.genome_fasta)

    # make SNP sequence generator
    def snp_gen():
        for sc in snp_clusters:
            snp_1hot_list = sc.get_1hots(genome_open)
            for snp_1hot in snp_1hot_list:
                yield {'sequence': snp_1hot}

    snp_types = {'sequence': tf.float32}
    snp_shapes = {
        'sequence':
        tf.TensorShape([tf.Dimension(job['seq_length']),
                        tf.Dimension(4)])
    }

    dataset = tf.data.Dataset().from_generator(snp_gen,
                                               output_types=snp_types,
                                               output_shapes=snp_shapes)
    dataset = dataset.batch(job['batch_size'])
    dataset = dataset.prefetch(2 * job['batch_size'])
    # dataset = dataset.apply(tf.contrib.data.prefetch_to_device('/device:GPU:0'))

    iterator = dataset.make_one_shot_iterator()
    data_ops = iterator.get_next()

    #################################################################
    # setup model

    # build model
    t0 = time.time()
    model = seqnn.SeqNN()
    model.build_sad(job,
                    data_ops,
                    ensemble_rc=options.rc,
                    ensemble_shifts=options.shifts,
                    embed_penultimate=options.penultimate,
                    target_subset=target_subset)
    print('Model building time %f' % (time.time() - t0), flush=True)

    if options.penultimate:
        # labels become inappropriate
        target_ids = [''] * model.hp.cnn_filters[-1]
        target_labels = target_ids

    # read target normalization factors
    target_norms = np.ones(len(target_labels))
    if options.norm_file is not None:
        ti = 0
        for line in open(options.norm_file):
            target_norms[ti] = float(line.strip())
            ti += 1

    num_targets = len(target_ids)

    #################################################################
    # setup output

    sad_out = initialize_output_h5(options.out_dir, options.sad_stats, snps,
                                   target_ids, target_labels)

    snp_threads = []

    snp_queue = Queue()
    for i in range(1):
        sw = SNPWorker(snp_queue, sad_out, options.sad_stats,
                       options.log_pseudo)
        sw.start()
        snp_threads.append(sw)

    #################################################################
    # predict SNP scores, write output

    # initialize saver
    saver = tf.train.Saver()
    with tf.Session() as sess:
        # coordinator
        coord = tf.train.Coordinator()
        tf.train.start_queue_runners(coord=coord)

        # load variables into session
        saver.restore(sess, model_file)

        # initialize predictions stream
        preds_stream = PredStream(sess, model, 32)

        # predictions index
        pi = 0

        # SNP index
        si = 0

        for snp_cluster in snp_clusters:
            ref_preds = preds_stream[pi]
            pi += 1

            for snp in snp_cluster.snps:
                # print(snp, flush=True)

                alt_preds = preds_stream[pi]
                pi += 1

                # queue SNP
                snp_queue.put((ref_preds, alt_preds, si))

                # update SNP index
                si += 1

    # finish queue
    print('Waiting for threads to finish.', flush=True)
    snp_queue.join()

    # close genome
    genome_open.close()

    ###################################################
    # compute SAD distributions across variants

    # define percentiles
    d_fine = 0.001
    d_coarse = 0.01
    percentiles_neg = np.arange(d_fine, 0.1, d_fine)
    percentiles_base = np.arange(0.1, 0.9, d_coarse)
    percentiles_pos = np.arange(0.9, 1, d_fine)

    percentiles = np.concatenate(
        [percentiles_neg, percentiles_base, percentiles_pos])
    sad_out.create_dataset('percentiles', data=percentiles)
    pct_len = len(percentiles)

    for sad_stat in options.sad_stats:
        sad_stat_pct = '%s_pct' % sad_stat

        # compute
        sad_pct = np.percentile(sad_out[sad_stat], 100 * percentiles, axis=0).T
        sad_pct = sad_pct.astype('float16')

        # save
        sad_out.create_dataset(sad_stat_pct, data=sad_pct, dtype='float16')

    sad_out.close()
Ejemplo n.º 7
0
def main():
    usage = (
        "usage: %prog [options] <params_file> <model_file> <genes_hdf5_file>"
        " <vcf_file>"
    )
    parser = OptionParser(usage)
    parser.add_option(
        "-a",
        dest="all_sed",
        default=False,
        action="store_true",
        help="Print all variant-gene pairs, as opposed to only nonzero [Default: %default]",
    )
    parser.add_option(
        "-b",
        dest="batch_size",
        default=None,
        type="int",
        help="Batch size [Default: %default]",
    )
    parser.add_option(
        "-c",
        dest="csv",
        default=False,
        action="store_true",
        help="Print table as CSV [Default: %default]",
    )
    parser.add_option(
        "-g",
        dest="genome_file",
        default="%s/data/human.hg19.genome" % os.environ["BASENJIDIR"],
        help="Chromosome lengths file [Default: %default]",
    )
    parser.add_option(
        "-o",
        dest="out_dir",
        default="sed",
        help="Output directory for tables and plots [Default: %default]",
    )
    parser.add_option(
        "-p",
        dest="processes",
        default=None,
        type="int",
        help="Number of processes, passed by multi script",
    )
    parser.add_option(
        "--pseudo",
        dest="log_pseudo",
        default=0.125,
        type="float",
        help="Log2 pseudocount [Default: %default]",
    )
    parser.add_option(
        "-r",
        dest="tss_radius",
        default=0,
        type="int",
        help="Radius of bins considered to quantify TSS transcription [Default: %default]",
    )
    parser.add_option(
        "--rc",
        dest="rc",
        default=False,
        action="store_true",
        help="Average the forward and reverse complement predictions when testing [Default: %default]",
    )
    parser.add_option(
        "--shifts",
        dest="shifts",
        default="0",
        help="Ensemble prediction shifts [Default: %default]",
    )
    parser.add_option(
        "-t",
        dest="targets_file",
        default=None,
        help="File specifying target indexes and labels in table format",
    )
    parser.add_option(
        "-u",
        dest="penultimate",
        default=False,
        action="store_true",
        help="Compute SED in the penultimate layer [Default: %default]",
    )
    parser.add_option(
        "-x",
        dest="tss_table",
        default=False,
        action="store_true",
        help="Print TSS table in addition to gene [Default: %default]",
    )
    (options, args) = parser.parse_args()

    if len(args) == 4:
        # single worker
        params_file = args[0]
        model_file = args[1]
        genes_hdf5_file = args[2]
        vcf_file = args[3]

    elif len(args) == 6:
        # multi worker
        options_pkl_file = args[0]
        params_file = args[1]
        model_file = args[2]
        genes_hdf5_file = args[3]
        vcf_file = args[4]
        worker_index = int(args[5])

        # load options
        options_pkl = open(options_pkl_file, "rb")
        options = pickle.load(options_pkl)
        options_pkl.close()

        # update output directory
        options.out_dir = "%s/job%d" % (options.out_dir, worker_index)

    else:
        parser.error(
            "Must provide parameters and model files, genes HDF5 file, and QTL VCF"
            " file"
        )

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    options.shifts = [int(shift) for shift in options.shifts.split(",")]

    #################################################################
    # reads in genes HDF5

    gene_data = genedata.GeneData(genes_hdf5_file)

    # filter for worker sequences
    if options.processes is not None:
        gene_data.worker(worker_index, options.processes)

    #################################################################
    # prep SNPs

    # load SNPs
    snps = bvcf.vcf_snps(vcf_file)

    # intersect w/ segments
    print("Intersecting gene sequences with SNPs...", end="")
    sys.stdout.flush()
    seqs_snps = bvcf.intersect_seqs_snps(vcf_file, gene_data.gene_seqs, vision_p=0.5)
    print("%d sequences w/ SNPs" % len(seqs_snps))

    #################################################################
    # setup model

    job = params.read_job_params(params_file)

    job["seq_length"] = gene_data.seq_length
    job["seq_depth"] = gene_data.seq_depth
    job["target_pool"] = gene_data.pool_width

    if "num_targets" not in job and gene_data.num_targets is not None:
        job["num_targets"] = gene_data.num_targets

    if "num_targets" not in job:
        print(
            "Must specify number of targets (num_targets) in the \
            parameters file.",
            file=sys.stderr,
        )
        exit(1)

    if options.targets_file is None:
        target_labels = gene_data.target_labels
        target_subset = None
        target_ids = gene_data.target_ids
        if target_ids is None:
            target_ids = ["t%d" % ti for ti in range(job["num_targets"])]
            target_labels = [""] * len(target_ids)

    else:
        # Unfortunately, this target file differs from some others
        # in that it needs to specify the indexes from the original
        # set. In the future, I should standardize to this version.

        target_ids = []
        target_labels = []
        target_subset = []

        for line in open(options.targets_file):
            a = line.strip().split("\t")
            target_subset.append(int(a[0]))
            target_ids.append(a[1])
            target_labels.append(a[3])

            if len(target_subset) == job["num_targets"]:
                target_subset = None

    # build model
    model = seqnn.SeqNN()
    model.build_feed(job, target_subset=target_subset)

    if options.penultimate:
        # labels become inappropriate
        target_ids = [""] * model.hp.cnn_filters[-1]
        target_labels = target_ids

    #################################################################
    # compute, collect, and print SEDs

    header_cols = (
        "rsid",
        "ref",
        "alt",
        "gene",
        "tss_dist",
        "ref_pred",
        "alt_pred",
        "sed",
        "ser",
        "target_index",
        "target_id",
        "target_label",
    )
    if options.csv:
        sed_gene_out = open("%s/sed_gene.csv" % options.out_dir, "w")
        print(",".join(header_cols), file=sed_gene_out)
        if options.tss_table:
            sed_tss_out = open("%s/sed_tss.csv" % options.out_dir, "w")
            print(",".join(header_cols), file=sed_tss_out)

    else:
        sed_gene_out = open("%s/sed_gene.txt" % options.out_dir, "w")
        print(" ".join(header_cols), file=sed_gene_out)
        if options.tss_table:
            sed_tss_out = open("%s/sed_tss.txt" % options.out_dir, "w")
            print(" ".join(header_cols), file=sed_tss_out)

    # helper variables
    pred_buffer = model.hp.batch_buffer // model.hp.target_pool

    # initialize saver
    saver = tf.train.Saver()

    with tf.Session() as sess:
        # load variables into session
        saver.restore(sess, model_file)

        # for each gene sequence
        for seq_i in range(gene_data.num_seqs):
            gene_seq = gene_data.gene_seqs[seq_i]
            print(gene_seq)

            # if it contains SNPs
            seq_snps = [snps[snp_i] for snp_i in seqs_snps[seq_i]]
            if len(seq_snps) > 0:

                # one hot code allele sequences
                aseqs_1hot = alleles_1hot(
                    gene_seq, gene_data.seqs_1hot[seq_i], seq_snps
                )

                # initialize batcher
                batcher_gene = batcher.Batcher(
                    aseqs_1hot, batch_size=model.hp.batch_size
                )

                # construct allele gene_seq's
                allele_gene_seqs = [gene_seq] * aseqs_1hot.shape[0]

                # predict alleles
                allele_tss_preds = model.predict_genes(
                    sess,
                    batcher_gene,
                    allele_gene_seqs,
                    rc=options.rc,
                    shifts=options.shifts,
                    embed_penultimate=options.penultimate,
                    tss_radius=options.tss_radius,
                )

                # reshape (Alleles x TSSs) x Targets to Alleles x TSSs x Targets
                allele_tss_preds = allele_tss_preds.reshape(
                    (aseqs_1hot.shape[0], gene_seq.num_tss, -1)
                )

                # extract reference and SNP alt predictions
                ref_tss_preds = allele_tss_preds[0]  # TSSs x Targets
                alt_tss_preds = allele_tss_preds[1:]  # SNPs x TSSs x Targets

                # compute TSS SED scores
                snp_tss_sed = alt_tss_preds - ref_tss_preds
                snp_tss_ser = np.log2(alt_tss_preds + options.log_pseudo) - np.log2(
                    ref_tss_preds + options.log_pseudo
                )

                # compute gene-level predictions
                ref_gene_preds, gene_ids = gene.map_tss_genes(
                    ref_tss_preds, gene_seq.tss_list, options.tss_radius
                )
                alt_gene_preds = []
                for snp_i in range(len(seq_snps)):
                    agp, _ = gene.map_tss_genes(
                        alt_tss_preds[snp_i], gene_seq.tss_list, options.tss_radius
                    )
                    alt_gene_preds.append(agp)
                alt_gene_preds = np.array(alt_gene_preds)

                # compute gene SED scores
                gene_sed = alt_gene_preds - ref_gene_preds
                gene_ser = np.log2(alt_gene_preds + options.log_pseudo) - np.log2(
                    ref_gene_preds + options.log_pseudo
                )

                # for each SNP
                for snp_i in range(len(seq_snps)):
                    snp = seq_snps[snp_i]

                    # initialize gene data structures
                    snp_dist_gene = {}

                    # for each TSS
                    for tss_i in range(gene_seq.num_tss):
                        tss = gene_seq.tss_list[tss_i]

                        # SNP distance to TSS
                        snp_dist = abs(snp.pos - tss.pos)
                        if tss.gene_id in snp_dist_gene:
                            snp_dist_gene[tss.gene_id] = min(
                                snp_dist_gene[tss.gene_id], snp_dist
                            )
                        else:
                            snp_dist_gene[tss.gene_id] = snp_dist

                        # for each target
                        if options.tss_table:
                            for ti in range(ref_tss_preds.shape[1]):

                                # check if nonzero
                                if options.all_sed or not np.isclose(
                                    tss_sed[snp_i, tss_i, ti], 0, atol=1e-4
                                ):

                                    # print
                                    cols = (
                                        snp.rsid,
                                        bvcf.cap_allele(snp.ref_allele),
                                        bvcf.cap_allele(snp.alt_alleles[0]),
                                        tss.identifier,
                                        snp_dist,
                                        ref_tss_preds[tss_i, ti],
                                        alt_tss_preds[snp_i, tss_i, ti],
                                        tss_sed[snp_i, tss_i, ti],
                                        tss_ser[snp_i, tss_i, ti],
                                        ti,
                                        target_ids[ti],
                                        target_labels[ti],
                                    )
                                    if options.csv:
                                        print(
                                            ",".join([str(c) for c in cols]),
                                            file=sed_tss_out,
                                        )
                                    else:
                                        print(
                                            "%-13s %s %5s %16s %5d %7.4f %7.4f %7.4f %7.4f %4d %12s %s"
                                            % cols,
                                            file=sed_tss_out,
                                        )

                    # for each gene
                    for gi in range(len(gene_ids)):
                        gene_str = gene_ids[gi]
                        if gene_ids[gi] in gene_data.multi_seq_genes:
                            gene_str = "%s_multi" % gene_ids[gi]

                        # print rows to gene table
                        for ti in range(ref_gene_preds.shape[1]):

                            # check if nonzero
                            if options.all_sed or not np.isclose(
                                gene_sed[snp_i, gi, ti], 0, atol=1e-4
                            ):

                                # print
                                cols = [
                                    snp.rsid,
                                    bvcf.cap_allele(snp.ref_allele),
                                    bvcf.cap_allele(snp.alt_alleles[0]),
                                    gene_str,
                                    snp_dist_gene[gene_ids[gi]],
                                    ref_gene_preds[gi, ti],
                                    alt_gene_preds[snp_i, gi, ti],
                                    gene_sed[snp_i, gi, ti],
                                    gene_ser[snp_i, gi, ti],
                                    ti,
                                    target_ids[ti],
                                    target_labels[ti],
                                ]
                                if options.csv:
                                    print(
                                        ",".join([str(c) for c in cols]),
                                        file=sed_gene_out,
                                    )
                                else:
                                    print(
                                        "%-13s %s %5s %16s %5d %7.4f %7.4f %7.4f %7.4f %4d %12s %s"
                                        % tuple(cols),
                                        file=sed_gene_out,
                                    )

                # clean up
                gc.collect()

    sed_gene_out.close()
    if options.tss_table:
        sed_tss_out.close()
Ejemplo n.º 8
0
def main():
  usage = 'usage: %prog [options] <params_file> <model_file> <vcf_file>'
  parser = OptionParser(usage)
  parser.add_option(
      '-b',
      dest='batch_size',
      default=256,
      type='int',
      help='Batch size [Default: %default]')
  parser.add_option(
      '-c',
      dest='csv',
      default=False,
      action='store_true',
      help='Print table as CSV [Default: %default]')
  parser.add_option(
      '-f',
      dest='genome_fasta',
      default='%s/assembly/hg19.fa' % os.environ['HG19'],
      help='Genome FASTA from which sequences will be drawn [Default: %default]'
  )
  parser.add_option(
      '-g',
      dest='genome_file',
      default='%s/assembly/human.hg19.genome' % os.environ['HG19'],
      help='Chromosome lengths file [Default: %default]')
  parser.add_option(
      '-l',
      dest='seq_len',
      type='int',
      default=131072,
      help='Sequence length provided to the model [Default: %default]')
  parser.add_option(
      '--local',
      dest='local',
      default=1024,
      type='int',
      help='Local SAD score [Default: %default]')
  parser.add_option(
      '-n',
      dest='norm_file',
      default=None,
      help='Normalize SAD scores')
  parser.add_option(
      '-o',
      dest='out_dir',
      default='sad',
      help='Output directory for tables and plots [Default: %default]')
  parser.add_option(
      '-p',
      dest='processes',
      default=None,
      type='int',
      help='Number of processes, passed by multi script')
  parser.add_option(
      '--pseudo',
      dest='log_pseudo',
      default=1,
      type='float',
      help='Log2 pseudocount [Default: %default]')
  parser.add_option(
      '--rc',
      dest='rc',
      default=False,
      action='store_true',
      help=
      'Average the forward and reverse complement predictions when testing [Default: %default]'
  )
  parser.add_option(
      '--shifts',
      dest='shifts',
      default='0',
      help='Ensemble prediction shifts [Default: %default]')
  parser.add_option(
      '-t',
      dest='targets_file',
      default=None,
      help='File specifying target indexes and labels in table format')
  parser.add_option(
      '--ti',
      dest='track_indexes',
      default=None,
      help='Comma-separated list of target indexes to output BigWig tracks')
  parser.add_option(
      '-u',
      dest='penultimate',
      default=False,
      action='store_true',
      help='Compute SED in the penultimate layer [Default: %default]')
  (options, args) = parser.parse_args()

  if len(args) == 3:
    # single worker
    params_file = args[0]
    model_file = args[1]
    vcf_file = args[2]

  elif len(args) == 5:
    # multi worker
    options_pkl_file = args[0]
    params_file = args[1]
    model_file = args[2]
    vcf_file = args[3]
    worker_index = int(args[4])

    # load options
    options_pkl = open(options_pkl_file, 'rb')
    options = pickle.load(options_pkl)
    options_pkl.close()

    # update output directory
    options.out_dir = '%s/job%d' % (options.out_dir, worker_index)

  else:
    parser.error('Must provide parameters and model files and QTL VCF file')

  if not os.path.isdir(options.out_dir):
    os.mkdir(options.out_dir)

  if options.track_indexes is None:
    options.track_indexes = []
  else:
    options.track_indexes = [int(ti) for ti in options.track_indexes.split(',')]
    if not os.path.isdir('%s/tracks' % options.out_dir):
      os.mkdir('%s/tracks' % options.out_dir)

  options.shifts = [int(shift) for shift in options.shifts.split(',')]

  #################################################################
  # setup model

  job = basenji.dna_io.read_job_params(params_file)
  job['seq_length'] = options.seq_len

  if 'num_targets' not in job:
    print(
        "Must specify number of targets (num_targets) in the parameters file.",
        file=sys.stderr)
    exit(1)

  if 'target_pool' not in job:
    print(
        "Must specify target pooling (target_pool) in the parameters file.",
        file=sys.stderr)
    exit(1)

  if options.targets_file is None:
    target_ids = ['t%d' % ti for ti in range(job['num_targets'])]
    target_labels = ['']*len(target_ids)
    target_subset = None

  else:
    # Unfortunately, this target file differs from some others
    # in that it needs to specify the indexes from the original
    # set. In the future, I should standardize to this version.

    target_ids = []
    target_labels = []
    target_subset = []

    for line in open(options.targets_file):
      a = line.strip().split('\t')
      target_subset.append(int(a[0]))
      target_ids.append(a[1])
      target_labels.append(a[3])

      if len(target_subset) == job['num_targets']:
        target_subset = None

  # build model
  t0 = time.time()
  model = basenji.seqnn.SeqNN()
  model.build(job, target_subset=target_subset)
  print('Model building time %f' % (time.time() - t0), flush=True)

  if options.penultimate:
    # labels become inappropriate
    target_ids = ['']*model.cnn_filters[-1]
    target_labels = target_ids

  # read target normalization factors
  target_norms = np.ones(len(target_labels))
  if options.norm_file is not None:
    ti = 0
    for line in open(options.norm_file):
      target_norms[ti] = float(line.strip())
      ti += 1


  #################################################################
  # load SNPs

  snps = bvcf.vcf_snps(vcf_file)

  # filter for worker SNPs
  if options.processes is not None:
    snps = [
        snps[si] for si in range(len(snps))
        if si % options.processes == worker_index
    ]

  #################################################################
  # setup output

  header_cols = ('rsid', 'ref', 'alt',
                  'ref_pred', 'alt_pred', 'sad', 'sar', 'geo_sad',
                  'ref_lpred', 'alt_lpred', 'lsad', 'lsar',
                  'ref_xpred', 'alt_xpred', 'xsad', 'xsar',
                  'target_index', 'target_id', 'target_label')

  if options.csv:
    sad_out = open('%s/sad_table.csv' % options.out_dir, 'w')
    print(','.join(header_cols), file=sad_out)
  else:
    sad_out = open('%s/sad_table.txt' % options.out_dir, 'w')
    print(' '.join(header_cols), file=sad_out)


  #################################################################
  # process

  # open genome FASTA
  genome_open = pysam.Fastafile(options.genome_fasta)

  # determine local start and end
  loc_mid = model.target_length // 2
  loc_start = loc_mid - (options.local//2) // model.target_pool
  loc_end = loc_start + options.local // model.target_pool

  snp_i = 0

  # initialize saver
  saver = tf.train.Saver()

  with tf.Session() as sess:
    # load variables into session
    saver.restore(sess, model_file)

    # construct first batch
    batch_1hot, batch_snps, snp_i = snps_next_batch(
        snps, snp_i, options.batch_size, options.seq_len, genome_open)

    while len(batch_snps) > 0:
      ###################################################
      # predict

      # initialize batcher
      batcher = basenji.batcher.Batcher(batch_1hot, batch_size=model.batch_size)

      # predict
      batch_preds = model.predict(sess, batcher,
                      rc=options.rc, shifts=options.shifts,
                      penultimate=options.penultimate)

      # normalize
      batch_preds /= target_norms


      ###################################################
      # collect and print SADs

      pi = 0
      for snp in batch_snps:
        # get reference prediction (LxT)
        ref_preds = batch_preds[pi]
        pi += 1

        # sum across length
        ref_preds_sum = ref_preds.sum(axis=0, dtype='float64')

        # print tracks
        for ti in options.track_indexes:
          ref_bw_file = '%s/tracks/%s_t%d_ref.bw' % (options.out_dir, snp.rsid,
                                                     ti)
          bigwig_write(snp, options.seq_len, ref_preds[:, ti], model,
                       ref_bw_file, options.genome_file)

        for alt_al in snp.alt_alleles:
          # get alternate prediction (LxT)
          alt_preds = batch_preds[pi]
          pi += 1

          # sum across length
          alt_preds_sum = alt_preds.sum(axis=0, dtype='float64')

          # compare reference to alternative via mean subtraction
          sad_vec = alt_preds - ref_preds
          sad = alt_preds_sum - ref_preds_sum

          # compare reference to alternative via mean log division
          sar = np.log2(alt_preds_sum + options.log_pseudo) \
                  - np.log2(ref_preds_sum + options.log_pseudo)

          # compare geometric means
          sar_vec = np.log2(alt_preds.astype('float64') + options.log_pseudo) \
                      - np.log2(ref_preds.astype('float64') + options.log_pseudo)
          geo_sad = sar_vec.sum(axis=0)

          # sum locally
          ref_preds_loc = ref_preds[loc_start:loc_end,:].sum(axis=0, dtype='float64')
          alt_preds_loc = alt_preds[loc_start:loc_end,:].sum(axis=0, dtype='float64')

          # compute SAD locally
          sad_loc = alt_preds_loc - ref_preds_loc
          sar_loc = np.log2(alt_preds_loc + options.log_pseudo) \
                      - np.log2(ref_preds_loc + options.log_pseudo)

          # compute max difference position
          max_li = np.argmax(np.abs(sar_vec), axis=0)

          # print table lines
          for ti in range(len(sad)):
            # print line
            cols = (snp.rsid, bvcf.cap_allele(snp.ref_allele), bvcf.cap_allele(alt_al),
                    ref_preds_sum[ti], alt_preds_sum[ti], sad[ti], sar[ti], geo_sad[ti],
                    ref_preds_loc[ti], alt_preds_loc[ti], sad_loc[ti], sar_loc[ti],
                    ref_preds[max_li[ti], ti], alt_preds[max_li[ti], ti], sad_vec[max_li[ti],ti], sar_vec[max_li[ti],ti],
                    ti, target_ids[ti], target_labels[ti])
            if options.csv:
              print(','.join([str(c) for c in cols]), file=sad_out)
            else:
              print(
                  '%-13s %6s %6s | %8.2f %8.2f %8.3f %7.4f %7.3f | %7.3f %7.3f %7.3f %7.4f | %7.3f %7.3f %7.3f %7.4f | %4d %12s %s'
                  % cols,
                  file=sad_out)

          # print tracks
          for ti in options.track_indexes:
            alt_bw_file = '%s/tracks/%s_t%d_alt.bw' % (options.out_dir,
                                                       snp.rsid, ti)
            bigwig_write(snp, options.seq_len, alt_preds[:, ti], model,
                         alt_bw_file, options.genome_file)

      ###################################################
      # construct next batch

      batch_1hot, batch_snps, snp_i = snps_next_batch(
          snps, snp_i, options.batch_size, options.seq_len, genome_open)

  sad_out.close()
Ejemplo n.º 9
0
def main():
  usage = ('usage: %prog [options] <params_file> <model_file> <genes_hdf5_file>'
           ' <vcf_file>')
  parser = OptionParser(usage)
  parser.add_option(
      '-a',
      dest='all_sed',
      default=False,
      action='store_true',
      help=
      'Print all variant-gene pairs, as opposed to only nonzero [Default: %default]')
  parser.add_option(
      '-b',
      dest='batch_size',
      default=None,
      type='int',
      help='Batch size [Default: %default]')
  parser.add_option(
      '-c',
      dest='csv',
      default=False,
      action='store_true',
      help='Print table as CSV [Default: %default]')
  parser.add_option(
      '-g',
      dest='genome_file',
      default='%s/assembly/human.hg19.genome' % os.environ['HG19'],
      help='Chromosome lengths file [Default: %default]')
  parser.add_option(
      '-o',
      dest='out_dir',
      default='sed',
      help='Output directory for tables and plots [Default: %default]')
  parser.add_option(
      '-p',
      dest='processes',
      default=None,
      type='int',
      help='Number of processes, passed by multi script')
  parser.add_option(
      '--pseudo',
      dest='log_pseudo',
      default=0.125,
      type='float',
      help='Log2 pseudocount [Default: %default]')
  parser.add_option(
      '-r',
      dest='tss_radius',
      default=0,
      type='int',
      help='Radius of bins considered to quantify TSS transcription [Default: %default]')
  parser.add_option(
      '--rc',
      dest='rc',
      default=False,
      action='store_true',
      help=
      'Average the forward and reverse complement predictions when testing [Default: %default]')
  parser.add_option(
      '--shifts',
      dest='shifts',
      default='0',
      help='Ensemble prediction shifts [Default: %default]')
  parser.add_option(
      '-t',
      dest='targets_file',
      default=None,
      help='File specifying target indexes and labels in table format')
  parser.add_option(
      '-u',
      dest='penultimate',
      default=False,
      action='store_true',
      help='Compute SED in the penultimate layer [Default: %default]')
  parser.add_option(
      '-x',
      dest='tss_table',
      default=False,
      action='store_true',
      help='Print TSS table in addition to gene [Default: %default]')
  (options, args) = parser.parse_args()

  if len(args) == 4:
    # single worker
    params_file = args[0]
    model_file = args[1]
    genes_hdf5_file = args[2]
    vcf_file = args[3]

  elif len(args) == 6:
    # multi worker
    options_pkl_file = args[0]
    params_file = args[1]
    model_file = args[2]
    genes_hdf5_file = args[3]
    vcf_file = args[4]
    worker_index = int(args[5])

    # load options
    options_pkl = open(options_pkl_file, 'rb')
    options = pickle.load(options_pkl)
    options_pkl.close()

    # update output directory
    options.out_dir = '%s/job%d' % (options.out_dir, worker_index)

  else:
    parser.error(
        'Must provide parameters and model files, genes HDF5 file, and QTL VCF'
        ' file')

  if not os.path.isdir(options.out_dir):
    os.mkdir(options.out_dir)

  options.shifts = [int(shift) for shift in options.shifts.split(',')]

  #################################################################
  # reads in genes HDF5

  gene_data = genedata.GeneData(genes_hdf5_file)

  # filter for worker sequences
  if options.processes is not None:
    gene_data.worker(worker_index, options.processes)

  #################################################################
  # prep SNPs

  # load SNPs
  snps = bvcf.vcf_snps(vcf_file)

  # intersect w/ segments
  print('Intersecting gene sequences with SNPs...', end='')
  sys.stdout.flush()
  seqs_snps = bvcf.intersect_seqs_snps(
      vcf_file, gene_data.gene_seqs, vision_p=0.5)
  print('%d sequences w/ SNPs' % len(seqs_snps))


  #################################################################
  # setup model

  job = params.read_job_params(params_file)

  job['seq_length'] = gene_data.seq_length
  job['seq_depth'] = gene_data.seq_depth
  job['target_pool'] = gene_data.pool_width

  if 'num_targets' not in job and gene_data.num_targets is not None:
    job['num_targets'] = gene_data.num_targets

  if 'num_targets' not in job:
    print("Must specify number of targets (num_targets) in the \
            parameters file.", file=sys.stderr)
    exit(1)

  if options.targets_file is None:
    target_labels = gene_data.target_labels
    target_subset = None
    target_ids = gene_data.target_ids
    if target_ids is None:
      target_ids = ['t%d'%ti for ti in range(job['num_targets'])]
      target_labels = ['']*len(target_ids)

  else:
    # Unfortunately, this target file differs from some others
    # in that it needs to specify the indexes from the original
    # set. In the future, I should standardize to this version.

    target_ids = []
    target_labels = []
    target_subset = []

    for line in open(options.targets_file):
      a = line.strip().split('\t')
      target_subset.append(int(a[0]))
      target_ids.append(a[1])
      target_labels.append(a[3])

      if len(target_subset) == job['num_targets']:
        target_subset = None

  # build model
  model = seqnn.SeqNN()
  model.build_feed(job, target_subset=target_subset)

  if options.penultimate:
    # labels become inappropriate
    target_ids = ['']*model.hp.cnn_filters[-1]
    target_labels = target_ids

  #################################################################
  # compute, collect, and print SEDs

  header_cols = ('rsid', 'ref', 'alt', 'gene', 'tss_dist', 'ref_pred',
                 'alt_pred', 'sed', 'ser', 'target_index', 'target_id',
                 'target_label')
  if options.csv:
    sed_gene_out = open('%s/sed_gene.csv' % options.out_dir, 'w')
    print(','.join(header_cols), file=sed_gene_out)
    if options.tss_table:
      sed_tss_out = open('%s/sed_tss.csv' % options.out_dir, 'w')
      print(','.join(header_cols), file=sed_tss_out)

  else:
    sed_gene_out = open('%s/sed_gene.txt' % options.out_dir, 'w')
    print(' '.join(header_cols), file=sed_gene_out)
    if options.tss_table:
      sed_tss_out = open('%s/sed_tss.txt' % options.out_dir, 'w')
      print(' '.join(header_cols), file=sed_tss_out)

  # helper variables
  pred_buffer = model.hp.batch_buffer // model.hp.target_pool

  # initialize saver
  saver = tf.train.Saver()

  with tf.Session() as sess:
    # load variables into session
    saver.restore(sess, model_file)

    # for each gene sequence
    for seq_i in range(gene_data.num_seqs):
      gene_seq = gene_data.gene_seqs[seq_i]
      print(gene_seq)

      # if it contains SNPs
      seq_snps = [snps[snp_i] for snp_i in seqs_snps[seq_i]]
      if len(seq_snps) > 0:

        # one hot code allele sequences
        aseqs_1hot = alleles_1hot(gene_seq, gene_data.seqs_1hot[seq_i], seq_snps)

        # initialize batcher
        batcher_gene = batcher.Batcher(aseqs_1hot, batch_size=model.hp.batch_size)

        # construct allele gene_seq's
        allele_gene_seqs = [gene_seq] * aseqs_1hot.shape[0]

        # predict alleles
        allele_tss_preds = model.predict_genes(sess, batcher_gene, allele_gene_seqs,
                                                rc=options.rc, shifts=options.shifts,
                                                embed_penultimate=options.penultimate,
                                                tss_radius=options.tss_radius)

        # reshape (Alleles x TSSs) x Targets to Alleles x TSSs x Targets
        allele_tss_preds = allele_tss_preds.reshape((aseqs_1hot.shape[0], gene_seq.num_tss, -1))

        # extract reference and SNP alt predictions
        ref_tss_preds = allele_tss_preds[0] # TSSs x Targets
        alt_tss_preds = allele_tss_preds[1:] # SNPs x TSSs x Targets

        # compute TSS SED scores
        snp_tss_sed = alt_tss_preds - ref_tss_preds
        snp_tss_ser = np.log2(alt_tss_preds + options.log_pseudo) \
                        - np.log2(ref_tss_preds + options.log_pseudo)

        # compute gene-level predictions
        ref_gene_preds, gene_ids = gene.map_tss_genes(ref_tss_preds, gene_seq.tss_list, options.tss_radius)
        alt_gene_preds = []
        for snp_i in range(len(seq_snps)):
          agp, _ = gene.map_tss_genes(alt_tss_preds[snp_i], gene_seq.tss_list, options.tss_radius)
          alt_gene_preds.append(agp)
        alt_gene_preds = np.array(alt_gene_preds)

        # compute gene SED scores
        gene_sed = alt_gene_preds - ref_gene_preds
        gene_ser = np.log2(alt_gene_preds + options.log_pseudo) \
                    - np.log2(ref_gene_preds + options.log_pseudo)

        # for each SNP
        for snp_i in range(len(seq_snps)):
          snp = seq_snps[snp_i]

          # initialize gene data structures
          snp_dist_gene = {}

          # for each TSS
          for tss_i in range(gene_seq.num_tss):
            tss = gene_seq.tss_list[tss_i]

            # SNP distance to TSS
            snp_dist = abs(snp.pos - tss.pos)
            if tss.gene_id in snp_dist_gene:
              snp_dist_gene[tss.gene_id] = min(snp_dist_gene[tss.gene_id], snp_dist)
            else:
              snp_dist_gene[tss.gene_id] = snp_dist

            # for each target
            if options.tss_table:
              for ti in range(ref_tss_preds.shape[1]):

                # check if nonzero
                if options.all_sed or not np.isclose(tss_sed[snp_i,tss_i,ti], 0, atol=1e-4):

                  # print
                  cols = (snp.rsid, bvcf.cap_allele(snp.ref_allele),
                          bvcf.cap_allele(snp.alt_alleles[0]),
                          tss.identifier, snp_dist, ref_tss_preds[tss_i,ti], alt_tss_preds[snp_i,tss_i,ti],
                          tss_sed[snp_i,tss_i,ti], tss_ser[snp_i,tss_i,ti], ti, target_ids[ti], target_labels[ti])
                  if options.csv:
                    print(','.join([str(c) for c in cols]), file=sed_tss_out)
                  else:
                    print(
                        '%-13s %s %5s %16s %5d %7.4f %7.4f %7.4f %7.4f %4d %12s %s'
                        % cols,
                        file=sed_tss_out)

          # for each gene
          for gi in range(len(gene_ids)):
            gene_str = gene_ids[gi]
            if gene_ids[gi] in gene_data.multi_seq_genes:
              gene_str = '%s_multi' % gene_ids[gi]

            # print rows to gene table
            for ti in range(ref_gene_preds.shape[1]):

              # check if nonzero
              if options.all_sed or not np.isclose(gene_sed[snp_i,gi,ti], 0, atol=1e-4):

                # print
                cols = [
                    snp.rsid,
                    bvcf.cap_allele(snp.ref_allele),
                    bvcf.cap_allele(snp.alt_alleles[0]), gene_str,
                    snp_dist_gene[gene_ids[gi]], ref_gene_preds[gi,ti], alt_gene_preds[snp_i,gi,ti],
                    gene_sed[snp_i,gi,ti], gene_ser[snp_i,gi,ti], ti, target_ids[ti], target_labels[ti]
                ]
                if options.csv:
                  print(','.join([str(c) for c in cols]), file=sed_gene_out)
                else:
                  print(
                      '%-13s %s %5s %16s %5d %7.4f %7.4f %7.4f %7.4f %4d %12s %s'
                      % tuple(cols),
                      file=sed_gene_out)

        # clean up
        gc.collect()

  sed_gene_out.close()
  if options.tss_table:
    sed_tss_out.close()
Ejemplo n.º 10
0
def main():
    usage = 'usage: %prog [options] <params_file> <model_file> <vcf_file>'
    parser = OptionParser(usage)
    parser.add_option('-f',
                      dest='figure_width',
                      default=20,
                      type='float',
                      help='Figure width [Default: %default]')
    parser.add_option(
        '--f1',
        dest='genome1_fasta',
        default='%s/assembly/hg19.fa' % os.environ['HG19'],
        help='Genome FASTA which which major allele sequences will be drawn')
    parser.add_option(
        '--f2',
        dest='genome2_fasta',
        default=None,
        help='Genome FASTA which which minor allele sequences will be drawn')
    parser.add_option(
        '-g',
        dest='gain',
        default=False,
        action='store_true',
        help='Draw a sequence logo for the gain score, too [Default: %default]'
    )
    parser.add_option(
        '-l',
        dest='satmut_len',
        default=200,
        type='int',
        help='Length of centered sequence to mutate [Default: %default]')
    parser.add_option('-m',
                      dest='mc_n',
                      default=0,
                      type='int',
                      help='Monte carlo iterations [Default: %default]')
    parser.add_option('--min',
                      dest='min_limit',
                      default=0.01,
                      type='float',
                      help='Minimum heatmap limit [Default: %default]')
    parser.add_option(
        '-n',
        dest='load_sat_npy',
        default=False,
        action='store_true',
        help='Load the predictions from .npy files [Default: %default]')
    parser.add_option('-o',
                      dest='out_dir',
                      default='sat_vcf',
                      help='Output directory [Default: %default]')
    parser.add_option(
        '--rc',
        dest='rc',
        default=False,
        action='store_true',
        help=
        'Ensemble forward and reverse complement predictions [Default: %default]'
    )
    parser.add_option('-s',
                      dest='seq_len',
                      default=131072,
                      type='int',
                      help='Input sequence length [Default: %default]')
    parser.add_option('--shifts',
                      dest='shifts',
                      default='0',
                      help='Ensemble prediction shifts [Default: %default]')
    parser.add_option(
        '-t',
        dest='targets',
        default='0',
        help='Comma-separated target indexes [Default: %default]')
    (options, args) = parser.parse_args()

    if len(args) != 3:
        parser.error('Must provide parameters and model files and VCF')
    else:
        params_file = args[0]
        model_file = args[1]
        vcf_file = args[2]

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    options.shifts = [int(shift) for shift in options.shifts.split(',')]

    # decide which targets to obtain
    target_indexes = [int(ti) for ti in options.targets.split(',')]

    #################################################################
    # prep SNP sequences
    #################################################################
    # load SNPs
    snps = vcf.vcf_snps(vcf_file)

    for si in range(len(snps)):
        print(snps[si])

    # get one hot coded input sequences
    if not options.genome2_fasta:
        seqs_1hot, seq_headers, snps, seqs = vcf.snps_seq1(
            snps, options.seq_len, options.genome1_fasta, return_seqs=True)
    else:
        seqs_1hot, seq_headers, snps, seqs = vcf.snps2_seq1(
            snps,
            options.seq_len,
            options.genome1_fasta,
            options.genome2_fasta,
            return_seqs=True)

    seqs_n = seqs_1hot.shape[0]

    #################################################################
    # setup model
    #################################################################
    job = params.read_job_params(params_file)

    job['seq_length'] = seqs_1hot.shape[1]
    job['seq_depth'] = seqs_1hot.shape[2]

    if 'num_targets' not in job or 'target_pool' not in job:
        print('Must provide num_targets and target_pool in parameters file',
              file=sys.stderr)
        exit(1)

    # build model
    model = seqnn.SeqNN()
    model.build_feed(job,
                     ensemble_rc=options.rc,
                     ensemble_shifts=options.shifts,
                     target_subset=target_indexes)

    # initialize saver
    saver = tf.train.Saver()

    #################################################################
    # predict and process
    #################################################################

    with tf.Session() as sess:
        # load variables into session
        saver.restore(sess, model_file)

        for si in range(seqs_n):
            header = seq_headers[si]
            header_fs = fs_clean(header)

            print('Mutating sequence %d / %d' % (si + 1, seqs_n), flush=True)

            # write sequence
            fasta_out = open('%s/seq%d.fa' % (options.out_dir, si), 'w')
            end_len = (len(seqs[si]) - options.satmut_len) // 2
            print('>seq%d\n%s' % (si, seqs[si][end_len:-end_len]),
                  file=fasta_out)
            fasta_out.close()

            #################################################################
            # predict modifications

            if options.load_sat_npy:
                sat_preds = np.load('%s/seq%d_preds.npy' %
                                    (options.out_dir, si))

            else:
                # supplement with saturated mutagenesis
                sat_seqs_1hot = satmut_seqs(seqs_1hot[si:si + 1],
                                            options.satmut_len)

                # initialize batcher
                batcher_sat = batcher.Batcher(sat_seqs_1hot,
                                              batch_size=model.hp.batch_size)

                # predict
                sat_preds = model.predict_h5(sess, batcher_sat)
                np.save('%s/seq%d_preds.npy' % (options.out_dir, si),
                        sat_preds)

            #################################################################
            # compute delta, loss, and gain matrices

            # compute the matrix of prediction deltas: (4 x L_sm x T) array
            sat_delta = delta_matrix(seqs_1hot[si], sat_preds,
                                     options.satmut_len)

            # sat_loss, sat_gain = loss_gain(sat_delta, sat_preds[si], options.satmut_len)
            sat_loss = sat_delta.min(axis=0)
            sat_gain = sat_delta.max(axis=0)

            ##############################################
            # plot

            for ti in range(len(target_indexes)):
                # setup plot
                sns.set(style='white', font_scale=1)
                spp = subplot_params(sat_delta.shape[1])

                if options.gain:
                    plt.figure(figsize=(options.figure_width, 4))
                    ax_logo_loss = plt.subplot2grid((4, spp['heat_cols']),
                                                    (0, spp['logo_start']),
                                                    colspan=spp['logo_span'])
                    ax_logo_gain = plt.subplot2grid((4, spp['heat_cols']),
                                                    (1, spp['logo_start']),
                                                    colspan=spp['logo_span'])
                    ax_sad = plt.subplot2grid((4, spp['heat_cols']),
                                              (2, spp['sad_start']),
                                              colspan=spp['sad_span'])
                    ax_heat = plt.subplot2grid((4, spp['heat_cols']), (3, 0),
                                               colspan=spp['heat_cols'])
                else:
                    plt.figure(figsize=(options.figure_width, 3))
                    ax_logo_loss = plt.subplot2grid((3, spp['heat_cols']),
                                                    (0, spp['logo_start']),
                                                    colspan=spp['logo_span'])
                    ax_sad = plt.subplot2grid((3, spp['heat_cols']),
                                              (1, spp['sad_start']),
                                              colspan=spp['sad_span'])
                    ax_heat = plt.subplot2grid((3, spp['heat_cols']), (2, 0),
                                               colspan=spp['heat_cols'])

                # plot sequence logo
                plot_seqlogo(ax_logo_loss, seqs_1hot[si], -sat_loss[:, ti])
                if options.gain:
                    plot_seqlogo(ax_logo_gain, seqs_1hot[si], sat_gain[:, ti])

                # plot SAD
                plot_sad(ax_sad, sat_loss[:, ti], sat_gain[:, ti])

                # plot heat map
                plot_heat(ax_heat, sat_delta[:, :, ti], options.min_limit)

                plt.tight_layout()
                plt.savefig('%s/%s_t%d.pdf' % (options.out_dir, header_fs, ti),
                            dpi=600)
                plt.close()
Ejemplo n.º 11
0
def main():
    usage = 'usage: %prog [options] <params_file> <model_file> <vcf_file>'
    parser = OptionParser(usage)
    parser.add_option('-f',
                      dest='genome_fasta',
                      default=None,
                      help='Genome FASTA for sequences [Default: %default]')
    parser.add_option(
        '-m',
        dest='plot_map',
        default=False,
        action='store_true',
        help='Plot contact map for each allele [Default: %default]')
    parser.add_option(
        '-o',
        dest='out_dir',
        default='scd',
        help='Output directory for tables and plots [Default: %default]')
    parser.add_option('-p',
                      dest='processes',
                      default=None,
                      type='int',
                      help='Number of processes, passed by multi script')
    parser.add_option(
        '--rc',
        dest='rc',
        default=False,
        action='store_true',
        help=
        'Average forward and reverse complement predictions [Default: %default]'
    )
    parser.add_option('--shifts',
                      dest='shifts',
                      default='0',
                      type='str',
                      help='Ensemble prediction shifts [Default: %default]')
    parser.add_option(
        '--stats',
        dest='scd_stats',
        default='SCD',
        help='Comma-separated list of stats to save. [Default: %default]')
    parser.add_option(
        '-t',
        dest='targets_file',
        default=None,
        type='str',
        help='File specifying target indexes and labels in table format')
    (options, args) = parser.parse_args()

    if len(args) == 3:
        # single worker
        params_file = args[0]
        model_file = args[1]
        vcf_file = args[2]

    elif len(args) == 5:
        # multi worker
        options_pkl_file = args[0]
        params_file = args[1]
        model_file = args[2]
        vcf_file = args[3]
        worker_index = int(args[4])

        # load options
        options_pkl = open(options_pkl_file, 'rb')
        options = pickle.load(options_pkl)
        options_pkl.close()

        # update output directory
        options.out_dir = '%s/job%d' % (options.out_dir, worker_index)

    else:
        parser.error(
            'Must provide parameters and model files and QTL VCF file')

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)
    if options.plot_map:
        plot_dir = options.out_dir
    else:
        plot_dir = None

    options.shifts = [int(shift) for shift in options.shifts.split(',')]
    options.scd_stats = options.scd_stats.split(',')

    #################################################################
    # read parameters

    # read model parameters
    with open(params_file) as params_open:
        params = json.load(params_open)
    params_train = params['train']
    params_model = params['model']

    if options.targets_file is None:
        target_ids = ['t%d' % ti for ti in range(params_model['num_targets'])]
        target_labels = [''] * len(target_ids)

    else:
        targets_df = pd.read_csv(options.targets_file, sep='\t', index_col=0)
        target_ids = targets_df.identifier
        target_labels = targets_df.description

    #################################################################
    # load SNPs

    # filter for worker SNPs
    if options.processes is not None:
        # determine boundaries
        num_snps = bvcf.vcf_count(vcf_file)
        worker_bounds = np.linspace(0,
                                    num_snps,
                                    options.processes + 1,
                                    dtype='int')

        # read SNPs form VCF
        snps = bvcf.vcf_snps(vcf_file,
                             start_i=worker_bounds[worker_index],
                             end_i=worker_bounds[worker_index + 1])

    else:
        # read SNPs form VCF
        snps = bvcf.vcf_snps(vcf_file)

    # open genome FASTA
    genome_open = pysam.Fastafile(options.genome_fasta)

    def snp_gen():
        for snp in snps:
            # get SNP sequences
            snp_1hot_list = bvcf.snp_seq1(snp, params_model['seq_length'],
                                          genome_open)

            for snp_1hot in snp_1hot_list:
                yield {'sequence': snp_1hot}

    snp_types = {'sequence': tf.float32}
    snp_shapes = {
        'sequence':
        tf.TensorShape(
            [tf.Dimension(params_model['seq_length']),
             tf.Dimension(4)])
    }

    dataset = tf.data.Dataset.from_generator(snp_gen,
                                             output_types=snp_types,
                                             output_shapes=snp_shapes)
    dataset = dataset.batch(params_train['batch_size'])
    dataset = dataset.prefetch(2 * params_train['batch_size'])
    dataset_iter = iter(dataset)

    # def get_chunk(chunk_size=32):
    #   """Get a chunk of data from the dataset iterator."""
    #   x = []
    #   for ci in range(chunk_size):
    #     try:
    #       x.append(next(dataset_iter))
    #     except StopIteration:
    #       break

    #################################################################
    # setup model

    # load model
    seqnn_model = seqnn.SeqNN(params_model)
    seqnn_model.restore(model_file)
    seqnn_model.build_ensemble(options.rc, options.shifts)

    #################################################################
    # setup output

    scd_out = initialize_output_h5(options.out_dir, options.scd_stats, snps,
                                   target_ids, target_labels)

    #################################################################
    # process

    szi = 0
    sum_write_thread = None

    # predict first
    # batch_seqs = get_chunk()
    # batch_preds = seqnn_model.predict(batch_seqs, steps=batch_seqs)
    batch_preds = seqnn_model.predict(dataset_iter, generator=True, steps=32)

    while len(batch_preds) > 0:
        # count predicted SNPs
        num_snps = batch_preds.shape[0] // 2

        # block for last thread
        if sum_write_thread is not None:
            sum_write_thread.join()

        # summarize and write
        sum_write_thread = threading.Thread(target=summarize_write,
                                            args=(batch_preds, scd_out, szi,
                                                  options.scd_stats, plot_dir,
                                                  seqnn_model.diagonal_offset))
        sum_write_thread.start()

        # update SNP index
        szi += num_snps

        # predict next
        try:
            # batch_preds = seqnn_model.predict(get_chunk())
            batch_preds = seqnn_model.predict(dataset_iter,
                                              generator=True,
                                              steps=32)
        except ValueError:
            batch_preds = []

    print('Waiting for threads to finish.', flush=True)
    sum_write_thread.join()

    scd_out.close()
Ejemplo n.º 12
0
def main():
    usage = 'usage: %prog [options] <params_file> <model_file> <vcf_file>'
    parser = OptionParser(usage)
    parser.add_option('-b',
                      dest='batch_size',
                      default=256,
                      type='int',
                      help='Batch size [Default: %default]')
    parser.add_option('-c',
                      dest='csv',
                      default=False,
                      action='store_true',
                      help='Print table as CSV [Default: %default]')
    parser.add_option('-f',
                      dest='genome_fasta',
                      default='%s/assembly/hg19.fa' % os.environ['HG19'],
                      help='Genome FASTA for sequences [Default: %default]')
    parser.add_option('-g',
                      dest='genome_file',
                      default='%s/assembly/human.hg19.genome' %
                      os.environ['HG19'],
                      help='Chromosome lengths file [Default: %default]')
    parser.add_option('--h5',
                      dest='out_h5',
                      default=False,
                      action='store_true',
                      help='Output stats to sad.h5 [Default: %default]')
    parser.add_option('--local',
                      dest='local',
                      default=1024,
                      type='int',
                      help='Local SAD score [Default: %default]')
    parser.add_option('-n',
                      dest='norm_file',
                      default=None,
                      help='Normalize SAD scores')
    parser.add_option(
        '-o',
        dest='out_dir',
        default='sad',
        help='Output directory for tables and plots [Default: %default]')
    parser.add_option('-p',
                      dest='processes',
                      default=None,
                      type='int',
                      help='Number of processes, passed by multi script')
    parser.add_option('--pseudo',
                      dest='log_pseudo',
                      default=1,
                      type='float',
                      help='Log2 pseudocount [Default: %default]')
    parser.add_option(
        '--rc',
        dest='rc',
        default=False,
        action='store_true',
        help=
        'Average forward and reverse complement predictions [Default: %default]'
    )
    parser.add_option('--shifts',
                      dest='shifts',
                      default='0',
                      type='str',
                      help='Ensemble prediction shifts [Default: %default]')
    parser.add_option(
        '--stats',
        dest='sad_stats',
        default='SAD,xSAR',
        help='Comma-separated list of stats to save. [Default: %default]')
    parser.add_option(
        '-t',
        dest='targets_file',
        default=None,
        type='str',
        help='File specifying target indexes and labels in table format')
    parser.add_option(
        '--ti',
        dest='track_indexes',
        default=None,
        type='str',
        help='Comma-separated list of target indexes to output BigWig tracks')
    parser.add_option(
        '-u',
        dest='penultimate',
        default=False,
        action='store_true',
        help='Compute SED in the penultimate layer [Default: %default]')
    parser.add_option('-z',
                      dest='out_zarr',
                      default=False,
                      action='store_true',
                      help='Output stats to sad.zarr [Default: %default]')
    (options, args) = parser.parse_args()

    if len(args) == 3:
        # single worker
        params_file = args[0]
        model_file = args[1]
        vcf_file = args[2]

    elif len(args) == 5:
        # multi worker
        options_pkl_file = args[0]
        params_file = args[1]
        model_file = args[2]
        vcf_file = args[3]
        worker_index = int(args[4])

        # load options
        options_pkl = open(options_pkl_file, 'rb')
        options = pickle.load(options_pkl)
        options_pkl.close()

        # update output directory
        options.out_dir = '%s/job%d' % (options.out_dir, worker_index)

    else:
        parser.error(
            'Must provide parameters and model files and QTL VCF file')

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    if options.track_indexes is None:
        options.track_indexes = []
    else:
        options.track_indexes = [
            int(ti) for ti in options.track_indexes.split(',')
        ]
        if not os.path.isdir('%s/tracks' % options.out_dir):
            os.mkdir('%s/tracks' % options.out_dir)

    options.shifts = [int(shift) for shift in options.shifts.split(',')]
    options.sad_stats = options.sad_stats.split(',')

    #################################################################
    # setup model

    job = basenji.params.read_job_params(
        params_file, require=['seq_length', 'num_targets', 'target_pool'])

    if options.targets_file is None:
        target_ids = ['t%d' % ti for ti in range(job['num_targets'])]
        target_labels = [''] * len(target_ids)
        target_subset = None

    else:
        targets_df = pd.read_table(options.targets_file)
        target_ids = targets_df.identifier
        target_labels = targets_df.description
        target_subset = targets_df.index
        if len(target_subset) == job['num_targets']:
            target_subset = None

    # build model
    t0 = time.time()
    model = basenji.seqnn.SeqNN()
    model.build_feed(job,
                     ensemble_rc=options.rc,
                     ensemble_shifts=options.shifts,
                     embed_penultimate=options.penultimate,
                     target_subset=target_subset)
    print('Model building time %f' % (time.time() - t0), flush=True)

    if options.penultimate:
        # labels become inappropriate
        target_ids = [''] * model.hp.cnn_filters[-1]
        target_labels = target_ids

    # read target normalization factors
    target_norms = np.ones(len(target_labels))
    if options.norm_file is not None:
        ti = 0
        for line in open(options.norm_file):
            target_norms[ti] = float(line.strip())
            ti += 1

    num_targets = len(target_ids)

    #################################################################
    # load SNPs

    snps = bvcf.vcf_snps(vcf_file)

    # filter for worker SNPs
    if options.processes is not None:
        worker_bounds = np.linspace(0,
                                    len(snps),
                                    options.processes + 1,
                                    dtype='int')
        snps = snps[worker_bounds[worker_index]:worker_bounds[worker_index +
                                                              1]]

    num_snps = len(snps)

    #################################################################
    # setup output

    header_cols = ('rsid', 'ref', 'alt', 'ref_pred', 'alt_pred', 'sad', 'sar',
                   'geo_sad', 'ref_lpred', 'alt_lpred', 'lsad', 'lsar',
                   'ref_xpred', 'alt_xpred', 'xsad', 'xsar', 'target_index',
                   'target_id', 'target_label')

    if options.out_h5:
        sad_out = initialize_output_h5(options.out_dir, options.sad_stats,
                                       snps, target_ids, target_labels)

    elif options.out_zarr:
        sad_out = initialize_output_zarr(options.out_dir, options.sad_stats,
                                         snps, target_ids, target_labels)

    else:
        if options.csv:
            sad_out = open('%s/sad_table.csv' % options.out_dir, 'w')
            print(','.join(header_cols), file=sad_out)
        else:
            sad_out = open('%s/sad_table.txt' % options.out_dir, 'w')
            print(' '.join(header_cols), file=sad_out)

    #################################################################
    # process

    # open genome FASTA
    genome_open = pysam.Fastafile(options.genome_fasta)

    # determine local start and end
    loc_mid = model.target_length // 2
    loc_start = loc_mid - (options.local // 2) // model.hp.target_pool
    loc_end = loc_start + options.local // model.hp.target_pool

    snp_i = 0
    szi = 0

    # initialize saver
    saver = tf.train.Saver()
    with tf.Session() as sess:
        # load variables into session
        saver.restore(sess, model_file)

        # construct first batch
        batch_1hot, batch_snps, snp_i = snps_next_batch(
            snps, snp_i, options.batch_size, job['seq_length'], genome_open)

        while len(batch_snps) > 0:
            ###################################################
            # predict

            # initialize batcher
            batcher = basenji.batcher.Batcher(batch_1hot,
                                              batch_size=model.hp.batch_size)

            # predict
            # batch_preds = model.predict(sess, batcher,
            #                 rc=options.rc, shifts=options.shifts,
            #                 penultimate=options.penultimate)
            batch_preds = model.predict_h5(sess, batcher)

            # normalize
            batch_preds /= target_norms

            ###################################################
            # collect and print SADs

            pi = 0
            for snp in batch_snps:
                # get reference prediction (LxT)
                ref_preds = batch_preds[pi]
                pi += 1

                # sum across length
                ref_preds_sum = ref_preds.sum(axis=0, dtype='float64')

                # print tracks
                for ti in options.track_indexes:
                    ref_bw_file = '%s/tracks/%s_t%d_ref.bw' % (options.out_dir,
                                                               snp.rsid, ti)
                    bigwig_write(snp, job['seq_length'], ref_preds[:, ti],
                                 model, ref_bw_file, options.genome_file)

                for alt_al in snp.alt_alleles:
                    # get alternate prediction (LxT)
                    alt_preds = batch_preds[pi]
                    pi += 1

                    # sum across length
                    alt_preds_sum = alt_preds.sum(axis=0, dtype='float64')

                    # compare reference to alternative via mean subtraction
                    sad_vec = alt_preds - ref_preds
                    sad = alt_preds_sum - ref_preds_sum

                    # compare reference to alternative via mean log division
                    sar = np.log2(alt_preds_sum + options.log_pseudo) \
                            - np.log2(ref_preds_sum + options.log_pseudo)

                    # compare geometric means
                    sar_vec = np.log2(alt_preds.astype('float64') + options.log_pseudo) \
                                - np.log2(ref_preds.astype('float64') + options.log_pseudo)
                    geo_sad = sar_vec.sum(axis=0)

                    # sum locally
                    ref_preds_loc = ref_preds[loc_start:loc_end, :].sum(
                        axis=0, dtype='float64')
                    alt_preds_loc = alt_preds[loc_start:loc_end, :].sum(
                        axis=0, dtype='float64')

                    # compute SAD locally
                    sad_loc = alt_preds_loc - ref_preds_loc
                    sar_loc = np.log2(alt_preds_loc + options.log_pseudo) \
                                - np.log2(ref_preds_loc + options.log_pseudo)

                    # compute max difference position
                    max_li = np.argmax(np.abs(sar_vec), axis=0)

                    if options.out_h5 or options.out_zarr:
                        sad_out['SAD'][szi, :] = sad.astype('float16')
                        sad_out['xSAR'][szi, :] = np.array([
                            sar_vec[max_li[ti], ti]
                            for ti in range(num_targets)
                        ],
                                                           dtype='float16')
                        szi += 1

                    else:
                        # print table lines
                        for ti in range(len(sad)):
                            # print line
                            cols = (snp.rsid, bvcf.cap_allele(snp.ref_allele),
                                    bvcf.cap_allele(alt_al), ref_preds_sum[ti],
                                    alt_preds_sum[ti], sad[ti], sar[ti],
                                    geo_sad[ti], ref_preds_loc[ti],
                                    alt_preds_loc[ti], sad_loc[ti],
                                    sar_loc[ti], ref_preds[max_li[ti], ti],
                                    alt_preds[max_li[ti],
                                              ti], sad_vec[max_li[ti], ti],
                                    sar_vec[max_li[ti], ti], ti,
                                    target_ids[ti], target_labels[ti])
                            if options.csv:
                                print(','.join([str(c) for c in cols]),
                                      file=sad_out)
                            else:
                                print(
                                    '%-13s %6s %6s | %8.2f %8.2f %8.3f %7.4f %7.3f | %7.3f %7.3f %7.3f %7.4f | %7.3f %7.3f %7.3f %7.4f | %4d %12s %s'
                                    % cols,
                                    file=sad_out)

                    # print tracks
                    for ti in options.track_indexes:
                        alt_bw_file = '%s/tracks/%s_t%d_alt.bw' % (
                            options.out_dir, snp.rsid, ti)
                        bigwig_write(snp, job['seq_length'], alt_preds[:, ti],
                                     model, alt_bw_file, options.genome_file)

            ###################################################
            # construct next batch

            batch_1hot, batch_snps, snp_i = snps_next_batch(
                snps, snp_i, options.batch_size, job['seq_length'],
                genome_open)

    ###################################################
    # compute SAD distributions across variants

    if options.out_h5 or options.out_zarr:
        # define percentiles
        d_fine = 0.001
        d_coarse = 0.01
        percentiles_neg = np.arange(d_fine, 0.1, d_fine)
        percentiles_base = np.arange(0.1, 0.9, d_coarse)
        percentiles_pos = np.arange(0.9, 1, d_fine)

        percentiles = np.concatenate(
            [percentiles_neg, percentiles_base, percentiles_pos])
        sad_out.create_dataset('percentiles', data=percentiles)
        pct_len = len(percentiles)

        for sad_stat in options.sad_stats:
            sad_stat_pct = '%s_pct' % sad_stat

            # compute
            sad_pct = np.percentile(sad_out[sad_stat],
                                    100 * percentiles,
                                    axis=0).T
            sad_pct = sad_pct.astype('float16')

            # save
            sad_out.create_dataset(sad_stat_pct, data=sad_pct, dtype='float16')

    if not options.out_zarr:
        sad_out.close()
Ejemplo n.º 13
0
def main():
    usage = 'usage: %prog [options] <params_file> <model_file> <vcf_file>'
    parser = OptionParser(usage)
    parser.add_option('-f',
                      dest='genome_fasta',
                      default=None,
                      help='Genome FASTA for sequences [Default: %default]')
    parser.add_option('-l',
                      dest='plot_lim_min',
                      default=0.1,
                      type='float',
                      help='Heatmap plot limit [Default: %default]')
    parser.add_option(
        '-m',
        dest='plot_map',
        default=False,
        action='store_true',
        help='Plot contact map for each allele [Default: %default]')
    parser.add_option(
        '-o',
        dest='out_dir',
        default='scd',
        help='Output directory for tables and plots [Default: %default]')
    parser.add_option('-p',
                      dest='processes',
                      default=None,
                      type='int',
                      help='Number of processes, passed by multi script')
    parser.add_option(
        '--rc',
        dest='rc',
        default=False,
        action='store_true',
        help=
        'Average forward and reverse complement predictions [Default: %default]'
    )
    parser.add_option('--shifts',
                      dest='shifts',
                      default='0',
                      type='str',
                      help='Ensemble prediction shifts [Default: %default]')
    parser.add_option(
        '--stats',
        dest='scd_stats',
        default='SCD',
        help='Comma-separated list of stats to save. [Default: %default]')
    parser.add_option(
        '-t',
        dest='targets_file',
        default=None,
        type='str',
        help='File specifying target indexes and labels in table format')
    (options, args) = parser.parse_args()

    if len(args) == 3:
        # single worker
        params_file = args[0]
        model_file = args[1]
        vcf_file = args[2]

    elif len(args) == 5:
        # multi worker
        options_pkl_file = args[0]
        params_file = args[1]
        model_file = args[2]
        vcf_file = args[3]
        worker_index = int(args[4])

        # load options
        options_pkl = open(options_pkl_file, 'rb')
        options = pickle.load(options_pkl)
        options_pkl.close()

        # update output directory
        options.out_dir = '%s/job%d' % (options.out_dir, worker_index)

    else:
        parser.error(
            'Must provide parameters and model files and QTL VCF file')

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)
    if options.plot_map:
        plot_dir = options.out_dir
    else:
        plot_dir = None

    options.shifts = [int(shift) for shift in options.shifts.split(',')]
    options.scd_stats = options.scd_stats.split(',')

    random.seed(44)

    #################################################################
    # read parameters and targets

    # read model parameters
    with open(params_file) as params_open:
        params = json.load(params_open)
    params_train = params['train']
    params_model = params['model']

    if options.targets_file is not None:
        targets_df = pd.read_csv(options.targets_file, sep='\t', index_col=0)
        target_ids = targets_df.identifier
        target_labels = targets_df.description

    #################################################################
    # setup model

    # load model
    seqnn_model = seqnn.SeqNN(params_model)
    seqnn_model.restore(model_file)
    seqnn_model.build_ensemble(options.rc, options.shifts)

    # dummy target info
    if options.targets_file is None:
        num_targets = seqnn_model.num_targets()
        target_ids = ['t%d' % ti for ti in range(num_targets)]
        target_labels = [''] * len(target_ids)

    #################################################################
    # load SNPs

    # filter for worker SNPs
    if options.processes is not None:
        # determine boundaries
        num_snps = bvcf.vcf_count(vcf_file)
        worker_bounds = np.linspace(0,
                                    num_snps,
                                    options.processes + 1,
                                    dtype='int')

        # read SNPs form VCF
        snps = bvcf.vcf_snps(vcf_file,
                             start_i=worker_bounds[worker_index],
                             end_i=worker_bounds[worker_index + 1])

    else:
        # read SNPs form VCF
        snps = bvcf.vcf_snps(vcf_file)

    num_snps = len(snps)

    # open genome FASTA
    genome_open = pysam.Fastafile(options.genome_fasta)

    def snp_gen():
        for snp in snps:
            # get SNP sequences
            snp_1hot_list = bvcf.snp_seq1(snp, params_model['seq_length'],
                                          genome_open)
            for snp_1hot in snp_1hot_list:
                yield snp_1hot

    #################################################################
    # setup output

    scd_out = initialize_output_h5(options.out_dir, options.scd_stats, snps,
                                   target_ids, target_labels)

    #################################################################
    # predict SNP scores, write output

    write_thread = None

    # initialize predictions stream
    preds_stream = stream.PredStreamGen(seqnn_model, snp_gen(),
                                        params_train['batch_size'])

    # predictions index
    pi = 0

    for si in range(num_snps):
        # get predictions
        ref_preds = preds_stream[pi]
        pi += 1
        alt_preds = preds_stream[pi]
        pi += 1

        # process SNP
        write_snp(ref_preds, alt_preds, scd_out, si, options.scd_stats,
                  plot_dir, seqnn_model.diagonal_offset, options.plot_lim_min)

    genome_open.close()
    scd_out.close()