Esempio n. 1
0
def plot_tuning_mtype_grid(par_df, auc_df, use_clf, args, cdata):
    par_count = len(use_clf.tune_priors)
    fig, axarr = plt.subplots(figsize=(0.5 + 7 * par_count, 7 * par_count),
                              nrows=par_count,
                              ncols=par_count)

    auc_vals = auc_df.quantile(q=0.25, axis=1)
    auc_clrs = auc_vals.apply(auc_cmap)
    size_vec = [
        461 * sum(cdata.train_pheno(mtype)) /
        (len(cdata.get_samples()) * par_count) for mtype in auc_vals.index
    ]

    for i, (par_name, tune_distr) in enumerate(use_clf.tune_priors):
        axarr[i, i].grid(False)

        if detect_log_distr(tune_distr):
            use_distr = [np.log10(par_val) for par_val in tune_distr]
            par_lbl = par_name + '\n(log-scale)'

        else:
            use_distr = tune_distr
            par_lbl = par_name

        distr_diff = np.mean(
            np.array(use_distr[1:]) - np.array(use_distr[:-1]))
        plt_min = use_distr[0] - distr_diff / 2
        plt_max = use_distr[-1] + distr_diff / 2

        axarr[i, i].set_xlim(plt_min, plt_max)
        axarr[i, i].set_ylim(plt_min, plt_max)
        axarr[i, i].text((plt_min + plt_max) / 2, (plt_min + plt_max) / 2,
                         par_lbl,
                         ha='center',
                         fontsize=28,
                         weight='semibold')

        for par_val in use_distr:
            axarr[i, i].axhline(y=par_val,
                                color='#116611',
                                ls='--',
                                linewidth=4.1,
                                alpha=0.27)
            axarr[i, i].axvline(x=par_val,
                                color='#116611',
                                ls='--',
                                linewidth=4.1,
                                alpha=0.27)

    for (i, (par_name1, tn_distr1)), (j, (par_name2, tn_distr2)) in combn(
            enumerate(use_clf.tune_priors), 2):

        if detect_log_distr(tn_distr1):
            use_distr1 = [np.log10(par_val) for par_val in tn_distr1]
            par_meds1 = np.log10(par_df[par_name1]).median(axis=1)
            par_means1 = np.log10(par_df[par_name1]).mean(axis=1)

            distr_diff = np.mean(
                np.log10(np.array(tn_distr1[1:])) -
                np.log10(np.array(tn_distr1[:-1])))
            plt_ymin = np.log10(tn_distr1[0]) - distr_diff / 2
            plt_ymax = np.log10(tn_distr1[-1]) + distr_diff / 2

        else:
            use_distr1 = tn_distr1
            par_meds1 = par_df[par_name1].median(axis=1)
            par_means1 = par_df[par_name1].mean(axis=1)

            distr_diff = np.mean(
                np.array(tn_distr1[1:]) - np.array(tn_distr1[:-1]))
            plt_ymin = tn_distr1[0] - distr_diff / 2
            plt_ymax = tn_distr1[-1] + distr_diff / 2

        if detect_log_distr(tn_distr2):
            use_distr2 = [np.log10(par_val) for par_val in tn_distr2]
            par_meds2 = np.log10(par_df[par_name2]).median(axis=1)
            par_means2 = np.log10(par_df[par_name2]).mean(axis=1)

            distr_diff = np.mean(
                np.log10(np.array(tn_distr2[1:])) -
                np.log10(np.array(tn_distr2[:-1])))
            plt_xmin = np.log10(tn_distr2[0]) - distr_diff / 2
            plt_xmax = np.log10(tn_distr2[-1]) + distr_diff / 2

        else:
            use_distr2 = tn_distr2
            par_meds2 = par_df[par_name2].median(axis=1)
            par_means2 = par_df[par_name2].mean(axis=1)

            distr_diff = np.mean(
                np.array(tn_distr2[1:]) - np.array(tn_distr2[:-1]))
            plt_xmin = tn_distr2[0] - distr_diff / 2
            plt_xmax = tn_distr2[-1] + distr_diff / 2

        par_meds1 = par_meds1[auc_clrs.index]
        par_meds2 = par_meds2[auc_clrs.index]
        y_adj = (plt_ymax - plt_ymin) / len(tn_distr1)
        x_adj = (plt_xmax - plt_xmin) / len(tn_distr2)
        plt_adj = (plt_xmax - plt_xmin) / (plt_ymax - plt_ymin)

        for med1, med2 in set(zip(par_meds1, par_meds2)):
            use_indx = (par_meds1 == med1) & (par_meds2 == med2)

            cnt_adj = use_indx.sum()**0.49
            use_sizes = [s for s, ix in zip(size_vec, use_indx) if ix]
            sort_indx = sorted(enumerate(use_sizes),
                               key=lambda x: x[1],
                               reverse=True)

            from circlify import circlify
            mpl.use('Agg')

            for k, circ in enumerate(circlify([s for _, s in sort_indx])):
                axarr[i, j].scatter(
                    med2 + (1 / 23) * cnt_adj * circ.y * plt_adj,
                    med1 + (1 / 23) * cnt_adj * circ.x * plt_adj**-1,
                    s=sort_indx[k][1],
                    c=auc_clrs[use_indx][sort_indx[k][0]],
                    alpha=0.36,
                    edgecolor='black')

        par_means1 += np.random.normal(0, y_adj / 27, auc_df.shape[0])
        par_means2 += np.random.normal(0, x_adj / 27, auc_df.shape[0])
        axarr[j, i].scatter(par_means1[auc_clrs.index],
                            par_means2[auc_clrs.index],
                            s=size_vec,
                            c=auc_clrs,
                            alpha=0.36,
                            edgecolor='black')

        axarr[i, j].set_xlim(plt_xmin, plt_xmax)
        axarr[i, j].set_ylim(plt_ymin, plt_ymax)
        axarr[j, i].set_ylim(plt_xmin, plt_xmax)
        axarr[j, i].set_xlim(plt_ymin, plt_ymax)

        annot_placed = place_annot(par_meds2,
                                   par_meds1,
                                   size_vec=size_vec,
                                   annot_vec=auc_vals.index,
                                   x_range=plt_xmax - plt_xmin,
                                   y_range=plt_ymax - plt_ymin)

        for annot_x, annot_y, annot, halign in annot_placed:
            axarr[i, j].text(annot_x, annot_y, annot, size=11, ha=halign)

        for par_val1 in use_distr1:
            axarr[i, j].axhline(y=par_val1,
                                color='#116611',
                                ls=':',
                                linewidth=2.3,
                                alpha=0.19)
            axarr[j, i].axvline(x=par_val1,
                                color='#116611',
                                ls=':',
                                linewidth=2.3,
                                alpha=0.19)

        for par_val2 in use_distr2:
            axarr[i, j].axvline(x=par_val2,
                                color='#116611',
                                ls=':',
                                linewidth=2.3,
                                alpha=0.19)
            axarr[j, i].axhline(y=par_val2,
                                color='#116611',
                                ls=':',
                                linewidth=2.3,
                                alpha=0.19)

    plt.tight_layout()
    fig.savefig(os.path.join(
        plot_dir, args.expr_source,
        "{}__samps-{}".format(args.cohort, args.samp_cutoff),
        args.model_name.split('__')[0],
        "{}__tuning-mtype-grid.svg".format(args.model_name.split('__')[1])),
                bbox_inches='tight',
                format='svg')

    plt.close()
