예제 #1
0
def state_mean(dm: DataManager, *, hlpr: PlotHelper, uni: UniverseGroup,
               mean_of: str, **plot_kwargs):
    """Calculates the mean of the `mean_of` dataset and performs a lineplot
    over time.
    
    Args:
        dm (DataManager): The data manager from which to retrieve the data
        hlpr (PlotHelper): The PlotHelper that instantiates the figure and
            takes care of plot aesthetics (labels, title, ...) and saving.
            To plot on the current axis, use ``hlpr.ax``.
        uni (UniverseGroup): The selected universe data
        mean_of (str): The name of the CopyMeGrid dataset that the mean is to be
            calculated of
        **plot_kwargs: Passed on to matplotlib.pyplot.plot
    """
    # Extract the data
    data = uni['data/CopyMeGrid'][mean_of]

    # Calculate the mean over the spatial dimensions
    mean = data.mean(['x', 'y'])

    # Call the plot function on the currently selected axis in the plot helper
    hlpr.ax.plot(mean.coords['time'], mean, **plot_kwargs)
    # NOTE `hlpr.ax` is just the current matplotlib.axes object. It has the
    #      same interface as `plt`, aka `matplotlib.pyplot`

    # Provide the plot helper with some information that is then used when
    # the helpers are invoked
    hlpr.provide_defaults('set_title', title="Mean '{}'".format(mean_of))
    hlpr.provide_defaults('set_labels', y="<{}>".format(mean_of))
예제 #2
0
def time_series(*, data: dict, hlpr: PlotHelper, **plot_kwargs):
    """This is a generic plotting function that plots one or multiple time
    series from the 'y' tag that is selected via the DAG framework.

    The y data needs to be an xarray object. If y is an xr.DataArray, it is
    assumed to be one- or two-dimensional.
    If y is an xr.Dataset, all data variables are plotted and their name is
    used as the label.

    For the x axis values, the corresponding 'time' coordinates are used.
    
    Args:
        data (dict): The data selected by the DAG framework
        hlpr (PlotHelper): The plot helper
        **plot_kwargs: Passed on ot matplotlib.pyplot.plot
    """
    y = data['y']

    # If this is an xr.DataArray, it may be one or two-dimensional
    if isinstance(y, xr.Dataset):
        # Simply plot all data variables
        for dvar, line in y.data_vars.items():
            hlpr.ax.plot(line.coords['time'], line, label=dvar, **plot_kwargs)

    elif isinstance(y, xr.DataArray):
        # Also allow two-dimensional arrays
        if y.ndim == 1:
            hlpr.ax.plot(y.coords['time'], y, **plot_kwargs)

        elif y.ndim == 2:
            loop_dim = [d for d in y.dims if d != 'time'][0]

            for c in y.coords[loop_dim]:
                line = y.sel({loop_dim: c})
                hlpr.ax.plot(line.coords['time'],
                             line,
                             label="{:.2g}".format(c.item()),
                             **plot_kwargs)

            # Provide a default title to the legend: name of the loop dimension
            hlpr.provide_defaults('set_legend',
                                  title="${}$ coordinate".format(loop_dim))

        else:
            raise ValueError("Given y-data needs to be of one- or two-"
                             "dimensional, was of dimensionality {}! Data: {}"
                             "".format(y.ndim, y))

    else:
        raise TypeError("Expected xr.Dataset or xr.DataArray, got {}"
                        "".format(type(y)))
예제 #3
0
def some_DAG_based_CopyMeGrid_plot(*,
                                   data: dict,
                                   hlpr: PlotHelper,
                                   scatter_kwargs: dict = None):
    """This is an example plot to show how to implement a generic DAG-based
    plot function.
    
    In this example, nothing spectacular happens: A scatter plot is made and
    a sine curve with some amplitude and frequency is plotted, both of which
    arguments are provided via the DAG.
    
    Ideally, plot functions should be as generic as possible; just a bridge
    between some data (produced by the DAG) and its visualization.
    This plot is a (completely made-up) example for cases where it makes sense
    to have a rather specific plot...
    
    .. note::
    
        To specify the plot in a generic way, omit the ``creator_type``
        argument in the decorator. If you want to specialize the plot function,
        you can of course specify the creator type, but most often that's not
        necessary.
    
        Check the decorator's signature and the utopya and dantro documentation
        for more information on the DAG framework
    
    Args:
        data (dict): The selected DAG data, contains the required DAG tags.
        hlpr (PlotHelper): The PlotHelper that instantiates the figure and
            takes care of plot aesthetics (labels, title, ...) and saving.
            To plot on the current axis, use ``hlpr.ax``.
        scatter_kwargs (dict, optional): Passed to matplotlib.pyplot.scatter
    """
    x_vals = data['x_values']
    hlpr.ax.scatter(x_vals, data['scatter'],
                    **(scatter_kwargs if scatter_kwargs else {}))

    # Do the sine plot
    sine = (data['amplitude'] * np.sin(x_vals * data['frequency']) +
            data['offset'])
    hlpr.ax.plot(x_vals, sine, label="my sine")

    # Enable the legend
    hlpr.mark_enabled("set_legend")
예제 #4
0
def sweep2d(data,
            hlpr: PlotHelper,
            x: str,
            y: str,
            z: str,
            plot_kwargs: dict = {}):
    """For multiverse runs, this produces a two dimensional plot showing
    specified values.

    Arguments:
        data (xarray): the dataset
        hlpr (PlotHelper): description
        x (str): the first parameter dimension of the diagram.
        y (str): the second parameter dimension of the diagram.
        z (str): the data dimension.
        plot_kwargs (dict, optional): kwargs passed to the pcolor plot function
    """

    df = pd.DataFrame(data['data'][z].data,
                      index=data['data'][y].data,
                      columns=data['data'][x].data)
    im = hlpr.ax.pcolor(df, **plot_kwargs)

    hlpr.ax.set_yticks([
        i for i in np.linspace(0.5,
                               len(data['data'][y].data) -
                               0.5, len(data['data'][y].data))
    ])
    hlpr.ax.set_yticklabels([np.around(i, 3) for i in data['data'][y].data])

    hlpr.ax.set_xticks([
        i for i in np.linspace(0.5,
                               len(data['data'][x].data) -
                               0.5, len(data['data'][x].data))
    ])
    hlpr.ax.set_xticklabels([np.around(i, 3) for i in data['data'][x].data])

    divider = make_axes_locatable(hlpr.ax)
    cax = divider.append_axes("right", size="5%", pad=0.2)
    cbar = hlpr.fig.colorbar(im, cax=cax)
    cbar.set_label(z)

    hlpr.select_axis(0, 0)
예제 #5
0
def bifurcation(dm: DataManager,
                *,
                hlpr: PlotHelper,
                mv_data,
                avg_window: int=20,
                dim: str=None,
                plot_kwargs: dict=None,
                title: str=None):

    """For multiverse runs, this plot finds the upper turning points of the
    average opinion and plots them against the sweep parameter. Only works for a
    single sweep parameter or for a sweep parameter and different seeds.

    Configuration:
        - use the `select/field` key to associate one or multiple datasets
        - choose the dimension `dim` in which the sweep was performed. For a single
          sweep dimension, the sweep parameter is automatically deduced
        - use the `select/subspace` key to set values for all other parameters

    Arguments:
        dm (DataManager): the data manager from which to retrieve the data
        hlpr (PlotHelper): description
        mv_data (xr.Dataset): the extracted multidimensional dataset
        avg_window (int): the smoothing window for the rolling average of the
           opinion dataset
        dim (str, optional): the parameter dimension of the diagram. If no str
           is passed, an attempt will be made to automatically deduce the sweep
           dimension.
        plot_kwargs (dict, optional): kwargs passed to the scatter plot function
        title (str, optional): custom plot title

    Raises:
        ValueError: for a parameter dimension higher than 3 (or 4 if the sweep
        is also conducted over the seed)
        ValueError: if dim does not exist
    """
    if dim is None:
        dim = deduce_sweep_dimension(mv_data)

    else:
        if not dim in mv_data.dims:
            raise ValueError(f"Dimension '{dim}' not available in multiverse data."
                             f" Available: {mv_data.coords}")

    #get datasets and cfg ......................................................
    dataset = mv_data['opinion']
    time_steps = dataset['time'].size
    keys, cfg = get_keys_cfg(mv_data, dm['multiverse'].pspace.default,
                             keys_to_ignore=[dim, 'time'])

    #figure setup ..............................................................
    figure, axs = setup_figure(cfg, plot_name='bifurcation', title=title, dim1=dim)
    hlpr.attach_figure_and_axes(fig=figure, axes=axs)
    hlpr.select_axis(0, 1)

    #data analysis .............................................................
    #get the turning points of the average opinion (maxima only). If a sweep over
    #seed was performed, multiple datapoints are collected per x-value
    log.info("Starting data analysis ...")
    to_plot = []
    if 'seed' in mv_data.coords and len(dataset['seed'])>1:
        for i in range(len(dataset[dim])):
            extremes = []
            for j in range(len(dataset['seed'])):
                keys[dim] = i
                keys['seed'] = j
                data = np.asarray(dataset[keys])
                means_glob = pd.Series(np.mean(data, axis=1)).rolling(window=avg_window).mean()
                res = find_extrema(means_glob)
                extremes.extend(res['max']['y'])

            to_plot.append((dataset[dim][i].data, extremes))
    else:
        for i in range(len(dataset[dim])):
            keys[dim] = i
            data = np.asarray(dataset[keys])
            means_glob = pd.Series(np.mean(data, axis=1)).rolling(window=avg_window).mean()
            extremes = find_extrema(means_glob)['max']['y']

            to_plot.append((dataset[dim][i].data, extremes))

    log.info("Data analysis complete.")

    #plot scatter plot of extrema ..............................................
    for p, o in to_plot:
        hlpr.ax.scatter([p] * len(o), o, **plot_kwargs)

    #hlpr.ax.set_ylim(0, 1)
    hlpr.ax.set_xlabel(convert_to_label(dim))
    hlpr.ax.set_ylabel(r'mean opinion $\bar{\sigma}$')
    legend_elements = [Line2D([0], [0],
                       label=(r'$\bar{\sigma}^\prime = 0$,'+
                              r'$\bar{\sigma}^{\prime \prime} < 0$'),
                       lw=0,
                       marker='o', color=plot_kwargs['color'] if not None else 'navy',
                       markerfacecolor=plot_kwargs['color'] if not None else 'navy',
                       markersize=5)]
    hlpr.ax.legend(handles=legend_elements, bbox_to_anchor=(1, 1.01),
                   loc='lower right', ncol=2, fontsize='xx-small')
