def test_closing_fig(block=False): """ Test that if we close the figure, we recognise that fact, and that we therefore need to recreate the subplot afterwards. """ fig = plt.figure() select_subplot('fdsfsd') plt.plot(np.sin(np.linspace(0, 10, 100))) plt.close(fig) select_subplot('fdsfsd') plt.plot(np.cos(np.linspace(0, 10, 100))) plt.show(block = block)
def compare_learning_curves(records: Sequence[ExperimentRecord], show_now=True): argcommon, argdiffs = separate_common_args(records, as_dicts=True, only_shared_argdiffs=False) fig = plt.figure() ax = select_subplot(1) color_cycle = get_lines_color_cycle() for i, (rec, ad, c) in enumerate(zip(records, argdiffs, color_cycle)): result = rec.get_result() for i, subset in enumerate(('train_error', 'test_error')): is_train = subset == "train_error" ax.plot( result[:, 'epoch'], result[:, subset], label=('train: ' if subset == 'train_error' else 'test: ') + dict_to_str(ad).replace('lambdas', '$\lambda$').replace( 'epsilon', '$\epsilon$'), linestyle='--' if is_train else '-', alpha=0.7 if is_train else 1, color=c) ax.grid() ax.legend() ax.set_ybound(0, max(10, min(result[:, subset]) * 1.5)) ax.set_xlabel('Epoch') ax.set_ylabel('Classification Error') ax.grid() if show_now: plt.show()
def use_dbplot_axis( name, fig=None, layout=None, clear=False, ): ax = select_subplot(name, fig=_get_dbplot_plot_object(fig).figure, layout=_default_layout if layout is None else layout) if clear: ax.clear() return ax
def test_positioning_plots(block=False): select_subplot(position=(0, 1)) plt.plot(np.random.randn(10)) select_subplot(position=(1, 1)) plt.plot(np.random.randn(10)) select_subplot(position=(1, 0)) plt.plot(np.random.randn(10)) plt.show(block=block)
def test_expanding_subplots(block=False): select_subplot('agfdsfgdg') plt.plot(np.sin(np.linspace(0, 10, 100))) select_subplot('dsxfdsgf') plt.imshow(np.random.randn(10, 10)) select_subplot('agfdsfgdg') plt.plot(np.cos(np.linspace(0, 10, 100))) plt.show(block = block)
def compare_learning_curves_new(records: Sequence[ExperimentRecord], show_now=True): argcommon, argdiffs = separate_common_args(records, as_dicts=True, only_shared_argdiffs=False) fig = plt.figure() ax = select_subplot(1) color_cycle = get_lines_color_cycle() subsets = ('test_init_error', 'test_neg_error', 'train_init_error', 'train_neg_error') maxminscore = 0 for i, (rec, ad, c) in enumerate(zip(records, argdiffs, color_cycle)): result = rec.get_result() for i, (subset, (linestyle, alpha)) in enumerate( izip_equal(subsets, (('-', 1), (':', 1), ('-', .5), (':', .5)))): ax.plot(result[:, 'epoch'], result[:, subset], label=('train' if 'train' in subset else 'test') + '-' + ('init' if 'init' in subset else 'neg') + ': ' + dict_to_str(ad).replace('lambdas', '$\lambda$').replace( 'epsilon', '$\epsilon$'), linestyle=linestyle, alpha=alpha, color=c) ax.grid() ax.legend() maxminscore = max(maxminscore, min(result[:, subset])) ax.set_ybound(0, max(10, maxminscore * 1.5)) ax.set_xlabel('Epoch') ax.set_ylabel('Classification Error') ax.grid() if show_now: plt.show()
def compare(records: Sequence[ExperimentRecord], show_now=True): argcommon, argdiffs = separate_common_args(records, as_dicts=True, only_shared_argdiffs=False) ax = select_subplot(1) color_cycle = get_lines_color_cycle() for i, (rec, ad, c) in enumerate(zip(records, argdiffs, color_cycle)): result = rec.get_result() for i, subset in enumerate(('train_error', 'test_error')): is_train = subset == "train_error" ax.plot(result[:, 'epoch'], result[:, subset], label=dict_to_str(ad).replace('lambdas', '$\lambda$').replace( 'epsilon', '$\epsilon$'), linestyle='--' if is_train else '-', alpha=0.7 if is_train else 1, color=c) ax.grid() ax.legend() ax.set_ybound(0, max(10, min(result[:, subset]) * 1.5)) # ax.set_ylabel(f'{# "Train" if is_train else "Test"} % Error') ax.set_xlabel('Epoch') ax.set_ylabel('Classification Error') ax.grid() # plt.legend([f'{alg}{subset}' for alg in ['Real Eq-Prop', f'Bin Eq-Prop $\lambda$={records[-1].get_args()["lambdas"]}'] for subset in ['Train', 'Test']]) # plt.ion() if show_now: plt.show() print('ENter')
def dbplot(data, name=None, plot_type=None, axis=None, plot_mode='live', draw_now=True, hang=False, title=None, fig=None, xlabel=None, ylabel=None, draw_every=None, layout=None, legend=None, grid=False, wait_for_display_sec=0, cornertext=None): """ Plot arbitrary data and continue execution. This program tries to figure out what type of plot to use. :param data: Any data. Hopefully, we at dbplot will be able to figure out a plot for it. :param name: A name uniquely identifying this plot. :param plot_type: A specialized constructor to be used the first time when plotting. You can also pass certain string to give hints as to what kind of plot you want (can resolve cases where the given data could be plotted in multiple ways): 'line': Plots a line plot 'img': An image plot 'colour': A colour image plot 'pic': A picture (no scale bars, axis labels, etc). :param axis: A string identifying which axis to plot on. By default, it is the same as "name". Only use this argument if you indend to make multiple dbplots share the same axis. :param plot_mode: Influences how the data should be used to choose the plot type: 'live': Best for 'live' plots that you intend to update as new data arrives 'static': Best for 'static' plots, that you do not intend to update 'image': Try to represent the plot as an image :param draw_now: Draw the plot now (you may choose false if you're going to add another plot immediately after and don't want have to draw this one again. :param hang: Hang on the plot (wait for it to be closed before continuing) :param title: Title of the plot (will default to name if not included) :param fig: Name of the figure - use this when you want to create multiple figures. :param grid: Turn the grid on :param wait_for_display_sec: In server mode, you can choose to wait maximally wait_for_display_sec seconds before this call returns. In case plotting is finished earlier, the call returns earlier. Setting wait_for_display_sec to a negative number will cause the call to block until the plot has been displayed. """ if is_server_plotting_on(): # Redirect the function call to the plotting server. The flag gets turned on in a configuration file. It is # turned off when this file is run ON the plotting server, from the first line in plotting_server.py arg_locals = locals().copy() from artemis.remote.plotting.plotting_client import dbplot_remotely dbplot_remotely(arg_locals=arg_locals) return if isinstance(fig, plt.Figure): assert None not in _DBPLOT_FIGURES, "If you pass a figure, you can only do it on the first call to dbplot (for now)" _DBPLOT_FIGURES[None] = _PlotWindow(figure=fig, subplots=OrderedDict(), axes={}) fig = None elif fig not in _DBPLOT_FIGURES or not plt.fignum_exists( _DBPLOT_FIGURES[fig].figure.number ): # Second condition handles closed figures. _DBPLOT_FIGURES[fig] = _PlotWindow(figure=_make_dbplot_figure(), subplots=OrderedDict(), axes={}) if fig is not None: _DBPLOT_FIGURES[fig].figure.canvas.set_window_title(fig) suplot_dict = _DBPLOT_FIGURES[fig].subplots if axis is None: axis = name if name not in suplot_dict: if isinstance(plot_type, str): plot = PLOT_CONSTRUCTORS[plot_type]() elif plot_type is None: plot = get_plot_from_data(data, mode=plot_mode) else: assert hasattr(plot_type, "__call__") plot = plot_type() if isinstance(axis, SubplotSpec): axis = plt.subplot(axis) if isinstance(axis, Axes): ax = axis ax_name = str(axis) elif isinstance(axis, string_types) or axis is None: ax = select_subplot( axis, fig=_DBPLOT_FIGURES[fig].figure, layout=_default_layout if layout is None else layout) ax_name = axis # ax.set_title(axis) else: raise Exception( "Axis specifier must be a string, an Axis object, or a SubplotSpec object. Not {}" .format(axis)) if ax_name not in _DBPLOT_FIGURES[fig].axes: ax.set_title(name) _DBPLOT_FIGURES[fig].subplots[name] = _Subplot(axis=ax, plot_object=plot) _DBPLOT_FIGURES[fig].axes[ax_name] = ax _DBPLOT_FIGURES[fig].subplots[name] = _Subplot( axis=_DBPLOT_FIGURES[fig].axes[ax_name], plot_object=plot) plt.sca(_DBPLOT_FIGURES[fig].axes[ax_name]) if xlabel is not None: _DBPLOT_FIGURES[fig].subplots[name].axis.set_xlabel(xlabel) if ylabel is not None: _DBPLOT_FIGURES[fig].subplots[name].axis.set_ylabel(ylabel) if draw_every is not None: _draw_counters[fig, name] = -1 if grid: plt.grid() # Update the relevant data and plot it. TODO: Add option for plotting update interval plot = _DBPLOT_FIGURES[fig].subplots[name].plot_object plot.update(data) plot.plot() if cornertext is not None: if not hasattr(_DBPLOT_FIGURES[fig].figure, '__cornertext'): _DBPLOT_FIGURES[fig].figure.__cornertext = next( iter(_DBPLOT_FIGURES[fig].subplots.values())).axis.annotate( cornertext, xy=(0, 0), xytext=(0.01, 0.98), textcoords='figure fraction') else: _DBPLOT_FIGURES[fig].figure.__cornertext.set_text(cornertext) if title is not None: _DBPLOT_FIGURES[fig].subplots[name].axis.set_title(title) if legend is not None: _DBPLOT_FIGURES[fig].subplots[name].axis.legend(legend, loc='best', framealpha=0.5) if draw_now and not _hold_plots: if draw_every is not None: _draw_counters[fig, name] += 1 if _draw_counters[fig, name] % draw_every != 0: return _DBPLOT_FIGURES[fig].subplots[name].axis if hang: plt.figure(_DBPLOT_FIGURES[fig].figure.number) plt.show() else: redraw_figure(_DBPLOT_FIGURES[fig].figure) return _DBPLOT_FIGURES[fig].subplots[name].axis
def dbplot(data, name=None, plot_type=None, axis=None, plot_mode='live', draw_now=True, hang=False, title=None, fig=None, xlabel=None, ylabel=None, draw_every=None, layout=None, legend=None, grid=False, wait_for_display_sec=0, cornertext=None, reset_color_cycle=False): """ Plot arbitrary data and continue execution. This program tries to figure out what type of plot to use. :param data: Any data. Hopefully, we at dbplot will be able to figure out a plot for it. :param name: A name uniquely identifying this plot. :param Union[Callable[[],LinePlot],str,Tuple[Callable, Dict]] plot_type : A specialized constructor to be used the first time when plotting. Several predefined constructors are defined in the DBPlotTypes class - you can pass those. For back-compatibility you can also pass a string matching the name of one of the fields in the DBPlotTypes class. DBPlotTypes.LINE: Plots a line plot DBPlotTypes.IMG: An image plot DBPlotTypes.COLOUR: A colour image plot DBPlotTypes.PIC: A picture (no scale bars, axis labels, etc) You can also, pass a tuple of (constructor, keyword_args) where keyword args is a dict of arcuments to the plot constructor. :param axis: A string identifying which axis to plot on. By default, it is the same as "name". Only use this argument if you indend to make multiple dbplots share the same axis. :param plot_mode: Influences how the data should be used to choose the plot type: 'live': Best for 'live' plots that you intend to update as new data arrives 'static': Best for 'static' plots, that you do not intend to update 'image': Try to represent the plot as an image :param draw_now: Draw the plot now (you may choose false if you're going to add another plot immediately after and don't want have to draw this one again. :param hang: Hang on the plot (wait for it to be closed before continuing) :param title: Title of the plot (will default to name if not included) :param fig: Name of the figure - use this when you want to create multiple figures. :param grid: Turn the grid on :param wait_for_display_sec: In server mode, you can choose to wait maximally wait_for_display_sec seconds before this call returns. In case plotting is finished earlier, the call returns earlier. Setting wait_for_display_sec to a negative number will cause the call to block until the plot has been displayed. """ if is_server_plotting_on(): # Redirect the function call to the plotting server. The flag gets turned on in a configuration file. It is # turned off when this file is run ON the plotting server, from the first line in plotting_server.py arg_locals = locals().copy() from artemis.remote.plotting.plotting_client import dbplot_remotely dbplot_remotely(arg_locals=arg_locals) return if data.__class__.__module__ == 'torch' and data.__class__.__name__ == 'Tensor': data = data.detach().cpu().numpy() plot_object = _get_dbplot_plot_object(fig) # type: _PlotWindow suplot_dict = plot_object.subplots if axis is None: axis = name if name not in suplot_dict: # Initialize new axis if isinstance(plot_type, str): plot = DBPlotTypes.from_string(plot_type)() elif isinstance(plot_type, tuple): assert len(plot_type) == 2 and isinstance( plot_type[0], str ) and isinstance( plot_type[1], dict ), 'If you specify a tuple for plot_type, we expect (name, arg_dict). Got: {}'.format( plot_type) plot_type_name, plot_type_args = plot_type if isinstance(plot_type_name, str): plot = DBPlotTypes.from_string(plot_type_name)( **plot_type_args) elif callable(plot_type_name): plot = plot_type_name(**plot_type_args) else: raise Exception( 'The first argument of the plot type tuple must be a plot type name or a callable plot type constructor.' ) elif plot_type is None: plot = get_plot_from_data(data, mode=plot_mode) else: assert hasattr(plot_type, "__call__") plot = plot_type() if isinstance(axis, SubplotSpec): axis = plt.subplot(axis) if isinstance(axis, Axes): ax = axis ax_name = str(axis) elif isinstance(axis, string_types) or axis is None: ax = select_subplot( axis, fig=plot_object.figure, layout=_default_layout if layout is None else layout) ax_name = axis # ax.set_title(axis) else: raise Exception( "Axis specifier must be a string, an Axis object, or a SubplotSpec object. Not {}" .format(axis)) if ax_name not in plot_object.axes: ax.set_title(name) plot_object.subplots[name] = _Subplot(axis=ax, plot_object=plot) plot_object.axes[ax_name] = ax plot_object.subplots[name] = _Subplot(axis=plot_object.axes[ax_name], plot_object=plot) plt.sca(plot_object.axes[ax_name]) if xlabel is not None: plot_object.subplots[name].axis.set_xlabel(xlabel) if ylabel is not None: plot_object.subplots[name].axis.set_ylabel(ylabel) if draw_every is not None: _draw_counters[fig, name] = Checkpoints(draw_every) if grid: plt.grid() plot = plot_object.subplots[name].plot_object if reset_color_cycle: use_dbplot_axis(axis, fig=fig, clear=False).set_color_cycle(None) plot.update(data) # Update Labels... if cornertext is not None: if not hasattr(plot_object.figure, '__cornertext'): plot_object.figure.__cornertext = next( iter(plot_object.subplots.values())).axis.annotate( cornertext, xy=(0, 0), xytext=(0.01, 0.98), textcoords='figure fraction') else: plot_object.figure.__cornertext.set_text(cornertext) if title is not None: plot_object.subplots[name].axis.set_title(title) if legend is not None: plot_object.subplots[name].axis.legend(legend, loc='best', framealpha=0.5) if draw_now and not _hold_plots and (draw_every is None or ( (fig, name) not in _draw_counters) or _draw_counters[fig, name]()): plot.plot() display_figure(plot_object.figure, hang=hang) return plot_object.subplots[name].axis
def plot_mnist_energy_results(results, x_scale=('flops', 'int-energy')): # all_settings_, all_measurements_, all_datasets_, all_subsets_, all_nets_ = get_data_structure_info(results) if isinstance(x_scale, basestring): x_scale = [x_scale] assert all(x in ('flops', 'multadd-flops', 'int-energy', 'float-energy') for x in x_scale) subset = 'test' lambda_results = [ res for name, res in results.iteritems() if 'lambda' in name.lower() ] dataset_names = OrderedDict([('mnist', 'MNIST'), ('temp_mnist', 'Temporal MNIST')]) plt.figure() first_plot = True for x_s in x_scale: for dataset in dataset_names: last_plot = x_s == x_scale[-1] and dataset == dataset_names.keys( )[-1] ax = select_subplot((x_s, dataset), layout='h') plt.subplots_adjust(bottom=.15, wspace=0) flop_measure = 'MFlops' if x_s == 'flops' else 'MFlops-multadd' td_flops = [ res[dataset, subset, 'td'][flop_measure] for res in lambda_results ] td_errors = [ res[dataset, subset, 'td']['class_error'] for res in lambda_results ] round_flops = [ res[dataset, subset, 'round'][flop_measure] for res in lambda_results ] round_errors = [ res[dataset, subset, 'round']['class_error'] for res in lambda_results ] original_flops = results['unoptimized'][ dataset, subset, 'truth']['Sparse MFlops'], results['unoptimized'][ dataset, subset, 'truth']['Dense MFlops'] original_errors = [ results['unoptimized'][dataset, subset, 'truth']['class_error'] ] * 2 if x_s in ('int-energy', 'float-energy'): dtype = 'int' if x_s == 'int-energy' else 'float' # Cet the energy in nJ. (pJ/Op)*(N MPp)->(muJ/Op). (muJ/Op)*(1e3 n/mu) -> nJ/Op td_x = estimate_energy_cost( td_flops, 'add', dtype=dtype, n_bits=32) * 1e3 round_x = estimate_energy_cost( round_flops, 'add', dtype=dtype, n_bits=32) * 1e3 original_x = estimate_energy_cost( original_flops, 'mult-add', dtype=dtype, n_bits=32) * 1e3 else: td_x, round_x, original_x = np.array(td_flops) * 1e3, np.array( round_flops) * 1e3, np.array(original_flops) * 1e3 plt.plot(round_x, round_errors, linewidth=2, label='Rounding Network', marker='.', markersize=10) plt.plot(td_x, td_errors, label='$\Sigma\Delta$ Network', linewidth=2, marker='.', markersize=10) plt.plot(original_x, original_errors, label='Original Network', marker='.', markersize=20, linewidth=2) if x_s in ('int-energy', 'float-energy'): plt.xlabel('nJ/sample') # pico*mega = micro plt.gca().set_xscale('log') else: plt.xlabel('kOps/sample') for label in ax.get_xticklabels()[::2]: label.set_visible(False) plt.gca().set_yscale('log') plt.title(dataset_names[dataset]) plt.grid() yticks = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 20]) plt.gca().yaxis.set_ticks(yticks) plt.gca().set_ylim(1.5, 20) if first_plot: plt.ylabel('Classification Error (%)') first_plot = False plt.gca().set_yticklabels([str(i) for i in yticks]) else: plt.gca().set_yticklabels(['' for _ in yticks]) if last_plot: plt.legend(loc='upper right', framealpha=0.5, fontsize='medium') plt.show()
def dbplot(data, name=None, plot_type=None, axis=None, plot_mode='live', draw_now=True, hang=False, title=None, fig=None, xlabel=None, ylabel=None, draw_every=None, layout=None, legend=None, grid=False, wait_for_display_sec=0): """ Plot arbitrary data. This program tries to figure out what type of plot to use. :param data: Any data. Hopefully, we at dbplot will be able to figure out a plot for it. :param name: A name uniquely identifying this plot. :param plot_type: A specialized constructor to be used the first time when plotting. You can also pass certain string to give hints as to what kind of plot you want (can resolve cases where the given data could be plotted in multiple ways): 'line': Plots a line plot 'img': An image plot 'colour': A colour image plot 'pic': A picture (no scale bars, axis labels, etc). :param axis: A string identifying which axis to plot on. By default, it is the same as "name". Only use this argument if you indend to make multiple dbplots share the same axis. :param plot_mode: Influences how the data should be used to choose the plot type: 'live': Best for 'live' plots that you intend to update as new data arrives 'static': Best for 'static' plots, that you do not intend to update 'image': Try to represent the plot as an image :param draw_now: Draw the plot now (you may choose false if you're going to add another plot immediately after and don't want have to draw this one again. :param hang: Hang on the plot (wait for it to be closed before continuing) :param title: Title of the plot (will default to name if not included) :param fig: Name of the figure - use this when you want to create multiple figures. :param grid: Turn the grid on :param wait_for_display_sec: In server mode, you can choose to wait maximally wait_for_display_sec seconds before this call returns. In case plotting is finished earlier, the call returns earlier. Setting wait_for_display_sec to a negative number will cause the call to block until the plot has been displayed. """ if is_server_plotting_on(): # Redirect the function call to the plotting server. The flag gets turned on in a configuration file. It is # turned off when this file is run ON the plotting server, from the first line in plotting_server.py arg_locals = locals().copy() from artemis.remote.plotting.plotting_client import dbplot_remotetly dbplot_remotetly(arg_locals=arg_locals) return if isinstance(fig, plt.Figure): assert None not in _DBPLOT_FIGURES, "If you pass a figure, you can only do it on the first call to dbplot (for now)" _DBPLOT_FIGURES[None] = fig fig = None elif fig not in _DBPLOT_FIGURES: _DBPLOT_FIGURES[fig] = _PlotWindow(figure=_make_dbplot_figure(), subplots=OrderedDict(), axes={}) if fig is not None: _DBPLOT_FIGURES[fig].figure.canvas.set_window_title(fig) suplot_dict = _DBPLOT_FIGURES[fig].subplots if axis is None: axis = name if name not in suplot_dict: if isinstance(plot_type, str): plot = { 'line': LinePlot, 'thick-line': lambda: LinePlot(plot_kwargs={'linewidth': 3}), 'pos_line': lambda: LinePlot(y_bounds=(0, None), y_bound_extend=(0, 0.05)), # 'pos_line': lambda: LinePlot(y_bounds=(0, None)), 'bar': BarPlot, 'img': ImagePlot, 'colour': lambda: ImagePlot(is_colour_data=True), 'equal_aspect': lambda: ImagePlot(aspect='equal'), 'image_history': lambda: MovingImagePlot(), 'fixed_line_history': lambda: MovingPointPlot(buffer_len=100), 'pic': lambda: ImagePlot(show_clims=False, aspect='equal'), 'notice': lambda: TextPlot(max_history=1, horizontal_alignment='center', vertical_alignment='center', size='x-large'), 'cost': lambda: MovingPointPlot(y_bounds=(0, None), y_bound_extend=(0, 0.05)), 'percent': lambda: MovingPointPlot(y_bounds=(0, 100)), 'trajectory': lambda: Moving2DPointPlot(axes_update_mode='expand'), 'trajectory+': lambda: Moving2DPointPlot(axes_update_mode='expand', x_bounds=(0, None), y_bounds=(0, None)), 'histogram': lambda: HistogramPlot(edges=np.linspace(-5, 5, 20)), 'cumhist': lambda: CumulativeLineHistogram(edges=np.linspace(-5, 5, 20)), }[plot_type]() elif plot_type is None: plot = get_plot_from_data(data, mode=plot_mode) else: assert hasattr(plot_type, "__call__") plot = plot_type() if isinstance(axis, SubplotSpec): axis = plt.subplot(axis) if isinstance(axis, Axes): ax = axis ax_name = str(axis) elif isinstance(axis, basestring) or axis is None: ax = select_subplot( axis, fig=_DBPLOT_FIGURES[fig].figure, layout=_default_layout if layout is None else layout) ax_name = axis # ax.set_title(axis) else: raise Exception( "Axis specifier must be a string, an Axis object, or a SubplotSpec object. Not {}" .format(axis)) if ax_name not in _DBPLOT_FIGURES[fig].axes: ax.set_title(name) _DBPLOT_FIGURES[fig].subplots[name] = _Subplot(axis=ax, plot_object=plot) _DBPLOT_FIGURES[fig].axes[ax_name] = ax _DBPLOT_FIGURES[fig].subplots[name] = _Subplot( axis=_DBPLOT_FIGURES[fig].axes[ax_name], plot_object=plot) plt.sca(_DBPLOT_FIGURES[fig].axes[ax_name]) if xlabel is not None: _DBPLOT_FIGURES[fig].subplots[name].axis.set_xlabel(xlabel) if ylabel is not None: _DBPLOT_FIGURES[fig].subplots[name].axis.set_ylabel(ylabel) if draw_every is not None: _draw_counters[fig, name] = -1 if grid: plt.grid() # Update the relevant data and plot it. TODO: Add option for plotting update interval plot = _DBPLOT_FIGURES[fig].subplots[name].plot_object plot.update(data) plot.plot() if title is not None: _DBPLOT_FIGURES[fig].subplots[name].axis.set_title(title) if legend is not None: _DBPLOT_FIGURES[fig].subplots[name].axis.legend(legend, loc='best', framealpha=0.5) if draw_now and not _hold_plots: if draw_every is not None: _draw_counters[fig, name] += 1 if _draw_counters[fig, name] % draw_every != 0: return _DBPLOT_FIGURES[fig].subplots[name].axis if hang: plt.figure(_DBPLOT_FIGURES[fig].figure.number) plt.show() else: redraw_figure(_DBPLOT_FIGURES[fig].figure) return _DBPLOT_FIGURES[fig].subplots[name].axis
def dbplot(data, name = None, plot_type = None, axis=None, plot_mode = 'live', draw_now = True, hang = False, title=None, fig = None, xlabel = None, ylabel = None, draw_every = None, layout=None, legend=None, grid=False, wait_for_display_sec=0, cornertext = None, reset_color_cycle = False): """ Plot arbitrary data and continue execution. This program tries to figure out what type of plot to use. :param data: Any data. Hopefully, we at dbplot will be able to figure out a plot for it. :param name: A name uniquely identifying this plot. :param plot_type: A specialized constructor to be used the first time when plotting. You can also pass certain string to give hints as to what kind of plot you want (can resolve cases where the given data could be plotted in multiple ways): 'line': Plots a line plot 'img': An image plot 'colour': A colour image plot 'pic': A picture (no scale bars, axis labels, etc). :param axis: A string identifying which axis to plot on. By default, it is the same as "name". Only use this argument if you indend to make multiple dbplots share the same axis. :param plot_mode: Influences how the data should be used to choose the plot type: 'live': Best for 'live' plots that you intend to update as new data arrives 'static': Best for 'static' plots, that you do not intend to update 'image': Try to represent the plot as an image :param draw_now: Draw the plot now (you may choose false if you're going to add another plot immediately after and don't want have to draw this one again. :param hang: Hang on the plot (wait for it to be closed before continuing) :param title: Title of the plot (will default to name if not included) :param fig: Name of the figure - use this when you want to create multiple figures. :param grid: Turn the grid on :param wait_for_display_sec: In server mode, you can choose to wait maximally wait_for_display_sec seconds before this call returns. In case plotting is finished earlier, the call returns earlier. Setting wait_for_display_sec to a negative number will cause the call to block until the plot has been displayed. """ if is_server_plotting_on(): # Redirect the function call to the plotting server. The flag gets turned on in a configuration file. It is # turned off when this file is run ON the plotting server, from the first line in plotting_server.py arg_locals = locals().copy() from artemis.remote.plotting.plotting_client import dbplot_remotely dbplot_remotely(arg_locals=arg_locals) return if isinstance(fig, plt.Figure): assert None not in _DBPLOT_FIGURES, "If you pass a figure, you can only do it on the first call to dbplot (for now)" _DBPLOT_FIGURES[None] = _PlotWindow(figure=fig, subplots=OrderedDict(), axes={}) fig = None elif fig not in _DBPLOT_FIGURES or not plt.fignum_exists(_DBPLOT_FIGURES[fig].figure.number): # Second condition handles closed figures. _DBPLOT_FIGURES[fig] = _PlotWindow(figure = _make_dbplot_figure(), subplots=OrderedDict(), axes = {}) if fig is not None: _DBPLOT_FIGURES[fig].figure.canvas.set_window_title(fig) suplot_dict = _DBPLOT_FIGURES[fig].subplots if axis is None: axis=name if name not in suplot_dict: # Initialize new axis if isinstance(plot_type, str): plot = PLOT_CONSTRUCTORS[plot_type]() elif isinstance(plot_type, tuple): assert len(plot_type)==2 and isinstance(plot_type[0], str) and isinstance(plot_type[1], dict), 'If you specify a tuple for plot_type, we expect (name, arg_dict). Got: {}'.format(plot_type) plot_type_name, plot_type_args = plot_type plot = PLOT_CONSTRUCTORS[plot_type_name](**plot_type_args) elif plot_type is None: plot = get_plot_from_data(data, mode=plot_mode) else: assert hasattr(plot_type, "__call__") plot = plot_type() if isinstance(axis, SubplotSpec): axis = plt.subplot(axis) if isinstance(axis, Axes): ax = axis ax_name = str(axis) elif isinstance(axis, string_types) or axis is None: ax = select_subplot(axis, fig=_DBPLOT_FIGURES[fig].figure, layout=_default_layout if layout is None else layout) ax_name = axis # ax.set_title(axis) else: raise Exception("Axis specifier must be a string, an Axis object, or a SubplotSpec object. Not {}".format(axis)) if ax_name not in _DBPLOT_FIGURES[fig].axes: ax.set_title(name) _DBPLOT_FIGURES[fig].subplots[name] = _Subplot(axis=ax, plot_object=plot) _DBPLOT_FIGURES[fig].axes[ax_name] = ax _DBPLOT_FIGURES[fig].subplots[name] = _Subplot(axis=_DBPLOT_FIGURES[fig].axes[ax_name], plot_object=plot) plt.sca(_DBPLOT_FIGURES[fig].axes[ax_name]) if xlabel is not None: _DBPLOT_FIGURES[fig].subplots[name].axis.set_xlabel(xlabel) if ylabel is not None: _DBPLOT_FIGURES[fig].subplots[name].axis.set_ylabel(ylabel) if draw_every is not None: _draw_counters[fig, name] = Checkpoints(draw_every) if grid: plt.grid() plot = _DBPLOT_FIGURES[fig].subplots[name].plot_object if reset_color_cycle: get_dbplot_axis(axis_name=axis, fig=fig).set_color_cycle(None) plot.update(data) # Update Labels... if cornertext is not None: if not hasattr(_DBPLOT_FIGURES[fig].figure, '__cornertext'): _DBPLOT_FIGURES[fig].figure.__cornertext = next(iter(_DBPLOT_FIGURES[fig].subplots.values())).axis.annotate(cornertext, xy=(0, 0), xytext=(0.01, 0.98), textcoords='figure fraction') else: _DBPLOT_FIGURES[fig].figure.__cornertext.set_text(cornertext) if title is not None: _DBPLOT_FIGURES[fig].subplots[name].axis.set_title(title) if legend is not None: _DBPLOT_FIGURES[fig].subplots[name].axis.legend(legend, loc='best', framealpha=0.5) if draw_now and not _hold_plots and (draw_every is None or ((fig, name) not in _draw_counters) or _draw_counters[fig, name]()): plot.plot() if hang: plt.figure(_DBPLOT_FIGURES[fig].figure.number) plt.show() else: redraw_figure(_DBPLOT_FIGURES[fig].figure) return _DBPLOT_FIGURES[fig].subplots[name].axis