Ejemplo n.º 1
0
def demo_locatable_axes_hard(fig):
    from mpl_toolkits.axes_grid1 import SubplotDivider, Size
    from mpl_toolkits.axes_grid1.mpl_axes import Axes
    divider = SubplotDivider(fig, 2, 2, 2, aspect=True)

    # axes for image
    ax = Axes(fig, divider.get_position())

    # axes for colorbar
    ax_cb = Axes(fig, divider.get_position())
    h = [
        Size.AxesX(ax),  # main axes
        Size.Fixed(0.05),  # padding, 0.1 inch
        Size.Fixed(0.2),  # colorbar, 0.3 inch
    ]
    v = [Size.AxesY(ax)]
    divider.set_horizontal(h)
    divider.set_vertical(v)
    ax.set_axes_locator(divider.new_locator(nx=0, ny=0))
    ax_cb.set_axes_locator(divider.new_locator(nx=2, ny=0))
    fig.add_axes(ax)
    fig.add_axes(ax_cb)
    ax_cb.axis["left"].toggle(all=False)
    ax_cb.axis["right"].toggle(ticks=True)
    Z, extent = get_demo_image()
    im = ax.imshow(Z, extent=extent, interpolation="nearest")
    plt.colorbar(im, cax=ax_cb)
    plt.setp(ax_cb.get_yticklabels(), visible=False)
Ejemplo n.º 2
0
def demo_locatable_axes_hard(fig1):

    from mpl_toolkits.axes_grid1 import SubplotDivider, LocatableAxes, Size

    divider = SubplotDivider(fig1, 2, 2, 2, aspect=True)

    # axes for image
    ax = LocatableAxes(fig1, divider.get_position())

    # axes for colorbar
    ax_cb = LocatableAxes(fig1, divider.get_position())

    h = [Size.AxesX(ax), Size.Fixed(0.05), Size.Fixed(0.2)]  # main axes  # padding, 0.1 inch  # colorbar, 0.3 inch

    v = [Size.AxesY(ax)]

    divider.set_horizontal(h)
    divider.set_vertical(v)

    ax.set_axes_locator(divider.new_locator(nx=0, ny=0))
    ax_cb.set_axes_locator(divider.new_locator(nx=2, ny=0))

    fig1.add_axes(ax)
    fig1.add_axes(ax_cb)

    ax_cb.axis["left"].toggle(all=False)
    ax_cb.axis["right"].toggle(ticks=True)

    Z, extent = get_demo_image()

    im = ax.imshow(Z, extent=extent, interpolation="nearest")
    plt.colorbar(im, cax=ax_cb)
    plt.setp(ax_cb.get_yticklabels(), visible=False)
Ejemplo n.º 3
0
def plot_heatmap(fig2, Z):
    from mpl_toolkits.axes_grid1 \
     import SubplotDivider, LocatableAxes, Size

    Z = np.flipud(Z)

    divider = SubplotDivider(fig2, 1, 1, 1, aspect=True)

    # axes for image
    ax = LocatableAxes(fig2, divider.get_position())

    # axes for colorbar
    ax_cb = LocatableAxes(fig2, divider.get_position())

    h = [
        Size.AxesX(ax),  # main axes
        Size.Fixed(0.05),  # padding, 0.1 inch
        Size.Fixed(0.2),  # colorbar, 0.3 inch
    ]

    v = [Size.AxesY(ax)]

    divider.set_horizontal(h)
    divider.set_vertical(v)

    ax.set_axes_locator(divider.new_locator(nx=0, ny=0))
    ax_cb.set_axes_locator(divider.new_locator(nx=2, ny=0))

    fig2.add_axes(ax)
    fig2.add_axes(ax_cb)

    ax_cb.axis["left"].toggle(all=False)
    ax_cb.axis["right"].toggle(ticks=True)

    im = ax.imshow(Z,
                   cmap=cm.coolwarm,
                   extent=(0, 1, 0, 1),
                   interpolation="nearest")
    plt.colorbar(im, cax=ax_cb)
    plt.setp(ax_cb.get_yticklabels(), visible=False)

    mngr = plt.get_current_fig_manager()
    geom = mngr.window.geometry()
    x, y, dx, dy = geom.getRect()
    mngr.window.setGeometry(dx + 200, 100, dx, dy)
