Exemple #1
0
def main():
    usage = 'usage: %prog [options] <params_file> <model_file> <vcf_file>'
    parser = OptionParser(usage)
    parser.add_option(
        '-d',
        dest='mut_down',
        default=0,
        type='int',
        help=
        'Nucleotides downstream of center sequence to mutate [Default: %default]'
    )
    parser.add_option('-f',
                      dest='figure_width',
                      default=20,
                      type='float',
                      help='Figure width [Default: %default]')
    parser.add_option(
        '--f1',
        dest='genome1_fasta',
        default='%s/data/hg38.fa' % os.environ['BASENJIDIR'],
        help='Genome FASTA which which major allele sequences will be drawn')
    parser.add_option(
        '--f2',
        dest='genome2_fasta',
        default=None,
        help='Genome FASTA which which minor allele sequences will be drawn')
    parser.add_option(
        '-l',
        dest='mut_len',
        default=200,
        type='int',
        help='Length of centered sequence to mutate [Default: %default]')
    parser.add_option('-o',
                      dest='out_dir',
                      default='sat_vcf',
                      help='Output directory [Default: %default]')
    parser.add_option(
        '--rc',
        dest='rc',
        default=False,
        action='store_true',
        help=
        'Ensemble forward and reverse complement predictions [Default: %default]'
    )
    parser.add_option('--shifts',
                      dest='shifts',
                      default='0',
                      help='Ensemble prediction shifts [Default: %default]')
    parser.add_option(
        '--stats',
        dest='sad_stats',
        default='sum',
        help='Comma-separated list of stats to save. [Default: %default]')
    parser.add_option(
        '-t',
        dest='targets_file',
        default=None,
        type='str',
        help='File specifying target indexes and labels in table format')
    parser.add_option(
        '-u',
        dest='mut_up',
        default=0,
        type='int',
        help=
        'Nucleotides upstream of center sequence to mutate [Default: %default]'
    )
    (options, args) = parser.parse_args()

    if len(args) != 3:
        parser.error('Must provide parameters and model files and VCF')
    else:
        params_file = args[0]
        model_file = args[1]
        vcf_file = args[2]

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    options.shifts = [int(shift) for shift in options.shifts.split(',')]
    options.sad_stats = [
        sad_stat.lower() for sad_stat in options.sad_stats.split(',')
    ]

    if options.mut_up > 0 or options.mut_down > 0:
        options.mut_len = options.mut_up + options.mut_down
    else:
        assert (options.mut_len > 0)
        options.mut_up = options.mut_len // 2
        options.mut_down = options.mut_len - options.mut_up

    #################################################################
    # read parameters and targets

    # read model parameters
    with open(params_file) as params_open:
        params = json.load(params_open)
    params_model = params['model']
    params_train = params['train']

    # read targets
    if options.targets_file is None:
        target_slice = None
    else:
        targets_df = pd.read_table(options.targets_file, index_col=0)
        target_slice = targets_df.index

    #################################################################
    # setup model

    seqnn_model = seqnn.SeqNN(params_model)
    seqnn_model.restore(model_file)
    seqnn_model.build_slice(target_slice)
    seqnn_model.build_ensemble(options.rc, options.shifts)

    num_targets = seqnn_model.num_targets()

    #################################################################
    # SNP sequence dataset

    # load SNPs
    snps = vcf.vcf_snps(vcf_file)

    # get one hot coded input sequences
    if not options.genome2_fasta:
        seqs_1hot, seq_headers, snps, seqs_dna = vcf.snps_seq1(
            snps,
            params_model['seq_length'],
            options.genome1_fasta,
            return_seqs=True)
    else:
        seqs_1hot, seq_headers, snps, seqs_dna = vcf.snps2_seq1(
            snps,
            params_model['seq_length'],
            options.genome1_fasta,
            options.genome2_fasta,
            return_seqs=True)
    num_seqs = seqs_1hot.shape[0]

    # determine mutation region limits
    seq_mid = params_model['seq_length'] // 2
    mut_start = seq_mid - options.mut_up
    mut_end = mut_start + options.mut_len

    # make sequence generator
    seqs_gen = satmut_gen(seqs_dna, mut_start, mut_end)

    #################################################################
    # setup output

    scores_h5_file = '%s/scores.h5' % options.out_dir
    if os.path.isfile(scores_h5_file):
        os.remove(scores_h5_file)
    scores_h5 = h5py.File(scores_h5_file, 'w')
    scores_h5.create_dataset('label', data=np.array(seq_headers, dtype='S'))
    scores_h5.create_dataset('seqs',
                             dtype='bool',
                             shape=(num_seqs, options.mut_len, 4))
    for sad_stat in options.sad_stats:
        scores_h5.create_dataset(sad_stat,
                                 dtype='float16',
                                 shape=(num_seqs, options.mut_len, 4,
                                        num_targets))

    preds_per_seq = 1 + 3 * options.mut_len

    score_threads = []
    score_queue = Queue()
    for i in range(1):
        sw = ScoreWorker(score_queue, scores_h5, options.sad_stats, mut_start,
                         mut_end)
        sw.start()
        score_threads.append(sw)

    #################################################################
    # predict scores and write output

    # find center
    preds_length = seqnn_model.target_lengths[0]
    center_start = preds_length // 2
    if preds_length % 2 == 0:
        center_end = center_start + 2
    else:
        center_end = center_start + 1

    # initialize predictions stream
    preds_stream = stream.PredStreamGen(seqnn_model, seqs_gen,
                                        params_train['batch_size'])

    # predictions index
    pi = 0

    for si in range(num_seqs):
        print('Predicting %d' % si, flush=True)

        # collect sequence predictions
        seq_preds_sum = []
        seq_preds_center = []
        seq_preds_scd = []
        preds_mut0 = preds_stream[pi]
        for spi in range(preds_per_seq):
            preds_mut = preds_stream[pi]
            preds_sum = preds_mut.sum(axis=0)
            seq_preds_sum.append(preds_sum)
            if 'center' in options.sad_stats:
                preds_center = preds_mut[center_start:center_end, :].sum(
                    axis=0)
                seq_preds_center.append(preds_center)
            if 'scd' in options.sad_stats:
                preds_scd = np.sqrt(((preds_mut - preds_mut0)**2).sum(axis=0))
                seq_preds_scd.append(preds_scd)
            pi += 1
        seq_preds_sum = np.array(seq_preds_sum)
        seq_preds_center = np.array(seq_preds_center)
        seq_preds_scd = np.array(seq_preds_scd)

        # wait for previous to finish
        score_queue.join()

        # queue sequence for scoring
        seq_pred_stats = (seq_preds_sum, seq_preds_center, seq_preds_scd)
        score_queue.put((seqs_dna[si], seq_pred_stats, si))

        gc.collect()

    # finish queue
    print('Waiting for threads to finish.', flush=True)
    score_queue.join()

    # close output HDF5
    scores_h5.close()
