def save_json(self, name, blob):

        sub_dir = ensure_sub_dir(self.data_dir, f"dataframes/")
        filename = os.path.join(sub_dir, name)
        json_obj = json.dumps(blob)
        with open(filename, "w") as json_file:
            json_file.write(json_obj)
    def plot_activation_fns(self, act_fns):
        """
        Plots the given activation functions on the same figure
        """

        x = np.linspace(-5, 5, 50)
        x = torch.tensor(x)
        fig, ax = plt.subplots(figsize=(7, 5))

        for fn in act_fns:
            y = fn(x)
            ax.plot(x, y, label=str(fn))

        ax.legend()

        # optional saving
        if not self.save_fig:
            print("Not saving.")
            plt.show()
            return

        sub_dir = ensure_sub_dir(self.data_dir, f"figures/act_fns/")
        fn_names = " & ".join([str(fn) for fn in act_fns])
        filename = f"{fn_names}.png"
        filename = os.path.join(sub_dir, filename)
        print(f"Saving... {filename}")
        plt.savefig(filename, dpi=300)
    def plot_type_specific_weights(self, net_name, case):
        """
        Plots mean absolute weights for each cell type across layers 
        """

        # pull data
        df = self.stats_processor.load_weight_df(net_name, case)

        # plot
        state_keys = list(nets[net_name]["state_keys"].keys())
        x = np.array([i * 1.25 for i in range(len(state_keys))])
        n_act_fns = len(df.index.levels[1])
        width = 1.0 / n_act_fns
        err_kw = dict(lw=1, capsize=3, capthick=1)

        fig, ax = plt.subplots(figsize=(14, 8))
        clrs = sns.color_palette("hls", n_act_fns)

        for i in range(n_act_fns):

            act_fn = df.index.levels[1][i]

            yvals = df["avg_weight"][:, act_fn][state_keys]
            yerr = df["sem_weight"][:, act_fn][state_keys]

            ax.bar(x,
                   yvals,
                   width,
                   yerr=yerr,
                   label=act_fn,
                   error_kw=err_kw,
                   color=clrs[i])

            # update bar locations for next group
            x = [loc + width for loc in x]

        ax.set_title("Weight distribution across layers after training")
        ax.set_xlabel("Layer")
        ax.set_ylabel("Mean abs weight per layer")
        ax.legend()

        loc = (n_act_fns - 1) / (2. * n_act_fns)
        ax.set_xticks([loc + i * 1.25 for i in range(len(state_keys))])
        labels = list(nets[net_name]["state_keys"].values())
        ax.set_xticklabels(labels)

        # optional saving
        if not self.save_fig:
            print("Not saving.")
            plt.show()
            return

        sub_dir = ensure_sub_dir(self.data_dir,
                                 f"figures/{net_name}/weight distr/")
        filename = f"{case} weight distr.png"
        filename = os.path.join(sub_dir, filename)
        print(f"Saving... {filename}")
        plt.savefig(filename, dpi=300)
    def save(self, sub_dir_name, filename):

        if self.sub_dir_name is not None:
            sub_dir_name = self.sub_dir_name
        sub_dir = ensure_sub_dir(self.data_dir, f"figures/{sub_dir_name}")

        filename = os.path.join(sub_dir, filename)

        print(f"Saving... {filename}")

        plt.savefig(f"{filename}.svg")

        if self.save_png:
            plt.savefig(f"{filename}.png", dpi=300)
Ejemplo n.º 5
0
    def plot_activation_fns(self, act_fns, clr_set="husl"):
        """
        Plots the given activation functions on the same figure
        """

        x = np.linspace(-100, 100, 10000)
        x = torch.tensor(x)
        fig, ax = plt.subplots(figsize=(5, 5))
        clrs = sns.color_palette(clr_set, len(act_fns))

        for i in range(len(act_fns)):
            fn = act_fns[i]
            y = fn(x)
            normalized = y / max(y)
            label = str(fn)
            ax.plot(x, y, label=label, c=clrs[i], linewidth=3)
            # ax.plot(x, normalized, label=f"{str(fn)} norm")

        # axes
        ax.axhline(y=0, color="k", linestyle="--", alpha=0.2)
        ax.axvline(x=0, color="k", linestyle="--", alpha=0.2)

        ax.set_xticks([-1, 0, 1])
        ax.set_xticklabels([-1, 0, 1])
        ax.set_yticks([-1, 0, 1])
        ax.set_yticklabels([-1, 0, 1])
        ax.set_xlim([-2, 2])
        ax.set_ylim([-1, 2])
        # ax.axis("equal")
        ax.set_aspect("equal", "box")
        ax.set_xlabel("Input", fontsize=large_font_size)
        ax.set_ylabel("Activation", fontsize=large_font_size)
        ax.legend(fontsize=small_font_size, loc="upper left")
        plt.tight_layout()

        # optional saving
        if not self.save_fig:
            print("Not saving.")
            plt.show()
            return

        sub_dir = ensure_sub_dir(self.data_dir, f"figures/act_fns/")
        fn_names = " & ".join([str(fn) for fn in act_fns])
        filename = f"{fn_names}"
        print(f"Saving... {filename}")
        plt.savefig(os.path.join(sub_dir, f"{filename}.svg"))
        plt.savefig(os.path.join(sub_dir, f"{filename}.png"), dpi=300)
