Ejemplo n.º 1
0
    def __init__(self, fignum, stats, key0, plot_u, E, P, **kwargs):
        fig, ax = freshfig(fignum, (6, 3), loc='3333')
        ax.set_xlabel('Sing. value index')
        ax.set_yscale('log')
        self.init_incomplete = True
        self.ax = ax
        self.plot_u = plot_u

        try:
            self.msft = stats.umisf
            self.sprd = stats.svals
        except AttributeError:
            self.is_active = False
            not_available_text(ax, "Spectral stats not being computed")
Ejemplo n.º 2
0
    def __init__(self, fignum, stats, key0, plot_u, E, P, **kwargs):
        if not hasattr(stats, 'w'):
            self.is_active = False
            return
        fig, ax = freshfig(fignum, (7, 3),
                           loc='3323',
                           gridspec_kw={'bottom': .15})

        ax.set_xscale('log')
        ax.set_xlabel('Weigth')
        ax.set_ylabel('Count')
        self.stats = stats
        self.ax = ax
        self.hist = []
        self.bins = np.exp(np.linspace(np.log(1e-10), np.log(1), 31))
Ejemplo n.º 3
0
    def plot(self,
             statkey="rmse.a",
             axes=AXES_ROLES,
             get_style=default_styles,
             fignum=None,
             figsize=None,
             panels=None,
             title2=None,
             costfun=None,
             unique_labels=True):
        """Plot the avrgs of `statkey` as a function of `axis["inner"]`.

        Optionally, the experiments can be grouped by `axis["outer"]`,
        producing a figure with columns of panels.
        Firs of all, though, mean and optimum computations are done for
        `axis["mean"]` and `axis["optim"]`, where the optimization can
        be controlled through `costfun` (see `xpSpace.tune`)

        This is entirely analogous to the roles of `axis` in `xpSpace.print`.

        The optimal parameters are plotted in smaller panels below the main plot.
        This can be prevented by providing the figure axes through the `panels` arg.
        """
        def plot1(panelcol, row, style):
            """Plot a given line (row) in the main panel and the optim panels.

            Involves: Sort, insert None's, handle constant lines.
            """
            # Make a full row (yy) of vals, whether is_constant or not.
            # row.is_constant = (len(row)==1 and next(iter(row))==row.Coord(None))
            row.is_constant = all(x == row.Coord(None) for x in row)
            yy = [
                row[0] if row.is_constant else y for y in row.get_for(xticks)
            ]

            # Plot main
            row.vals = [getattr(y, 'val', None) for y in yy]
            row.handles = {}
            row.handles["main_panel"] = panelcol[0].plot(
                xticks, row.vals, **style)[0]

            # Plot tuning params
            row.tuned_coords = {}  # Store ordered, "transposed" argmins
            argmins = [getattr(y, 'tuned_coord', None) for y in yy]
            for a, panel in zip(axes["optim"], panelcol[1:]):
                yy = [getattr(coord, a, None) for coord in argmins]
                row.tuned_coords[a] = yy

                # Plotting all None's sets axes units (like any plotting call)
                # which can cause trouble if the axes units were actually supposed
                # to be categorical (eg upd_a), but this is only revealed later.
                if not all(y == None for y in yy):
                    row.handles[a] = panel.plot(xticks, yy, **style)

        # Nest axes through table_tree()
        assert len(axes["inner"]) == 1, "You must chose the abscissa."
        axes, tables = self.table_tree(statkey, axes)
        xticks = self.tickz(axes["inner"][0])

        # Figure panels
        if panels is None:
            nrows = len(axes['optim'] or ()) + 1
            ncols = len(tables)
            maxW = 12.7  # my mac screen
            figsize = figsize or (min(5 * ncols, maxW), 7)
            gs = dict(
                height_ratios=[6] + [1] * (nrows - 1),
                hspace=0.05,
                wspace=0.05,
                # eyeballed:
                left=0.15 / (1 + np.log(ncols)),
                right=0.97,
                bottom=0.06,
                top=0.9)
            # Create
            _, panels = freshfig(num=fignum,
                                 figsize=figsize,
                                 nrows=nrows,
                                 sharex=True,
                                 ncols=ncols,
                                 sharey='row',
                                 gridspec_kw=gs)
            panels = np.ravel(panels).reshape((-1, ncols))
        else:
            panels = np.atleast_2d(panels)

        # Title
        fig = panels[0, 0].figure
        fig_title = "Average wrt. time"
        if axes["mean"] is not None:
            fig_title += f" and {axes['mean']}"
        if title2 is not None:
            fig_title += "\n" + str(title2)
        fig.suptitle(fig_title)

        # Loop outer
        label_register = set()  # mv inside loop to get legend on each panel
        for table_panels, (table_coord, table) in zip(panels.T,
                                                      tables.items()):
            table.panels = table_panels
            title = '' if axes["outer"] is None else repr(table_coord)

            # Plot
            for coord, row in table.items():
                style = get_style(coord)

                # Rm duplicate labels (contrary to coords, labels can
                # be "tampered" with, and so can be duplicate)
                if unique_labels:
                    if style.get("label", None) in label_register:
                        del style["label"]
                    else:
                        label_register.add(style["label"])

                plot1(table.panels, row, style)

            # Beautify
            panel0 = table.panels[0]
            panel0.set_title(title)
            if panel0.is_first_col():
                panel0.set_ylabel(statkey)
            with set_tmp(mpl_logger, 'level', 99):  # silence "no label" msg
                panel0.legend()
            table.panels[-1].set_xlabel(axes["inner"][0])
            # Tuning panels:
            for a, panel in zip(axes["optim"] or (), table.panels[1:]):
                if panel.is_first_col():
                    panel.set_ylabel(f"Optim.\n{a}")

        tables.fig = fig
        tables.xp_dict = self
        tables.axes_roles = axes
        return tables
