def make_figure(
    data, config, methods, infos, runtimes,
):

    seaborn_config(len(methods))

    fig, axes = plt.subplots(2, len(config["n_chan"]), figsize=figsize)

    leg_handles = {}
    for c, n_chan in enumerate(config["n_chan"]):

        for method in methods:
            print(f"=== {method:7s} runtime={runtimes[n_chan][method]:.3f} ===")

        # make the figure
        min_cost = np.inf
        for method in methods:
            loc_min = infos[n_chan][method]["head_costs"].min()
            if loc_min < min_cost:
                min_cost = loc_min
        min_cost = 0

        cost_ylim = [np.inf, -np.inf]
        for method in methods_order:
            key = methods[method]

            median_head_error = np.median(infos[n_chan][method]["head_errors"], axis=0)
            axes[0, c].loglog(
                np.arange(len(median_head_error)) + 1, median_head_error, label=method
            )
            axes[0, c].set_ylim([1e-32, 1000])
            axes[0, c].set_xticks([1, 10, 100, 1000])
            axes[0, c].set_xticklabels(["", "", "", ""])
            if c > 0:
                axes[0, c].set_yticks([])
            else:
                axes[0, c].set_yticks([1e-30, 1e-20, 1e-10, 1e0])

            # cost
            cost_agg = np.median(infos[n_chan][method]["head_costs"] - min_cost, axis=0)
            if method != "NCG" and method != "IPA+NCG":
                cost_ylim[0] = np.minimum(cost_agg.min(), cost_ylim[0])
                cost_ylim[1] = np.maximum(cost_agg.max(), cost_ylim[1])

            axes[1, c].semilogx(np.arange(len(cost_agg)) + 1, cost_agg, label=method)
            axes[1, c].set_xticks([1, 10, 100, 1000])
            axes[1, c].yaxis.labelpad = 1

            # X axis limits
            axes[0, c].set_xlim([1, 500])
            axes[1, c].set_xlim([1, 500])
            # Y axis labels
            axes[0, c].set_title(f"$M={n_chan}$")
            axes[1, c].set_xlabel("Iteration")
            if c == 0:
                axes[0, c].set_ylabel("SeDJoCo Residual")
                axes[1, c].set_ylabel("Surrogate Cost")

            # keep track of the legend
            handles, labels = axes[0, c].get_legend_handles_labels()
            for lbl, hand in zip(labels, handles):
                if lbl not in leg_handles:
                    if lbl.endswith(" (PCA)"):
                        lbl = lbl[:-6]
                    leg_handles[lbl] = hand

        cost_ylim = np.array(cost_ylim)
        cost_ylim_m = 0.80 * cost_ylim[0] + 0.20 * cost_ylim[1]
        cost_ylim = cost_ylim_m + np.r_[1.05, 0.0] * (cost_ylim - cost_ylim_m)
        axes[1, c].set_ylim(cost_ylim)

    sns.despine(fig=fig)

    fig.tight_layout(pad=0.1, h_pad=0.5)

    fig.legend(
        leg_handles.values(),
        leg_handles.keys(),
        fontsize="x-small",
        loc="upper center",
        bbox_to_anchor=[0.5, 1.01],
        ncol=len(methods_order),
        frameon=False,
    )
    fig.subplots_adjust(top=0.86)

    # fig.align_ylabels(axes[:, 0])

    for j in range(2):
        axes[j, 0].yaxis.set_label_coords(-0.41, 0.5)

    return fig, axes
