Example #1
0
def main():
    usage = 'usage: %prog [options] <model> <vcf_file>'
    parser = OptionParser(usage)
    parser.add_option('-c',
                      dest='slice_center',
                      default=None,
                      type='int',
                      help='Slice center positions [Default: %default]')
    parser.add_option('-f',
                      dest='genome_fasta',
                      default='%s/data/hg38.fa' % os.environ['BASENJIDIR'],
                      help='Genome FASTA for sequences [Default: %default]')
    parser.add_option('-n',
                      dest='norm_file',
                      default=None,
                      help='Normalize SAD scores')
    parser.add_option(
        '-o',
        dest='out_dir',
        default='sad',
        help='Output directory for tables and plots [Default: %default]')
    parser.add_option('-p',
                      dest='processes',
                      default=None,
                      type='int',
                      help='Number of processes, passed by multi script')
    parser.add_option('--pseudo',
                      dest='log_pseudo',
                      default=1,
                      type='float',
                      help='Log2 pseudocount [Default: %default]')
    parser.add_option(
        '--rc',
        dest='rc',
        default=False,
        action='store_true',
        help=
        'Average forward and reverse complement predictions [Default: %default]'
    )
    parser.add_option('--shifts',
                      dest='shifts',
                      default='0',
                      type='str',
                      help='Ensemble prediction shifts [Default: %default]')
    parser.add_option('--species', dest='species', default='human')
    parser.add_option(
        '--stats',
        dest='sad_stats',
        default='SAD',
        help='Comma-separated list of stats to save. [Default: %default]')
    parser.add_option(
        '-t',
        dest='targets_file',
        default=None,
        type='str',
        help='File specifying target indexes and labels in table format')
    (options, args) = parser.parse_args()

    if len(args) == 2:
        # single worker
        model_file = args[0]
        vcf_file = args[1]

    elif len(args) == 3:
        # multi separate
        options_pkl_file = args[0]
        model_file = args[1]
        vcf_file = args[2]

        # save out dir
        out_dir = options.out_dir

        # load options
        options_pkl = open(options_pkl_file, 'rb')
        options = pickle.load(options_pkl)
        options_pkl.close()

        # update output directory
        options.out_dir = out_dir

    elif len(args) == 4:
        # multi worker
        options_pkl_file = args[0]
        model_file = args[1]
        vcf_file = args[2]
        worker_index = int(args[3])

        # load options
        options_pkl = open(options_pkl_file, 'rb')
        options = pickle.load(options_pkl)
        options_pkl.close()

        # update output directory
        options.out_dir = '%s/job%d' % (options.out_dir, worker_index)

    else:
        parser.error('Must provide model and VCF file')

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    options.shifts = [int(shift) for shift in options.shifts.split(',')]
    options.sad_stats = options.sad_stats.split(',')

    #################################################################
    # read parameters and targets

    if options.targets_file is None:
        target_slice = None
    else:
        targets_df = pd.read_csv(options.targets_file, sep='\t', index_col=0)
        target_ids = targets_df.identifier
        target_labels = targets_df.description
        target_slice = targets_df.index

    #################################################################
    # setup model

    seqnn_model = tf.saved_model.load(model_file).model

    # query num model targets
    seq_length = seqnn_model.predict_on_batch.input_signature[0].shape[1]
    null_1hot = np.zeros((1, seq_length, 4))
    null_preds = seqnn_model.predict_on_batch(null_1hot)
    null_preds = null_preds[options.species].numpy()
    _, targets_length, num_targets = null_preds.shape

    if options.targets_file is None:
        target_ids = ['t%d' % ti for ti in range(num_targets)]
        target_labels = [''] * len(target_ids)

    #################################################################
    # load SNPs

    # filter for worker SNPs
    if options.processes is not None:
        # determine boundaries
        num_snps = bvcf.vcf_count(vcf_file)
        worker_bounds = np.linspace(0,
                                    num_snps,
                                    options.processes + 1,
                                    dtype='int')

        # read SNPs form VCF
        snps = bvcf.vcf_snps(vcf_file,
                             start_i=worker_bounds[worker_index],
                             end_i=worker_bounds[worker_index + 1])

    else:
        # read SNPs form VCF
        snps = bvcf.vcf_snps(vcf_file)

    num_snps = len(snps)

    # open genome FASTA
    genome_open = pysam.Fastafile(options.genome_fasta)

    # create SNP sequence generator
    def snp_gen():
        for snp in snps:
            # get SNP sequences
            snp_1hot_list = bvcf.snp_seq1(snp, seq_length, genome_open)
            for snp_1hot in snp_1hot_list:
                yield snp_1hot

    #################################################################
    # setup output

    sad_out = initialize_output_h5(options.out_dir, options.sad_stats, snps,
                                   target_ids, target_labels, targets_length)

    #################################################################
    # predict SNP scores, write output

    # initialize predictions stream
    preds_stream = PredStreamGen(seqnn_model,
                                 snp_gen(),
                                 rc=options.rc,
                                 shifts=options.shifts,
                                 slice_center=options.slice_center,
                                 species=options.species)

    # predictions index
    pi = 0

    for si in range(num_snps):
        # get predictions
        ref_preds = preds_stream[pi]
        pi += 1
        alt_preds = preds_stream[pi]
        pi += 1

        # process SNP
        write_snp(ref_preds, alt_preds, sad_out, si, options.sad_stats,
                  options.log_pseudo)

    # close genome
    genome_open.close()

    ###################################################
    # compute SAD distributions across variants

    write_pct(sad_out, options.sad_stats)
    sad_out.close()