Ejemplo n.º 4
0
    def init(fignum, stats, key0, plot_u, E, P, **kwargs):
        xx, yy, mu, _, chrono = \
            stats.xx, stats.yy, stats.mu, stats.std, stats.HMM.t

        # Set parameters (kwargs takes precedence over params_orig)
        p = DotDict(
            **{kw: kwargs.get(kw, val)
               for kw, val in params_orig.items()})

        # Lag settings:
        has_w = hasattr(stats, 'w')
        if p.Tplot == 0:
            K_plot = 1
        else:
            T_lag, K_lag, a_lag = validate_lag(p.Tplot, chrono)
            K_plot = comp_K_plot(K_lag, a_lag, plot_u)
            # Extend K_plot forther for adding blanks in resampling (PartFilt):
            if has_w:
                K_plot += a_lag

        # Dimension settings
        if p.dims == []:
            p.dims = arange(M)
        if p.labels == []:
            p.labels = ["$x_%d$" % d for d in p.dims]
        assert len(p.dims) == M

        # Set up figure, axes
        fig, _ = freshfig(fignum, figsize=(5, 5), loc='2321')
        ax = plt.subplot(111, projection='3d' if is_3d else None)
        ax.set_facecolor('w')
        ax.set_title("Phase space trajectories")
        # Tune plot
        for ind, (s, i, t) in enumerate(zip(p.labels, p.dims, "xyz")):
            viz.set_ilim(ax, ind,
                         *viz.stretch(*viz.xtrema(xx[:, i]), 1 / p.zoom))
            eval("ax.set_%slabel('%s')" % (t, s))

        # Allocate
        d = DotDict()  # data arrays
        h = DotDict()  # plot handles
        s = DotDict()  # scatter handles
        if E is not None:
            d.E = RollingArray((K_plot, len(E), M))
            h.E = []
        if P is not None:
            d.mu = RollingArray((K_plot, M))
        if True:
            d.x = RollingArray((K_plot, M))
        if list(p.obs_inds) == list(p.dims):
            d.y = RollingArray((K_plot, M))

        # Plot tails (invisible coz everything here is nan, for the moment).
        if 'E' in d:
            h.E += [
                ax.plot(*xn, **p.ens_props)[0]
                for xn in np.transpose(d.E, [1, 2, 0])
            ]
        if 'mu' in d:
            h.mu = ax.plot(*d.mu.T, 'b', lw=2)[0]
        if True:
            h.x = ax.plot(*d.x.T, 'k', lw=3)[0]
        if 'y' in d:
            h.y = ax.plot(*d.y.T, 'g*', ms=14)[0]

        # Scatter. NB: don't init with nan's coz it's buggy
        # (wrt. get_color() and _offsets3d) since mpl 3.1.
        if 'E' in d:
            s.E = ax.scatter(*E.T[p.dims],
                             s=3**2,
                             c=[hn.get_color() for hn in h.E])
        if 'mu' in d:
            s.mu = ax.scatter(*ones(M), s=8**2, c=[
                h.mu.get_color(),
            ])
        if True:
            s.x = ax.scatter(*ones(M),
                             s=14**2,
                             c=[
                                 h.x.get_color(),
                             ],
                             marker=(5, 1),
                             zorder=99)

        def update(key, E, P):
            k, kObs, faus = key
            show_y = 'y' in d and kObs is not None

            def update_tail(handle, newdata):
                handle.set_data(newdata[:, 0], newdata[:, 1])
                if is_3d:
                    handle.set_3d_properties(newdata[:, 2])

            def update_sctr(handle, newdata):
                if is_3d:
                    handle._offsets3d = juggle_axes(*newdata.T, 'z')
                else:
                    handle.set_offsets(newdata)

            EE = duplicate_with_blanks_for_resampled(E, p.dims, key, has_w)

            # Roll data array
            ind = k if plot_u else kObs
            for Ens in EE:  # If E is duplicated, so must the others be.
                if 'E' in d:
                    d.E.insert(ind, Ens)
                if True:
                    d.x.insert(ind, xx[k, p.dims])
                if 'y' in d:
                    d.y.insert(ind, yy[kObs, :] if show_y else nan * ones(M))
                if 'mu' in d:
                    d.mu.insert(ind, mu[key][p.dims])

            # Update graph
            update_sctr(s.x, d.x[[-1]])
            update_tail(h.x, d.x)
            if 'y' in d:
                update_tail(h.y, d.y)
            if 'mu' in d:
                update_sctr(s.mu, d.mu[[-1]])
                update_tail(h.mu, d.mu)
            else:
                update_sctr(s.E, d.E[-1])
                for n in range(len(E)):
                    update_tail(h.E[n], d.E[:, n, :])
                update_alpha(key, stats, h.E, s.E)

            return

        return update
