Beispiel #1
0
def demo_locatable_axes_hard(fig):
    from mpl_toolkits.axes_grid1 import SubplotDivider, Size
    from mpl_toolkits.axes_grid1.mpl_axes import Axes
    divider = SubplotDivider(fig, 2, 2, 2, aspect=True)

    # axes for image
    ax = Axes(fig, divider.get_position())

    # axes for colorbar
    ax_cb = Axes(fig, divider.get_position())
    h = [
        Size.AxesX(ax),  # main axes
        Size.Fixed(0.05),  # padding, 0.1 inch
        Size.Fixed(0.2),  # colorbar, 0.3 inch
    ]
    v = [Size.AxesY(ax)]
    divider.set_horizontal(h)
    divider.set_vertical(v)
    ax.set_axes_locator(divider.new_locator(nx=0, ny=0))
    ax_cb.set_axes_locator(divider.new_locator(nx=2, ny=0))
    fig.add_axes(ax)
    fig.add_axes(ax_cb)
    ax_cb.axis["left"].toggle(all=False)
    ax_cb.axis["right"].toggle(ticks=True)
    Z, extent = get_demo_image()
    im = ax.imshow(Z, extent=extent, interpolation="nearest")
    plt.colorbar(im, cax=ax_cb)
    plt.setp(ax_cb.get_yticklabels(), visible=False)
Beispiel #2
0
def addColorbar(mappable, ax):
    """ Append colorbar to axes

    Parameters
    ----------
    mappable :
        a mappable object
    ax :
        an axes object

    Returns
    -------
    cbax :
        colorbar axes object

    Notes
    -----
    This is mostly useful for axes created with :func:`curvedEarthAxes`.

    written by Sebastien, 2013-04

    """
    from mpl_toolkits.axes_grid1 import SubplotDivider, LocatableAxes, Size
    import matplotlib.pyplot as plt

    fig1 = ax.get_figure()
    divider = SubplotDivider(fig1, *ax.get_geometry(), aspect=True)

    # axes for colorbar
    cbax = LocatableAxes(fig1, divider.get_position())

    h = [
        Size.AxesX(ax),  # main axes
        Size.Fixed(0.1),  # padding
        Size.Fixed(0.2)
    ]  # colorbar
    v = [Size.AxesY(ax)]

    _ = divider.set_horizontal(h)
    _ = divider.set_vertical(v)

    _ = ax.set_axes_locator(divider.new_locator(nx=0, ny=0))
    _ = cbax.set_axes_locator(divider.new_locator(nx=2, ny=0))

    _ = fig1.add_axes(cbax)

    _ = cbax.axis["left"].toggle(all=False)
    _ = cbax.axis["top"].toggle(all=False)
    _ = cbax.axis["bottom"].toggle(all=False)
    _ = cbax.axis["right"].toggle(ticklabels=True, label=True)

    _ = plt.colorbar(mappable, cax=cbax)

    return cbax
Beispiel #3
0
def plot_heatmap(fig2, Z):
    from mpl_toolkits.axes_grid1 \
     import SubplotDivider, LocatableAxes, Size

    Z = np.flipud(Z)

    divider = SubplotDivider(fig2, 1, 1, 1, aspect=True)

    # axes for image
    ax = LocatableAxes(fig2, divider.get_position())

    # axes for colorbar
    ax_cb = LocatableAxes(fig2, divider.get_position())

    h = [
        Size.AxesX(ax),  # main axes
        Size.Fixed(0.05),  # padding, 0.1 inch
        Size.Fixed(0.2),  # colorbar, 0.3 inch
    ]

    v = [Size.AxesY(ax)]

    divider.set_horizontal(h)
    divider.set_vertical(v)

    ax.set_axes_locator(divider.new_locator(nx=0, ny=0))
    ax_cb.set_axes_locator(divider.new_locator(nx=2, ny=0))

    fig2.add_axes(ax)
    fig2.add_axes(ax_cb)

    ax_cb.axis["left"].toggle(all=False)
    ax_cb.axis["right"].toggle(ticks=True)

    im = ax.imshow(Z,
                   cmap=cm.coolwarm,
                   extent=(0, 1, 0, 1),
                   interpolation="nearest")
    plt.colorbar(im, cax=ax_cb)
    plt.setp(ax_cb.get_yticklabels(), visible=False)

    mngr = plt.get_current_fig_manager()
    geom = mngr.window.geometry()
    x, y, dx, dy = geom.getRect()
    mngr.window.setGeometry(dx + 200, 100, dx, dy)