예제 #6
0
def bifurcation_diagram(dm: DataManager,
                        *,
                        hlpr: PlotHelper,
                        mv_data: xr.Dataset,
                        dim: str = None,
                        dims: Tuple[str, str] = None,
                        analysis_steps: Sequence[Union[str, Tuple[str, str]]],
                        custom_analysis_funcs: Dict[str, Callable] = None,
                        analysis_kwargs: dict = None,
                        visualization_kwargs: dict = None,
                        to_plot: dict = None,
                        **kwargs) -> None:
    """Plots a bifurcation diagram for one or two parameter dimensions
       (arguments ``dim`` or ``dims``). 

    Args:
        dm (DataManager): The data manager from which to retrieve the data
        hlpr (PlotHelper): The PlotHelper that instantiates the figure and
            takes care of plot aesthetics (labels, title, ...) and saving
        mv_data (xr.Dataset): The extracted multidimensional dataset
        dim (str, optional): The required parameter dimension of the 1d
            bifurcation diagram. 
        dims (str, optional): The required parameter dimensions (x, y) of the 
            2d-bifurcation diagram.                  
        analysis_steps (Sequence): The analysis steps that are to be made
            until one is conclusive. Applied per universe.

            - If seq of str: The str will also be used as attractor key for 
                plotting if the test is conclusive.
            - If seq of Tuple(str, str): The first str defines the attractor
                key for plotting, the second str is a key within 
                custom_analysis_funcs.

            Default analysis_funcs are:

            - endpoint: utopya.dataprocessing.find_endpoint
            - fixpoint: utopya.dataprocessing.find_fixpoint
            - multistability: utdp.find_multistability
            - oscillation: utdp.find_oscillation
            - scatter: resolve_scatter
        custom_analysis_funcs (dict): A collection of custom analysis functions
            that will overwrite the default analysis funcs (recursive update).
        analysis_kwargs (dict, optional): The entries need to match the 
            analysis_steps. The subentry (dict) is passed on to the analysis 
            function.
        visualization_kwargs (dict, optional): The entries need to match the 
            analysis_steps. The subentry (dict) is used to configure a
            rectangle to visualize the conclusive analysis step. Is passed to
            matplotlib.patches.rectangle. xy, width, height, and angle are
            ignored and set automatically. Required in 2d bifurcation diagram.
        to_plot (dict, optional): The configuration for the data to plot. The
            entries of this key need to match the data_vars selected in mv_data.
            It is used to visualize the state of the attractor additionally to 
            the visualization kwargs. Only for 1d-bifurcation diagram.
            sub_keys: 

            - ``label`` (str, optional): label in plot
            - ``plot_kwargs`` (dict, optional): passed to scatter for every
                universe

                - color (str, recommended): unique color for every 
                    data_variable accross universes  
        **kwargs: Collection of optional dicts passed to different functions

            - plot_coords_kwargs (dict): Passed to ax.scatter to mark the
                universe's center in the bifurcation diagram
            - rectangle_map_kwargs (dict): Passed to 
                utopya.plot_funcs._utils.calc_pxmap_rectangles
            - legend_kwargs (dict): Passed to ax.legend
    """
    def resolve_analysis_steps(
        analysis_steps: Sequence[Union[str, Tuple[str, str]]]
    ) -> Sequence[Tuple[str, str]]:
        """Resolve instance of str to Tuple[str, str] in sequence
        
        Args:
            analysis_steps (Sequence[Union[str, Tuple[str, str]]]): The 
                original sequence
        
        Returns:
            analysis_steps (Sequence[Tuple[str, str]]): The sequence of
                attractor_key and analysis_func pairs.
        """
        for i, analysis_step in enumerate(analysis_steps):
            # get key and func for the analysis step
            if isinstance(analysis_step, str):
                analysis_steps[i] = [analysis_step, analysis_step]

        return analysis_steps

    def resolve_to_plot_kwargs(to_plot: dict) -> dict:
        """Resolves the to_plot dict, e.g. adding labels if not explicitly
        specified.
        
        Args:
            to_plot (dict): The to_plot dict to parse

        Returns:
            dara_vars_plot_kwargs (dict): A dict with the 'plot_kwargs' for 
                every data_var in to_plot.
        
        """
        if not to_plot:
            return {}

        data_vars_plot_kwargs = {}

        for k, v in to_plot.items():
            plot_kwargs = v.get("plot_kwargs", {})
            if not plot_kwargs.get('label'):
                plot_kwargs['label'] = v.get('label', k)
            data_vars_plot_kwargs[k] = plot_kwargs

        return data_vars_plot_kwargs

    def create_legend_handles(*, visualization_kwargs: dict,
                              data_vars_plot_kwargs: dict):
        """Creates legend handles
        
        Processes entries in data_vars_plot_kwargs (from to_plot) and
        visualization_kwargs.
        
        Args:
            visualization_kwargs (dict): The visualization kwargs
            data_vars_plot_kwargs (dict): The resolved entries in to_plot
        
        Returns:
            Tuple[list, list]: Tuple of legend handles and legend labels lists
                as required by ax.legend(handles, labels)
            data_vars_plot_kwargs: The updated entries data_vars_plot_kwargs
                as required by plot_attractor.
        """
        # Some defaults
        circle_kwargs = dict(xy=(.5, .5), radius=.25, edgecolor="none")
        rect_kwargs = dict(xy=(0., 0.), height=.75, width=1., edgecolor="none")

        # Lists to be populated for matplotlib legend
        legend_handles = []
        legend_labels = []

        for k, kwargs in data_vars_plot_kwargs.items():
            label = kwargs.pop('label', kwargs)
            kwargs['linewidth'] = kwargs.get('linewidth', 0.)
            data_vars_plot_kwargs[k] = kwargs

            # Determine color
            if 'color' in kwargs:
                color = kwargs['color']

            elif 'cmap' in kwargs:
                cmap = mpl.cm.get_cmap(kwargs['cmap'])
                color = cmap(1.)

            else:
                log.warning("No color defined for data_var '{}'!".format(k))
                color = None

            # Create and add the handle and the label
            legend_handles.append(Circle(**circle_kwargs, facecolor=color))
            legend_labels.append(label)

        for k, kwargs in visualization_kwargs.items():
            if 'to_plot' in kwargs:
                for dvar_name, dvar_kwargs in kwargs['to_plot'].items():
                    # Make sure a linewidth is set
                    dvar_kwargs['linewidth'] = dvar_kwargs.get('linewidth', 0.)
                    kwargs['to_plot'][dvar_name] = dvar_kwargs

                    # Create and add the handle and the label
                    legend_handles.append(
                        Rectangle(**rect_kwargs, **dvar_kwargs))
                    legend_labels.append(dvar_kwargs.get('label', dvar_name))
            else:
                kwargs['linewidth'] = kwargs.get('linewidth', 0.)
                data_vars_plot_kwargs[k] = kwargs

                legend_handles.append(Rectangle(**rect_kwargs, **kwargs))
                legend_labels.append(kwargs.get('label', k))

        return [legend_handles, legend_labels], data_vars_plot_kwargs

    def apply_analysis_steps(data: xr.Dataset,
                             analysis_steps: Sequence[Union[str, Tuple[str,
                                                                       str]]],
                             *, analysis_funcs: dict, analysis_kwargs: dict):
        """Perform the sequence of analysis steps until the first conclusive.
        
        Args:
            data (xr.Dataset): The data to analyse.
            analysis_steps (Sequence[Union[str, Tuple[str, str]]]): The analysis steps that are to be made
                until one is conclusive. Applied per universe.
            analysis_funcs (dict): The entries need to match the 
                analysis_steps. Map of the analysis_steps to their Callables
            analysis_kwargs (dict): The entries need to match the 
                analysis_steps. The subentry (dict) is passed on to the
                corresponding analysis function.
        
        Returns:
            analysis_key (str): The key of the conclusive analysis step.
            attractor (xr.Dataset): The data corresponding to this analysis.
        """
        for analysis_key, analysis_func in analysis_steps:
            analysis_func_kwargs = analysis_kwargs.get(analysis_func, {})

            # resolve the analysis function from its name
            if isinstance(analysis_func, str):
                if analysis_func in analysis_funcs:
                    analysis_func = analysis_funcs[analysis_func]

                else:
                    # Try to get it from dataprocessing ... might fail.
                    analysis_func = getattr(utopya.dataprocessing,
                                            analysis_func)

            # Perfom the analysis step
            conclusive, attractor = analysis_func(data, **analysis_func_kwargs)

            # Return if conclusive
            if conclusive:
                return analysis_key, attractor

        # Return non-conclusive
        return None, None

    def resolve_rectangle(coord: dict, rectangles: xr.Dataset) -> Rectangle:
        """Resolve the rectangle patch at this coordinate
        
        Args:
            coord (dict): The bifurcation parameter's coordinate
            rectangles (xr.Dataset): The rectangles that cover the 2D space
                spanned by the coordiantes. The `coord` should be one entry of
                rectangles.coords.
        
        Raises:
            ValueError: Coordinate not available in rectangles.
        
        Returns:
            Rectangle: A rectangle around a universe with coord and
                shape defined by rectangles.
        """
        try:
            rectangle = rectangles.sel(coord)

        except Exception as exc:
            raise ValueError("The requested paramspace coordinate(s) {} are "
                             "not coordinates of rectangles {}. Plot failed."
                             "".format(coord, rectangles.coords)) from exc

        rect_spec = rectangle['rect_spec']
        return Rectangle(*rect_spec.item())

    def append_vis_patch(attrator_key: str, attractor: xr.Dataset,
                         vis_patches: dict, vis_kwargs: dict,
                         **resolve_rectangle_args):
        """Append visualization patch

        Performs postprocess for

            - attractor key 'fixpoint' and 'endpoint' if: 'to_plot' in entry of
                vis_kwargs. Then finds the data_var with highest valued
                datapoint.
        
        Args:
            attractor_key (str): Key according to which to decode the attractor
            attractor (xr.Dataset): The Dataset with the encoded attractor 
                information. See possible encodings
            vis_patches (dict): The map of attractor_key to
                List[Rectangle] where to append the new patch
            vis_kwargs (dict): The visualization kwargs
            **resolve_rectangle_args: Passed on to resolve_rectangle
        
        Raises:
            ValueError: Bad postprocess key
        
        Returns:
            vis_patches (dict): The new map of attractor_key to
                List[Rectangle]
        
        Deleted Parameters:
            resolve_rectangle_args (dict): Args as required by
                resolve_rectangle
        """
        # Depending on the kind of attractor, add different patches
        kwargs = vis_kwargs.get(attractor_key)
        if kwargs is None:
            return vis_patches

        # Postprocess fixpoint and to_plot, append rectangle
        if ((attractor_key == 'fixpoint' or attractor_key == 'endpoint')
                and 'to_plot' in kwargs):
            max_value = -np.inf
            for data_var_name, data_var in attractor.data_vars.items():
                if data_var.max() > max_value:
                    max_value = data_var.max()
                    max_name = data_var_name

            rect = resolve_rectangle(**resolve_rectangle_args)

            attractor_var_key = attractor_key + '_' + max_name
            vis_patches[attractor_var_key].append(rect)

        # Append rectangle
        elif vis_patches.get(attractor_key) is not None:
            rect = resolve_rectangle(coord=coord, rectangles=rects)
            vis_patches[attractor_key].append(rect)

        return vis_patches

    def append_plot_attractor(attractor_key: str,
                              attractor: xr.Dataset,
                              *,
                              coord: float = None,
                              scatter_kwargs: list,
                              **plot_kwargs):
        """Resolves how to plot attractor of specified type 
        at specific bifurcation parameter value.

        Args:
            attractor_key (str): Key according to which to decode the attractor
            attractor (xr.Dataset): The Dataset with the encoded attractor 
                information. See possible encodings
            coord (float, optional): The bifurcation parameter's coordinate,
                if None its derived from the attractors coordinates
            scatter_kwargs (list): The list of scatter datasets where to append
                the new scatter.
            plot_kwargs (dict, optional): The kwargs used to specify ax.scatter
                where the entries match the attractor.data_vars 

        Possible encodings, i.e. values for ``attractor_key``:
            ``fixpoint``: xr.Dataset with dimensions ()
            ``scatter``: xr.Dataset with dimensions (time: >=1)
            ``multistability``: xr.Dataset with dimensions
                (<initial_state>: >= 1)
            ``oscillation``: xr.Dataset with dimensions (osc: 2), the minimum
                and maximum
        
        NOTE the attractor must contain the bifurcation parameter coordinate
        
        Raises:
            KeyError: Unknown attractor_key
            KeyError: No bifurcation coordinate received
            ValueError: Attractor encoding mismatched with the given 
                attractor_key

        Returns:
            scatter_kwargs: The new list of scatter datasets
        """
        # Get the bifurcation parameter coordinate
        if not coord:
            try:
                coord = attractor[dim]

            except KeyError as err:
                raise KeyError(
                    "No bifurcation parameter coordinate '{}' "
                    "could be found! Either have it as a "
                    "coordinate in 'attractor' or pass it to "
                    "'plot_attractor' explicitly.".format(dim)) from err

        # Resolve the scatter kwargs depending on attractor key
        if attractor_key in ('fixpoint', 'endpoint', 'multistability'):
            for data_var_name, data_var in attractor.data_vars.items():
                data_var = data_var.where(data_var != np.nan, drop=True)
                entries = 1
                if data_var.shape:
                    entries = len(data_var)

                scatter_kwargs.append(
                    dict(x=[coord] * entries,
                         y=data_var,
                         **plot_kwargs.get(data_var_name, {})))

        elif attractor_key == 'scatter':
            for data_var_name, data_var in attractor.data_vars.items():
                if 'cmap' in plot_kwargs.get(data_var_name, {}):
                    scatter_kwargs.append(
                        dict(x=[coord] * len(data_var.data),
                             y=data_var,
                             c=attractor['time'],
                             **plot_kwargs.get(data_var_name, {})))

                else:
                    scatter_kwargs.append(
                        dict(x=[coord] * len(data_var.data),
                             y=data_var,
                             **plot_kwargs.get(data_var_name, {})))

        elif attractor_key == 'oscillation':
            for data_var_name, data_var in attractor.data_vars.items():
                scatter_kwargs.append(
                    dict(x=[coord] * len(data_var.data),
                         y=data_var,
                         **plot_kwargs.get(data_var_name, {})))

        elif attractor_key:
            raise KeyError("Invalid attractor-key '{}'! "
                           "Available keys: 'endpoint', fixpoint',"
                           " 'multistability', 'scatter', 'oscillation'."
                           "".format(attractor_key))

        return scatter_kwargs

    def resolve_scatter(data: xr.Dataset,
                        *,
                        spin_up_time: int = 0,
                        **kwargs) -> tuple:
        """A mock analysis function to plot all times larger than a spin
        up time.
        """
        return True, data.where(data.time >= spin_up_time, drop=True)

    # .........................................................................

    # Check argument values
    if not dim and not dims:
        raise KeyError("No dim (str) or dims (Tuple[str, str]) specified. "
                       "Use dim for a 1d-bifurcation diagram and dims for a "
                       "2d-bifurcation diagram.")

    if dim and dims:
        raise KeyError("dim='{}' and dims='{}' specified. "
                       "Use either dim for a 1d-bifurcation diagram or dims "
                       "for a 2d-bifurcation diagram."
                       "".format(dim, dims))

    if dims is not None and len(dims) != 2:
        raise ValueError("Argument dims should be of length 2, but was: {}"
                         "".format(dims))

    # TODO In the future, consider not using `dim` below here but handling it
    #      via the length of `dims`.

    # Default values
    if visualization_kwargs is None:
        visualization_kwargs = {}

    if analysis_kwargs is None:
        analysis_kwargs = {}

    # Resolve legend handles and visualization kwargs
    data_vars_plot_kwargs = resolve_to_plot_kwargs(to_plot)
    legend_handles, data_vars_plot_kwargs = create_legend_handles(
        visualization_kwargs=visualization_kwargs,
        data_vars_plot_kwargs=data_vars_plot_kwargs)

    # Define default analysis functions
    analysis_funcs = dict(endpoint=utdp.find_endpoint,
                          fixpoint=utdp.find_fixpoint,
                          multistability=utdp.find_multistability,
                          oscillation=utdp.find_oscillation,
                          scatter=resolve_scatter)

    # If given, update
    if custom_analysis_funcs:
        log.debug("Updating with custom analysis functions ...")
        analysis_funcs = recursive_update(analysis_funcs,
                                          custom_analysis_funcs)

    analysis_steps = resolve_analysis_steps(analysis_steps)

    # Obtain the rectangles covering space spanned by the coordinates
    rectangle_map_kwargs = kwargs.get('rectangle_map_kwargs', {})
    if dim:
        rects, limits = calc_pxmap_rectangles(x_coords=mv_data[dim].values,
                                              y_coords=None,
                                              **rectangle_map_kwargs)
    elif dims:
        rects, limits = calc_pxmap_rectangles(x_coords=mv_data[dims[0]].values,
                                              y_coords=mv_data[dims[1]].values,
                                              **rectangle_map_kwargs)

    # Obtain the list of param_coords to iterate
    if dim:
        param_iter = mv_data[dim].values

    elif dims:
        param_iter = itertools.product(mv_data[dims[0]].values,
                                       mv_data[dims[1]].values)

    # Map of analysis_key to list[mpatch.Rectangle]
    vis_patches = {}
    for analysis_key, _ in analysis_steps:
        if not visualization_kwargs.get(analysis_key):
            continue

        if 'to_plot' in visualization_kwargs[analysis_key]:
            for var_key, _ in visualization_kwargs[analysis_key][
                    'to_plot'].items():
                analysis_var_key = analysis_key + '_' + var_key
                vis_patches[analysis_var_key] = []

        else:
            vis_patches[analysis_key] = []

    # The List[dict] passed to ax.scatter
    scatter_kwargs = []
    scatter_coords_kwargs = []

    # Iterate the parameter coordinates
    for param_coord in param_iter:
        # Resolve the param_coord to dict
        if dim:
            param_coord = {dim: param_coord}

        elif dims:
            param_coord = {dims[0]: param_coord[0], dims[1]: param_coord[1]}

        # Plot coord
        if dim and kwargs.get('plot_coords_kwargs'):
            plot_coords_kwargs = kwargs.get('plot_coords_kwargs')
            scatter_coords_kwargs.append({
                'x': param_coord[dim],
                'y': plot_coords_kwargs.pop('y', 0.),
                **plot_coords_kwargs
            })

        if dims and kwargs.get('plot_coords_kwargs'):
            scatter_coords_kwargs.append({
                'x': param_coord[dims[0]],
                'y': param_coord[dims[1]],
                **kwargs.get('plot_coords_kwargs')
            })

        # Select the data and analyse
        data = mv_data.sel(param_coord)
        attractor_key, attractor = apply_analysis_steps(
            data,
            analysis_steps,
            analysis_funcs=analysis_funcs,
            analysis_kwargs=analysis_kwargs)

        # If conclusive, append a rectangular patch to the attractor_key's
        # patch collection
        if attractor_key:
            # Determine coordinate value
            if dim:
                rect_map_kwargs = kwargs.get('rectangle_map_kwargs', {})
                y = rect_map_kwargs.get('default_pos', (0., 0.))[1]
                coord = dict(x=param_coord[dim], y=y)

            elif dims:
                coord = dict(x=param_coord[dims[0]], y=param_coord[dims[1]])

            vis_patches = append_vis_patch(attractor_key,
                                           attractor,
                                           vis_patches,
                                           visualization_kwargs,
                                           coord=coord,
                                           rectangles=rects)

        # For 1d case ...
        if dim and to_plot:
            scatter_kwargs = append_plot_attractor(
                attractor_key,
                attractor,
                coord=param_coord[dim],
                scatter_kwargs=scatter_kwargs,
                **data_vars_plot_kwargs)

    # Draw collection of visualization patches
    for analysis_key, _ in analysis_steps:
        if not visualization_kwargs.get(analysis_key):
            continue

        if analysis_key in visualization_kwargs:
            vis_kwargs = visualization_kwargs[analysis_key]

            if 'to_plot' in vis_kwargs:
                for var_key, var_kwargs in vis_kwargs['to_plot'].items():
                    attractor_var_key = analysis_key + '_' + var_key

                    pc = PatchCollection(vis_patches[attractor_var_key],
                                         **var_kwargs)
                    hlpr.ax.add_collection(pc)

            else:
                pc = PatchCollection(vis_patches[analysis_key], **vis_kwargs)
                hlpr.ax.add_collection(pc)

        # else: nothing to do

    # Scatter the universe's coordinates
    for kws in scatter_coords_kwargs:
        hlpr.ax.scatter(**kws)

    # Scatter the attractor
    for kws in scatter_kwargs:
        hlpr.ax.scatter(**kws)

    # Provide PlotHelper defaults
    hlpr.provide_defaults('set_limits', **limits)

    if dim:
        hlpr.provide_defaults('set_labels', x=dim, y='state')
    elif dims:
        hlpr.provide_defaults('set_labels', x=dims[0], y=dims[1])

    if legend_handles:
        hlpr.ax.legend(legend_handles[0],
                       legend_handles[1],
                       handler_map={Circle: HandlerEllipse()},
                       **kwargs.get('legend_kwargs', {}))