Ejemplo n.º 5
0
    def init(fignum, stats, key0, plot_u, E, P, **kwargs):
        xx, yy, mu, std, chrono = \
            stats.xx, stats.yy, stats.mu, stats.std, stats.HMM.t

        # Set parameters (kwargs takes precedence over params_orig)
        p = DotDict(
            **{kw: kwargs.get(kw, val)
               for kw, val in params_orig.items()})

        # Lag settings:
        T_lag, K_lag, a_lag = validate_lag(p.Tplot, chrono)
        K_plot = comp_K_plot(K_lag, a_lag, plot_u)
        # Extend K_plot forther for adding blanks in resampling (PartFilt):
        has_w = hasattr(stats, 'w')
        if has_w:
            K_plot += a_lag

        # Chose marginal dims to plot
        if p.dims == []:
            Nx = min(10, xx.shape[-1])
            DimsX = dapper.tools.math.linspace_int(xx.shape[-1], Nx)
        else:
            Nx = len(p.dims)
            DimsX = p.dims
        # Pre-process obs dimensions
        # Rm inds of obs if not in DimsX
        iiY = [i for i, m in enumerate(p.obs_inds) if m in DimsX]
        # Rm obs_inds    if not in DimsX
        DimsY = [m for i, m in enumerate(p.obs_inds) if m in DimsX]
        # Get dim (within y) of each x
        DimsY = [DimsY.index(m) if m in DimsY else None for m in DimsX]
        Ny = len(iiY)

        # Set up figure, axes
        fig, axs = freshfig(fignum, (5, 7),
                            loc='231-22',
                            nrows=Nx,
                            sharex=True)
        if Nx == 1:
            axs = [axs]

        # Tune plots
        axs[0].set_title("Marginal time series")
        for ix, (m, ax) in enumerate(zip(DimsX, axs)):
            ax.set_ylim(*viz.stretch(*viz.xtrema(xx[:, m]), 1 / p.zoomy))
            if p.labels == []:
                ax.set_ylabel("$x_{%d}$" % m)
            else:
                ax.set_ylabel(p.labels[ix])
        axs[-1].set_xlabel('Time (t)')

        plot_pause(0.05)
        plt.tight_layout()

        # Allocate
        d = DotDict()  # data arrays
        h = DotDict()  # plot handles
        # Why "if True" ? Just to indent the rest of the line...
        if True:
            d.t = RollingArray((K_plot, ))
        if True:
            d.x = RollingArray((K_plot, Nx))
            h.x = []
        if True:
            d.y = RollingArray((K_plot, Ny))
            h.y = []
        if E is not None:
            d.E = RollingArray((K_plot, len(E), Nx))
            h.E = []
        if P is not None:
            d.mu = RollingArray((K_plot, Nx))
            h.mu = []
        if P is not None:
            d.s = RollingArray((K_plot, 2, Nx))
            h.s = []

        # Plot (invisible coz everything here is nan, for the moment).
        for ix, (m, iy, ax) in enumerate(zip(DimsX, DimsY, axs)):
            if True:
                h.x += ax.plot(d.t, d.x[:, ix], 'k')
            if iy != None:
                h.y += ax.plot(d.t, d.y[:, iy], 'g*', ms=10)
            if 'E' in d:
                h.E += [ax.plot(d.t, d.E[:, :, ix], **p.ens_props)]
            if 'mu' in d:
                h.mu += ax.plot(d.t, d.mu[:, ix], 'b')
            if 's' in d:
                h.s += [ax.plot(d.t, d.s[:, :, ix], 'b--', lw=1)]

        def update(key, E, P):
            k, kObs, faus = key

            EE = duplicate_with_blanks_for_resampled(E, DimsX, key, has_w)

            # Roll data array
            ind = k if plot_u else kObs
            for Ens in EE:  # If E is duplicated, so must the others be.
                if 'E' in d:
                    d.E.insert(ind, Ens)
                if 'mu' in d:
                    d.mu.insert(ind, mu[key][DimsX])
                if 's' in d:
                    d.s.insert(ind,
                               mu[key][DimsX] + [[1], [-1]] * std[key][DimsX])
                if True:
                    d.t.insert(ind, chrono.tt[k])
                if True:
                    d.y.insert(
                        ind, yy[kObs,
                                iiY] if kObs is not None else nan * ones(Ny))
                if True:
                    d.x.insert(ind, xx[k, DimsX])

            # Update graphs
            for ix, (m, iy, ax) in enumerate(zip(DimsX, DimsY, axs)):
                sliding_xlim(ax, d.t, T_lag, True)
                if True:
                    h.x[ix].set_data(d.t, d.x[:, ix])
                if iy != None:
                    h.y[iy].set_data(d.t, d.y[:, iy])
                if 'mu' in d:
                    h.mu[ix].set_data(d.t, d.mu[:, ix])
                if 's' in d:
                    [h.s[ix][b].set_data(d.t, d.s[:, b, ix]) for b in [0, 1]]
                if 'E' in d:
                    [
                        h.E[ix][n].set_data(d.t, d.E[:, n, ix])
                        for n in range(len(E))
                    ]
                if 'E' in d:
                    update_alpha(key, stats, h.E[ix])

            return

        return update
