Ejemplo n.º 1
0
def main():
    usage = 'usage: %prog [options] <model> <vcf_file>'
    parser = OptionParser(usage)
    parser.add_option('-c',
                      dest='slice_center',
                      default=None,
                      type='int',
                      help='Slice center positions [Default: %default]')
    parser.add_option('-f',
                      dest='genome_fasta',
                      default='%s/data/hg38.fa' % os.environ['BASENJIDIR'],
                      help='Genome FASTA for sequences [Default: %default]')
    parser.add_option('-n',
                      dest='norm_file',
                      default=None,
                      help='Normalize SAD scores')
    parser.add_option(
        '-o',
        dest='out_dir',
        default='sad',
        help='Output directory for tables and plots [Default: %default]')
    parser.add_option('-p',
                      dest='processes',
                      default=None,
                      type='int',
                      help='Number of processes, passed by multi script')
    parser.add_option('--pseudo',
                      dest='log_pseudo',
                      default=1,
                      type='float',
                      help='Log2 pseudocount [Default: %default]')
    parser.add_option(
        '--rc',
        dest='rc',
        default=False,
        action='store_true',
        help=
        'Average forward and reverse complement predictions [Default: %default]'
    )
    parser.add_option('--shifts',
                      dest='shifts',
                      default='0',
                      type='str',
                      help='Ensemble prediction shifts [Default: %default]')
    parser.add_option('--species', dest='species', default='human')
    parser.add_option(
        '--stats',
        dest='sad_stats',
        default='SAD',
        help='Comma-separated list of stats to save. [Default: %default]')
    parser.add_option(
        '-t',
        dest='targets_file',
        default=None,
        type='str',
        help='File specifying target indexes and labels in table format')
    (options, args) = parser.parse_args()

    if len(args) == 2:
        # single worker
        model_file = args[0]
        vcf_file = args[1]

    elif len(args) == 3:
        # multi separate
        options_pkl_file = args[0]
        model_file = args[1]
        vcf_file = args[2]

        # save out dir
        out_dir = options.out_dir

        # load options
        options_pkl = open(options_pkl_file, 'rb')
        options = pickle.load(options_pkl)
        options_pkl.close()

        # update output directory
        options.out_dir = out_dir

    elif len(args) == 4:
        # multi worker
        options_pkl_file = args[0]
        model_file = args[1]
        vcf_file = args[2]
        worker_index = int(args[3])

        # load options
        options_pkl = open(options_pkl_file, 'rb')
        options = pickle.load(options_pkl)
        options_pkl.close()

        # update output directory
        options.out_dir = '%s/job%d' % (options.out_dir, worker_index)

    else:
        parser.error('Must provide model and VCF file')

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    options.shifts = [int(shift) for shift in options.shifts.split(',')]
    options.sad_stats = options.sad_stats.split(',')

    #################################################################
    # read parameters and targets

    if options.targets_file is None:
        target_slice = None
    else:
        targets_df = pd.read_csv(options.targets_file, sep='\t', index_col=0)
        target_ids = targets_df.identifier
        target_labels = targets_df.description
        target_slice = targets_df.index

    #################################################################
    # setup model

    seqnn_model = tf.saved_model.load(model_file).model

    # query num model targets
    seq_length = seqnn_model.predict_on_batch.input_signature[0].shape[1]
    null_1hot = np.zeros((1, seq_length, 4))
    null_preds = seqnn_model.predict_on_batch(null_1hot)
    null_preds = null_preds[options.species].numpy()
    _, targets_length, num_targets = null_preds.shape

    if options.targets_file is None:
        target_ids = ['t%d' % ti for ti in range(num_targets)]
        target_labels = [''] * len(target_ids)

    #################################################################
    # load SNPs

    # filter for worker SNPs
    if options.processes is not None:
        # determine boundaries
        num_snps = bvcf.vcf_count(vcf_file)
        worker_bounds = np.linspace(0,
                                    num_snps,
                                    options.processes + 1,
                                    dtype='int')

        # read SNPs form VCF
        snps = bvcf.vcf_snps(vcf_file,
                             start_i=worker_bounds[worker_index],
                             end_i=worker_bounds[worker_index + 1])

    else:
        # read SNPs form VCF
        snps = bvcf.vcf_snps(vcf_file)

    num_snps = len(snps)

    # open genome FASTA
    genome_open = pysam.Fastafile(options.genome_fasta)

    # create SNP sequence generator
    def snp_gen():
        for snp in snps:
            # get SNP sequences
            snp_1hot_list = bvcf.snp_seq1(snp, seq_length, genome_open)
            for snp_1hot in snp_1hot_list:
                yield snp_1hot

    #################################################################
    # setup output

    sad_out = initialize_output_h5(options.out_dir, options.sad_stats, snps,
                                   target_ids, target_labels, targets_length)

    #################################################################
    # predict SNP scores, write output

    # initialize predictions stream
    preds_stream = PredStreamGen(seqnn_model,
                                 snp_gen(),
                                 rc=options.rc,
                                 shifts=options.shifts,
                                 slice_center=options.slice_center,
                                 species=options.species)

    # predictions index
    pi = 0

    for si in range(num_snps):
        # get predictions
        ref_preds = preds_stream[pi]
        pi += 1
        alt_preds = preds_stream[pi]
        pi += 1

        # process SNP
        write_snp(ref_preds, alt_preds, sad_out, si, options.sad_stats,
                  options.log_pseudo)

    # close genome
    genome_open.close()

    ###################################################
    # compute SAD distributions across variants

    write_pct(sad_out, options.sad_stats)
    sad_out.close()