Ejemplo n.º 4
0
def plot2dHeatMap(fig, x, y, z):
    '''z is a 2d grid; x and y are implicit linspaces'''

    from mpl_toolkits.axes_grid1 \
     import SubplotDivider, LocatableAxes, Size

    z = np.flipud(z)

    divider = SubplotDivider(fig, 1, 1, 1, aspect=True)

    # axes for image
    ax = LocatableAxes(fig, divider.get_position())

    # axes for colorbar
    ax_cb = LocatableAxes(fig, divider.get_position())

    h = [
        Size.AxesX(ax),  # main axes
        Size.Fixed(0.05),  # padding, 0.1 inch
        Size.Fixed(0.2),  # colorbar, 0.3 inch
    ]

    v = [Size.AxesY(ax)]

    divider.set_horizontal(h)
    divider.set_vertical(v)

    ax.set_axes_locator(divider.new_locator(nx=0, ny=0))
    ax_cb.set_axes_locator(divider.new_locator(nx=2, ny=0))

    fig.add_axes(ax)
    fig.add_axes(ax_cb)

    ax_cb.axis["left"].toggle(all=False)
    ax_cb.axis["right"].toggle(ticks=True)

    im = ax.imshow(z,
                   cmap=cm.coolwarm,
                   extent=(0, 1, 0, 1),
                   interpolation="nearest")
    plt.colorbar(im, cax=ax_cb)
    plt.setp(ax_cb.get_yticklabels(), visible=False)

    return ax
Ejemplo n.º 5
0
def addColorbar(mappable, ax):
    """ Append colorbar to axes

    Parameters
    ----------
    mappable :
        a mappable object
    ax :
        an axes object

    Returns
    -------
    cbax :
        colorbar axes object

    Notes
    -----
    This is mostly useful for axes created with :func:`curvedEarthAxes`.

    written by Sebastien, 2013-04

    """
    from mpl_toolkits.axes_grid1 import SubplotDivider, LocatableAxes, Size
    import matplotlib.pyplot as plt

    fig1 = ax.get_figure()
    divider = SubplotDivider(fig1, *ax.get_geometry(), aspect=True)

    # axes for colorbar
    cbax = LocatableAxes(fig1, divider.get_position())

    h = [
        Size.AxesX(ax),  # main axes
        Size.Fixed(0.1),  # padding
        Size.Fixed(0.2)
    ]  # colorbar
    v = [Size.AxesY(ax)]

    _ = divider.set_horizontal(h)
    _ = divider.set_vertical(v)

    _ = ax.set_axes_locator(divider.new_locator(nx=0, ny=0))
    _ = cbax.set_axes_locator(divider.new_locator(nx=2, ny=0))

    _ = fig1.add_axes(cbax)

    _ = cbax.axis["left"].toggle(all=False)
    _ = cbax.axis["top"].toggle(all=False)
    _ = cbax.axis["bottom"].toggle(all=False)
    _ = cbax.axis["right"].toggle(ticklabels=True, label=True)

    _ = plt.colorbar(mappable, cax=cbax)

    return cbax
Ejemplo n.º 6
0
def demo_locatable_axes_hard(fig):

    from mpl_toolkits.axes_grid1 import SubplotDivider, Size
    from mpl_toolkits.axes_grid1.mpl_axes import Axes

    divider = SubplotDivider(fig, 2, 2, 2, aspect=True)

    # axes for image
    ax = fig.add_axes(divider.get_position(), axes_class=Axes)

    # axes for colorbar
    # (the label prevents Axes.add_axes from incorrectly believing that the two
    # axes are the same)
    ax_cb = fig.add_axes(divider.get_position(), axes_class=Axes, label="cb")

    h = [
        Size.AxesX(ax),  # main axes
        Size.Fixed(0.05),  # padding, 0.1 inch
        Size.Fixed(0.2),  # colorbar, 0.3 inch
    ]

    v = [Size.AxesY(ax)]

    divider.set_horizontal(h)
    divider.set_vertical(v)

    ax.set_axes_locator(divider.new_locator(nx=0, ny=0))
    ax_cb.set_axes_locator(divider.new_locator(nx=2, ny=0))

    ax_cb.axis["left"].toggle(all=False)
    ax_cb.axis["right"].toggle(ticks=True)

    Z, extent = get_demo_image()

    im = ax.imshow(Z, extent=extent)
    plt.colorbar(im, cax=ax_cb)
    ax_cb.yaxis.set_tick_params(labelright=False)