Example #2
0
def main():
    usage = 'usage: %prog [options] <params_file> <model_file> <vcf_file>'
    parser = OptionParser(usage)
    parser.add_option(
        '-c',
        dest='center_pct',
        default=0.25,
        type='float',
        help='Require clustered SNPs lie in center region [Default: %default]')
    parser.add_option('-f',
                      dest='genome_fasta',
                      default='%s/data/hg19.fa' % os.environ['BASENJIDIR'],
                      help='Genome FASTA for sequences [Default: %default]')
    parser.add_option(
        '--flip',
        dest='flip_ref',
        default=False,
        action='store_true',
        help='Flip reference/alternate alleles when simple [Default: %default]'
    )
    parser.add_option('--local',
                      dest='local',
                      default=1024,
                      type='int',
                      help='Local SAD score [Default: %default]')
    parser.add_option('-n',
                      dest='norm_file',
                      default=None,
                      help='Normalize SAD scores')
    parser.add_option(
        '-o',
        dest='out_dir',
        default='sad',
        help='Output directory for tables and plots [Default: %default]')
    parser.add_option('-p',
                      dest='processes',
                      default=None,
                      type='int',
                      help='Number of processes, passed by multi script')
    parser.add_option('--pseudo',
                      dest='log_pseudo',
                      default=1,
                      type='float',
                      help='Log2 pseudocount [Default: %default]')
    parser.add_option(
        '--rc',
        dest='rc',
        default=False,
        action='store_true',
        help=
        'Average forward and reverse complement predictions [Default: %default]'
    )
    parser.add_option('--shifts',
                      dest='shifts',
                      default='0',
                      type='str',
                      help='Ensemble prediction shifts [Default: %default]')
    parser.add_option(
        '--stats',
        dest='sad_stats',
        default='SAD',
        help='Comma-separated list of stats to save. [Default: %default]')
    parser.add_option(
        '-t',
        dest='targets_file',
        default=None,
        type='str',
        help='File specifying target indexes and labels in table format')
    parser.add_option(
        '--ti',
        dest='track_indexes',
        default=None,
        type='str',
        help='Comma-separated list of target indexes to output BigWig tracks')
    parser.add_option(
        '--threads',
        dest='threads',
        default=False,
        action='store_true',
        help='Run CPU math and output in a separate thread [Default: %default]'
    )
    parser.add_option(
        '-u',
        dest='penultimate',
        default=False,
        action='store_true',
        help='Compute SED in the penultimate layer [Default: %default]')
    (options, args) = parser.parse_args()

    if len(args) == 3:
        # single worker
        params_file = args[0]
        model_file = args[1]
        vcf_file = args[2]

    elif len(args) == 5:
        # multi worker
        options_pkl_file = args[0]
        params_file = args[1]
        model_file = args[2]
        vcf_file = args[3]
        worker_index = int(args[4])

        # load options
        options_pkl = open(options_pkl_file, 'rb')
        options = pickle.load(options_pkl)
        options_pkl.close()

        # update output directory
        options.out_dir = '%s/job%d' % (options.out_dir, worker_index)

    else:
        parser.error(
            'Must provide parameters and model files and QTL VCF file')

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    if options.track_indexes is None:
        options.track_indexes = []
    else:
        options.track_indexes = [
            int(ti) for ti in options.track_indexes.split(',')
        ]
        if not os.path.isdir('%s/tracks' % options.out_dir):
            os.mkdir('%s/tracks' % options.out_dir)

    options.shifts = [int(shift) for shift in options.shifts.split(',')]
    options.sad_stats = options.sad_stats.split(',')

    #################################################################
    # read parameters and targets

    # read model parameters
    with open(params_file) as params_open:
        params = json.load(params_open)
    params_model = params['model']
    params_train = params['train']

    if options.targets_file is None:
        target_slice = None
    else:
        targets_df = pd.read_csv(options.targets_file, sep='\t', index_col=0)
        target_ids = targets_df.identifier
        target_labels = targets_df.description
        target_slice = targets_df.index

    if options.penultimate:
        parser.error('Not implemented for TF2')

    #################################################################
    # setup model

    seqnn_model = seqnn.SeqNN(params_model)
    seqnn_model.restore(model_file)
    seqnn_model.build_slice(target_slice)
    seqnn_model.build_ensemble(options.rc, options.shifts)

    num_targets = seqnn_model.num_targets()
    if options.targets_file is None:
        target_ids = ['t%d' % ti for ti in range(num_targets)]
        target_labels = [''] * len(target_ids)

    #################################################################
    # load SNPs

    # filter for worker SNPs
    if options.processes is not None:
        # determine boundaries
        num_snps = bvcf.vcf_count(vcf_file)
        worker_bounds = np.linspace(0,
                                    num_snps,
                                    options.processes + 1,
                                    dtype='int')

        # read sorted SNPs from VCF
        snps = bvcf.vcf_snps(vcf_file,
                             require_sorted=True,
                             flip_ref=options.flip_ref,
                             validate_ref_fasta=options.genome_fasta,
                             start_i=worker_bounds[worker_index],
                             end_i=worker_bounds[worker_index + 1])
    else:
        # read sorted SNPs from VCF
        snps = bvcf.vcf_snps(vcf_file,
                             require_sorted=True,
                             flip_ref=options.flip_ref,
                             validate_ref_fasta=options.genome_fasta)

    # cluster SNPs by position
    snp_clusters = cluster_snps(snps, params_model['seq_length'],
                                options.center_pct)

    # delimit sequence boundaries
    [sc.delimit(params_model['seq_length']) for sc in snp_clusters]

    # open genome FASTA
    genome_open = pysam.Fastafile(options.genome_fasta)

    # make SNP sequence generator
    def snp_gen():
        for sc in snp_clusters:
            snp_1hot_list = sc.get_1hots(genome_open)
            for snp_1hot in snp_1hot_list:
                yield snp_1hot

    #################################################################
    # setup output

    snp_flips = np.array([snp.flipped for snp in snps], dtype='bool')

    sad_out = initialize_output_h5(options.out_dir, options.sad_stats, snps,
                                   target_ids, target_labels)

    if options.threads:
        snp_threads = []
        snp_queue = Queue()
        for i in range(1):
            sw = SNPWorker(snp_queue, sad_out, options.sad_stats,
                           options.log_pseudo)
            sw.start()
            snp_threads.append(sw)

    #################################################################
    # predict SNP scores, write output

    # initialize predictions stream
    preds_stream = stream.PredStreamGen(seqnn_model, snp_gen(),
                                        params['train']['batch_size'])

    # predictions index
    pi = 0

    # SNP index
    si = 0

    for snp_cluster in snp_clusters:
        ref_preds = preds_stream[pi]
        pi += 1

        for snp in snp_cluster.snps:
            # print(snp, flush=True)

            alt_preds = preds_stream[pi]
            pi += 1

            if snp_flips[si]:
                ref_preds, alt_preds = alt_preds, ref_preds

            if options.threads:
                # queue SNP
                snp_queue.put((ref_preds, alt_preds, si))
            else:
                # process SNP
                write_snp(ref_preds, alt_preds, sad_out, si, options.sad_stats,
                          options.log_pseudo)

            # update SNP index
            si += 1

    # finish queue
    if options.threads:
        print('Waiting for threads to finish.', flush=True)
        snp_queue.join()

    # close genome
    genome_open.close()

    ###################################################
    # compute SAD distributions across variants

    write_pct(sad_out, options.sad_stats)
    sad_out.close()