예제 #7
0
def opinion_at_time(dm: DataManager,
                    *,
                    hlpr: PlotHelper,
                    uni: UniverseGroup,
                    age_groups: list = [10, 20, 40, 60, 80],
                    num_bins: int = 100,
                    plot_kwargs: dict = {},
                    time_step: float,
                    title: str = None,
                    to_plot: str,
                    val_range: tuple = (0., 1.)):
    """Plots opinion state at a specific time frame. If the model mode is 'ageing',
    the opinions of the specified age groups are plotted, in all other cases
    the opinions of the groups are shown.

    Arguments:
        age_groups (list): The age intervals to be plotted in the axs[2]
            distribution plot for the 'ageing' mode.
        time_step (int): the time frame to plot (as a fraction of the total length)
        title (str, optional): Custom plot title
        to_plot (str): whether or not to differentiate by groups

    Raises:
        TypeError: if the age groups have fewer than two entries (in which
        case no binning is possible)
        ValueError: if to_plot or time_step are invalid
    """
    if len(age_groups) < 2:
        raise TypeError("'Age groups' list needs at least two entries!")

    if to_plot not in ['overall', 'by_group', 'discriminators']:
        raise ValueError(f"Unrecognized argument {to_plot}: must be one of "
                         "'overall' or 'by_group'")
    if time_step < 0 or time_step > 1:
        raise ValueError("time_step must be in [0, 1]")

    #datasets...................................................................
    mode = uni['cfg']['OpDisc']['mode']
    ageing = True if mode == 'ageing' else False
    time_idx = int(time_step *
                   (uni['data/OpDisc/nw/opinion']['time'].size - 1))
    opinions = uni['data/OpDisc/nw/opinion']
    if ageing:
        groups = np.asarray(uni['data/OpDisc/nw/group_label'], dtype=int)
    else:
        groups = np.asarray(uni['data/OpDisc/nw/group_label'][0, :], dtype=int)
    num_groups = len(
        age_groups) - 1 if ageing else uni['cfg']['OpDisc']['number_of_groups']
    group_list = age_groups if ageing else [_ for _ in range(num_groups)]
    time = uni['data/OpDisc/nw/opinion'].coords['time'].data

    #figure setup ..............................................................
    figure, axs = setup_figure(uni['cfg'], plot_name='opinion')
    hlpr.attach_figure_and_axes(fig=figure, axes=axs)
    hlpr.select_axis(0, 1)

    # data analysis and plotting................................................
    if to_plot == 'by_group':
        #get opinions by group
        to_plot = np.zeros((num_bins, num_groups))
        data_by_groups = data_by_group(opinions,
                                       groups,
                                       group_list,
                                       val_range,
                                       num_bins,
                                       ageing=ageing)

        #calculate a histogram of the opinion distribution at each time step
        for k in range(num_groups):
            counts, _ = np.histogram(data_by_groups[k][time_idx],
                                     bins=num_bins,
                                     range=val_range)
            to_plot[:, k] = counts[:]

        #get pretty labels
        if ageing:
            labels = [
                f"Ages {group_list[_]}-{group_list[_+1]}"
                for _ in range(num_groups)
            ]
            max_age = np.amax(groups)
            if (age_groups[-1] >= max_age):
                labels[-1] = f"Ages {group_list[-2]}+"
        else:
            labels = [f"Group {_+1}" for _ in group_list]
        X = pd.DataFrame(to_plot[:, :], columns=labels)
        X.plot.bar(stacked=True,
                   ax=hlpr.ax,
                   legend=False,
                   rot=0,
                   **plot_kwargs)
        hlpr.ax.legend(bbox_to_anchor=(1, 1.01),
                       loc='lower right',
                       ncol=num_groups,
                       fontsize='xx-small')
        hlpr.ax.set_xticks([i for i in np.linspace(0, num_bins - 1, 11)])
        hlpr.ax.set_xticklabels(
            [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1])

    elif to_plot == 'overall':
        hlpr.ax.hist(opinions[time_idx][:],
                     bins=num_bins,
                     alpha=1,
                     **plot_kwargs)

    elif to_plot == 'discriminators':
        discriminators = uni['data/OpDisc/nw/discriminators']
        mask = np.empty(opinions.shape, dtype=bool)
        mask[:, :] = (discriminators == 0)
        ops_disc = np.ma.MaskedArray(opinions, mask)
        ops_nondisc = np.ma.MaskedArray(opinions, ~mask)
        hlpr.ax.hist(ops_disc[time_idx].compressed()[:],
                     bins=num_bins,
                     alpha=0.5,
                     **plot_kwargs,
                     label='disc')
        hlpr.ax.hist(ops_nondisc[time_idx].compressed()[:],
                     bins=num_bins,
                     alpha=0.5,
                     **plot_kwargs,
                     label='non-disc')
        hlpr.ax.hist(opinions[time_idx][:],
                     bins=num_bins,
                     alpha=1,
                     **plot_kwargs,
                     histtype='step')
        hlpr.ax.legend(bbox_to_anchor=(1, 1.01),
                       loc='lower right',
                       ncol=num_groups,
                       fontsize='xx-small')
        hlpr.ax.set_xlim(val_range[0], val_range[1])

    hlpr.ax.text(0.02,
                 0.97,
                 f'step {time[time_idx]}',
                 transform=hlpr.ax.transAxes,
                 fontsize='xx-small')
예제 #8
0
def op_groups(dm: DataManager, *,
              uni: UniverseGroup,
              hlpr: PlotHelper,
              age_groups: list=[10, 20, 40, 60, 80],
              num_bins: int=100,
              time_idx: int=None,
              title: str=None,
              val_range: tuple=(0., 1.)):
    """Plots an animated stacked histogram of the opinion distribution of
       each group.

    Arguments:
        age_groups (list): The age intervals to be plotted in the final_ax
            distribution plot for the 'ageing' mode.
        num_bins(int): Binning of the histogram
        time_idx (int, optional): Only plot one single frame (eg. last frame)
        title (str, optional): Custom plot title
        val_range(int, optional): Value range of the histogram

    Raises:
        TypeError: if the 'age_groups' list does not contain at least two
            entries
    """

    if len(age_groups)<2:
        raise TypeError("'age_groups' list must contain at least 2 entries!")

    #figure setup..............................................................
    figure, axs = setup_figure(uni['cfg'], plot_name='op_groups', title=title)
    hlpr.attach_figure_and_axes(fig=figure, axes=axs)
    hlpr.select_axis(0, 1)

    #datasets ..................................................................
    ageing = True if uni['cfg']['OpDisc']['mode'] == 'ageing' else False
    if ageing:
        groups = np.asarray(uni['data/OpDisc/nw/group_label'], dtype=int)
    else:
        groups = np.asarray(uni['data/OpDisc/nw/group_label'][0, :], dtype=int)
    num_groups = len(age_groups)-1 if ageing else uni['cfg']['OpDisc']['number_of_groups']
    group_list = age_groups if ageing else [_ for _ in range(num_groups)]
    opinions = uni['data/OpDisc/nw/opinion']
    time = opinions.coords['time'].data
    time_steps = opinions.coords['time'].size

    #data analysis .............................................................
    #get opinions by group
    to_plot = np.zeros((time_steps, num_bins, num_groups))
    data_by_groups = data_by_group(opinions, groups, group_list, val_range,
                                   num_bins, ageing=ageing)
    #calculate a histogram of the opinion distribution at each time step
    for t in range(time_steps):
        for k in range(num_groups):
            counts, _ = np.histogram(data_by_groups[k][t], bins=num_bins,
                                     range=val_range)
            to_plot[t, :, k] = counts[:]

    #get pretty labels
    if ageing:
        labels = [f"Ages {group_list[_]}-{group_list[_+1]}" for _ in range(num_groups)]
        max_age = np.amax(groups)
        if (age_groups[-1]>=max_age):
            labels[-1]=f"Ages {group_list[-2]}+"
    else:
        labels = [f"Group {_+1}" for _ in group_list]

    #create pandas df for stacked bar plot
    if not time_idx:
        X = [pd.DataFrame(to_plot[_, :, :], columns=labels) for _ in range(time_steps)]
    else:
        X = pd.DataFrame(to_plot[time_idx, :, :], columns=labels)

    #plotting ..................................................................
    #plot an animated stacked bar chart. Since there is no 'set height' function
    #for pandas charts, we need to clear the axis and entirely reformat the plot
    #for every frame. Clearing is necessary or else successive frames are simply
    #superimposed
    def update_data(stepsize: int=1):
        """Updates the data of the imshow objects"""
        if time_idx:
            log.info(f"Plotting discribution at time step {time[time_idx]} ...")
        else:
            log.info(f"Plotting animation with {opinions.shape[0]//stepsize} frames ...")
        next_frame_idx = 0
        if time_steps < stepsize:
            log.warn("Stepsize is greater than number of steps. "
                          "Continue by plotting fist and last frame.")
            stepsize=time_steps-1
        for t in range(time_steps):
            if t < next_frame_idx:
                continue
            hlpr.ax.clear()
            hlpr.ax.set_xlim(0, 1)
            if time_idx:
                t = time_idx
                im = X.plot.bar(stacked=True, ax=hlpr.ax, legend=False, rot=0)
            else:
                im = X[t].plot.bar(stacked=True, ax=hlpr.ax, legend=False, rot=0)
            time_text = hlpr.ax.text(0.02, 0.97, '', transform=hlpr.ax.transAxes,
                                     fontsize ='xx-small')
            time_text.set_text(f'step {time[t]}')
            hlpr.ax.legend(bbox_to_anchor=(1, 1.01), loc='lower right',
                                           ncol=num_groups, fontsize='xx-small')
            hlpr.ax.set_xticks([i for i in np.linspace(0, num_bins-1, 11)])
            hlpr.ax.set_xticklabels([0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1])
            hlpr.ax.set_xlabel(hlpr.axis_cfg['set_labels']['x'])
            hlpr.ax.set_ylabel(hlpr.axis_cfg['set_labels']['y'])
            if time_idx:
                yield
                break
            next_frame_idx = t + stepsize
            yield
    hlpr.register_animation_update(update_data)