Exemple #2
0
def main():
    usage = 'usage: %prog [options] <params_file> <model_file> <vcf_file>'
    parser = OptionParser(usage)
    parser.add_option(
        '-c',
        dest='center_pct',
        default=0.25,
        type='float',
        help='Require clustered SNPs lie in center region [Default: %default]')
    parser.add_option('-f',
                      dest='genome_fasta',
                      default='%s/data/hg19.fa' % os.environ['BASENJIDIR'],
                      help='Genome FASTA for sequences [Default: %default]')
    parser.add_option(
        '--flip',
        dest='flip_ref',
        default=False,
        action='store_true',
        help='Flip reference/alternate alleles when simple [Default: %default]'
    )
    parser.add_option('--local',
                      dest='local',
                      default=1024,
                      type='int',
                      help='Local SAD score [Default: %default]')
    parser.add_option('-n',
                      dest='norm_file',
                      default=None,
                      help='Normalize SAD scores')
    parser.add_option(
        '-o',
        dest='out_dir',
        default='sad',
        help='Output directory for tables and plots [Default: %default]')
    parser.add_option('-p',
                      dest='processes',
                      default=None,
                      type='int',
                      help='Number of processes, passed by multi script')
    parser.add_option('--pseudo',
                      dest='log_pseudo',
                      default=1,
                      type='float',
                      help='Log2 pseudocount [Default: %default]')
    parser.add_option(
        '--rc',
        dest='rc',
        default=False,
        action='store_true',
        help=
        'Average forward and reverse complement predictions [Default: %default]'
    )
    parser.add_option('--shifts',
                      dest='shifts',
                      default='0',
                      type='str',
                      help='Ensemble prediction shifts [Default: %default]')
    parser.add_option(
        '--stats',
        dest='sad_stats',
        default='SAD',
        help='Comma-separated list of stats to save. [Default: %default]')
    parser.add_option(
        '-t',
        dest='targets_file',
        default=None,
        type='str',
        help='File specifying target indexes and labels in table format')
    parser.add_option(
        '--ti',
        dest='track_indexes',
        default=None,
        type='str',
        help='Comma-separated list of target indexes to output BigWig tracks')
    parser.add_option(
        '--threads',
        dest='threads',
        default=False,
        action='store_true',
        help='Run CPU math and output in a separate thread [Default: %default]'
    )
    parser.add_option(
        '-u',
        dest='penultimate',
        default=False,
        action='store_true',
        help='Compute SED in the penultimate layer [Default: %default]')
    (options, args) = parser.parse_args()

    if len(args) == 3:
        # single worker
        params_file = args[0]
        model_file = args[1]
        vcf_file = args[2]

    elif len(args) == 5:
        # multi worker
        options_pkl_file = args[0]
        params_file = args[1]
        model_file = args[2]
        vcf_file = args[3]
        worker_index = int(args[4])

        # load options
        options_pkl = open(options_pkl_file, 'rb')
        options = pickle.load(options_pkl)
        options_pkl.close()

        # update output directory
        options.out_dir = '%s/job%d' % (options.out_dir, worker_index)

    else:
        parser.error(
            'Must provide parameters and model files and QTL VCF file')

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    if options.track_indexes is None:
        options.track_indexes = []
    else:
        options.track_indexes = [
            int(ti) for ti in options.track_indexes.split(',')
        ]
        if not os.path.isdir('%s/tracks' % options.out_dir):
            os.mkdir('%s/tracks' % options.out_dir)

    options.shifts = [int(shift) for shift in options.shifts.split(',')]
    options.sad_stats = options.sad_stats.split(',')

    #################################################################
    # read parameters and targets

    # read model parameters
    with open(params_file) as params_open:
        params = json.load(params_open)
    params_model = params['model']
    params_train = params['train']

    if options.targets_file is None:
        target_slice = None
    else:
        targets_df = pd.read_csv(options.targets_file, sep='\t', index_col=0)
        target_ids = targets_df.identifier
        target_labels = targets_df.description
        target_slice = targets_df.index

    if options.penultimate:
        parser.error('Not implemented for TF2')

    #################################################################
    # setup model

    seqnn_model = seqnn.SeqNN(params_model)
    seqnn_model.restore(model_file)
    seqnn_model.build_slice(target_slice)
    seqnn_model.build_ensemble(options.rc, options.shifts)

    num_targets = seqnn_model.num_targets()
    if options.targets_file is None:
        target_ids = ['t%d' % ti for ti in range(num_targets)]
        target_labels = [''] * len(target_ids)

    #################################################################
    # load SNPs

    # filter for worker SNPs
    if options.processes is not None:
        # determine boundaries
        num_snps = bvcf.vcf_count(vcf_file)
        worker_bounds = np.linspace(0,
                                    num_snps,
                                    options.processes + 1,
                                    dtype='int')

        # read sorted SNPs from VCF
        snps = bvcf.vcf_snps(vcf_file,
                             require_sorted=True,
                             flip_ref=options.flip_ref,
                             validate_ref_fasta=options.genome_fasta,
                             start_i=worker_bounds[worker_index],
                             end_i=worker_bounds[worker_index + 1])
    else:
        # read sorted SNPs from VCF
        snps = bvcf.vcf_snps(vcf_file,
                             require_sorted=True,
                             flip_ref=options.flip_ref,
                             validate_ref_fasta=options.genome_fasta)

    # cluster SNPs by position
    snp_clusters = cluster_snps(snps, params_model['seq_length'],
                                options.center_pct)

    # delimit sequence boundaries
    [sc.delimit(params_model['seq_length']) for sc in snp_clusters]

    # open genome FASTA
    genome_open = pysam.Fastafile(options.genome_fasta)

    # make SNP sequence generator
    def snp_gen():
        for sc in snp_clusters:
            snp_1hot_list = sc.get_1hots(genome_open)
            for snp_1hot in snp_1hot_list:
                yield snp_1hot

    #################################################################
    # setup output

    snp_flips = np.array([snp.flipped for snp in snps], dtype='bool')

    sad_out = initialize_output_h5(options.out_dir, options.sad_stats, snps,
                                   target_ids, target_labels)

    if options.threads:
        snp_threads = []
        snp_queue = Queue()
        for i in range(1):
            sw = SNPWorker(snp_queue, sad_out, options.sad_stats,
                           options.log_pseudo)
            sw.start()
            snp_threads.append(sw)

    #################################################################
    # predict SNP scores, write output

    # initialize predictions stream
    preds_stream = stream.PredStreamGen(seqnn_model, snp_gen(),
                                        params['train']['batch_size'])

    # predictions index
    pi = 0

    # SNP index
    si = 0

    for snp_cluster in snp_clusters:
        ref_preds = preds_stream[pi]
        pi += 1

        for snp in snp_cluster.snps:
            # print(snp, flush=True)

            alt_preds = preds_stream[pi]
            pi += 1

            if snp_flips[si]:
                ref_preds, alt_preds = alt_preds, ref_preds

            if options.threads:
                # queue SNP
                snp_queue.put((ref_preds, alt_preds, si))
            else:
                # process SNP
                write_snp(ref_preds, alt_preds, sad_out, si, options.sad_stats,
                          options.log_pseudo)

            # update SNP index
            si += 1

    # finish queue
    if options.threads:
        print('Waiting for threads to finish.', flush=True)
        snp_queue.join()

    # close genome
    genome_open.close()

    ###################################################
    # compute SAD distributions across variants

    write_pct(sad_out, options.sad_stats)
    sad_out.close()