Ejemplo n.º 2
0
def main():
    usage = 'usage: %prog [options] <params_file> <model_file> <vcf_file>'
    parser = OptionParser(usage)
    parser.add_option(
        '-c',
        dest='center_pct',
        default=0.25,
        type='float',
        help='Require clustered SNPs lie in center region [Default: %default]')
    parser.add_option('-f',
                      dest='genome_fasta',
                      default='%s/data/hg19.fa' % os.environ['BASENJIDIR'],
                      help='Genome FASTA for sequences [Default: %default]')
    parser.add_option(
        '--flip',
        dest='flip_ref',
        default=False,
        action='store_true',
        help='Flip reference/alternate alleles when simple [Default: %default]'
    )
    parser.add_option('--local',
                      dest='local',
                      default=1024,
                      type='int',
                      help='Local SAD score [Default: %default]')
    parser.add_option('-n',
                      dest='norm_file',
                      default=None,
                      help='Normalize SAD scores')
    parser.add_option(
        '-o',
        dest='out_dir',
        default='sad',
        help='Output directory for tables and plots [Default: %default]')
    parser.add_option('-p',
                      dest='processes',
                      default=None,
                      type='int',
                      help='Number of processes, passed by multi script')
    parser.add_option('--pseudo',
                      dest='log_pseudo',
                      default=1,
                      type='float',
                      help='Log2 pseudocount [Default: %default]')
    parser.add_option(
        '--rc',
        dest='rc',
        default=False,
        action='store_true',
        help=
        'Average forward and reverse complement predictions [Default: %default]'
    )
    parser.add_option('--shifts',
                      dest='shifts',
                      default='0',
                      type='str',
                      help='Ensemble prediction shifts [Default: %default]')
    parser.add_option(
        '--stats',
        dest='sad_stats',
        default='SAD',
        help='Comma-separated list of stats to save. [Default: %default]')
    parser.add_option(
        '-t',
        dest='targets_file',
        default=None,
        type='str',
        help='File specifying target indexes and labels in table format')
    parser.add_option(
        '--ti',
        dest='track_indexes',
        default=None,
        type='str',
        help='Comma-separated list of target indexes to output BigWig tracks')
    parser.add_option(
        '--threads',
        dest='threads',
        default=False,
        action='store_true',
        help='Run CPU math and output in a separate thread [Default: %default]'
    )
    parser.add_option(
        '-u',
        dest='penultimate',
        default=False,
        action='store_true',
        help='Compute SED in the penultimate layer [Default: %default]')
    (options, args) = parser.parse_args()

    if len(args) == 3:
        # single worker
        params_file = args[0]
        model_file = args[1]
        vcf_file = args[2]

    elif len(args) == 5:
        # multi worker
        options_pkl_file = args[0]
        params_file = args[1]
        model_file = args[2]
        vcf_file = args[3]
        worker_index = int(args[4])

        # load options
        options_pkl = open(options_pkl_file, 'rb')
        options = pickle.load(options_pkl)
        options_pkl.close()

        # update output directory
        options.out_dir = '%s/job%d' % (options.out_dir, worker_index)

    else:
        parser.error(
            'Must provide parameters and model files and QTL VCF file')

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    if options.track_indexes is None:
        options.track_indexes = []
    else:
        options.track_indexes = [
            int(ti) for ti in options.track_indexes.split(',')
        ]
        if not os.path.isdir('%s/tracks' % options.out_dir):
            os.mkdir('%s/tracks' % options.out_dir)

    options.shifts = [int(shift) for shift in options.shifts.split(',')]
    options.sad_stats = options.sad_stats.split(',')

    #################################################################
    # read parameters and targets

    # read model parameters
    with open(params_file) as params_open:
        params = json.load(params_open)
    params_model = params['model']
    params_train = params['train']

    if options.targets_file is None:
        target_slice = None
    else:
        targets_df = pd.read_csv(options.targets_file, sep='\t', index_col=0)
        target_ids = targets_df.identifier
        target_labels = targets_df.description
        target_slice = targets_df.index

    if options.penultimate:
        parser.error('Not implemented for TF2')

    #################################################################
    # setup model

    seqnn_model = seqnn.SeqNN(params_model)
    seqnn_model.restore(model_file)
    seqnn_model.build_slice(target_slice)
    seqnn_model.build_ensemble(options.rc, options.shifts)

    num_targets = seqnn_model.num_targets()
    if options.targets_file is None:
        target_ids = ['t%d' % ti for ti in range(num_targets)]
        target_labels = [''] * len(target_ids)

    #################################################################
    # load SNPs

    # filter for worker SNPs
    if options.processes is not None:
        # determine boundaries
        num_snps = bvcf.vcf_count(vcf_file)
        worker_bounds = np.linspace(0,
                                    num_snps,
                                    options.processes + 1,
                                    dtype='int')

        # read sorted SNPs from VCF
        snps = bvcf.vcf_snps(vcf_file,
                             require_sorted=True,
                             flip_ref=options.flip_ref,
                             validate_ref_fasta=options.genome_fasta,
                             start_i=worker_bounds[worker_index],
                             end_i=worker_bounds[worker_index + 1])
    else:
        # read sorted SNPs from VCF
        snps = bvcf.vcf_snps(vcf_file,
                             require_sorted=True,
                             flip_ref=options.flip_ref,
                             validate_ref_fasta=options.genome_fasta)

    # cluster SNPs by position
    snp_clusters = cluster_snps(snps, params_model['seq_length'],
                                options.center_pct)

    # delimit sequence boundaries
    [sc.delimit(params_model['seq_length']) for sc in snp_clusters]

    # open genome FASTA
    genome_open = pysam.Fastafile(options.genome_fasta)

    # make SNP sequence generator
    def snp_gen():
        for sc in snp_clusters:
            snp_1hot_list = sc.get_1hots(genome_open)
            for snp_1hot in snp_1hot_list:
                yield snp_1hot

    #################################################################
    # setup output

    snp_flips = np.array([snp.flipped for snp in snps], dtype='bool')

    sad_out = initialize_output_h5(options.out_dir, options.sad_stats, snps,
                                   target_ids, target_labels)

    if options.threads:
        snp_threads = []
        snp_queue = Queue()
        for i in range(1):
            sw = SNPWorker(snp_queue, sad_out, options.sad_stats,
                           options.log_pseudo)
            sw.start()
            snp_threads.append(sw)

    #################################################################
    # predict SNP scores, write output

    # initialize predictions stream
    preds_stream = stream.PredStreamGen(seqnn_model, snp_gen(),
                                        params['train']['batch_size'])

    # predictions index
    pi = 0

    # SNP index
    si = 0

    for snp_cluster in snp_clusters:
        ref_preds = preds_stream[pi]
        pi += 1

        for snp in snp_cluster.snps:
            # print(snp, flush=True)

            alt_preds = preds_stream[pi]
            pi += 1

            if snp_flips[si]:
                ref_preds, alt_preds = alt_preds, ref_preds

            if options.threads:
                # queue SNP
                snp_queue.put((ref_preds, alt_preds, si))
            else:
                # process SNP
                write_snp(ref_preds, alt_preds, sad_out, si, options.sad_stats,
                          options.log_pseudo)

            # update SNP index
            si += 1

    # finish queue
    if options.threads:
        print('Waiting for threads to finish.', flush=True)
        snp_queue.join()

    # close genome
    genome_open.close()

    ###################################################
    # compute SAD distributions across variants

    write_pct(sad_out, options.sad_stats)
    sad_out.close()
