Exemple #1
0
def plot_per_day(data: dict[str, dict[datetime, int]],
                 sort: bool = True,
                 cut: float | None = None):
    # convert to list of tuples to be able to sort
    data = list(data.items())

    if sort:
        # sort them such that the largest at the last time step gets plotted first and
        # the colors are in a nice order
        last_vals = [
            list(vals.values())[-1] - list(vals.values())[-2]
            for _, vals in data
        ]
        data = [data[i] for i in _argsort(last_vals)[::-1]]

    if cut is not None:
        # cut those files where the latest data is less than cut*max_latest
        last_vals = [
            list(vals.values())[-1] - list(vals.values())[-2]
            for _, vals in data
        ]

        max_overall = max(last_vals)
        data = [(tag, vals) for (tag, vals), last_val in zip(data, last_vals)
                if last_val > cut * max_overall]

    times = []
    values = []
    labels = []
    for tag, vals in data:
        t = list(vals.keys())
        v = list(vals.values())
        times.append(_get_middle_times(t))
        values.append(_get_avg_per_day(t, v))
        labels.append(tag)

    # start plotting from the 0 before the first value
    for j, (tm, val) in enumerate(zip(times, values)):
        for i, x in enumerate(val):
            if x > 0:
                k = max(i - 1, 0)
                break
        times[j] = tm[k:]
        values[j] = val[k:]

    n = len(times)
    for k, (time, vals, label) in enumerate(zip(times, values, labels)):
        plt.plot(time, vals, label=label, zorder=n - k)

    matplotx.line_labels()

    return plt
Exemple #2
0
def plot_scalar(
        y: np.ndarray,
        x: np.ndarray = None,
        label: str = None,
        xlabel: str = None,
        ylabel: str = None,
        fig_axes: FigAxes = None,
        outfile: os.PathLike = None,
        **kwargs,
) -> FigAxes:
    assert len(y.shape) == 1
    if x is None:
        x = np.arange(len(y))

    if fig_axes is None:
        fig, ax = plt.subplots()
    else:
        fig, ax = fig_axes

    _ = ax.plot(x, y, label=label, **kwargs)
    if xlabel is not None:
        _ = ax.set_xlabel(xlabel)
    if ylabel is not None:
        _ = ax.set_ylabel(ylabel)
    if label is not None:
        _ = matplotx.line_labels()

    if outfile is not None:
        savefig(fig, outfile)

    return fig, ax
Exemple #3
0
def plot_leapfrogs(
        y: np.ndarray,
        x: np.ndarray = None,
        fig_axes: FigAxes = None,
        xlabel: str = None,
        ylabel: str = None,
        outfile: os.PathLike = None,
) -> FigAxes:
    assert len(y.shape) == 3

    if fig_axes is None:
        fig, ax = plt.subplots()
    else:
        fig, ax = fig_axes

    if x is None:
        x = np.arange(y.shape[0])

    # y.shape = [ndraws, nleapfrog, nchains]
    nlf = y.shape[1]
    yavg = y.mean(-1)
    cmap = plt.get_cmap('viridis')
    colors = {n: cmap(n / nlf) for n in range(nlf)}
    for lf in range(nlf):
        _ = ax.plot(x, yavg[:, lf], color=colors[lf], label=f'{lf}')

    _ = matplotx.line_labels()

    if xlabel is not None:
        _ = ax.set_xlabel(xlabel)
    if ylabel is not None:
        _ = ax.set_ylabel(ylabel)

    if outfile is not None:
        savefig(fig, outfile)

    return fig, ax