Ejemplo n.º 6
0
    def __init__(self, fignum, stats, key0, plot_u, E, P, **kwargs):

        GS = {'height_ratios': [4, 1], 'hspace': 0.09, 'top': 0.95}
        fig, (ax, ax2) = freshfig(fignum, (5, 6),
                                  loc='2321',
                                  nrows=2,
                                  gridspec_kw=GS)

        if E is None and np.isnan(
                P.diag if isinstance(P, CovMat) else P).all():
            not_available_text(ax, ('Not available in replays'
                                    '\ncoz full Ens/Cov not stored.'))
            self.is_active = False
            return

        Nx = len(stats.mu[key0])
        if Nx <= 1003:
            C = np.eye(Nx)
            # Mask half
            mask = np.zeros_like(C, dtype=np.bool)
            mask[np.tril_indices_from(mask)] = True
            # Make colormap. Log-transform cmap,
            # but not internally in matplotlib,
            # so as to avoid transforming the colorbar too.
            cmap = plt.get_cmap('RdBu_r')
            trfm = mpl.colors.SymLogNorm(linthresh=0.2,
                                         linscale=0.2,
                                         base=np.e,
                                         vmin=-1,
                                         vmax=1)
            cmap = cmap(trfm(np.linspace(-0.6, 0.6, cmap.N)))
            cmap = mpl.colors.ListedColormap(cmap)
            #
            VM = 1.0  # abs(np.percentile(C,[1,99])).max()
            im = ax.imshow(C, cmap=cmap, vmin=-VM, vmax=VM)
            # Colorbar
            _ = ax.figure.colorbar(im, ax=ax, shrink=0.8)
            # Tune plot
            plt.box(False)
            ax.set_facecolor('w')
            ax.grid(False)
            ax.set_title("State correlation matrix:", y=1.07)
            ax.xaxis.tick_top()

            # ax2 = inset_axes(ax,width="30%",height="60%",loc=3)
            line_AC, = ax2.plot(arange(Nx), ones(Nx), label='Correlation')
            line_AA, = ax2.plot(arange(Nx), ones(Nx), label='Abs. corr.')
            _ = ax2.hlines(0, 0, Nx - 1, 'k', 'dotted', lw=1)
            # Align ax2 with ax
            bb_AC = ax2.get_position()
            bb_C = ax.get_position()
            ax2.set_position([bb_C.x0, bb_AC.y0, bb_C.width, bb_AC.height])
            # Tune plot
            ax2.set_title("Auto-correlation:")
            ax2.set_ylabel("Mean value")
            ax2.set_xlabel("Distance (in state indices)")
            ax2.set_xticklabels([])
            ax2.set_yticks([0, 1] + list(ax2.get_yticks()[[0, -1]]))
            ax2.set_ylim(top=1)
            ax2.legend(frameon=True,
                       facecolor='w',
                       bbox_to_anchor=(1, 1),
                       loc='upper left',
                       borderaxespad=0.02)

            self.ax = ax
            self.ax2 = ax2
            self.im = im
            self.line_AC = line_AC
            self.line_AA = line_AA
            self.mask = mask
            if hasattr(stats, 'w'):
                self.w = stats.w
        else:
            not_available_text(ax)