Ejemplo n.º 3
0
def main():
    usage = 'usage: %prog [options] <params_file> <model_file> <vcf_file>'
    parser = OptionParser(usage)
    parser.add_option(
        '-c',
        dest='center_pct',
        default=0.25,
        type='float',
        help='Require clustered SNPs lie in center region [Default: %default]')
    parser.add_option('-f',
                      dest='genome_fasta',
                      default='%s/data/hg19.fa' % os.environ['BASENJIDIR'],
                      help='Genome FASTA for sequences [Default: %default]')
    parser.add_option(
        '--flip',
        dest='flip_ref',
        default=False,
        action='store_true',
        help='Flip reference/alternate alleles when simple [Default: %default]'
    )
    parser.add_option('--local',
                      dest='local',
                      default=1024,
                      type='int',
                      help='Local SAD score [Default: %default]')
    parser.add_option('-n',
                      dest='norm_file',
                      default=None,
                      help='Normalize SAD scores')
    parser.add_option(
        '-o',
        dest='out_dir',
        default='sad',
        help='Output directory for tables and plots [Default: %default]')
    parser.add_option('-p',
                      dest='processes',
                      default=None,
                      type='int',
                      help='Number of processes, passed by multi script')
    parser.add_option('--pseudo',
                      dest='log_pseudo',
                      default=1,
                      type='float',
                      help='Log2 pseudocount [Default: %default]')
    parser.add_option(
        '--rc',
        dest='rc',
        default=False,
        action='store_true',
        help=
        'Average forward and reverse complement predictions [Default: %default]'
    )
    parser.add_option('--shifts',
                      dest='shifts',
                      default='0',
                      type='str',
                      help='Ensemble prediction shifts [Default: %default]')
    parser.add_option(
        '--stats',
        dest='sad_stats',
        default='SAD',
        help='Comma-separated list of stats to save. [Default: %default]')
    parser.add_option(
        '-t',
        dest='targets_file',
        default=None,
        type='str',
        help='File specifying target indexes and labels in table format')
    parser.add_option(
        '--ti',
        dest='track_indexes',
        default=None,
        type='str',
        help='Comma-separated list of target indexes to output BigWig tracks')
    parser.add_option(
        '-u',
        dest='penultimate',
        default=False,
        action='store_true',
        help='Compute SED in the penultimate layer [Default: %default]')
    (options, args) = parser.parse_args()

    if len(args) == 3:
        # single worker
        params_file = args[0]
        model_file = args[1]
        vcf_file = args[2]

    elif len(args) == 5:
        # multi worker
        options_pkl_file = args[0]
        params_file = args[1]
        model_file = args[2]
        vcf_file = args[3]
        worker_index = int(args[4])

        # load options
        options_pkl = open(options_pkl_file, 'rb')
        options = pickle.load(options_pkl)
        options_pkl.close()

        # update output directory
        options.out_dir = '%s/job%d' % (options.out_dir, worker_index)

    else:
        parser.error(
            'Must provide parameters and model files and QTL VCF file')

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    if options.track_indexes is None:
        options.track_indexes = []
    else:
        options.track_indexes = [
            int(ti) for ti in options.track_indexes.split(',')
        ]
        if not os.path.isdir('%s/tracks' % options.out_dir):
            os.mkdir('%s/tracks' % options.out_dir)

    options.shifts = [int(shift) for shift in options.shifts.split(',')]
    options.sad_stats = options.sad_stats.split(',')

    #################################################################
    # read parameters and collet target information

    job = params.read_job_params(params_file,
                                 require=['seq_length', 'num_targets'])

    if options.targets_file is None:
        target_ids = ['t%d' % ti for ti in range(job['num_targets'])]
        target_labels = [''] * len(target_ids)
        target_subset = None

    else:
        targets_df = pd.read_csv(options.targets_file, sep='\t', index_col=0)
        target_ids = targets_df.identifier
        target_labels = targets_df.description
        target_subset = targets_df.index
        if len(target_subset) == job['num_targets']:
            target_subset = None

    #################################################################
    # load SNPs

    # filter for worker SNPs
    if options.processes is not None:
        # determine boundaries
        num_snps = bvcf.vcf_count(vcf_file)
        worker_bounds = np.linspace(0,
                                    num_snps,
                                    options.processes + 1,
                                    dtype='int')

        # read sorted SNPs from VCF
        snps = bvcf.vcf_snps(vcf_file,
                             require_sorted=True,
                             flip_ref=options.flip_ref,
                             validate_ref_fasta=options.genome_fasta,
                             start_i=worker_bounds[worker_index],
                             end_i=worker_bounds[worker_index + 1])
    else:
        # read sorted SNPs from VCF
        snps = bvcf.vcf_snps(vcf_file,
                             require_sorted=True,
                             flip_ref=options.flip_ref,
                             validate_ref_fasta=options.genome_fasta)

    # cluster SNPs by position
    snp_clusters = cluster_snps(snps, job['seq_length'], options.center_pct)

    # delimit sequence boundaries
    [sc.delimit(job['seq_length']) for sc in snp_clusters]

    # open genome FASTA
    genome_open = pysam.Fastafile(options.genome_fasta)

    # make SNP sequence generator
    def snp_gen():
        for sc in snp_clusters:
            snp_1hot_list = sc.get_1hots(genome_open)
            for snp_1hot in snp_1hot_list:
                yield {'sequence': snp_1hot}

    snp_types = {'sequence': tf.float32}
    snp_shapes = {
        'sequence':
        tf.TensorShape([tf.Dimension(job['seq_length']),
                        tf.Dimension(4)])
    }

    dataset = tf.data.Dataset.from_generator(snp_gen,
                                             output_types=snp_types,
                                             output_shapes=snp_shapes)
    dataset = dataset.batch(job['batch_size'])
    dataset = dataset.prefetch(2 * job['batch_size'])
    # dataset = dataset.apply(tf.contrib.data.prefetch_to_device('/device:GPU:0'))

    iterator = dataset.make_one_shot_iterator()
    data_ops = iterator.get_next()

    #################################################################
    # setup model

    # build model
    t0 = time.time()
    model = seqnn.SeqNN()
    model.build_sad(job,
                    data_ops,
                    ensemble_rc=options.rc,
                    ensemble_shifts=options.shifts,
                    embed_penultimate=options.penultimate,
                    target_subset=target_subset)
    print('Model building time %f' % (time.time() - t0), flush=True)

    if options.penultimate:
        # labels become inappropriate
        target_ids = [''] * model.hp.cnn_filters[-1]
        target_labels = target_ids

    # read target normalization factors
    target_norms = np.ones(len(target_labels))
    if options.norm_file is not None:
        ti = 0
        for line in open(options.norm_file):
            target_norms[ti] = float(line.strip())
            ti += 1

    num_targets = len(target_ids)

    #################################################################
    # setup output

    snp_flips = np.array([snp.flipped for snp in snps], dtype='bool')

    sad_out = initialize_output_h5(options.out_dir, options.sad_stats, snps,
                                   target_ids, target_labels)

    snp_threads = []

    snp_queue = Queue()
    for i in range(1):
        sw = SNPWorker(snp_queue, sad_out, options.sad_stats,
                       options.log_pseudo)
        sw.start()
        snp_threads.append(sw)

    #################################################################
    # predict SNP scores, write output

    # initialize saver
    saver = tf.train.Saver()
    with tf.Session() as sess:
        # load variables into session
        saver.restore(sess, model_file)

        # initialize predictions stream
        preds_stream = PredStream(sess, model, 32)

        # predictions index
        pi = 0

        # SNP index
        si = 0

        for snp_cluster in snp_clusters:
            ref_preds = preds_stream[pi]
            pi += 1

            for snp in snp_cluster.snps:
                # print(snp, flush=True)

                alt_preds = preds_stream[pi]
                pi += 1

                # queue SNP
                if snp_flips[si]:
                    snp_queue.put((alt_preds, ref_preds, si))
                else:
                    snp_queue.put((ref_preds, alt_preds, si))

                # update SNP index
                si += 1

    # finish queue
    print('Waiting for threads to finish.', flush=True)
    snp_queue.join()

    # close genome
    genome_open.close()

    ###################################################
    # compute SAD distributions across variants

    # define percentiles
    d_fine = 0.001
    d_coarse = 0.01
    percentiles_neg = np.arange(d_fine, 0.1, d_fine)
    percentiles_base = np.arange(0.1, 0.9, d_coarse)
    percentiles_pos = np.arange(0.9, 1, d_fine)

    percentiles = np.concatenate(
        [percentiles_neg, percentiles_base, percentiles_pos])
    sad_out.create_dataset('percentiles', data=percentiles)
    pct_len = len(percentiles)

    for sad_stat in options.sad_stats:
        sad_stat_pct = '%s_pct' % sad_stat

        # compute
        sad_pct = np.percentile(sad_out[sad_stat], 100 * percentiles, axis=0).T
        sad_pct = sad_pct.astype('float16')

        # save
        sad_out.create_dataset(sad_stat_pct, data=sad_pct, dtype='float16')

    sad_out.close()