def make_plot_isr(config, arg_isr_tables, arg_cost_tables, with_pca=True):

    # pick only the desired params and
    params = []
    isr_tables = []
    cost_tables = []
    n_algos = 0
    for p, isr, cost in zip(config["params"], arg_isr_tables, arg_cost_tables):
        if p["pca"] != with_pca:
            continue
        params.append(p)
        cost_tables.append(cost)

        isr_tables.append({})
        n_algos = 0
        for alg_name, alg_dict in isr.items():
            if alg_name in include_algos:
                isr_tables[-1][alg_name] = alg_dict
                n_algos += 1

    # pick the algorithms to include
    """
    new_isr_tables = {}
    for sub_dict in isr_tables:
        n_algos = 0
        rm_list = []
        for algo, d in sub_dict:
            if algo not in include_algos:
                rm_list.append(algo)
            else:
                n_algos += 1
        for a in rm_list:
            sub_dict.pop(a)
    """

    # construct the mosaic
    # n_algos = len(include_algos)
    n_rows = len(params)
    mosaic_array = []
    mosaic_len_left = n_algos // 2
    mosaic_row_len = mosaic_len_left + n_algos
    mos_map = []
    assert n_rows * (n_algos + 1) <= len(ascii_letters)
    for b in range(n_rows):
        mosaic_array.append([ascii_letters[b] * mosaic_len_left])
        mos_map.append([mosaic_array[b][0][0]])
        for i in range(1, n_algos + 1):
            letters = ascii_letters[n_rows * i + b]
            mosaic_array[b].append(letters)
            mos_map[b].append(letters)
    mosaic = "\n".join(["".join(a) for a in mosaic_array])

    # prepare the style
    seaborn_config(n_algos)

    # create the figure
    fig_size = (figsize[0], figsize[1] * len(params) / 3)
    fig, axes = plt.subplot_mosaic(mosaic, figsize=fig_size)

    # container for some info we will fill as we go
    leg_handles = {}
    y_lim_isr = [0, 1]
    x_lim_hist_isr = [0, 0]
    percent_converge = []

    for ip, pmt in enumerate(params):
        n_freq = pmt["n_freq"]
        n_chan = pmt["n_chan"]
        percent_converge.append({})

        for i, algo in enumerate(include_algos):
            if algo not in isr_tables[ip]:
                continue
            table = isr_tables[ip][algo]

            n_iter = config["algos"][algo]["kwargs"]["n_iter"]
            algo_name = config["algos"][algo]["algo"]

            if bss.is_dual_update[algo_name]:
                callback_checkpoints = np.arange(0, n_iter + 1, 2)
            else:
                callback_checkpoints = np.arange(0, n_iter + 1)

            # isr
            y_lim_isr = [
                min(y_lim_isr[0], table.min()),
                max(y_lim_isr[1], np.percentile(table, 99.5)),
            ]

            I_s = table[:, -1] < fail_thresh  # separation is sucessful
            I_f = table[:, -1] >= fail_thresh  # separation fails

            f_agg = np.mean

            # ISR convergence
            p = axes[mos_map[ip][0]].semilogx(
                np.array(callback_checkpoints),
                f_agg(table[I_s, :], axis=0),
                label=title_dict[algo],
            )

            # keep the percentage and mean of success/failure
            percent_converge[ip][algo] = (
                np.sum(I_s) / len(I_s),
                f_agg(table[I_s, -1]),
                f_agg(table[I_f, -1]),
            )

            # get color of main line
            c = p[0].get_color()

            # now draw the divergent line
            axes[mos_map[ip][0]].plot(
                np.array(callback_checkpoints),
                f_agg(table[I_f, :], axis=0),
                alpha=0.6,
                c=c,
                linestyle="--",
            )

            # Histograms
            bin_heights, bins, patches = axes[mos_map[ip][i + 1]].hist(
                table[:, -1],
                bins=n_bins,
                orientation="horizontal",
                density=True,
                color=c,
                linewidth=0.0,
            )

            # keep track of required length of x-axis for the histograms
            x_lim_hist_isr[1] = max(x_lim_hist_isr[1], bin_heights.max())

        # collect the labels
        handles, labels = axes[mos_map[ip][0]].get_legend_handles_labels()
        for lbl, hand in zip(labels, handles):
            if lbl not in leg_handles:
                leg_handles[lbl] = hand

    sns.despine(fig=fig, offset=1.0)

    # arrange the parameters
    for ip, pmt in enumerate(params):
        n_freq = pmt["n_freq"]
        n_chan = pmt["n_chan"]

        # set the x/y-axis limit for all histograms
        if ip == n_rows - 1:
            axes[mos_map[ip][0]].set_xticks([1, 10, 100, 1000])
            axes[mos_map[ip][0]].set_xlim([0.9, 1000])
        else:
            axes[mos_map[ip][0]].set_xticks([])
            axes[mos_map[ip][0]].set_xlim([0.9, 1000])

        axes[mos_map[ip][0]].set_ylim(y_lim_isr)
        for i, algo in enumerate(include_algos):
            if algo not in isr_tables[ip]:
                continue
            table = isr_tables[ip][algo]
            axes[mos_map[ip][i + 1]].set_ylim(y_lim_isr)
            axes[mos_map[ip][i + 1]].set_yticks([])
            axes[mos_map[ip][i + 1]].set_xlim(x_lim_hist_isr)

            if ip == n_rows - 1:
                axes[mos_map[ip][i + 1]].set_xticks([np.mean(x_lim_hist_isr)])
                axes[mos_map[ip][i + 1]].set_xticklabels([title_dict[algo]
                                                          ]  # , rotation=75
                                                         )
            else:
                axes[mos_map[ip][i + 1]].set_xticks([])

            # write down the percentage of convergent point onto the histogram directly
            p, y_s, y_f = percent_converge[ip][algo]
            # success
            pts = [0.25 * x_lim_hist_isr[1], y_s + 4.0]
            axes[mos_map[ip][i + 1]].annotate(f"{100 * p:4.1f}%",
                                              pts,
                                              fontsize="x-small",
                                              ha="left")
            # failure
            pts = [
                0.25 * x_lim_hist_isr[1],
                min(y_f + 4.0, y_lim_isr[1] - 1.0)
            ]
            axes[mos_map[ip][i + 1]].annotate(f"{100 * (1 - p):4.1f}%",
                                              pts,
                                              fontsize="x-small",
                                              ha="left")

        axes[mos_map[ip][0]].set_title(f"$F={n_freq}$ $M={n_chan}$")
        axes[mos_map[ip][0]].set_ylabel("ISR [dB]")

    axes[mos_map[0][0]].annotate("IP/ISS overlap", [1, -28],
                                 fontsize="xx-small")
    axes[mos_map[-1][0]].set_xlabel("Iteration")

    fig.tight_layout(pad=0.1, w_pad=0.2, h_pad=1)
    figleg = fig.legend(
        leg_handles.values(),
        leg_handles.keys(),
        title="Algorithm",
        title_fontsize="x-small",
        fontsize="x-small",
        # bbox_to_anchor=[1 - leg_space / fig_width, 0.5],
        bbox_to_anchor=[1, 0.5],
        loc="center right",
        frameon=False,
    )

    fig.subplots_adjust(right=1 - leg_space / (fig_width))

    return fig, axes