Ejemplo n.º 7
0
    def __init__(self,
                 fignum,
                 stats,
                 key0,
                 plot_u,
                 E,
                 P,
                 Tplot=None,
                 **kwargs):

        # STYLE TABLES - Defines which/how diagnostics get plotted
        styles = {}

        def lin(a, b):
            return (lambda x: a + b * x)

        divN = 1 / getattr(stats.xp, 'N', 99)
        # Columns: transf, shape, plt kwargs
        styles['RMS'] = {
            'err.rms': [None, None, dict(c='k', label='Error')],
            'std.rms': [None, None,
                        dict(c='b', label='Spread', alpha=0.6)],
        }
        styles['Values'] = {
            'skew': [None, None,
                     dict(c='g', label=star + r'Skew/$\sigma^3$')],
            'kurt':
            [None, None,
             dict(c='r', label=star + r'Kurt$/\sigma^4{-}3$')],
            'trHK': [None, None, dict(c='k', label=star + 'HK')],
            'infl': [lin(-10, 10), 'step',
                     dict(c='c', label='10(infl-1)')],
            'N_eff':
            [lin(0, divN), 'dirac',
             dict(c='y', label='N_eff/N', lw=3)],
            'iters': [lin(0, .1), 'dirac',
                      dict(c='m', label='iters/10')],
            'resmpl': [None, 'dirac',
                       dict(c='k', label='resampled?')],
        }

        nAx = len(styles)
        GS = {'left': 0.125, 'right': 0.76}
        fig, axs = freshfig(fignum, (5, 1 + nAx),
                            loc='2311',
                            nrows=nAx,
                            sharex=True,
                            gridspec_kw=GS)

        axs[0].set_title("Diagnostics")
        for style, ax in zip(styles, axs):
            ax.set_ylabel(style)
        ax.set_xlabel('Time (t)')
        viz.adjust_position(ax, y0=0.03)

        self.T_lag, K_lag, a_lag = validate_lag(Tplot, stats.HMM.t)

        def init_ax(ax, style_table):
            lines = {}
            for name in style_table:

                # SKIP -- if stats[name] is not in existence
                # Note: The nan check/deletion comes after the first kObs.
                try:
                    stat = deep_getattr(stats, name)
                except AttributeError:
                    continue
                # try: val0 = stat[key0[0]]
                # except KeyError: continue
                # PS: recall (from series.py) that even if store_u is false, stat[k] is
                # still present if liveplots=True via the k_tmp functionality.

                # Unpack style
                ln = {}
                ln['transf'] = style_table[name][0] or (lambda x: x)
                ln['shape'] = style_table[name][1]
                ln['plt'] = style_table[name][2]

                # Create series
                if isinstance(stat, FAUSt):
                    ln['plot_u'] = plot_u
                    K_plot = comp_K_plot(K_lag, a_lag, ln['plot_u'])
                else:
                    ln['plot_u'] = False
                    K_plot = a_lag
                ln['data'] = RollingArray(K_plot)
                ln['tt'] = RollingArray(K_plot)

                # Plot (init)
                ln['handle'], = ax.plot(ln['tt'], ln['data'], **ln['plt'])

                # Plotting only nans yield ugly limits. Revert to defaults.
                ax.set_xlim(0, 1)
                ax.set_ylim(0, 1)

                lines[name] = ln
            return lines

        # Plot
        self.d = [init_ax(ax, styles[style]) for style, ax in zip(styles, axs)]

        # Horizontal line at y=0
        self.baseline0, = ax.plot(ax.get_xlim(), [0, 0],
                                  c=0.5 * ones(3),
                                  lw=0.7,
                                  label='_nolegend_')

        # Store
        self.axs = axs
        self.stats = stats
        self.init_incomplete = True