Esempio n. 2
0
def plot_tuning_mtype(par_df, auc_df, use_clf, args, cdata):
    fig, axarr = plt.subplots(figsize=(1 + 9 * len(use_clf.tune_priors), 13),
                              nrows=3,
                              ncols=len(use_clf.tune_priors),
                              gridspec_kw={'height_ratios': [1, 0.3, 1]},
                              squeeze=False,
                              sharex=False,
                              sharey=True)

    auc_vals = auc_df.quantile(q=0.25, axis=1)
    size_vec = [
        198 * len(mtype.get_samples(cdata.mtree)) / len(cdata.get_samples())
        for mtype in auc_vals.index
    ]

    for i, (par_name, tune_distr) in enumerate(use_clf.tune_priors):
        axarr[1, i].set_axis_off()
        axarr[2, i].tick_params(length=6)

        if detect_log_distr(tune_distr):
            med_vals = np.log10(par_df[par_name]).median(axis=1)
            mean_vals = np.log10(par_df[par_name]).mean(axis=1)
            use_distr = [np.log10(par_val) for par_val in tune_distr]
            par_lbl = par_name + '\n(log-scale)'

        else:
            med_vals = par_df[par_name].median(axis=1)
            mean_vals = par_df[par_name].mean(axis=1)
            use_distr = tune_distr
            par_lbl = par_name

        med_vals = med_vals[auc_vals.index]
        mean_vals = mean_vals[auc_vals.index]
        distr_diff = np.mean(
            np.array(use_distr[1:]) - np.array(use_distr[:-1]))

        for j in range(3):
            axarr[j, i].set_xlim(use_distr[0] - distr_diff / 2,
                                 use_distr[-1] + distr_diff / 2)

        axarr[1, i].text((use_distr[0] + use_distr[-1]) / 2,
                         0.5,
                         par_lbl,
                         ha='center',
                         va='center',
                         fontsize=25,
                         weight='semibold')

        med_vals += np.random.normal(0, (use_distr[-1] - use_distr[0]) /
                                     (len(tune_distr) * 17), auc_df.shape[0])
        mean_vals += np.random.normal(0, (use_distr[-1] - use_distr[0]) /
                                      (len(tune_distr) * 23), auc_df.shape[0])

        axarr[0, i].scatter(med_vals,
                            auc_vals,
                            s=size_vec,
                            c='black',
                            alpha=0.23)
        axarr[2, i].scatter(mean_vals,
                            auc_vals,
                            s=size_vec,
                            c='black',
                            alpha=0.23)

        axarr[0, i].set_ylim(0, 1)
        axarr[2, i].set_ylim(0, 1)
        axarr[0, i].set_ylabel("1st Quartile AUC", size=19, weight='semibold')
        axarr[2, i].set_ylabel("1st Quartile AUC", size=19, weight='semibold')

        axarr[0, i].axhline(y=0.5,
                            color='#550000',
                            linewidth=2.3,
                            linestyle='--',
                            alpha=0.32)
        axarr[2, i].axhline(y=0.5,
                            color='#550000',
                            linewidth=2.3,
                            linestyle='--',
                            alpha=0.32)

        for par_val in use_distr:
            axarr[1, i].axvline(x=par_val,
                                color='#116611',
                                ls='--',
                                linewidth=3.4,
                                alpha=0.27)

            axarr[0, i].axvline(x=par_val,
                                color='#116611',
                                ls=':',
                                linewidth=1.3,
                                alpha=0.16)
            axarr[2, i].axvline(x=par_val,
                                color='#116611',
                                ls=':',
                                linewidth=1.3,
                                alpha=0.16)

        annot_placed = place_annot(med_vals,
                                   auc_vals.values.tolist(),
                                   size_vec=size_vec,
                                   annot_vec=auc_vals.index,
                                   x_range=use_distr[-1] - use_distr[0] +
                                   2 * distr_diff,
                                   y_range=1)

        for annot_x, annot_y, annot, halign in annot_placed:
            axarr[0, i].text(annot_x, annot_y, annot, size=8, ha=halign)

    plt.tight_layout(h_pad=0)
    fig.savefig(os.path.join(
        plot_dir, args.expr_source,
        "{}__samps-{}".format(args.cohort, args.samp_cutoff),
        args.model_name.split('__')[0],
        "{}__tuning-mtype.svg".format(args.model_name.split('__')[1])),
                bbox_inches='tight',
                format='svg')

    plt.close()