Ejemplo n.º 4
0
def main():
    usage = 'usage: %prog [options] <params_file> <model_file> <vcf_file>'
    parser = OptionParser(usage)
    parser.add_option('-c',
                      dest='csv',
                      default=False,
                      action='store_true',
                      help='Print table as CSV [Default: %default]')
    parser.add_option('--cpu',
                      dest='cpu',
                      default=False,
                      action='store_true',
                      help='Run without a GPU [Default: %default]')
    parser.add_option('-f',
                      dest='genome_fasta',
                      default='%s/data/hg19.fa' % os.environ['BASENJIDIR'],
                      help='Genome FASTA for sequences [Default: %default]')
    parser.add_option('--local',
                      dest='local',
                      default=1024,
                      type='int',
                      help='Local SAD score [Default: %default]')
    parser.add_option('-n',
                      dest='norm_file',
                      default=None,
                      help='Normalize SAD scores')
    parser.add_option(
        '-o',
        dest='out_dir',
        default='sad',
        help='Output directory for tables and plots [Default: %default]')
    parser.add_option('-p',
                      dest='processes',
                      default=None,
                      type='int',
                      help='Number of processes, passed by multi script')
    parser.add_option('--pseudo',
                      dest='log_pseudo',
                      default=1,
                      type='float',
                      help='Log2 pseudocount [Default: %default]')
    parser.add_option(
        '--rc',
        dest='rc',
        default=False,
        action='store_true',
        help=
        'Average forward and reverse complement predictions [Default: %default]'
    )
    parser.add_option('--shifts',
                      dest='shifts',
                      default='0',
                      type='str',
                      help='Ensemble prediction shifts [Default: %default]')
    parser.add_option(
        '--stats',
        dest='sad_stats',
        default='SAD',
        help='Comma-separated list of stats to save. [Default: %default]')
    parser.add_option(
        '-t',
        dest='targets_file',
        default=None,
        type='str',
        help='File specifying target indexes and labels in table format')
    parser.add_option(
        '--ti',
        dest='track_indexes',
        default=None,
        type='str',
        help='Comma-separated list of target indexes to output BigWig tracks')
    parser.add_option('--txt',
                      dest='out_txt',
                      default=False,
                      action='store_true',
                      help='Output stats to text table [Default: %default]')
    parser.add_option(
        '-u',
        dest='penultimate',
        default=False,
        action='store_true',
        help='Compute SED in the penultimate layer [Default: %default]')
    parser.add_option('-z',
                      dest='out_zarr',
                      default=False,
                      action='store_true',
                      help='Output stats to sad.zarr [Default: %default]')
    (options, args) = parser.parse_args()

    if len(args) == 3:
        # single worker
        params_file = args[0]
        model_file = args[1]
        vcf_file = args[2]

    elif len(args) == 5:
        # multi worker
        options_pkl_file = args[0]
        params_file = args[1]
        model_file = args[2]
        vcf_file = args[3]
        worker_index = int(args[4])

        # load options
        options_pkl = open(options_pkl_file, 'rb')
        options = pickle.load(options_pkl)
        options_pkl.close()

        # update output directory
        options.out_dir = '%s/job%d' % (options.out_dir, worker_index)

    else:
        parser.error(
            'Must provide parameters and model files and QTL VCF file')

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    if options.track_indexes is None:
        options.track_indexes = []
    else:
        options.track_indexes = [
            int(ti) for ti in options.track_indexes.split(',')
        ]
        if not os.path.isdir('%s/tracks' % options.out_dir):
            os.mkdir('%s/tracks' % options.out_dir)

    options.shifts = [int(shift) for shift in options.shifts.split(',')]
    options.sad_stats = options.sad_stats.split(',')

    #################################################################
    # read parameters

    job = params.read_job_params(params_file,
                                 require=['seq_length', 'num_targets'])

    if options.targets_file is None:
        target_ids = ['t%d' % ti for ti in range(job['num_targets'])]
        target_labels = [''] * len(target_ids)
        target_subset = None

    else:
        targets_df = pd.read_csv(options.targets_file, sep='\t', index_col=0)
        target_ids = targets_df.identifier
        target_labels = targets_df.description
        target_subset = targets_df.index
        if len(target_subset) == job['num_targets']:
            target_subset = None

    #################################################################
    # load SNPs

    # filter for worker SNPs
    if options.processes is not None:
        # determine boundaries
        num_snps = bvcf.vcf_count(vcf_file)
        worker_bounds = np.linspace(0,
                                    num_snps,
                                    options.processes + 1,
                                    dtype='int')

        # read SNPs form VCF
        snps = bvcf.vcf_snps(vcf_file,
                             start_i=worker_bounds[worker_index],
                             end_i=worker_bounds[worker_index + 1])

    else:
        # read SNPs form VCF
        snps = bvcf.vcf_snps(vcf_file)

    # open genome FASTA
    genome_open = pysam.Fastafile(options.genome_fasta)

    def snp_gen():
        for snp in snps:
            # get SNP sequences
            snp_1hot_list = bvcf.snp_seq1(snp, job['seq_length'], genome_open)

            for snp_1hot in snp_1hot_list:
                yield {'sequence': snp_1hot}

    snp_types = {'sequence': tf.float32}
    snp_shapes = {
        'sequence':
        tf.TensorShape([tf.Dimension(job['seq_length']),
                        tf.Dimension(4)])
    }

    dataset = tf.data.Dataset.from_generator(snp_gen,
                                             output_types=snp_types,
                                             output_shapes=snp_shapes)
    dataset = dataset.batch(job['batch_size'])
    dataset = dataset.prefetch(2 * job['batch_size'])
    if not options.cpu:
        dataset = dataset.apply(
            tf.contrib.data.prefetch_to_device('/device:GPU:0'))

    iterator = dataset.make_one_shot_iterator()
    data_ops = iterator.get_next()

    #################################################################
    # setup model

    # build model
    t0 = time.time()
    model = seqnn.SeqNN()
    model.build_sad(job,
                    data_ops,
                    ensemble_rc=options.rc,
                    ensemble_shifts=options.shifts,
                    embed_penultimate=options.penultimate,
                    target_subset=target_subset)
    print('Model building time %f' % (time.time() - t0), flush=True)

    if options.penultimate:
        # labels become inappropriate
        target_ids = [''] * model.hp.cnn_filters[-1]
        target_labels = target_ids

    # read target normalization factors
    target_norms = np.ones(len(target_labels))
    if options.norm_file is not None:
        ti = 0
        for line in open(options.norm_file):
            target_norms[ti] = float(line.strip())
            ti += 1

    num_targets = len(target_ids)

    #################################################################
    # setup output

    if options.out_zarr:
        sad_out = initialize_output_zarr(options.out_dir, options.sad_stats,
                                         snps, target_ids, target_labels)

    elif options.out_txt:
        header_cols = ('rsid', 'ref', 'alt', 'ref_pred', 'alt_pred', 'sad',
                       'sar', 'geo_sad', 'ref_lpred', 'alt_lpred', 'lsad',
                       'lsar', 'ref_xpred', 'alt_xpred', 'xsad', 'xsar',
                       'target_index', 'target_id', 'target_label')

        if options.csv:
            sad_out = open('%s/sad_table.csv' % options.out_dir, 'w')
            print(','.join(header_cols), file=sad_out)
        else:
            sad_out = open('%s/sad_table.txt' % options.out_dir, 'w')
            print(' '.join(header_cols), file=sad_out)

    else:
        sad_out = initialize_output_h5(options.out_dir, options.sad_stats,
                                       snps, target_ids, target_labels)

    #################################################################
    # process

    szi = 0
    sum_write_thread = None
    sw_batch_size = 32 // job['batch_size']

    # initialize saver
    saver = tf.train.Saver()
    with tf.Session() as sess:
        # load variables into session
        saver.restore(sess, model_file)

        # predict first
        batch_preds = model.predict_tfr(sess, test_batches=sw_batch_size)

        while batch_preds.shape[0] > 0:
            # count predicted SNPs
            num_snps = batch_preds.shape[0] // 2

            # normalize
            batch_preds /= target_norms

            # block for last thread
            if sum_write_thread is not None:
                sum_write_thread.join()

            # summarize and write
            sum_write_thread = threading.Thread(target=summarize_write,
                                                args=(batch_preds, sad_out,
                                                      szi, options.sad_stats,
                                                      options.log_pseudo))
            sum_write_thread.start()

            # update SNP index
            szi += num_snps

            # predict next
            batch_preds = model.predict_tfr(sess, test_batches=sw_batch_size)

    print('Waiting for threads to finish.', flush=True)
    sum_write_thread.join()

    ###################################################
    # compute SAD distributions across variants

    if not options.out_txt:
        # define percentiles
        d_fine = 0.001
        d_coarse = 0.01
        percentiles_neg = np.arange(d_fine, 0.1, d_fine)
        percentiles_base = np.arange(0.1, 0.9, d_coarse)
        percentiles_pos = np.arange(0.9, 1, d_fine)

        percentiles = np.concatenate(
            [percentiles_neg, percentiles_base, percentiles_pos])
        sad_out.create_dataset('percentiles', data=percentiles)
        pct_len = len(percentiles)

        for sad_stat in options.sad_stats:
            sad_stat_pct = '%s_pct' % sad_stat

            # compute
            sad_pct = np.percentile(sad_out[sad_stat],
                                    100 * percentiles,
                                    axis=0).T
            sad_pct = sad_pct.astype('float16')

            # save
            sad_out.create_dataset(sad_stat_pct, data=sad_pct, dtype='float16')

    if not options.out_zarr:
        sad_out.close()