Ejemplo n.º 7
0
def addColorbar(mappable, ax):
    """ Append colorbar to axes

    Parameters
    ----------
    mappable :
        a mappable object
    ax :
        an axes object

    Returns
    -------
    cbax :
        colorbar axes object

    Notes
    -----
    This is mostly useful for axes created with :func:`curvedEarthAxes`.

    written by Sebastien, 2013-04

    """
    from mpl_toolkits.axes_grid1 import SubplotDivider, LocatableAxes, Size
    import matplotlib.pyplot as plt 

    fig1 = ax.get_figure()
    divider = SubplotDivider(fig1, *ax.get_geometry(), aspect=True)

    # axes for colorbar
    cbax = LocatableAxes(fig1, divider.get_position())

    h = [Size.AxesX(ax), # main axes
         Size.Fixed(0.1), # padding
         Size.Fixed(0.2)] # colorbar
    v = [Size.AxesY(ax)]

    _ = divider.set_horizontal(h)
    _ = divider.set_vertical(v)

    _ = ax.set_axes_locator(divider.new_locator(nx=0, ny=0))
    _ = cbax.set_axes_locator(divider.new_locator(nx=2, ny=0))

    _ = fig1.add_axes(cbax)

    _ = cbax.axis["left"].toggle(all=False)
    _ = cbax.axis["top"].toggle(all=False)
    _ = cbax.axis["bottom"].toggle(all=False)
    _ = cbax.axis["right"].toggle(ticklabels=True, label=True)

    _ = plt.colorbar(mappable, cax=cbax)

    return cbax
Ejemplo n.º 8
0
def add_cbar(mappable, ax):
    """ 
    Append colorbar to axes
    Copied from DaViTPy: https://github.com/vtsuperdarn/davitpy/blob/1b578ea2491888e3d97d6e0a8bc6d8cc7c9211fb/davitpy/utils/plotUtils.py#L674
    """
    from mpl_toolkits.axes_grid1 import SubplotDivider, Size
    from mpl_toolkits.axes_grid1.mpl_axes import Axes
    import matplotlib.pyplot as plt

    fig1 = ax.get_figure()
    divider = SubplotDivider(fig1, *ax.get_geometry(), aspect=True)

    # axes for colorbar
    cbax = Axes(fig1, divider.get_position())

    h = [
        Size.AxesX(ax),  # main axes
        Size.Fixed(0.1),  # padding
        Size.Fixed(0.2)
    ]  # colorbar
    v = [Size.AxesY(ax)]

    _ = divider.set_horizontal(h)
    _ = divider.set_vertical(v)

    _ = ax.set_axes_locator(divider.new_locator(nx=0, ny=0))
    _ = cbax.set_axes_locator(divider.new_locator(nx=2, ny=0))

    _ = fig1.add_axes(cbax)

    _ = cbax.axis["left"].toggle(all=False)
    _ = cbax.axis["top"].toggle(all=False)
    _ = cbax.axis["bottom"].toggle(all=False)
    _ = cbax.axis["right"].toggle(ticklabels=True, label=True)

    _ = plt.colorbar(mappable, cax=cbax)

    return cbax