Esempio n. 3
0
def plot_tuned_auc(out_dict, phn_dict, auc_dict, args):
    tune_priors = tuple(out_dict.values())[0]['Clf'].tune_priors
    fig, axarr = plt.subplots(figsize=(17, 1 + 7 * len(tune_priors)),
                              nrows=len(tune_priors),
                              ncols=1,
                              squeeze=False)

    use_srcs = sorted(set(src for src, _, _ in out_dict.keys()))
    src_mrks = dict(zip(use_srcs, use_marks[:len(use_srcs)]))

    use_cohs = sorted(set(coh for _, coh, _ in out_dict.keys()))
    coh_clrs = dict(
        zip(use_cohs, sns.color_palette("muted", n_colors=len(use_cohs))))

    for ax, (par_name, tune_distr) in zip(axarr.flatten(), tune_priors):
        if detect_log_distr(tune_distr):
            par_fnc = np.log10
            plt_xmin = 2 * np.log10(tune_distr[0]) - np.log10(tune_distr[1])
            plt_xmax = 2 * np.log10(tune_distr[-1]) - np.log10(tune_distr[-2])

        else:
            par_fnc = lambda x: x
            plt_xmin = 2 * tune_distr[0] - tune_distr[1]
            plt_xmax = 2 * tune_distr[-1] - tune_distr[-2]

        for (src, coh, lvls), ols in out_dict.items():
            tune_df = ols['Pars'].loc[:, (slice(None), par_name)]
            phn_list = phn_dict[src, coh, lvls]
            auc_df = auc_dict[src, coh, lvls]

            for mtype, vals in tune_df.iterrows():
                for (cis_lbl, _), val in vals.iteritems():
                    ax.scatter(par_fnc(val),
                               auc_df.loc[mtype, cis_lbl],
                               marker=src_mrks[src],
                               s=371 * np.mean(phn_list[mtype]),
                               c=[coh_clrs[coh]],
                               alpha=0.17,
                               edgecolor='none')

        ax.set_xlim(plt_xmin, plt_xmax)
        ax.set_ylim(0.48, 1.02)
        ax.tick_params(labelsize=19)
        ax.set_xlabel('Tuned {} Value'.format(par_name),
                      fontsize=27,
                      weight='semibold')

        ax.axhline(y=1.0, color='black', linewidth=2.1, alpha=0.37)
        ax.axhline(y=0.5,
                   color='#550000',
                   linewidth=2.7,
                   linestyle='--',
                   alpha=0.29)

    fig.text(-0.01,
             0.5,
             'Aggregate AUC',
             ha='center',
             va='center',
             fontsize=27,
             weight='semibold',
             rotation='vertical')

    plt.tight_layout(h_pad=1.7)
    fig.savefig(os.path.join(plot_dir,
                             "{}__tuned-auc.svg".format(args.classif)),
                bbox_inches='tight',
                format='svg')

    plt.close()