Ejemplo n.º 5
0
def main():
    usage = 'usage: %prog [options] <params_file> <model_file> <vcf_file>'
    parser = OptionParser(usage)
    parser.add_option('-f',
                      dest='genome_fasta',
                      default=None,
                      help='Genome FASTA for sequences [Default: %default]')
    parser.add_option(
        '-m',
        dest='plot_map',
        default=False,
        action='store_true',
        help='Plot contact map for each allele [Default: %default]')
    parser.add_option(
        '-o',
        dest='out_dir',
        default='scd',
        help='Output directory for tables and plots [Default: %default]')
    parser.add_option('-p',
                      dest='processes',
                      default=None,
                      type='int',
                      help='Number of processes, passed by multi script')
    parser.add_option(
        '--rc',
        dest='rc',
        default=False,
        action='store_true',
        help=
        'Average forward and reverse complement predictions [Default: %default]'
    )
    parser.add_option('--shifts',
                      dest='shifts',
                      default='0',
                      type='str',
                      help='Ensemble prediction shifts [Default: %default]')
    parser.add_option(
        '--stats',
        dest='scd_stats',
        default='SCD',
        help='Comma-separated list of stats to save. [Default: %default]')
    parser.add_option(
        '-t',
        dest='targets_file',
        default=None,
        type='str',
        help='File specifying target indexes and labels in table format')
    (options, args) = parser.parse_args()

    if len(args) == 3:
        # single worker
        params_file = args[0]
        model_file = args[1]
        vcf_file = args[2]

    elif len(args) == 5:
        # multi worker
        options_pkl_file = args[0]
        params_file = args[1]
        model_file = args[2]
        vcf_file = args[3]
        worker_index = int(args[4])

        # load options
        options_pkl = open(options_pkl_file, 'rb')
        options = pickle.load(options_pkl)
        options_pkl.close()

        # update output directory
        options.out_dir = '%s/job%d' % (options.out_dir, worker_index)

    else:
        parser.error(
            'Must provide parameters and model files and QTL VCF file')

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)
    if options.plot_map:
        plot_dir = options.out_dir
    else:
        plot_dir = None

    options.shifts = [int(shift) for shift in options.shifts.split(',')]
    options.scd_stats = options.scd_stats.split(',')

    #################################################################
    # read parameters

    # read model parameters
    with open(params_file) as params_open:
        params = json.load(params_open)
    params_train = params['train']
    params_model = params['model']

    if options.targets_file is None:
        target_ids = ['t%d' % ti for ti in range(params_model['num_targets'])]
        target_labels = [''] * len(target_ids)

    else:
        targets_df = pd.read_csv(options.targets_file, sep='\t', index_col=0)
        target_ids = targets_df.identifier
        target_labels = targets_df.description

    #################################################################
    # load SNPs

    # filter for worker SNPs
    if options.processes is not None:
        # determine boundaries
        num_snps = bvcf.vcf_count(vcf_file)
        worker_bounds = np.linspace(0,
                                    num_snps,
                                    options.processes + 1,
                                    dtype='int')

        # read SNPs form VCF
        snps = bvcf.vcf_snps(vcf_file,
                             start_i=worker_bounds[worker_index],
                             end_i=worker_bounds[worker_index + 1])

    else:
        # read SNPs form VCF
        snps = bvcf.vcf_snps(vcf_file)

    # open genome FASTA
    genome_open = pysam.Fastafile(options.genome_fasta)

    def snp_gen():
        for snp in snps:
            # get SNP sequences
            snp_1hot_list = bvcf.snp_seq1(snp, params_model['seq_length'],
                                          genome_open)

            for snp_1hot in snp_1hot_list:
                yield {'sequence': snp_1hot}

    snp_types = {'sequence': tf.float32}
    snp_shapes = {
        'sequence':
        tf.TensorShape(
            [tf.Dimension(params_model['seq_length']),
             tf.Dimension(4)])
    }

    dataset = tf.data.Dataset.from_generator(snp_gen,
                                             output_types=snp_types,
                                             output_shapes=snp_shapes)
    dataset = dataset.batch(params_train['batch_size'])
    dataset = dataset.prefetch(2 * params_train['batch_size'])
    dataset_iter = iter(dataset)

    # def get_chunk(chunk_size=32):
    #   """Get a chunk of data from the dataset iterator."""
    #   x = []
    #   for ci in range(chunk_size):
    #     try:
    #       x.append(next(dataset_iter))
    #     except StopIteration:
    #       break

    #################################################################
    # setup model

    # load model
    seqnn_model = seqnn.SeqNN(params_model)
    seqnn_model.restore(model_file)
    seqnn_model.build_ensemble(options.rc, options.shifts)

    #################################################################
    # setup output

    scd_out = initialize_output_h5(options.out_dir, options.scd_stats, snps,
                                   target_ids, target_labels)

    #################################################################
    # process

    szi = 0
    sum_write_thread = None

    # predict first
    # batch_seqs = get_chunk()
    # batch_preds = seqnn_model.predict(batch_seqs, steps=batch_seqs)
    batch_preds = seqnn_model.predict(dataset_iter, generator=True, steps=32)

    while len(batch_preds) > 0:
        # count predicted SNPs
        num_snps = batch_preds.shape[0] // 2

        # block for last thread
        if sum_write_thread is not None:
            sum_write_thread.join()

        # summarize and write
        sum_write_thread = threading.Thread(target=summarize_write,
                                            args=(batch_preds, scd_out, szi,
                                                  options.scd_stats, plot_dir,
                                                  seqnn_model.diagonal_offset))
        sum_write_thread.start()

        # update SNP index
        szi += num_snps

        # predict next
        try:
            # batch_preds = seqnn_model.predict(get_chunk())
            batch_preds = seqnn_model.predict(dataset_iter,
                                              generator=True,
                                              steps=32)
        except ValueError:
            batch_preds = []

    print('Waiting for threads to finish.', flush=True)
    sum_write_thread.join()

    scd_out.close()