def plot2dHeatMap(fig, x, y, z):
    '''z is a 2d grid; x and y are implicit linspaces'''

    from mpl_toolkits.axes_grid1 \
     import SubplotDivider, LocatableAxes, Size

    z = np.flipud(z)

    divider = SubplotDivider(fig, 1, 1, 1, aspect=True)

    # axes for image
    ax = LocatableAxes(fig, divider.get_position())

    # axes for colorbar
    ax_cb = LocatableAxes(fig, divider.get_position())

    h = [
        Size.AxesX(ax),  # main axes
        Size.Fixed(0.05),  # padding, 0.1 inch
        Size.Fixed(0.2),  # colorbar, 0.3 inch
    ]

    v = [Size.AxesY(ax)]

    divider.set_horizontal(h)
    divider.set_vertical(v)

    ax.set_axes_locator(divider.new_locator(nx=0, ny=0))
    ax_cb.set_axes_locator(divider.new_locator(nx=2, ny=0))

    fig.add_axes(ax)
    fig.add_axes(ax_cb)

    ax_cb.axis["left"].toggle(all=False)
    ax_cb.axis["right"].toggle(ticks=True)

    im = ax.imshow(z,
                   cmap=cm.coolwarm,
                   extent=(0, 1, 0, 1),
                   interpolation="nearest")
    plt.colorbar(im, cax=ax_cb)
    plt.setp(ax_cb.get_yticklabels(), visible=False)

    return ax
Beispiel #5
0
def add_cbar(mappable, ax):
    """ 
    Append colorbar to axes
    Copied from DaViTPy: https://github.com/vtsuperdarn/davitpy/blob/1b578ea2491888e3d97d6e0a8bc6d8cc7c9211fb/davitpy/utils/plotUtils.py#L674
    """
    from mpl_toolkits.axes_grid1 import SubplotDivider, Size
    from mpl_toolkits.axes_grid1.mpl_axes import Axes
    import matplotlib.pyplot as plt

    fig1 = ax.get_figure()
    divider = SubplotDivider(fig1, *ax.get_geometry(), aspect=True)

    # axes for colorbar
    cbax = Axes(fig1, divider.get_position())

    h = [
        Size.AxesX(ax),  # main axes
        Size.Fixed(0.1),  # padding
        Size.Fixed(0.2)
    ]  # colorbar
    v = [Size.AxesY(ax)]

    _ = divider.set_horizontal(h)
    _ = divider.set_vertical(v)

    _ = ax.set_axes_locator(divider.new_locator(nx=0, ny=0))
    _ = cbax.set_axes_locator(divider.new_locator(nx=2, ny=0))

    _ = fig1.add_axes(cbax)

    _ = cbax.axis["left"].toggle(all=False)
    _ = cbax.axis["top"].toggle(all=False)
    _ = cbax.axis["bottom"].toggle(all=False)
    _ = cbax.axis["right"].toggle(ticklabels=True, label=True)

    _ = plt.colorbar(mappable, cax=cbax)

    return cbax
def demo_locatable_axes_hard(fig):

    from mpl_toolkits.axes_grid1 import SubplotDivider, Size
    from mpl_toolkits.axes_grid1.mpl_axes import Axes

    divider = SubplotDivider(fig, 2, 2, 2, aspect=True)

    # axes for image
    ax = fig.add_axes(divider.get_position(), axes_class=Axes)

    # axes for colorbar
    # (the label prevents Axes.add_axes from incorrectly believing that the two
    # axes are the same)
    ax_cb = fig.add_axes(divider.get_position(), axes_class=Axes, label="cb")

    h = [
        Size.AxesX(ax),  # main axes
        Size.Fixed(0.05),  # padding, 0.1 inch
        Size.Fixed(0.2),  # colorbar, 0.3 inch
    ]

    v = [Size.AxesY(ax)]

    divider.set_horizontal(h)
    divider.set_vertical(v)

    ax.set_axes_locator(divider.new_locator(nx=0, ny=0))
    ax_cb.set_axes_locator(divider.new_locator(nx=2, ny=0))

    ax_cb.axis["left"].toggle(all=False)
    ax_cb.axis["right"].toggle(ticks=True)

    Z, extent = get_demo_image()

    im = ax.imshow(Z, extent=extent)
    plt.colorbar(im, cax=ax_cb)
    ax_cb.yaxis.set_tick_params(labelright=False)