def make_figure_cost(config, arg_isr_tables, arg_cost_tables, with_pca=True):

    # pick only the desired params and
    params = []
    isr_tables = []
    cost_tables = []
    for p, isr, cost in zip(config["params"], arg_isr_tables, arg_cost_tables):
        if p["pca"] != with_pca:
            continue
        params.append(p)
        cost_tables.append(cost)

        isr_tables.append({})
        n_algos = 0
        for alg_name, alg_dict in isr.items():
            if alg_name in include_algos:
                isr_tables[-1][alg_name] = alg_dict
                n_algos += 1
    """
    # pick the algorithms to include
    for sub_dict in isr_tables:
        n_algos = 0
        rm_list = []
        for algo in sub_dict:
            if algo not in include_algos_cost:
                rm_list.append(algo)
            else:
                n_algos += 1
        for a in rm_list:
            sub_dict.pop(a)
    """

    results = []

    # prepare the style
    seaborn_config(n_algos)

    # create the figure
    fig, axes = plt.subplots(1, len(params), figsize=figsize_cost)

    # container for some info we will fill as we go
    leg_handles = {}
    y_lim = [np.inf, -np.inf]

    y_lim_up = -100000
    ticks = [
        [-200000, y_lim_up],
        [-300000, y_lim_up],
        [-400000, y_lim_up],
    ]
    ticklabels = [
        [
            "$-2 x 10^5$",
            "$-10^5$",
        ],
        [
            "$-3 x 10^5$",
            "$-10^5$",
        ],
        [
            "$-4 x 10^5$",
            "$-10^5$",
        ],
    ]

    for ip, pmt in enumerate(params):
        n_freq = pmt["n_freq"]
        n_chan = pmt["n_chan"]

        for i, algo in enumerate(include_algos_cost):
            if algo not in isr_tables[ip]:
                continue
            table = cost_tables[ip][algo]

            agg_cost = np.mean(table, axis=0)

            axes[ip].semilogx(np.arange(1,
                                        len(agg_cost) + 1),
                              agg_cost,
                              label=title_dict[algo])

            y_lim = [
                min(y_lim[0], agg_cost.min()),
                max(y_lim[1], agg_cost.max()),
            ]

        y_lim[1] = y_lim_up

        y_lim[0] = y_lim[0] - 0.05 * np.diff(y_lim)

        # collect the labels
        handles, labels = axes[ip].get_legend_handles_labels()
        for lbl, hand in zip(labels, handles):
            if lbl not in leg_handles:
                leg_handles[lbl] = hand

        axes[ip].set_xlabel("Iteration")
        axes[ip].set_xticks([1, 10, 100])
        axes[ip].set_title(f"$F={n_freq}$ M={n_chan}")
        axes[ip].set_ylim(y_lim)

        if ip < len(ticks):
            axes[ip].set_yticks(ticks[ip])
        if ip < len(ticklabels):
            axes[ip].set_yticklabels(ticklabels[ip], fontsize=20)
        axes[ip].tick_params(axis="y",
                             labelsize="x-small",
                             rotation=30,
                             pad=-0.1)

    sns.despine(fig=fig, offset=0.1)

    fig.tight_layout(pad=0.1, w_pad=0.0, h_pad=1.0)
    figleg = axes[-1].legend(
        leg_handles.values(),
        leg_handles.keys(),
        title="Algorithm",
        title_fontsize="x-small",
        fontsize="x-small",
        # bbox_to_anchor=[1 - leg_space / fig_width, 0.5],
        # bbox_to_anchor=[1, 0.5],
        bbox_to_anchor=[1.15, 1.1],
        loc="upper right",
        frameon=False,
    )
    # fig.subplots_adjust(right=1 - leg_space / (fig_width_cost))

    return fig, axes
    if cli_args.pca:
        pca_str = " (PCA)"
    else:
        pca_str = ""

    all_algos = [
        "IVA-NG" + pca_str,
        "FastIVA" + pca_str,
        "AuxIVA-IP" + pca_str,
        "AuxIVA-ISS" + pca_str,
        "AuxIVA-IP2" + pca_str,
        "AuxIVA-IPA" + pca_str,
    ]

    seaborn_config(n_colors=len(all_algos), style="whitegrid")

    if not os.path.exists("figures"):
        os.mkdir("figures")

    fig_dir = "figures/{}_{}_{}".format(parameters["name"],
                                        parameters["_date"],
                                        parameters["_git_sha"])

    if not os.path.exists(fig_dir):
        os.mkdir(fig_dir)

    plt_kwargs = {
        "improvements": {
            "yticks": [[-10, 0, 10, 20], [-10, 0, 10, 20], [0, 10, 20, 30]],
        },