Exemplo n.º 1
0
def main():
    usage = 'usage: %prog [options] <model_th> <vcf_file>'
    parser = OptionParser(usage)
    parser.add_option('-c', dest='csv', default=False, action='store_true', help='Print table as CSV [Default: %default]')
    parser.add_option('--cuda', dest='cuda', default=False, action='store_true', help='Predict on the GPU [Default: %default]')
    parser.add_option('--cudnn', dest='cudnn', default=False, action='store_true', help='Predict on the GPU w/ CuDNN [Default: %default]')
    parser.add_option('-d', dest='model_hdf5_file', default=None, help='Pre-computed model output as HDF5 [Default: %default]')
    parser.add_option('-e', dest='heatmaps', default=False, action='store_true', help='Draw score heatmaps, grouped by index SNP [Default: %default]')
    parser.add_option('-f', dest='genome_fasta', default='%s/data/genomes/hg19.fa'%os.environ['BASSETDIR'], help='Genome FASTA from which sequences will be drawn [Default: %default]')
    parser.add_option('--f1', dest='genome1_fasta', default=None, help='Genome FASTA which which major allele sequences will be drawn')
    parser.add_option('--f2', dest='genome2_fasta', default=None, help='Genome FASTA which which minor allele sequences will be drawn')
    parser.add_option('-i', dest='index_snp', default=False, action='store_true', help='SNPs are labeled with their index SNP as column 6 [Default: %default]')
    parser.add_option('-l', dest='seq_len', type='int', default=600, help='Sequence length provided to the model [Default: %default]')
    parser.add_option('-m', dest='min_limit', default=0.1, type='float', help='Minimum heatmap limit [Default: %default]')
    parser.add_option('-o', dest='out_dir', default='sad', help='Output directory for tables and plots [Default: %default]')
    parser.add_option('-s', dest='score', default=False, action='store_true', help='SNPs are labeled with scores as column 7 [Default: %default]')
    parser.add_option('-t', dest='targets_file', default=None, help='File specifying target indexes and labels in table format')
    (options,args) = parser.parse_args()

    if len(args) != 2:
        parser.error('Must provide Torch model and VCF file')
    else:
        model_th = args[0]
        vcf_file = args[1]

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    #################################################################
    # prep SNP sequences
    #################################################################
    # load SNPs
    snps = vcf.vcf_snps(vcf_file, options.index_snp, options.score, options.genome2_fasta is not None)

    # get one hot coded input sequences
    if not options.genome1_fasta or not options.genome2_fasta:
        seq_vecs, seqs, seq_headers, snps = vcf.snps_seq1(snps, options.seq_len, options.genome_fasta)
    else:
        seq_vecs, seqs, seq_headers, snps = vcf.snps2_seq1(snps, options.seq_len, options.genome1_fasta, options.genome2_fasta)

    # reshape sequences for torch
    seq_vecs = seq_vecs.reshape((seq_vecs.shape[0],4,1,seq_vecs.shape[1]/4))

    # write to HDF5
    h5f = h5py.File('%s/model_in.h5'%options.out_dir, 'w')
    h5f.create_dataset('test_in', data=seq_vecs)
    h5f.close()


    #################################################################
    # predict in Torch
    #################################################################
    if options.model_hdf5_file is None:
        if options.cudnn:
            cuda_str = '-cudnn'
        elif options.cuda:
            cuda_str = '-cuda'
        else:
            cuda_str = ''

        options.model_hdf5_file = '%s/model_out.txt' % options.out_dir
        cmd = 'basset_predict.lua -rc %s %s %s/model_in.h5 %s' % (cuda_str, model_th, options.out_dir, options.model_hdf5_file)
        print cmd
        subprocess.call(cmd, shell=True)

    # read in predictions
    seq_preds = []
    for line in open(options.model_hdf5_file):
        seq_preds.append(np.array([np.float16(p) for p in line.split()]))
    seq_preds = np.array(seq_preds)


    #################################################################
    # collect and print SADs
    #################################################################
    if options.targets_file is None:
        target_labels = ['t%d' % ti for ti in range(seq_preds.shape[1])]
    else:
        target_labels = [line.split()[0] for line in open(options.targets_file)]

    header_cols = ('rsid', 'index', 'score', 'ref', 'alt', 'target', 'ref_pred', 'alt pred', 'sad')
    if options.csv:
        sad_out = open('%s/sad_table.csv' % options.out_dir, 'w')
        print >> sad_out, ','.join(header_cols)
    else:
        sad_out = open('%s/sad_table.txt' % options.out_dir, 'w')
        print >> sad_out, ' '.join(header_cols)

    # hash by index snp
    sad_matrices = {}
    sad_labels = {}
    sad_scores = {}

    pi = 0
    for snp in snps:
        # get reference prediction
        ref_preds = seq_preds[pi,:]
        pi += 1

        for alt_al in snp.alt_alleles:
            # get alternate prediction
            alt_preds = seq_preds[pi,:]
            pi += 1

            # normalize by reference
            alt_sad = alt_preds - ref_preds
            sad_matrices.setdefault(snp.index_snp,[]).append(alt_sad)

            # label as mutation from reference
            alt_label = '%s_%s>%s' % (snp.rsid, vcf.cap_allele(snp.ref_allele), vcf.cap_allele(alt_al))
            sad_labels.setdefault(snp.index_snp,[]).append(alt_label)

            # save scores
            sad_scores.setdefault(snp.index_snp,[]).append(snp.score)

            # print table lines
            for ti in range(len(alt_sad)):
                if options.index_snp and options.score:
                    cols = (snp.rsid, snp.index_snp, snp.score, vcf.cap_allele(snp.ref_allele), vcf.cap_allele(alt_al), target_labels[ti], ref_preds[ti], alt_preds[ti], alt_sad[ti])
                    if options.csv:
                        print >> sad_out, ','.join([str(c) for c in cols])
                    else:
                        print >> sad_out, '%-13s %-13s %5.3f %6s %6s %12s %6.4f %6.4f %7.4f' % cols

                elif options.index_snp:
                    cols = (snp.rsid, snp.index_snp, vcf.cap_allele(snp.ref_allele), vcf.cap_allele(alt_al), target_labels[ti], ref_preds[ti], alt_preds[ti], alt_sad[ti])
                    if options.csv:
                        print >> sad_out, ','.join([str(c) for c in cols])
                    else:
                        print >> sad_out, '%-13s %-13s %6s %6s %12s %6.4f %6.4f %7.4f' % cols
                elif options.score:
                    cols = (snp.rsid, snp.score, vcf.cap_allele(snp.ref_allele), vcf.cap_allele(alt_al), target_labels[ti], ref_preds[ti], alt_preds[ti], alt_sad[ti])
                    if options.csv:
                        print >> sad_out, ','.join([str(c) for c in cols])
                    else:
                        print >> sad_out, '%-13s %5.3f %6s %6s %12s %6.4f %6.4f %7.4f' % cols
                else:
                    cols = (snp.rsid, vcf.cap_allele(snp.ref_allele), vcf.cap_allele(alt_al), target_labels[ti], ref_preds[ti], alt_preds[ti], alt_sad[ti])
                    if options.csv:
                        print >> sad_out, ','.join([str(c) for c in cols])
                    else:
                        print >> sad_out, '%-13s %6s %6s %12s %6.4f %6.4f %7.4f' % cols

    sad_out.close()


    #################################################################
    # plot SAD heatmaps
    #################################################################
    if options.heatmaps:
        for ii in sad_matrices:
            # convert fully to numpy arrays
            sad_matrix = abs(np.array(sad_matrices[ii]))
            print ii, sad_matrix.shape

            if sad_matrix.shape[0] > 1:
                vlim = max(options.min_limit, sad_matrix.max())
                score_mat = np.reshape(np.array(sad_scores[ii]), (-1, 1))

                # plot heatmap
                plt.figure(figsize=(20, 0.5 + 0.5*sad_matrix.shape[0]))

                if options.score:
                    # lay out scores
                    cols = 12
                    ax_score = plt.subplot2grid((1,cols), (0,0))
                    ax_sad = plt.subplot2grid((1,cols), (0,1), colspan=(cols-1))

                    sns.heatmap(score_mat, xticklabels=False, yticklabels=False, vmin=0, vmax=1, cmap='Reds', cbar=False, ax=ax_score)
                else:
                    ax_sad = plt.gca()

                sns.heatmap(sad_matrix, xticklabels=target_labels, yticklabels=sad_labels[ii], vmin=0, vmax=vlim, ax=ax_sad)

                for tick in ax_sad.get_xticklabels():
                    tick.set_rotation(-45)
                    tick.set_horizontalalignment('left')
                    tick.set_fontsize(5)

                plt.tight_layout()
                if ii == '.':
                    out_pdf = '%s/sad_heat.pdf' % options.out_dir
                else:
                    out_pdf = '%s/sad_%s_heat.pdf' % (options.out_dir, ii)
                plt.savefig(out_pdf)
                plt.close()