Exemple #3
0
def main():
    usage = 'usage: %prog [options] <params_file> <model_file> <bed_file>'
    parser = OptionParser(usage)
    parser.add_option(
        '-b',
        dest='bigwig_indexes',
        default=None,
        help='Comma-separated list of target indexes to write BigWigs')
    parser.add_option('-e',
                      dest='embed_layer',
                      default=None,
                      type='int',
                      help='Embed sequences using the specified layer index.')
    parser.add_option('-f',
                      dest='genome_fasta',
                      default=None,
                      help='Genome FASTA for sequences [Default: %default]')
    parser.add_option('-g',
                      dest='genome_file',
                      default=None,
                      help='Chromosome length information [Default: %default]')
    parser.add_option(
        '-l',
        dest='site_length',
        default=None,
        type='int',
        help='Prediction site length. [Default: params.seq_length]')
    parser.add_option('-o',
                      dest='out_dir',
                      default='pred_out',
                      help='Output directory [Default: %default]')
    # parser.add_option('--plots', dest='plots',
    #     default=False, action='store_true',
    #     help='Make heatmap plots [Default: %default]')
    parser.add_option('-p',
                      dest='processes',
                      default=None,
                      type='int',
                      help='Number of processes, passed by multi script')
    parser.add_option(
        '--rc',
        dest='rc',
        default=False,
        action='store_true',
        help=
        'Ensemble forward and reverse complement predictions [Default: %default]'
    )
    parser.add_option('-s',
                      dest='sum',
                      default=False,
                      action='store_true',
                      help='Sum site predictions [Default: %default]')
    parser.add_option('--shifts',
                      dest='shifts',
                      default='0',
                      help='Ensemble prediction shifts [Default: %default]')
    parser.add_option(
        '-t',
        dest='targets_file',
        default=None,
        type='str',
        help='File specifying target indexes and labels in table format')
    (options, args) = parser.parse_args()

    if len(args) == 3:
        params_file = args[0]
        model_file = args[1]
        bed_file = args[2]

    elif len(args) == 5:
        # multi worker
        options_pkl_file = args[0]
        params_file = args[1]
        model_file = args[2]
        bed_file = args[3]
        worker_index = int(args[4])

        # load options
        options_pkl = open(options_pkl_file, 'rb')
        options = pickle.load(options_pkl)
        options_pkl.close()

        # update output directory
        options.out_dir = '%s/job%d' % (options.out_dir, worker_index)
    else:
        parser.error('Must provide parameter and model files and BED file')

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    options.shifts = [int(shift) for shift in options.shifts.split(',')]

    if options.bigwig_indexes is not None:
        options.bigwig_indexes = [
            int(bi) for bi in options.bigwig_indexes.split(',')
        ]
    else:
        options.bigwig_indexes = []

    if len(options.bigwig_indexes) > 0:
        bigwig_dir = '%s/bigwig' % options.out_dir
        if not os.path.isdir(bigwig_dir):
            os.mkdir(bigwig_dir)

    #################################################################
    # read parameters and collet target information

    with open(params_file) as params_open:
        params = json.load(params_open)
    params_model = params['model']

    if options.targets_file is None:
        target_slice = None
    else:
        targets_df = pd.read_table(options.targets_file, index_col=0)
        target_slice = targets_df.index

    #################################################################
    # setup model

    # initialize model
    seqnn_model = seqnn.SeqNN(params_model)
    seqnn_model.restore(model_file)
    seqnn_model.build_slice(target_slice)
    seqnn_model.build_ensemble(options.rc, options.shifts)

    if options.embed_layer is not None:
        seqnn_model.build_embed(options.embed_layer)
        _, preds_length, preds_depth = seqnn_model.embed.output.shape
    else:
        _, preds_length, preds_depth = seqnn_model.model.output.shape

    if type(preds_length) == tf.compat.v1.Dimension:
        preds_length = preds_length.value
        preds_depth = preds_depth.value

    preds_window = seqnn_model.model_strides[0]
    seq_crop = seqnn_model.target_crops[0] * preds_window

    #################################################################
    # sequence dataset

    if options.site_length is None:
        options.site_length = preds_window * preds_length
        print('site_length: %d' % options.site_length)

    # construct model sequences
    model_seqs_dna, model_seqs_coords = bed.make_bed_seqs(
        bed_file,
        options.genome_fasta,
        params_model['seq_length'],
        stranded=False)

    # construct site coordinates
    site_seqs_coords = bed.read_bed_coords(bed_file, options.site_length)

    # filter for worker SNPs
    if options.processes is not None:
        worker_bounds = np.linspace(0,
                                    len(model_seqs_dna),
                                    options.processes + 1,
                                    dtype='int')
        model_seqs_dna = model_seqs_dna[
            worker_bounds[worker_index]:worker_bounds[worker_index + 1]]
        model_seqs_coords = model_seqs_coords[
            worker_bounds[worker_index]:worker_bounds[worker_index + 1]]
        site_seqs_coords = site_seqs_coords[
            worker_bounds[worker_index]:worker_bounds[worker_index + 1]]

    num_seqs = len(model_seqs_dna)

    #################################################################
    # setup output

    assert (preds_length % 2 == 0)
    preds_mid = preds_length // 2

    assert (options.site_length % preds_window == 0)
    site_preds_length = options.site_length // preds_window

    assert (site_preds_length % 2 == 0)
    site_preds_start = preds_mid - site_preds_length // 2
    site_preds_end = site_preds_start + site_preds_length

    # initialize HDF5
    out_h5_file = '%s/predict.h5' % options.out_dir
    if os.path.isfile(out_h5_file):
        os.remove(out_h5_file)
    out_h5 = h5py.File(out_h5_file, 'w')

    # create predictions
    if options.sum:
        out_h5.create_dataset('preds',
                              shape=(num_seqs, preds_depth),
                              dtype='float16')
    else:
        out_h5.create_dataset('preds',
                              shape=(num_seqs, site_preds_length, preds_depth),
                              dtype='float16')

    # store site coordinates
    site_seqs_chr, site_seqs_start, site_seqs_end = zip(*site_seqs_coords)
    site_seqs_chr = np.array(site_seqs_chr, dtype='S')
    site_seqs_start = np.array(site_seqs_start)
    site_seqs_end = np.array(site_seqs_end)
    out_h5.create_dataset('chrom', data=site_seqs_chr)
    out_h5.create_dataset('start', data=site_seqs_start)
    out_h5.create_dataset('end', data=site_seqs_end)

    #################################################################
    # predict scores, write output

    # define sequence generator
    def seqs_gen():
        for seq_dna in model_seqs_dna:
            yield dna_io.dna_1hot(seq_dna)

    # predict
    preds_stream = stream.PredStreamGen(seqnn_model, seqs_gen(),
                                        params['train']['batch_size'])

    for si in range(num_seqs):
        preds_seq = preds_stream[si]

        # slice site
        preds_site = preds_seq[site_preds_start:site_preds_end, :]

        # write
        if options.sum:
            out_h5['preds'][si] = preds_site.sum(axis=0)
        else:
            out_h5['preds'][si] = preds_site

        # write bigwig
        for ti in options.bigwig_indexes:
            bw_file = '%s/s%d_t%d.bw' % (bigwig_dir, si, ti)
            bigwig_write(preds_seq[:, ti], model_seqs_coords[si], bw_file,
                         options.genome_file, seq_crop)

    # close output HDF5
    out_h5.close()