Esempio n. 4
0
def plot_tuning_profile(out_dict, auc_dict, args):
    tune_priors = tuple(out_dict.values())[0]['Clf'].tune_priors
    fig, axarr = plt.subplots(figsize=(17, 1 + 7 * len(tune_priors)),
                              nrows=len(tune_priors),
                              ncols=1,
                              squeeze=False)

    use_aucs = pd.concat(auc_dict.values()).round(4)
    auc_bins = pd.qcut(
        use_aucs.values.flatten(),
        q=[0., 0.5, 0.75, 0.8, 0.85, 0.9, 0.92, 0.94, 0.96, 0.98, 0.99, 1.],
        precision=5).categories

    use_srcs = sorted(set(src for src, _, _ in out_dict.keys()))
    src_mrks = dict(zip(use_srcs, use_marks[:len(use_srcs)]))

    use_cohs = sorted(set(coh for _, coh, _ in out_dict.keys()))
    coh_clrs = dict(
        zip(use_cohs, sns.color_palette("muted", n_colors=len(use_cohs))))

    for ax, (par_name, tune_distr) in zip(axarr.flatten(), tune_priors):
        if detect_log_distr(tune_distr):
            par_fnc = np.log10
            plt_xmin = 2 * np.log10(tune_distr[0]) - np.log10(tune_distr[1])
            plt_xmax = 2 * np.log10(tune_distr[-1]) - np.log10(tune_distr[-2])

        else:
            par_fnc = lambda x: x
            plt_xmin = 2 * tune_distr[0] - tune_distr[1]
            plt_xmax = 2 * tune_distr[-1] - tune_distr[-2]

        plot_df = pd.DataFrame([])
        for (src, coh, lvls), ols in out_dict.items():
            for cis_lbl in cis_lbls:
                tune_vals = pd.DataFrame.from_records(ols['Acc'].loc[:,
                                                                     (cis_lbl,
                                                                      'avg')])

                tune_vals -= pd.DataFrame.from_records(ols['Acc'].loc[:,
                                                                      (cis_lbl,
                                                                       'std')])
                par_vals = pd.DataFrame.from_records(
                    ols['Acc'].loc[:, (cis_lbl, 'par')]).applymap(
                        itemgetter(par_name)).applymap(par_fnc)

                tune_vals.index = ols['Acc'].index
                par_vals.index = ols['Acc'].index
                par_df = pd.concat(
                    [par_vals.stack(), tune_vals.stack()],
                    axis=1,
                    keys=['par', 'auc'])

                par_df['auc_bin'] = [
                    auc_bins.get_loc(
                        round(auc_dict[src, coh, lvls].loc[mtype, cis_lbl], 4))
                    for mtype, _ in par_df.index
                ]
                plot_df = pd.concat([plot_df, par_df])

        for auc_bin, bin_vals in plot_df.groupby('auc_bin'):
            plot_vals = bin_vals.groupby('par').mean()
            ax.plot(plot_vals.index, plot_vals.auc)

        ax.set_xlim(plt_xmin, plt_xmax)
        ax.set_ylim(0.45, 1.01)
        ax.tick_params(labelsize=19)
        ax.set_xlabel('Tested {} Value'.format(par_name),
                      fontsize=27,
                      weight='semibold')

        ax.axhline(y=1.0, color='black', linewidth=2.1, alpha=0.37)
        ax.axhline(y=0.5,
                   color='#550000',
                   linewidth=2.7,
                   linestyle='--',
                   alpha=0.29)

    fig.text(-0.01,
             0.5,
             'Aggregate AUC',
             ha='center',
             va='center',
             fontsize=27,
             weight='semibold',
             rotation='vertical')

    plt.tight_layout(h_pad=1.7)
    fig.savefig(os.path.join(plot_dir,
                             "{}__tuning-profile.svg".format(args.classif)),
                bbox_inches='tight',
                format='svg')

    plt.close()