Exemplo n.º 2
0
def main():
    usage = 'usage: %prog [options] <model_file> <vcf_file>'
    parser = OptionParser(usage)
    parser.add_option('-d', dest='model_hdf5_file', default=None, help='Pre-computed model output as HDF5 [Default: %default]')
    parser.add_option('-f', dest='genome_fasta', default='%s/data/genomes/hg19.fa'%os.environ['BASSETDIR'], help='Genome FASTA from which sequences will be drawn [Default: %default]')
    parser.add_option('-g', dest='gain_height', default=False, action='store_true', help='Nucleotide heights determined by the max of loss and gain [Default: %default]')
    parser.add_option('-l', dest='seq_len', type='int', default=600, help='Sequence length provided to the model [Default: %default]')
    parser.add_option('-m', dest='min_limit', default=0.1, type='float', help='Minimum heatmap limit [Default: %default]')
    parser.add_option('-n', dest='center_nt', default=200, type='int', help='Nt around the SNP to mutate and plot in the heatmap [Default: %default]')
    parser.add_option('-o', dest='out_dir', default='heat', help='Output directory [Default: %default]')
    parser.add_option('-t', dest='targets', default='0', help='Comma-separated list of target indexes to plot (or -1 for all) [Default: %default]')
    (options,args) = parser.parse_args()

    if len(args) != 2:
        parser.error('Must provide Basset model file and input SNPs in VCF format')
    else:
        model_file = args[0]
        vcf_file = args[1]

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    #################################################################
    # prep SNP sequences
    #################################################################
    # load SNPs
    snps = vcf.vcf_snps(vcf_file)

    # get one hot coded input sequences
    seqs_1hot, seqs, seq_headers = vcf.snps_seq1(snps, options.genome_fasta, options.seq_len)

    # reshape sequences for torch
    seqs_1hot = seqs_1hot.reshape((seqs_1hot.shape[0],4,1,seqs_1hot.shape[1]/4))

    # write to HDF5
    model_input_hdf5 = '%s/model_in.h5'%options.out_dir
    h5f = h5py.File(model_input_hdf5, 'w')
    h5f.create_dataset('test_in', data=seqs_1hot)
    h5f.close()


    #################################################################
    # Torch predict modifications
    #################################################################
    if options.model_hdf5_file is None:
        options.model_hdf5_file = '%s/model_out.h5' % options.out_dir
        torch_cmd = 'basset_sat_predict.lua -center_nt %d %s %s %s' % (options.center_nt, model_file, model_input_hdf5, options.model_hdf5_file)
        if subprocess.call(torch_cmd, shell=True):
            message('Error running basset_sat_predict.lua', 'error')

    #################################################################
    # load modification predictions
    #################################################################
    hdf5_in = h5py.File(options.model_hdf5_file, 'r')
    seq_mod_preds = np.array(hdf5_in['seq_mod_preds'])
    hdf5_in.close()

    # trim seqs to match seq_mod_preds length
    delta_start = 0
    delta_len = seq_mod_preds.shape[2]
    if delta_len < options.seq_len:
        delta_start = (options.seq_len - delta_len)/2
        for si in range(len(seqs)):
            seqs[si] = seqs[si][delta_start:delta_start+delta_len]

    # decide which cells to plot
    if options.targets == '-1':
        plot_cells = xrange(seq_mod_preds.shape[3])
    else:
        plot_cells = [int(ci) for ci in options.targets.split(',')]


    #################################################################
    # plot
    #################################################################
    table_out = open('%s/table.txt' % options.out_dir, 'w')

    rdbu = sns.color_palette("RdBu_r", 10)

    nts = 'ACGT'
    for si in range(seq_mod_preds.shape[0]):
        header = seq_headers[si]
        seq = seqs[si]

        # plot some descriptive heatmaps for each individual cell type
        for ci in plot_cells:
            seq_mod_preds_cell = seq_mod_preds[si,:,:,ci]
            real_pred_cell = get_real_pred(seq_mod_preds_cell, seq)

            # compute matrices
            norm_matrix = seq_mod_preds_cell - real_pred_cell
            min_scores = seq_mod_preds_cell.min(axis=0)
            max_scores = seq_mod_preds_cell.max(axis=0)
            minmax_matrix = np.vstack([min_scores - real_pred_cell, max_scores - real_pred_cell])

            # prepare figure
            sns.set(style='white', font_scale=0.5)
            sns.axes_style({'axes.linewidth':1})
            heat_cols = 400
            sad_start = 1
            sad_end = 323
            logo_start = 0
            logo_end = 324
            fig = plt.figure(figsize=(20,3))
            ax_logo = plt.subplot2grid((3,heat_cols), (0,logo_start), colspan=(logo_end-logo_start))
            ax_sad = plt.subplot2grid((3,heat_cols), (1,sad_start), colspan=(sad_end-sad_start))
            ax_heat = plt.subplot2grid((3,heat_cols), (2,0), colspan=heat_cols)

            # print a WebLogo of the sequence
            vlim = max(options.min_limit, abs(minmax_matrix).max())
            if options.gain_height:
                seq_heights = 0.25 + 1.75/vlim*(abs(minmax_matrix).max(axis=0))
            else:
                seq_heights = 0.25 + 1.75/vlim*(-minmax_matrix[0])
            logo_eps = '%s/%s_c%d_seq.eps' % (options.out_dir, header.replace(':','_'), ci)
            seq_logo(seq, seq_heights, logo_eps)

            # add to figure
            logo_png = '%s.png' % logo_eps[:-4]
            logo_cmd = 'convert -density 300 %s %s' % (logo_eps, logo_png)
            if subprocess.call(logo_cmd, shell=True):
                message('Error running convert', 'error')
            logo = Image.open(logo_png)
            ax_logo.imshow(logo)
            ax_logo.set_axis_off()

            # plot loss and gain SAD scores
            ax_sad.plot(-minmax_matrix[0], c=rdbu[0], label='loss', linewidth=1)
            ax_sad.plot(minmax_matrix[1], c=rdbu[-1], label='gain', linewidth=1)
            ax_sad.set_xlim(0,minmax_matrix.shape[1])
            ax_sad.legend()
            # ax_sad.grid(True, linestyle=':')
            for axis in ['top','bottom','left','right']:
                ax_sad.spines[axis].set_linewidth(0.5)

            # plot real-normalized scores
            vlim = max(options.min_limit, abs(norm_matrix).max())
            sns.heatmap(norm_matrix, linewidths=0, cmap='RdBu_r', vmin=-vlim, vmax=vlim, xticklabels=False, ax=ax_heat)
            ax_heat.yaxis.set_ticklabels('TGCA', rotation='horizontal') # , size=10)

            # save final figure
            plt.tight_layout()
            plt.savefig('%s/%s_c%d_heat.pdf' % (options.out_dir,header.replace(':','_'), ci), dpi=300)
            plt.close()


        #################################################################
        # print table of nt variability for each cell
        #################################################################
        for ci in range(seq_mod_preds.shape[3]):
            seq_mod_preds_cell = seq_mod_preds[si,:,:,ci]
            real_pred_cell = get_real_pred(seq_mod_preds_cell, seq)

            min_scores = seq_mod_preds_cell.min(axis=0)
            max_scores = seq_mod_preds_cell.max(axis=0)

            loss_matrix = real_pred_cell - seq_mod_preds_cell.min(axis=0)
            gain_matrix = seq_mod_preds_cell.max(axis=0) - real_pred_cell

            for pos in range(seq_mod_preds_cell.shape[1]):
                cols = [header, delta_start+pos, ci, loss_matrix[pos], gain_matrix[pos]]
                print >> table_out, '\t'.join([str(c) for c in cols])

    table_out.close()