Ejemplo n.º 6
0
    def plot_act_fn_mapping(self, act_fn1, act_fn2):
        """
        Visualize the expressivity of mixed activation functions
        """

        # plot input space
        fig, ax = plt.subplots(figsize=(7, 5))
        circle = plt.Circle((0, 0), 1, color="k", fill=False, linewidth=2)
        ax.add_artist(circle)
        ax.axis("equal")
        ax.set(xlim=(-2, 2), ylim=(-2, 2))
        ax.axvline(0, linestyle="--", alpha=0.25, color="k")
        ax.axhline(0, linestyle="--", alpha=0.25, color="k")
        # plt.savefig("unit_circle.svg")

        # plot output space
        fig, ax = plt.subplots(figsize=(7, 5))
        ax.axis("equal")
        ax.set(xlim=(-2, 2), ylim=(-2, 2))
        ax.axvline(0, linestyle="--", alpha=0.25, color="k")
        ax.axhline(0, linestyle="--", alpha=0.25, color="k")
        x = np.arange(0, 2 * np.pi, 1 / 100)
        x1 = torch.tensor(np.sin(x))
        x2 = torch.tensor(np.cos(x))
        ax.plot(x1, x2, "k:", linewidth=2, label="Input")

        # output space
        ax.plot(act_fn1(x1), act_fn1(x2), "b--", linewidth=2, label="act_fn1")
        ax.plot(act_fn2(x1), act_fn2(x2), "g--", linewidth=2, label="act_fn2")

        # mixed space
        ax.plot(act_fn1(x1), act_fn2(x2), "r", linewidth=2, label="Mixed")
        plt.legend()

        # optional saving
        if not self.save_fig:
            print("Not saving.")
            plt.show()
            return

        sub_dir = ensure_sub_dir(self.data_dir, f"figures/act_fns/")
        fn_names = " & ".join([str(fn) for fn in act_fns])
        filename = f"{fn_names}.svg"
        filename = os.path.join(sub_dir, filename)
        print(f"Saving... {filename}")
        plt.savefig(filename, dpi=300)