Esempio n. 5
0
def plot_tuning_auc(out_list, args):
    tune_priors = out_list[0][0]['Clf'].tune_priors
    fig, axarr = plt.subplots(figsize=(17, 1 + 7 * len(tune_priors)),
                              nrows=len(tune_priors),
                              ncols=1,
                              squeeze=False)

    use_cohs = sorted(set(coh for _, _, (coh, _) in out_list))
    coh_mrks = dict(zip(use_cohs, use_marks[:len(use_cohs)]))

    use_genes = sorted(set(gene for _, _, (_, gene) in out_list))
    gene_clrs = dict(
        zip(use_genes, sns.color_palette("muted", n_colors=len(use_genes))))

    for ax, (par_name, tune_distr) in zip(axarr.flatten(), tune_priors):
        if detect_log_distr(tune_distr):
            par_fnc = np.log10
            plt_xmin = 2 * np.log10(tune_distr[0]) - np.log10(tune_distr[1])
            plt_xmax = 2 * np.log10(tune_distr[-1]) - np.log10(tune_distr[-2])

        else:
            par_fnc = lambda x: x
            plt_xmin = 2 * tune_distr[0] - tune_distr[1]
            plt_xmax = 2 * tune_distr[-1] - tune_distr[-2]

        for out_dict, (pheno_dict, auc_list), (coh, gene) in out_list:
            tune_df = out_dict['Tune'].loc[:, (slice(None), par_name)]

            for mtype, (all_val, iso_val) in tune_df.iterrows():
                ax.scatter(par_fnc(all_val),
                           auc_list.loc[mtype, 'All'],
                           marker=coh_mrks[coh],
                           s=551 * np.mean(pheno_dict[mtype]),
                           c=gene_clrs[gene],
                           alpha=0.23)

                ax.scatter(par_fnc(iso_val),
                           auc_list.loc[mtype, 'Iso'],
                           marker=coh_mrks[coh],
                           s=551 * np.mean(pheno_dict[mtype]),
                           c=gene_clrs[gene],
                           alpha=0.23)

        ax.set_xlim(plt_xmin, plt_xmax)
        ax.set_ylim(0.48, 1.02)
        ax.tick_params(labelsize=19)
        ax.set_xlabel('Tuned {} Value'.format(par_name),
                      fontsize=27,
                      weight='semibold')

        ax.axhline(y=1.0, color='black', linewidth=2.1, alpha=0.37)
        ax.axhline(y=0.5,
                   color='#550000',
                   linewidth=2.7,
                   linestyle='--',
                   alpha=0.29)

    fig.text(-0.01,
             0.5,
             'Aggregate AUC',
             ha='center',
             va='center',
             fontsize=27,
             weight='semibold',
             rotation='vertical')

    plt.tight_layout(h_pad=1.7)
    fig.savefig(os.path.join(plot_dir,
                             "{}__tuning-auc.svg".format(args.classif)),
                bbox_inches='tight',
                format='svg')

    plt.close()