Ejemplo n.º 8
0
    def init(fignum, stats, key0, plot_u, E, P, **kwargs):

        GS = {'left': 0.125 - 0.04, 'right': 0.9 - 0.04}
        fig, axs = freshfig(fignum, (6, 6),
                            loc='231-22-3',
                            nrows=2,
                            ncols=2,
                            sharex=True,
                            sharey=True,
                            gridspec_kw=GS)

        for ax in axs.flatten():
            ax.set_aspect('equal', viz.adjustable_box_or_forced())

        ((ax_11, ax_12), (ax_21, ax_22)) = axs

        ax_11.grid(color='w', linewidth=0.2)
        ax_12.grid(color='w', linewidth=0.2)
        ax_21.grid(color='k', linewidth=0.1)
        ax_22.grid(color='k', linewidth=0.1)

        # Upper colorbar -- position relative to ax_12
        bb = ax_12.get_position()
        dy = 0.1 * bb.height
        ax_13 = fig.add_axes(
            [bb.x1 + 0.03, bb.y0 + dy, 0.04, bb.height - 2 * dy])
        # Lower colorbar -- position relative to ax_22
        bb = ax_22.get_position()
        dy = 0.1 * bb.height
        ax_23 = fig.add_axes(
            [bb.x1 + 0.03, bb.y0 + dy, 0.04, bb.height - 2 * dy])

        # Extract data arrays
        xx, _, mu, std, err = stats.xx, stats.yy, stats.mu, stats.std, stats.err
        k = key0[0]
        tt = stats.HMM.t.tt

        # Plot
        # - origin='lower' might get overturned by set_ylim() below.
        im_11 = ax_11.imshow(square(mu[key0]), cmap=cm)
        im_12 = ax_12.imshow(square(xx[k]), cmap=cm)
        # hot is better, but needs +1 colorbar
        im_21 = ax_21.imshow(square(std[key0]), cmap=plt.cm.bwr)
        im_22 = ax_22.imshow(square(err[key0]), cmap=plt.cm.bwr)
        ims = (im_11, im_12, im_21, im_22)
        # Obs init -- a list where item 0 is the handle of something invisible.
        lh = list(ax_12.plot(0, 0)[0:1])

        sx = '$\\psi$'
        ax_11.set_title('mean ' + sx)
        ax_12.set_title('true ' + sx)
        ax_21.set_title('std. ' + sx)
        ax_22.set_title('err. ' + sx)

        # TODO 7
        # for ax in axs.flatten():
        # Crop boundries (which should be 0, i.e. yield harsh q gradients):
        # lims = (1, nx-2)
        # step = (nx - 1)/8
        # ticks = arange(step,nx-1,step)
        # ax.set_xlim  (lims)
        # ax.set_ylim  (lims[::-1])
        # ax.set_xticks(ticks)
        # ax.set_yticks(ticks)

        for im, clim in zip(ims, clims):
            im.set_clim(clim)

        fig.colorbar(im_12, cax=ax_13)
        fig.colorbar(im_22, cax=ax_23)
        for ax in [ax_13, ax_23]:
            ax.yaxis.set_tick_params('major',
                                     length=2,
                                     width=0.5,
                                     direction='in',
                                     left=True,
                                     right=True)
            ax.set_axisbelow('line')  # make ticks appear over colorbar patch

        # Title
        title = "Streamfunction (" + sx + ")"
        fig.suptitle(title)
        # Time info
        text_t = ax_12.text(1,
                            1.1,
                            format_time(None, None, None),
                            transform=ax_12.transAxes,
                            family='monospace',
                            ha='left')

        def update(key, E, P):
            k, kObs, faus = key
            t = tt[k]

            im_11.set_data(square(mu[key]))
            im_12.set_data(square(xx[k]))
            im_21.set_data(square(std[key]))
            im_22.set_data(square(err[key]))

            # Remove previous obs
            try:
                lh[0].remove()
            except ValueError:
                pass
            # Plot current obs.
            #  - plot() automatically adjusts to direction of y-axis in use.
            #  - ind2sub returns (iy,ix), while plot takes (ix,iy) => reverse.
            if kObs is not None and obs_inds is not None:
                lh[0] = ax_12.plot(*ind2sub(obs_inds(t))[::-1],
                                   'k.',
                                   ms=1,
                                   zorder=5)[0]

            text_t.set_text(format_time(k, kObs, t))

            return

        return update