Ejemplo n.º 9
0
    def plot(self, time, beam=None, maxground=2000, maxalt=500, step=1,
        showrefract=False, nr_cmap='jet_r', nr_lim=[0.8, 1.], 
        raycolor='0.4',  
        fig=None, rect=111):
        """Plot ray paths
        
        **Args**: 
            * **time** (datetime.datetime): time of rays
            * [**beam**]: beam number
            * [**maxground**]: maximum ground range [km]
            * [**maxalt**]: highest altitude limit [km]
            * [**step**]: step between each plotted ray (in number of ray steps)
            * [**showrefract**]: show refractive index along ray paths (supersedes raycolor)
            * [**nr_cmap**]: color map name for refractive index coloring
            * [**nr_lim**]: refractive index plotting limits
            * [**raycolor**]: color of ray paths
            * [**rect**]: subplot spcification
            * [**fig**]: A pylab.figure object (default to gcf)
        **Returns**:
            * **ax**: matplotlib.axes object containing formatting
            * **aax**: matplotlib.axes object containing data
            * **cbax**: matplotlib.axes object containing colorbar
        **Example**:
            ::

                # Show ray paths with colored refractive index along path
                import datetime as dt
                from models import raydarn
                sTime = dt.datetime(2012, 11, 18, 5)
                rto = raydarn.rtRun(sTime, rCode='bks', beam=12)
                rto.readRays() # read rays into memory
                ax, aax, cbax = rto.rays.plot(sTime, step=2, showrefract=True, nr_lim=[.85,1])
                ax.grid()
                
        written by Sebastien, 2013-04
        """
        from utils import plotUtils
        from mpl_toolkits.axes_grid1 import make_axes_locatable
        from matplotlib.collections import LineCollection
        import matplotlib.pyplot as plt
        import numpy as np

        ax, aax = plotUtils.curvedEarthAxes(fig=fig, rect=rect, 
            maxground=maxground, maxalt=maxalt)

        # make sure that the required time and beam are present
        assert (time in self.paths.keys()), 'Unkown time %s' % time
        if beam:
            assert (beam in self.paths[time].keys()), 'Unkown beam %s' % beam
        else:
            beam = self.paths[time].keys()[0]
        
        for ir, (el, rays) in enumerate( sorted(self.paths[time][beam].items()) ):
            if not ir % step:
                if not showrefract:
                    aax.plot(rays['th'], rays['r']*1e-3, c=raycolor, zorder=2)
                else:
                    points = np.array([rays['th'], rays['r']*1e-3]).T.reshape(-1, 1, 2)
                    segments = np.concatenate([points[:-1], points[1:]], axis=1)
                    lcol = LineCollection( segments )
                    lcol.set_cmap( nr_cmap )
                    lcol.set_norm( plt.Normalize(*nr_lim) )
                    lcol.set_array( rays['nr'] )
                    aax.add_collection( lcol )
        # Add a colorbar when plotting refractive index
        if showrefract:
            from mpl_toolkits.axes_grid1 import SubplotDivider, LocatableAxes, Size

            fig1 = ax.get_figure()
            divider = SubplotDivider(fig1, *ax.get_geometry(), aspect=True)

            # axes for colorbar
            cbax = LocatableAxes(fig1, divider.get_position())

            h = [Size.AxesX(ax), # main axes
                 Size.Fixed(0.1), # padding
                 Size.Fixed(0.2)] # colorbar
            v = [Size.AxesY(ax)]

            divider.set_horizontal(h)
            divider.set_vertical(v)

            ax.set_axes_locator(divider.new_locator(nx=0, ny=0))
            cbax.set_axes_locator(divider.new_locator(nx=2, ny=0))

            fig1.add_axes(cbax)

            cbax.axis["left"].toggle(all=False)
            cbax.axis["top"].toggle(all=False)
            cbax.axis["bottom"].toggle(all=False)
            cbax.axis["right"].toggle(ticklabels=True, label=True)

            plt.colorbar(lcol, cax=cbax)
            cbax.set_ylabel("refractive index")

        return ax, aax, cbax