Ejemplo n.º 7
0
class WeightVisualizer():
    
    def __init__(self, data_dir, n_classes=10, save_fig=False):
        
        self.data_dir = data_dir
        self.save_fig = save_fig
        
        self.stats_processor = StatsProcessor(data_dir, n_classes)
        
    def plot_type_specific_weights(self, net_name, case):
        """
        Plots mean absolute weights for each cell type across layers 
        """

        # pull data
        df = self.stats_processor.load_weight_df(net_name, case)

        # plot
        state_keys = list(nets[net_name]["state_keys"].keys())
        x = np.array([i * 1.25 for i in range(len(state_keys))])
        n_act_fns = len(df.index.levels[1])
        width = 1.0 / n_act_fns
        err_kw = dict(lw=1, capsize=3, capthick=1)

        fig, ax = plt.subplots(figsize=(14,8))
        clrs = sns.color_palette("hls", n_act_fns)

        for i in range(n_act_fns):

            act_fn = df.index.levels[1][i]

            yvals = df["avg_weight"][:, act_fn][state_keys]
            yerr = df["sem_weight"][:, act_fn][state_keys]

            ax.bar(x, yvals, width, yerr=yerr, label=act_fn, error_kw=err_kw, 
                color=clrs[i])

            # update bar locations for next group
            x = [loc + width for loc in x]

        ax.set_title("Weight distribution across layers after training")
        ax.set_xlabel("Layer")
        ax.set_ylabel("Mean abs weight per layer")
        ax.legend()

        loc = (n_act_fns - 1) / (2. * n_act_fns)
        ax.set_xticks([loc + i * 1.25 for i in range(len(state_keys))])
        labels = list(nets[net_name]["state_keys"].values())
        ax.set_xticklabels(labels)

        # optional saving
        if not self.save_fig:
            print("Not saving.")
            plt.show()
            return

        sub_dir = ensure_sub_dir(self.data_dir, f"figures/{net_name}/weight distr/")
        filename = f"{case} weight distr.svg"
        filename = os.path.join(sub_dir, filename)
        print(f"Saving... {filename}")
        plt.savefig(filename, dpi=300)  


    def plot_weight_changes(self, net_name, cases, train_schemes):
        """
        Plots average change in weights over training for the given
        experimental cases.

        Args:
            cases (list): Experimental cases to include in figure.

        Returns:
            None.

        """
        # pull data
        df = self.stats_processor.load_weight_change_df(net_name, cases, train_schemes)

        state_keys = df.columns.to_list()
        sem_cols = list(filter(lambda x: x.endswith(".sem"), df.columns))
        df_groups = df.groupby(["train_scheme", "case"])

        # plot
        x = np.array([i * 1.25 for i in range(len(state_keys))])
        width = 1.0 / len(cases)
        err_kw = dict(lw=1, capsize=3, capthick=1)

        fig, ax = plt.subplots(figsize=(14,8))
        clrs = sns.color_palette("hls", len(cases))

        for i in range(len(cases)):

            case = cases[i]
            group = df_groups.get_group(case)
            yvals = group[state_keys].values[0]
            yerr = group[sem_cols].values[0]

            ax.bar(x, yvals, width, yerr=yerr, label=case, error_kw=err_kw, 
                color=clrs[i])

            # update bar locations for next group
            x = [loc + width for loc in x]

        ax.set_title("Weight changes by layer during training")
        ax.set_xlabel("Layer")
        ax.set_ylabel("Mean abs weight change per layer")
        ax.legend()

        loc = (len(cases) - 1) / (2. * len(cases))
        ax.set_xticks([loc + i * 1.25 for i in range(len(state_keys))])
        labels = [k[:-7] for k in df.columns if k.endswith(".weight")]
        ax.set_xticklabels(labels)

        # optional saving
        if not self.save_fig:
            print("Not saving.")
            plt.show()
            return

        sub_dir = ensure_sub_dir(self.data_dir, f"figures/{net_name}/weight change/")
        cases = " & ".join(cases)
        filename = f"{cases} weight.svg"
        filename = os.path.join(sub_dir, filename)
        print(f"Saving... {filename}")
        plt.savefig(filename, dpi=300)
    def save_df(self, name, df):

        sub_dir = ensure_sub_dir(self.data_dir, f"dataframes/")
        filename = os.path.join(sub_dir, name)
        df.to_csv(filename, header=True, columns=df.columns)