Esempio n. 6
0
def plot_tuning_grid(out_list, args):
    par_count = len(use_clf.tune_priors)
    fig, axarr = plt.subplots(figsize=(0.5 + 7 * par_count, 7 * par_count),
                              nrows=par_count,
                              ncols=par_count)

    coh_vec = reduce(add, [[out_path[0]] * infer_df['Iso'].shape[0] * 2
                           for out_path, infer_df, _, in out_list])
    use_cohs = sorted(set(coh_vec))
    mark_vec = [use_marks[use_cohs.index(coh)] for coh in coh_vec]

    gene_vec = reduce(add, [[out_path[1]] * infer_df['Iso'].shape[0] * 2
                            for out_path, infer_df, _, in out_list])
    use_genes = sorted(set(gene_vec))
    gene_clrs = sns.color_palette("muted", n_colors=len(use_genes))
    clr_vec = [gene_clrs[use_genes.index(gn)] for gn in gene_vec]

    size_vec = np.concatenate([
        np.repeat([np.sum(stat_dict[mcomb]) for mcomb in tune_df.index], 2)
        for (_, _, tune_df), (stat_dict, _) in zip(out_list, score_list)
    ])
    size_vec = 341 * size_vec / np.max(size_vec)

    par_vals = {
        par_name: np.concatenate([
            tune_df.loc[:, (slice(None), par_name)].values.flatten()
            for (_, _, tune_df) in out_list
        ])
        for par_name, _ in use_clf.tune_priors
    }

    auc_vals = np.concatenate([
        auc_df.loc[tune_df.index].values.flatten()
        for (_, _, tune_df), (_, auc_df) in zip(out_list, score_list)
    ])
    auc_clrs = [auc_cmap(auc_val) for auc_val in auc_vals]

    for i, (par_name, tune_distr) in enumerate(use_clf.tune_priors):
        axarr[i, i].grid(False)

        if detect_log_distr(tune_distr):
            use_distr = [np.log10(par_val) for par_val in tune_distr]
            par_lbl = par_name + '\n(log-scale)'

        else:
            use_distr = tune_distr
            par_lbl = par_name

        distr_diff = np.array(use_distr[-1]) - np.array(use_distr[0])
        plt_min = use_distr[0] - distr_diff / 9
        plt_max = use_distr[-1] + distr_diff / 9

        axarr[i, i].set_xlim(plt_min, plt_max)
        axarr[i, i].set_ylim(plt_min, plt_max)
        axarr[i, i].text((plt_min + plt_max) / 2, (plt_min + plt_max) / 2,
                         par_lbl,
                         ha='center',
                         fontsize=31,
                         weight='semibold')

        for par_val in use_distr:
            axarr[i, i].axhline(y=par_val,
                                color='#116611',
                                ls='--',
                                linewidth=1.9,
                                alpha=0.23)
            axarr[i, i].axvline(x=par_val,
                                color='#116611',
                                ls='--',
                                linewidth=1.9,
                                alpha=0.23)

    for (i, (par_name1, tn_distr1)), (j, (par_name2, tn_distr2)) in combn(
            enumerate(use_clf.tune_priors), 2):

        if detect_log_distr(tn_distr1):
            use_vals1 = np.log10(par_vals[par_name1])
            distr_diff = np.log10(np.array(tn_distr1[-1]))
            distr_diff -= np.log10(np.array(tn_distr1[0]))

            plt_ymin = np.log10(tn_distr1[0]) - distr_diff / 9
            plt_ymax = np.log10(tn_distr1[-1]) + distr_diff / 9

        else:
            use_vals1 = par_vals[par_name1]
            distr_diff = tn_distr1[-1] - tn_distr1[0]
            plt_ymin = tn_distr1[0] - distr_diff / 9
            plt_ymax = tn_distr1[-1] + distr_diff / 9

        if detect_log_distr(tn_distr2):
            use_vals2 = np.log10(par_vals[par_name2])
            distr_diff = np.log10(np.array(tn_distr2[-1]))
            distr_diff -= np.log10(np.array(tn_distr2[0]))

            plt_xmin = np.log10(tn_distr2[0]) - distr_diff / 9
            plt_xmax = np.log10(tn_distr2[-1]) + distr_diff / 9

        else:
            use_vals2 = par_vals[par_name2]
            distr_diff = tn_distr2[-1] - tn_distr2[0]
            plt_xmin = tn_distr2[0] - distr_diff / 9
            plt_xmax = tn_distr2[-1] + distr_diff / 9

        use_vals1 += np.random.normal(0, (plt_ymax - plt_ymin) /
                                      (len(tn_distr1) * 11), auc_vals.shape[0])
        use_vals2 += np.random.normal(0, (plt_xmax - plt_xmin) /
                                      (len(tn_distr2) * 11), auc_vals.shape[0])

        for use_val2, use_val1, mark_val, size_val, auc_val in zip(
                use_vals2, use_vals1, mark_vec, size_vec, auc_clrs):
            axarr[i, j].scatter(use_val2,
                                use_val1,
                                marker=mark_val,
                                s=size_val,
                                c=auc_val,
                                alpha=0.35,
                                edgecolor='black')

        for use_val1, use_val2, mark_val, size_val, gene_val in zip(
                use_vals1, use_vals2, mark_vec, size_vec, clr_vec):
            axarr[j, i].scatter(use_val1,
                                use_val2,
                                marker=mark_val,
                                s=size_val,
                                c=gene_val,
                                alpha=0.35,
                                edgecolor='black')

        axarr[i, j].set_xlim(plt_xmin, plt_xmax)
        axarr[i, j].set_ylim(plt_ymin, plt_ymax)
        axarr[j, i].set_ylim(plt_xmin, plt_xmax)
        axarr[j, i].set_xlim(plt_ymin, plt_ymax)

    plt.tight_layout()
    fig.savefig(os.path.join(plot_dir,
                             "{}__tuning-grid.svg".format(args.classif)),
                bbox_inches='tight',
                format='svg')

    plt.close()