예제 #9
0
파일: graph.py 프로젝트: utopia-foss/utopia
def draw_graph(*,
               hlpr: PlotHelper,
               data: dict,
               graph_group_tag: str = "graph_group",
               graph: Union[nx.Graph, xr.DataArray] = None,
               graph_creation: dict = None,
               graph_drawing: dict = None,
               graph_animation: dict = None,
               register_property_maps: Sequence[str] = None,
               clear_existing_property_maps: bool = True,
               suptitle_kwargs: dict = None):
    """Draws a graph either from a :py:class:`~utopya.datagroup.GraphGroup` or
    directly from a ``networkx.Graph`` using the
    :py:class:`~utopya.plot_funcs._graph.GraphPlot` class.

    If the graph object is to be created from a graph group the latter needs to
    be selected via the TransformationDAG. Additional property maps can also be
    made available for plotting, see ``register_property_map`` argument.
    Animations can be created either from a graph group by using the select
    interface in ``graph_animation`` or by passing a DataArray of networkx
    graphs via the ``graph`` argument.

    For more information on how to use the transformation framework, refer to
    the `dantro documentation <https://dantro.readthedocs.io/en/stable/plotting/plot_data_selection.html>`_.

    For more information on how to configure the graph layout refer to the
    :py:class:`~utopya.plot_funcs._graph.GraphPlot` documentation.

    Args:
        hlpr (PlotHelper): The PlotHelper instance for this plot
        data (dict): Data from TransformationDAG selection
        graph_group_tag (str, optional): The TransformationDAG tag of the graph
            group
        graph (Union[nx.Graph, xr.DataArray], optional): If given, the ``data``
            and ``graph_creation`` arguments are ignored and this graph is
            drawn directly.
            If a DataArray of graphs is given, the first graph is drawn for a
            single graph plot. In animation mode the (flattened) array
            represents the animation frames.
        graph_creation (dict, optional): Configuration of the graph creation.
            Passed to ``GraphGroup.create_graph``.
        graph_drawing (dict, optional): Configuration of the graph layout.
            Passed to :py:class:`~utopya.plot_funcs._graph.GraphPlot`.
        graph_animation (dict, optional): Animation configuration. The
            following arguments are allowed:

            times (dict, optional):
                *Deprecated*: Equivaluent to a sel.time entry.
            sel (dict, optional):
                Select by value. Dictionary with dimension names as keys. The
                values may either be coordinate values or a dict with a single
                ``from_property`` (str) entry which specifies a container
                withing the GraphGroup or registered external data from which
                the coordinates are extracted.
            isel (dict, optional):
                Select by index. Coordinate indices keyed by dimension. May be
                given together with ``sel`` if no key appears in both.
            update_positions (bool, optional):
                Whether to update the node positions for each frame by
                recalculating the layout with the parameters specified in
                graph_drawing.positions. If this parameter is not given or
                false, the positions are calculated once initially and then
                fixed.
            update_colormapping (bool, optional):
                Whether to reconfigure the nodes' and edges'
                :py:class:`~utopya.plot_funcs._mpl_helpers.ColorManager` for
                each frame (default=False). If False, the colormapping (and the
                colorbar) is configured with the first frame and then fixed.
            skip_empty_frames (bool, optional):
                Whether to skip the frames where the selected graph is missing
                or of a type different than ``nx.Graph`` (default=False).
                If False, such frames are empty.

        register_property_maps (Sequence[str], optional): Names of properties
            to be registered in the graph group before the graph creation.
            The property names must be valid TransformationDAG tags, i.e., be
            available in ``data``. Note that the tags may not conflict with any
            valid path reachable from inside the selected ``GraphGroup``.
        clear_existing_property_maps (bool, optional): Whether to clear any
            existing property maps from the selected ``GraphGroup``. This is
            enabled by default to reduce side effects from previous plots.
            Set this to False if you have property maps registered with the
            GraphGroup that you would like to keep.
        suptitle_kwargs (dict, optional): Passed on to the PlotHelper's
            ``set_suptitle`` helper function. Only used in animation mode.
            The ``title`` can be a format string containing a placeholder with
            the dimension name as key for each dimension along which selection
            is done. The format string is updated for each frame of the
            animation. The default is ``<dim-name> = {<dim-name>}`` for each
            dimension.

    Raises:
        ValueError: On invalid or non-computed TransformationDAG tags in
            ``register_property_maps`` or invalid graph group tag.
    """
    # Work on a copy such that the original configuration is not modified
    graph_drawing = copy.deepcopy(graph_drawing if graph_drawing else {})

    # Get the sub-configurations for the drawing of the graph
    select = graph_drawing.get("select", {})
    pos_kwargs = graph_drawing.get("positions", {})
    node_kwargs = graph_drawing.get("nodes", {})
    edge_kwargs = graph_drawing.get("edges", {})
    node_label_kwargs = graph_drawing.get("node_labels", {})
    edge_label_kwargs = graph_drawing.get("edge_labels", {})
    mark_nodes_kwargs = graph_drawing.get("mark_nodes", {})
    mark_edges_kwargs = graph_drawing.get("mark_edges", {})

    def get_dag_data(tag):
        try:
            return data[tag]
        except KeyError as err:
            _available_tags = ", ".join(data.keys())
            raise ValueError(
                f"No tag '{tag}' found in the data selected by the DAG! Make "
                "sure the tag is named correctly and is selected to be "
                "computed; adjust the 'compute_only' argument if needed."
                "\nThe following tags are available in the DAG results: "
                f"{_available_tags}") from err

    # Prepare graph group and external property data
    graph_group = get_dag_data(graph_group_tag) if graph is None else None
    property_maps = None
    if register_property_maps:
        property_maps = {}
        for tag in register_property_maps:
            property_maps[tag] = get_dag_data(tag)

    # If not in animation mode, make a single graph plot
    if not hlpr.animation_enabled:
        # Set up a GraphPlot instance
        if graph_group is not None:
            # Create GraphPlot from graph group
            gp = GraphPlot.from_group(
                graph_group=graph_group,
                graph_creation=graph_creation,
                register_property_maps=property_maps,
                clear_existing_property_maps=clear_existing_property_maps,
                fig=hlpr.fig,
                ax=hlpr.ax,
                **graph_drawing,
            )

        else:
            if graph_creation is not None:
                warnings.warn(
                    "Received both a 'graph' argument and a 'graph_creation' "
                    "configuration. The latter will be ignored. To remove "
                    "this warning set graph_creation to None.", UserWarning)

            if isinstance(graph, xr.DataArray):
                # Use the first array element for a single graph plot
                g = graph.values.flat[0]

                if graph.size > 1:
                    log.caution(
                        "Received a DataArray of size %d as 'graph' argument, "
                        "performing a single plot using the first entry. "
                        "Animations must be enabled by setting "
                        "animation.enabled to True.",
                        graph.size,
                    )

            else:
                g = graph

            # Create a GraphPlot from a nx.Graph
            gp = GraphPlot(g=g, fig=hlpr.fig, ax=hlpr.ax, **graph_drawing)

        # Make the actual plot via the GraphPlot
        gp.draw()

    # In animation mode, register the animation frame generator
    else:
        # Prepare animation kwargs for the update routine
        graph_animation = graph_animation if graph_animation else {}

        def update():
            """The animation frames generator.
            See :py:meth:`~utopya.plot_funcs.dag.graph.graph_animation_update`.
            """
            if graph_group is None and not isinstance(graph, xr.DataArray):
                raise TypeError(
                    "Failed to create animation due to invalid type of the "
                    "'graph' argument. Required: xr.DataArray (filled "
                    f"with graph objects). Received: {type(graph)}")

            yield from graph_animation_update(
                hlpr=hlpr,
                graphs=graph if isinstance(graph, xr.DataArray) else None,
                graph_group=graph_group,
                graph_creation=graph_creation,
                register_property_maps=property_maps,
                clear_existing_property_maps=clear_existing_property_maps,
                suptitle_kwargs=suptitle_kwargs,
                animation_kwargs=graph_animation,
                **graph_drawing,
            )

        hlpr.register_animation_update(update)