Exemple #4
0
    def plot_dataArray(
        self,
        val: xr.DataArray,
        key: str = None,
        therm_frac: float = 0.,
        num_chains: int = 0,
        title: str = None,
        outdir: str = None,
        subplots_kwargs: dict[str, Any] = None,
        plot_kwargs: dict[str, Any] = None,
    ) -> tuple:
        plot_kwargs = {} if plot_kwargs is None else plot_kwargs
        subplots_kwargs = {} if subplots_kwargs is None else subplots_kwargs
        figsize = subplots_kwargs.get('figsize', set_size())
        subplots_kwargs.update({'figsize': figsize})
        subfigs = None

        # tmp = val[0]
        arr = val.values  # shape: [nchains, ndraws]
        steps = np.arange(arr.shape[0])

        if therm_frac > 0:
            drop = int(therm_frac * arr.shape[0])
            arr = arr[drop:]
            steps = steps[drop:]

        if len(arr.shape) == 2:
            _ = subplots_kwargs.pop('constrained_layout', True)
            figsize = (3 * figsize[0], 1.5 * figsize[1])

            fig = plt.figure(figsize=figsize, constrained_layout=True)
            subfigs = fig.subfigures(1, 2)

            gs_kw = {'width_ratios': [1.33, 0.33]}
            (ax, ax1) = subfigs[1].subplots(1,
                                            2,
                                            sharey=True,
                                            gridspec_kw=gs_kw)
            ax.grid(alpha=0.2)
            ax1.grid(False)
            color = plot_kwargs.get('color', None)
            label = r'$\langle$' + f' {key} ' + r'$\rangle$'
            ax.plot(steps,
                    arr.mean(-1),
                    lw=1.5 * LW,
                    label=label,
                    **plot_kwargs)
            sns.kdeplot(y=arr.flatten(), ax=ax1, color=color, shade=True)
            ax1.set_xticks([])
            ax1.set_xticklabels([])
            sns.despine(ax=ax, top=True, right=True)
            sns.despine(ax=ax1, top=True, right=True, left=True, bottom=True)
            ax1.set_xlabel('')
            # _ = subfigs[1].subplots_adjust(wspace=-0.75)
            axes = (ax, ax1)

            ax0 = subfigs[0].subplots(1, 1)
            im = val.plot(ax=ax0)
            im.colorbar.set_label(key)
            sns.despine(subfigs[0])
            ax0.plot(steps, arr.mean(0), lw=2., color=color)
            for idx in range(min(num_chains, arr.shape[0])):
                ax0.plot(steps, arr[idx, :], lw=1., alpha=0.7, color=color)

        else:
            if len(arr.shape) == 1:
                fig, ax = plt.subplots(**subplots_kwargs)
                ax.plot(steps, arr, **plot_kwargs)
                axes = ax
            elif len(arr.shape) == 3:
                fig, ax = plt.subplots(**subplots_kwargs)
                cmap = plt.get_cmap('viridis')
                nlf = arr.shape[1]
                for idx in range(nlf):
                    y = arr[:, idx, :].mean(-1)
                    pkwargs = {
                        'color': cmap(idx / nlf),
                        'label': f'{idx}',
                    }
                    ax.plot(steps, y, **pkwargs)
                axes = ax
            else:
                raise ValueError('Unexpected shape encountered')

            ax.set_ylabel(key)

            if num_chains > 0 and len(arr.shape) > 1:
                lw = LW / 2.
                for idx in range(min(num_chains, arr.shape[1])):
                    # ax = subfigs[0].subplots(1, 1)
                    # plot values of invidual chains, arr[:, idx]
                    # where arr[:, idx].shape = [ndraws, 1]
                    ax.plot(steps,
                            arr[:, idx],
                            alpha=0.5,
                            lw=lw / 2.,
                            **plot_kwargs)

        matplotx.line_labels()
        ax.set_xlabel('draw')
        if title is not None:
            fig.suptitle(title)

        if outdir is not None:
            plt.savefig(Path(outdir).joinpath(f'{key}.svg'),
                        dpi=400,
                        bbox_inches='tight')

        return (fig, subfigs, axes)
