Exemple #1
0
def main():
    usage = 'usage: %prog [options] <model_th> <vcf_file>'
    parser = OptionParser(usage)
    parser.add_option('-c',
                      dest='csv',
                      default=False,
                      action='store_true',
                      help='Print table as CSV [Default: %default]')
    parser.add_option('--cuda',
                      dest='cuda',
                      default=False,
                      action='store_true',
                      help='Predict on the GPU [Default: %default]')
    parser.add_option('--cudnn',
                      dest='cudnn',
                      default=False,
                      action='store_true',
                      help='Predict on the GPU w/ CuDNN [Default: %default]')
    parser.add_option(
        '-d',
        dest='model_hdf5_file',
        default=None,
        help='Pre-computed model output as HDF5 [Default: %default]')
    parser.add_option(
        '--dense',
        dest='dense_table',
        default=False,
        action='store_true',
        help=
        'Print a dense SNP x Targets table, as opposed to a SNP/Target pair per line [Default: %default]'
    )
    parser.add_option(
        '-e',
        dest='heatmaps',
        default=False,
        action='store_true',
        help='Draw score heatmaps, grouped by index SNP [Default: %default]')
    parser.add_option(
        '-f',
        dest='genome_fasta',
        default='%s/data/genomes/hg19.fa' % os.environ['BASSETDIR'],
        help=
        'Genome FASTA from which sequences will be drawn [Default: %default]')
    parser.add_option(
        '--f1',
        dest='genome1_fasta',
        default=None,
        help='Genome FASTA which which major allele sequences will be drawn')
    parser.add_option(
        '--f2',
        dest='genome2_fasta',
        default=None,
        help='Genome FASTA which which minor allele sequences will be drawn')
    parser.add_option(
        '-i',
        dest='index_snp',
        default=False,
        action='store_true',
        help=
        'SNPs are labeled with their index SNP as column 6 [Default: %default]'
    )
    parser.add_option(
        '-l',
        dest='seq_len',
        type='int',
        default=1000,
        help='Sequence length provided to the model [Default: %default]')
    parser.add_option('-m',
                      dest='min_limit',
                      default=0.1,
                      type='float',
                      help='Minimum heatmap limit [Default: %default]')
    parser.add_option(
        '-o',
        dest='out_dir',
        default='sad',
        help='Output directory for tables and plots [Default: %default]')
    parser.add_option(
        '-s',
        dest='score',
        default=False,
        action='store_true',
        help='SNPs are labeled with scores as column 7 [Default: %default]')
    parser.add_option(
        '-t',
        dest='targets_file',
        default=None,
        help='File specifying target indexes and labels in table format')
    (options, args) = parser.parse_args()

    if len(args) != 2:
        parser.error('Must provide Torch model and VCF file')
    else:
        model_th = args[0]
        vcf_file = args[1]

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    #################################################################
    # prep SNP sequences
    #################################################################
    # load SNPs
    snps = bvcf.vcf_snps(vcf_file, options.index_snp, options.score,
                         options.genome2_fasta is not None)

    # get one hot coded input sequences
    if not options.genome1_fasta or not options.genome2_fasta:
        seq_vecs, seq_headers, snps = bvcf.snps_seq1(snps, options.seq_len,
                                                     options.genome_fasta)
    else:
        seq_vecs, seq_headers, snps = bvcf.snps2_seq1(snps, options.seq_len,
                                                      options.genome1_fasta,
                                                      options.genome2_fasta)

    # reshape sequences for torch
    seq_vecs = seq_vecs.reshape(
        (seq_vecs.shape[0], 4, 1, seq_vecs.shape[1] // 4))

    # write to HDF5
    h5f = h5py.File('%s/model_in.h5' % options.out_dir, 'w')
    h5f.create_dataset('test_in', data=seq_vecs)
    h5f.close()

    #################################################################
    # predict in Torch
    #################################################################
    if options.model_hdf5_file is None:
        if options.cudnn:
            cuda_str = '-cudnn'
        elif options.cuda:
            cuda_str = '-cuda'
        else:
            cuda_str = ''

        options.model_hdf5_file = '%s/model_out.txt' % options.out_dir
        cmd = '%s/src/basset_predict.lua -rc %s %s %s/model_in.h5 %s' % (
            os.environ['BASSETDIR'], cuda_str, model_th, options.out_dir,
            options.model_hdf5_file)
        print(cmd)
        subprocess.call(cmd, shell=True)

        # clean up
        os.remove('%s/model_in.h5' % options.out_dir)

    # read in predictions
    seq_preds = []
    for line in open(options.model_hdf5_file):
        # seq_preds.append(np.array([np.float16(p) for p in line.split()]))
        seq_preds.append(
            np.array([float(p) for p in line.split()], dtype='float16'))
    seq_preds = np.array(seq_preds, dtype='float16')

    # clean up
    os.remove(options.model_hdf5_file)

    #################################################################
    # collect and print SADs
    #################################################################
    if options.targets_file is None:
        target_labels = ['t%d' % ti for ti in range(seq_preds.shape[1])]
    else:
        target_labels = [
            line.split()[0] for line in open(options.targets_file)
        ]

    if options.dense_table:
        sad_out = open('%s/sad_table.txt' % options.out_dir, 'w')
    else:
        header_cols = ('rsid', 'index', 'score', 'ref', 'alt', 'target',
                       'ref_pred', 'alt pred', 'sad')
        if options.csv:
            sad_out = open('%s/sad_table.csv' % options.out_dir, 'w')
            print(','.join(header_cols), file=sad_out)
        else:
            sad_out = open('%s/sad_table.txt' % options.out_dir, 'w')
            print(' '.join(header_cols), file=sad_out)

    # hash by index snp
    sad_matrices = {}
    sad_labels = {}
    sad_scores = {}

    pi = 0
    for snp in snps:
        # get reference prediction
        ref_preds = seq_preds[pi, :]
        pi += 1

        for alt_al in snp.alt_alleles:
            # get alternate prediction
            alt_preds = seq_preds[pi, :]
            pi += 1

            # normalize by reference
            alt_sad = alt_preds - ref_preds
            sad_matrices.setdefault(snp.index_snp, []).append(alt_sad)

            # label as mutation from reference
            alt_label = '%s_%s>%s' % (snp.rsid, bvcf.cap_allele(
                snp.ref_allele), bvcf.cap_allele(alt_al))
            sad_labels.setdefault(snp.index_snp, []).append(alt_label)

            # save scores
            sad_scores.setdefault(snp.index_snp, []).append(snp.score)

            # set index SNP
            snp_is = '%-13s' % '.'
            if options.index_snp:
                snp_is = '%-13s' % snp.index_snp

            # set score
            snp_score = '%5s' % '.'
            if options.score:
                snp_score = '%5.3f' % snp.score

            # print table line(s)
            if options.dense_table:
                cols = [
                    snp.rsid, snp_is, snp_score,
                    bvcf.cap_allele(snp.ref_allele),
                    bvcf.cap_allele(alt_al)
                ]
                for ti in range(len(alt_sad)):
                    cols += ['%.4f' % ref_preds[ti], '%.4f' % alt_sad[ti]]

                sep = ' '
                if options.csv:
                    sep = ','

                print(sep.join([str(c) for c in cols]), file=sad_out)

            else:
                for ti in range(len(alt_sad)):
                    cols = (snp.rsid, snp_is, snp_score,
                            bvcf.cap_allele(snp.ref_allele),
                            bvcf.cap_allele(alt_al), target_labels[ti],
                            ref_preds[ti], alt_preds[ti], alt_sad[ti])
                    if options.csv:
                        print(','.join([str(c) for c in cols]), file=sad_out)
                    else:
                        print('%-13s %s %5s %6s %6s %12s %6.4f %6.4f %7.4f' %
                              cols,
                              file=sad_out)

    sad_out.close()

    #################################################################
    # plot SAD heatmaps
    #################################################################
    if options.heatmaps:
        for ii in sad_matrices:
            # convert fully to numpy arrays
            sad_matrix = abs(np.array(sad_matrices[ii]))
            print(ii, sad_matrix.shape)

            if sad_matrix.shape[0] > 1:
                vlim = max(options.min_limit, sad_matrix.max())
                score_mat = np.reshape(np.array(sad_scores[ii]), (-1, 1))

                # plot heatmap
                plt.figure(figsize=(20, 0.5 + 0.5 * sad_matrix.shape[0]))

                if options.score:
                    # lay out scores
                    cols = 12
                    ax_score = plt.subplot2grid((1, cols), (0, 0))
                    ax_sad = plt.subplot2grid((1, cols), (0, 1),
                                              colspan=(cols - 1))

                    sns.heatmap(score_mat,
                                xticklabels=False,
                                yticklabels=False,
                                vmin=0,
                                vmax=1,
                                cmap='Reds',
                                cbar=False,
                                ax=ax_score)
                else:
                    ax_sad = plt.gca()

                sns.heatmap(sad_matrix,
                            xticklabels=target_labels,
                            yticklabels=sad_labels[ii],
                            vmin=0,
                            vmax=vlim,
                            ax=ax_sad)

                for tick in ax_sad.get_xticklabels():
                    tick.set_rotation(-45)
                    tick.set_horizontalalignment('left')
                    tick.set_fontsize(5)

                plt.tight_layout()
                if ii == '.':
                    out_pdf = '%s/sad_heat.pdf' % options.out_dir
                else:
                    out_pdf = '%s/sad_%s_heat.pdf' % (options.out_dir, ii)
                plt.savefig(out_pdf)
                plt.close()
Exemple #2
0
def main():
    usage = 'usage: %prog [options] <model_file> <vcf_file>'
    parser = OptionParser(usage)
    parser.add_option(
        '-d',
        dest='model_hdf5_file',
        default=None,
        help='Pre-computed model output as HDF5 [Default: %default]')
    parser.add_option(
        '-f',
        dest='genome_fasta',
        default='%s/data/genomes/hg19.fa' % os.environ['BASSETDIR'],
        help=
        'Genome FASTA from which sequences will be drawn [Default: %default]')
    parser.add_option(
        '--f1',
        dest='genome1_fasta',
        default=None,
        help='Genome FASTA which which major allele sequences will be drawn')
    parser.add_option(
        '--f2',
        dest='genome2_fasta',
        default=None,
        help='Genome FASTA which which minor allele sequences will be drawn')
    parser.add_option(
        '-g',
        dest='gain_height',
        default=False,
        action='store_true',
        help=
        'Nucleotide heights determined by the max of loss and gain [Default: %default]'
    )
    parser.add_option(
        '-l',
        dest='seq_len',
        type='int',
        default=600,
        help='Sequence length provided to the model [Default: %default]')
    parser.add_option('-m',
                      dest='min_limit',
                      default=0.1,
                      type='float',
                      help='Minimum heatmap limit [Default: %default]')
    parser.add_option(
        '-n',
        dest='center_nt',
        default=200,
        type='int',
        help=
        'Nt around the SNP to mutate and plot in the heatmap [Default: %default]'
    )
    parser.add_option('-o',
                      dest='out_dir',
                      default='heat',
                      help='Output directory [Default: %default]')
    parser.add_option(
        '-t',
        dest='targets',
        default='-1',
        help=
        'Comma-separated list of target indexes to plot (or -1 for all) [Default: %default]'
    )
    (options, args) = parser.parse_args()

    if len(args) != 2:
        parser.error(
            'Must provide Basset model file and input SNPs in VCF format')
    else:
        model_file = args[0]
        vcf_file = args[1]

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    #################################################################
    # prep SNP sequences
    #################################################################
    # load SNPs
    snps = bvcf.vcf_snps(vcf_file, pos2=(options.genome2_fasta is not None))

    # get one hot coded input sequences
    if not options.genome1_fasta or not options.genome2_fasta:
        seqs_1hot, seq_headers, snps, seqs = bvcf.snps_seq1(
            snps, options.seq_len, options.genome_fasta, return_seqs=True)
    else:
        seqs_1hot, seq_headers, snps, seqs = bvcf.snps2_seq1(
            snps,
            options.seq_len,
            options.genome1_fasta,
            options.genome2_fasta,
            return_seqs=True)

    # reshape sequences for torch
    seqs_1hot = seqs_1hot.reshape(
        (seqs_1hot.shape[0], 4, 1, seqs_1hot.shape[1] // 4))

    # write to HDF5
    model_input_hdf5 = '%s/model_in.h5' % options.out_dir
    h5f = h5py.File(model_input_hdf5, 'w')
    h5f.create_dataset('test_in', data=seqs_1hot)
    h5f.close()

    #################################################################
    # Torch predict modifications
    #################################################################
    if options.model_hdf5_file is None:
        options.model_hdf5_file = '%s/model_out.h5' % options.out_dir
        torch_cmd = '%s/src/basset_sat_predict.lua -center_nt %d %s %s %s' % (
            os.environ['BASSETDIR'], options.center_nt, model_file,
            model_input_hdf5, options.model_hdf5_file)
        subprocess.call(torch_cmd, shell=True)

    #################################################################
    # load modification predictions
    #################################################################
    hdf5_in = h5py.File(options.model_hdf5_file, 'r')
    seq_mod_preds = np.array(hdf5_in['seq_mod_preds'])
    hdf5_in.close()

    # trim seqs to match seq_mod_preds length
    delta_start = 0
    delta_len = seq_mod_preds.shape[2]
    if delta_len < options.seq_len:
        delta_start = (options.seq_len - delta_len) // 2
        for si in range(len(seqs)):
            seqs[si] = seqs[si][delta_start:delta_start + delta_len]

    # decide which cells to plot
    if options.targets == '-1':
        plot_cells = xrange(seq_mod_preds.shape[3])
    else:
        plot_cells = [int(ci) for ci in options.targets.split(',')]

    #################################################################
    # plot
    #################################################################
    table_out = open('%s/table.txt' % options.out_dir, 'w')

    rdbu = sns.color_palette("RdBu_r", 10)

    nts = 'ACGT'
    for si in range(seq_mod_preds.shape[0]):
        header = seq_headers[si]
        header_fs = fs_clean(header)
        seq = seqs[si]

        # plot some descriptive heatmaps for each individual cell type
        for ci in plot_cells:
            seq_mod_preds_cell = seq_mod_preds[si, :, :, ci]
            real_pred_cell = basset_sat.get_real_pred(seq_mod_preds_cell, seq)

            # compute matrices
            norm_matrix = seq_mod_preds_cell - real_pred_cell
            min_scores = seq_mod_preds_cell.min(axis=0)
            max_scores = seq_mod_preds_cell.max(axis=0)
            minmax_matrix = np.vstack(
                [min_scores - real_pred_cell, max_scores - real_pred_cell])

            # prepare figure
            sns.set(style='white', font_scale=0.5)
            sns.axes_style({'axes.linewidth': 1})
            spp = basset_sat.subplot_params(seq_mod_preds_cell.shape[1])
            fig = plt.figure(figsize=(20, 3))
            ax_logo = plt.subplot2grid(
                (3, spp['heat_cols']), (0, spp['logo_start']),
                colspan=(spp['logo_end'] - spp['logo_start']))
            ax_sad = plt.subplot2grid(
                (3, spp['heat_cols']), (1, spp['sad_start']),
                colspan=(spp['sad_end'] - spp['sad_start']))
            ax_heat = plt.subplot2grid((3, spp['heat_cols']), (2, 0),
                                       colspan=spp['heat_cols'])

            # print a WebLogo of the sequence
            vlim = max(options.min_limit, abs(minmax_matrix).max())
            if options.gain_height:
                seq_heights = 0.25 + 1.75 / vlim * (abs(minmax_matrix).max(
                    axis=0))
            else:
                seq_heights = 0.25 + 1.75 / vlim * (-minmax_matrix[0])
            logo_eps = '%s/%s_c%d_seq.eps' % (options.out_dir, header_fs, ci)
            seq_logo(seq, seq_heights, logo_eps, color_mode='meme')

            # add to figure
            logo_png = '%s.png' % logo_eps[:-4]
            subprocess.call('convert -density 300 %s %s' %
                            (logo_eps, logo_png),
                            shell=True)
            logo = Image.open(logo_png)
            ax_logo.imshow(logo)
            ax_logo.set_axis_off()

            # plot loss and gain SAD scores
            ax_sad.plot(-minmax_matrix[0],
                        c=rdbu[0],
                        label='loss',
                        linewidth=1)
            ax_sad.plot(minmax_matrix[1],
                        c=rdbu[-1],
                        label='gain',
                        linewidth=1)
            ax_sad.set_xlim(0, minmax_matrix.shape[1])
            ax_sad.legend()
            # ax_sad.grid(True, linestyle=':')
            for axis in ['top', 'bottom', 'left', 'right']:
                ax_sad.spines[axis].set_linewidth(0.5)

            # plot real-normalized scores
            vlim = max(options.min_limit, abs(norm_matrix).max())
            sns.heatmap(norm_matrix,
                        linewidths=0,
                        cmap='RdBu_r',
                        vmin=-vlim,
                        vmax=vlim,
                        xticklabels=False,
                        ax=ax_heat)
            ax_heat.yaxis.set_ticklabels('ACGT',
                                         rotation='horizontal')  # , size=10)

            # save final figure
            plt.tight_layout()
            plt.savefig('%s/%s_c%d_heat.pdf' %
                        (options.out_dir, header_fs, ci),
                        dpi=300)
            plt.close()

        #################################################################
        # print table of nt variability for each cell
        #################################################################
        for ci in range(seq_mod_preds.shape[3]):
            seq_mod_preds_cell = seq_mod_preds[si, :, :, ci]
            real_pred_cell = basset_sat.get_real_pred(seq_mod_preds_cell, seq)

            min_scores = seq_mod_preds_cell.min(axis=0)
            max_scores = seq_mod_preds_cell.max(axis=0)

            loss_matrix = real_pred_cell - seq_mod_preds_cell.min(axis=0)
            gain_matrix = seq_mod_preds_cell.max(axis=0) - real_pred_cell

            for pos in range(seq_mod_preds_cell.shape[1]):
                cols = [
                    header, delta_start + pos, ci, loss_matrix[pos],
                    gain_matrix[pos]
                ]
                print('\t'.join([str(c) for c in cols]), file=table_out)

    table_out.close()