Ejemplo n.º 10
0
def view_layers(logdir, mode=0, ppc=20):
    '''
    DOCUMENTATION
    :param logdir: path to log directory that contains pickled run logs
    :param mode: viewing mode index. Must be an int between 0 and 2
        0: limits the viewing to feedforward information only (weights, biases, net_input, output)
        1: same as 0, but also includes gradient information (gweights, gbiases, gnet_input, goutput)
        2: same as 2, but also includes cumulative gradient information
    :return:
    '''
    plt.ion()
    # get runlog filenames and paths
    FILENAMES, RUNLOG_PATHS = [sorted(l) for l in list_pickles(logdir)]

    # get testing epochs and losses data
    EPOCHS, LOSSES, LOSS_SUMS = get_data_by_key(
        runlog_path=RUNLOG_PATHS[0], keys=['enum', 'loss',
                                           'loss_sum']).values()

    # get layer names and layer dims to set up figure
    layer_names = get_layer_names(runlog_path=RUNLOG_PATHS[0])
    layer_names.reverse()
    layer_dims = get_layer_dims(runlog_path=RUNLOG_PATHS[0],
                                layer_names=layer_names)

    # set up and make figure
    figure = _make_figure(layer_dims=layer_dims,
                          mode=mode,
                          ppc=ppc,
                          dpi=96,
                          fig_title='view_layers: ' + logdir)

    num_layers = len(layer_names)
    disp_targs = [True] + [False for l in layer_names[1:]]

    axes_dicts = []
    for i, (layer_name, disp_targ) in enumerate(zip(layer_names, disp_targs)):
        sp_divider = SubplotDivider(figure,
                                    num_layers,
                                    1,
                                    i + 1,
                                    aspect=True,
                                    anchor='NW')
        vdims = [dim[0] for dim in layer_dims.values()]
        sp_divider._subplotspec._gridspec._row_height_ratios = [
            vdim + 1.8 for vdim in vdims
        ]
        axes_dicts.append(
            _divide_axes_grid(mpl_figure=figure,
                              divider=sp_divider,
                              layer_name=layer_name.upper().replace('_', ' '),
                              inp_size=layer_dims[layer_name][1],
                              layer_size=layer_dims[layer_name][0],
                              mode=mode,
                              target=disp_targ))
    plt.tight_layout()

    _widget_layout = widgets.Layout(width='100%')

    run_widget = widgets.Dropdown(options=dict(zip(FILENAMES, RUNLOG_PATHS)),
                                  description='Run log: ',
                                  value=RUNLOG_PATHS[0],
                                  layout=_widget_layout)

    cmap_widget = widgets.Dropdown(options=sorted([
        'BrBG', 'bwr', 'coolwarm', 'PiYG', 'PRGn', 'PuOr', 'RdBu', 'RdGy',
        'RdYlBu', 'RdYlGn', 'seismic'
    ]),
                                   description='Colors: ',
                                   value='coolwarm',
                                   disabled=False,
                                   layout=_widget_layout)

    vrange_widget = widgets.FloatSlider(value=1.0,
                                        min=0,
                                        max=8,
                                        step=.1,
                                        description='V-range: ',
                                        continuous_update=False,
                                        layout=_widget_layout)

    step_widget = widgets.IntSlider(value=0,
                                    min=0,
                                    max=len(EPOCHS) - 1,
                                    step=1,
                                    description='Step index: ',
                                    continuous_update=False,
                                    layout=_widget_layout)

    pattern_options = get_pattern_options(runlog_path=RUNLOG_PATHS[0],
                                          tind=step_widget.value)
    options_map = {}
    for i, pattern_option in enumerate(pattern_options):
        options_map[pattern_option] = i
    pattern_widget = widgets.Dropdown(options=options_map,
                                      value=0,
                                      description='Pattern: ',
                                      disabled=False,
                                      layout=_widget_layout)

    loss_observer = LossDataObsever(
        epoch_list=EPOCHS,
        loss_list=LOSSES,
        loss_sum_list=LOSS_SUMS,
        tind=step_widget.value,
        pind=pattern_widget.value,
    )

    fig_observer = FigureObserver(mpl_figure=figure)

    step_widget.observe(handler=loss_observer.on_epoch_change, names='value')
    pattern_widget.observe(handler=loss_observer.on_pattern_change,
                           names='value')

    def on_runlog_change(change):
        if change['type'] == 'change' and change['name'] == 'value':
            newEPOCHS, newLOSSES, newLOSS_SUMS = get_data_by_key(
                runlog_path=change['new'], keys=['enum', 'loss',
                                                 'loss_sum']).values()
            step_widget.max = len(newEPOCHS) - 1
            step_widget.value = 0
            pattern_widget.value = 0
            loss_observer.new_runlog(newEPOCHS, newLOSSES, newLOSS_SUMS)

    run_widget.observe(on_runlog_change)

    controls_dict = dict(
        runlog_path=run_widget,
        img_dicts=widgets.fixed(axes_dicts),
        layer_names=widgets.fixed(layer_names),
        colormap=cmap_widget,
        vrange=vrange_widget,
        tind=step_widget,
        pind=pattern_widget,
    )

    row_layout = widgets.Layout(display='flex',
                                flex_flow='row',
                                justify_content='center')

    stretch_layout = widgets.Layout(display='flex',
                                    flex_flow='row',
                                    justify_content='space-around')

    control_panel_rows = [
        widgets.Box(
            children=[controls_dict['runlog_path'], controls_dict['pind']],
            layout=row_layout),
        widgets.Box(
            children=[controls_dict['colormap'], controls_dict['vrange']],
            layout=row_layout),
        widgets.Box(children=[controls_dict['tind']], layout=row_layout),
        widgets.Box(children=[
            loss_observer.epoch_widget, loss_observer.loss_sum_widget,
            loss_observer.loss_widget, fig_observer.widget
        ],
                    layout=stretch_layout)
    ]

    controls_panel = widgets.Box(children=control_panel_rows,
                                 layout=widgets.Layout(display='flex',
                                                       flex_flow='column',
                                                       padding='5px',
                                                       border='ridge 1px',
                                                       align_items='stretch',
                                                       width='100%'))

    widgets.interactive_output(f=_draw_layers, controls=controls_dict)
    display(controls_panel)