Exemplo n.º 3
0
def main():
    usage = 'usage: %prog [options] <model_th> <vcf_file>'
    parser = OptionParser(usage)
    parser.add_option('-c',
                      dest='csv',
                      default=False,
                      action='store_true',
                      help='Print table as CSV [Default: %default]')
    parser.add_option('--cuda',
                      dest='cuda',
                      default=False,
                      action='store_true',
                      help='Predict on the GPU [Default: %default]')
    parser.add_option('--cudnn',
                      dest='cudnn',
                      default=False,
                      action='store_true',
                      help='Predict on the GPU w/ CuDNN [Default: %default]')
    parser.add_option(
        '-d',
        dest='model_hdf5_file',
        default=None,
        help='Pre-computed model output as HDF5 [Default: %default]')
    parser.add_option(
        '--dense',
        dest='dense_table',
        default=False,
        action='store_true',
        help=
        'Print a dense SNP x Targets table, as opposed to a SNP/Target pair per line [Default: %default]'
    )
    parser.add_option(
        '-e',
        dest='heatmaps',
        default=False,
        action='store_true',
        help='Draw score heatmaps, grouped by index SNP [Default: %default]')
    parser.add_option(
        '-f',
        dest='genome_fasta',
        default='%s/data/genomes/hg19.fa' % os.environ['BASSETDIR'],
        help=
        'Genome FASTA from which sequences will be drawn [Default: %default]')
    parser.add_option(
        '--f1',
        dest='genome1_fasta',
        default=None,
        help='Genome FASTA which which major allele sequences will be drawn')
    parser.add_option(
        '--f2',
        dest='genome2_fasta',
        default=None,
        help='Genome FASTA which which minor allele sequences will be drawn')
    parser.add_option(
        '-i',
        dest='index_snp',
        default=False,
        action='store_true',
        help=
        'SNPs are labeled with their index SNP as column 6 [Default: %default]'
    )
    parser.add_option(
        '-l',
        dest='seq_len',
        type='int',
        default=1000,
        help='Sequence length provided to the model [Default: %default]')
    parser.add_option('-m',
                      dest='min_limit',
                      default=0.1,
                      type='float',
                      help='Minimum heatmap limit [Default: %default]')
    parser.add_option(
        '-o',
        dest='out_dir',
        default='sad',
        help='Output directory for tables and plots [Default: %default]')
    parser.add_option(
        '-s',
        dest='score',
        default=False,
        action='store_true',
        help='SNPs are labeled with scores as column 7 [Default: %default]')
    parser.add_option(
        '-t',
        dest='targets_file',
        default=None,
        help='File specifying target indexes and labels in table format')
    (options, args) = parser.parse_args()

    if len(args) != 2:
        parser.error('Must provide Torch model and VCF file')
    else:
        model_th = args[0]
        vcf_file = args[1]

    if not os.path.isdir(options.out_dir):
        os.mkdir(options.out_dir)

    #################################################################
    # prep SNP sequences
    #################################################################
    # load SNPs
    snps = vcf.vcf_snps(vcf_file, options.index_snp, options.score,
                        options.genome2_fasta is not None)

    # get one hot coded input sequences
    if not options.genome1_fasta or not options.genome2_fasta:
        seq_vecs, seq_headers, snps = vcf.snps_seq1(snps, options.seq_len,
                                                    options.genome_fasta)
    else:
        seq_vecs, seq_headers, snps = vcf.snps2_seq1(snps, options.seq_len,
                                                     options.genome1_fasta,
                                                     options.genome2_fasta)

    # reshape sequences for torch
    seq_vecs = seq_vecs.reshape(
        (seq_vecs.shape[0], 4, 1, seq_vecs.shape[1] // 4))

    # write to HDF5
    h5f = h5py.File('%s/model_in.h5' % options.out_dir, 'w')
    h5f.create_dataset('test_in', data=seq_vecs)
    h5f.close()

    #################################################################
    # predict in Torch
    #################################################################
    if options.model_hdf5_file is None:
        if options.cudnn:
            cuda_str = '-cudnn'
        elif options.cuda:
            cuda_str = '-cuda'
        else:
            cuda_str = ''

        options.model_hdf5_file = '%s/model_out.txt' % options.out_dir
        cmd = 'basset_predict.lua -rc %s %s %s/model_in.h5 %s' % (
            cuda_str, model_th, options.out_dir, options.model_hdf5_file)
        print(cmd)
        subprocess.call(cmd, shell=True)

    # read in predictions
    seq_preds = []
    for line in open(options.model_hdf5_file):
        seq_preds.append(np.array([np.float16(p) for p in line.split()]))
    seq_preds = np.array(seq_preds)

    #################################################################
    # collect and print SADs
    #################################################################
    if options.targets_file is None:
        target_labels = ['t%d' % ti for ti in range(seq_preds.shape[1])]
    else:
        target_labels = [
            line.split()[0] for line in open(options.targets_file)
        ]

    if options.dense_table:
        sad_out = open('%s/sad_table.txt' % options.out_dir, 'w')
    else:
        header_cols = ('rsid', 'index', 'score', 'ref', 'alt', 'target',
                       'ref_pred', 'alt pred', 'sad')
        if options.csv:
            sad_out = open('%s/sad_table.csv' % options.out_dir, 'w')
            print(','.join(header_cols), file=sad_out)
        else:
            sad_out = open('%s/sad_table.txt' % options.out_dir, 'w')
            print(' '.join(header_cols), file=sad_out)

    # hash by index snp
    sad_matrices = {}
    sad_labels = {}
    sad_scores = {}

    pi = 0
    for snp in snps:
        # get reference prediction
        ref_preds = seq_preds[pi, :]
        pi += 1

        for alt_al in snp.alt_alleles:
            # get alternate prediction
            alt_preds = seq_preds[pi, :]
            pi += 1

            # normalize by reference
            alt_sad = alt_preds - ref_preds
            sad_matrices.setdefault(snp.index_snp, []).append(alt_sad)

            # label as mutation from reference
            alt_label = '%s_%s>%s' % (snp.rsid, vcf.cap_allele(
                snp.ref_allele), vcf.cap_allele(alt_al))
            sad_labels.setdefault(snp.index_snp, []).append(alt_label)

            # save scores
            sad_scores.setdefault(snp.index_snp, []).append(snp.score)

            # set index SNP
            snp_is = '%-13s' % '.'
            if options.index_snp:
                snp_is = '%-13s' % snp.index_snp

            # set score
            snp_score = '%5s' % '.'
            if options.score:
                snp_score = '%5.3f' % snp.score

            # print table line(s)
            if options.dense_table:
                cols = [
                    snp.rsid, snp_is, snp_score,
                    vcf.cap_allele(snp.ref_allele),
                    vcf.cap_allele(alt_al)
                ]
                for ti in range(len(alt_sad)):
                    cols += ['%.4f' % ref_preds[ti], '%.4f' % alt_sad[ti]]

                sep = ' '
                if options.csv:
                    sep = ','

                print(sep.join([str(c) for c in cols]), file=sad_out)

            else:
                for ti in range(len(alt_sad)):
                    cols = (snp.rsid, snp_is, snp_score,
                            vcf.cap_allele(snp.ref_allele),
                            vcf.cap_allele(alt_al), target_labels[ti],
                            ref_preds[ti], alt_preds[ti], alt_sad[ti])
                    if options.csv:
                        print(','.join([str(c) for c in cols]), file=sad_out)
                    else:
                        print('%-13s %s %5s %6s %6s %12s %6.4f %6.4f %7.4f' %
                              cols,
                              file=sad_out)

    sad_out.close()

    #################################################################
    # plot SAD heatmaps
    #################################################################
    if options.heatmaps:
        for ii in sad_matrices:
            # convert fully to numpy arrays
            sad_matrix = abs(np.array(sad_matrices[ii]))
            print(ii, sad_matrix.shape)

            if sad_matrix.shape[0] > 1:
                vlim = max(options.min_limit, sad_matrix.max())
                score_mat = np.reshape(np.array(sad_scores[ii]), (-1, 1))

                # plot heatmap
                plt.figure(figsize=(20, 0.5 + 0.5 * sad_matrix.shape[0]))

                if options.score:
                    # lay out scores
                    cols = 12
                    ax_score = plt.subplot2grid((1, cols), (0, 0))
                    ax_sad = plt.subplot2grid((1, cols), (0, 1),
                                              colspan=(cols - 1))

                    sns.heatmap(score_mat,
                                xticklabels=False,
                                yticklabels=False,
                                vmin=0,
                                vmax=1,
                                cmap='Reds',
                                cbar=False,
                                ax=ax_score)
                else:
                    ax_sad = plt.gca()

                sns.heatmap(sad_matrix,
                            xticklabels=target_labels,
                            yticklabels=sad_labels[ii],
                            vmin=0,
                            vmax=vlim,
                            ax=ax_sad)

                for tick in ax_sad.get_xticklabels():
                    tick.set_rotation(-45)
                    tick.set_horizontalalignment('left')
                    tick.set_fontsize(5)

                plt.tight_layout()
                if ii == '.':
                    out_pdf = '%s/sad_heat.pdf' % options.out_dir
                else:
                    out_pdf = '%s/sad_%s_heat.pdf' % (options.out_dir, ii)
                plt.savefig(out_pdf)
                plt.close()