Exemple #5
0
    def plot(
        self,
        val: torch.Tensor,
        key: str = None,
        therm_frac: float = 0.,
        num_chains: int = 0,
        title: str = None,
        outdir: str = None,
        subplots_kwargs: dict[str, Any] = None,
        plot_kwargs: dict[str, Any] = None,
    ):
        plot_kwargs = {} if plot_kwargs is None else plot_kwargs
        subplots_kwargs = {} if subplots_kwargs is None else subplots_kwargs
        figsize = subplots_kwargs.get('figsize', set_size())
        subplots_kwargs.update({'figsize': figsize})

        tmp = val[0]
        if isinstance(tmp, torch.Tensor):
            arr = val.detach().numpy()
        elif isinstance(tmp, float):
            arr = np.array(val)
        else:
            try:
                arr = np.array([np.array(i) for i in val])
            except (AttributeError, ValueError) as exc:
                raise exc

        subfigs = None
        steps = np.arange(arr.shape[0])
        if therm_frac > 0:
            drop = int(therm_frac * arr.shape[0])
            arr = arr[drop:]
            steps = steps[drop:]

        if len(arr.shape) == 2:
            _ = subplots_kwargs.pop('constrained_layout', True)
            figsize = (3 * figsize[0], 1.5 * figsize[1])

            fig = plt.figure(figsize=figsize, constrained_layout=True)
            # subfigs = fig.subfigures((1, 2), wspace=0.01)#, width_ratios=[1., 1.5])
            subfigs = fig.subfigures(
                1, 2)  #, wspace=0.1)#, width_ratios=[1., 1.5])

            gs_kw = {'width_ratios': [1.33, 0.33]}
            (ax, ax1) = subfigs[1].subplots(1,
                                            2,
                                            sharey=True,
                                            gridspec_kw=gs_kw)
            ax.grid(alpha=0.2)
            ax1.grid(False)
            # (ax, ax1) = fig.subfigures(1, 1).subplots(1, 2)
            # gs = fig.add_gridspec(ncols=3, nrows=1, width_ratios=[1.5, 1., 1.5])
            color = plot_kwargs.get('color', None)
            label = r'$\langle$' + f' {key} ' + r'$\rangle$'
            ax.plot(steps,
                    arr.mean(-1),
                    lw=1.5 * LW,
                    label=label,
                    **plot_kwargs)
            sns.kdeplot(y=arr.flatten(), ax=ax1, color=color, shade=True)
            ax1.set_xticks([])
            ax1.set_xticklabels([])
            # ax1.set_yticks([])
            # ax1.set_yticklabels([])
            sns.despine(ax=ax, top=True, right=True)
            sns.despine(ax=ax1, top=True, right=True, left=True, bottom=True)
            # ax.legend(loc='best', frameon=False)
            ax1.set_xlabel('')
            # ax1.set_ylabel('')
            # ax.set_yticks(ax.get_yticks())
            # ax.set_yticklabels(ax.get_yticklabels())
            # ax.set_ylabel(key)
            # _ = subfigs[1].subplots_adjust(wspace=-0.75)
            axes = (ax, ax1)
        else:
            if len(arr.shape) == 1:
                fig, ax = plt.subplots(**subplots_kwargs)
                ax.plot(steps, arr, **plot_kwargs)
                axes = ax
            elif len(arr.shape) == 3:
                fig, ax = plt.subplots(**subplots_kwargs)
                for idx in range(arr.shape[1]):
                    ax.plot(steps,
                            arr[:, idx, :].mean(-1),
                            label='idx',
                            **plot_kwargs)
                axes = ax
            else:
                raise ValueError('Unexpected shape encountered')

            ax.set_ylabel(key)
        if num_chains > 0 and len(arr.shape) > 1:
            lw = LW / 2.
            for idx in range(min(num_chains, arr.shape[1])):
                # plot values of invidual chains, arr[:, idx]
                # where arr[:, idx].shape = [ndraws, 1]
                ax.plot(steps,
                        arr[:, idx],
                        alpha=0.5,
                        lw=lw / 2.,
                        **plot_kwargs)

        matplotx.line_labels()
        ax.set_xlabel('draw')
        if title is not None:
            fig.suptitle(title)

        if outdir is not None:
            plt.savefig(Path(outdir).joinpath(f'{key}.svg'),
                        dpi=400,
                        bbox_inches='tight')

        return fig, subfigs, axes