class Visualizer():
    def __init__(self, data_dir, n_classes=10, save_fig=False, refresh=False):

        self.data_dir = data_dir
        self.save_fig = save_fig
        self.refresh = refresh

        self.stats_processor = StatsProcessor(data_dir, n_classes)

    def plot_activation_fns(self, act_fns):
        """
        Plots the given activation functions on the same figure
        """

        x = np.linspace(-5, 5, 50)
        x = torch.tensor(x)
        fig, ax = plt.subplots(figsize=(7, 5))

        for fn in act_fns:
            y = fn(x)
            ax.plot(x, y, label=str(fn))

        ax.legend()

        # optional saving
        if not self.save_fig:
            print("Not saving.")
            plt.show()
            return

        sub_dir = ensure_sub_dir(self.data_dir, f"figures/act_fns/")
        fn_names = " & ".join([str(fn) for fn in act_fns])
        filename = f"{fn_names}.png"
        filename = os.path.join(sub_dir, filename)
        print(f"Saving... {filename}")
        plt.savefig(filename, dpi=300)

    def plot_prediction(self, pred_type="linear"):
        """
        """

        # pull data
        df, case_dict, index_cols = self.stats_processor.load_final_acc_df(
            self.refresh)
        df_groups = df.groupby(index_cols)

        # plot
        fig, ax = plt.subplots(figsize=(9, 12))

    def scatter_final_acc(self, dataset, net_names, schemes, act_fns):
        """
        Plot a scatter plot of predicted vs actual final accuracy for the 
        given mixed cases.

        Args:
            net_names
            schemes
            act_fns
        """

        # pull data
        df, case_dict, index_cols = self.stats_processor.load_final_acc_df(
            self.refresh)
        df_groups = df.groupby(index_cols)

        # plot
        fig, ax = plt.subplots(figsize=(14, 14))
        fmts = [".", "^"]
        mfcs = ["None", None]
        clrs = sns.color_palette("husl", len(mixed_cases))

        # plot mixed cases
        i = 0
        for g in df_groups.groups:

            gset, net, scheme, case = g

            g_data = df_groups.get_group((dataset, net, scheme, case))
            fmt = fmts[net_names.index(net)]
            mfc = mfcs[schemes.index(scheme)]
            clr = clrs[mixed_cases.index(case)]

            # actual
            y_act = g_data["final_val_acc"]["mean"].values[0]
            y_err = g_data["final_val_acc"]["std"].values[0] * 2

            # prediction - get component cases...
            x_pred = component_accs.mean()
            x_err = component_stds.mean()

            # plot
            ax.errorbar(x_pred,
                        y_act,
                        xerr=x_err,
                        yerr=y_err,
                        label=f"{net} {scheme} {case}",
                        elinewidth=1,
                        c=clr,
                        fmt=fmt,
                        markersize=10,
                        markerfacecolor=mfc)

            i += 1

        # plot reference line
        x = np.linspace(0, 1, 50)
        ax.plot(x, x, c=(0.5, 0.5, 0.5, 0.25), dashes=[6, 2])

        # set figure text
        ax.set_title(
            f"Linear predicted vs actual mixed case final accuracy - {dataset}",
            fontsize=18)
        ax.set_xlabel("Predicted", fontsize=16)
        ax.set_ylabel("Actual", fontsize=16)
        ax.set_xlim([0.1, 1])
        ax.set_ylim([0.1, 1])
        ax.set_aspect("equal", "box")
        ax.legend()

        # optional saving
        if not self.save_fig:
            print("Not saving.")
            plt.show()
            return

        sub_dir = ensure_sub_dir(self.data_dir, f"figures/scatter/{dataset}")
        net_names = ", ".join(net_names)
        schemes = ", ".join(schemes)
        act_fns = ", ".join(act_fns)
        filename = f"{net_names}_{schemes}_{act_fns}_scatter.png"
        filename = os.path.join(sub_dir, filename)
        print(f"Saving... {filename}")
        plt.savefig(filename, dpi=300)

    def plot_final_accuracy(self, net_name, control_cases, mixed_cases):
        """
        Plot accuracy at the end of training for given control cases
        and mixed case, including predicted mixed case accuracy based
        on linear combination of control cases
        """

        # pull data
        acc_df, case_dict = self.stats_processor.load_final_acc_df(
            net_name, control_cases + mixed_cases)
        acc_df_groups = acc_df.groupby("case")

        # plot...
        handles = []
        labels = []

        fig, axes = plt.subplots(nrows=1,
                                 ncols=2,
                                 figsize=(14, 8),
                                 sharey=True)
        fig.subplots_adjust(wspace=0)
        clrs = sns.color_palette("hls",
                                 len(control_cases) + 2 * len(mixed_cases))

        for i in range(len(control_cases)):

            case = control_cases[i]
            group = acc_df_groups.get_group(case)
            p = float(case_dict[case][0])

            # error bars = 2 standard devs
            yvals = group["final_val_acc"]["mean"].values
            yerr = group["final_val_acc"]["std"].values * 2
            h = axes[0].errorbar(p,
                                 yvals[0],
                                 yerr=yerr,
                                 label=case,
                                 capsize=3,
                                 elinewidth=1,
                                 c=clrs[i],
                                 fmt=".")

            handles.append(h)
            labels.append(case)

        # plot mixed case
        for i in range(len(mixed_cases)):

            mixed_case = mixed_cases[i]

            # actual
            group = acc_df_groups.get_group(mixed_case)
            y_act = group["final_val_acc"]["mean"].values[0]
            y_err = group["final_val_acc"]["std"].values * 2
            l = f"{mixed_case} actual"
            h = axes[1].errorbar(i,
                                 y_act,
                                 yerr=y_err,
                                 label=l,
                                 capsize=3,
                                 elinewidth=1,
                                 c=clrs[len(control_cases) + i],
                                 fmt=".")

            labels.append(l)
            handles.append(h)

            # predicted
            ps = [p for p in case_dict[mixed_case]]
            component_cases = [
                k for k, v in case_dict.items() if len(v) == 1 and v[0] in ps
            ]
            y_pred = acc_df["final_val_acc"]["mean"][component_cases].mean()
            l = f"{mixed_case} prediction"
            h = axes[1].plot(i,
                             y_pred,
                             "x",
                             label=l,
                             c=clrs[len(control_cases) + i + 1])

            labels.append(l)
            handles.append(h)

        fig.suptitle("Final accuracy")
        axes[0].set_xlabel("Activation function parameter value")
        axes[1].set_xlabel("Mixed cases")
        axes[0].set_ylabel("Final validation accuracy")
        axes[1].xaxis.set_ticks([])

        # shrink second axis by 20%
        box = axes[1].get_position()
        axes[1].set_position([box.x0, box.y0, box.width * 0.8, box.height])

        # append legend to second axis
        axes[1].legend(handles,
                       labels,
                       loc='center left',
                       bbox_to_anchor=(1, 0.5))

        # optional saving
        if not self.save_fig:
            print("Not saving.")
            plt.show()
            return

        sub_dir = ensure_sub_dir(self.data_dir,
                                 f"figures/{net_name}/final accuracy/")
        cases = " & ".join(mixed_cases)
        filename = f"{cases} final acc.png"
        filename = os.path.join(sub_dir, filename)
        print(f"Saving... {filename}")
        plt.savefig(filename, dpi=300)
    def plot_accuracy(self, dataset, net_name, schemes, cases):
        """
        Plots accuracy over training for different experimental cases.

        Args:
            dataset
            net_name
            schemes
            cases (list): Experimental cases to include in figure.

        Returns:
            None.

        """
        # pull data
        acc_df = self.stats_processor.load_accuracy_df(dataset, net_name,
                                                       cases, schemes)

        # group and compute stats
        acc_df.set_index(["train_scheme", "case", "epoch"], inplace=True)
        acc_df_groups = acc_df.groupby(["train_scheme", "case", "epoch"])
        acc_df_stats = acc_df_groups.agg({"acc": [np.mean, np.std]})
        acc_df_stats_groups = acc_df_stats.groupby(["train_scheme", "case"])

        # plot
        fig, ax = plt.subplots(figsize=(14, 8))
        clrs = sns.color_palette("hls", len(acc_df_stats_groups.groups))

        for group, clr in zip(acc_df_stats_groups.groups, clrs):

            scheme, case = group
            group_data = acc_df_stats_groups.get_group((scheme, case))

            # error bars = 2 standard devs
            yvals = group_data["acc"]["mean"].values
            yerr = group_data["acc"]["std"].values * 2
            ax.plot(range(len(yvals)), yvals, label=f"{scheme} {case}", c=clr)
            ax.fill_between(range(len(yvals)),
                            yvals - yerr,
                            yvals + yerr,
                            alpha=0.1,
                            facecolor=clr)

        ax.set_title("Classification accuracy during training")
        ax.set_xlabel("Epoch")
        ax.set_ylabel("Validation accuracy")
        ax.legend()
        step = 5
        ax.set_xticks([i * step for i in range(int((len(yvals) + 1) / step))])

        # optional saving
        if not self.save_fig:
            print("Not saving.")
            plt.show()
            return

        sub_dir = ensure_sub_dir(self.data_dir,
                                 f"figures/{dataset}/{net_name}/accuracy/")
        case_names = " & ".join(cases)
        filename = f"{case_names} accuracy.png"
        filename = os.path.join(sub_dir, filename)
        print(f"Saving... {filename}")
        plt.savefig(filename, dpi=300)
    def scatter_final_acc(self, dataset, net_names, schemes, act_fns):
        """
        Plot a scatter plot of predicted vs actual final accuracy for the 
        given mixed cases.

        Args:
            net_names
            schemes
            act_fns
        """

        # pull data
        df, case_dict, index_cols = self.stats_processor.load_final_acc_df(
            self.refresh)
        df_groups = df.groupby(index_cols)

        # plot
        fig, ax = plt.subplots(figsize=(14, 14))
        fmts = [".", "^"]
        mfcs = ["None", None]
        clrs = sns.color_palette("husl", len(mixed_cases))

        # plot mixed cases
        i = 0
        for g in df_groups.groups:

            gset, net, scheme, case = g

            g_data = df_groups.get_group((dataset, net, scheme, case))
            fmt = fmts[net_names.index(net)]
            mfc = mfcs[schemes.index(scheme)]
            clr = clrs[mixed_cases.index(case)]

            # actual
            y_act = g_data["final_val_acc"]["mean"].values[0]
            y_err = g_data["final_val_acc"]["std"].values[0] * 2

            # prediction - get component cases...
            x_pred = component_accs.mean()
            x_err = component_stds.mean()

            # plot
            ax.errorbar(x_pred,
                        y_act,
                        xerr=x_err,
                        yerr=y_err,
                        label=f"{net} {scheme} {case}",
                        elinewidth=1,
                        c=clr,
                        fmt=fmt,
                        markersize=10,
                        markerfacecolor=mfc)

            i += 1

        # plot reference line
        x = np.linspace(0, 1, 50)
        ax.plot(x, x, c=(0.5, 0.5, 0.5, 0.25), dashes=[6, 2])

        # set figure text
        ax.set_title(
            f"Linear predicted vs actual mixed case final accuracy - {dataset}",
            fontsize=18)
        ax.set_xlabel("Predicted", fontsize=16)
        ax.set_ylabel("Actual", fontsize=16)
        ax.set_xlim([0.1, 1])
        ax.set_ylim([0.1, 1])
        ax.set_aspect("equal", "box")
        ax.legend()

        # optional saving
        if not self.save_fig:
            print("Not saving.")
            plt.show()
            return

        sub_dir = ensure_sub_dir(self.data_dir, f"figures/scatter/{dataset}")
        net_names = ", ".join(net_names)
        schemes = ", ".join(schemes)
        act_fns = ", ".join(act_fns)
        filename = f"{net_names}_{schemes}_{act_fns}_scatter.png"
        filename = os.path.join(sub_dir, filename)
        print(f"Saving... {filename}")
        plt.savefig(filename, dpi=300)
        ax.set_xlabel("Layer")
        ax.set_ylabel("Mean abs weight change per layer")
        ax.legend()

        loc = (len(cases) - 1) / (2. * len(cases))
        ax.set_xticks([loc + i * 1.25 for i in range(len(state_keys))])
        labels = [k[:-7] for k in df.columns if k.endswith(".weight")]
        ax.set_xticklabels(labels)

        # optional saving
        if not self.save_fig:
            print("Not saving.")
            plt.show()
            return

        sub_dir = ensure_sub_dir(self.data_dir,
                                 f"figures/{net_name}/weight change/")
        cases = " & ".join(cases)
        filename = f"{cases} weight.png"
        filename = os.path.join(sub_dir, filename)
        print(f"Saving... {filename}")
        plt.savefig(filename, dpi=300)


