def _get_viewer_data(data_source, case_id=None): """ Get the data needed by the N2 viewer as a dictionary. Parameters ---------- data_source : <Problem> or <Group> or str A Problem or Group or case recorder filename containing the model or model data. If the case recorder file from a parallel run has separate metadata, the filenames can be specified with a comma, e.g.: case.sql_0,case.sql_meta case_id : int or str or None Case name or index of case in SQL file. Returns ------- dict A dictionary containing information about the model for use by the viewer. """ if isinstance(data_source, Problem): root_group = data_source.model if not isinstance(root_group, Group): issue_warning( "The model is not a Group, viewer data is unavailable.") return {} driver = data_source.driver driver_name = driver.__class__.__name__ driver_type = 'doe' if isinstance(driver, DOEDriver) else 'optimization' driver_options = { key: _serialize_single_option(driver.options._dict[key]) for key in driver.options } if driver_type == 'optimization' and hasattr(driver, 'opt_settings'): driver_opt_settings = driver.opt_settings else: driver_opt_settings = None elif isinstance(data_source, Group): if not data_source.pathname: # root group root_group = data_source driver_name = None driver_type = None driver_options = None driver_opt_settings = None else: # this function only makes sense when it is at the root issue_warning( f"Viewer data is not available for sub-Group '{data_source.pathname}'." ) return {} elif isinstance(data_source, str): if ',' in data_source: filenames = data_source.split(',') cr = CaseReader(filenames[0], metadata_filename=filenames[1]) else: cr = CaseReader(data_source) data_dict = cr.problem_metadata if case_id is not None: cases = cr.get_case(case_id) print(f"Using source: {cases.source}\nCase: {cases.name}") def recurse(children, stack): for child in children: # if 'val' in child if child['type'] == 'subsystem': if child['name'] != '_auto_ivc': stack.append(child['name']) recurse(child['children'], stack) stack.pop() elif child['type'] == 'input': if cases.inputs is None: child['val'] = 'N/A' else: path = child['name'] if not stack else '.'.join( stack + [child['name']]) child['val'] = cases.inputs[path] elif child['type'] == 'output': if cases.outputs is None: child['val'] = 'N/A' else: path = child['name'] if not stack else '.'.join( stack + [child['name']]) try: child['val'] = cases.outputs[path] except KeyError: child['val'] = 'N/A' recurse(data_dict['tree']['children'], []) # Delete the variables key since it's not used in N2 if 'variables' in data_dict: del data_dict['variables'] # Older recordings might not have this. if 'md5_hash' not in data_dict: data_dict['md5_hash'] = None return data_dict else: raise TypeError( f"Viewer data is not available for '{data_source}'." "The source must be a Problem, model or the filename of a recording." ) data_dict = {} data_dict['tree'] = _get_tree_dict(root_group) data_dict['md5_hash'] = root_group._generate_md5_hash() connections_list = [] sys_idx = { } # map of pathnames to index of pathname in list (systems in cycles only) G = root_group.compute_sys_graph(comps_only=True) scc = nx.strongly_connected_components(G) strongdict = {} for i, strong_comp in enumerate(scc): for c in strong_comp: strongdict[ c] = i # associate each comp with a strongly connected component if len(strong_comp) > 1: # these IDs are only used when back edges are present for name in strong_comp: sys_idx[name] = len(sys_idx) comp_orders = { name: i for i, name in enumerate(root_group._ordered_comp_name_iter()) } # 1 is added to the indices of all edges in the matrix so that we can use 0 entries to # indicate that there is no connection. matrix = np.zeros((len(comp_orders), len(comp_orders)), dtype=np.int32) edge_ids = [] for i, edge in enumerate(G.edges()): src, tgt = edge if strongdict[src] == strongdict[tgt]: matrix[comp_orders[src], comp_orders[tgt]] = i + 1 # bump edge index by 1 edge_ids.append((sys_idx[src], sys_idx[tgt])) else: edge_ids.append(None) for edge_i, (src, tgt) in enumerate(G.edges()): if strongdict[src] == strongdict[tgt]: start = comp_orders[src] end = comp_orders[tgt] # get a view here so we can remove this edge from submat temporarily to eliminate # an 'if' check inside the nested list comprehension for edges_list rem = matrix[start:start + 1, end:end + 1] rem[0, 0] = 0 if end < start: start, end = end, start submat = matrix[start:end + 1, start:end + 1] nz = submat[submat > 0] rem[0, 0] = edge_i + 1 # put removed edge back if nz.size > 1: nz -= 1 # convert back to correct edge index edges_list = [edge_ids[i] for i in nz] for vsrc, vtgtlist in G.get_edge_data(src, tgt)['conns'].items(): for vtgt in vtgtlist: connections_list.append({ 'src': vsrc, 'tgt': vtgt, 'cycle_arrows': edges_list }) continue for vsrc, vtgtlist in G.get_edge_data(src, tgt)['conns'].items(): for vtgt in vtgtlist: connections_list.append({'src': vsrc, 'tgt': vtgt}) data_dict['sys_pathnames_list'] = list(sys_idx) data_dict['connections_list'] = connections_list data_dict['abs2prom'] = root_group._var_abs2prom data_dict['driver'] = { 'name': driver_name, 'type': driver_type, 'options': driver_options, 'opt_settings': driver_opt_settings } data_dict['design_vars'] = root_group.get_design_vars(use_prom_ivc=False) data_dict['responses'] = root_group.get_responses() data_dict['declare_partials_list'] = _get_declare_partials(root_group) return data_dict
def _get_viewer_data(data_source, case_id=None): """ Get the data needed by the N2 viewer as a dictionary. Parameters ---------- data_source : <Problem> or <Group> or str A Problem or Group or case recorder filename containing the model or model data. If the case recorder file from a parallel run has separate metadata, the filenames can be specified with a comma, e.g.: case.sql_0,case.sql_meta case_id : int or str or None Case name or index of case in SQL file. Returns ------- dict A dictionary containing information about the model for use by the viewer. """ if isinstance(data_source, Problem): root_group = data_source.model if not isinstance(root_group, Group): simple_warning("The model is not a Group, viewer data is unavailable.") return {} driver = data_source.driver driver_name = driver.__class__.__name__ driver_type = 'doe' if isinstance(driver, DOEDriver) else 'optimization' driver_options = {key: _serialize_single_option(driver.options._dict[key]) for key in driver.options} if driver_type == 'optimization' and 'opt_settings' in dir(driver): driver_opt_settings = driver.opt_settings else: driver_opt_settings = None elif isinstance(data_source, Group): if not data_source.pathname: # root group root_group = data_source driver_name = None driver_type = None driver_options = None driver_opt_settings = None else: # this function only makes sense when it is at the root simple_warning(f"Viewer data is not available for sub-Group '{data_source.pathname}'.") return {} elif isinstance(data_source, str): if ',' in data_source: filenames = data_source.split(',') cr = CaseReader(filenames[0], metadata_filename=filenames[1]) else: cr = CaseReader(data_source) data_dict = cr.problem_metadata if case_id is not None: cases = cr.get_case(case_id) print(f"Using source: {cases.source}\nCase: {cases.name}") def recurse(children, stack): for child in children: if child['type'] == 'subsystem': if child['name'] != '_auto_ivc': stack.append(child['name']) recurse(child['children'], stack) stack.pop() elif child['type'] == 'input': if cases.inputs is None: child['value'] = 'N/A' else: path = child['name'] if not stack else '.'.join(stack + [child['name']]) child['value'] = cases.inputs[path] elif child['type'] == 'output': if cases.outputs is None: child['value'] = 'N/A' else: path = child['name'] if not stack else '.'.join(stack + [child['name']]) try: child['value'] = cases.outputs[path] except KeyError: child['value'] = 'N/A' recurse(data_dict['tree']['children'], []) # Delete the variables key since it's not used in N2 if 'variables' in data_dict: del data_dict['variables'] # Older recordings might not have this. if 'md5_hash' not in data_dict: data_dict['md5_hash'] = None return data_dict else: raise TypeError(f"Viewer data is not available for '{data_source}'." "The source must be a Problem, model or the filename of a recording.") data_dict = {} comp_exec_idx = [0] # list so pass by ref orders = {} data_dict['tree'] = _get_tree_dict(root_group, orders, comp_exec_idx) data_dict['md5_hash'] = root_group._generate_md5_hash() connections_list = [] sys_pathnames_list = [] # list of pathnames of systems found in cycles sys_pathnames_dict = {} # map of pathnames to index of pathname in list G = root_group.compute_sys_graph(comps_only=True) scc = nx.strongly_connected_components(G) for strong_comp in scc: if len(strong_comp) > 1: # these IDs are only used when back edges are present sys_pathnames_list.extend(strong_comp) for name in strong_comp: sys_pathnames_dict[name] = len(sys_pathnames_dict) for src, tgt in G.edges(strong_comp): if src in strong_comp and tgt in strong_comp: if src in orders: exe_src = orders[src] else: exe_src = orders[src] = -1 if tgt in orders: exe_tgt = orders[tgt] else: exe_tgt = orders[tgt] = -1 if exe_tgt < exe_src: exe_low = exe_tgt exe_high = exe_src else: exe_low = exe_src exe_high = exe_tgt edges_list = [ (sys_pathnames_dict[s], sys_pathnames_dict[t]) for s, t in G.edges(strong_comp) if s in orders and exe_low <= orders[s] <= exe_high and t in orders and exe_low <= orders[t] <= exe_high and not (s == src and t == tgt) and t in sys_pathnames_dict ] for vsrc, vtgtlist in G.get_edge_data(src, tgt)['conns'].items(): for vtgt in vtgtlist: connections_list.append({'src': vsrc, 'tgt': vtgt, 'cycle_arrows': edges_list}) else: # edge is out of the SCC for vsrc, vtgtlist in G.get_edge_data(src, tgt)['conns'].items(): for vtgt in vtgtlist: connections_list.append({'src': vsrc, 'tgt': vtgt}) data_dict['sys_pathnames_list'] = sys_pathnames_list data_dict['connections_list'] = connections_list data_dict['abs2prom'] = root_group._var_abs2prom data_dict['driver'] = { 'name': driver_name, 'type': driver_type, 'options': driver_options, 'opt_settings': driver_opt_settings } data_dict['design_vars'] = root_group.get_design_vars(use_prom_ivc=False) data_dict['responses'] = root_group.get_responses() data_dict['declare_partials_list'] = _get_declare_partials(root_group) return data_dict
class CaseViewer(object): """ Notebook GUI to plot data from a CaseReader. Parameters ---------- f : str The file from which the cases are to be viewed. Attributes ---------- _filename : str The filename associated with the record file. _case_reader : CaseReader The CaseReader instance used to retrieve data from the case file. _lines : list of mpl.Line2D The line objects on the matplotlib Axes. _scatters : list of PathCollection The scatter plot objects on the matplotlib Axes. _fig : mpl.Figure The figure on which the plots are displayed. _ax : mpl.Axes The matplotlib Axes on which the plots are displayed. _cmap : mpl.ColorMap The matplotlib color map used to show variation in the plots. _scalar_mappable : cm.ScalarMappable The object used to provide the extents of the color mapping. _colorbar : mpl.Colorbar The colorbar shown on the right side of the figure. _case_index_str : str Constant defining the "Case Index" string used throughout the CaseViewer. """ def __init__(self, f): """ Initialize the case viewer interface. """ if mpl is None: raise RuntimeError('CaseViewer requires matplotlib and ipympl') if get_ipython is None: raise RuntimeError('CaseViewer requires jupyter') if ipw is None: raise RuntimeError('CaseViewer requires ipywidgets') if get_ipython() is None: raise RuntimeError( 'CaseViewer must be run from within a Jupyter notebook.') try: import ipympl except ImportError: raise RuntimeError('CaseViewer requires ipympl') get_ipython().run_line_magic('matplotlib', 'widget') self._case_reader = CaseReader(f) if isinstance(f, str) else f self._cmap = cm.viridis self._case_index_str = 'Case Index' self._filename = self._case_reader._filename self._make_gui() self._register_callbacks() self._fig, self._ax = plt.subplots(1, 1, figsize=(9, 9 / 1.6), tight_layout=True) norm = mpl.colors.Normalize(vmin=0, vmax=1) self._scalar_mappable = cm.ScalarMappable(norm=norm, cmap=self._cmap) self._colorbar = self._fig.colorbar(self._scalar_mappable, label='Case Index') self._scatters = [] self._lines = [] self._update_source_options() self._update_case_select_options() self._update_var_select_options('x') self._update_var_select_options('y') self._update_var_info('x') self._update_var_info('y') def _make_gui(self): """ Define the widgets for the CaseViewer and display them. """ self._widgets = {} self._widgets['source_select'] = ipw.Dropdown(description='Source', layout=ipw.Layout( width='30%', height='auto')) self._widgets['cases_select'] = ipw.SelectMultiple(description='Cases', layout=ipw.Layout( width='40%', height='auto')) self._widgets['case_select_button'] = ipw.Button( description='Select ' + '\u27F6', layout={'width': '100%'}) self._widgets['case_select_all_button'] = ipw.Button( description='Select All ' + '\u27F9', layout={'width': '100%'}) self._widgets['case_remove_button'] = ipw.Button( description='Remove ' + '\u274c', layout={'width': '100%'}) self._widgets['cases_list'] = ipw.Select( layout=ipw.Layout(width='40%', height='auto')) self._widgets['x_filter'] = ipw.Text('', description='X-Axis Filter', layout=ipw.Layout(width='49%', height='auto')) var_types_list = [ 'outputs', 'inputs', 'optimization', 'desvars', 'constraints', 'objectives', 'residuals' ] self._widgets['x_var_type'] = ipw.Dropdown(options=var_types_list, description='X Var Type', value='outputs', layout={'width': '49%'}) self._widgets['x_select'] = ipw.Select( description='X-Axis', layout=ipw.Layout(width='auto', height='auto')) self._widgets['x_transform_select'] = ipw.Dropdown( options=_func_map.keys(), value='None', description='X Transform', layout=ipw.Layout(width='95%', height='auto')) self._widgets['x_slice'] = ipw.Text('[...]', description='X Slice', layout=ipw.Layout(width='95%', height='auto')) self._widgets['x_info'] = ipw.HTML(value='', description='X Shape', layout={'width': '95%'}) self._widgets['x_scale'] = ipw.Dropdown(options=['linear', 'log'], value='linear', description='X Scale', layout={'width': '95%'}) self._widgets['y_filter'] = ipw.Text('', description='Y-Axis Filter', layout={ 'width': '49%', 'height': 'auto' }) self._widgets['y_var_type'] = ipw.Dropdown(options=var_types_list, description='Y Var Type', value='outputs', layout={'width': '49%'}) self._widgets['y_select'] = ipw.Select( description='Y-Axis', layout=ipw.Layout(width='auto', height='auto')) self._widgets['y_transform_select'] = ipw.Dropdown( options=_func_map.keys(), value='None', description='Y Transform', layout=ipw.Layout(width='95%', height='auto')) self._widgets['y_slice'] = ipw.Text('[...]', description='Y Slice', layout=ipw.Layout(width='95%', height='auto')) self._widgets['y_info'] = ipw.HTML(value='', description='Y Shape', layout={'width': '95%'}) self._widgets['y_scale'] = ipw.Dropdown(options=['linear', 'log'], value='linear', description='Y Scale', layout={'width': '95%'}) self._widgets['case_slider'] = ipw.IntSlider(value=1, min=0, max=1, step=1, description='Case #', disabled=False, continuous_update=True, orientation='horizontal', readout=True, readout_format='d', layout={'width': '95%'}) self._widgets['debug_output'] = ipw.Output(description='', layout={ 'border': '0px solid black', 'width': '95%', 'height': '400' }) display( ipw.VBox([ self._widgets['source_select'], ipw.HBox([ self._widgets['cases_select'], ipw.VBox([ self._widgets['case_select_button'], self._widgets['case_select_all_button'], self._widgets['case_remove_button'] ]), self._widgets['cases_list'] ], layout={'width': '95%'}), ipw.HBox([ ipw.VBox([ ipw.HBox([ self._widgets['x_filter'], self._widgets['x_var_type'] ]), self._widgets['x_select'] ], layout={'width': '50%'}), ipw.VBox([ self._widgets['x_info'], self._widgets['x_slice'], self._widgets['x_transform_select'], self._widgets['x_scale'] ], layout={'width': '20%'}), ]), ipw.HBox([ ipw.VBox([ ipw.HBox([ self._widgets['y_filter'], self._widgets['y_var_type'] ]), self._widgets['y_select'] ], layout={'width': '50%'}), ipw.VBox([ self._widgets['y_info'], self._widgets['y_slice'], self._widgets['y_transform_select'], self._widgets['y_scale'] ], layout={'width': '20%'}) ]), self._widgets['case_slider'], self._widgets['debug_output'] ])) def _update_source_options(self): """ Update the contents of the source selection dropdown menu. """ sources = self._case_reader.list_sources(out_stream=None) self._widgets['source_select'].options = sources self._widgets['source_select'].value = sources[0] def _update_var_info(self, axis): """ Update the variable info displayed. Parameters ---------- axis : str 'x' or 'y' - case insensitive. """ if axis.lower() not in ('x', 'y'): raise ValueError(f'Unknown axis: {axis}') src = self._widgets['source_select'].value cases = self._widgets['cases_list'].options if not cases: self._widgets[f'{axis}_info'].value = 'N/A' return var = self._widgets[f'{axis}_select'].value if var == self._case_index_str: shape = (len(cases), ) elif var is None: shape = 'N/A' else: meta = _get_var_meta(self._case_reader, cases[0], var) shape = meta['shape'] self._widgets[f'{axis}_info'].value = f'{shape}' def _update_case_select_options(self): """ Update the available cases listed in the source_select widget. """ src = self._widgets['source_select'].value avialable_cases = self._case_reader.list_cases(source=src, recurse=False, out_stream=None) self._widgets['cases_select'].options = avialable_cases self._update_case_slider() def _update_var_select_options(self, axis): """ Update the variables available for plotting. Parameters ---------- axis : str 'x' or 'y' - case insensitive. """ with self._widgets['debug_output']: if axis.lower() not in ('x', 'y'): raise ValueError(f'Unknown axis: {axis}') src = self._widgets['source_select'].value cases = self._widgets['cases_list'].options if not cases: self._widgets[f'{axis}_select'].options = [] self._widgets[f'{axis}_info'].value = 'N/A' return w_var_select = self._widgets[f'{axis}_select'] var_filter = self._widgets[f'{axis}_filter'].value var_select = w_var_select.value var_type = self._widgets[f'{axis}_var_type'].value if var_type == 'optimization': vars = _get_opt_vars(self._case_reader, cases) elif var_type in ('desvars', 'constraints', 'objectives'): vars = _get_opt_vars(self._case_reader, cases, var_type=var_type) elif var_type == 'residuals': vars = _get_resids_vars(self._case_reader, cases) else: vars = _get_vars(self._case_reader, cases, var_types=var_type) # We have a list of available vars, now filter it. r = re.compile(var_filter) filtered_list = list(filter(r.search, vars)) w_var_select.options = [self._case_index_str] + filtered_list \ if axis == 'x' else filtered_list self._update_var_info(axis) def _update_case_slider(self): """ Update the extents of the case slider. """ n = len(self._widgets['cases_list'].options) self._widgets['case_slider'].max = n self._widgets['case_slider'].value = n def _register_callbacks(self): """ Register callback functions with the widgets. """ self._widgets['debug_output'].clear_output() with self._widgets['debug_output']: self._widgets['source_select'].observe( self._callback_select_source) self._widgets['case_select_button'].on_click( self._callback_select_case) self._widgets['case_select_all_button'].on_click( self._callback_select_all_cases) self._widgets['case_remove_button'].on_click( self._callback_remove_case) self._widgets['cases_list'].observe( self._callback_case_list_select) self._widgets['x_filter'].observe(self._callback_filter_vars, 'value') self._widgets['y_filter'].observe(self._callback_filter_vars, 'value') self._widgets['x_var_type'].observe(self._callback_filter_vars, 'value') self._widgets['y_var_type'].observe(self._callback_filter_vars, 'value') self._widgets['x_select'].observe(self._callback_select_var, 'value') self._widgets['y_select'].observe(self._callback_select_var, 'value') self._widgets['x_slice'].observe(self._callback_change_slice, 'value') self._widgets['y_slice'].observe(self._callback_change_slice, 'value') self._widgets['x_scale'].observe(self._callback_change_scale, 'value') self._widgets['y_scale'].observe(self._callback_change_scale, 'value') self._widgets['x_transform_select'].observe( self._callback_select_transform, 'value') self._widgets['y_transform_select'].observe( self._callback_select_transform, 'value') self._widgets['case_slider'].observe(self._callback_case_slider, 'value') def _callback_select_source(self, *args): """ Repopulate cases_select with cases from the chosen source. Parameters ---------- args : tuple The information passed by the widget upon callback. """ self._widgets['debug_output'].clear_output() with self._widgets['debug_output']: self._update_case_select_options() def _callback_select_case(self, *args): """ Add the selected case(s) to the chosen cases list. Parameters ---------- args : tuple The information passed by the widget upon callback. """ self._widgets['debug_output'].clear_output() with self._widgets['debug_output']: clw = self._widgets['cases_list'] current = clw.options new = self._widgets['cases_select'].value def _numeric_sorter(case_name): return case_name.split('|')[0], int(case_name.split('|')[-1]) self._widgets['cases_list'].options = sorted(list( set(current + new)), key=_numeric_sorter) self._update_case_slider() self._update_var_select_options('x') self._update_var_select_options('y') self._update_plot() def _callback_case_list_select(self, *args): """ Update the plot when a different case is selected in the cases list. Parameters ---------- args : tuple The information passed by the widget upon callback. """ if not _DEBUG: self._widgets['debug_output'].clear_output() with self._widgets['debug_output']: self._update_plot() def _callback_case_slider(self, *args): """ Update the plot when the case slider is changed. Parameters ---------- args : tuple The information passed by the widget upon callback. """ if not _DEBUG: self._widgets['debug_output'].clear_output() with self._widgets['debug_output']: selected_case_idx = self._widgets['case_slider'].value cases = self._widgets['cases_list'].options if not _DEBUG: self._widgets['debug_output'].clear_output() with self._widgets['debug_output']: for i, case_name in enumerate(cases): lw, ms, s, alpha = _get_plot_style(i, selected_case_idx, len(cases)) if i < len(self._scatters): scat = self._scatters[i] scat.set_sizes(s * np.ones_like(scat.get_sizes())) scat.set_alpha(alpha) if i < len(self._lines): line = self._lines[i] for li in line: li.set_linewidth(lw) li.set_markersize(ms) li.set_alpha(alpha) def _callback_select_all_cases(self, *args): """ Add all available cases to the cases list. Parameters ---------- args : tuple The information passed by the widget upon callback. """ if not _DEBUG: self._widgets['debug_output'].clear_output() with self._widgets['debug_output']: clw = self._widgets['cases_list'] current = clw.options new = self._widgets['cases_select'].options def _numeric_sorter(case_name): return case_name.split('|')[0], int(case_name.split('|')[-1]) self._widgets['cases_list'].options = sorted(list( set(current + new)), key=_numeric_sorter) self._update_case_slider() self._update_var_select_options('x') self._update_var_select_options('y') self._update_plot() def _callback_remove_case(self, *args): """ Remove the selected case from the chosen cases list widget. Parameters ---------- args : tuple The information passed by the widget upon callback. """ if not _DEBUG: self._widgets['debug_output'].clear_output() with self._widgets['debug_output']: clw = self._widgets['cases_list'] new_list = list(clw.options) if clw.value in new_list: new_list.remove(clw.value) clw.options = new_list self._update_var_select_options('x') self._update_var_select_options('y') self._update_plot() def _callback_filter_vars(self, *args): """ Update the plot and the available variables when the filtering criteria is changed. Parameters ---------- args : tuple The information passed by the widget upon callback. """ event = args[0] if not _DEBUG: self._widgets['debug_output'].clear_output() with self._widgets['debug_output']: w = event['owner'] axis = 'x' if w is self._widgets['x_filter'] \ or w is self._widgets['x_var_type'] else 'y' self._update_var_select_options(axis) self._update_plot() def _callback_select_var(self, *args): """ Update the variable info and the plot when a new variable is chosen. Parameters ---------- args : tuple The information passed by the widget upon callback. """ event = args[0] if not _DEBUG: self._widgets['debug_output'].clear_output() with self._widgets['debug_output']: w = event['owner'] s = w.value axis = 'x' if w is self._widgets['x_select'] else 'y' self._update_var_info(axis) if s is None: self._ax.clear() else: self._update_plot() def _callback_change_slice(self, *args): """ Update the plot when a new, valid slice is provided. Parameters ---------- args : tuple The information passed by the widget upon callback. """ event = args[0] if not _DEBUG: self._widgets['debug_output'].clear_output() with self._widgets['debug_output']: w = event['owner'] s = w.value if s.startswith('[') and s.endswith(']'): self._update_plot() def _callback_change_scale(self, *args): """ Update the plot when the x or y axis scale is changed. Parameters ---------- args : tuple The information passed by the widget upon callback. """ if not _DEBUG: self._widgets['debug_output'].clear_output() with self._widgets['debug_output']: event = args[0] w = event['owner'] if w is self._widgets['x_scale']: self._ax.set_xscale(w.value) else: self._ax.set_yscale(w.value) def _callback_select_transform(self, *args): """ Update the plot when a new transformation is choen for the x or y variable. Parameters ---------- args : tuple The information passed by the widget upon callback. """ self._update_plot() def _redraw_plot(self): """ Update the plot area by plotting one variable vs another over one or more cases. """ x_min = y_min = 1E16 x_max = y_max = -1E16 cases = self._widgets['cases_list'].options x_slice = self._widgets['x_slice'].value y_slice = self._widgets['y_slice'].value x_transform = self._widgets['x_transform_select'].value y_transform = self._widgets['y_transform_select'].value x_var = self._widgets['x_select'].value y_var = self._widgets['y_select'].value x_var_type = self._widgets['x_var_type'].value y_var_type = self._widgets['y_var_type'].value selected_case_idx = self._widgets['case_slider'].value max_size = 0 if not _DEBUG: self._widgets['debug_output'].clear_output() with self._widgets['debug_output']: for i, case_name in enumerate(cases): lw, ms, s, alpha = _get_plot_style(i, selected_case_idx, len(cases)) case = self._case_reader.get_case(case_name) if y_var_type == 'residuals': y_val = _get_resids_val(case, y_var) else: y_val = case.get_val(y_var) try: y_val = _apply_slice(y_val, y_slice) except ValueError: if _DEBUG: print(f'Error while applying Y slice: {y_slice}') continue y_val = _apply_transform(y_val, y_transform) if x_var != self._case_index_str: if x_var_type == 'residuals': x_val = _get_resids_val(case, x_var) else: x_val = self._case_reader.get_case(case_name).get_val( x_var) else: x_val = i * np.ones(y_val.size) try: x_val = _apply_slice(x_val, x_slice) except ValueError: if _DEBUG: print(f'Error while applying X slice: {x_slice}') continue x_val = _apply_transform(x_val, x_transform) if x_val is None or y_val is None: continue if x_val.shape[0] != y_val.shape[0]: print(f'Incompatible shapes: x.shape = {x_val.shape} ' f' y.shape = {y_val.shape}.') print('Size along first axis must agree.') return max_size = max(max_size, x_val.size) x_min = min(x_min, np.min(x_val)) x_max = max(x_max, np.max(x_val)) y_min = min(y_min, np.min(y_val)) y_max = max(y_max, np.max(y_val)) if x_var == self._case_index_str: s = self._ax.scatter(x_val, y_val, c=np.arange(x_val.size), s=s, cmap=self._cmap, alpha=alpha) self._scatters.append(s) else: line = self._ax.plot(x_val, y_val, c=self._cmap(float(i) / len(cases)), mfc=self._cmap(float(i) / len(cases)), mec='k', marker='o', linestyle='-', linewidth=lw, markersize=ms, alpha=alpha) self._lines.append(line) self._fig.canvas.flush_events() x_margin = (x_max - x_min) * 0.05 x_margin = 0.1 if x_margin < 1.0E-16 else x_margin y_margin = (y_max - y_min) * 0.05 y_margin = 0.1 if y_margin < 1.0E-16 else y_margin bad_x_bounds = np.any(x.isinf() or x.isnan() for x in [x_min, x_max, x_margin]) bad_y_bounds = np.any(x.isinf() or x.isnan() for x in [x_min, x_max, x_margin]) if not bad_x_bounds: self._ax.set_xlim(x_min - x_margin, x_max + x_margin) if not bad_y_bounds: self._ax.set_ylim(y_min - y_margin, y_max + y_margin) # Add the colorbar. # Color shows the index of each point in its vector if the x-axis is Case Index, # otherwise it shows the case index. if x_var == self._case_index_str: vmax = max_size cbar_label = 'Array Index' else: vmax = len(cases) cbar_label = self._case_index_str self._scalar_mappable.set_clim(0, vmax) self._colorbar.set_label(cbar_label) def _update_plot(self): """ Update the plot based on the contents of the widgets. """ with self._widgets['debug_output']: cr = self._case_reader src = self._widgets['source_select'].value cases = self._widgets['cases_list'].options x_var = self._widgets['x_select'].value y_var = self._widgets['y_select'].value x_slice = '' if self._widgets['x_slice'].value == '[...]' \ else self._widgets['x_slice'].value y_slice = '' if self._widgets['y_slice'].value == '[...]' \ else self._widgets['y_slice'].value x_transform = self._widgets['x_transform_select'].value y_transform = self._widgets['y_transform_select'].value x_var_type = self._widgets['x_var_type'].value y_var_type = self._widgets['y_var_type'].value try: self._ax.clear() self._scatters.clear() self._lines.clear() except AttributeError: return if not cases or not x_var or not y_var: print('Nothing to plot') return x_units = 'None' if x_var == self._case_index_str \ else _get_var_meta(cr, cases[0], x_var)['units'] y_units = _get_var_meta(cr, cases[0], y_var)['units'] self._redraw_plot() x_label = rf'{x_var}{x_slice}' y_label = rf'{y_var}{y_slice}' if x_var_type == 'residuals': x_label = rf'$\mathcal{{R}}$({x_label})' if y_var_type == 'residuals': y_label = rf'$\mathcal{{R}}$({y_label})' if x_transform != 'None': x_label = f'{x_transform}({x_label})' if y_transform != 'None': y_label = f'{y_transform}({y_label})' self._ax.set_xlabel(f'{x_label}\n({x_units})') self._ax.set_ylabel(f'{y_label}\n({y_units})') self._ax.grid(True) self._ax.set_xscale(self._widgets['x_scale'].value) self._ax.set_yscale(self._widgets['y_scale'].value)