Exemple #6
0
    def plot(  # noqa: C901
        self,
        time_unit: str = "s",
        relative_to: int | None = None,
        logx: str | bool = "auto",
        logy: str | bool = "auto",
    ):
        if logx == "auto":
            # Check if the x values are approximately equally spaced in log
            if np.any(self.n_range <= 0):
                logx = False
            else:
                log_n_range = np.log(self.n_range)
                linlog = np.linspace(log_n_range[0], log_n_range[-1],
                                     len(log_n_range))
                # don't consider first and last, they are equal anyway
                rel_diff = (log_n_range - linlog)[1:-1] / log_n_range[1:-1]
                logx = np.all(np.abs(rel_diff) <= 0.1)

        if logy == "auto":
            if relative_to is not None:
                logy = False
            elif self.flop is not None:
                logy = False
            else:
                logy = logx

        if logx and logy:
            plotfun = plt.loglog
        elif logx:
            plotfun = plt.semilogx
        elif logy:
            plotfun = plt.semilogy
        else:
            plotfun = plt.plot

        if self.flop is None:
            if relative_to is None:
                # Set time unit of plots.
                # Allowed values: ("s", "ms", "us", "ns", "auto")
                if time_unit == "auto":
                    time_unit = _auto_time_unit(np.min(self.timings_s))
                else:
                    assert time_unit in si_time, "Provided `time_unit` is not valid"

                scaled_timings = self.timings_s / si_time[time_unit]
                ylabel = f"Runtime [{time_unit}]"
            else:
                scaled_timings = self.timings_s / self.timings_s[relative_to]
                ylabel = f"Runtime\nrelative to {self.labels[relative_to]}"

            for t, label in zip(scaled_timings, self.labels):
                plotfun(self.n_range, t, label=label)

            matplotx.ylabel_top(ylabel)
        else:
            if relative_to is None:
                flops = self.flop / self.timings_s
                plt.title("FLOPS")
            else:
                flops = self.timings_s[relative_to] / self.timings_s
                plt.title(f"FLOPS relative to {self.labels[relative_to]}")

            for fl, label in zip(flops, self.labels):
                plotfun(self.n_range, fl, label=label)

        if self.xlabel:
            plt.xlabel(self.xlabel)
        if self.title:
            plt.title(self.title)
        if relative_to is not None and not logy:
            plt.gca().set_ylim(bottom=0)

        matplotx.line_labels()