if __name__ == "__main__":

    visualizer = Visualizer(
        "/home/briardoty/Source/allen-inst-cell-types/data_mountpoint",
        10,
        save_fig=False,
        refresh=False)
Ejemplo n.º 13
0
    def plot_predictions(self,
                         dataset,
                         net_names,
                         schemes,
                         excl_arr,
                         pred_type="min",
                         cross_family=None,
                         pred_std=False):
        """
        Plot a single axis figure of offset from predicted number of epochs
        it takes a net to get to 90% of its peak accuracy.
        """

        # pull data
        df, case_dict, idx_cols = self.stats_processor.load_max_acc_df()

        # performance relative to predictions
        vs = "epochs_past_vs"
        df[f"{vs}_linear"] = df["epochs_past"]["mean"] - df[
            "linear_pred_epochs_past"]["mean"]
        df[f"{vs}_min"] = df["epochs_past"]["mean"] - df[
            "min_pred_epochs_past"]["mean"]

        # filter dataframe
        df = df.query(f"is_mixed") \
            .query(f"dataset == '{dataset}'") \
            .query(f"net_name in {net_names}") \
            .query(f"train_scheme in {schemes}")
        if cross_family is not None:
            df = df.query(f"cross_fam == {cross_family}")
        for excl in excl_arr:
            df = df.query(f"not case.str.contains('{excl}')", engine="python")
        sort_df = df.sort_values(["net_name", f"{vs}_{pred_type}"],
                                 ascending=False)

        # determine each label length for alignment
        lengths = {}
        label_idxs = [3]
        for i in label_idxs:
            lengths[i] = np.max(
                [len(x) for x in sort_df.index.unique(level=i)]) + 2

        # plot
        plt.figure(figsize=(16, 16))
        plt.gca().axvline(0, color='k', linestyle='--')
        clrs = sns.color_palette("Set2", len(net_names))

        ylabels = dict()
        handles = dict()
        sig_arr = list()
        i = 0
        xmax = 0
        xmin = 0
        for midx in sort_df.index.values:

            # dataset, net, scheme, case, mixed, cross-family
            d, n, s, c, m, cf = midx
            clr = clrs[net_names.index(n)]

            # prettify
            if np.mod(i, 2) == 0:
                plt.gca().axhspan(i - .5, i + .5, alpha=0.1, color="k")

            # stats
            perf = sort_df.loc[midx][f"{vs}_{pred_type}"].values[0]
            err = sort_df.loc[midx]["epochs_past"]["std"] * 1.98

            xmin = min(xmin, perf - err)
            xmax = max(xmax, perf + err)

            # plot "good" and "bad"
            if perf + err < 0:
                if cf or cross_family is not None:
                    plt.plot([perf - err, perf + err], [i, i],
                             linestyle="-",
                             c=clr,
                             linewidth=6,
                             alpha=.8)
                else:
                    plt.plot([perf - err, perf + err], [i, i],
                             linestyle=":",
                             c=clr,
                             linewidth=6,
                             alpha=.8)
                h = plt.plot(perf, i, c=clr, marker="o")
                handles[n] = h[0]
            else:
                if cf or cross_family is not None:
                    plt.plot([perf - err, perf + err], [i, i],
                             linestyle="-",
                             c=clr,
                             linewidth=6,
                             alpha=.2)
                    plt.plot([perf - err, perf + err], [i, i],
                             linestyle="-",
                             linewidth=6,
                             c="k",
                             alpha=.1)
                else:
                    plt.plot([perf - err, perf + err], [i, i],
                             linestyle=":",
                             c=clr,
                             linewidth=6,
                             alpha=.2)
                    plt.plot([perf - err, perf + err], [i, i],
                             linestyle=":",
                             linewidth=6,
                             c="k",
                             alpha=.1)
                h = plt.plot(perf, i, c=clr, marker="o", alpha=0.5)
                if handles.get(n) is None:
                    handles[n] = h[0]

            # optionally, plot the 95% ci for the prediction
            if pred_std:
                pred_err = sort_df.loc[midx][f"{pred_type}_pred_epochs_past"][
                    "std"] * 1.98
                plt.plot([-pred_err, pred_err], [i, i],
                         linestyle="-",
                         c="k",
                         linewidth=6,
                         alpha=.2)

            # BH corrected significance
            sig_arr.append(
                sort_df.loc[midx,
                            f"{pred_type}_pred_epochs_past_rej_h0"].values[0])

            # make an aligned label
            label_arr = [d, n, s, c]
            aligned = "".join(
                [label_arr[i].ljust(lengths[i]) for i in label_idxs])
            ylabels[i] = aligned

            # track vars
            i += 1

        # indicate BH corrected significance
        h = plt.plot(i + 100, 0, "k*", alpha=0.5)
        handles["p < 0.05"] = h[0]
        for i in range(len(sig_arr)):
            if sig_arr[i]:
                plt.plot(xmax + xmax / 12., i, "k*", alpha=0.5)

        # determine padding for labels
        max_length = np.max([len(l) for l in ylabels.values()])

        # add handles
        if cross_family is None:
            h1 = plt.gca().axhline(i + 100,
                                   color="k",
                                   linestyle="-",
                                   alpha=0.5)
            h2 = plt.gca().axhline(i + 100,
                                   color="k",
                                   linestyle=":",
                                   alpha=0.5)
            handles["cross-family"] = h1
            handles["within-family"] = h2

        # set figure text
        plt.xlabel(
            f"N epochs to reach {self.stats_processor.pct}% peak accuracy relative to {pred_type} prediction",
            fontsize=16,
            labelpad=10)
        plt.ylabel("Network configuration", fontsize=16, labelpad=10)
        plt.yticks(list(ylabels.keys()), ylabels.values(), ha="left")
        plt.ylim(-0.5, i + 0.5)
        plt.legend(handles.values(),
                   handles.keys(),
                   fontsize=14,
                   loc="lower left")
        plt.xlim([xmin - xmax / 10., xmax + xmax / 10.])
        yax = plt.gca().get_yaxis()
        yax.set_tick_params(pad=max_length * 7)
        plt.tight_layout()

        # optional saving
        if not self.save_fig:
            print("Not saving.")
            plt.show()
            return

        sub_dir = ensure_sub_dir(self.data_dir, f"figures/learning prediction")
        net_names = ", ".join(net_names)
        schemes = ", ".join(schemes)
        filename = f"{dataset}_{net_names}_{schemes}_{pred_type}-learning-prediction"
        if cross_family == True:
            filename += "_xfam"
        elif cross_family == False:
            filename += "_infam"
        filename = os.path.join(sub_dir, filename)
        print(f"Saving... {filename}")
        plt.savefig(f"{filename}.svg")
        plt.savefig(f"{filename}.png", dpi=300)
    def plot_accuracy(self,
                      dataset,
                      net_name,
                      schemes,
                      groups,
                      cases,
                      inset=True):
        """
        Plots accuracy over training for different experimental cases.

        Args:
            dataset
            net_name
            schemes
            cases (list): Experimental cases to include in figure.

        Returns:
            None.

        """
        # pull data
        acc_df = self.stats_processor.load_accuracy_df(dataset, net_name,
                                                       schemes, cases)

        # process a bit
        acc_df = acc_df.query(f"group in {groups}")
        index_cols = ["dataset", "net_name", "train_scheme", "case", "epoch"]
        acc_df.set_index(index_cols, inplace=True)
        df_stats = acc_df.groupby(index_cols).agg(
            {"val_acc": [np.mean, np.std]})

        # group
        df_groups = df_stats.groupby(index_cols[:-1])

        # plot
        fig, ax = plt.subplots(figsize=(14, 8))
        clrs = sns.color_palette("hls", len(df_groups.groups))

        y_arr = []
        yerr_arr = []
        for group, clr in zip(df_groups.groups, clrs):

            d, n, s, c = group
            group_data = df_groups.get_group(group)

            # error bars = 2 standard devs
            yvals = group_data["val_acc"]["mean"].values * 100
            yerr = group_data["val_acc"]["std"].values * 1.98 * 100
            ax.plot(range(len(yvals)), yvals, label=f"{s} {c}", c=clr)
            ax.fill_between(range(len(yvals)),
                            yvals - yerr,
                            yvals + yerr,
                            alpha=0.1,
                            facecolor=clr)

            # for the insert...
            y_arr.append(yvals)
            yerr_arr.append(yerr)

        # zoomed inset
        if inset:
            axins = zoomed_inset_axes(ax, zoom=10, loc=8)
            for yvals, yerr, clr in zip(y_arr, yerr_arr, clrs):
                nlast = 10
                x = [i for i in range(len(yvals) - nlast, len(yvals))]
                y_end = yvals[-nlast:]
                yerr_end = yerr[-nlast:]
                axins.plot(x, y_end, c=clr)
                axins.fill_between(x,
                                   y_end - yerr_end,
                                   y_end + yerr_end,
                                   alpha=0.1,
                                   facecolor=clr)

            mark_inset(ax, axins, loc1=2, loc2=4, fc="none", ec="0.5")
            axins.xaxis.set_ticks([])

        ax.set_title(
            f"Classification accuracy during training: {net_name} on {dataset}",
            fontsize=20)
        ax.set_xlabel("Epoch", fontsize=16)
        ax.set_ylabel("Validation accuracy (%)", fontsize=16)
        ax.set_ylim([0, 100])
        ax.legend(fontsize=14)

        plt.tight_layout()

        # optional saving
        if not self.save_fig:
            print("Not saving.")
            plt.show()
            return

        sub_dir = ensure_sub_dir(self.data_dir, f"figures/accuracy/")
        case_names = ", ".join(cases)
        schemes = ", ".join(schemes)
        filename = f"{dataset}_{net_name}_{schemes}_{case_names} accuracy"
        filename = os.path.join(sub_dir, filename)
        print(f"Saving... {filename}")
        plt.savefig(f"{filename}.svg")
        plt.savefig(f"{filename}.png", dpi=300)
        # figure text
        # fig.suptitle(f"Classification accuracy during training: {net_name} on {dataset}", fontsize=large_font_size)
        ax.set_xlabel("Epoch", fontsize=large_font_size)
        ax.set_ylabel(f"{acc_type} accuracy (%)", fontsize=large_font_size)
        ax.set_ylim([10, 100])
        ax.legend(fontsize=small_font_size)

        plt.tight_layout()

        # optional saving
        if not self.save_fig:
            print("Not saving.")
            plt.show()
            return

        sub_dir = ensure_sub_dir(self.data_dir, f"figures/accuracy/")
        filename = f"{dataset}_{net_name}_{scheme}_{case} accuracy"
        filename = os.path.join(sub_dir, filename)
        print(f"Saving... {filename}")
        plt.savefig(f"{filename}.svg")
        plt.savefig(f"{filename}.png", dpi=300)

    def plot_single_accuracy(self,
                             dataset,
                             net_name,
                             scheme,
                             case,
                             metric="z",
                             sample=0):
        """
        Plot single net accuracy trajectory with windowed z score