def make_figure( data, config, methods, infos, runtimes, ): seaborn_config(len(methods)) fig, axes = plt.subplots(2, len(config["n_chan"]), figsize=figsize) leg_handles = {} for c, n_chan in enumerate(config["n_chan"]): for method in methods: print(f"=== {method:7s} runtime={runtimes[n_chan][method]:.3f} ===") # make the figure min_cost = np.inf for method in methods: loc_min = infos[n_chan][method]["head_costs"].min() if loc_min < min_cost: min_cost = loc_min min_cost = 0 cost_ylim = [np.inf, -np.inf] for method in methods_order: key = methods[method] median_head_error = np.median(infos[n_chan][method]["head_errors"], axis=0) axes[0, c].loglog( np.arange(len(median_head_error)) + 1, median_head_error, label=method ) axes[0, c].set_ylim([1e-32, 1000]) axes[0, c].set_xticks([1, 10, 100, 1000]) axes[0, c].set_xticklabels(["", "", "", ""]) if c > 0: axes[0, c].set_yticks([]) else: axes[0, c].set_yticks([1e-30, 1e-20, 1e-10, 1e0]) # cost cost_agg = np.median(infos[n_chan][method]["head_costs"] - min_cost, axis=0) if method != "NCG" and method != "IPA+NCG": cost_ylim[0] = np.minimum(cost_agg.min(), cost_ylim[0]) cost_ylim[1] = np.maximum(cost_agg.max(), cost_ylim[1]) axes[1, c].semilogx(np.arange(len(cost_agg)) + 1, cost_agg, label=method) axes[1, c].set_xticks([1, 10, 100, 1000]) axes[1, c].yaxis.labelpad = 1 # X axis limits axes[0, c].set_xlim([1, 500]) axes[1, c].set_xlim([1, 500]) # Y axis labels axes[0, c].set_title(f"$M={n_chan}$") axes[1, c].set_xlabel("Iteration") if c == 0: axes[0, c].set_ylabel("SeDJoCo Residual") axes[1, c].set_ylabel("Surrogate Cost") # keep track of the legend handles, labels = axes[0, c].get_legend_handles_labels() for lbl, hand in zip(labels, handles): if lbl not in leg_handles: if lbl.endswith(" (PCA)"): lbl = lbl[:-6] leg_handles[lbl] = hand cost_ylim = np.array(cost_ylim) cost_ylim_m = 0.80 * cost_ylim[0] + 0.20 * cost_ylim[1] cost_ylim = cost_ylim_m + np.r_[1.05, 0.0] * (cost_ylim - cost_ylim_m) axes[1, c].set_ylim(cost_ylim) sns.despine(fig=fig) fig.tight_layout(pad=0.1, h_pad=0.5) fig.legend( leg_handles.values(), leg_handles.keys(), fontsize="x-small", loc="upper center", bbox_to_anchor=[0.5, 1.01], ncol=len(methods_order), frameon=False, ) fig.subplots_adjust(top=0.86) # fig.align_ylabels(axes[:, 0]) for j in range(2): axes[j, 0].yaxis.set_label_coords(-0.41, 0.5) return fig, axes
def make_plot_isr(config, arg_isr_tables, arg_cost_tables, with_pca=True): # pick only the desired params and params = [] isr_tables = [] cost_tables = [] n_algos = 0 for p, isr, cost in zip(config["params"], arg_isr_tables, arg_cost_tables): if p["pca"] != with_pca: continue params.append(p) cost_tables.append(cost) isr_tables.append({}) n_algos = 0 for alg_name, alg_dict in isr.items(): if alg_name in include_algos: isr_tables[-1][alg_name] = alg_dict n_algos += 1 # pick the algorithms to include """ new_isr_tables = {} for sub_dict in isr_tables: n_algos = 0 rm_list = [] for algo, d in sub_dict: if algo not in include_algos: rm_list.append(algo) else: n_algos += 1 for a in rm_list: sub_dict.pop(a) """ # construct the mosaic # n_algos = len(include_algos) n_rows = len(params) mosaic_array = [] mosaic_len_left = n_algos // 2 mosaic_row_len = mosaic_len_left + n_algos mos_map = [] assert n_rows * (n_algos + 1) <= len(ascii_letters) for b in range(n_rows): mosaic_array.append([ascii_letters[b] * mosaic_len_left]) mos_map.append([mosaic_array[b][0][0]]) for i in range(1, n_algos + 1): letters = ascii_letters[n_rows * i + b] mosaic_array[b].append(letters) mos_map[b].append(letters) mosaic = "\n".join(["".join(a) for a in mosaic_array]) # prepare the style seaborn_config(n_algos) # create the figure fig_size = (figsize[0], figsize[1] * len(params) / 3) fig, axes = plt.subplot_mosaic(mosaic, figsize=fig_size) # container for some info we will fill as we go leg_handles = {} y_lim_isr = [0, 1] x_lim_hist_isr = [0, 0] percent_converge = [] for ip, pmt in enumerate(params): n_freq = pmt["n_freq"] n_chan = pmt["n_chan"] percent_converge.append({}) for i, algo in enumerate(include_algos): if algo not in isr_tables[ip]: continue table = isr_tables[ip][algo] n_iter = config["algos"][algo]["kwargs"]["n_iter"] algo_name = config["algos"][algo]["algo"] if bss.is_dual_update[algo_name]: callback_checkpoints = np.arange(0, n_iter + 1, 2) else: callback_checkpoints = np.arange(0, n_iter + 1) # isr y_lim_isr = [ min(y_lim_isr[0], table.min()), max(y_lim_isr[1], np.percentile(table, 99.5)), ] I_s = table[:, -1] < fail_thresh # separation is sucessful I_f = table[:, -1] >= fail_thresh # separation fails f_agg = np.mean # ISR convergence p = axes[mos_map[ip][0]].semilogx( np.array(callback_checkpoints), f_agg(table[I_s, :], axis=0), label=title_dict[algo], ) # keep the percentage and mean of success/failure percent_converge[ip][algo] = ( np.sum(I_s) / len(I_s), f_agg(table[I_s, -1]), f_agg(table[I_f, -1]), ) # get color of main line c = p[0].get_color() # now draw the divergent line axes[mos_map[ip][0]].plot( np.array(callback_checkpoints), f_agg(table[I_f, :], axis=0), alpha=0.6, c=c, linestyle="--", ) # Histograms bin_heights, bins, patches = axes[mos_map[ip][i + 1]].hist( table[:, -1], bins=n_bins, orientation="horizontal", density=True, color=c, linewidth=0.0, ) # keep track of required length of x-axis for the histograms x_lim_hist_isr[1] = max(x_lim_hist_isr[1], bin_heights.max()) # collect the labels handles, labels = axes[mos_map[ip][0]].get_legend_handles_labels() for lbl, hand in zip(labels, handles): if lbl not in leg_handles: leg_handles[lbl] = hand sns.despine(fig=fig, offset=1.0) # arrange the parameters for ip, pmt in enumerate(params): n_freq = pmt["n_freq"] n_chan = pmt["n_chan"] # set the x/y-axis limit for all histograms if ip == n_rows - 1: axes[mos_map[ip][0]].set_xticks([1, 10, 100, 1000]) axes[mos_map[ip][0]].set_xlim([0.9, 1000]) else: axes[mos_map[ip][0]].set_xticks([]) axes[mos_map[ip][0]].set_xlim([0.9, 1000]) axes[mos_map[ip][0]].set_ylim(y_lim_isr) for i, algo in enumerate(include_algos): if algo not in isr_tables[ip]: continue table = isr_tables[ip][algo] axes[mos_map[ip][i + 1]].set_ylim(y_lim_isr) axes[mos_map[ip][i + 1]].set_yticks([]) axes[mos_map[ip][i + 1]].set_xlim(x_lim_hist_isr) if ip == n_rows - 1: axes[mos_map[ip][i + 1]].set_xticks([np.mean(x_lim_hist_isr)]) axes[mos_map[ip][i + 1]].set_xticklabels([title_dict[algo] ] # , rotation=75 ) else: axes[mos_map[ip][i + 1]].set_xticks([]) # write down the percentage of convergent point onto the histogram directly p, y_s, y_f = percent_converge[ip][algo] # success pts = [0.25 * x_lim_hist_isr[1], y_s + 4.0] axes[mos_map[ip][i + 1]].annotate(f"{100 * p:4.1f}%", pts, fontsize="x-small", ha="left") # failure pts = [ 0.25 * x_lim_hist_isr[1], min(y_f + 4.0, y_lim_isr[1] - 1.0) ] axes[mos_map[ip][i + 1]].annotate(f"{100 * (1 - p):4.1f}%", pts, fontsize="x-small", ha="left") axes[mos_map[ip][0]].set_title(f"$F={n_freq}$ $M={n_chan}$") axes[mos_map[ip][0]].set_ylabel("ISR [dB]") axes[mos_map[0][0]].annotate("IP/ISS overlap", [1, -28], fontsize="xx-small") axes[mos_map[-1][0]].set_xlabel("Iteration") fig.tight_layout(pad=0.1, w_pad=0.2, h_pad=1) figleg = fig.legend( leg_handles.values(), leg_handles.keys(), title="Algorithm", title_fontsize="x-small", fontsize="x-small", # bbox_to_anchor=[1 - leg_space / fig_width, 0.5], bbox_to_anchor=[1, 0.5], loc="center right", frameon=False, ) fig.subplots_adjust(right=1 - leg_space / (fig_width)) return fig, axes
def make_figure_cost(config, arg_isr_tables, arg_cost_tables, with_pca=True): # pick only the desired params and params = [] isr_tables = [] cost_tables = [] for p, isr, cost in zip(config["params"], arg_isr_tables, arg_cost_tables): if p["pca"] != with_pca: continue params.append(p) cost_tables.append(cost) isr_tables.append({}) n_algos = 0 for alg_name, alg_dict in isr.items(): if alg_name in include_algos: isr_tables[-1][alg_name] = alg_dict n_algos += 1 """ # pick the algorithms to include for sub_dict in isr_tables: n_algos = 0 rm_list = [] for algo in sub_dict: if algo not in include_algos_cost: rm_list.append(algo) else: n_algos += 1 for a in rm_list: sub_dict.pop(a) """ results = [] # prepare the style seaborn_config(n_algos) # create the figure fig, axes = plt.subplots(1, len(params), figsize=figsize_cost) # container for some info we will fill as we go leg_handles = {} y_lim = [np.inf, -np.inf] y_lim_up = -100000 ticks = [ [-200000, y_lim_up], [-300000, y_lim_up], [-400000, y_lim_up], ] ticklabels = [ [ "$-2 x 10^5$", "$-10^5$", ], [ "$-3 x 10^5$", "$-10^5$", ], [ "$-4 x 10^5$", "$-10^5$", ], ] for ip, pmt in enumerate(params): n_freq = pmt["n_freq"] n_chan = pmt["n_chan"] for i, algo in enumerate(include_algos_cost): if algo not in isr_tables[ip]: continue table = cost_tables[ip][algo] agg_cost = np.mean(table, axis=0) axes[ip].semilogx(np.arange(1, len(agg_cost) + 1), agg_cost, label=title_dict[algo]) y_lim = [ min(y_lim[0], agg_cost.min()), max(y_lim[1], agg_cost.max()), ] y_lim[1] = y_lim_up y_lim[0] = y_lim[0] - 0.05 * np.diff(y_lim) # collect the labels handles, labels = axes[ip].get_legend_handles_labels() for lbl, hand in zip(labels, handles): if lbl not in leg_handles: leg_handles[lbl] = hand axes[ip].set_xlabel("Iteration") axes[ip].set_xticks([1, 10, 100]) axes[ip].set_title(f"$F={n_freq}$ M={n_chan}") axes[ip].set_ylim(y_lim) if ip < len(ticks): axes[ip].set_yticks(ticks[ip]) if ip < len(ticklabels): axes[ip].set_yticklabels(ticklabels[ip], fontsize=20) axes[ip].tick_params(axis="y", labelsize="x-small", rotation=30, pad=-0.1) sns.despine(fig=fig, offset=0.1) fig.tight_layout(pad=0.1, w_pad=0.0, h_pad=1.0) figleg = axes[-1].legend( leg_handles.values(), leg_handles.keys(), title="Algorithm", title_fontsize="x-small", fontsize="x-small", # bbox_to_anchor=[1 - leg_space / fig_width, 0.5], # bbox_to_anchor=[1, 0.5], bbox_to_anchor=[1.15, 1.1], loc="upper right", frameon=False, ) # fig.subplots_adjust(right=1 - leg_space / (fig_width_cost)) return fig, axes
if cli_args.pca: pca_str = " (PCA)" else: pca_str = "" all_algos = [ "IVA-NG" + pca_str, "FastIVA" + pca_str, "AuxIVA-IP" + pca_str, "AuxIVA-ISS" + pca_str, "AuxIVA-IP2" + pca_str, "AuxIVA-IPA" + pca_str, ] seaborn_config(n_colors=len(all_algos), style="whitegrid") if not os.path.exists("figures"): os.mkdir("figures") fig_dir = "figures/{}_{}_{}".format(parameters["name"], parameters["_date"], parameters["_git_sha"]) if not os.path.exists(fig_dir): os.mkdir(fig_dir) plt_kwargs = { "improvements": { "yticks": [[-10, 0, 10, 20], [-10, 0, 10, 20], [0, 10, 20, 30]], },