Exemple #7
0
def plot_metric(
        val: np.ndarray,
        key: str = None,
        therm_frac: float = 0.,
        num_chains: int = 0,
        title: str = None,
        outdir: os.PathLike = None,
        subplots_kwargs: dict[str, Any] = None,
        plot_kwargs: dict[str, Any] = None,
        ext: str = 'png',
) -> tuple:
    plot_kwargs = {} if plot_kwargs is None else plot_kwargs
    subplots_kwargs = {} if subplots_kwargs is None else subplots_kwargs
    figsize = subplots_kwargs.get('figsize', set_size())
    subplots_kwargs.update({'figsize': figsize})

    # tmp = val[0]
    arr = np.array(val)

    subfigs = None
    steps = np.arange(arr.shape[0])
    if therm_frac > 0:
        drop = int(therm_frac * arr.shape[0])
        arr = arr[drop:]
        steps = steps[drop:]

    # arr.shape = [draws, chains]
    if len(arr.shape) == 2:
        _ = subplots_kwargs.pop('constrained_layout', True)
        figsize = (3 * figsize[0], 1.5 * figsize[1])

        fig = plt.figure(figsize=figsize, constrained_layout=True)
        subfigs = fig.subfigures(1, 2)

        gs_kw = {'width_ratios': [1.33, 0.33]}
        (ax, ax1) = subfigs[1].subplots(1, 2, sharey=True,
                                        gridspec_kw=gs_kw)
        ax.grid(alpha=0.2)
        ax1.grid(False)
        color = plot_kwargs.get('color', 'C0')
        label = r'$\langle$' + f' {key} ' + r'$\rangle$'
        ax.plot(steps, arr.mean(-1), lw=1.5*LW, label=label, **plot_kwargs)
        if num_chains > 0:
            for chain in range(min((num_chains, arr.shape[1]))):
                plot_kwargs.update({'label': None})
                ax.plot(steps, arr[:, chain], lw=LW/2., **plot_kwargs)
        sns.kdeplot(y=arr.flatten(), ax=ax1, color=color, shade=True)
        ax1.set_xticks([])
        ax1.set_xticklabels([])
        sns.despine(ax=ax, top=True, right=True)
        sns.despine(ax=ax1, top=True, right=True, left=True, bottom=True)
        ax1.set_xlabel('')
        axes = (ax, ax1)
    else:
        # arr.shape = [draws]
        if len(arr.shape) == 1:
            fig, ax = plt.subplots(**subplots_kwargs)
            ax.plot(steps, arr, **plot_kwargs)
            axes = ax
        # arr.shape = [draws, nleapfrog, chains]
        elif len(arr.shape) == 3:
            fig, ax = plt.subplots(**subplots_kwargs)
            cmap = plt.get_cmap('viridis', lut=arr.shape[1])
            _ = plot_kwargs.pop('color', None)
            for idx in range(arr.shape[1]):
                y = arr[:, idx]
                color = cmap(idx / y.shape[1])
                if len(y.shape) == 2:
                    # TOO: Plot chains
                    if num_chains > 0:
                        for idx in range(min((num_chains, y.shape[1]))):
                            ax.plot(steps, y[:, idx], color=color,
                                    lw=LW/4., alpha=0.7, **plot_kwargs)

                    ax.plot(steps, y.mean(-1), color=color,
                            label=f'{idx}', **plot_kwargs)
                else:
                    ax.plot(steps, y, color=color,
                            label=f'{idx}', **plot_kwargs)
            axes = ax
        else:
            raise ValueError('Unexpected shape encountered')

        ax.set_ylabel(key)
    if num_chains > 0 and len(arr.shape) > 1:
        lw = LW / 2.
        for idx in range(min(num_chains, arr.shape[1])):
            # plot values of invidual chains, arr[:, idx]
            # where arr[:, idx].shape = [ndraws, 1]
            ax.plot(steps, arr[:, idx], alpha=0.5, lw=lw/2., **plot_kwargs)

    matplotx.line_labels()
    ax.set_xlabel('draw')
    if title is not None:
        fig.suptitle(title)

    if outdir is not None:
        outfile = Path(outdir).joinpath(f'{key}.{ext}')
        if not outfile.is_file():
            plt.savefig(Path(outdir).joinpath(f'{key}.{ext}'),
                        dpi=400, bbox_inches='tight')

    return fig, subfigs, axes