Ejemplo n.º 6
0
def main():
    usage = 'usage: %prog [options] <params_file> <model_file> <vcf_file>'
    parser = OptionParser(usage)
    parser.add_option('-f',
                      dest='genome_fasta',
                      default=None,
                      help='Genome FASTA for sequences [Default: %default]')
    parser.add_option('-l',
                      dest='plot_lim_min',
                      default=0.1,
                      type='float',
                      help='Heatmap plot limit [Default: %default]')
    parser.add_option(
        '-m',
        dest='plot_map',
        default=False,
        action='store_true',
        help='Plot contact map for each allele [Default: %default]')
    parser.add_option(
        '-o',
        dest='out_dir',
        default='scd',
        help='Output directory for tables and plots [Default: %default]')
    parser.add_option('-p',
                      dest='processes',
                      default=None,
                      type='int',
                      help='Number of processes, passed by multi script')
    parser.add_option(
        '--rc',
        dest='rc',
        default=False,
        action='store_true',
        help=
        'Average forward and reverse complement predictions [Default: %default]'
    )
    parser.add_option('--shifts',
                      dest='shifts',
                      default='0',
                      type='str',
                      help='Ensemble prediction shifts [Default: %default]')
    parser.add_option(
        '--stats',
        dest='scd_stats',
        default='SCD',
        help='Comma-separated list of stats to save. [Default: %default]')
    parser.add_option(
        '-t',
        dest='targets_file',
        default=None,
        type='str',
        help='File specifying target indexes and labels in table format')
    (options, args) = parser.parse_args()

    if len(args) == 3:
        # single worker
        params_file = args[0]
        model_file = args[1]
        vcf_file = args[2]

    elif len(args) == 5:
        # multi worker
        options_pkl_file = args[0]
        params_file = args[1]
        model_file = args[2]
        vcf_file = args[3]
        worker_index = int(args[4])

        # load options
        options_pkl = open(options_pkl_file, 'rb')
        options = pickle.load(options_pkl)
        options_pkl.close()

        # update output directory
        options.out_dir = '%s/job%d' % (options.out_dir, worker_index)

    else:
        parser.error(
            'Must provide parameters and model files and QTL VCF file')

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)
    if options.plot_map:
        plot_dir = options.out_dir
    else:
        plot_dir = None

    options.shifts = [int(shift) for shift in options.shifts.split(',')]
    options.scd_stats = options.scd_stats.split(',')

    random.seed(44)

    #################################################################
    # read parameters and targets

    # read model parameters
    with open(params_file) as params_open:
        params = json.load(params_open)
    params_train = params['train']
    params_model = params['model']

    if options.targets_file is not None:
        targets_df = pd.read_csv(options.targets_file, sep='\t', index_col=0)
        target_ids = targets_df.identifier
        target_labels = targets_df.description

    #################################################################
    # setup model

    # load model
    seqnn_model = seqnn.SeqNN(params_model)
    seqnn_model.restore(model_file)
    seqnn_model.build_ensemble(options.rc, options.shifts)

    # dummy target info
    if options.targets_file is None:
        num_targets = seqnn_model.num_targets()
        target_ids = ['t%d' % ti for ti in range(num_targets)]
        target_labels = [''] * len(target_ids)

    #################################################################
    # load SNPs

    # filter for worker SNPs
    if options.processes is not None:
        # determine boundaries
        num_snps = bvcf.vcf_count(vcf_file)
        worker_bounds = np.linspace(0,
                                    num_snps,
                                    options.processes + 1,
                                    dtype='int')

        # read SNPs form VCF
        snps = bvcf.vcf_snps(vcf_file,
                             start_i=worker_bounds[worker_index],
                             end_i=worker_bounds[worker_index + 1])

    else:
        # read SNPs form VCF
        snps = bvcf.vcf_snps(vcf_file)

    num_snps = len(snps)

    # open genome FASTA
    genome_open = pysam.Fastafile(options.genome_fasta)

    def snp_gen():
        for snp in snps:
            # get SNP sequences
            snp_1hot_list = bvcf.snp_seq1(snp, params_model['seq_length'],
                                          genome_open)
            for snp_1hot in snp_1hot_list:
                yield snp_1hot

    #################################################################
    # setup output

    scd_out = initialize_output_h5(options.out_dir, options.scd_stats, snps,
                                   target_ids, target_labels)

    #################################################################
    # predict SNP scores, write output

    write_thread = None

    # initialize predictions stream
    preds_stream = stream.PredStreamGen(seqnn_model, snp_gen(),
                                        params_train['batch_size'])

    # predictions index
    pi = 0

    for si in range(num_snps):
        # get predictions
        ref_preds = preds_stream[pi]
        pi += 1
        alt_preds = preds_stream[pi]
        pi += 1

        # process SNP
        write_snp(ref_preds, alt_preds, scd_out, si, options.scd_stats,
                  plot_dir, seqnn_model.diagonal_offset, options.plot_lim_min)

    genome_open.close()
    scd_out.close()