예제 #10
0
def opinion_animation(dm: DataManager,
                      *,
                      uni: UniverseGroup,
                      hlpr: PlotHelper,
                      time_idx: int,
                      to_plot: str,
                      num_bins: int = 100,
                      val_range: tuple = (0., 1.)):

    opinions = uni['data/OpDyn/nw_users/' + to_plot]
    life_cycle = int(uni['cfg']['OpDyn']['life_cycle'])
    time = opinions['time'].data
    time_steps = time.size
    if not time_idx:
        time_idx = int(next(iter([i for i in range(time_steps)])))
    if time_idx == -1:
        opinions_at_time = opinions[-1, :]

    else:
        opinions_at_time = opinions[time_idx, :]

    bins = num_bins if num_bins else 100
    start, stop = val_range if val_range else (0., 1.)

    # Calculate the opinion histogram at time_idx
    counts_at_time, bin_edges = np.histogram(opinions_at_time,
                                             range=(start, stop),
                                             bins=bins)

    # Calculate bin positions, i.e. midpoint of bin edges
    bin_pos = bin_edges[:-1] + (np.diff(bin_edges) / 2.)
    bar = hlpr.ax.bar(bin_pos,
                      counts_at_time,
                      width=np.diff(bin_edges),
                      color='dodgerblue')
    time_text = hlpr.ax.text(0.02, 0.95, '', transform=hlpr.ax.transAxes)

    def update_data(stepsize: int = 1):
        """Updates the data of the imshow objects"""
        log.info("Plotting animation with %d frames ...",
                 opinions.shape[0] // stepsize)

        next_frame_idx = 0

        if time_steps < stepsize:
            warnings.warn("Stepsize is greater than number of steps. "
                          "Continue by plotting fist and last frame.")

        for i in range(time_steps):
            time_idx = i
            if time_idx < next_frame_idx and time_idx < time_steps:
                continue

            log.debug("Plotting frame for time index %d ...", time_idx)

            # Get the opinions
            opinion = opinions[time_idx, :]

            counts_step, _ = np.histogram(opinion, bin_edges)

            for idx, rect in enumerate(bar):
                rect.set_height(counts_step[idx])

            if (uni['cfg']['OpDyn']['user_ageing'] == 'on'):
                time_text.set_text('%.0f years' %
                                   int(time[time_idx] / life_cycle))
            else:
                time_text.set_text('step %.0f' % time[time_idx])

            hlpr.ax.relim()
            hlpr.ax.autoscale_view(scalex=False)

            next_frame_idx = time_idx + stepsize

            yield

    # Register this update method with the helper, which takes care of the rest
    hlpr.register_animation_update(update_data)
예제 #11
0
def opinion_time_series(*,
                        hlpr: PlotHelper,
                        data: dict,
                        bins: int = 100,
                        opinion_range=None,
                        representatives: dict = None,
                        density_kwargs: dict = None,
                        hist_kwargs: dict = None):
    """Plots the temporal development of the opinion density and the final
    opinion distribution.

    Args:
        hlpr (PlotHelper): The Plot Helper
        data (dict): The data from DAG selection
        bins (int, optional): The number of bins for the histograms
        opinion_range (tuple, optional): range of opinions to be plotted. If
            none provided, the opinion space from the config is used.
        representatives (dict, optional): kwargs for representative
            trajectories. Possible configurations:
                'enabled': Whether to show representative trajectories
                    (default: false).
                'max_reps': The maximum total number of chosen representatives
                'rep_threshold': Lower threshold for the final group size above
                    which a second representative per group is allowed
        density_kwargs (dict, optional): Passed to plt.imshow (density plot)
        hist_kwargs (dict, optional): Passed to plt.hist (final distribution)
    """
    opinions = data['opinion']
    final_opinions = opinions.isel(time=-1)
    cfg_op_space = data['opinion_space']
    if opinion_range is not None:
        opinion_range = opinion_range
    else:
        if (cfg_op_space['type']) == 'continuous':
            opinion_range = cfg_op_space['interval']
        elif (cfg_op_space['type']) == 'discrete':
            opinion_range = tuple((0, cfg_op_space['num_opinions']))

    num_vertices = opinions.coords['vertex_idx'].size
    time = opinions.coords['time'].data

    density_kwargs = density_kwargs if density_kwargs else {}
    hist_kwargs = hist_kwargs if hist_kwargs else {}
    rep_kwargs = representatives if representatives else {}

    # Plot an opinion density heatmap over time ...............................

    densities = np.empty((bins, len(time)))

    for i in range(len(time)):
        densities[:, i] = np.histogram(opinions.isel(time=i),
                                       range=opinion_range,
                                       bins=bins)[0]

    density_plot = hlpr.ax.imshow(densities,
                                  extent=[time[0], time[-1], *opinion_range],
                                  **density_kwargs)

    hlpr.provide_defaults("set_limits", y=opinion_range)
    hlpr.ax.get_yaxis().set_minor_locator(ticker.MultipleLocator(0.1))

    # Add colorbar for the density plot
    plt.colorbar(density_plot, ax=hlpr.ax)

    # Plot representative trajectories ........................................
    # The representative vertices are chosen based on the final opinion groups.
    # Reps are first picked for the largest opinion groups.

    if (rep_kwargs.get('enabled', False)):

        max_reps = rep_kwargs.get('max_reps', num_vertices)
        rep_threshold = rep_kwargs.get('rep_threshold', 1)

        reps = []  # store the vertex ids of the representatives here

        hist, bin_edges = np.histogram(final_opinions,
                                       range=opinion_range,
                                       bins=bins)

        # idxs that sort hist in descending order
        sort_idxs = np.argsort(hist)[::-1]
        # sorted hist values
        hist_sorted = hist[sort_idxs]
        # lower bin edges sorted from highest to lowest respective hist value
        # (restricted to non-empty bins)
        bins_to_rep = bin_edges[:-1][sort_idxs][hist_sorted > 0]

        # Generate a random sequence of vertex indices
        rand_vertex_idxs = np.arange(num_vertices)
        np.random.shuffle(rand_vertex_idxs)

        # Find representative vertices
        for b, n in zip(bins_to_rep, range(max_reps)):
            for v in rand_vertex_idxs:
                if (final_opinions.isel(vertex_idx=v) > b
                        and final_opinions.isel(vertex_idx=v) < b + 1. / bins):
                    reps.append(v)
                    break

        # If more reps available, add a second rep for larger opinion groups
        if (len(reps) < max_reps):
            for h, b, n in zip(hist_sorted, bins_to_rep,
                               range(max_reps - len(reps))):
                if (h > rep_threshold):
                    for v in rand_vertex_idxs:
                        if (final_opinions.isel(vertex_idx=v) > b
                                and final_opinions.isel(vertex_idx=v) <
                                b + 1. / bins and v not in reps):
                            reps.append(v)
                            break

        # Plot temporal opinion development for all representatives
        for r in reps:
            hlpr.ax.plot(time, opinions.isel(vertex_idx=r), lw=0.8)

    # Plot histogram of the final opinions  ...................................

    hlpr.select_axis(1, 0)
    hist = hlpr.ax.hist(final_opinions,
                        range=opinion_range,
                        bins=bins,
                        **hist_kwargs)
    hlpr.provide_defaults("set_limits", y=opinion_range)
예제 #12
0
def histogram(*,
              data: dict,
              hlpr: PlotHelper,
              x: str,
              hue: str = None,
              frames: str = None,
              coarsen_by: int = None,
              align: str = 'edge',
              bin_widths: Union[str, Sequence[float]] = None,
              suptitle_kwargs: dict = None,
              show_histogram_info: bool = True,
              **bar_kwargs):
    """Shows a distribution as a stacked bar plot, allowing animation.

    Expects as DAG result ``counts`` an xr.DataArray of one, two, or three
    dimensions. Depending on the ``hue`` and ``frames`` arguments, this will
    be represented as a stacked barplot and as an animation, respectively.

    Args:
        data (dict): The DAG results
        hlpr (PlotHelper): The PlotHelper
        x (str): The name of the dimension that represents the position of the
            histogram bins. By default, these are the bin *centers*.
        hue (str, optional): Which dimension to represent by stacking bars of
            different hue on top of each other
        frames (str, optional): Which dimension to represent by animation
        coarsen_by (int, optional): By which factor to coarsen the dimension
            specified by ``x``. Uses xr.DataArray.coarsen and pads boundary
            values.
        align (str, optional): Where to align bins. By default, uses ``edge``
            for alignment, as this is more exact for histograms.
        bin_widths (Union[str, Sequence[float]], optional): If not given, will
            use the difference between the ``x`` coordinates as bin widths,
            padding on the right side using the last value
            If a string, assume that it is a DAG result and retrieve it from
            ``data``. Otherwise, use it directly for the ``width`` argument of
            ``plt.bar``, i.e. assume it's a scalar or a sequence of bin widths.
        suptitle_kwargs (dict, optional): Description
        show_histogram_info (bool, optional): Whether to draw a box with
            information about the histogram.
        **bar_kwargs: Passed on ``hlpr.ax.bar`` invocation

    Returns:
        None

    Raises:
        ValueError: Bad dimensionality or missing ``bin_widths`` DAG result
    """
    def stacked_bar_plot(ax, dists: xr.DataArray, bin_widths):
        """Given a 2D xr.DataArray, plots a stacked barplot"""
        bottom = None  # to keep track of the bottom edges for stacking

        # Create the iterator
        if hue:
            hues = [c.item() for c in dists.coords[hue]]
            original_sorting = lambda c: hues.index(c[0])
            dist_iter = sorted(dists.groupby(hue), key=original_sorting)
        else:
            dist_iter = [(None, dists)]

        # Create the plots for each hue value
        for label, dist in dist_iter:
            dist = dist.squeeze(drop=True)
            ax.bar(dist.coords[x],
                   dist,
                   align=align,
                   width=bin_widths,
                   bottom=bottom,
                   label=label,
                   **bar_kwargs)
            bottom = dist.data if bottom is None else bottom + dist.data

        # Annotate it
        if not show_histogram_info:
            return
        total_sum = dists.sum().item()
        hlpr.ax.text(1,
                     1, (f"$N_{{bins}} = {dist.coords[x].size}$, "
                         fr"$\Sigma_{{{x}}} = {total_sum:.4g}$"),
                     transform=hlpr.ax.transAxes,
                     verticalalignment='bottom',
                     horizontalalignment='right',
                     fontdict=dict(fontsize="smaller"),
                     bbox=dict(facecolor="white", linewidth=.5, pad=2))

    # Retrieve the data
    dists = data['counts']

    # Check expected dimensions
    expected_ndim = 1 + bool(hue) + bool(frames)
    if dists.ndim != expected_ndim:
        raise ValueError(f"With `hue: {hue}` and `frames: {frames}`, expected "
                         f"{expected_ndim}-dimensional data, but got:\n"
                         f"{dists}")

    # Calculate bin widths
    if bin_widths is None:
        bin_widths = dists.coords[x].diff(x)
        bin_widths = bin_widths.pad({x: (0, 1)}, mode='edge')

    elif isinstance(bin_widths, str):
        log.remark("Using DAG result '%s' for bin widths ...", bin_widths)
        try:
            bin_widths = data[bin_widths]
        except KeyError:
            raise ValueError(f"No DAG result '{bin_widths}' available for bin "
                             "widths. Make sure `compute_only` is set such "
                             "that the result will be computed.")

    # Allow dynamically plotting without animation
    if not frames:
        hlpr.disable_animation()
        stacked_bar_plot(hlpr.ax, dists, bin_widths)
        return
    # else: want an animation. Everything below here is only for that case.
    hlpr.enable_animation()

    # Determine the maximum, such that the scale is always the same
    max_counts = dists.sum(hue).max() if hue else dists.max()

    # Prepare some parameters for the update routine
    suptitle_kwargs = suptitle_kwargs if suptitle_kwargs else {}
    if 'title' not in suptitle_kwargs:
        suptitle_kwargs['title'] = "{dim:} = {value:d}"

    # Define an animation update function. All frames are plotted therein.
    # There is no need to plot the first frame _outside_ the update function,
    # because it would be discarded anyway.
    def update():
        """The animation update function: a python generator"""
        log.note("Commencing histogram animation for %d time steps ...",
                 len(dists.coords[frames]))

        for t, _dists in dists.groupby(frames):
            # Plot a frame onto an empty canvas
            hlpr.ax.clear()
            stacked_bar_plot(hlpr.ax, _dists, bin_widths)

            # Set the y-limits
            hlpr.invoke_helper('set_limits', y=[0, max_counts * 1.05])

            # Apply the suptitle format string, then invoke the helper
            st_kwargs = copy.deepcopy(suptitle_kwargs)
            st_kwargs['title'] = st_kwargs['title'].format(dim='time', value=t)
            hlpr.invoke_helper('set_suptitle', **st_kwargs)

            # Done with this frame. Let the writer grab it.
            yield

    # Register the animation update with the helper
    hlpr.register_animation_update(update, invoke_helpers_before_grab=True)
예제 #13
0
def age_groups(dm: DataManager,
               *,
               uni: UniverseGroup,
               hlpr: PlotHelper,
               time_idx: int,
               to_plot: str,
               num_bins: int = 50,
               val_range: tuple = (0., 1.),
               ages: list):

    data = uni['data/OpDyn/nw_users/' + to_plot]
    user_ages = uni['data/OpDyn/nw_users/age_u']
    life_cycle = int(uni['cfg']['OpDyn']['life_cycle'])
    time_steps = data.coords['time'].size
    time = data.coords['time'].data
    labels = [(str(ages[i]) + "–" + str(ages[i + 1]))
              for i in range(len(ages) - 1)]
    if (ages[-1] > 120):
        labels[-1] = str(ages[-2]) + "+"

    if not time_idx:
        time_idx = int(next(iter(np.linspace(0, time_steps - 1, 1))))

    bins = num_bins if num_bins else 50
    start, stop = val_range if val_range else (0., 1.)

    # group data by age bins
    data_by_age = np.zeros((time_steps, bins, len(ages) - 1))
    for t in range(time_steps):
        data_bins = pd.cut(np.asarray(data[t, :]),
                           bins,
                           labels=False,
                           include_lowest=True)
        age_bins = pd.cut(np.asarray(user_ages[t, :]),
                          ages,
                          labels=False,
                          include_lowest=True)
        for i in range(len(data_bins)):
            data_by_age[t, int(data_bins[i]), int(age_bins[i])] += 1

    X=[pd.DataFrame(data_by_age[i, :, :], \
                    index=[np.around(np.linspace(start, stop, bins), 3)], \
                    columns=labels) \
                    for i in range(time_steps)]
    log.info("Created %d dataframes for plotting", len(X))

    #colors to use
    colors = [
        'gold', 'orange', 'darkorange', 'orangered', 'firebrick', 'darkred',
        'indigo', 'navy', 'royalblue', 'cornflowerblue', 'slategray', 'peru',
        'saddlebrown', 'black'
    ]

    if time_idx == -1:
        im = X[-1].plot.bar(
            stacked=True,
            ax=hlpr.ax,
            color=colors[0:-1:np.maximum(1, int(len(colors) / len(ages)))],
            legend=False)

    else:
        im = X[time_idx].plot.bar(stacked=True,
                                  ax=hlpr.ax,
                                  color=colors,
                                  legend=False)

    def update_data(stepsize: int = 1):
        """Updates the data of the imshow objects"""
        log.info("Plotting animation with %d frames ...",
                 data.shape[0] // stepsize)

        next_frame_idx = 0

        if time_steps < stepsize:
            warnings.warn("Stepsize is greater than number of steps. "
                          "Continue by plotting fist and last frame.")
            stepsize = time_steps - 1

        for time_idx in range(time_steps):
            if time_idx < next_frame_idx and time_idx < time_steps:
                continue
            hlpr.ax.clear()
            im = X[time_idx].plot.bar(
                stacked=True,
                ax=hlpr.ax,
                legend=False,
                color=colors[0:-1:np.maximum(1, int(len(colors) / len(ages)))],
                rot=0)
            time_text = hlpr.ax.text(0.02,
                                     0.95,
                                     '',
                                     transform=hlpr.ax.transAxes)
            timestamp = int(time[time_idx] / life_cycle)
            time_text.set_text('%.0f years' % timestamp)
            hlpr.ax.legend(loc='upper right')
            hlpr.ax.set_xticks([i for i in np.linspace(0, bins - 1, 11)])
            hlpr.ax.set_xticklabels(
                [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1])
            hlpr.ax.set_title(hlpr.axis_cfg['set_title']['title'])
            hlpr.ax.set_xlabel(hlpr.axis_cfg['set_labels']['x'])
            hlpr.ax.set_ylabel(hlpr.axis_cfg['set_labels']['y'])
            next_frame_idx = time_idx + stepsize
            yield

    # Register this update method with the helper, which takes care of the rest
    hlpr.register_animation_update(update_data)
예제 #14
0
def opinion_animation(dm: DataManager,
                      *,
                      uni: UniverseGroup,
                      hlpr: PlotHelper,
                      num_bins: int = 100,
                      time_idx: int,
                      title: str = None,
                      val_range: tuple = (0., 1.)):
    """Plots an animated histogram of the opinion distribution over time. If
    the model mode is 'conflict_undir', the opinion distribution of the
    discriminators and non-discriminators is also shown.

    Arguments:
        num_bins(int): Binning of the histogram
        time_idx (int, optional): Only plot one single frame (eg. last frame)
        title (str, optional): Custom plot title
        val_range(int, optional): Value range of the histogram
    """

    #figure layout..............................................................
    #the 'conflict_undir' has a non-standard plot layout with two additional axis
    #for the discriminators' and non-discriminators' opinion distributions
    if uni['cfg']['OpDisc']['mode'] == 'conflict_undir':
        disc_plot = True
        figure, axs = setup_figure(uni['cfg'],
                                   plot_name='opinion_anim',
                                   title=title,
                                   figsize=(10, 15),
                                   ncols=2,
                                   nrows=3,
                                   height_ratios=[1, 6, 3],
                                   width_ratios=[1, 1],
                                   gridspec=[(0, slice(0, 2)), (1, slice(0,
                                                                         2)),
                                             (2, 0), (2, 1)])
    else:
        disc_plot = False
        figure, axs = setup_figure(uni['cfg'],
                                   plot_name='opinion_anim',
                                   title=title)
    hlpr.attach_figure_and_axes(fig=figure, axes=axs)

    #datasets...................................................................
    opinions = uni['data/OpDisc/nw/opinion']
    time = opinions['time'].data
    time_steps = time.size
    #dict containing the data to plot, as well axis-specific info
    to_plot = {
        'all': {
            'data': opinions,
            'axs_idx': 1,
            'text': '',
            'color': 'dodgerblue'
        }
    }

    #data analysis..............................................................
    if disc_plot:
        #calculate the opinions of only the discriminators and non-discriminators
        #respectively and add to the dict
        discriminators = uni['data/OpDisc/nw/discriminators']
        p_disc = uni['cfg']['OpDisc']['discriminators']
        mask = np.empty(opinions.shape, dtype=bool)
        mask[:, :] = (discriminators == 0)
        ops_disc = np.ma.MaskedArray(opinions, mask)
        ops_nondisc = np.ma.MaskedArray(opinions, ~mask)

        to_plot['disc'] = {
            'data': ops_disc,
            'axs_idx': 2,
            'color': 'teal',
            'text': f'discriminators ($p_d$={p_disc})'
        }

        to_plot['nondisc'] = {
            'data': ops_nondisc,
            'axs_idx': 3,
            'color': 'mediumaquamarine',
            'text': f'discriminators ($1-p_d$={1-p_disc})'
        }

    #get histograms.............................................................
    def get_hist_data(input_data):
        counts, bin_edges = np.histogram(input_data,
                                         range=val_range,
                                         bins=num_bins)
        bin_pos = bin_edges[:-1] + (np.diff(bin_edges) / 2.)

        return counts, bin_edges, bin_pos

    bars = {}
    t = time_idx if time_idx else range(time_steps)
    #calculate histograms, set axis ranges, set axis descriptions in upper left
    #corners
    for key in to_plot.keys():
        counts, bin_edges, pos = get_hist_data(to_plot[key]['data'][t, :])
        hlpr.select_axis(0, to_plot[key]['axs_idx'])
        hlpr.ax.set_xlim(val_range)
        bars[key] = hlpr.ax.bar(pos,
                                counts,
                                width=np.diff(bin_edges),
                                color=to_plot[key]['color'])

    for key in to_plot.keys():
        hlpr.select_axis(0, to_plot[key]['axs_idx'])
        to_plot[key]['text'] = hlpr.ax.text(0.02,
                                            0.93,
                                            to_plot[key]['text'],
                                            transform=hlpr.ax.transAxes)

    #animate....................................................................
    def update_data(stepsize: int = 1):
        """Updates the data of the imshow objects"""
        if time_idx:
            log.info(
                f"Plotting distribution at time step {time[time_idx]} ...")
        else:
            log.info(
                f"Plotting animation with {opinions.shape[0] // stepsize} "
                "frames ...")
        next_frame_idx = 0
        if time_steps < stepsize:
            log.warn("Stepsize is greater than number of steps. Continue by "
                     "plotting fist and last frame.")
            stepsize = time_steps - 1
        for t in range(time_steps):
            if t < next_frame_idx:
                continue
            if time_idx:
                t = time_idx
            for key in to_plot.keys():
                hlpr.select_axis(0, to_plot[key]['axs_idx'])
                data = to_plot[key]['data'][t, :]
                if key != 'all':
                    #for disc/non-disc plots, the data is a masked array and needs
                    #to be compressed (removing None values)
                    data = data.compressed()
                counts_at_t, _, _ = get_hist_data(data)
                for idx, rect in enumerate(bars[key]):
                    rect.set_height(counts_at_t[idx])
                if key == 'all':
                    to_plot[key]['text'].set_text(f'step {time[t]}')
                    hlpr.ax.relim()
                    hlpr.ax.autoscale_view(scalex=False)
                    y_max = hlpr.ax.get_ylim()
                else:
                    #rescale ylim to same value for all plots
                    hlpr.ax.set_ylim(y_max)
            if time_idx:
                yield
                break
            next_frame_idx = t + stepsize
            yield

    hlpr.register_animation_update(update_data)
예제 #15
0
def group_avg(dm: DataManager,
              *,
              uni: UniverseGroup,
              hlpr: PlotHelper,
              age_groups: list = [10, 20, 40, 60, 80],
              num_bins: int = 100,
              title: str = None,
              val_range: tuple = (0, 1)):
    """This function plots the average opinion of each group over time.

    Arguments:
       age_groups (list): the age binning to be plotted for the 'ageing' model
       num_bins (int, optional): binning size for the histogram
       title (str, optional): custom title for the plot
       val_range (tuple, optional): binning range for the histogram

    Raises:
        TypeError: if the 'age_groups' list does not contain at least two
            entries
    """
    if len(age_groups) < 2:
        raise TypeError("'age_groups' list must contain at least 2 entries!")

    #figure setup ..............................................................
    figure, axs = setup_figure(uni['cfg'], plot_name='group_avg', title=title)
    hlpr.attach_figure_and_axes(fig=figure, axes=axs)
    hlpr.select_axis(0, 1)

    #get data ..................................................................
    ageing = True if uni['cfg']['OpDisc']['mode'] == 'ageing' else False
    opinions = uni['data/OpDisc/nw/opinion']
    if ageing:
        groups = np.asarray(uni['data/OpDisc/nw/group_label'], dtype=int)
    else:
        groups = np.asarray(uni['data/OpDisc/nw/group_label'][0, :], dtype=int)
    num_groups = len(
        age_groups) - 1 if ageing else uni['cfg']['OpDisc']['number_of_groups']
    group_list = age_groups if ageing else [_ for _ in range(num_groups)]
    time_steps = opinions['time'].size
    time = np.asarray(opinions['time'].data)
    hlpr.ax.set_xlim(0, 1)
    hlpr.ax.set_ylim(time[-1], time[0])

    #data analysis..............................................................
    #calculate mean opinion and std of each group
    means = np.zeros((time_steps, num_groups))
    stddevs = np.zeros_like(means)
    data_by_groups = data_by_group(opinions,
                                   groups,
                                   group_list,
                                   val_range,
                                   num_bins,
                                   ageing=ageing)
    for k in range(num_groups):
        for t in range(time_steps):
            if len(data_by_groups[k][t]) == 0:
                #empty slices may occur if certain age groups are not present
                #for a period of time
                continue
            means[t, k] = np.mean(data_by_groups[k][t])
            stddevs[t, k] = np.std(data_by_groups[k][t])

    #plotting...................................................................
    #get pretty labels
    if ageing:
        labels = [
            f"Ages {group_list[_]}-{group_list[_+1]}"
            for _ in range(num_groups)
        ]
        max_age = np.amax(groups)
        if (age_groups[-1] >= max_age):
            labels[-1] = f"Ages {group_list[-2]}+"
    else:
        labels = [f"Group {_+1}" for _ in range(num_groups)]

    #plot mean opinion with std as errorbar
    for i in range(num_groups):
        hlpr.ax.errorbar(means[:, i],
                         time,
                         xerr=stddevs[:, i],
                         lw=2,
                         alpha=0.8,
                         elinewidth=25. / time_steps,
                         label=labels[i],
                         capsize=0,
                         capthick=1)

    hlpr.ax.set_xticks(np.linspace(0, 1, 11), minor=False)
    hlpr.ax.xaxis.grid(True, which='major', lw=0.1)

    #temporary..................................................................
    #calculate the global mean and plot its turning points
    # mean_glob = pd.Series(np.mean(opinions, axis=1)).rolling(window=20).mean()
    # hlpr.ax.plot(mean_glob, time, lw=1, color='black', label='avg', zorder=num_groups+1)
    #
    # extremes = find_extrema(mean_glob, x=time)['max']
    # if extremes['y']:
    #      hlpr.ax.scatter(x=extremes['y'], y=extremes['x'], s=10, alpha=0.8, zorder=num_groups+2)
    #
    # constants = find_const_vals(mean_glob, time_steps, time=time, averaging_window=0.4, tolerance=0.01)
    #
    # if constants['t']:
    #     hlpr.ax.scatter(x=constants['x'], y=constants['t'], s=10, color='red', alpha=0.8, zorder=num_groups+2)
    hlpr.ax.legend(bbox_to_anchor=(1, 1.01),
                   loc='lower right',
                   ncol=num_groups + 1,
                   fontsize='xx-small')
예제 #16
0
파일: sweep1d.py 프로젝트: ThGaskin/OpDisc
def sweep1d(dm: DataManager,
            *,
            hlpr: PlotHelper,
            mv_data,
            age_groups: list = [10, 20, 40, 60, 80],
            dim: str = None,
            plot_by_groups: bool = True,
            plot_kwargs: dict = {},
            to_plot: str):
    """Plots statistical measures of the final distribution over a sweep
    parameter.

    Configuration:
        - use the `select/field` key to associate one or multiple datasets
        - choose the dimension `dim` in which the sweep was performed. For a single
          sweep dimension, the sweep parameter is automatically deduced
        - use the `select/subspace` key to set values for all other parameters

    Arguments:
        dm (DataManager): the data manager from which to retrieve the data
        hlpr (PlotHelper): description
        mv_data (xr.Dataset): the extracted multidimensional dataset
        age_groups (list): The age intervals to be plotted in the final_ax
            distribution plot for the 'ageing' mode.
        dim (str, optional): the parameter dimension of the diagram. If none is
            provided, an attempt will be made to automatically deduce the sweep
            parameter
        plot_kwargs (dict): kwargs passed to the errorbar plot function
        to_plot (str): the data to be plotted. Can be:
            - absolute_area: the area (unsigned) of the mean minus 0.5 of the
              entire population. is always positive.
            - area: the area (signed) of the mean minus 0.5 of the entire
              population. can be positive or negative.
            - means: the mean of each group at the final time step, with an error
            - stddevs: the stddev of each group at the final time step, with an
              error

    Raises:
        ValueError: if an unknown 'to_plot' argument is passed
        ValueError: if the dimension cannot be deduced
        ValueError: if the dimension is not available
    """

    if to_plot not in [
            'absolute_area', 'area', 'area_comp', 'area_diff', 'means',
            'stddevs'
    ]:
        raise ValueError(f"Unknown statistical variable {to_plot}!")

    if dim is None:
        dim = deduce_sweep_dimension(mv_data)
    else:
        if not dim in mv_data.dims:
            raise ValueError(
                f"Dimension '{dim}' not available in multiverse data."
                f" Available: {mv_data.coords}")

    #get datasets and cfg ......................................................
    keys, cfg = get_keys_cfg(mv_data,
                             dm['multiverse'].pspace.default,
                             keys_to_ignore=[dim, 'time'])
    mode = cfg['OpDisc']['mode']
    ageing = True if mode == 'ageing' else False
    num_groups = len(
        age_groups) - 1 if ageing else cfg['OpDisc']['number_of_groups']
    group_list = age_groups if ageing else [_ for _ in range(num_groups)]
    #get pretty labels
    if ageing:
        labels = [
            f"Ages {group_list[_]}-{group_list[_+1]}"
            for _ in range(num_groups)
        ]
        max_age = np.amax(mv_data[keys]['group_label'])
        if (age_groups[-1] >= max_age):
            labels[-1] = f"Ages {group_list[-2]}+"
    else:
        labels = [f"Group {_+1}" for _ in group_list]

    #figure setup ..............................................................
    figure, axs = setup_figure(cfg, plot_name=to_plot, dim1=dim)
    hlpr.attach_figure_and_axes(fig=figure, axes=axs)
    hlpr.select_axis(0, 1)

    #data analysis and plotting ................................................
    log.info("Commencing data analytics ...")

    if to_plot == 'absolute_area':
        data_to_plot, err = get_absolute_area(mv_data, keys, dim=dim)
        hlpr.ax.errorbar(mv_data.coords[dim].data,
                         data_to_plot,
                         yerr=err,
                         **plot_kwargs)

    elif to_plot == 'area':
        data_to_plot, err = get_area(mv_data, keys, dim=dim)
        hlpr.ax.errorbar(mv_data.coords[dim].data,
                         data_to_plot,
                         yerr=err,
                         **plot_kwargs)

    elif to_plot == 'area_comp':
        data_to_plot_0, err_0 = get_absolute_area(mv_data, keys, dim=dim)
        # hlpr.ax.errorbar(mv_data.coords[dim].data, data_to_plot_0, yerr=err_0, **plot_kwargs, label=r'$\vert A \vert$')
        data_to_plot_1, err_1 = get_area(mv_data, keys, dim=dim)
        # hlpr.ax.errorbar(mv_data.coords[dim].data, data_to_plot_1, yerr=err_1, **plot_kwargs, label=r'$A$')
        # hlpr.ax.legend(bbox_to_anchor=(1, 1.01), loc='lower right',
        #                ncol=2, fontsize='xx-small')
        #write data values for further evaluation....................................

        res = np.vstack((data_to_plot_0, err_0, data_to_plot_1, err_1))
        idx = ['abs_area', 'abs_area_err', 'area', 'area_err']
        df = pd.DataFrame(res, idx, mv_data.coords[dim].data)
        phom = cfg['OpDisc']['homophily_parameter']
        df.to_csv(
            hlpr.out_path.replace('area.pdf',
                                  f'area_N_{num_groups}_phom_{phom}.csv'))
        log.info("Finished writing files")

    elif to_plot == 'area_diff':
        data_to_plot_0, err_0 = get_absolute_area(mv_data, keys, dim=dim)
        data_to_plot_1, err_1 = get_area(mv_data, keys, dim=dim)
        hlpr.ax.plot(mv_data.coords[dim].data,
                     np.subtract(data_to_plot_0, data_to_plot_1),
                     **plot_kwargs)

    elif to_plot == 'means' or to_plot == 'stddevs':
        keys['time'] = -1
        data_to_plot, err = means_stddevs_by_group(mv_data,
                                                   group_list,
                                                   dim,
                                                   keys,
                                                   mode=mode,
                                                   ageing=ageing,
                                                   num_groups=num_groups,
                                                   which=to_plot,
                                                   time_step=-1)
        if plot_by_groups:
            for i in range(len(err)):
                hlpr.ax.errorbar(mv_data.coords[dim].data,
                                 data_to_plot[i],
                                 yerr=err[i],
                                 label=labels[i],
                                 **plot_kwargs)
            hlpr.ax.legend(bbox_to_anchor=(1, 1.01),
                           loc='lower right',
                           ncol=num_groups + 1,
                           fontsize='xx-small')
        else:
            hlpr.ax.errorbar(mv_data.coords[dim].data,
                             np.mean(data_to_plot, axis=0),
                             yerr=np.mean(err, axis=0),
                             **plot_kwargs)

    log.info("Data analysis complete.")

    #set axis lables etc .......................................................
    hlpr.ax.set_xlabel(convert_to_label(dim))
    hlpr.ax.set_ylabel(convert_to_label(to_plot))
예제 #17
0
def group_avg_anim(dm: DataManager,
                   *,
                   hlpr: PlotHelper,
                   mv_data,
                   dim: str,
                   age_groups: list = [10, 20, 40, 60, 80],
                   num_bins: int = 100,
                   title: str = None,
                   val_range: tuple = (0, 1),
                   write: bool = False):
    """Plots an animation of the average opinion evolution by group as a
    function of the sweep parameter.

    Arguments:
        dm (DataManager): the data manager from which to retrieve the data
        hlpr (PlotHelper): description
        mv_data (xr.Dataset): the extracted multidimensional dataset
        dim (str): the parameter dimension of the diagram
        num_bins (int, optional): binning size for the histogram
        title (str, optional): custom plot title
        val_range (tuple, optional): binning range for the histogram
        write (bool, optional): if true, the model will write the widths of the
           distribution at each time step and with corresponding R_p factors
           (for the purposes of my thesis only)
    Raises:
        ValueError: if the dimension is not present in the multiverse data
        ValueError: if the parameter space is greater than four
        ValueError: if the sweep parameter is 'seed' (to do)
    """

    if not dim in mv_data.dims:
        raise ValueError(f"Dimension '{dim}' not available in multiverse data."
                         f" Available: {mv_data.coords}")
    if len(mv_data.dims) > 3:
        for key in mv_data.dims.keys():
            if key not in ['vertex', 'time', dim] and mv_data.dims[key] > 1:
                raise ValueError(
                    f"Too many dimensions! Use 'subspace' to "
                    f"select specific values for keys other than {dim}!")

    if dim == 'seed':
        raise ValueError("'seed' sweeps currently not supported.")

    #datasets...................................................................
    #manually modify any subspace entries in the cfg
    cfg = dm['multiverse'].pspace.default
    #replace the mode if it is a subspace selection
    if 'mode' in mv_data.coords:
        cfg['OpDisc']['mode'] = str(mv_data.coords['mode'].data[0])
    mode = cfg['OpDisc']['mode']
    ageing = True if mode == 'ageing' else False
    #get group labels
    keys = dict(zip(dict(mv_data.dims).keys(), [0] * len(mv_data.dims)))
    keys.pop('vertex')
    if ageing:
        #group labels change over time
        keys.pop('time')
        groups = np.asarray(mv_data['group_label'][keys], dtype=int)
    else:
        #group labels do not change over time
        groups = np.asarray(mv_data['group_label'][keys], dtype=int)
        keys.pop('time')
    #replace any other subspace selection parameters
    for key in keys:
        if key != dim:
            cfg['OpDisc'][key] = mv_data.coords[key].data[0]

    num_groups = len(
        age_groups) - 1 if ageing else cfg['OpDisc']['number_of_groups']
    num_vertices = cfg['OpDisc']['nw']['num_vertices']
    group_list = age_groups if ageing else [_ for _ in range(num_groups)]
    time_steps = mv_data['time'].size
    time = mv_data['time'].data

    #figure layout .............................................................
    figure, axs = setup_figure(cfg,
                               plot_name='group_avgs_anim',
                               title=title,
                               dim=dim)
    hlpr.attach_figure_and_axes(fig=figure, axes=axs)
    hlpr.select_axis(0, 1)

    #data analysis .............................................................
    #get mean opinion and std of each group using tools.data_by_group
    means = np.zeros((len(mv_data.coords[dim]), time_steps, num_groups))
    stddevs = np.zeros_like(means)
    for param in range(len(mv_data.coords[dim])):
        keys[dim] = param
        data = np.asarray(mv_data[keys]['opinion'])
        data_by_groups = data_by_group(data,
                                       groups,
                                       group_list,
                                       val_range,
                                       num_bins,
                                       ageing=ageing)
        for k in range(num_groups):
            for t in range(time_steps):
                means[param, t, k] = np.mean(data_by_groups[k][t])
                stddevs[param, t, k] = np.std(data_by_groups[k][t])
    log.info("Finished data analysis.")

    #plotting...................................................................
    #get pretty labels
    if ageing:
        labels = [
            f"Ages {group_list[_]}-{group_list[_+1]}"
            for _ in range(num_groups)
        ]
        max_age = np.amax(groups)
        if (age_groups[-1] >= max_age):
            labels[-1] = f"Ages {group_list[-2]}+"
    else:
        labels = [f"Group {_+1}" for _ in range(num_groups)]

    #calculate R_p factor (for p_hom sweeps)
    P, Q = R_p_factors(num_vertices, num_groups)
    R_p_fs = R_p(mv_data.coords[dim], num_groups, P, Q, mode)

    #animate
    def update_data(stepsize: int = 1):
        log.info(
            f"Plotting animation with {len(mv_data.coords[dim])} frames ...")
        for param in range(len(mv_data.coords[dim])):
            hlpr.ax.clear()
            hlpr.ax.set_xlim(0, 1)
            hlpr.ax.set_ylim(time[-1], 0)
            hlpr.ax.set_xlabel(hlpr.axis_cfg['set_labels']['x'])
            hlpr.ax.set_ylabel(hlpr.axis_cfg['set_labels']['y'])
            if dim == 'homophily_parameter':
                sw_text = (
                    f"$R_p=${R_p_fs[param]:.3f} ({convert_to_label(dim)} = {mv_data[dim][param].data})"
                )
            else:
                sw_text = f"{convert_to_label(dim)}={mv_data[dim][param].data}"
            sweep_text = hlpr.ax.text(0,
                                      1.02,
                                      sw_text,
                                      fontsize='x-small',
                                      transform=hlpr.ax.transAxes)
            for i in range(num_groups):
                hlpr.ax.errorbar(means[param, :, i],
                                 time,
                                 xerr=stddevs[param, :, i],
                                 lw=2,
                                 alpha=1,
                                 elinewidth=0.1,
                                 label=labels[i],
                                 capsize=1,
                                 capthick=0.3)
            hlpr.ax.legend(bbox_to_anchor=(1, 1.01),
                           loc='lower right',
                           ncol=num_groups,
                           fontsize='xx-small')
            yield

    hlpr.register_animation_update(update_data)

    #write data values for further evaluation....................................
    #This is for the purpose of my thesis only and will be removed upon
    #completion.
    if write and dim == 'homophily_parameter':
        widths = np.zeros((time_steps, len(mv_data.coords[dim])))
        w_0 = np.min(means[:, -1, -1] - means[:, -1, 0])
        w_max = np.max(means[-1, :, -1] - means[-1, :, 0])
        for param in range(len(mv_data.coords[dim])):
            widths[:, param] = (means[param, :, -1] - means[param, :, 0] -
                                w_0) / (w_max - w_0)
        df = pd.DataFrame(widths, time, R_p_fs)
        df.to_csv(
            hlpr.out_path.replace('group_avgs_anim.mp4', f'widths_{mode}.csv'))
        log.info("Finished writing files")
예제 #18
0
파일: sweep2d.py 프로젝트: ThGaskin/OpDisc
def sweep2d(dm: DataManager,
            *,
            hlpr: PlotHelper,
            mv_data,
            age_groups: list = [10, 20, 40, 60, 80],
            x: str,
            y: str,
            plot_kwargs: dict = {},
            stacked: bool = False,
            to_plot: str):
    """For multiverse runs, this produces a two dimensional plot showing
    specified values.

    Configuration:
        - use the `select/field` key to associate one or multiple datasets
        - choose the dimension `dim` in which the sweep was performed. For a single
          sweep dimension, the sweep parameter is automatically deduced
        - use the `select/subspace` key to set values for all other parameters

    Arguments:
        dm (DataManager): the data manager from which to retrieve the data
        hlpr (PlotHelper): description
        mv_data (xr.Dataset): the extracted multidimensional dataset
        age_groups (list): The age intervals to be plotted in the final_ax
            distribution plot for the 'ageing' mode.
        x (str): the first parameter dimension of the diagram.
        y (str): the first parameter dimension of the diagram.
        plot_kwargs (dict, optional): kwargs passed to the scatter plot function
        stacked (bool): whether to plot a 2d heatmap or a stacked line plot
        to_plot (str): the data to be plotted. Can be:
            - extreme_means_diff: the difference between the means of the outer
              groups at the final time step. not compatible with a seed sweep.
            - avg_of_means_diff_to_05: the average of the absolute distance of
              each group to 0.5 at the final time step
            - absolute_area: the area (unsigned) of the mean minus 0.5 of the
              entire population. is always positive.
            - area: the area (signed) of the mean minus 0.5 of the entire
              population. can be positive or negative.
            - area_diff: the difference of the absolute area und the signed
              area under the means curve.

    Raises:
        ValueError: if a sweep over 'seed' is performed and to_plot is
        'extreme_means_diff'.
    """
    if ((to_plot == 'extreme_means_diff') and ('seed' in mv_data.coords)
            and (len(mv_data.coords['seed'].data) > 1)):
        raise ValueError(
            "Plotting does not support 'seed' at this time. Select"
            " a single value using the 'subspace' key")

    #get datasets and cfg ......................................................
    keys, cfg = get_keys_cfg(mv_data,
                             dm['multiverse'].pspace.default,
                             keys_to_ignore=[x, y])
    mode = cfg['OpDisc']['mode']
    ageing = True if mode == 'ageing' else False
    num_groups = len(
        age_groups) - 1 if ageing else cfg['OpDisc']['number_of_groups']
    group_list = age_groups if ageing else [_ for _ in range(num_groups)]

    requires_group_label = [
        'extreme_means_diff', 'avg_of_means_diff_to_05', 'avg_of_stddevs'
    ]
    #get group labels
    if ageing:
        #group labels change over time
        keys.update({x: 0, y: 0})
        if to_plot in requires_group_label:
            groups = np.asarray(mv_data['group_label'][keys], dtype=int)
        for ele in [x, y]:
            keys.pop(ele)
    else:
        #group labels do not change over time
        keys.update({'time': 0, x: 0, y: 0})
        if to_plot in requires_group_label:
            groups = np.asarray(mv_data['group_label'][keys], dtype=int)
        for ele in [x, y, 'time']:
            keys.pop(ele)

    #figure setup ..............................................................
    figure, axs = setup_figure(cfg, plot_name=to_plot, dim1=x, dim2=y)
    hlpr.attach_figure_and_axes(fig=figure, axes=axs)
    hlpr.select_axis(0, 1)

    #data analysis .............................................................
    if to_plot == 'extreme_means_diff':
        data_to_plot = difference_of_extreme_means(mv_data,
                                                   x,
                                                   y,
                                                   groups,
                                                   group_list,
                                                   ageing=ageing,
                                                   group_1=0,
                                                   group_2=-1,
                                                   time_step=-1)

    elif to_plot == 'avg_of_means_diff_to_05':
        data_to_plot = avg_of_means_stddevs(mv_data,
                                            x,
                                            y,
                                            groups,
                                            group_list,
                                            keys,
                                            mode,
                                            num_groups,
                                            which='means',
                                            ageing=ageing,
                                            time_step=-1)

    elif to_plot == 'avg_of_stddevs':
        data_to_plot = avg_of_means_stddevs(mv_data,
                                            x,
                                            y,
                                            groups,
                                            group_list,
                                            keys,
                                            mode,
                                            num_groups,
                                            which='stddevs',
                                            ageing=ageing,
                                            time_step=-1)

    elif to_plot == 'absolute_area':
        data_to_plot = np.zeros(
            (len(mv_data.coords[y]), len(mv_data.coords[x])))
        for param1 in range(len(mv_data.coords[x])):
            keys[x] = param1
            data_to_plot[:, param1] = get_absolute_area(mv_data, keys,
                                                        dim=y)[0]

    elif to_plot == 'area':
        data_to_plot = np.zeros(
            (len(mv_data.coords[y]), len(mv_data.coords[x])))
        for param1 in range(len(mv_data.coords[x])):
            keys[x] = param1
            data_to_plot[:, param1] = get_area(mv_data, keys, dim=y)[0]

    elif to_plot == 'area_diff':
        data_to_plot = np.zeros(
            (len(mv_data.coords[y]), len(mv_data.coords[x])))
        for param1 in range(len(mv_data.coords[x])):
            keys[x] = param1
            val_1 = get_absolute_area(mv_data, keys, dim=y)[0]
            val_2 = get_area(mv_data, keys, dim=y)[0]
            data_to_plot[:, param1] = np.subtract(val_1, val_2)

    #plotting ..................................................................
    if stacked:
        for i in range(len(mv_data.coords[y])):
            hlpr.ax.plot(
                data_to_plot[i, :],
                label=f'{convert_to_label(y)}={mv_data.coords[y].data[i]}',
                **plot_kwargs)
        hlpr.ax.legend(bbox_to_anchor=(1, 1.01),
                       loc='lower right',
                       ncol=len(mv_data.coords[y]) + 1,
                       fontsize='xx-small')

    else:
        df = pd.DataFrame(data_to_plot,
                          index=mv_data.coords[y].data,
                          columns=mv_data.coords[x].data)
        im = hlpr.ax.pcolor(df, **plot_kwargs)
        hlpr.ax.set_ylabel(parameters[y], rotation=0)
        hlpr.ax.set_yticks([
            i for i in np.linspace(0.5,
                                   len(mv_data.coords[y].data) -
                                   0.5, len(mv_data.coords[y].data))
        ])
        hlpr.ax.set_yticklabels(
            [np.around(i, 3) for i in mv_data.coords[y].data])

        hlpr.ax.set_xlabel(parameters[x])
        hlpr.ax.set_xticks([
            i for i in np.linspace(0.5,
                                   len(mv_data.coords[x].data) -
                                   0.5, len(mv_data.coords[x].data))
        ])
        hlpr.ax.set_xticklabels(
            [np.around(i, 3) for i in mv_data.coords[x].data])

        divider = make_axes_locatable(hlpr.ax)
        cax = divider.append_axes("right", size="5%", pad=0.2)
        cbar = figure.colorbar(im, cax=cax)

        cbar.set_label(convert_to_label(to_plot))
예제 #19
0
def sweep( dm: DataManager, *,
           hlpr: PlotHelper,
           mv_data,
           plot_prop: str,
           dim: str,
           dim2: str=None,
           stack: bool=False,
           no_errors: bool=False,
           data_name: str='opinion_u',
           bin_number: int=100,
           plot_kwargs: dict=None):
    """Plots a bifurcation diagram for one parameter dimension (dim)
        i.e. plots the chosen final distribution measure over the parameter,
        or - if second parameter dimension (dim2) is given - plots the
        2d parameter space as a heatmap (if not stacked).
    
    Configuration:
        - use the `select/field` key to associate one or multiple datasets
        - change `data_name` if needed
        - choose the dimension `dim` (and `dim2`) in which the sweep was
          performed.
    
    Arguments:
        dm (DataManager): The data manager from which to retrieve the data
        hlpr (PlotHelper): Description
        mv_data (xr.Dataset): The extracted multidimensional dataset
        plot_prop (str): The quantity that is extracted from the
            data. Available are: ['number_of_peaks', 'localization',
            'max_distance', 'polarization']
        dim (str): The parameter dimension of the diagram
        dim2 (str, optional): The second parameter dimension of the diagram
        stack (bool, optional): Whether the plots for dim2 are stacked
            or extend to heatmap
        data_name (str, optional): Description
        bin_number (int, optional): default: 100
            number of bins for the discretization of the final distribution
        plot_kwargs (dict, optional): passed to the plot function
    
    Raises:
        TypeError: for a parameter dimesion higher than 5
            (and higher than 4 if not sweeped over seed)
        ValueError: If 'data_name' data does not exist
    """

    # Drop coordinates that only have a single value
    # (i.e. when certain value selected in fields/select/..)
    # for coord in mv_data.coords:
    #     if len(mv_data[coord]) == 1:
    #         mv_data = mv_data[{coord: 0}]
    #         mv_data = mv_data.drop(coord)

    if not dim in mv_data.dims:
        raise ValueError("Dimension `dim` not available in multiverse data."
                         " Was: {} with value: '{}'."
                         " Available: {}"
                         "".format(type(dim), 
                                   dim,
                                   mv_data.coords))
    if len(mv_data.coords) > 5:
        raise TypeError("mv_data has more than two extra parameter dimensions."
                        " Are: {}. Chosen dim: {}. (Max: ['vertex', 'time', "
                        "'seed'] + 2)".format(mv_data.coords, dim))

    if (len(mv_data.coords) > 4) and ('seed' not in mv_data.coords):
        raise TypeError("mv_data has more than two extra parameter dimensions."
                        " Are: {}. Chosen dim: {}. (Max: ['vertex', 'time', "
                        "'seed'] + 2)".format(mv_data.coords, dim))

    plot_kwargs = (plot_kwargs if plot_kwargs else {})

    # Default plot configurations
    plot_kwargs_default_1d = {'fmt': 'o', 'ls': '-', 'lw': .4, 'mec': None,
                              'capsize': 1, 'mew': .6, 'ms': 2}
    
    if no_errors:
        plot_kwargs_default_1d = {'marker': 'o', 'ms': 2, 'lw': 0.4}

    plot_kwargs_default_2d = {'origin': 'lower', 'cmap': 'Spectral_r'}

    # analysis and plot functions ..............................................

    def plot_data_1d(param_plot, data_plot, std, plot_kwargs):
        if no_errors:
            hlpr.ax.plot(param_plot, data_plot, **plot_kwargs)
        else:
            hlpr.ax.errorbar(param_plot, data_plot, yerr=std, **plot_kwargs)

    def plot_data_2d(data_plot, param1, param2, plot_kwargs):

        heatmap = hlpr.ax.imshow(data_plot, **plot_kwargs) #'RdYlGn_r'

        if len(param1) <= 10:
            hlpr.ax.set_xticks(np.arange(len(param1)))
            xticklabels = np.array(["{:.2f}".format(p) for p in param1])
        else:
            hlpr.ax.set_xticks(np.arange(len(param1))
                            [::(int)(np.ceil(len(param1)/10))])
            xticklabels = np.array(["{:.2f}".format(p) 
                            for p in param1[::(int)(np.ceil(len(param1)/10))]])
        if len(param2) <= 10:
            hlpr.ax.set_yticks(np.arange(len(param2)))
            yticklabels = np.array(["{:.2f}".format(p) for p in param2])
        else:
            hlpr.ax.set_yticks(np.arange(len(param2))
                            [::(int)(np.ceil(len(param2)/10))])
            yticklabels = np.array(["{:.2f}".format(p) 
                            for p in param2[::(int)(np.ceil(len(param2)/10))]])
        
        hlpr.ax.set_xticklabels(xticklabels)
        hlpr.ax.set_yticklabels(yticklabels)
        plt.colorbar(heatmap)

    def get_number_of_peaks(raw_data):
        final_state = raw_data.isel(time=-1)
        final_state = final_state[~np.isnan(final_state)]

        # binning of the final opinion distribution with binsize=1/bin_number
        hist_data,bin_edges = np.histogram(final_state, range=(0.,1.),
                                            bins=bin_number)

        peak_number = len(find_peaks(hist_data, prominence=15, distance=5)[0])

        return peak_number

    def get_localization(raw_data):
        final_state = raw_data.isel(time=-1)
        final_state = final_state[~np.isnan(final_state)]

        hist, bins = np.histogram(final_state, range=(0.,1.), bins=bin_number,
                                    density=True)
        hist *= 1/bin_number

        l = 0
        norm = 0
        for i in range(len(hist)):
            norm += hist[i]**4
            l += hist[i]**2

        l = norm/l**2

        return l

    def get_max_distance(raw_data):
        final_state = raw_data.isel(time=-1)
        final_state = final_state[~np.isnan(final_state)]
        min = 1.
        max = 0.
        for val in final_state:
            if val > max:
                max = val
            elif val < min:
                min = val

        return max-min

    def get_polarization(raw_data):
        final_state = raw_data.isel(time=-1)
        final_state = final_state[~np.isnan(final_state)]

        p = 0
        for i in range(len(final_state)):
            for j in range(len(final_state)):
                p += (final_state[i] - final_state[j])**2

        return p

    def get_final_variance(raw_data):
        final_state = raw_data.isel(time=-1)
        final_state = final_state[~np.isnan(final_state)]

        var = np.var(final_state.values)

        return var

    def get_convergence_time(raw_data):
        max_tol = 0.05
        min_tol = 0.005
        ct = 0.
        for t in raw_data.coords['time']:
            nxt = False
            data = raw_data.sel(time=t)
            sorted_data = np.sort(data)
            for diff in np.diff(sorted_data):
                if diff > min_tol and diff < max_tol:
                    nxt = True
                    break
            if not nxt:
                ct = t
                break

        return ct

    def get_property(data, plot_prop: str=None):

        if plot_prop == 'number_of_peaks':
            return get_number_of_peaks(data)
        elif plot_prop == 'localization':
            return get_localization(data)
        elif plot_prop == 'max_distance':
            return get_max_distance(data)
        elif plot_prop == 'polarization':    
            return get_polarization(data)
        elif plot_prop == 'final_variance':    
            return get_final_variance(data)
        elif plot_prop == 'convergence_time':    
            return get_convergence_time(data)
        else:
            raise ValueError("'plot_prop' invalid! Was: {}".format(plot_prop))

    # data handling and plot setup .............................................
    legend = False
    heatmap = False

    if not data_name in mv_data.data_vars:
        raise ValueError("'{}' not available in multiverse data."
                         " Available in multiverse field: {}"
                         "".format(data_name, mv_data.data_vars))

    # this is the dataset containing the chosen data to plot
    # for all parameter combinations
    dataset = mv_data[data_name]

    # number of different parameter values i.e. number of points in the graph
    num_param = len(dataset[dim])

    # initialize arrays containing the data to plot:
    data_plot = np.zeros(num_param)
    param_plot = np.zeros(num_param)
    std = np.zeros_like(data_plot)

    # Get additional information for plotting
    leg_title = dim2
    if leg_title == "num_vertices":
        leg_title = "$N$"
    elif leg_title == "weighting":
        leg_title = "$\kappa$"
    elif leg_title == "rewiring":
        leg_title = "$r$"
    elif leg_title == "p_rewire":
        leg_title = "$p_{rewire}$"

    cmap_kwargs = plot_kwargs.pop("cmap_kwargs", None)
    if cmap_kwargs:
        cmin = cmap_kwargs.get("min", 0.)
        cmax = cmap_kwargs.get("max", 1.)
        cmap = cmap_kwargs.get("cmap")
        cmap = cm.get_cmap(cmap)

    markers = plot_kwargs.pop("markers", None)

    # If only one parameter sweep (dim) is done, the calculated quantity
    # is plotted against the parameter value.
    if (len(mv_data.coords) == 3):

        plot_kwargs = recursive_update(plot_kwargs_default_1d, plot_kwargs)
        param_index = 0
        for data in dataset:

            data_plot[param_index] = get_property(data, plot_prop)
            param_plot[param_index] = data[dim]
            param_index += 1

        if markers:
            plot_kwargs['marker'] = markers[0]

        plot_data_1d(param_plot, data_plot, std, plot_kwargs)

    # if two sweeps are done, check if the seed is sweeped
    elif (len(mv_data.coords) == 4):

        # average over the seed in this case
        if 'seed' in mv_data.coords:

            plot_kwargs = recursive_update(plot_kwargs_default_1d, plot_kwargs)

            for i in range(len(dataset[dim])):

                num_seeds = len(dataset['seed'])
                arr = np.zeros(num_seeds)

                for j in range(num_seeds):

                    data = dataset[{dim: i, 'seed': j}]
                    arr[j] = get_property(data, plot_prop)

                data_plot[i] = np.mean(arr)
                param_plot[i] = dataset[dim][i]
                std[i] = np.std(arr)

            if markers:
                plot_kwargs['marker'] = markers[0]

            plot_data_1d(param_plot, data_plot, std, plot_kwargs)
        
        # If 'stack', plot data of both dimensions against dim values (1d),
        # otherwise map data on 2d sweep parameter grid (color-coded).
        elif stack:

            legend = True
            plot_kwargs = recursive_update(plot_kwargs_default_1d, plot_kwargs)
            param_plot = dataset[dim]

            
            for i in range(len(dataset[dim2])):

                for j in range(num_param):

                    data = dataset[{dim2: i, dim: j}]
                    data_plot[j] = get_property(data, plot_prop)

                recursive_update(plot_kwargs, {'label': "{}"
                                    "".format(dataset[dim2][i].data)})
                if cmap_kwargs:
                    c = i * (cmax - cmin) / (len(dataset[dim2])-1.) + cmin
                    recursive_update(plot_kwargs, {'color': cmap(c)})

                if markers:
                    plot_kwargs['marker'] = markers[i%len(markers)]

                plot_data_1d(param_plot, data_plot, std, plot_kwargs)

        else:

            plot_kwargs = recursive_update(plot_kwargs_default_2d, plot_kwargs)
            heatmap = True
            num_param2 = len(dataset[dim2])
            data_plot = np.zeros((num_param2, num_param))
            param1 = np.zeros(num_param)
            param2 = np.zeros(num_param2)

            for i in range(num_param):
                param1[i] = dataset[dim][i]

                for j in range(num_param2):
                    param2[j] = dataset[dim2][j]
                    data = dataset[{dim: i, dim2: j}]
                    data_plot[j,i] = get_property(data, plot_prop)
            
            plot_data_2d(data_plot, param1, param2, plot_kwargs)

    elif (len(mv_data.coords) == 5):

        num_seeds = len(dataset['seed'])

        if stack:

            legend = True
            plot_kwargs = recursive_update(plot_kwargs_default_1d, plot_kwargs)
            param_plot = dataset[dim]
            param2 = dataset[dim2]
            
            for i in range(len(param2)):

                for j in range(num_param):

                    arr = np.zeros(num_seeds)

                    for k in range(num_seeds):
                        data = dataset[{dim2: i, dim: j, 'seed': k}]
                        arr[k] = get_property(data, plot_prop)

                    data_plot[j] = np.mean(arr)

                recursive_update(plot_kwargs, {'label': "{}"
                                    "".format(dataset[dim2][i].data)})
                if cmap_kwargs:
                    c = i * (cmax - cmin) / (len(dataset[dim2])-1.) + cmin
                    recursive_update(plot_kwargs, {'color': cmap(c)})

                if markers:
                    plot_kwargs['marker'] = markers[i%len(markers)]

                plot_data_1d(param_plot, data_plot, std, plot_kwargs)

        else:

            plot_kwargs = recursive_update(plot_kwargs_default_2d, plot_kwargs)
            heatmap = True
            num_param2 = len(dataset[dim2])
            data_plot = np.zeros((num_param2, num_param))
            param1 = np.zeros(num_param)
            param2 = np.zeros(num_param2)

            for i in range(num_param):
                param1[i] = dataset[dim][i]

                for j in range(num_param2):
                    param2[j] = dataset[dim2][j]
                    arr = np.zeros(num_seeds)

                    for k in range(num_seeds):
                        data = dataset[{dim: i, dim2: j, 'seed': k}]
                        arr[k] = get_property(data, plot_prop)

                    data_plot[j,i] = np.mean(arr)

            plot_data_2d(data_plot, param1, param2, plot_kwargs)

    # else: Error raised

    # Add labels and title
    if heatmap:
        hlpr.provide_defaults('set_labels', **{'x': dim, 'y': dim2})
        hlpr.provide_defaults('set_title', **{'title': plot_prop})

    else:
        hlpr.provide_defaults('set_labels',
             **{'x': ("$\epsilon$" if dim == "tolerance_u" else dim)})
        if plot_prop == "localization":
            hlpr.provide_defaults('set_labels', **{'y': "$L$"})
        elif plot_prop == "final_variance":
            hlpr.provide_defaults('set_labels', **{'y': "var($\sigma$)"})
        elif plot_prop == 'number_of_peaks':
            hlpr.provide_defaults('set_labels', **{'y': "$N_{peaks}$"})
        elif plot_prop == 'max_distance':
            hlpr.provide_defaults('set_labels', **{'y': "$d_{max}$"})
        elif plot_prop == 'convergence_time':    
            hlpr.provide_defaults('set_labels', **{'y': "$T_{conv}$"})
        else:    
            hlpr.provide_defaults('set_labels', **{'y': plot_prop})


        if legend:
            hlpr.ax.legend(title=leg_title)

        # Set minor ticks
        if plot_prop == "max_distance" or plot_prop == "localization":
            hlpr.ax.get_xaxis().set_major_locator(ticker.MultipleLocator(0.1))
            hlpr.ax.get_xaxis().set_minor_locator(ticker.MultipleLocator(0.05))
            hlpr.ax.get_yaxis().set_minor_locator(ticker.MultipleLocator(0.1))
예제 #20
0
def opinion_animated(dm: DataManager,
                     *,
                     uni: UniverseGroup,
                     hlpr: PlotHelper,
                     num_bins: int = 100,
                     time_idx: int,
                     **plot_kwargs):
    """Plots an animated histogram of the opinion distribution over time.

    Arguments:
        num_bins(int): Binning of the histogram
        time_idx (int, optional): Only plot one single frame (eg. last frame)
        plot_kwargs (dict, optional): Passed to plt.bar
    """

    #datasets...................................................................
    opinions = uni['data/Opinionet/nw/opinion']
    time = opinions['time'].data
    time_steps = time.size
    cfg_op_space = uni['cfg']['Opinionet']['opinion_space']
    if (cfg_op_space['type']) == 'continuous':
        val_range = cfg_op_space['interval']
    elif (cfg_op_space['type']) == 'discrete':
        val_range = tuple((0, cfg_op_space['num_opinions']))

    #get histograms.............................................................
    def get_hist_data(input_data):
        counts, bin_edges = np.histogram(input_data,
                                         range=val_range,
                                         bins=num_bins)
        bin_pos = bin_edges[:-1] + (np.diff(bin_edges) / 2.)

        return counts, bin_edges, bin_pos

    t = time_idx if time_idx else range(time_steps)

    #calculate histograms, set axis ranges
    counts, bin_edges, pos = get_hist_data(opinions[t, :])
    hlpr.ax.set_xlim(val_range)
    bars = hlpr.ax.bar(pos, counts, width=np.diff(bin_edges), **plot_kwargs)
    text = hlpr.ax.text(0.02,
                        0.93,
                        f'step {time[t]}',
                        transform=hlpr.ax.transAxes)

    #animate....................................................................
    def update_data(stepsize: int = 1):
        """Updates the data of the imshow objects"""
        if time_idx:
            log.info(
                f"Plotting distribution at time step {time[time_idx]} ...")
        else:
            log.info(
                f"Plotting animation with {opinions.shape[0] // stepsize} "
                "frames ...")
        next_frame_idx = 0
        if time_steps < stepsize:
            log.warn("Stepsize is greater than number of steps. Continue by "
                     "plotting first and last frame.")
            stepsize = time_steps - 1
        for t in range(time_steps):
            if t < next_frame_idx:
                continue
            if time_idx:
                t = time_idx
            data = opinions[t, :]
            counts_at_t, _, _ = get_hist_data(data)
            for idx, rect in enumerate(bars):
                rect.set_height(counts_at_t[idx])
                text.set_text(f'step {time[t]}')
                hlpr.ax.relim()
                hlpr.ax.autoscale_view(scalex=False)
            if time_idx:
                yield
                break
            next_frame_idx = t + stepsize
            yield

    hlpr.register_animation_update(update_data)
예제 #21
0
파일: graph.py 프로젝트: utopia-foss/utopia
def graph_animation_update(*,
                           hlpr: PlotHelper,
                           graphs: xr.DataArray = None,
                           graph_group=None,
                           graph_creation: dict = None,
                           register_property_maps: dict = None,
                           clear_existing_property_maps: bool = True,
                           positions: dict = None,
                           animation_kwargs: dict = None,
                           suptitle_kwargs: dict = None,
                           **drawing_kwargs):
    """Graph animation frame generator. Yields whenever the plot helper may
    grab the current frame.

    If ``graphs`` is given, the networkx graphs in the array are used to create
    the frames.

    Otherwise, use a graph group. The frames are defined via the selectors in
    ``animation_kwargs``. From all provided coordinates the cartesian product
    is taken. Each of those points defines one graph and thus one frame.
    The selection kwargs in ``graph_creation`` are ignored silently.

    Args:
        hlpr (PlotHelper): The plot helper
        graphs (xr.DataArray, optional): Networkx graphs to draw. The array
            will be flattened beforehand.
        graph_group (None, optional): Required if ``graphs`` is None. The
            GraphGroup from which to generate the animation frames as specified
            via sel and isel in ``animation_kwargs``.
        graph_creation (dict, optional): Graph creation configuration. Passed
            to :py:meth:`~utopya.plot_funcs._graph.GraphPlot.create_graph_from_group`
            if ``graph_group`` is given.
        register_property_maps (dict, optional): Passed to
            :py:meth:`~utopya.plot_funcs._graph.GraphPlot.create_graph_from_group`
            if ``graph_group`` is given.
        clear_existing_property_maps (bool, optional): Passed to
            :py:meth:`~utopya.plot_funcs._graph.GraphPlot.create_graph_from_group`
            if ``graph_group`` is given.
        positions (dict, optional): The node position configuration.
            If ``update_positions`` is True the positions are reconfigured for
            each frame.
        animation_kwargs (dict, optional): Animation configuration. The
            following arguments are allowed:

            times (dict, optional):
                *Deprecated*: Equivaluent to a sel.time entry.
            sel (dict, optional):
                Select by value. Coordinate values (or ``from_property`` entry)
                keyed by dimension name.
            isel (dict, optional):
                Select by index. Coordinate indices keyed by dimension. May be
                given together with ``sel`` if no key appears in both.
            update_positions (bool, optional):
                Whether to reconfigure the node positions for each frame
                (default=False).
            update_colormapping (bool, optional):
                Whether to reconfigure the nodes' and edges'
                :py:class:`~utopya.plot_funcs._mpl_helpers.ColorManager` for
                each frame (default=False). If False, the colormapping (and the
                colorbar) is configured with the first frame and then fixed.
            skip_empty_frames (bool, optional):
                Whether to skip the frames where the selected graph is missing
                or of a type different than ``nx.Graph`` (default=False).
                If False, such frames are empty.

        suptitle_kwargs (dict, optional): Passed on to the PlotHelper's
            ``set_suptitle`` helper function. Only used in animation mode.
            The ``title`` can be a format string containing a placeholder with
            the dimension name as key for each dimension along which selection
            is done. The format string is updated for each frame of the
            animation. The default is ``<dim-name> = {<dim-name>}`` for each
            dimension.
        **drawing_kwargs: Passed to :py:class:`~utopya.plot_funcs._graph.GraphPlot`
    """
    suptitle_kwargs = (copy.deepcopy(suptitle_kwargs)
                       if suptitle_kwargs is not None else {})
    animation_kwargs = (copy.deepcopy(animation_kwargs)
                        if animation_kwargs is not None else {})
    update_positions = animation_kwargs.pop("update_positions", False)
    update_colormapping = animation_kwargs.pop("update_colormapping", False)
    skip_empty_frames = animation_kwargs.pop("skip_empty_frames", False)
    sel = animation_kwargs.pop("sel", None)
    isel = animation_kwargs.pop("isel", None)

    if graphs is None:
        graphs = graph_array_from_group(
            graph_group=graph_group,
            graph_creation=graph_creation,
            register_property_maps=register_property_maps,
            clear_existing_property_maps=clear_existing_property_maps,
            sel=sel,
            isel=isel,
            **animation_kwargs,
        )

    else:
        # Apply selectors to the graphs DataArray
        if sel is not None:
            graphs = graphs.sel(**sel)
        if isel is not None:
            graphs = graphs.isel(**isel)

    # Prepare the suptitle format string once for all frames
    if "title" not in suptitle_kwargs:
        suptitle_kwargs["title"] = "; ".join(
            [f"{dim}" + " = {" + dim + "}" for dim in graphs.coords])

    def get_graph_and_coords(graphs: xr.DataArray):
        """Generator that yields the (graph, coordinates) pairs"""
        dims = graphs.dims
        if dims:
            # Stack all dimensions of the xr.DataArray
            graphs = graphs.stack(_flat=dims)

            # For each entry, yield the graph and the coords
            for el in graphs:
                g = el.item()
                coords = {
                    d: c
                    for d, c in zip(dims, el.coords["_flat"].item())
                }
                # Also get the scalar coordinates which were not stacked
                coords.update({
                    d: c.item()
                    for d, c in el.coords.items() if d != "_flat"
                })
                yield g, coords

        else:
            # zero-dimensional: extract single graph item
            g = graphs.item()
            coords = {d: c.item() for d, c in graphs.coords.items()}
            yield g, coords

    # Don't show the axis. This is also done in GraphPlot.draw but need to do
    # it here in case the first frame is empty.
    hlpr.ax.axis("off")

    # Indicator for things to be done only once after the first GraphPlot.draw
    # call.
    _first_graph = True

    # Always configure the colormanagers on the first GraphPlot.draw call.
    # Overwritten by the update_colormapping kwarg afterwards.
    _update_colormapping = True

    # Loop over all remaining graphs
    for g, coords in get_graph_and_coords(graphs):

        _missing_val = not isinstance(g, nx.Graph)

        # On missing entries, skip the frame if skip_empty_frames=True.
        # If skip_empty_frames=False, the suptitle helper is applied but no
        # graph is drawn.
        if _missing_val and skip_empty_frames:
            continue

        if not _missing_val:
            # Use a new GraphPlot for the next frame such that the `select`
            # kwargs are re-evaluated.
            gp = GraphPlot(
                g=g,
                fig=hlpr.fig,
                ax=hlpr.ax,
                positions=positions,
                **drawing_kwargs,
            )
            gp.draw(
                suppress_cbar=not _update_colormapping,
                update_colormapping=_update_colormapping,
            )

        # Update the suptitle format string and invoke the helper
        st_kwargs = copy.deepcopy(suptitle_kwargs)
        st_kwargs["title"] = st_kwargs["title"].format(**coords)
        hlpr.invoke_helper("set_suptitle", **st_kwargs)

        if _first_graph:
            # Fix the y-position of the suptitle after the first invocation
            # since repetitive invocations of the set-suptitle helper would
            # re-adjust the subplots size each time. Use the matplotlib default
            # if not given.
            if "y" not in suptitle_kwargs:
                suptitle_kwargs["y"] = 0.98

            _update_colormapping = update_colormapping

        # Let the writer grab the current frame
        yield

        if not _missing_val:
            # Clean up the frame
            gp.clear_plot(keep_colorbars=not _update_colormapping)

            if _first_graph:
                # If the positions should be fixed, overwrite the positions arg
                if not update_positions:
                    positions = dict(from_dict=gp.positions)

                # Done handling the first GraphPlot.draw call
                _first_graph = False