def main():
    usage = 'usage: %prog [options] <params_file> <model_file> <bed_file>'
    parser = OptionParser(usage)
    parser.add_option(
        '-d',
        dest='mut_down',
        default=0,
        type='int',
        help=
        'Nucleotides downstream of center sequence to mutate [Default: %default]'
    )
    parser.add_option('-f',
                      dest='genome_fasta',
                      default=None,
                      help='Genome FASTA for sequences [Default: %default]')
    parser.add_option(
        '-l',
        dest='mut_len',
        default=0,
        type='int',
        help='Length of center sequence to mutate [Default: %default]')
    parser.add_option('-o',
                      dest='out_dir',
                      default='sat_mut',
                      help='Output directory [Default: %default]')
    parser.add_option('--plots',
                      dest='plots',
                      default=False,
                      action='store_true',
                      help='Make heatmap plots [Default: %default]')
    parser.add_option('-p',
                      dest='processes',
                      default=None,
                      type='int',
                      help='Number of processes, passed by multi script')
    parser.add_option(
        '--rc',
        dest='rc',
        default=False,
        action='store_true',
        help=
        'Ensemble forward and reverse complement predictions [Default: %default]'
    )
    parser.add_option('--shifts',
                      dest='shifts',
                      default='0',
                      help='Ensemble prediction shifts [Default: %default]')
    parser.add_option(
        '--stats',
        dest='sad_stats',
        default='sum',
        help='Comma-separated list of stats to save. [Default: %default]')
    parser.add_option(
        '-t',
        dest='targets_file',
        default=None,
        type='str',
        help='File specifying target indexes and labels in table format')
    parser.add_option(
        '-u',
        dest='mut_up',
        default=0,
        type='int',
        help=
        'Nucleotides upstream of center sequence to mutate [Default: %default]'
    )
    (options, args) = parser.parse_args()

    if len(args) == 3:
        # single worker
        params_file = args[0]
        model_file = args[1]
        bed_file = args[2]

    elif len(args) == 4:
        # master script
        options_pkl_file = args[0]
        params_file = args[1]
        model_file = args[2]
        bed_file = args[3]

        # load options
        options_pkl = open(options_pkl_file, 'rb')
        options = pickle.load(options_pkl)
        options_pkl.close()

    elif len(args) == 5:
        # multi worker
        options_pkl_file = args[0]
        params_file = args[1]
        model_file = args[2]
        bed_file = args[3]
        worker_index = int(args[4])

        # load options
        options_pkl = open(options_pkl_file, 'rb')
        options = pickle.load(options_pkl)
        options_pkl.close()

        # update output directory
        options.out_dir = '%s/job%d' % (options.out_dir, worker_index)

    else:
        parser.error('Must provide parameter and model files and BED file')

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    options.shifts = [int(shift) for shift in options.shifts.split(',')]
    options.sad_stats = [
        sad_stat.lower() for sad_stat in options.sad_stats.split(',')
    ]

    if options.mut_up > 0 or options.mut_down > 0:
        options.mut_len = options.mut_up + options.mut_down
    else:
        assert (options.mut_len > 0)
        options.mut_up = options.mut_len // 2
        options.mut_down = options.mut_len - options.mut_up

    #################################################################
    # read parameters and targets

    # read model parameters
    with open(params_file) as params_open:
        params = json.load(params_open)
    params_model = params['model']
    params_train = params['train']

    # read targets
    if options.targets_file is None:
        target_slice = None
    else:
        targets_df = pd.read_table(options.targets_file, index_col=0)
        target_slice = targets_df.index

    #################################################################
    # setup model

    seqnn_model = seqnn.SeqNN(params_model)
    seqnn_model.restore(model_file)
    seqnn_model.build_slice(target_slice)
    seqnn_model.build_ensemble(options.rc, options.shifts)

    num_targets = seqnn_model.num_targets()

    #################################################################
    # sequence dataset

    # read sequences from BED
    seqs_dna, seqs_coords = bed.make_bed_seqs(bed_file,
                                              options.genome_fasta,
                                              params_model['seq_length'],
                                              stranded=True)

    # filter for worker SNPs
    if options.processes is not None:
        worker_bounds = np.linspace(0,
                                    len(seqs_dna),
                                    options.processes + 1,
                                    dtype='int')
        seqs_dna = seqs_dna[
            worker_bounds[worker_index]:worker_bounds[worker_index + 1]]
        seqs_coords = seqs_coords[
            worker_bounds[worker_index]:worker_bounds[worker_index + 1]]

    num_seqs = len(seqs_dna)

    # determine mutation region limits
    seq_mid = params_model['seq_length'] // 2
    mut_start = seq_mid - options.mut_up
    mut_end = mut_start + options.mut_len

    # make sequence generator
    seqs_gen = satmut_gen(seqs_dna, mut_start, mut_end)

    #################################################################
    # setup output

    scores_h5_file = '%s/scores.h5' % options.out_dir
    if os.path.isfile(scores_h5_file):
        os.remove(scores_h5_file)
    scores_h5 = h5py.File('%s/scores.h5' % options.out_dir, 'w')
    scores_h5.create_dataset('seqs',
                             dtype='bool',
                             shape=(num_seqs, options.mut_len, 4))
    for sad_stat in options.sad_stats:
        scores_h5.create_dataset(sad_stat,
                                 dtype='float16',
                                 shape=(num_seqs, options.mut_len, 4,
                                        num_targets))

    # store mutagenesis sequence coordinates
    seqs_chr, seqs_start, _, seqs_strand = zip(*seqs_coords)
    seqs_chr = np.array(seqs_chr, dtype='S')
    seqs_start = np.array(seqs_start) + mut_start
    seqs_end = seqs_start + options.mut_len
    seqs_strand = np.array(seqs_strand, dtype='S')
    scores_h5.create_dataset('chrom', data=seqs_chr)
    scores_h5.create_dataset('start', data=seqs_start)
    scores_h5.create_dataset('end', data=seqs_end)
    scores_h5.create_dataset('strand', data=seqs_strand)

    preds_per_seq = 1 + 3 * options.mut_len

    score_threads = []
    score_queue = Queue()
    for i in range(1):
        sw = ScoreWorker(score_queue, scores_h5, options.sad_stats, mut_start,
                         mut_end)
        sw.start()
        score_threads.append(sw)

    #################################################################
    # predict scores, write output

    # find center
    preds_length = seqnn_model.target_lengths[0]
    center_start = preds_length // 2
    if preds_length % 2 == 0:
        center_end = center_start + 2
    else:
        center_end = center_start + 1

    # initialize predictions stream
    preds_stream = stream.PredStreamGen(seqnn_model, seqs_gen,
                                        params['train']['batch_size'])

    # predictions index
    pi = 0

    for si in range(num_seqs):
        print('Predicting %d' % si, flush=True)

        # collect sequence predictions
        seq_preds_sum = []
        seq_preds_center = []
        seq_preds_scd = []
        preds_mut0 = preds_stream[pi]
        for spi in range(preds_per_seq):
            preds_mut = preds_stream[pi]
            preds_sum = preds_mut.sum(axis=0)
            seq_preds_sum.append(preds_sum)
            if 'center' in options.sad_stats:
                preds_center = preds_mut[center_start:center_end, :].sum(
                    axis=0)
                seq_preds_center.append(preds_center)
            if 'scd' in options.sad_stats:
                preds_scd = np.sqrt(((preds_mut - preds_mut0)**2).sum(axis=0))
                seq_preds_scd.append(preds_scd)
            pi += 1
        seq_preds_sum = np.array(seq_preds_sum)
        seq_preds_center = np.array(seq_preds_center)
        seq_preds_scd = np.array(seq_preds_scd)

        # wait for previous to finish
        score_queue.join()

        # queue sequence for scoring
        seq_pred_stats = (seq_preds_sum, seq_preds_center, seq_preds_scd)
        score_queue.put((seqs_dna[si], seq_pred_stats, si))

        # queue sequence for plotting
        if options.plots:
            plot_queue.put((seqs_dna[si], seq_preds_sum, si))

        gc.collect()

    # finish queue
    print('Waiting for threads to finish.', flush=True)
    score_queue.join()

    # close output HDF5
    scores_h5.close()