Exemple #8
0
def plot_dataArray(
        val: xr.DataArray,
        key: str = None,
        therm_frac: float = 0.,
        num_chains: int = 0,
        title: str = None,
        outdir: str = None,
        subplots_kwargs: dict[str, Any] = None,
        plot_kwargs: dict[str, Any] = None,

) -> tuple:
    plot_kwargs = {} if plot_kwargs is None else plot_kwargs
    subplots_kwargs = {} if subplots_kwargs is None else subplots_kwargs
    figsize = subplots_kwargs.get('figsize', set_size())
    subplots_kwargs.update({'figsize': figsize})
    subfigs = None

    if key == 'dt':
        therm_frac = 0.2

    # tmp = val[0]
    arr = val.values  # shape: [nchains, ndraws]
    steps = np.arange(len(val.coords['draw']))

    if therm_frac > 0:
        drop = int(therm_frac * arr.shape[0])
        arr = arr[drop:]
        steps = steps[drop:]

    if len(arr.shape) == 2:
        _ = subplots_kwargs.pop('constrained_layout', True)
        figsize = (3 * figsize[0], 1.5 * figsize[1])

        fig = plt.figure(figsize=figsize, constrained_layout=True)
        subfigs = fig.subfigures(1, 2)

        gs_kw = {'width_ratios': [1.33, 0.33]}
        (ax, ax1) = subfigs[1].subplots(1, 2, sharey=True,
                                        gridspec_kw=gs_kw)
        ax.grid(alpha=0.2)
        ax1.grid(False)
        vmin = np.min(val)
        vmax = np.max(val)
        cmap = None
        if vmin < 0 < vmax:
            # BWR: uniform cmap from blue -> white == 0 -> red
            color = '#FF5252' if val.mean() > 0 else '#007DFF'
            cmap = 'RdBu_r'
        elif 0 < vmin < vmax:
            # viridis: uniform cmap from 0 < blue -> green -> yellow
            cmap = 'mako'
            # color = '#4DC26B'
            # color = '#B2DD2D'
            color = '#3FB5AD'
        else:
            color = plot_kwargs.get('color', f'C{np.random.randint(4)}')

        label = r'$\langle$' + f' {key} ' + r'$\rangle$'
        ax.plot(steps, val.mean('chain'),
                label=label, lw=1.5*LW, **plot_kwargs)
        sns.kdeplot(y=arr.flatten(), ax=ax1, color=color, shade=True)
        ax1.set_xticks([])
        ax1.set_xticklabels([])
        sns.despine(ax=ax, top=True, right=True)
        sns.despine(ax=ax1, top=True, right=True, left=True, bottom=True)
        ax1.set_xlabel('')
        # _ = subfigs[1].subplots_adjust(wspace=-0.75)
        axes = (ax, ax1)

        ax0 = subfigs[0].subplots(1, 1)
        val = val.dropna('chain')
        nchains = min((num_chains, len(val.coords['chain'])))
        _ = xr.plot.pcolormesh(val, 'draw', 'chain',
                               ax=ax0, robust=True,
                               cmap=cmap, add_colorbar=True)

        if key is not None and 'eps' in key:
            _ = ax0.set_ylabel('leapfrog')

        sns.despine(subfigs[0])
        plt.autoscale(enable=True, axis=ax0)

        ax.plot(steps, arr.mean(0), lw=2., color='k', label='avg')
        for idx in range(min(num_chains, arr.shape[0])):
            ax.plot(steps, arr[idx, :], lw=1., alpha=0.7, color=color)

        ax.legend(loc='best')

    else:
        if len(arr.shape) == 1:
            fig, ax = plt.subplots(**subplots_kwargs)
            ax.plot(steps, arr, **plot_kwargs)
            axes = ax
        elif len(arr.shape) == 3:
            fig, ax = plt.subplots(**subplots_kwargs)
            cmap = plt.get_cmap('viridis')
            y = val.mean('chain')
            for idx in range(len(val.coords['leapfrog'])):
                pkwargs = {
                    'color': cmap(idx / len(val.coords['leapfrog'])),
                    'label': f'{idx}',
                }
                ax.plot(steps, y[idx], **pkwargs)
            axes = ax
        else:
            raise ValueError('Unexpected shape encountered')

        ax.set_ylabel(key)

        # if num_chains > 0 and len(arr.shape) > 1:
        #     lw = LW / 2.
        #     #for idx in range(min(num_chains, arr.shape[1])):
        #     nchains = len(val.coords['chains'])
        #     for idx in range(min(nchains, num_chains)):
        #         # ax = subfigs[0].subplots(1, 1)
        #         # plot values of invidual chains, arr[:, idx]
        #         # where arr[:, idx].shape = [ndraws, 1]
        #         ax.plot(steps, val
        #                 alpha=0.5, lw=lw/2., **plot_kwargs)

    matplotx.line_labels()
    ax.set_xlabel('draw')
    if title is not None:
        fig.suptitle(title)

    if outdir is not None:
        plt.savefig(Path(outdir).joinpath(f'{key}.svg'),
                    dpi=400, bbox_inches='tight')

    return (fig, subfigs, axes)
