def state_mean(dm: DataManager, *, hlpr: PlotHelper, uni: UniverseGroup, mean_of: str, **plot_kwargs): """Calculates the mean of the `mean_of` dataset and performs a lineplot over time. Args: dm (DataManager): The data manager from which to retrieve the data hlpr (PlotHelper): The PlotHelper that instantiates the figure and takes care of plot aesthetics (labels, title, ...) and saving. To plot on the current axis, use ``hlpr.ax``. uni (UniverseGroup): The selected universe data mean_of (str): The name of the CopyMeGrid dataset that the mean is to be calculated of **plot_kwargs: Passed on to matplotlib.pyplot.plot """ # Extract the data data = uni['data/CopyMeGrid'][mean_of] # Calculate the mean over the spatial dimensions mean = data.mean(['x', 'y']) # Call the plot function on the currently selected axis in the plot helper hlpr.ax.plot(mean.coords['time'], mean, **plot_kwargs) # NOTE `hlpr.ax` is just the current matplotlib.axes object. It has the # same interface as `plt`, aka `matplotlib.pyplot` # Provide the plot helper with some information that is then used when # the helpers are invoked hlpr.provide_defaults('set_title', title="Mean '{}'".format(mean_of)) hlpr.provide_defaults('set_labels', y="<{}>".format(mean_of))
def time_series(*, data: dict, hlpr: PlotHelper, **plot_kwargs): """This is a generic plotting function that plots one or multiple time series from the 'y' tag that is selected via the DAG framework. The y data needs to be an xarray object. If y is an xr.DataArray, it is assumed to be one- or two-dimensional. If y is an xr.Dataset, all data variables are plotted and their name is used as the label. For the x axis values, the corresponding 'time' coordinates are used. Args: data (dict): The data selected by the DAG framework hlpr (PlotHelper): The plot helper **plot_kwargs: Passed on ot matplotlib.pyplot.plot """ y = data['y'] # If this is an xr.DataArray, it may be one or two-dimensional if isinstance(y, xr.Dataset): # Simply plot all data variables for dvar, line in y.data_vars.items(): hlpr.ax.plot(line.coords['time'], line, label=dvar, **plot_kwargs) elif isinstance(y, xr.DataArray): # Also allow two-dimensional arrays if y.ndim == 1: hlpr.ax.plot(y.coords['time'], y, **plot_kwargs) elif y.ndim == 2: loop_dim = [d for d in y.dims if d != 'time'][0] for c in y.coords[loop_dim]: line = y.sel({loop_dim: c}) hlpr.ax.plot(line.coords['time'], line, label="{:.2g}".format(c.item()), **plot_kwargs) # Provide a default title to the legend: name of the loop dimension hlpr.provide_defaults('set_legend', title="${}$ coordinate".format(loop_dim)) else: raise ValueError("Given y-data needs to be of one- or two-" "dimensional, was of dimensionality {}! Data: {}" "".format(y.ndim, y)) else: raise TypeError("Expected xr.Dataset or xr.DataArray, got {}" "".format(type(y)))
def some_DAG_based_CopyMeGrid_plot(*, data: dict, hlpr: PlotHelper, scatter_kwargs: dict = None): """This is an example plot to show how to implement a generic DAG-based plot function. In this example, nothing spectacular happens: A scatter plot is made and a sine curve with some amplitude and frequency is plotted, both of which arguments are provided via the DAG. Ideally, plot functions should be as generic as possible; just a bridge between some data (produced by the DAG) and its visualization. This plot is a (completely made-up) example for cases where it makes sense to have a rather specific plot... .. note:: To specify the plot in a generic way, omit the ``creator_type`` argument in the decorator. If you want to specialize the plot function, you can of course specify the creator type, but most often that's not necessary. Check the decorator's signature and the utopya and dantro documentation for more information on the DAG framework Args: data (dict): The selected DAG data, contains the required DAG tags. hlpr (PlotHelper): The PlotHelper that instantiates the figure and takes care of plot aesthetics (labels, title, ...) and saving. To plot on the current axis, use ``hlpr.ax``. scatter_kwargs (dict, optional): Passed to matplotlib.pyplot.scatter """ x_vals = data['x_values'] hlpr.ax.scatter(x_vals, data['scatter'], **(scatter_kwargs if scatter_kwargs else {})) # Do the sine plot sine = (data['amplitude'] * np.sin(x_vals * data['frequency']) + data['offset']) hlpr.ax.plot(x_vals, sine, label="my sine") # Enable the legend hlpr.mark_enabled("set_legend")
def sweep2d(data, hlpr: PlotHelper, x: str, y: str, z: str, plot_kwargs: dict = {}): """For multiverse runs, this produces a two dimensional plot showing specified values. Arguments: data (xarray): the dataset hlpr (PlotHelper): description x (str): the first parameter dimension of the diagram. y (str): the second parameter dimension of the diagram. z (str): the data dimension. plot_kwargs (dict, optional): kwargs passed to the pcolor plot function """ df = pd.DataFrame(data['data'][z].data, index=data['data'][y].data, columns=data['data'][x].data) im = hlpr.ax.pcolor(df, **plot_kwargs) hlpr.ax.set_yticks([ i for i in np.linspace(0.5, len(data['data'][y].data) - 0.5, len(data['data'][y].data)) ]) hlpr.ax.set_yticklabels([np.around(i, 3) for i in data['data'][y].data]) hlpr.ax.set_xticks([ i for i in np.linspace(0.5, len(data['data'][x].data) - 0.5, len(data['data'][x].data)) ]) hlpr.ax.set_xticklabels([np.around(i, 3) for i in data['data'][x].data]) divider = make_axes_locatable(hlpr.ax) cax = divider.append_axes("right", size="5%", pad=0.2) cbar = hlpr.fig.colorbar(im, cax=cax) cbar.set_label(z) hlpr.select_axis(0, 0)
def bifurcation(dm: DataManager, *, hlpr: PlotHelper, mv_data, avg_window: int=20, dim: str=None, plot_kwargs: dict=None, title: str=None): """For multiverse runs, this plot finds the upper turning points of the average opinion and plots them against the sweep parameter. Only works for a single sweep parameter or for a sweep parameter and different seeds. Configuration: - use the `select/field` key to associate one or multiple datasets - choose the dimension `dim` in which the sweep was performed. For a single sweep dimension, the sweep parameter is automatically deduced - use the `select/subspace` key to set values for all other parameters Arguments: dm (DataManager): the data manager from which to retrieve the data hlpr (PlotHelper): description mv_data (xr.Dataset): the extracted multidimensional dataset avg_window (int): the smoothing window for the rolling average of the opinion dataset dim (str, optional): the parameter dimension of the diagram. If no str is passed, an attempt will be made to automatically deduce the sweep dimension. plot_kwargs (dict, optional): kwargs passed to the scatter plot function title (str, optional): custom plot title Raises: ValueError: for a parameter dimension higher than 3 (or 4 if the sweep is also conducted over the seed) ValueError: if dim does not exist """ if dim is None: dim = deduce_sweep_dimension(mv_data) else: if not dim in mv_data.dims: raise ValueError(f"Dimension '{dim}' not available in multiverse data." f" Available: {mv_data.coords}") #get datasets and cfg ...................................................... dataset = mv_data['opinion'] time_steps = dataset['time'].size keys, cfg = get_keys_cfg(mv_data, dm['multiverse'].pspace.default, keys_to_ignore=[dim, 'time']) #figure setup .............................................................. figure, axs = setup_figure(cfg, plot_name='bifurcation', title=title, dim1=dim) hlpr.attach_figure_and_axes(fig=figure, axes=axs) hlpr.select_axis(0, 1) #data analysis ............................................................. #get the turning points of the average opinion (maxima only). If a sweep over #seed was performed, multiple datapoints are collected per x-value log.info("Starting data analysis ...") to_plot = [] if 'seed' in mv_data.coords and len(dataset['seed'])>1: for i in range(len(dataset[dim])): extremes = [] for j in range(len(dataset['seed'])): keys[dim] = i keys['seed'] = j data = np.asarray(dataset[keys]) means_glob = pd.Series(np.mean(data, axis=1)).rolling(window=avg_window).mean() res = find_extrema(means_glob) extremes.extend(res['max']['y']) to_plot.append((dataset[dim][i].data, extremes)) else: for i in range(len(dataset[dim])): keys[dim] = i data = np.asarray(dataset[keys]) means_glob = pd.Series(np.mean(data, axis=1)).rolling(window=avg_window).mean() extremes = find_extrema(means_glob)['max']['y'] to_plot.append((dataset[dim][i].data, extremes)) log.info("Data analysis complete.") #plot scatter plot of extrema .............................................. for p, o in to_plot: hlpr.ax.scatter([p] * len(o), o, **plot_kwargs) #hlpr.ax.set_ylim(0, 1) hlpr.ax.set_xlabel(convert_to_label(dim)) hlpr.ax.set_ylabel(r'mean opinion $\bar{\sigma}$') legend_elements = [Line2D([0], [0], label=(r'$\bar{\sigma}^\prime = 0$,'+ r'$\bar{\sigma}^{\prime \prime} < 0$'), lw=0, marker='o', color=plot_kwargs['color'] if not None else 'navy', markerfacecolor=plot_kwargs['color'] if not None else 'navy', markersize=5)] hlpr.ax.legend(handles=legend_elements, bbox_to_anchor=(1, 1.01), loc='lower right', ncol=2, fontsize='xx-small')
def bifurcation_diagram(dm: DataManager, *, hlpr: PlotHelper, mv_data: xr.Dataset, dim: str = None, dims: Tuple[str, str] = None, analysis_steps: Sequence[Union[str, Tuple[str, str]]], custom_analysis_funcs: Dict[str, Callable] = None, analysis_kwargs: dict = None, visualization_kwargs: dict = None, to_plot: dict = None, **kwargs) -> None: """Plots a bifurcation diagram for one or two parameter dimensions (arguments ``dim`` or ``dims``). Args: dm (DataManager): The data manager from which to retrieve the data hlpr (PlotHelper): The PlotHelper that instantiates the figure and takes care of plot aesthetics (labels, title, ...) and saving mv_data (xr.Dataset): The extracted multidimensional dataset dim (str, optional): The required parameter dimension of the 1d bifurcation diagram. dims (str, optional): The required parameter dimensions (x, y) of the 2d-bifurcation diagram. analysis_steps (Sequence): The analysis steps that are to be made until one is conclusive. Applied per universe. - If seq of str: The str will also be used as attractor key for plotting if the test is conclusive. - If seq of Tuple(str, str): The first str defines the attractor key for plotting, the second str is a key within custom_analysis_funcs. Default analysis_funcs are: - endpoint: utopya.dataprocessing.find_endpoint - fixpoint: utopya.dataprocessing.find_fixpoint - multistability: utdp.find_multistability - oscillation: utdp.find_oscillation - scatter: resolve_scatter custom_analysis_funcs (dict): A collection of custom analysis functions that will overwrite the default analysis funcs (recursive update). analysis_kwargs (dict, optional): The entries need to match the analysis_steps. The subentry (dict) is passed on to the analysis function. visualization_kwargs (dict, optional): The entries need to match the analysis_steps. The subentry (dict) is used to configure a rectangle to visualize the conclusive analysis step. Is passed to matplotlib.patches.rectangle. xy, width, height, and angle are ignored and set automatically. Required in 2d bifurcation diagram. to_plot (dict, optional): The configuration for the data to plot. The entries of this key need to match the data_vars selected in mv_data. It is used to visualize the state of the attractor additionally to the visualization kwargs. Only for 1d-bifurcation diagram. sub_keys: - ``label`` (str, optional): label in plot - ``plot_kwargs`` (dict, optional): passed to scatter for every universe - color (str, recommended): unique color for every data_variable accross universes **kwargs: Collection of optional dicts passed to different functions - plot_coords_kwargs (dict): Passed to ax.scatter to mark the universe's center in the bifurcation diagram - rectangle_map_kwargs (dict): Passed to utopya.plot_funcs._utils.calc_pxmap_rectangles - legend_kwargs (dict): Passed to ax.legend """ def resolve_analysis_steps( analysis_steps: Sequence[Union[str, Tuple[str, str]]] ) -> Sequence[Tuple[str, str]]: """Resolve instance of str to Tuple[str, str] in sequence Args: analysis_steps (Sequence[Union[str, Tuple[str, str]]]): The original sequence Returns: analysis_steps (Sequence[Tuple[str, str]]): The sequence of attractor_key and analysis_func pairs. """ for i, analysis_step in enumerate(analysis_steps): # get key and func for the analysis step if isinstance(analysis_step, str): analysis_steps[i] = [analysis_step, analysis_step] return analysis_steps def resolve_to_plot_kwargs(to_plot: dict) -> dict: """Resolves the to_plot dict, e.g. adding labels if not explicitly specified. Args: to_plot (dict): The to_plot dict to parse Returns: dara_vars_plot_kwargs (dict): A dict with the 'plot_kwargs' for every data_var in to_plot. """ if not to_plot: return {} data_vars_plot_kwargs = {} for k, v in to_plot.items(): plot_kwargs = v.get("plot_kwargs", {}) if not plot_kwargs.get('label'): plot_kwargs['label'] = v.get('label', k) data_vars_plot_kwargs[k] = plot_kwargs return data_vars_plot_kwargs def create_legend_handles(*, visualization_kwargs: dict, data_vars_plot_kwargs: dict): """Creates legend handles Processes entries in data_vars_plot_kwargs (from to_plot) and visualization_kwargs. Args: visualization_kwargs (dict): The visualization kwargs data_vars_plot_kwargs (dict): The resolved entries in to_plot Returns: Tuple[list, list]: Tuple of legend handles and legend labels lists as required by ax.legend(handles, labels) data_vars_plot_kwargs: The updated entries data_vars_plot_kwargs as required by plot_attractor. """ # Some defaults circle_kwargs = dict(xy=(.5, .5), radius=.25, edgecolor="none") rect_kwargs = dict(xy=(0., 0.), height=.75, width=1., edgecolor="none") # Lists to be populated for matplotlib legend legend_handles = [] legend_labels = [] for k, kwargs in data_vars_plot_kwargs.items(): label = kwargs.pop('label', kwargs) kwargs['linewidth'] = kwargs.get('linewidth', 0.) data_vars_plot_kwargs[k] = kwargs # Determine color if 'color' in kwargs: color = kwargs['color'] elif 'cmap' in kwargs: cmap = mpl.cm.get_cmap(kwargs['cmap']) color = cmap(1.) else: log.warning("No color defined for data_var '{}'!".format(k)) color = None # Create and add the handle and the label legend_handles.append(Circle(**circle_kwargs, facecolor=color)) legend_labels.append(label) for k, kwargs in visualization_kwargs.items(): if 'to_plot' in kwargs: for dvar_name, dvar_kwargs in kwargs['to_plot'].items(): # Make sure a linewidth is set dvar_kwargs['linewidth'] = dvar_kwargs.get('linewidth', 0.) kwargs['to_plot'][dvar_name] = dvar_kwargs # Create and add the handle and the label legend_handles.append( Rectangle(**rect_kwargs, **dvar_kwargs)) legend_labels.append(dvar_kwargs.get('label', dvar_name)) else: kwargs['linewidth'] = kwargs.get('linewidth', 0.) data_vars_plot_kwargs[k] = kwargs legend_handles.append(Rectangle(**rect_kwargs, **kwargs)) legend_labels.append(kwargs.get('label', k)) return [legend_handles, legend_labels], data_vars_plot_kwargs def apply_analysis_steps(data: xr.Dataset, analysis_steps: Sequence[Union[str, Tuple[str, str]]], *, analysis_funcs: dict, analysis_kwargs: dict): """Perform the sequence of analysis steps until the first conclusive. Args: data (xr.Dataset): The data to analyse. analysis_steps (Sequence[Union[str, Tuple[str, str]]]): The analysis steps that are to be made until one is conclusive. Applied per universe. analysis_funcs (dict): The entries need to match the analysis_steps. Map of the analysis_steps to their Callables analysis_kwargs (dict): The entries need to match the analysis_steps. The subentry (dict) is passed on to the corresponding analysis function. Returns: analysis_key (str): The key of the conclusive analysis step. attractor (xr.Dataset): The data corresponding to this analysis. """ for analysis_key, analysis_func in analysis_steps: analysis_func_kwargs = analysis_kwargs.get(analysis_func, {}) # resolve the analysis function from its name if isinstance(analysis_func, str): if analysis_func in analysis_funcs: analysis_func = analysis_funcs[analysis_func] else: # Try to get it from dataprocessing ... might fail. analysis_func = getattr(utopya.dataprocessing, analysis_func) # Perfom the analysis step conclusive, attractor = analysis_func(data, **analysis_func_kwargs) # Return if conclusive if conclusive: return analysis_key, attractor # Return non-conclusive return None, None def resolve_rectangle(coord: dict, rectangles: xr.Dataset) -> Rectangle: """Resolve the rectangle patch at this coordinate Args: coord (dict): The bifurcation parameter's coordinate rectangles (xr.Dataset): The rectangles that cover the 2D space spanned by the coordiantes. The `coord` should be one entry of rectangles.coords. Raises: ValueError: Coordinate not available in rectangles. Returns: Rectangle: A rectangle around a universe with coord and shape defined by rectangles. """ try: rectangle = rectangles.sel(coord) except Exception as exc: raise ValueError("The requested paramspace coordinate(s) {} are " "not coordinates of rectangles {}. Plot failed." "".format(coord, rectangles.coords)) from exc rect_spec = rectangle['rect_spec'] return Rectangle(*rect_spec.item()) def append_vis_patch(attrator_key: str, attractor: xr.Dataset, vis_patches: dict, vis_kwargs: dict, **resolve_rectangle_args): """Append visualization patch Performs postprocess for - attractor key 'fixpoint' and 'endpoint' if: 'to_plot' in entry of vis_kwargs. Then finds the data_var with highest valued datapoint. Args: attractor_key (str): Key according to which to decode the attractor attractor (xr.Dataset): The Dataset with the encoded attractor information. See possible encodings vis_patches (dict): The map of attractor_key to List[Rectangle] where to append the new patch vis_kwargs (dict): The visualization kwargs **resolve_rectangle_args: Passed on to resolve_rectangle Raises: ValueError: Bad postprocess key Returns: vis_patches (dict): The new map of attractor_key to List[Rectangle] Deleted Parameters: resolve_rectangle_args (dict): Args as required by resolve_rectangle """ # Depending on the kind of attractor, add different patches kwargs = vis_kwargs.get(attractor_key) if kwargs is None: return vis_patches # Postprocess fixpoint and to_plot, append rectangle if ((attractor_key == 'fixpoint' or attractor_key == 'endpoint') and 'to_plot' in kwargs): max_value = -np.inf for data_var_name, data_var in attractor.data_vars.items(): if data_var.max() > max_value: max_value = data_var.max() max_name = data_var_name rect = resolve_rectangle(**resolve_rectangle_args) attractor_var_key = attractor_key + '_' + max_name vis_patches[attractor_var_key].append(rect) # Append rectangle elif vis_patches.get(attractor_key) is not None: rect = resolve_rectangle(coord=coord, rectangles=rects) vis_patches[attractor_key].append(rect) return vis_patches def append_plot_attractor(attractor_key: str, attractor: xr.Dataset, *, coord: float = None, scatter_kwargs: list, **plot_kwargs): """Resolves how to plot attractor of specified type at specific bifurcation parameter value. Args: attractor_key (str): Key according to which to decode the attractor attractor (xr.Dataset): The Dataset with the encoded attractor information. See possible encodings coord (float, optional): The bifurcation parameter's coordinate, if None its derived from the attractors coordinates scatter_kwargs (list): The list of scatter datasets where to append the new scatter. plot_kwargs (dict, optional): The kwargs used to specify ax.scatter where the entries match the attractor.data_vars Possible encodings, i.e. values for ``attractor_key``: ``fixpoint``: xr.Dataset with dimensions () ``scatter``: xr.Dataset with dimensions (time: >=1) ``multistability``: xr.Dataset with dimensions (<initial_state>: >= 1) ``oscillation``: xr.Dataset with dimensions (osc: 2), the minimum and maximum NOTE the attractor must contain the bifurcation parameter coordinate Raises: KeyError: Unknown attractor_key KeyError: No bifurcation coordinate received ValueError: Attractor encoding mismatched with the given attractor_key Returns: scatter_kwargs: The new list of scatter datasets """ # Get the bifurcation parameter coordinate if not coord: try: coord = attractor[dim] except KeyError as err: raise KeyError( "No bifurcation parameter coordinate '{}' " "could be found! Either have it as a " "coordinate in 'attractor' or pass it to " "'plot_attractor' explicitly.".format(dim)) from err # Resolve the scatter kwargs depending on attractor key if attractor_key in ('fixpoint', 'endpoint', 'multistability'): for data_var_name, data_var in attractor.data_vars.items(): data_var = data_var.where(data_var != np.nan, drop=True) entries = 1 if data_var.shape: entries = len(data_var) scatter_kwargs.append( dict(x=[coord] * entries, y=data_var, **plot_kwargs.get(data_var_name, {}))) elif attractor_key == 'scatter': for data_var_name, data_var in attractor.data_vars.items(): if 'cmap' in plot_kwargs.get(data_var_name, {}): scatter_kwargs.append( dict(x=[coord] * len(data_var.data), y=data_var, c=attractor['time'], **plot_kwargs.get(data_var_name, {}))) else: scatter_kwargs.append( dict(x=[coord] * len(data_var.data), y=data_var, **plot_kwargs.get(data_var_name, {}))) elif attractor_key == 'oscillation': for data_var_name, data_var in attractor.data_vars.items(): scatter_kwargs.append( dict(x=[coord] * len(data_var.data), y=data_var, **plot_kwargs.get(data_var_name, {}))) elif attractor_key: raise KeyError("Invalid attractor-key '{}'! " "Available keys: 'endpoint', fixpoint'," " 'multistability', 'scatter', 'oscillation'." "".format(attractor_key)) return scatter_kwargs def resolve_scatter(data: xr.Dataset, *, spin_up_time: int = 0, **kwargs) -> tuple: """A mock analysis function to plot all times larger than a spin up time. """ return True, data.where(data.time >= spin_up_time, drop=True) # ......................................................................... # Check argument values if not dim and not dims: raise KeyError("No dim (str) or dims (Tuple[str, str]) specified. " "Use dim for a 1d-bifurcation diagram and dims for a " "2d-bifurcation diagram.") if dim and dims: raise KeyError("dim='{}' and dims='{}' specified. " "Use either dim for a 1d-bifurcation diagram or dims " "for a 2d-bifurcation diagram." "".format(dim, dims)) if dims is not None and len(dims) != 2: raise ValueError("Argument dims should be of length 2, but was: {}" "".format(dims)) # TODO In the future, consider not using `dim` below here but handling it # via the length of `dims`. # Default values if visualization_kwargs is None: visualization_kwargs = {} if analysis_kwargs is None: analysis_kwargs = {} # Resolve legend handles and visualization kwargs data_vars_plot_kwargs = resolve_to_plot_kwargs(to_plot) legend_handles, data_vars_plot_kwargs = create_legend_handles( visualization_kwargs=visualization_kwargs, data_vars_plot_kwargs=data_vars_plot_kwargs) # Define default analysis functions analysis_funcs = dict(endpoint=utdp.find_endpoint, fixpoint=utdp.find_fixpoint, multistability=utdp.find_multistability, oscillation=utdp.find_oscillation, scatter=resolve_scatter) # If given, update if custom_analysis_funcs: log.debug("Updating with custom analysis functions ...") analysis_funcs = recursive_update(analysis_funcs, custom_analysis_funcs) analysis_steps = resolve_analysis_steps(analysis_steps) # Obtain the rectangles covering space spanned by the coordinates rectangle_map_kwargs = kwargs.get('rectangle_map_kwargs', {}) if dim: rects, limits = calc_pxmap_rectangles(x_coords=mv_data[dim].values, y_coords=None, **rectangle_map_kwargs) elif dims: rects, limits = calc_pxmap_rectangles(x_coords=mv_data[dims[0]].values, y_coords=mv_data[dims[1]].values, **rectangle_map_kwargs) # Obtain the list of param_coords to iterate if dim: param_iter = mv_data[dim].values elif dims: param_iter = itertools.product(mv_data[dims[0]].values, mv_data[dims[1]].values) # Map of analysis_key to list[mpatch.Rectangle] vis_patches = {} for analysis_key, _ in analysis_steps: if not visualization_kwargs.get(analysis_key): continue if 'to_plot' in visualization_kwargs[analysis_key]: for var_key, _ in visualization_kwargs[analysis_key][ 'to_plot'].items(): analysis_var_key = analysis_key + '_' + var_key vis_patches[analysis_var_key] = [] else: vis_patches[analysis_key] = [] # The List[dict] passed to ax.scatter scatter_kwargs = [] scatter_coords_kwargs = [] # Iterate the parameter coordinates for param_coord in param_iter: # Resolve the param_coord to dict if dim: param_coord = {dim: param_coord} elif dims: param_coord = {dims[0]: param_coord[0], dims[1]: param_coord[1]} # Plot coord if dim and kwargs.get('plot_coords_kwargs'): plot_coords_kwargs = kwargs.get('plot_coords_kwargs') scatter_coords_kwargs.append({ 'x': param_coord[dim], 'y': plot_coords_kwargs.pop('y', 0.), **plot_coords_kwargs }) if dims and kwargs.get('plot_coords_kwargs'): scatter_coords_kwargs.append({ 'x': param_coord[dims[0]], 'y': param_coord[dims[1]], **kwargs.get('plot_coords_kwargs') }) # Select the data and analyse data = mv_data.sel(param_coord) attractor_key, attractor = apply_analysis_steps( data, analysis_steps, analysis_funcs=analysis_funcs, analysis_kwargs=analysis_kwargs) # If conclusive, append a rectangular patch to the attractor_key's # patch collection if attractor_key: # Determine coordinate value if dim: rect_map_kwargs = kwargs.get('rectangle_map_kwargs', {}) y = rect_map_kwargs.get('default_pos', (0., 0.))[1] coord = dict(x=param_coord[dim], y=y) elif dims: coord = dict(x=param_coord[dims[0]], y=param_coord[dims[1]]) vis_patches = append_vis_patch(attractor_key, attractor, vis_patches, visualization_kwargs, coord=coord, rectangles=rects) # For 1d case ... if dim and to_plot: scatter_kwargs = append_plot_attractor( attractor_key, attractor, coord=param_coord[dim], scatter_kwargs=scatter_kwargs, **data_vars_plot_kwargs) # Draw collection of visualization patches for analysis_key, _ in analysis_steps: if not visualization_kwargs.get(analysis_key): continue if analysis_key in visualization_kwargs: vis_kwargs = visualization_kwargs[analysis_key] if 'to_plot' in vis_kwargs: for var_key, var_kwargs in vis_kwargs['to_plot'].items(): attractor_var_key = analysis_key + '_' + var_key pc = PatchCollection(vis_patches[attractor_var_key], **var_kwargs) hlpr.ax.add_collection(pc) else: pc = PatchCollection(vis_patches[analysis_key], **vis_kwargs) hlpr.ax.add_collection(pc) # else: nothing to do # Scatter the universe's coordinates for kws in scatter_coords_kwargs: hlpr.ax.scatter(**kws) # Scatter the attractor for kws in scatter_kwargs: hlpr.ax.scatter(**kws) # Provide PlotHelper defaults hlpr.provide_defaults('set_limits', **limits) if dim: hlpr.provide_defaults('set_labels', x=dim, y='state') elif dims: hlpr.provide_defaults('set_labels', x=dims[0], y=dims[1]) if legend_handles: hlpr.ax.legend(legend_handles[0], legend_handles[1], handler_map={Circle: HandlerEllipse()}, **kwargs.get('legend_kwargs', {}))
def opinion_at_time(dm: DataManager, *, hlpr: PlotHelper, uni: UniverseGroup, age_groups: list = [10, 20, 40, 60, 80], num_bins: int = 100, plot_kwargs: dict = {}, time_step: float, title: str = None, to_plot: str, val_range: tuple = (0., 1.)): """Plots opinion state at a specific time frame. If the model mode is 'ageing', the opinions of the specified age groups are plotted, in all other cases the opinions of the groups are shown. Arguments: age_groups (list): The age intervals to be plotted in the axs[2] distribution plot for the 'ageing' mode. time_step (int): the time frame to plot (as a fraction of the total length) title (str, optional): Custom plot title to_plot (str): whether or not to differentiate by groups Raises: TypeError: if the age groups have fewer than two entries (in which case no binning is possible) ValueError: if to_plot or time_step are invalid """ if len(age_groups) < 2: raise TypeError("'Age groups' list needs at least two entries!") if to_plot not in ['overall', 'by_group', 'discriminators']: raise ValueError(f"Unrecognized argument {to_plot}: must be one of " "'overall' or 'by_group'") if time_step < 0 or time_step > 1: raise ValueError("time_step must be in [0, 1]") #datasets................................................................... mode = uni['cfg']['OpDisc']['mode'] ageing = True if mode == 'ageing' else False time_idx = int(time_step * (uni['data/OpDisc/nw/opinion']['time'].size - 1)) opinions = uni['data/OpDisc/nw/opinion'] if ageing: groups = np.asarray(uni['data/OpDisc/nw/group_label'], dtype=int) else: groups = np.asarray(uni['data/OpDisc/nw/group_label'][0, :], dtype=int) num_groups = len( age_groups) - 1 if ageing else uni['cfg']['OpDisc']['number_of_groups'] group_list = age_groups if ageing else [_ for _ in range(num_groups)] time = uni['data/OpDisc/nw/opinion'].coords['time'].data #figure setup .............................................................. figure, axs = setup_figure(uni['cfg'], plot_name='opinion') hlpr.attach_figure_and_axes(fig=figure, axes=axs) hlpr.select_axis(0, 1) # data analysis and plotting................................................ if to_plot == 'by_group': #get opinions by group to_plot = np.zeros((num_bins, num_groups)) data_by_groups = data_by_group(opinions, groups, group_list, val_range, num_bins, ageing=ageing) #calculate a histogram of the opinion distribution at each time step for k in range(num_groups): counts, _ = np.histogram(data_by_groups[k][time_idx], bins=num_bins, range=val_range) to_plot[:, k] = counts[:] #get pretty labels if ageing: labels = [ f"Ages {group_list[_]}-{group_list[_+1]}" for _ in range(num_groups) ] max_age = np.amax(groups) if (age_groups[-1] >= max_age): labels[-1] = f"Ages {group_list[-2]}+" else: labels = [f"Group {_+1}" for _ in group_list] X = pd.DataFrame(to_plot[:, :], columns=labels) X.plot.bar(stacked=True, ax=hlpr.ax, legend=False, rot=0, **plot_kwargs) hlpr.ax.legend(bbox_to_anchor=(1, 1.01), loc='lower right', ncol=num_groups, fontsize='xx-small') hlpr.ax.set_xticks([i for i in np.linspace(0, num_bins - 1, 11)]) hlpr.ax.set_xticklabels( [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]) elif to_plot == 'overall': hlpr.ax.hist(opinions[time_idx][:], bins=num_bins, alpha=1, **plot_kwargs) elif to_plot == 'discriminators': discriminators = uni['data/OpDisc/nw/discriminators'] mask = np.empty(opinions.shape, dtype=bool) mask[:, :] = (discriminators == 0) ops_disc = np.ma.MaskedArray(opinions, mask) ops_nondisc = np.ma.MaskedArray(opinions, ~mask) hlpr.ax.hist(ops_disc[time_idx].compressed()[:], bins=num_bins, alpha=0.5, **plot_kwargs, label='disc') hlpr.ax.hist(ops_nondisc[time_idx].compressed()[:], bins=num_bins, alpha=0.5, **plot_kwargs, label='non-disc') hlpr.ax.hist(opinions[time_idx][:], bins=num_bins, alpha=1, **plot_kwargs, histtype='step') hlpr.ax.legend(bbox_to_anchor=(1, 1.01), loc='lower right', ncol=num_groups, fontsize='xx-small') hlpr.ax.set_xlim(val_range[0], val_range[1]) hlpr.ax.text(0.02, 0.97, f'step {time[time_idx]}', transform=hlpr.ax.transAxes, fontsize='xx-small')
def op_groups(dm: DataManager, *, uni: UniverseGroup, hlpr: PlotHelper, age_groups: list=[10, 20, 40, 60, 80], num_bins: int=100, time_idx: int=None, title: str=None, val_range: tuple=(0., 1.)): """Plots an animated stacked histogram of the opinion distribution of each group. Arguments: age_groups (list): The age intervals to be plotted in the final_ax distribution plot for the 'ageing' mode. num_bins(int): Binning of the histogram time_idx (int, optional): Only plot one single frame (eg. last frame) title (str, optional): Custom plot title val_range(int, optional): Value range of the histogram Raises: TypeError: if the 'age_groups' list does not contain at least two entries """ if len(age_groups)<2: raise TypeError("'age_groups' list must contain at least 2 entries!") #figure setup.............................................................. figure, axs = setup_figure(uni['cfg'], plot_name='op_groups', title=title) hlpr.attach_figure_and_axes(fig=figure, axes=axs) hlpr.select_axis(0, 1) #datasets .................................................................. ageing = True if uni['cfg']['OpDisc']['mode'] == 'ageing' else False if ageing: groups = np.asarray(uni['data/OpDisc/nw/group_label'], dtype=int) else: groups = np.asarray(uni['data/OpDisc/nw/group_label'][0, :], dtype=int) num_groups = len(age_groups)-1 if ageing else uni['cfg']['OpDisc']['number_of_groups'] group_list = age_groups if ageing else [_ for _ in range(num_groups)] opinions = uni['data/OpDisc/nw/opinion'] time = opinions.coords['time'].data time_steps = opinions.coords['time'].size #data analysis ............................................................. #get opinions by group to_plot = np.zeros((time_steps, num_bins, num_groups)) data_by_groups = data_by_group(opinions, groups, group_list, val_range, num_bins, ageing=ageing) #calculate a histogram of the opinion distribution at each time step for t in range(time_steps): for k in range(num_groups): counts, _ = np.histogram(data_by_groups[k][t], bins=num_bins, range=val_range) to_plot[t, :, k] = counts[:] #get pretty labels if ageing: labels = [f"Ages {group_list[_]}-{group_list[_+1]}" for _ in range(num_groups)] max_age = np.amax(groups) if (age_groups[-1]>=max_age): labels[-1]=f"Ages {group_list[-2]}+" else: labels = [f"Group {_+1}" for _ in group_list] #create pandas df for stacked bar plot if not time_idx: X = [pd.DataFrame(to_plot[_, :, :], columns=labels) for _ in range(time_steps)] else: X = pd.DataFrame(to_plot[time_idx, :, :], columns=labels) #plotting .................................................................. #plot an animated stacked bar chart. Since there is no 'set height' function #for pandas charts, we need to clear the axis and entirely reformat the plot #for every frame. Clearing is necessary or else successive frames are simply #superimposed def update_data(stepsize: int=1): """Updates the data of the imshow objects""" if time_idx: log.info(f"Plotting discribution at time step {time[time_idx]} ...") else: log.info(f"Plotting animation with {opinions.shape[0]//stepsize} frames ...") next_frame_idx = 0 if time_steps < stepsize: log.warn("Stepsize is greater than number of steps. " "Continue by plotting fist and last frame.") stepsize=time_steps-1 for t in range(time_steps): if t < next_frame_idx: continue hlpr.ax.clear() hlpr.ax.set_xlim(0, 1) if time_idx: t = time_idx im = X.plot.bar(stacked=True, ax=hlpr.ax, legend=False, rot=0) else: im = X[t].plot.bar(stacked=True, ax=hlpr.ax, legend=False, rot=0) time_text = hlpr.ax.text(0.02, 0.97, '', transform=hlpr.ax.transAxes, fontsize ='xx-small') time_text.set_text(f'step {time[t]}') hlpr.ax.legend(bbox_to_anchor=(1, 1.01), loc='lower right', ncol=num_groups, fontsize='xx-small') hlpr.ax.set_xticks([i for i in np.linspace(0, num_bins-1, 11)]) hlpr.ax.set_xticklabels([0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]) hlpr.ax.set_xlabel(hlpr.axis_cfg['set_labels']['x']) hlpr.ax.set_ylabel(hlpr.axis_cfg['set_labels']['y']) if time_idx: yield break next_frame_idx = t + stepsize yield hlpr.register_animation_update(update_data)
def draw_graph(*, hlpr: PlotHelper, data: dict, graph_group_tag: str = "graph_group", graph: Union[nx.Graph, xr.DataArray] = None, graph_creation: dict = None, graph_drawing: dict = None, graph_animation: dict = None, register_property_maps: Sequence[str] = None, clear_existing_property_maps: bool = True, suptitle_kwargs: dict = None): """Draws a graph either from a :py:class:`~utopya.datagroup.GraphGroup` or directly from a ``networkx.Graph`` using the :py:class:`~utopya.plot_funcs._graph.GraphPlot` class. If the graph object is to be created from a graph group the latter needs to be selected via the TransformationDAG. Additional property maps can also be made available for plotting, see ``register_property_map`` argument. Animations can be created either from a graph group by using the select interface in ``graph_animation`` or by passing a DataArray of networkx graphs via the ``graph`` argument. For more information on how to use the transformation framework, refer to the `dantro documentation <https://dantro.readthedocs.io/en/stable/plotting/plot_data_selection.html>`_. For more information on how to configure the graph layout refer to the :py:class:`~utopya.plot_funcs._graph.GraphPlot` documentation. Args: hlpr (PlotHelper): The PlotHelper instance for this plot data (dict): Data from TransformationDAG selection graph_group_tag (str, optional): The TransformationDAG tag of the graph group graph (Union[nx.Graph, xr.DataArray], optional): If given, the ``data`` and ``graph_creation`` arguments are ignored and this graph is drawn directly. If a DataArray of graphs is given, the first graph is drawn for a single graph plot. In animation mode the (flattened) array represents the animation frames. graph_creation (dict, optional): Configuration of the graph creation. Passed to ``GraphGroup.create_graph``. graph_drawing (dict, optional): Configuration of the graph layout. Passed to :py:class:`~utopya.plot_funcs._graph.GraphPlot`. graph_animation (dict, optional): Animation configuration. The following arguments are allowed: times (dict, optional): *Deprecated*: Equivaluent to a sel.time entry. sel (dict, optional): Select by value. Dictionary with dimension names as keys. The values may either be coordinate values or a dict with a single ``from_property`` (str) entry which specifies a container withing the GraphGroup or registered external data from which the coordinates are extracted. isel (dict, optional): Select by index. Coordinate indices keyed by dimension. May be given together with ``sel`` if no key appears in both. update_positions (bool, optional): Whether to update the node positions for each frame by recalculating the layout with the parameters specified in graph_drawing.positions. If this parameter is not given or false, the positions are calculated once initially and then fixed. update_colormapping (bool, optional): Whether to reconfigure the nodes' and edges' :py:class:`~utopya.plot_funcs._mpl_helpers.ColorManager` for each frame (default=False). If False, the colormapping (and the colorbar) is configured with the first frame and then fixed. skip_empty_frames (bool, optional): Whether to skip the frames where the selected graph is missing or of a type different than ``nx.Graph`` (default=False). If False, such frames are empty. register_property_maps (Sequence[str], optional): Names of properties to be registered in the graph group before the graph creation. The property names must be valid TransformationDAG tags, i.e., be available in ``data``. Note that the tags may not conflict with any valid path reachable from inside the selected ``GraphGroup``. clear_existing_property_maps (bool, optional): Whether to clear any existing property maps from the selected ``GraphGroup``. This is enabled by default to reduce side effects from previous plots. Set this to False if you have property maps registered with the GraphGroup that you would like to keep. suptitle_kwargs (dict, optional): Passed on to the PlotHelper's ``set_suptitle`` helper function. Only used in animation mode. The ``title`` can be a format string containing a placeholder with the dimension name as key for each dimension along which selection is done. The format string is updated for each frame of the animation. The default is ``<dim-name> = {<dim-name>}`` for each dimension. Raises: ValueError: On invalid or non-computed TransformationDAG tags in ``register_property_maps`` or invalid graph group tag. """ # Work on a copy such that the original configuration is not modified graph_drawing = copy.deepcopy(graph_drawing if graph_drawing else {}) # Get the sub-configurations for the drawing of the graph select = graph_drawing.get("select", {}) pos_kwargs = graph_drawing.get("positions", {}) node_kwargs = graph_drawing.get("nodes", {}) edge_kwargs = graph_drawing.get("edges", {}) node_label_kwargs = graph_drawing.get("node_labels", {}) edge_label_kwargs = graph_drawing.get("edge_labels", {}) mark_nodes_kwargs = graph_drawing.get("mark_nodes", {}) mark_edges_kwargs = graph_drawing.get("mark_edges", {}) def get_dag_data(tag): try: return data[tag] except KeyError as err: _available_tags = ", ".join(data.keys()) raise ValueError( f"No tag '{tag}' found in the data selected by the DAG! Make " "sure the tag is named correctly and is selected to be " "computed; adjust the 'compute_only' argument if needed." "\nThe following tags are available in the DAG results: " f"{_available_tags}") from err # Prepare graph group and external property data graph_group = get_dag_data(graph_group_tag) if graph is None else None property_maps = None if register_property_maps: property_maps = {} for tag in register_property_maps: property_maps[tag] = get_dag_data(tag) # If not in animation mode, make a single graph plot if not hlpr.animation_enabled: # Set up a GraphPlot instance if graph_group is not None: # Create GraphPlot from graph group gp = GraphPlot.from_group( graph_group=graph_group, graph_creation=graph_creation, register_property_maps=property_maps, clear_existing_property_maps=clear_existing_property_maps, fig=hlpr.fig, ax=hlpr.ax, **graph_drawing, ) else: if graph_creation is not None: warnings.warn( "Received both a 'graph' argument and a 'graph_creation' " "configuration. The latter will be ignored. To remove " "this warning set graph_creation to None.", UserWarning) if isinstance(graph, xr.DataArray): # Use the first array element for a single graph plot g = graph.values.flat[0] if graph.size > 1: log.caution( "Received a DataArray of size %d as 'graph' argument, " "performing a single plot using the first entry. " "Animations must be enabled by setting " "animation.enabled to True.", graph.size, ) else: g = graph # Create a GraphPlot from a nx.Graph gp = GraphPlot(g=g, fig=hlpr.fig, ax=hlpr.ax, **graph_drawing) # Make the actual plot via the GraphPlot gp.draw() # In animation mode, register the animation frame generator else: # Prepare animation kwargs for the update routine graph_animation = graph_animation if graph_animation else {} def update(): """The animation frames generator. See :py:meth:`~utopya.plot_funcs.dag.graph.graph_animation_update`. """ if graph_group is None and not isinstance(graph, xr.DataArray): raise TypeError( "Failed to create animation due to invalid type of the " "'graph' argument. Required: xr.DataArray (filled " f"with graph objects). Received: {type(graph)}") yield from graph_animation_update( hlpr=hlpr, graphs=graph if isinstance(graph, xr.DataArray) else None, graph_group=graph_group, graph_creation=graph_creation, register_property_maps=property_maps, clear_existing_property_maps=clear_existing_property_maps, suptitle_kwargs=suptitle_kwargs, animation_kwargs=graph_animation, **graph_drawing, ) hlpr.register_animation_update(update)
def opinion_animation(dm: DataManager, *, uni: UniverseGroup, hlpr: PlotHelper, time_idx: int, to_plot: str, num_bins: int = 100, val_range: tuple = (0., 1.)): opinions = uni['data/OpDyn/nw_users/' + to_plot] life_cycle = int(uni['cfg']['OpDyn']['life_cycle']) time = opinions['time'].data time_steps = time.size if not time_idx: time_idx = int(next(iter([i for i in range(time_steps)]))) if time_idx == -1: opinions_at_time = opinions[-1, :] else: opinions_at_time = opinions[time_idx, :] bins = num_bins if num_bins else 100 start, stop = val_range if val_range else (0., 1.) # Calculate the opinion histogram at time_idx counts_at_time, bin_edges = np.histogram(opinions_at_time, range=(start, stop), bins=bins) # Calculate bin positions, i.e. midpoint of bin edges bin_pos = bin_edges[:-1] + (np.diff(bin_edges) / 2.) bar = hlpr.ax.bar(bin_pos, counts_at_time, width=np.diff(bin_edges), color='dodgerblue') time_text = hlpr.ax.text(0.02, 0.95, '', transform=hlpr.ax.transAxes) def update_data(stepsize: int = 1): """Updates the data of the imshow objects""" log.info("Plotting animation with %d frames ...", opinions.shape[0] // stepsize) next_frame_idx = 0 if time_steps < stepsize: warnings.warn("Stepsize is greater than number of steps. " "Continue by plotting fist and last frame.") for i in range(time_steps): time_idx = i if time_idx < next_frame_idx and time_idx < time_steps: continue log.debug("Plotting frame for time index %d ...", time_idx) # Get the opinions opinion = opinions[time_idx, :] counts_step, _ = np.histogram(opinion, bin_edges) for idx, rect in enumerate(bar): rect.set_height(counts_step[idx]) if (uni['cfg']['OpDyn']['user_ageing'] == 'on'): time_text.set_text('%.0f years' % int(time[time_idx] / life_cycle)) else: time_text.set_text('step %.0f' % time[time_idx]) hlpr.ax.relim() hlpr.ax.autoscale_view(scalex=False) next_frame_idx = time_idx + stepsize yield # Register this update method with the helper, which takes care of the rest hlpr.register_animation_update(update_data)
def opinion_time_series(*, hlpr: PlotHelper, data: dict, bins: int = 100, opinion_range=None, representatives: dict = None, density_kwargs: dict = None, hist_kwargs: dict = None): """Plots the temporal development of the opinion density and the final opinion distribution. Args: hlpr (PlotHelper): The Plot Helper data (dict): The data from DAG selection bins (int, optional): The number of bins for the histograms opinion_range (tuple, optional): range of opinions to be plotted. If none provided, the opinion space from the config is used. representatives (dict, optional): kwargs for representative trajectories. Possible configurations: 'enabled': Whether to show representative trajectories (default: false). 'max_reps': The maximum total number of chosen representatives 'rep_threshold': Lower threshold for the final group size above which a second representative per group is allowed density_kwargs (dict, optional): Passed to plt.imshow (density plot) hist_kwargs (dict, optional): Passed to plt.hist (final distribution) """ opinions = data['opinion'] final_opinions = opinions.isel(time=-1) cfg_op_space = data['opinion_space'] if opinion_range is not None: opinion_range = opinion_range else: if (cfg_op_space['type']) == 'continuous': opinion_range = cfg_op_space['interval'] elif (cfg_op_space['type']) == 'discrete': opinion_range = tuple((0, cfg_op_space['num_opinions'])) num_vertices = opinions.coords['vertex_idx'].size time = opinions.coords['time'].data density_kwargs = density_kwargs if density_kwargs else {} hist_kwargs = hist_kwargs if hist_kwargs else {} rep_kwargs = representatives if representatives else {} # Plot an opinion density heatmap over time ............................... densities = np.empty((bins, len(time))) for i in range(len(time)): densities[:, i] = np.histogram(opinions.isel(time=i), range=opinion_range, bins=bins)[0] density_plot = hlpr.ax.imshow(densities, extent=[time[0], time[-1], *opinion_range], **density_kwargs) hlpr.provide_defaults("set_limits", y=opinion_range) hlpr.ax.get_yaxis().set_minor_locator(ticker.MultipleLocator(0.1)) # Add colorbar for the density plot plt.colorbar(density_plot, ax=hlpr.ax) # Plot representative trajectories ........................................ # The representative vertices are chosen based on the final opinion groups. # Reps are first picked for the largest opinion groups. if (rep_kwargs.get('enabled', False)): max_reps = rep_kwargs.get('max_reps', num_vertices) rep_threshold = rep_kwargs.get('rep_threshold', 1) reps = [] # store the vertex ids of the representatives here hist, bin_edges = np.histogram(final_opinions, range=opinion_range, bins=bins) # idxs that sort hist in descending order sort_idxs = np.argsort(hist)[::-1] # sorted hist values hist_sorted = hist[sort_idxs] # lower bin edges sorted from highest to lowest respective hist value # (restricted to non-empty bins) bins_to_rep = bin_edges[:-1][sort_idxs][hist_sorted > 0] # Generate a random sequence of vertex indices rand_vertex_idxs = np.arange(num_vertices) np.random.shuffle(rand_vertex_idxs) # Find representative vertices for b, n in zip(bins_to_rep, range(max_reps)): for v in rand_vertex_idxs: if (final_opinions.isel(vertex_idx=v) > b and final_opinions.isel(vertex_idx=v) < b + 1. / bins): reps.append(v) break # If more reps available, add a second rep for larger opinion groups if (len(reps) < max_reps): for h, b, n in zip(hist_sorted, bins_to_rep, range(max_reps - len(reps))): if (h > rep_threshold): for v in rand_vertex_idxs: if (final_opinions.isel(vertex_idx=v) > b and final_opinions.isel(vertex_idx=v) < b + 1. / bins and v not in reps): reps.append(v) break # Plot temporal opinion development for all representatives for r in reps: hlpr.ax.plot(time, opinions.isel(vertex_idx=r), lw=0.8) # Plot histogram of the final opinions ................................... hlpr.select_axis(1, 0) hist = hlpr.ax.hist(final_opinions, range=opinion_range, bins=bins, **hist_kwargs) hlpr.provide_defaults("set_limits", y=opinion_range)
def histogram(*, data: dict, hlpr: PlotHelper, x: str, hue: str = None, frames: str = None, coarsen_by: int = None, align: str = 'edge', bin_widths: Union[str, Sequence[float]] = None, suptitle_kwargs: dict = None, show_histogram_info: bool = True, **bar_kwargs): """Shows a distribution as a stacked bar plot, allowing animation. Expects as DAG result ``counts`` an xr.DataArray of one, two, or three dimensions. Depending on the ``hue`` and ``frames`` arguments, this will be represented as a stacked barplot and as an animation, respectively. Args: data (dict): The DAG results hlpr (PlotHelper): The PlotHelper x (str): The name of the dimension that represents the position of the histogram bins. By default, these are the bin *centers*. hue (str, optional): Which dimension to represent by stacking bars of different hue on top of each other frames (str, optional): Which dimension to represent by animation coarsen_by (int, optional): By which factor to coarsen the dimension specified by ``x``. Uses xr.DataArray.coarsen and pads boundary values. align (str, optional): Where to align bins. By default, uses ``edge`` for alignment, as this is more exact for histograms. bin_widths (Union[str, Sequence[float]], optional): If not given, will use the difference between the ``x`` coordinates as bin widths, padding on the right side using the last value If a string, assume that it is a DAG result and retrieve it from ``data``. Otherwise, use it directly for the ``width`` argument of ``plt.bar``, i.e. assume it's a scalar or a sequence of bin widths. suptitle_kwargs (dict, optional): Description show_histogram_info (bool, optional): Whether to draw a box with information about the histogram. **bar_kwargs: Passed on ``hlpr.ax.bar`` invocation Returns: None Raises: ValueError: Bad dimensionality or missing ``bin_widths`` DAG result """ def stacked_bar_plot(ax, dists: xr.DataArray, bin_widths): """Given a 2D xr.DataArray, plots a stacked barplot""" bottom = None # to keep track of the bottom edges for stacking # Create the iterator if hue: hues = [c.item() for c in dists.coords[hue]] original_sorting = lambda c: hues.index(c[0]) dist_iter = sorted(dists.groupby(hue), key=original_sorting) else: dist_iter = [(None, dists)] # Create the plots for each hue value for label, dist in dist_iter: dist = dist.squeeze(drop=True) ax.bar(dist.coords[x], dist, align=align, width=bin_widths, bottom=bottom, label=label, **bar_kwargs) bottom = dist.data if bottom is None else bottom + dist.data # Annotate it if not show_histogram_info: return total_sum = dists.sum().item() hlpr.ax.text(1, 1, (f"$N_{{bins}} = {dist.coords[x].size}$, " fr"$\Sigma_{{{x}}} = {total_sum:.4g}$"), transform=hlpr.ax.transAxes, verticalalignment='bottom', horizontalalignment='right', fontdict=dict(fontsize="smaller"), bbox=dict(facecolor="white", linewidth=.5, pad=2)) # Retrieve the data dists = data['counts'] # Check expected dimensions expected_ndim = 1 + bool(hue) + bool(frames) if dists.ndim != expected_ndim: raise ValueError(f"With `hue: {hue}` and `frames: {frames}`, expected " f"{expected_ndim}-dimensional data, but got:\n" f"{dists}") # Calculate bin widths if bin_widths is None: bin_widths = dists.coords[x].diff(x) bin_widths = bin_widths.pad({x: (0, 1)}, mode='edge') elif isinstance(bin_widths, str): log.remark("Using DAG result '%s' for bin widths ...", bin_widths) try: bin_widths = data[bin_widths] except KeyError: raise ValueError(f"No DAG result '{bin_widths}' available for bin " "widths. Make sure `compute_only` is set such " "that the result will be computed.") # Allow dynamically plotting without animation if not frames: hlpr.disable_animation() stacked_bar_plot(hlpr.ax, dists, bin_widths) return # else: want an animation. Everything below here is only for that case. hlpr.enable_animation() # Determine the maximum, such that the scale is always the same max_counts = dists.sum(hue).max() if hue else dists.max() # Prepare some parameters for the update routine suptitle_kwargs = suptitle_kwargs if suptitle_kwargs else {} if 'title' not in suptitle_kwargs: suptitle_kwargs['title'] = "{dim:} = {value:d}" # Define an animation update function. All frames are plotted therein. # There is no need to plot the first frame _outside_ the update function, # because it would be discarded anyway. def update(): """The animation update function: a python generator""" log.note("Commencing histogram animation for %d time steps ...", len(dists.coords[frames])) for t, _dists in dists.groupby(frames): # Plot a frame onto an empty canvas hlpr.ax.clear() stacked_bar_plot(hlpr.ax, _dists, bin_widths) # Set the y-limits hlpr.invoke_helper('set_limits', y=[0, max_counts * 1.05]) # Apply the suptitle format string, then invoke the helper st_kwargs = copy.deepcopy(suptitle_kwargs) st_kwargs['title'] = st_kwargs['title'].format(dim='time', value=t) hlpr.invoke_helper('set_suptitle', **st_kwargs) # Done with this frame. Let the writer grab it. yield # Register the animation update with the helper hlpr.register_animation_update(update, invoke_helpers_before_grab=True)
def age_groups(dm: DataManager, *, uni: UniverseGroup, hlpr: PlotHelper, time_idx: int, to_plot: str, num_bins: int = 50, val_range: tuple = (0., 1.), ages: list): data = uni['data/OpDyn/nw_users/' + to_plot] user_ages = uni['data/OpDyn/nw_users/age_u'] life_cycle = int(uni['cfg']['OpDyn']['life_cycle']) time_steps = data.coords['time'].size time = data.coords['time'].data labels = [(str(ages[i]) + "–" + str(ages[i + 1])) for i in range(len(ages) - 1)] if (ages[-1] > 120): labels[-1] = str(ages[-2]) + "+" if not time_idx: time_idx = int(next(iter(np.linspace(0, time_steps - 1, 1)))) bins = num_bins if num_bins else 50 start, stop = val_range if val_range else (0., 1.) # group data by age bins data_by_age = np.zeros((time_steps, bins, len(ages) - 1)) for t in range(time_steps): data_bins = pd.cut(np.asarray(data[t, :]), bins, labels=False, include_lowest=True) age_bins = pd.cut(np.asarray(user_ages[t, :]), ages, labels=False, include_lowest=True) for i in range(len(data_bins)): data_by_age[t, int(data_bins[i]), int(age_bins[i])] += 1 X=[pd.DataFrame(data_by_age[i, :, :], \ index=[np.around(np.linspace(start, stop, bins), 3)], \ columns=labels) \ for i in range(time_steps)] log.info("Created %d dataframes for plotting", len(X)) #colors to use colors = [ 'gold', 'orange', 'darkorange', 'orangered', 'firebrick', 'darkred', 'indigo', 'navy', 'royalblue', 'cornflowerblue', 'slategray', 'peru', 'saddlebrown', 'black' ] if time_idx == -1: im = X[-1].plot.bar( stacked=True, ax=hlpr.ax, color=colors[0:-1:np.maximum(1, int(len(colors) / len(ages)))], legend=False) else: im = X[time_idx].plot.bar(stacked=True, ax=hlpr.ax, color=colors, legend=False) def update_data(stepsize: int = 1): """Updates the data of the imshow objects""" log.info("Plotting animation with %d frames ...", data.shape[0] // stepsize) next_frame_idx = 0 if time_steps < stepsize: warnings.warn("Stepsize is greater than number of steps. " "Continue by plotting fist and last frame.") stepsize = time_steps - 1 for time_idx in range(time_steps): if time_idx < next_frame_idx and time_idx < time_steps: continue hlpr.ax.clear() im = X[time_idx].plot.bar( stacked=True, ax=hlpr.ax, legend=False, color=colors[0:-1:np.maximum(1, int(len(colors) / len(ages)))], rot=0) time_text = hlpr.ax.text(0.02, 0.95, '', transform=hlpr.ax.transAxes) timestamp = int(time[time_idx] / life_cycle) time_text.set_text('%.0f years' % timestamp) hlpr.ax.legend(loc='upper right') hlpr.ax.set_xticks([i for i in np.linspace(0, bins - 1, 11)]) hlpr.ax.set_xticklabels( [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]) hlpr.ax.set_title(hlpr.axis_cfg['set_title']['title']) hlpr.ax.set_xlabel(hlpr.axis_cfg['set_labels']['x']) hlpr.ax.set_ylabel(hlpr.axis_cfg['set_labels']['y']) next_frame_idx = time_idx + stepsize yield # Register this update method with the helper, which takes care of the rest hlpr.register_animation_update(update_data)
def opinion_animation(dm: DataManager, *, uni: UniverseGroup, hlpr: PlotHelper, num_bins: int = 100, time_idx: int, title: str = None, val_range: tuple = (0., 1.)): """Plots an animated histogram of the opinion distribution over time. If the model mode is 'conflict_undir', the opinion distribution of the discriminators and non-discriminators is also shown. Arguments: num_bins(int): Binning of the histogram time_idx (int, optional): Only plot one single frame (eg. last frame) title (str, optional): Custom plot title val_range(int, optional): Value range of the histogram """ #figure layout.............................................................. #the 'conflict_undir' has a non-standard plot layout with two additional axis #for the discriminators' and non-discriminators' opinion distributions if uni['cfg']['OpDisc']['mode'] == 'conflict_undir': disc_plot = True figure, axs = setup_figure(uni['cfg'], plot_name='opinion_anim', title=title, figsize=(10, 15), ncols=2, nrows=3, height_ratios=[1, 6, 3], width_ratios=[1, 1], gridspec=[(0, slice(0, 2)), (1, slice(0, 2)), (2, 0), (2, 1)]) else: disc_plot = False figure, axs = setup_figure(uni['cfg'], plot_name='opinion_anim', title=title) hlpr.attach_figure_and_axes(fig=figure, axes=axs) #datasets................................................................... opinions = uni['data/OpDisc/nw/opinion'] time = opinions['time'].data time_steps = time.size #dict containing the data to plot, as well axis-specific info to_plot = { 'all': { 'data': opinions, 'axs_idx': 1, 'text': '', 'color': 'dodgerblue' } } #data analysis.............................................................. if disc_plot: #calculate the opinions of only the discriminators and non-discriminators #respectively and add to the dict discriminators = uni['data/OpDisc/nw/discriminators'] p_disc = uni['cfg']['OpDisc']['discriminators'] mask = np.empty(opinions.shape, dtype=bool) mask[:, :] = (discriminators == 0) ops_disc = np.ma.MaskedArray(opinions, mask) ops_nondisc = np.ma.MaskedArray(opinions, ~mask) to_plot['disc'] = { 'data': ops_disc, 'axs_idx': 2, 'color': 'teal', 'text': f'discriminators ($p_d$={p_disc})' } to_plot['nondisc'] = { 'data': ops_nondisc, 'axs_idx': 3, 'color': 'mediumaquamarine', 'text': f'discriminators ($1-p_d$={1-p_disc})' } #get histograms............................................................. def get_hist_data(input_data): counts, bin_edges = np.histogram(input_data, range=val_range, bins=num_bins) bin_pos = bin_edges[:-1] + (np.diff(bin_edges) / 2.) return counts, bin_edges, bin_pos bars = {} t = time_idx if time_idx else range(time_steps) #calculate histograms, set axis ranges, set axis descriptions in upper left #corners for key in to_plot.keys(): counts, bin_edges, pos = get_hist_data(to_plot[key]['data'][t, :]) hlpr.select_axis(0, to_plot[key]['axs_idx']) hlpr.ax.set_xlim(val_range) bars[key] = hlpr.ax.bar(pos, counts, width=np.diff(bin_edges), color=to_plot[key]['color']) for key in to_plot.keys(): hlpr.select_axis(0, to_plot[key]['axs_idx']) to_plot[key]['text'] = hlpr.ax.text(0.02, 0.93, to_plot[key]['text'], transform=hlpr.ax.transAxes) #animate.................................................................... def update_data(stepsize: int = 1): """Updates the data of the imshow objects""" if time_idx: log.info( f"Plotting distribution at time step {time[time_idx]} ...") else: log.info( f"Plotting animation with {opinions.shape[0] // stepsize} " "frames ...") next_frame_idx = 0 if time_steps < stepsize: log.warn("Stepsize is greater than number of steps. Continue by " "plotting fist and last frame.") stepsize = time_steps - 1 for t in range(time_steps): if t < next_frame_idx: continue if time_idx: t = time_idx for key in to_plot.keys(): hlpr.select_axis(0, to_plot[key]['axs_idx']) data = to_plot[key]['data'][t, :] if key != 'all': #for disc/non-disc plots, the data is a masked array and needs #to be compressed (removing None values) data = data.compressed() counts_at_t, _, _ = get_hist_data(data) for idx, rect in enumerate(bars[key]): rect.set_height(counts_at_t[idx]) if key == 'all': to_plot[key]['text'].set_text(f'step {time[t]}') hlpr.ax.relim() hlpr.ax.autoscale_view(scalex=False) y_max = hlpr.ax.get_ylim() else: #rescale ylim to same value for all plots hlpr.ax.set_ylim(y_max) if time_idx: yield break next_frame_idx = t + stepsize yield hlpr.register_animation_update(update_data)
def group_avg(dm: DataManager, *, uni: UniverseGroup, hlpr: PlotHelper, age_groups: list = [10, 20, 40, 60, 80], num_bins: int = 100, title: str = None, val_range: tuple = (0, 1)): """This function plots the average opinion of each group over time. Arguments: age_groups (list): the age binning to be plotted for the 'ageing' model num_bins (int, optional): binning size for the histogram title (str, optional): custom title for the plot val_range (tuple, optional): binning range for the histogram Raises: TypeError: if the 'age_groups' list does not contain at least two entries """ if len(age_groups) < 2: raise TypeError("'age_groups' list must contain at least 2 entries!") #figure setup .............................................................. figure, axs = setup_figure(uni['cfg'], plot_name='group_avg', title=title) hlpr.attach_figure_and_axes(fig=figure, axes=axs) hlpr.select_axis(0, 1) #get data .................................................................. ageing = True if uni['cfg']['OpDisc']['mode'] == 'ageing' else False opinions = uni['data/OpDisc/nw/opinion'] if ageing: groups = np.asarray(uni['data/OpDisc/nw/group_label'], dtype=int) else: groups = np.asarray(uni['data/OpDisc/nw/group_label'][0, :], dtype=int) num_groups = len( age_groups) - 1 if ageing else uni['cfg']['OpDisc']['number_of_groups'] group_list = age_groups if ageing else [_ for _ in range(num_groups)] time_steps = opinions['time'].size time = np.asarray(opinions['time'].data) hlpr.ax.set_xlim(0, 1) hlpr.ax.set_ylim(time[-1], time[0]) #data analysis.............................................................. #calculate mean opinion and std of each group means = np.zeros((time_steps, num_groups)) stddevs = np.zeros_like(means) data_by_groups = data_by_group(opinions, groups, group_list, val_range, num_bins, ageing=ageing) for k in range(num_groups): for t in range(time_steps): if len(data_by_groups[k][t]) == 0: #empty slices may occur if certain age groups are not present #for a period of time continue means[t, k] = np.mean(data_by_groups[k][t]) stddevs[t, k] = np.std(data_by_groups[k][t]) #plotting................................................................... #get pretty labels if ageing: labels = [ f"Ages {group_list[_]}-{group_list[_+1]}" for _ in range(num_groups) ] max_age = np.amax(groups) if (age_groups[-1] >= max_age): labels[-1] = f"Ages {group_list[-2]}+" else: labels = [f"Group {_+1}" for _ in range(num_groups)] #plot mean opinion with std as errorbar for i in range(num_groups): hlpr.ax.errorbar(means[:, i], time, xerr=stddevs[:, i], lw=2, alpha=0.8, elinewidth=25. / time_steps, label=labels[i], capsize=0, capthick=1) hlpr.ax.set_xticks(np.linspace(0, 1, 11), minor=False) hlpr.ax.xaxis.grid(True, which='major', lw=0.1) #temporary.................................................................. #calculate the global mean and plot its turning points # mean_glob = pd.Series(np.mean(opinions, axis=1)).rolling(window=20).mean() # hlpr.ax.plot(mean_glob, time, lw=1, color='black', label='avg', zorder=num_groups+1) # # extremes = find_extrema(mean_glob, x=time)['max'] # if extremes['y']: # hlpr.ax.scatter(x=extremes['y'], y=extremes['x'], s=10, alpha=0.8, zorder=num_groups+2) # # constants = find_const_vals(mean_glob, time_steps, time=time, averaging_window=0.4, tolerance=0.01) # # if constants['t']: # hlpr.ax.scatter(x=constants['x'], y=constants['t'], s=10, color='red', alpha=0.8, zorder=num_groups+2) hlpr.ax.legend(bbox_to_anchor=(1, 1.01), loc='lower right', ncol=num_groups + 1, fontsize='xx-small')
def sweep1d(dm: DataManager, *, hlpr: PlotHelper, mv_data, age_groups: list = [10, 20, 40, 60, 80], dim: str = None, plot_by_groups: bool = True, plot_kwargs: dict = {}, to_plot: str): """Plots statistical measures of the final distribution over a sweep parameter. Configuration: - use the `select/field` key to associate one or multiple datasets - choose the dimension `dim` in which the sweep was performed. For a single sweep dimension, the sweep parameter is automatically deduced - use the `select/subspace` key to set values for all other parameters Arguments: dm (DataManager): the data manager from which to retrieve the data hlpr (PlotHelper): description mv_data (xr.Dataset): the extracted multidimensional dataset age_groups (list): The age intervals to be plotted in the final_ax distribution plot for the 'ageing' mode. dim (str, optional): the parameter dimension of the diagram. If none is provided, an attempt will be made to automatically deduce the sweep parameter plot_kwargs (dict): kwargs passed to the errorbar plot function to_plot (str): the data to be plotted. Can be: - absolute_area: the area (unsigned) of the mean minus 0.5 of the entire population. is always positive. - area: the area (signed) of the mean minus 0.5 of the entire population. can be positive or negative. - means: the mean of each group at the final time step, with an error - stddevs: the stddev of each group at the final time step, with an error Raises: ValueError: if an unknown 'to_plot' argument is passed ValueError: if the dimension cannot be deduced ValueError: if the dimension is not available """ if to_plot not in [ 'absolute_area', 'area', 'area_comp', 'area_diff', 'means', 'stddevs' ]: raise ValueError(f"Unknown statistical variable {to_plot}!") if dim is None: dim = deduce_sweep_dimension(mv_data) else: if not dim in mv_data.dims: raise ValueError( f"Dimension '{dim}' not available in multiverse data." f" Available: {mv_data.coords}") #get datasets and cfg ...................................................... keys, cfg = get_keys_cfg(mv_data, dm['multiverse'].pspace.default, keys_to_ignore=[dim, 'time']) mode = cfg['OpDisc']['mode'] ageing = True if mode == 'ageing' else False num_groups = len( age_groups) - 1 if ageing else cfg['OpDisc']['number_of_groups'] group_list = age_groups if ageing else [_ for _ in range(num_groups)] #get pretty labels if ageing: labels = [ f"Ages {group_list[_]}-{group_list[_+1]}" for _ in range(num_groups) ] max_age = np.amax(mv_data[keys]['group_label']) if (age_groups[-1] >= max_age): labels[-1] = f"Ages {group_list[-2]}+" else: labels = [f"Group {_+1}" for _ in group_list] #figure setup .............................................................. figure, axs = setup_figure(cfg, plot_name=to_plot, dim1=dim) hlpr.attach_figure_and_axes(fig=figure, axes=axs) hlpr.select_axis(0, 1) #data analysis and plotting ................................................ log.info("Commencing data analytics ...") if to_plot == 'absolute_area': data_to_plot, err = get_absolute_area(mv_data, keys, dim=dim) hlpr.ax.errorbar(mv_data.coords[dim].data, data_to_plot, yerr=err, **plot_kwargs) elif to_plot == 'area': data_to_plot, err = get_area(mv_data, keys, dim=dim) hlpr.ax.errorbar(mv_data.coords[dim].data, data_to_plot, yerr=err, **plot_kwargs) elif to_plot == 'area_comp': data_to_plot_0, err_0 = get_absolute_area(mv_data, keys, dim=dim) # hlpr.ax.errorbar(mv_data.coords[dim].data, data_to_plot_0, yerr=err_0, **plot_kwargs, label=r'$\vert A \vert$') data_to_plot_1, err_1 = get_area(mv_data, keys, dim=dim) # hlpr.ax.errorbar(mv_data.coords[dim].data, data_to_plot_1, yerr=err_1, **plot_kwargs, label=r'$A$') # hlpr.ax.legend(bbox_to_anchor=(1, 1.01), loc='lower right', # ncol=2, fontsize='xx-small') #write data values for further evaluation.................................... res = np.vstack((data_to_plot_0, err_0, data_to_plot_1, err_1)) idx = ['abs_area', 'abs_area_err', 'area', 'area_err'] df = pd.DataFrame(res, idx, mv_data.coords[dim].data) phom = cfg['OpDisc']['homophily_parameter'] df.to_csv( hlpr.out_path.replace('area.pdf', f'area_N_{num_groups}_phom_{phom}.csv')) log.info("Finished writing files") elif to_plot == 'area_diff': data_to_plot_0, err_0 = get_absolute_area(mv_data, keys, dim=dim) data_to_plot_1, err_1 = get_area(mv_data, keys, dim=dim) hlpr.ax.plot(mv_data.coords[dim].data, np.subtract(data_to_plot_0, data_to_plot_1), **plot_kwargs) elif to_plot == 'means' or to_plot == 'stddevs': keys['time'] = -1 data_to_plot, err = means_stddevs_by_group(mv_data, group_list, dim, keys, mode=mode, ageing=ageing, num_groups=num_groups, which=to_plot, time_step=-1) if plot_by_groups: for i in range(len(err)): hlpr.ax.errorbar(mv_data.coords[dim].data, data_to_plot[i], yerr=err[i], label=labels[i], **plot_kwargs) hlpr.ax.legend(bbox_to_anchor=(1, 1.01), loc='lower right', ncol=num_groups + 1, fontsize='xx-small') else: hlpr.ax.errorbar(mv_data.coords[dim].data, np.mean(data_to_plot, axis=0), yerr=np.mean(err, axis=0), **plot_kwargs) log.info("Data analysis complete.") #set axis lables etc ....................................................... hlpr.ax.set_xlabel(convert_to_label(dim)) hlpr.ax.set_ylabel(convert_to_label(to_plot))
def group_avg_anim(dm: DataManager, *, hlpr: PlotHelper, mv_data, dim: str, age_groups: list = [10, 20, 40, 60, 80], num_bins: int = 100, title: str = None, val_range: tuple = (0, 1), write: bool = False): """Plots an animation of the average opinion evolution by group as a function of the sweep parameter. Arguments: dm (DataManager): the data manager from which to retrieve the data hlpr (PlotHelper): description mv_data (xr.Dataset): the extracted multidimensional dataset dim (str): the parameter dimension of the diagram num_bins (int, optional): binning size for the histogram title (str, optional): custom plot title val_range (tuple, optional): binning range for the histogram write (bool, optional): if true, the model will write the widths of the distribution at each time step and with corresponding R_p factors (for the purposes of my thesis only) Raises: ValueError: if the dimension is not present in the multiverse data ValueError: if the parameter space is greater than four ValueError: if the sweep parameter is 'seed' (to do) """ if not dim in mv_data.dims: raise ValueError(f"Dimension '{dim}' not available in multiverse data." f" Available: {mv_data.coords}") if len(mv_data.dims) > 3: for key in mv_data.dims.keys(): if key not in ['vertex', 'time', dim] and mv_data.dims[key] > 1: raise ValueError( f"Too many dimensions! Use 'subspace' to " f"select specific values for keys other than {dim}!") if dim == 'seed': raise ValueError("'seed' sweeps currently not supported.") #datasets................................................................... #manually modify any subspace entries in the cfg cfg = dm['multiverse'].pspace.default #replace the mode if it is a subspace selection if 'mode' in mv_data.coords: cfg['OpDisc']['mode'] = str(mv_data.coords['mode'].data[0]) mode = cfg['OpDisc']['mode'] ageing = True if mode == 'ageing' else False #get group labels keys = dict(zip(dict(mv_data.dims).keys(), [0] * len(mv_data.dims))) keys.pop('vertex') if ageing: #group labels change over time keys.pop('time') groups = np.asarray(mv_data['group_label'][keys], dtype=int) else: #group labels do not change over time groups = np.asarray(mv_data['group_label'][keys], dtype=int) keys.pop('time') #replace any other subspace selection parameters for key in keys: if key != dim: cfg['OpDisc'][key] = mv_data.coords[key].data[0] num_groups = len( age_groups) - 1 if ageing else cfg['OpDisc']['number_of_groups'] num_vertices = cfg['OpDisc']['nw']['num_vertices'] group_list = age_groups if ageing else [_ for _ in range(num_groups)] time_steps = mv_data['time'].size time = mv_data['time'].data #figure layout ............................................................. figure, axs = setup_figure(cfg, plot_name='group_avgs_anim', title=title, dim=dim) hlpr.attach_figure_and_axes(fig=figure, axes=axs) hlpr.select_axis(0, 1) #data analysis ............................................................. #get mean opinion and std of each group using tools.data_by_group means = np.zeros((len(mv_data.coords[dim]), time_steps, num_groups)) stddevs = np.zeros_like(means) for param in range(len(mv_data.coords[dim])): keys[dim] = param data = np.asarray(mv_data[keys]['opinion']) data_by_groups = data_by_group(data, groups, group_list, val_range, num_bins, ageing=ageing) for k in range(num_groups): for t in range(time_steps): means[param, t, k] = np.mean(data_by_groups[k][t]) stddevs[param, t, k] = np.std(data_by_groups[k][t]) log.info("Finished data analysis.") #plotting................................................................... #get pretty labels if ageing: labels = [ f"Ages {group_list[_]}-{group_list[_+1]}" for _ in range(num_groups) ] max_age = np.amax(groups) if (age_groups[-1] >= max_age): labels[-1] = f"Ages {group_list[-2]}+" else: labels = [f"Group {_+1}" for _ in range(num_groups)] #calculate R_p factor (for p_hom sweeps) P, Q = R_p_factors(num_vertices, num_groups) R_p_fs = R_p(mv_data.coords[dim], num_groups, P, Q, mode) #animate def update_data(stepsize: int = 1): log.info( f"Plotting animation with {len(mv_data.coords[dim])} frames ...") for param in range(len(mv_data.coords[dim])): hlpr.ax.clear() hlpr.ax.set_xlim(0, 1) hlpr.ax.set_ylim(time[-1], 0) hlpr.ax.set_xlabel(hlpr.axis_cfg['set_labels']['x']) hlpr.ax.set_ylabel(hlpr.axis_cfg['set_labels']['y']) if dim == 'homophily_parameter': sw_text = ( f"$R_p=${R_p_fs[param]:.3f} ({convert_to_label(dim)} = {mv_data[dim][param].data})" ) else: sw_text = f"{convert_to_label(dim)}={mv_data[dim][param].data}" sweep_text = hlpr.ax.text(0, 1.02, sw_text, fontsize='x-small', transform=hlpr.ax.transAxes) for i in range(num_groups): hlpr.ax.errorbar(means[param, :, i], time, xerr=stddevs[param, :, i], lw=2, alpha=1, elinewidth=0.1, label=labels[i], capsize=1, capthick=0.3) hlpr.ax.legend(bbox_to_anchor=(1, 1.01), loc='lower right', ncol=num_groups, fontsize='xx-small') yield hlpr.register_animation_update(update_data) #write data values for further evaluation.................................... #This is for the purpose of my thesis only and will be removed upon #completion. if write and dim == 'homophily_parameter': widths = np.zeros((time_steps, len(mv_data.coords[dim]))) w_0 = np.min(means[:, -1, -1] - means[:, -1, 0]) w_max = np.max(means[-1, :, -1] - means[-1, :, 0]) for param in range(len(mv_data.coords[dim])): widths[:, param] = (means[param, :, -1] - means[param, :, 0] - w_0) / (w_max - w_0) df = pd.DataFrame(widths, time, R_p_fs) df.to_csv( hlpr.out_path.replace('group_avgs_anim.mp4', f'widths_{mode}.csv')) log.info("Finished writing files")
def sweep2d(dm: DataManager, *, hlpr: PlotHelper, mv_data, age_groups: list = [10, 20, 40, 60, 80], x: str, y: str, plot_kwargs: dict = {}, stacked: bool = False, to_plot: str): """For multiverse runs, this produces a two dimensional plot showing specified values. Configuration: - use the `select/field` key to associate one or multiple datasets - choose the dimension `dim` in which the sweep was performed. For a single sweep dimension, the sweep parameter is automatically deduced - use the `select/subspace` key to set values for all other parameters Arguments: dm (DataManager): the data manager from which to retrieve the data hlpr (PlotHelper): description mv_data (xr.Dataset): the extracted multidimensional dataset age_groups (list): The age intervals to be plotted in the final_ax distribution plot for the 'ageing' mode. x (str): the first parameter dimension of the diagram. y (str): the first parameter dimension of the diagram. plot_kwargs (dict, optional): kwargs passed to the scatter plot function stacked (bool): whether to plot a 2d heatmap or a stacked line plot to_plot (str): the data to be plotted. Can be: - extreme_means_diff: the difference between the means of the outer groups at the final time step. not compatible with a seed sweep. - avg_of_means_diff_to_05: the average of the absolute distance of each group to 0.5 at the final time step - absolute_area: the area (unsigned) of the mean minus 0.5 of the entire population. is always positive. - area: the area (signed) of the mean minus 0.5 of the entire population. can be positive or negative. - area_diff: the difference of the absolute area und the signed area under the means curve. Raises: ValueError: if a sweep over 'seed' is performed and to_plot is 'extreme_means_diff'. """ if ((to_plot == 'extreme_means_diff') and ('seed' in mv_data.coords) and (len(mv_data.coords['seed'].data) > 1)): raise ValueError( "Plotting does not support 'seed' at this time. Select" " a single value using the 'subspace' key") #get datasets and cfg ...................................................... keys, cfg = get_keys_cfg(mv_data, dm['multiverse'].pspace.default, keys_to_ignore=[x, y]) mode = cfg['OpDisc']['mode'] ageing = True if mode == 'ageing' else False num_groups = len( age_groups) - 1 if ageing else cfg['OpDisc']['number_of_groups'] group_list = age_groups if ageing else [_ for _ in range(num_groups)] requires_group_label = [ 'extreme_means_diff', 'avg_of_means_diff_to_05', 'avg_of_stddevs' ] #get group labels if ageing: #group labels change over time keys.update({x: 0, y: 0}) if to_plot in requires_group_label: groups = np.asarray(mv_data['group_label'][keys], dtype=int) for ele in [x, y]: keys.pop(ele) else: #group labels do not change over time keys.update({'time': 0, x: 0, y: 0}) if to_plot in requires_group_label: groups = np.asarray(mv_data['group_label'][keys], dtype=int) for ele in [x, y, 'time']: keys.pop(ele) #figure setup .............................................................. figure, axs = setup_figure(cfg, plot_name=to_plot, dim1=x, dim2=y) hlpr.attach_figure_and_axes(fig=figure, axes=axs) hlpr.select_axis(0, 1) #data analysis ............................................................. if to_plot == 'extreme_means_diff': data_to_plot = difference_of_extreme_means(mv_data, x, y, groups, group_list, ageing=ageing, group_1=0, group_2=-1, time_step=-1) elif to_plot == 'avg_of_means_diff_to_05': data_to_plot = avg_of_means_stddevs(mv_data, x, y, groups, group_list, keys, mode, num_groups, which='means', ageing=ageing, time_step=-1) elif to_plot == 'avg_of_stddevs': data_to_plot = avg_of_means_stddevs(mv_data, x, y, groups, group_list, keys, mode, num_groups, which='stddevs', ageing=ageing, time_step=-1) elif to_plot == 'absolute_area': data_to_plot = np.zeros( (len(mv_data.coords[y]), len(mv_data.coords[x]))) for param1 in range(len(mv_data.coords[x])): keys[x] = param1 data_to_plot[:, param1] = get_absolute_area(mv_data, keys, dim=y)[0] elif to_plot == 'area': data_to_plot = np.zeros( (len(mv_data.coords[y]), len(mv_data.coords[x]))) for param1 in range(len(mv_data.coords[x])): keys[x] = param1 data_to_plot[:, param1] = get_area(mv_data, keys, dim=y)[0] elif to_plot == 'area_diff': data_to_plot = np.zeros( (len(mv_data.coords[y]), len(mv_data.coords[x]))) for param1 in range(len(mv_data.coords[x])): keys[x] = param1 val_1 = get_absolute_area(mv_data, keys, dim=y)[0] val_2 = get_area(mv_data, keys, dim=y)[0] data_to_plot[:, param1] = np.subtract(val_1, val_2) #plotting .................................................................. if stacked: for i in range(len(mv_data.coords[y])): hlpr.ax.plot( data_to_plot[i, :], label=f'{convert_to_label(y)}={mv_data.coords[y].data[i]}', **plot_kwargs) hlpr.ax.legend(bbox_to_anchor=(1, 1.01), loc='lower right', ncol=len(mv_data.coords[y]) + 1, fontsize='xx-small') else: df = pd.DataFrame(data_to_plot, index=mv_data.coords[y].data, columns=mv_data.coords[x].data) im = hlpr.ax.pcolor(df, **plot_kwargs) hlpr.ax.set_ylabel(parameters[y], rotation=0) hlpr.ax.set_yticks([ i for i in np.linspace(0.5, len(mv_data.coords[y].data) - 0.5, len(mv_data.coords[y].data)) ]) hlpr.ax.set_yticklabels( [np.around(i, 3) for i in mv_data.coords[y].data]) hlpr.ax.set_xlabel(parameters[x]) hlpr.ax.set_xticks([ i for i in np.linspace(0.5, len(mv_data.coords[x].data) - 0.5, len(mv_data.coords[x].data)) ]) hlpr.ax.set_xticklabels( [np.around(i, 3) for i in mv_data.coords[x].data]) divider = make_axes_locatable(hlpr.ax) cax = divider.append_axes("right", size="5%", pad=0.2) cbar = figure.colorbar(im, cax=cax) cbar.set_label(convert_to_label(to_plot))
def sweep( dm: DataManager, *, hlpr: PlotHelper, mv_data, plot_prop: str, dim: str, dim2: str=None, stack: bool=False, no_errors: bool=False, data_name: str='opinion_u', bin_number: int=100, plot_kwargs: dict=None): """Plots a bifurcation diagram for one parameter dimension (dim) i.e. plots the chosen final distribution measure over the parameter, or - if second parameter dimension (dim2) is given - plots the 2d parameter space as a heatmap (if not stacked). Configuration: - use the `select/field` key to associate one or multiple datasets - change `data_name` if needed - choose the dimension `dim` (and `dim2`) in which the sweep was performed. Arguments: dm (DataManager): The data manager from which to retrieve the data hlpr (PlotHelper): Description mv_data (xr.Dataset): The extracted multidimensional dataset plot_prop (str): The quantity that is extracted from the data. Available are: ['number_of_peaks', 'localization', 'max_distance', 'polarization'] dim (str): The parameter dimension of the diagram dim2 (str, optional): The second parameter dimension of the diagram stack (bool, optional): Whether the plots for dim2 are stacked or extend to heatmap data_name (str, optional): Description bin_number (int, optional): default: 100 number of bins for the discretization of the final distribution plot_kwargs (dict, optional): passed to the plot function Raises: TypeError: for a parameter dimesion higher than 5 (and higher than 4 if not sweeped over seed) ValueError: If 'data_name' data does not exist """ # Drop coordinates that only have a single value # (i.e. when certain value selected in fields/select/..) # for coord in mv_data.coords: # if len(mv_data[coord]) == 1: # mv_data = mv_data[{coord: 0}] # mv_data = mv_data.drop(coord) if not dim in mv_data.dims: raise ValueError("Dimension `dim` not available in multiverse data." " Was: {} with value: '{}'." " Available: {}" "".format(type(dim), dim, mv_data.coords)) if len(mv_data.coords) > 5: raise TypeError("mv_data has more than two extra parameter dimensions." " Are: {}. Chosen dim: {}. (Max: ['vertex', 'time', " "'seed'] + 2)".format(mv_data.coords, dim)) if (len(mv_data.coords) > 4) and ('seed' not in mv_data.coords): raise TypeError("mv_data has more than two extra parameter dimensions." " Are: {}. Chosen dim: {}. (Max: ['vertex', 'time', " "'seed'] + 2)".format(mv_data.coords, dim)) plot_kwargs = (plot_kwargs if plot_kwargs else {}) # Default plot configurations plot_kwargs_default_1d = {'fmt': 'o', 'ls': '-', 'lw': .4, 'mec': None, 'capsize': 1, 'mew': .6, 'ms': 2} if no_errors: plot_kwargs_default_1d = {'marker': 'o', 'ms': 2, 'lw': 0.4} plot_kwargs_default_2d = {'origin': 'lower', 'cmap': 'Spectral_r'} # analysis and plot functions .............................................. def plot_data_1d(param_plot, data_plot, std, plot_kwargs): if no_errors: hlpr.ax.plot(param_plot, data_plot, **plot_kwargs) else: hlpr.ax.errorbar(param_plot, data_plot, yerr=std, **plot_kwargs) def plot_data_2d(data_plot, param1, param2, plot_kwargs): heatmap = hlpr.ax.imshow(data_plot, **plot_kwargs) #'RdYlGn_r' if len(param1) <= 10: hlpr.ax.set_xticks(np.arange(len(param1))) xticklabels = np.array(["{:.2f}".format(p) for p in param1]) else: hlpr.ax.set_xticks(np.arange(len(param1)) [::(int)(np.ceil(len(param1)/10))]) xticklabels = np.array(["{:.2f}".format(p) for p in param1[::(int)(np.ceil(len(param1)/10))]]) if len(param2) <= 10: hlpr.ax.set_yticks(np.arange(len(param2))) yticklabels = np.array(["{:.2f}".format(p) for p in param2]) else: hlpr.ax.set_yticks(np.arange(len(param2)) [::(int)(np.ceil(len(param2)/10))]) yticklabels = np.array(["{:.2f}".format(p) for p in param2[::(int)(np.ceil(len(param2)/10))]]) hlpr.ax.set_xticklabels(xticklabels) hlpr.ax.set_yticklabels(yticklabels) plt.colorbar(heatmap) def get_number_of_peaks(raw_data): final_state = raw_data.isel(time=-1) final_state = final_state[~np.isnan(final_state)] # binning of the final opinion distribution with binsize=1/bin_number hist_data,bin_edges = np.histogram(final_state, range=(0.,1.), bins=bin_number) peak_number = len(find_peaks(hist_data, prominence=15, distance=5)[0]) return peak_number def get_localization(raw_data): final_state = raw_data.isel(time=-1) final_state = final_state[~np.isnan(final_state)] hist, bins = np.histogram(final_state, range=(0.,1.), bins=bin_number, density=True) hist *= 1/bin_number l = 0 norm = 0 for i in range(len(hist)): norm += hist[i]**4 l += hist[i]**2 l = norm/l**2 return l def get_max_distance(raw_data): final_state = raw_data.isel(time=-1) final_state = final_state[~np.isnan(final_state)] min = 1. max = 0. for val in final_state: if val > max: max = val elif val < min: min = val return max-min def get_polarization(raw_data): final_state = raw_data.isel(time=-1) final_state = final_state[~np.isnan(final_state)] p = 0 for i in range(len(final_state)): for j in range(len(final_state)): p += (final_state[i] - final_state[j])**2 return p def get_final_variance(raw_data): final_state = raw_data.isel(time=-1) final_state = final_state[~np.isnan(final_state)] var = np.var(final_state.values) return var def get_convergence_time(raw_data): max_tol = 0.05 min_tol = 0.005 ct = 0. for t in raw_data.coords['time']: nxt = False data = raw_data.sel(time=t) sorted_data = np.sort(data) for diff in np.diff(sorted_data): if diff > min_tol and diff < max_tol: nxt = True break if not nxt: ct = t break return ct def get_property(data, plot_prop: str=None): if plot_prop == 'number_of_peaks': return get_number_of_peaks(data) elif plot_prop == 'localization': return get_localization(data) elif plot_prop == 'max_distance': return get_max_distance(data) elif plot_prop == 'polarization': return get_polarization(data) elif plot_prop == 'final_variance': return get_final_variance(data) elif plot_prop == 'convergence_time': return get_convergence_time(data) else: raise ValueError("'plot_prop' invalid! Was: {}".format(plot_prop)) # data handling and plot setup ............................................. legend = False heatmap = False if not data_name in mv_data.data_vars: raise ValueError("'{}' not available in multiverse data." " Available in multiverse field: {}" "".format(data_name, mv_data.data_vars)) # this is the dataset containing the chosen data to plot # for all parameter combinations dataset = mv_data[data_name] # number of different parameter values i.e. number of points in the graph num_param = len(dataset[dim]) # initialize arrays containing the data to plot: data_plot = np.zeros(num_param) param_plot = np.zeros(num_param) std = np.zeros_like(data_plot) # Get additional information for plotting leg_title = dim2 if leg_title == "num_vertices": leg_title = "$N$" elif leg_title == "weighting": leg_title = "$\kappa$" elif leg_title == "rewiring": leg_title = "$r$" elif leg_title == "p_rewire": leg_title = "$p_{rewire}$" cmap_kwargs = plot_kwargs.pop("cmap_kwargs", None) if cmap_kwargs: cmin = cmap_kwargs.get("min", 0.) cmax = cmap_kwargs.get("max", 1.) cmap = cmap_kwargs.get("cmap") cmap = cm.get_cmap(cmap) markers = plot_kwargs.pop("markers", None) # If only one parameter sweep (dim) is done, the calculated quantity # is plotted against the parameter value. if (len(mv_data.coords) == 3): plot_kwargs = recursive_update(plot_kwargs_default_1d, plot_kwargs) param_index = 0 for data in dataset: data_plot[param_index] = get_property(data, plot_prop) param_plot[param_index] = data[dim] param_index += 1 if markers: plot_kwargs['marker'] = markers[0] plot_data_1d(param_plot, data_plot, std, plot_kwargs) # if two sweeps are done, check if the seed is sweeped elif (len(mv_data.coords) == 4): # average over the seed in this case if 'seed' in mv_data.coords: plot_kwargs = recursive_update(plot_kwargs_default_1d, plot_kwargs) for i in range(len(dataset[dim])): num_seeds = len(dataset['seed']) arr = np.zeros(num_seeds) for j in range(num_seeds): data = dataset[{dim: i, 'seed': j}] arr[j] = get_property(data, plot_prop) data_plot[i] = np.mean(arr) param_plot[i] = dataset[dim][i] std[i] = np.std(arr) if markers: plot_kwargs['marker'] = markers[0] plot_data_1d(param_plot, data_plot, std, plot_kwargs) # If 'stack', plot data of both dimensions against dim values (1d), # otherwise map data on 2d sweep parameter grid (color-coded). elif stack: legend = True plot_kwargs = recursive_update(plot_kwargs_default_1d, plot_kwargs) param_plot = dataset[dim] for i in range(len(dataset[dim2])): for j in range(num_param): data = dataset[{dim2: i, dim: j}] data_plot[j] = get_property(data, plot_prop) recursive_update(plot_kwargs, {'label': "{}" "".format(dataset[dim2][i].data)}) if cmap_kwargs: c = i * (cmax - cmin) / (len(dataset[dim2])-1.) + cmin recursive_update(plot_kwargs, {'color': cmap(c)}) if markers: plot_kwargs['marker'] = markers[i%len(markers)] plot_data_1d(param_plot, data_plot, std, plot_kwargs) else: plot_kwargs = recursive_update(plot_kwargs_default_2d, plot_kwargs) heatmap = True num_param2 = len(dataset[dim2]) data_plot = np.zeros((num_param2, num_param)) param1 = np.zeros(num_param) param2 = np.zeros(num_param2) for i in range(num_param): param1[i] = dataset[dim][i] for j in range(num_param2): param2[j] = dataset[dim2][j] data = dataset[{dim: i, dim2: j}] data_plot[j,i] = get_property(data, plot_prop) plot_data_2d(data_plot, param1, param2, plot_kwargs) elif (len(mv_data.coords) == 5): num_seeds = len(dataset['seed']) if stack: legend = True plot_kwargs = recursive_update(plot_kwargs_default_1d, plot_kwargs) param_plot = dataset[dim] param2 = dataset[dim2] for i in range(len(param2)): for j in range(num_param): arr = np.zeros(num_seeds) for k in range(num_seeds): data = dataset[{dim2: i, dim: j, 'seed': k}] arr[k] = get_property(data, plot_prop) data_plot[j] = np.mean(arr) recursive_update(plot_kwargs, {'label': "{}" "".format(dataset[dim2][i].data)}) if cmap_kwargs: c = i * (cmax - cmin) / (len(dataset[dim2])-1.) + cmin recursive_update(plot_kwargs, {'color': cmap(c)}) if markers: plot_kwargs['marker'] = markers[i%len(markers)] plot_data_1d(param_plot, data_plot, std, plot_kwargs) else: plot_kwargs = recursive_update(plot_kwargs_default_2d, plot_kwargs) heatmap = True num_param2 = len(dataset[dim2]) data_plot = np.zeros((num_param2, num_param)) param1 = np.zeros(num_param) param2 = np.zeros(num_param2) for i in range(num_param): param1[i] = dataset[dim][i] for j in range(num_param2): param2[j] = dataset[dim2][j] arr = np.zeros(num_seeds) for k in range(num_seeds): data = dataset[{dim: i, dim2: j, 'seed': k}] arr[k] = get_property(data, plot_prop) data_plot[j,i] = np.mean(arr) plot_data_2d(data_plot, param1, param2, plot_kwargs) # else: Error raised # Add labels and title if heatmap: hlpr.provide_defaults('set_labels', **{'x': dim, 'y': dim2}) hlpr.provide_defaults('set_title', **{'title': plot_prop}) else: hlpr.provide_defaults('set_labels', **{'x': ("$\epsilon$" if dim == "tolerance_u" else dim)}) if plot_prop == "localization": hlpr.provide_defaults('set_labels', **{'y': "$L$"}) elif plot_prop == "final_variance": hlpr.provide_defaults('set_labels', **{'y': "var($\sigma$)"}) elif plot_prop == 'number_of_peaks': hlpr.provide_defaults('set_labels', **{'y': "$N_{peaks}$"}) elif plot_prop == 'max_distance': hlpr.provide_defaults('set_labels', **{'y': "$d_{max}$"}) elif plot_prop == 'convergence_time': hlpr.provide_defaults('set_labels', **{'y': "$T_{conv}$"}) else: hlpr.provide_defaults('set_labels', **{'y': plot_prop}) if legend: hlpr.ax.legend(title=leg_title) # Set minor ticks if plot_prop == "max_distance" or plot_prop == "localization": hlpr.ax.get_xaxis().set_major_locator(ticker.MultipleLocator(0.1)) hlpr.ax.get_xaxis().set_minor_locator(ticker.MultipleLocator(0.05)) hlpr.ax.get_yaxis().set_minor_locator(ticker.MultipleLocator(0.1))
def opinion_animated(dm: DataManager, *, uni: UniverseGroup, hlpr: PlotHelper, num_bins: int = 100, time_idx: int, **plot_kwargs): """Plots an animated histogram of the opinion distribution over time. Arguments: num_bins(int): Binning of the histogram time_idx (int, optional): Only plot one single frame (eg. last frame) plot_kwargs (dict, optional): Passed to plt.bar """ #datasets................................................................... opinions = uni['data/Opinionet/nw/opinion'] time = opinions['time'].data time_steps = time.size cfg_op_space = uni['cfg']['Opinionet']['opinion_space'] if (cfg_op_space['type']) == 'continuous': val_range = cfg_op_space['interval'] elif (cfg_op_space['type']) == 'discrete': val_range = tuple((0, cfg_op_space['num_opinions'])) #get histograms............................................................. def get_hist_data(input_data): counts, bin_edges = np.histogram(input_data, range=val_range, bins=num_bins) bin_pos = bin_edges[:-1] + (np.diff(bin_edges) / 2.) return counts, bin_edges, bin_pos t = time_idx if time_idx else range(time_steps) #calculate histograms, set axis ranges counts, bin_edges, pos = get_hist_data(opinions[t, :]) hlpr.ax.set_xlim(val_range) bars = hlpr.ax.bar(pos, counts, width=np.diff(bin_edges), **plot_kwargs) text = hlpr.ax.text(0.02, 0.93, f'step {time[t]}', transform=hlpr.ax.transAxes) #animate.................................................................... def update_data(stepsize: int = 1): """Updates the data of the imshow objects""" if time_idx: log.info( f"Plotting distribution at time step {time[time_idx]} ...") else: log.info( f"Plotting animation with {opinions.shape[0] // stepsize} " "frames ...") next_frame_idx = 0 if time_steps < stepsize: log.warn("Stepsize is greater than number of steps. Continue by " "plotting first and last frame.") stepsize = time_steps - 1 for t in range(time_steps): if t < next_frame_idx: continue if time_idx: t = time_idx data = opinions[t, :] counts_at_t, _, _ = get_hist_data(data) for idx, rect in enumerate(bars): rect.set_height(counts_at_t[idx]) text.set_text(f'step {time[t]}') hlpr.ax.relim() hlpr.ax.autoscale_view(scalex=False) if time_idx: yield break next_frame_idx = t + stepsize yield hlpr.register_animation_update(update_data)
def graph_animation_update(*, hlpr: PlotHelper, graphs: xr.DataArray = None, graph_group=None, graph_creation: dict = None, register_property_maps: dict = None, clear_existing_property_maps: bool = True, positions: dict = None, animation_kwargs: dict = None, suptitle_kwargs: dict = None, **drawing_kwargs): """Graph animation frame generator. Yields whenever the plot helper may grab the current frame. If ``graphs`` is given, the networkx graphs in the array are used to create the frames. Otherwise, use a graph group. The frames are defined via the selectors in ``animation_kwargs``. From all provided coordinates the cartesian product is taken. Each of those points defines one graph and thus one frame. The selection kwargs in ``graph_creation`` are ignored silently. Args: hlpr (PlotHelper): The plot helper graphs (xr.DataArray, optional): Networkx graphs to draw. The array will be flattened beforehand. graph_group (None, optional): Required if ``graphs`` is None. The GraphGroup from which to generate the animation frames as specified via sel and isel in ``animation_kwargs``. graph_creation (dict, optional): Graph creation configuration. Passed to :py:meth:`~utopya.plot_funcs._graph.GraphPlot.create_graph_from_group` if ``graph_group`` is given. register_property_maps (dict, optional): Passed to :py:meth:`~utopya.plot_funcs._graph.GraphPlot.create_graph_from_group` if ``graph_group`` is given. clear_existing_property_maps (bool, optional): Passed to :py:meth:`~utopya.plot_funcs._graph.GraphPlot.create_graph_from_group` if ``graph_group`` is given. positions (dict, optional): The node position configuration. If ``update_positions`` is True the positions are reconfigured for each frame. animation_kwargs (dict, optional): Animation configuration. The following arguments are allowed: times (dict, optional): *Deprecated*: Equivaluent to a sel.time entry. sel (dict, optional): Select by value. Coordinate values (or ``from_property`` entry) keyed by dimension name. isel (dict, optional): Select by index. Coordinate indices keyed by dimension. May be given together with ``sel`` if no key appears in both. update_positions (bool, optional): Whether to reconfigure the node positions for each frame (default=False). update_colormapping (bool, optional): Whether to reconfigure the nodes' and edges' :py:class:`~utopya.plot_funcs._mpl_helpers.ColorManager` for each frame (default=False). If False, the colormapping (and the colorbar) is configured with the first frame and then fixed. skip_empty_frames (bool, optional): Whether to skip the frames where the selected graph is missing or of a type different than ``nx.Graph`` (default=False). If False, such frames are empty. suptitle_kwargs (dict, optional): Passed on to the PlotHelper's ``set_suptitle`` helper function. Only used in animation mode. The ``title`` can be a format string containing a placeholder with the dimension name as key for each dimension along which selection is done. The format string is updated for each frame of the animation. The default is ``<dim-name> = {<dim-name>}`` for each dimension. **drawing_kwargs: Passed to :py:class:`~utopya.plot_funcs._graph.GraphPlot` """ suptitle_kwargs = (copy.deepcopy(suptitle_kwargs) if suptitle_kwargs is not None else {}) animation_kwargs = (copy.deepcopy(animation_kwargs) if animation_kwargs is not None else {}) update_positions = animation_kwargs.pop("update_positions", False) update_colormapping = animation_kwargs.pop("update_colormapping", False) skip_empty_frames = animation_kwargs.pop("skip_empty_frames", False) sel = animation_kwargs.pop("sel", None) isel = animation_kwargs.pop("isel", None) if graphs is None: graphs = graph_array_from_group( graph_group=graph_group, graph_creation=graph_creation, register_property_maps=register_property_maps, clear_existing_property_maps=clear_existing_property_maps, sel=sel, isel=isel, **animation_kwargs, ) else: # Apply selectors to the graphs DataArray if sel is not None: graphs = graphs.sel(**sel) if isel is not None: graphs = graphs.isel(**isel) # Prepare the suptitle format string once for all frames if "title" not in suptitle_kwargs: suptitle_kwargs["title"] = "; ".join( [f"{dim}" + " = {" + dim + "}" for dim in graphs.coords]) def get_graph_and_coords(graphs: xr.DataArray): """Generator that yields the (graph, coordinates) pairs""" dims = graphs.dims if dims: # Stack all dimensions of the xr.DataArray graphs = graphs.stack(_flat=dims) # For each entry, yield the graph and the coords for el in graphs: g = el.item() coords = { d: c for d, c in zip(dims, el.coords["_flat"].item()) } # Also get the scalar coordinates which were not stacked coords.update({ d: c.item() for d, c in el.coords.items() if d != "_flat" }) yield g, coords else: # zero-dimensional: extract single graph item g = graphs.item() coords = {d: c.item() for d, c in graphs.coords.items()} yield g, coords # Don't show the axis. This is also done in GraphPlot.draw but need to do # it here in case the first frame is empty. hlpr.ax.axis("off") # Indicator for things to be done only once after the first GraphPlot.draw # call. _first_graph = True # Always configure the colormanagers on the first GraphPlot.draw call. # Overwritten by the update_colormapping kwarg afterwards. _update_colormapping = True # Loop over all remaining graphs for g, coords in get_graph_and_coords(graphs): _missing_val = not isinstance(g, nx.Graph) # On missing entries, skip the frame if skip_empty_frames=True. # If skip_empty_frames=False, the suptitle helper is applied but no # graph is drawn. if _missing_val and skip_empty_frames: continue if not _missing_val: # Use a new GraphPlot for the next frame such that the `select` # kwargs are re-evaluated. gp = GraphPlot( g=g, fig=hlpr.fig, ax=hlpr.ax, positions=positions, **drawing_kwargs, ) gp.draw( suppress_cbar=not _update_colormapping, update_colormapping=_update_colormapping, ) # Update the suptitle format string and invoke the helper st_kwargs = copy.deepcopy(suptitle_kwargs) st_kwargs["title"] = st_kwargs["title"].format(**coords) hlpr.invoke_helper("set_suptitle", **st_kwargs) if _first_graph: # Fix the y-position of the suptitle after the first invocation # since repetitive invocations of the set-suptitle helper would # re-adjust the subplots size each time. Use the matplotlib default # if not given. if "y" not in suptitle_kwargs: suptitle_kwargs["y"] = 0.98 _update_colormapping = update_colormapping # Let the writer grab the current frame yield if not _missing_val: # Clean up the frame gp.clear_plot(keep_colorbars=not _update_colormapping) if _first_graph: # If the positions should be fixed, overwrite the positions arg if not update_positions: positions = dict(from_dict=gp.positions) # Done handling the first GraphPlot.draw call _first_graph = False