Exemple #5
0
def main():
    usage = 'usage: %prog [options] <params_file> <model_file> <vcf_file>'
    parser = OptionParser(usage)
    parser.add_option('-f',
                      dest='genome_fasta',
                      default=None,
                      help='Genome FASTA for sequences [Default: %default]')
    parser.add_option('-l',
                      dest='plot_lim_min',
                      default=0.1,
                      type='float',
                      help='Heatmap plot limit [Default: %default]')
    parser.add_option(
        '-m',
        dest='plot_map',
        default=False,
        action='store_true',
        help='Plot contact map for each allele [Default: %default]')
    parser.add_option(
        '-o',
        dest='out_dir',
        default='scd',
        help='Output directory for tables and plots [Default: %default]')
    parser.add_option('-p',
                      dest='processes',
                      default=None,
                      type='int',
                      help='Number of processes, passed by multi script')
    parser.add_option(
        '--rc',
        dest='rc',
        default=False,
        action='store_true',
        help=
        'Average forward and reverse complement predictions [Default: %default]'
    )
    parser.add_option('--shifts',
                      dest='shifts',
                      default='0',
                      type='str',
                      help='Ensemble prediction shifts [Default: %default]')
    parser.add_option(
        '--stats',
        dest='scd_stats',
        default='SCD',
        help='Comma-separated list of stats to save. [Default: %default]')
    parser.add_option(
        '-t',
        dest='targets_file',
        default=None,
        type='str',
        help='File specifying target indexes and labels in table format')
    (options, args) = parser.parse_args()

    if len(args) == 3:
        # single worker
        params_file = args[0]
        model_file = args[1]
        vcf_file = args[2]

    elif len(args) == 5:
        # multi worker
        options_pkl_file = args[0]
        params_file = args[1]
        model_file = args[2]
        vcf_file = args[3]
        worker_index = int(args[4])

        # load options
        options_pkl = open(options_pkl_file, 'rb')
        options = pickle.load(options_pkl)
        options_pkl.close()

        # update output directory
        options.out_dir = '%s/job%d' % (options.out_dir, worker_index)

    else:
        parser.error(
            'Must provide parameters and model files and QTL VCF file')

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)
    if options.plot_map:
        plot_dir = options.out_dir
    else:
        plot_dir = None

    options.shifts = [int(shift) for shift in options.shifts.split(',')]
    options.scd_stats = options.scd_stats.split(',')

    random.seed(44)

    #################################################################
    # read parameters and targets

    # read model parameters
    with open(params_file) as params_open:
        params = json.load(params_open)
    params_train = params['train']
    params_model = params['model']

    if options.targets_file is not None:
        targets_df = pd.read_csv(options.targets_file, sep='\t', index_col=0)
        target_ids = targets_df.identifier
        target_labels = targets_df.description

    #################################################################
    # setup model

    # load model
    seqnn_model = seqnn.SeqNN(params_model)
    seqnn_model.restore(model_file)
    seqnn_model.build_ensemble(options.rc, options.shifts)

    # dummy target info
    if options.targets_file is None:
        num_targets = seqnn_model.num_targets()
        target_ids = ['t%d' % ti for ti in range(num_targets)]
        target_labels = [''] * len(target_ids)

    #################################################################
    # load SNPs

    # filter for worker SNPs
    if options.processes is not None:
        # determine boundaries
        num_snps = bvcf.vcf_count(vcf_file)
        worker_bounds = np.linspace(0,
                                    num_snps,
                                    options.processes + 1,
                                    dtype='int')

        # read SNPs form VCF
        snps = bvcf.vcf_snps(vcf_file,
                             start_i=worker_bounds[worker_index],
                             end_i=worker_bounds[worker_index + 1])

    else:
        # read SNPs form VCF
        snps = bvcf.vcf_snps(vcf_file)

    num_snps = len(snps)

    # open genome FASTA
    genome_open = pysam.Fastafile(options.genome_fasta)

    def snp_gen():
        for snp in snps:
            # get SNP sequences
            snp_1hot_list = bvcf.snp_seq1(snp, params_model['seq_length'],
                                          genome_open)
            for snp_1hot in snp_1hot_list:
                yield snp_1hot

    #################################################################
    # setup output

    scd_out = initialize_output_h5(options.out_dir, options.scd_stats, snps,
                                   target_ids, target_labels)

    #################################################################
    # predict SNP scores, write output

    write_thread = None

    # initialize predictions stream
    preds_stream = stream.PredStreamGen(seqnn_model, snp_gen(),
                                        params_train['batch_size'])

    # predictions index
    pi = 0

    for si in range(num_snps):
        # get predictions
        ref_preds = preds_stream[pi]
        pi += 1
        alt_preds = preds_stream[pi]
        pi += 1

        # process SNP
        write_snp(ref_preds, alt_preds, scd_out, si, options.scd_stats,
                  plot_dir, seqnn_model.diagonal_offset, options.plot_lim_min)

    genome_open.close()
    scd_out.close()