Ejemplo n.º 9
0
    def init(fignum, stats, key0, plot_u, E, P, **kwargs):
        xx, yy, mu = stats.xx, stats.yy, stats.mu

        # Set parameters (kwargs takes precedence over params_orig)
        p = DotDict(
            **{kw: kwargs.get(kw, val)
               for kw, val in params_orig.items()})

        if p.dims == []:
            M = xx.shape[-1]
            p.dims = arange(M)
        else:
            M = len(p.dims)

        # Make periodic wrapper
        ii, wrap = viz.setup_wrapping(M, p.periodicity)

        # Set up figure, axes
        fig, ax = freshfig(fignum, (8, 5), loc='2312-3')
        fig.suptitle("1d amplitude plot")

        # Nans
        nan1 = wrap(nan * ones(M))

        if E is None and p.conf_mult is None:
            p.conf_mult = 2

        # Init plots
        if p.conf_mult:
            lines_s = ax.plot(ii,
                              nan1,
                              "b-",
                              lw=1,
                              label=(str(p.conf_mult) + r'$\sigma$ conf'))
            lines_s += ax.plot(ii, nan1, "b-", lw=1)
            line_mu, = ax.plot(ii, nan1, 'b-', lw=2, label='DA mean')
        else:
            nanE = nan * ones((stats.xp.N, M))
            lines_E = ax.plot(ii,
                              wrap(nanE[0]),
                              **p.ens_props,
                              lw=1,
                              label='Ensemble')
            lines_E += ax.plot(ii, wrap(nanE[1:]).T, **p.ens_props, lw=1)
        # Truth, Obs
        line_x, = ax.plot(ii, nan1, 'k-', lw=3, label='Truth')
        if p.obs_inds is not None:
            line_y, = ax.plot(p.obs_inds,
                              nan * p.obs_inds,
                              'g*',
                              ms=5,
                              label='Obs')

        # Tune plot
        ax.set_ylim(*viz.xtrema(xx))
        ax.set_xlim(viz.stretch(ii[0], ii[-1], 1))
        # Xticks
        xt = ax.get_xticks()
        xt = xt[abs(xt % 1) < 0.01].astype(int)  # Keep only the integer ticks
        xt = xt[xt >= 0]
        xt = xt[xt < len(p.dims)]
        ax.set_xticks(xt)
        ax.set_xticklabels(p.dims[xt])

        ax.set_xlabel('State index')
        ax.set_ylabel('Value')
        ax.legend(loc='upper right')

        text_t = ax.text(0.01,
                         0.01,
                         format_time(None, None, None),
                         transform=ax.transAxes,
                         family='monospace',
                         ha='left')

        # Init visibility (must come after legend):
        if p.obs_inds is not None:
            line_y.set_visible(False)

        def update(key, E, P):
            k, kObs, faus = key

            if p.conf_mult:
                sigma = mu[key] + p.conf_mult * stats.std[key] * [[1], [-1]]
                lines_s[0].set_ydata(wrap(sigma[0, p.dims]))
                lines_s[1].set_ydata(wrap(sigma[1, p.dims]))
                line_mu.set_ydata(wrap(mu[key][p.dims]))
            else:
                for n, line in enumerate(lines_E):
                    line.set_ydata(wrap(E[n, p.dims]))
                update_alpha(key, stats, lines_E)

            line_x.set_ydata(wrap(xx[k, p.dims]))

            text_t.set_text(format_time(k, kObs, stats.HMM.t.tt[k]))

            if 'f' in faus:
                if p.obs_inds is not None:
                    line_y.set_ydata(yy[kObs])
                    line_y.set_zorder(5)
                    line_y.set_visible(True)

            if 'u' in faus:
                if p.obs_inds is not None:
                    line_y.set_visible(False)

            return

        return update