Exemple #9
0
    def plot(
        self,
        val: list | np.ndarray | tf.Tensor,
        key: str = None,
        therm_frac: float = 0.,
        num_chains: int = 0,
        title: str = None,
        outdir: str = None,
        subplots_kwargs: dict[str, Any] = None,
        plot_kwargs: dict[str, Any] = None,
        ext: str = 'svg',
    ):
        plot_kwargs = {} if plot_kwargs is None else plot_kwargs
        subplots_kwargs = {} if subplots_kwargs is None else subplots_kwargs
        figsize = subplots_kwargs.get('figsize', set_size())
        subplots_kwargs.update({'figsize': figsize})
        key = '' if key is None else key
        label = ' '.join([r'$\langle$', f'{key}', r'$\rangle$'])

        if isinstance(val[0], tf.Tensor):
            arr = val.numpy()
        elif isinstance(val[0], float):
            arr = np.array(val)
        else:
            try:
                arr = np.array([np.array(i) for i in val])
            except (AttributeError, ValueError) as exc:
                raise exc

        subfigs = None
        steps = np.arange(arr.shape[0])
        if therm_frac > 0:
            drop = int(therm_frac * arr.shape[0])
            arr = arr[drop:]
            steps = steps[drop:]

        if len(arr.shape) == 2:
            _ = subplots_kwargs.pop('constrained_layout', True)
            figsize = (3 * figsize[0], 1.5 * figsize[1])

            fig = plt.figure(figsize=figsize, constrained_layout=True)
            subfigs = fig.subfigures(1, 2)
            # , wspace=0.1)#, width_ratios=[1., 1.5])

            gs_kw = {'width_ratios': [1.33, 0.33]}
            (ax, ax1) = subfigs[1].subplots(1,
                                            2,
                                            sharey=True,
                                            gridspec_kw=gs_kw)
            ax.grid(alpha=0.2)
            ax1.grid(False)
            color = plot_kwargs.get('color', None)
            ax.plot(steps,
                    arr.mean(-1),
                    lw=1.5 * LW,
                    label=label,
                    **plot_kwargs)
            sns.kdeplot(y=arr.flatten(), ax=ax1, color=color, shade=True)
            ax1.set_xticks([])
            ax1.set_xticklabels([])
            sns.despine(ax=ax, top=True, right=True)
            sns.despine(ax=ax1, top=True, right=True, left=True, bottom=True)
            ax1.set_xlabel('')
            ax1.set_ylabel('')
            axes = (ax, ax1)
            matplotx.line_labels(ax=ax)
        else:
            if len(arr.shape) == 1:
                fig, ax = plt.subplots(**subplots_kwargs)
                ax.plot(steps, arr, label=label, **plot_kwargs)
                axes = ax
                matplotx.line_labels(ax=ax)
            elif len(arr.shape) == 3:
                fig, ax = plt.subplots(**subplots_kwargs)
                for idx in range(arr.shape[1]):
                    if idx == 0:
                        ax.plot(steps,
                                arr[:, idx, :].mean(-1),
                                label=label,
                                **plot_kwargs)
                    else:
                        ax.plot(steps, arr[:, idx, :].mean(-1), **plot_kwargs)
                axes = ax
                matplotx.line_labels(ax=ax)
            else:
                raise ValueError('Unexpected shape encountered')

            ax.set_ylabel(key)

        if num_chains > 0 and len(arr.shape) > 1:
            for idx in range(min(num_chains, arr.shape[1])):
                # plot values of invidual chains, arr[:, idx]
                # where arr[:, idx].shape = [ndraws, 1]
                ax.plot(steps,
                        arr[:, idx],
                        alpha=0.4,
                        lw=LW / 5.,
                        **plot_kwargs)

        ax.set_xlabel('draw')
        if title is not None:
            fig.suptitle(title)

        if outdir is not None:
            fig.savefig(Path(outdir).joinpath(f'{key}.{ext}'))

        return fig, subfigs, axes