Beispiel #7
0
def view_layers(logdir, mode=0, ppc=20):
    '''
    DOCUMENTATION
    :param logdir: path to log directory that contains pickled run logs
    :param mode: viewing mode index. Must be an int between 0 and 2
        0: limits the viewing to feedforward information only (weights, biases, net_input, output)
        1: same as 0, but also includes gradient information (gweights, gbiases, gnet_input, goutput)
        2: same as 2, but also includes cumulative gradient information
    :return:
    '''
    plt.ion()
    # get runlog filenames and paths
    FILENAMES, RUNLOG_PATHS = [sorted(l) for l in list_pickles(logdir)]

    # get testing epochs and losses data
    EPOCHS, LOSSES, LOSS_SUMS = get_data_by_key(
        runlog_path=RUNLOG_PATHS[0], keys=['enum', 'loss',
                                           'loss_sum']).values()

    # get layer names and layer dims to set up figure
    layer_names = get_layer_names(runlog_path=RUNLOG_PATHS[0])
    layer_names.reverse()
    layer_dims = get_layer_dims(runlog_path=RUNLOG_PATHS[0],
                                layer_names=layer_names)

    # set up and make figure
    figure = _make_figure(layer_dims=layer_dims,
                          mode=mode,
                          ppc=ppc,
                          dpi=96,
                          fig_title='view_layers: ' + logdir)

    num_layers = len(layer_names)
    disp_targs = [True] + [False for l in layer_names[1:]]

    axes_dicts = []
    for i, (layer_name, disp_targ) in enumerate(zip(layer_names, disp_targs)):
        sp_divider = SubplotDivider(figure,
                                    num_layers,
                                    1,
                                    i + 1,
                                    aspect=True,
                                    anchor='NW')
        vdims = [dim[0] for dim in layer_dims.values()]
        sp_divider._subplotspec._gridspec._row_height_ratios = [
            vdim + 1.8 for vdim in vdims
        ]
        axes_dicts.append(
            _divide_axes_grid(mpl_figure=figure,
                              divider=sp_divider,
                              layer_name=layer_name.upper().replace('_', ' '),
                              inp_size=layer_dims[layer_name][1],
                              layer_size=layer_dims[layer_name][0],
                              mode=mode,
                              target=disp_targ))
    plt.tight_layout()

    _widget_layout = widgets.Layout(width='100%')

    run_widget = widgets.Dropdown(options=dict(zip(FILENAMES, RUNLOG_PATHS)),
                                  description='Run log: ',
                                  value=RUNLOG_PATHS[0],
                                  layout=_widget_layout)

    cmap_widget = widgets.Dropdown(options=sorted([
        'BrBG', 'bwr', 'coolwarm', 'PiYG', 'PRGn', 'PuOr', 'RdBu', 'RdGy',
        'RdYlBu', 'RdYlGn', 'seismic'
    ]),
                                   description='Colors: ',
                                   value='coolwarm',
                                   disabled=False,
                                   layout=_widget_layout)

    vrange_widget = widgets.FloatSlider(value=1.0,
                                        min=0,
                                        max=8,
                                        step=.1,
                                        description='V-range: ',
                                        continuous_update=False,
                                        layout=_widget_layout)

    step_widget = widgets.IntSlider(value=0,
                                    min=0,
                                    max=len(EPOCHS) - 1,
                                    step=1,
                                    description='Step index: ',
                                    continuous_update=False,
                                    layout=_widget_layout)

    pattern_options = get_pattern_options(runlog_path=RUNLOG_PATHS[0],
                                          tind=step_widget.value)
    options_map = {}
    for i, pattern_option in enumerate(pattern_options):
        options_map[pattern_option] = i
    pattern_widget = widgets.Dropdown(options=options_map,
                                      value=0,
                                      description='Pattern: ',
                                      disabled=False,
                                      layout=_widget_layout)

    loss_observer = LossDataObsever(
        epoch_list=EPOCHS,
        loss_list=LOSSES,
        loss_sum_list=LOSS_SUMS,
        tind=step_widget.value,
        pind=pattern_widget.value,
    )

    fig_observer = FigureObserver(mpl_figure=figure)

    step_widget.observe(handler=loss_observer.on_epoch_change, names='value')
    pattern_widget.observe(handler=loss_observer.on_pattern_change,
                           names='value')

    def on_runlog_change(change):
        if change['type'] == 'change' and change['name'] == 'value':
            newEPOCHS, newLOSSES, newLOSS_SUMS = get_data_by_key(
                runlog_path=change['new'], keys=['enum', 'loss',
                                                 'loss_sum']).values()
            step_widget.max = len(newEPOCHS) - 1
            step_widget.value = 0
            pattern_widget.value = 0
            loss_observer.new_runlog(newEPOCHS, newLOSSES, newLOSS_SUMS)

    run_widget.observe(on_runlog_change)

    controls_dict = dict(
        runlog_path=run_widget,
        img_dicts=widgets.fixed(axes_dicts),
        layer_names=widgets.fixed(layer_names),
        colormap=cmap_widget,
        vrange=vrange_widget,
        tind=step_widget,
        pind=pattern_widget,
    )

    row_layout = widgets.Layout(display='flex',
                                flex_flow='row',
                                justify_content='center')

    stretch_layout = widgets.Layout(display='flex',
                                    flex_flow='row',
                                    justify_content='space-around')

    control_panel_rows = [
        widgets.Box(
            children=[controls_dict['runlog_path'], controls_dict['pind']],
            layout=row_layout),
        widgets.Box(
            children=[controls_dict['colormap'], controls_dict['vrange']],
            layout=row_layout),
        widgets.Box(children=[controls_dict['tind']], layout=row_layout),
        widgets.Box(children=[
            loss_observer.epoch_widget, loss_observer.loss_sum_widget,
            loss_observer.loss_widget, fig_observer.widget
        ],
                    layout=stretch_layout)
    ]

    controls_panel = widgets.Box(children=control_panel_rows,
                                 layout=widgets.Layout(display='flex',
                                                       flex_flow='column',
                                                       padding='5px',
                                                       border='ridge 1px',
                                                       align_items='stretch',
                                                       width='100%'))

    widgets.interactive_output(f=_draw_layers, controls=controls_dict)
    display(controls_panel)