Ejemplo n.º 1
0
def plot_state(times,
               values,
               time_unit=ms,
               var_unit=None,
               var_name=None,
               axes=None,
               **kwds):
    '''

    Parameters
    ----------
    times : `~brian2.units.fundamentalunits.Quantity`
        The array of times for the data points given in ``values``.
    values : `~brian2.units.fundamentalunits.Quantity`, `~numpy.ndarray`
        The values to plot, either a 1D array with the same length as ``times``,
        or a 2D array with ``len(times)`` rows.
    time_unit : `~brian2.units.fundamentalunits.Unit`, optional
        The unit to use for the time axis. Defaults to ``ms``, but longer
        simulations could use ``second``, for example.
    var_unit : `~brian2.units.fundamentalunits.Unit`, optional
        The unit to use to plot the ``values`` (e.g. ``mV`` for a membrane
        potential). If none is given (the default), an attempt is made to
        find a good scale automatically based on the ``values``.
    var_name : str, optional
        The name of the variable that is plotted. Used for the axis label.
    axes : `~matplotlib.axes.Axes`, optional
        The `~matplotlib.axes.Axes` instance used for plotting. Defaults to
        ``None`` which means that a new `~matplotlib.axes.Axes` will be
        created for the plot.
    kwds : dict, optional
        Any additional keywords command will be handed over to matplotlib's
        `~matplotlib.axes.Axes.plot` command. This can be used to set plot
        properties such as the ``color``.

    Returns
    -------
    axes : `~matplotlib.axes.Axes`
        The `~matplotlib.axes.Axes` instance that was used for plotting. This
        object allows to modify the plot further, e.g. by setting the plotted
        range, the axis labels, the plot title, etc.
    '''
    # Avoid circular import issues
    from brian2tools.plotting.base import _setup_axes_matplotlib
    axes = _setup_axes_matplotlib(axes)
    if var_unit is None:
        if isinstance(values, Quantity):
            var_unit = _get_best_unit(values)
    if var_unit is not None:
        values /= var_unit
    axes.plot(times / time_unit, values, **kwds)
    axes.set_xlabel('time (%s)' % time_unit)
    if var_unit is not None:
        axes.set_ylabel('%s (%s)' % (var_name, var_unit))
    else:
        axes.set_ylabel('%s' % var_name)
    return axes
Ejemplo n.º 2
0
def plot_state(times, values, time_unit=ms, var_unit=None, var_name=None,
               axes=None, **kwds):
    '''

    Parameters
    ----------
    times : `~brian2.units.fundamentalunits.Quantity`
        The array of times for the data points given in ``values``.
    values : `~brian2.units.fundamentalunits.Quantity`, `~numpy.ndarray`
        The values to plot, either a 1D array with the same length as ``times``,
        or a 2D array with ``len(times)`` rows.
    time_unit : `~brian2.units.fundamentalunits.Unit`, optional
        The unit to use for the time axis. Defaults to ``ms``, but longer
        simulations could use ``second``, for example.
    var_unit : `~brian2.units.fundamentalunits.Unit`, optional
        The unit to use to plot the ``values`` (e.g. ``mV`` for a membrane
        potential). If none is given (the default), an attempt is made to
        find a good scale automatically based on the ``values``.
    var_name : str, optional
        The name of the variable that is plotted. Used for the axis label.
    axes : `~matplotlib.axes.Axes`, optional
        The `~matplotlib.axes.Axes` instance used for plotting. Defaults to
        ``None`` which means that a new `~matplotlib.axes.Axes` will be
        created for the plot.
    kwds : dict, optional
        Any additional keywords command will be handed over to matplotlib's
        `~matplotlib.axes.Axes.plot` command. This can be used to set plot
        properties such as the ``color``.

    Returns
    -------
    axes : `~matplotlib.axes.Axes`
        The `~matplotlib.axes.Axes` instance that was used for plotting. This
        object allows to modify the plot further, e.g. by setting the plotted
        range, the axis labels, the plot title, etc.
    '''
    # Avoid circular import issues
    from brian2tools.plotting.base import _setup_axes_matplotlib
    axes = _setup_axes_matplotlib(axes)
    if var_unit is None:
        if isinstance(values, Quantity):
            var_unit = values._get_best_unit()
    if var_unit is not None:
        values /= var_unit
    axes.plot(times / time_unit, values, **kwds)
    axes.set_xlabel('time (%s)' % time_unit)
    if var_unit is not None:
        axes.set_ylabel('%s (%s)' % (var_name, var_unit))
    else:
        axes.set_ylabel('%s' % var_name)
    return axes
Ejemplo n.º 3
0
def plot_raster(spike_indices, spike_times, time_unit=ms,
                axes=None, **kwds):
    '''
    Plot a "raster plot", a plot of neuron indices over spike times. The default
    marker used for plotting is ``'.'``, it can be overriden with the ``marker``
    keyword argument.

    Parameters
    ----------
    spike_indices : `~numpy.ndarray`
        The indices of spiking neurons, corresponding to the times given in
        ``spike_times``.
    spike_times : `~brian2.units.fundamentalunits.Quantity`
        A sequence of spike times.
    time_unit : `~brian2.units.fundamentalunits.Unit`, optional
        The unit to use for the time axis. Defaults to ``ms``, but longer
        simulations could use ``second``, for example.
    axes : `~matplotlib.axes.Axes`, optional
        The `~matplotlib.axes.Axes` instance used for plotting. Defaults to
        ``None`` which means that a new `~matplotlib.axes.Axes` will be
        created for the plot.
    kwds : dict, optional
        Any additional keywords command will be handed over to matplotlib's
        `~matplotlib.axes.Axes.plot` command. This can be used to set plot
        properties such as the ``color``.

    Returns
    -------
    axes : `~matplotlib.axes.Axes`
        The `~matplotlib.axes.Axes` instance that was used for plotting. This
        object allows to modify the plot further, e.g. by setting the plotted
        range, the axis labels, the plot title, etc.
    '''
    # Avoid circular import issues
    from brian2tools.plotting.base import _setup_axes_matplotlib
    axes = _setup_axes_matplotlib(axes)
    axes.plot(spike_times/time_unit, spike_indices, '.', **kwds)
    axes.set_xlabel('time (%s)' % time_unit)
    axes.set_ylabel('neuron index')
    return axes
Ejemplo n.º 4
0
def plot_raster(spike_indices, spike_times, time_unit=ms, axes=None, **kwds):
    '''
    Plot a "raster plot", a plot of neuron indices over spike times. The default
    marker used for plotting is ``'.'``, it can be overriden with the ``marker``
    keyword argument.

    Parameters
    ----------
    spike_indices : `~numpy.ndarray`
        The indices of spiking neurons, corresponding to the times given in
        ``spike_times``.
    spike_times : `~brian2.units.fundamentalunits.Quantity`
        A sequence of spike times.
    time_unit : `~brian2.units.fundamentalunits.Unit`, optional
        The unit to use for the time axis. Defaults to ``ms``, but longer
        simulations could use ``second``, for example.
    axes : `~matplotlib.axes.Axes`, optional
        The `~matplotlib.axes.Axes` instance used for plotting. Defaults to
        ``None`` which means that a new `~matplotlib.axes.Axes` will be
        created for the plot.
    kwds : dict, optional
        Any additional keywords command will be handed over to matplotlib's
        `~matplotlib.axes.Axes.plot` command. This can be used to set plot
        properties such as the ``color``.

    Returns
    -------
    axes : `~matplotlib.axes.Axes`
        The `~matplotlib.axes.Axes` instance that was used for plotting. This
        object allows to modify the plot further, e.g. by setting the plotted
        range, the axis labels, the plot title, etc.
    '''
    # Avoid circular import issues
    from brian2tools.plotting.base import _setup_axes_matplotlib
    axes = _setup_axes_matplotlib(axes)
    axes.plot(spike_times / time_unit, spike_indices, '.', **kwds)
    axes.set_xlabel('time (%s)' % time_unit)
    axes.set_ylabel('neuron index')
    return axes
Ejemplo n.º 5
0
def plot_rate(times, rate, time_unit=ms, rate_unit=Hz, axes=None, **kwds):
    '''

    Parameters
    ----------
    times : `~brian2.units.fundamentalunits.Quantity`
        The time points at which the ``rate`` is measured.
    rate : `~brian2.units.fundamentalunits.Quantity`
        The population rate for each time point in ``times``
    time_unit : `~brian2.units.fundamentalunits.Unit`, optional
        The unit to use for the time axis. Defaults to ``ms``, but longer
        simulations could use ``second``, for example.
    time_unit : `~brian2.units.fundamentalunits.Unit`, optional
        The unit to use for the rate axis. Defaults to ``Hz``.
    axes : `~matplotlib.axes.Axes`, optional
        The `~matplotlib.axes.Axes` instance used for plotting. Defaults to
        ``None`` which means that a new `~matplotlib.axes.Axes` will be
        created for the plot.
    kwds : dict, optional
        Any additional keywords command will be handed over to matplotlib's
        `~matplotlib.axes.Axes.plot` command. This can be used to set plot
        properties such as the ``color``.

    Returns
    -------
    axes : `~matplotlib.axes.Axes`
        The `~matplotlib.axes.Axes` instance that was used for plotting. This
        object allows to modify the plot further, e.g. by setting the plotted
        range, the axis labels, the plot title, etc.
    '''
    # Avoid circular import issues
    from brian2tools.plotting.base import _setup_axes_matplotlib
    axes = _setup_axes_matplotlib(axes)
    axes.plot(times / time_unit, rate / rate_unit, **kwds)
    axes.set_xlabel('time (%s)' % time_unit)
    axes.set_ylabel('population rate (%s)' % rate_unit)
    return axes
Ejemplo n.º 6
0
def plot_rate(times, rate, time_unit=ms, rate_unit=Hz, axes=None, **kwds):
    '''

    Parameters
    ----------
    times : `~brian2.units.fundamentalunits.Quantity`
        The time points at which the ``rate`` is measured.
    rate : `~brian2.units.fundamentalunits.Quantity`
        The population rate for each time point in ``times``
    time_unit : `~brian2.units.fundamentalunits.Unit`, optional
        The unit to use for the time axis. Defaults to ``ms``, but longer
        simulations could use ``second``, for example.
    time_unit : `~brian2.units.fundamentalunits.Unit`, optional
        The unit to use for the rate axis. Defaults to ``Hz``.
    axes : `~matplotlib.axes.Axes`, optional
        The `~matplotlib.axes.Axes` instance used for plotting. Defaults to
        ``None`` which means that a new `~matplotlib.axes.Axes` will be
        created for the plot.
    kwds : dict, optional
        Any additional keywords command will be handed over to matplotlib's
        `~matplotlib.axes.Axes.plot` command. This can be used to set plot
        properties such as the ``color``.

    Returns
    -------
    axes : `~matplotlib.axes.Axes`
        The `~matplotlib.axes.Axes` instance that was used for plotting. This
        object allows to modify the plot further, e.g. by setting the plotted
        range, the axis labels, the plot title, etc.
    '''
    # Avoid circular import issues
    from brian2tools.plotting.base import _setup_axes_matplotlib
    axes = _setup_axes_matplotlib(axes)
    axes.plot(times/time_unit, rate/rate_unit, **kwds)
    axes.set_xlabel('time (%s)' % time_unit)
    axes.set_ylabel('population rate (%s)' % rate_unit)
    return axes
def plot_morphology(morphology, plot_3d=None, show_compartments=False,
                    show_diameter=False, colors=('darkblue', 'darkred'),
                    values=None, value_norm=(None, None), value_colormap='hot',
                    value_colorbar=True, value_unit=None, axes=None):
    '''
    Plot a given `~brian2.spatialneuron.morphology.Morphology` in 2D or 3D.

    Parameters
    ----------
    morphology : `~brian2.spatialneuron.morphology.Morphology`
        The morphology to plot
    plot_3d : bool, optional
        Whether to plot the morphology in 3D or in 2D. If not set (the default)
        a morphology where all z values are 0 is plotted in 2D, otherwise it is
        plot in 3D.
    show_compartments : bool, optional
        Whether to plot a dot at the center of each compartment. Defaults to
        ``False``.
    show_diameter : bool, optional
        Whether to plot the compartments with the diameter given in the
        morphology. Defaults to ``False``.
    colors : sequence of color specifications
        A list of colors that is cycled through for each new section. Can be
        any color specification that matplotlib understands (e.g. a string such
        as ``'darkblue'`` or a tuple such as `(0, 0.7, 0)`.
    values : ~brian2.units.fundamentalunits.Quantity, optional
        Values to fill compartment patches with a color that corresponds to
        their given value.
    value_norm : tuple or callable, optional
        Normalization function to scale the displayed values. Can be a tuple
        of a minimum and a maximum value (where either of them can be ``None``
        to denote taking the minimum/maximum from the data) or a function that
        takes a value and returns the scaled value (e.g. as returned by
        `.matplotlib.colors.PowerNorm`). For a tuple of values, will use
        `.matplotlib.colors.Normalize```(vmin, vmax, clip=True)``` with the
        given ``(vmin, vmax)`` values.
    value_colormap : str or matplotlib.colors.Colormap, optional
        Desired colormap for plots. Either the name of a standard colormap
        or a `.matplotlib.colors.Colormap` instance. Defaults to ``'hot'``.
        Note that this uses ``matplotlib`` color maps even for 3D plots with
        Mayavi.
    value_colorbar : bool or dict, optional
        Whether to add a colorbar for the ``values``. Defaults to ``True``,
        but will be ignored if no ``values`` are provided. Can also be a
        dictionary with the keyword arguments for matplotlib's
        `~.matplotlib.figure.Figure.colorbar` method (2D plot), or for
        Mayavi's `~.mayavi.mlab.scalarbar` method (3D plot).
    value_unit : `Unit`, optional
        A `Unit` to rescale the values for display in the colorbar. Does not
        have any visible effect if no colorbar is used. If not specified, will
        try to determine the "best unit" to itself.
    axes : `~matplotlib.axes.Axes` or `~mayavi.core.api.Scene`, optional
        A matplotlib `~matplotlib.axes.Axes` (for 2D plots) or mayavi
        `~mayavi.core.api.Scene` ( for 3D plots) instance, where the plot will
        be added.

    Returns
    -------
    axes : `~matplotlib.axes.Axes` or `~mayavi.core.api.Scene`
        The `~matplotlib.axes.Axes` or `~mayavi.core.api.Scene` instance that
        was used for plotting. This object allows to modify the plot further,
        e.g. by setting the plotted range, the axis labels, the plot title, etc.
    '''
    # Avoid circular import issues
    from brian2tools.plotting.base import (_setup_axes_matplotlib,
                                           _setup_axes_mayavi)

    if plot_3d is None:
        # Decide whether to use 2d or 3d plotting based on the coordinates
        flat_morphology = FlatMorphology(morphology)
        plot_3d = any(np.abs(flat_morphology.z) > 1e-12)

    if values is not None:
        if hasattr(values, 'name'):
            value_varname = values.name
        else:
            value_varname = 'values'
        if value_unit is not None:
            if not isinstance(value_unit, Unit):
                raise TypeError(f'\'value_unit\' has to be a unit but is'
                                f'\'{type(value_unit)}\'.')
            fail_for_dimension_mismatch(value_unit, values,
                                        'The \'value_unit\' arguments needs '
                                        'to have the same dimensions as '
                                        'the \'values\'.')
        else:
            if have_same_dimensions(values, DIMENSIONLESS):
                value_unit = 1.
            else:
                value_unit = values[:].get_best_unit()
        orig_values = values
        values = values/value_unit
        if isinstance(value_norm, tuple):
            if not len(value_norm) == 2:
                raise TypeError('Need a (vmin, vmax) tuple for the value '
                                'normalization, but got a tuple of length '
                                f'{len(value_norm)}.')
            vmin, vmax = value_norm
            if vmin is not None:
                err_msg = ('The minimum value in \'value_norm\' needs to '
                           'have the same units as \'values\'.')
                fail_for_dimension_mismatch(vmin, orig_values,
                                            error_message=err_msg)
                vmin /= value_unit
            if vmax is not None:
                err_msg = ('The maximum value in \'value_norm\' needs to '
                           'have the same units as \'values\'.')
                fail_for_dimension_mismatch(vmax, orig_values,
                                            error_message=err_msg)
                vmax /= value_unit
            if plot_3d:
                value_norm = (vmin, vmax)
            else:
                value_norm = Normalize(vmin=vmin, vmax=vmax, clip=True)
                value_norm.autoscale_None(values)
        elif plot_3d:
            raise TypeError('3d plots only support normalizations given by '
                            'a (min, max) tuple.')
        value_colormap = plt.get_cmap(value_colormap)

    if plot_3d:
        try:
            import mayavi.mlab as mayavi
        except ImportError:
            raise ImportError('3D plotting needs the mayavi library')
        axes = _setup_axes_mayavi(axes)
        axes.scene.disable_render = True
        surf = _plot_morphology3D(morphology, axes, colors=colors,
                                  values=values, value_norm=value_norm,
                                  value_colormap=value_colormap,
                                  show_diameters=show_diameter,
                                  show_compartments=show_compartments)
        if values is not None and value_colorbar:
            if not isinstance(value_colorbar, Mapping):
                value_colorbar = {}
                if not have_same_dimensions(value_unit, DIMENSIONLESS):
                    unit_str = f' ({value_unit!s})'
                else:
                    unit_str = ''
                if value_varname:
                    value_colorbar['title'] = f'{value_varname}{unit_str}'
            cb = mayavi.scalarbar(surf, **value_colorbar)
            # Make text dark gray
            cb.title_text_property.color = (0.1, 0.1, 0.1)
            cb.label_text_property.color = (0.1, 0.1, 0.1)
        axes.scene.disable_render = False
    else:
        axes = _setup_axes_matplotlib(axes)

        _plot_morphology2D(morphology, axes, colors,
                           values, value_norm, value_colormap,
                           show_compartments=show_compartments,
                           show_diameter=show_diameter)
        axes.set_xlabel('x (um)')
        axes.set_ylabel('y (um)')
        axes.set_aspect('equal')
        if values is not None and value_colorbar:
            divider = make_axes_locatable(axes)
            cax = divider.append_axes("right", size="5%", pad=0.1)
            mappable = ScalarMappable(norm=value_norm, cmap=value_colormap)
            mappable.set_array([])
            fig = axes.get_figure()
            if not isinstance(value_colorbar, Mapping):
                value_colorbar = {}
                if not have_same_dimensions(value_unit, DIMENSIONLESS):
                    unit_str = f' ({value_unit!s})'
                else:
                    unit_str = ''
                if value_varname:
                    value_colorbar['label'] = f'{value_varname}{unit_str}'
            fig.colorbar(mappable, cax=cax, **value_colorbar)
    return axes
Ejemplo n.º 8
0
def plot_synapses(sources, targets, values=None, var_unit=None,
                  var_name=None, plot_type='scatter', axes=None, **kwds):
    '''
    Parameters
    ----------
    sources : `~numpy.ndarray` of int
        The source indices of the connections (as returned by
        ``Synapses.i``).
    targets : `~numpy.ndarray` of int
        The target indices of the connections (as returned by
        ``Synapses.j``).
    values : `~brian2.units.fundamentalunits.Quantity`, `~numpy.ndarray`
        The values to plot, a 1D array of the same size as ``sources`` and
        ``targets``.
    var_unit : `~brian2.units.fundamentalunits.Unit`, optional
        The unit to use to plot the ``values`` (e.g. ``mV`` for a membrane
        potential). If none is given (the default), an attempt is made to
        find a good scale automatically based on the ``values``.
    var_name : str, optional
        The name of the variable that is plotted. Used for the axis label.
    plot_type : {``'scatter'``, ``'image'``, ``'hexbin'``}, optional
        What type of plot to use. Can be ``'scatter'`` (the default) to draw
        a scatter plot, ``'image'`` to display the connections as a matrix or
        ``'hexbin'`` to display a 2D histogram using matplotlib's
        `~matplotlib.axes.Axes.hexbin` function.
        For a large number of synapses, ``'scatter'`` will be very slow.
        Similarly, an ``'image'`` plot will use a lot of memory for connections
        between two large groups. For a small number of neurons and synapses,
        ``'hexbin'`` will be hard to interpret.
    axes : `~matplotlib.axes.Axes`, optional
        The `~matplotlib.axes.Axes` instance used for plotting. Defaults to
        ``None`` which means that a new `~matplotlib.axes.Axes` will be
        created for the plot.
    kwds : dict, optional
        Any additional keywords command will be handed over to the respective
        matplotlib command (`~matplotlib.axes.Axes.scatter` if the
        ``plot_type`` is ``'scatter'``, `~matplotlib.axes.Axes.imshow` for
        ``'image'``, and `~matplotlib.axes.Axes.hexbin` for ``'hexbin'``).
        This can be used to set plot properties such as the ``marker``.

    Returns
    -------
    axes : `~matplotlib.axes.Axes`
        The `~matplotlib.axes.Axes` instance that was used for plotting. This
        object allows to modify the plot further, e.g. by setting the plotted
        range, the axis labels, the plot title, etc.
    '''
    # Avoid circular import issues
    from brian2tools.plotting.base import _setup_axes_matplotlib
    axes = _setup_axes_matplotlib(axes)

    sources = np.asarray(sources)
    targets = np.asarray(targets)
    if not len(sources) == len(targets):
        raise TypeError('Length of sources and targets does not match.')

    if plot_type not in ['scatter', 'image', 'hexbin']:
        raise ValueError("plot_type has to be either 'scatter', 'image', or "
                         "'hexbin' (was: %r)" % plot_type)

    # Get some information out of the values if provided
    if values is not None:
        if len(values) != len(sources):
            raise TypeError('Length of values and sources/targets does not '
                            'match.')
        if var_name is None:
            var_name = getattr(values, 'name', None)  # works for a VariableView
        if var_unit is None:
            try:
                var_unit = values[:]._get_best_unit()
            except AttributeError:
                pass
        if var_unit is not None:
            values = values / var_unit

    if plot_type != 'hexbin':
        # For "hexbin", we are binning multiple synapses anyway, so we don't
        # have to make a difference for multiple synapses
        connection_count = Counter(zip(sources, targets))
        multiple_synapses = np.any(np.array(list(connection_count.values())) > 1)

    edgecolor = kwds.pop('edgecolor', 'none')

    if plot_type != 'hexbin' and multiple_synapses:
        if values is not None:
            raise NotImplementedError("Plotting variables with multiple "
                                      "synapses per source-target pair is only "
                                      "implemented for 'hexbin' plots.")
        unique_sources, unique_targets = zip(*connection_count.keys())
        n_synapses = list(connection_count.values())
        bounds, cmap, norm = _discrete_color_mapping(kwds.pop('cmap', None),
                                                     n_synapses)
        # Make the plot
        if plot_type == 'scatter':
            marker = kwds.pop('marker', ',')
            axes.scatter(unique_sources, unique_targets, marker=marker,
                         c=n_synapses, edgecolor=edgecolor, cmap=cmap,
                         norm=norm, **kwds)
        else:
            assert np.max(n_synapses) < 256
            matrix = _int_connection_matrix(unique_sources, unique_targets,
                                                 n_synapses)
            origin = kwds.pop('origin', 'lower')
            interpolation = kwds.pop('interpolation', 'nearest')
            axes.imshow(matrix, origin=origin, interpolation=interpolation,
                        cmap=cmap, norm=norm,
                        extent=(min(unique_sources) - 0.5, max(unique_sources) + 0.5,
                                min(unique_targets) - 0.5, max(unique_targets) + 0.5),
                        **kwds)

        # Add the colorbar
        locatable_axes = make_axes_locatable(axes)
        cax = locatable_axes.append_axes('right', size='5%', pad=0.05)
        mpl.colorbar.ColorbarBase(cax, cmap=cmap,
                                  norm=norm,
                                  ticks=bounds-0.5)
        cax.set_ylabel('number of synapses')
    else:
        if plot_type == 'scatter':
            marker = kwds.pop('marker', ',')
            color = kwds.pop('color', values if values is not None else None)
            plotted = axes.scatter(sources, targets, marker=marker, c=color,
                                   edgecolor=edgecolor, **kwds)
        elif plot_type == 'image':
            if values is not None:
                matrix = _float_connection_matrix(sources, targets, values)
            else:
                matrix = _int_connection_matrix(sources, targets, 1)
            origin = kwds.pop('origin', 'lower')
            interpolation = kwds.pop('interpolation', 'nearest')
            vmin = kwds.pop('vmin', 1 if values is None else None)
            plotted = axes.imshow(matrix, origin=origin,
                                  interpolation=interpolation,
                                  vmin=vmin,
                                  extent=(min(sources) - 0.5, max(sources) + 0.5,
                                          min(targets) - 0.5, max(targets) + 0.5),
                                  **kwds)
        elif plot_type == 'hexbin':
            if values is None:  # Counting synapses
                mincnt = kwds.pop('mincnt', 1)
            else:
                mincnt = kwds.pop('mincnt', None)
            plotted = axes.hexbin(sources, targets, C=values, mincnt=mincnt,
                                  **kwds)

        if values is not None or plot_type == 'hexbin':
            # Add a colorbar
            locatable_axes = make_axes_locatable(axes)
            cax = locatable_axes.append_axes('right', size='7.5%', pad=0.05)
            plt.colorbar(plotted, cax=cax)
            if var_name is None:
                if var_unit is not None:
                    cax.set_ylabel('in units of %s' % str(var_unit))
            else:
                label = var_name
                if var_unit is not None:
                    label += ' (%s)' % str(var_unit)
                cax.set_ylabel(label)

    axes.set_xlim(-0.5, max(sources) + 0.5)
    axes.set_ylim(-0.5, max(targets) + 0.5)
    axes.set_xlabel('source neuron index')
    axes.set_ylabel('target neuron index')
    # Prevent floating point values on the axes (e.g. when zooming in)
    axes.xaxis.set_major_locator(MaxNLocator(integer=True))
    axes.yaxis.set_major_locator(MaxNLocator(integer=True))
    return axes
Ejemplo n.º 9
0
def plot_dendrogram(morphology, axes=None):
    """
    Plot a "dendrogram" of a morphology, i.e. an abstract representation which
    visualizes the branching structure and the length of each section.

    Parameters
    ----------
    morphology : `~brian2.spatialneuron.morphology.Morphology`
        The morphology to visualize.
    axes : `~matplotlib.axes.Axes`, optional
        The `~matplotlib.axes.Axes` instance used for plotting. Defaults to
        ``None`` which means that a new `~matplotlib.axes.Axes` will be
        created for the plot.

    Returns
    -------
    axes : `~matplotlib.axes.Axes`
        The `~matplotlib.axes.Axes` instance that was used for plotting. This
        object allows to modify the plot further, e.g. by setting the plotted
        range, the axis labels, the plot title, etc.
    """
    # Avoid circular import issues
    from brian2tools.plotting.base import _setup_axes_matplotlib

    axes = _setup_axes_matplotlib(axes)
    # Get some information from the flattened morphology
    flat_morpho = FlatMorphology(morphology)
    section_depth = flat_morpho.depth[flat_morpho.starts]
    section_distance = flat_morpho.end_distance / float(um)
    n_sections = flat_morpho.sections
    max_depth = max(flat_morpho.depth)
    max_children = max(flat_morpho.morph_children_num)
    children = flat_morpho.morph_children

    length_metric = section_distance

    # Each point should be in the middle of its two outermost terminal points
    # We go backwards through the tree, noting for each point all terminal
    # indices in its subtree
    terminals = [set() for _ in range(n_sections)]
    terminal_counter = 0
    for d in range(max_depth, -1, -1):
        for idx in np.nonzero(section_depth == d)[0]:
            child_start_idx = (idx + 1) * max_children
            num_children = flat_morpho.morph_children_num[idx + 1]
            if num_children == 0:
                terminals[idx] = {terminal_counter}
                terminal_counter += 1
            else:
                child_indices = children[child_start_idx : child_start_idx + num_children]
                terminals[idx].update(*[terminals[c - 1] for c in child_indices])

    # Now we make sure that subtrees starting at a lower x value will be left
    # of other subtrees
    # This is probably not the most efficient algorithm, but it seems to work
    order_strings = [[] for _ in range(terminal_counter)]
    for idx in np.argsort(length_metric):
        child_terminals = terminals[idx]
        for t, order_string in enumerate(order_strings):
            if t in child_terminals:
                order_string.extend("A")
            else:
                order_string.extend("B")
    order_strings = ["".join(s) for s in order_strings]
    terminal_x_values = np.argsort(np.argsort(order_strings))
    # Use the re-arranged values to calculate the actual x value for the tree
    min_index = [min(terminal_x_values[np.array(list(ts), dtype=int)]) for ts in terminals]
    max_index = [max(terminal_x_values[np.array(list(ts), dtype=int)]) for ts in terminals]

    x_values = (np.array(min_index) + np.array(max_index)) / 2.0

    # Plot the dendogram with lengths of the vertical lines representing the
    # total distance to the root
    plt.plot(x_values[0], length_metric[0], "ko", clip_on=False)
    for sec, (x, depth) in enumerate(zip(x_values, length_metric)):
        child_start_idx = (sec + 1) * max_children
        num_children = flat_morpho.morph_children_num[sec + 1]
        if num_children > 0:
            child_indices = children[child_start_idx : child_start_idx + num_children]
            child_depth = length_metric[child_indices - 1]
            child_x = x_values[child_indices - 1]
            axes.vlines(child_x, depth, child_depth, clip_on=False, lw=2)
            axes.hlines(depth, min(child_x), max(child_x), lw=2)
    axes.set_xticks([])
    axes.set_ylabel("distance from root (um)")
    axes.set_xlim(-1, terminal_counter)
    return axes
Ejemplo n.º 10
0
def plot_morphology(
    morphology, plot_3d=None, show_compartments=False, show_diameter=False, colors=("darkblue", "darkred"), axes=None
):
    """
    Plot a given `~brian2.spatialneuron.morphology.Morphology` in 2D or 3D.

    Parameters
    ----------
    morphology : `~brian2.spatialneuron.morphology.Morphology`
        The morphology to plot
    plot_3d : bool, optional
        Whether to plot the morphology in 3D or in 2D. If not set (the default)
        a morphology where all z values are 0 is plotted in 2D, otherwise it is
        plot in 3D.
    show_compartments : bool, optional
        Whether to plot a dot at the center of each compartment. Defaults to
        ``False``.
    show_diameter : bool, optional
        Whether to plot the compartments with the diameter given in the
        morphology. Defaults to ``False``.
    colors : sequence of color specifications
        A list of colors that is cycled through for each new section. Can be
        any color specification that matplotlib understands (e.g. a string such
        as ``'darkblue'`` or a tuple such as `(0, 0.7, 0)`.
    axes : `~matplotlib.axes.Axes` or `~mayavi.core.api.Scene`, optional
        A matplotlib `~matplotlib.axes.Axes` (for 2D plots) or mayavi
        `~mayavi.core.api.Scene` ( for 3D plots) instance, where the plot will
        be added.

    Returns
    -------
    axes : `~matplotlib.axes.Axes` or `~mayavi.core.api.Scene`
        The `~matplotlib.axes.Axes` or `~mayavi.core.api.Scene` instance that
        was used for plotting. This object allows to modify the plot further,
        e.g. by setting the plotted range, the axis labels, the plot title, etc.
    """
    # Avoid circular import issues
    from brian2tools.plotting.base import _setup_axes_matplotlib, _setup_axes_mayavi

    if plot_3d is None:
        # Decide whether to use 2d or 3d plotting based on the coordinates
        flat_morphology = FlatMorphology(morphology)
        plot_3d = any(np.abs(flat_morphology.z) > 1e-12)

    if plot_3d:
        try:
            import mayavi.mlab as mayavi
        except ImportError:
            raise ImportError("3D plotting needs the mayavi library")
        axes = _setup_axes_mayavi(axes)
        axes.scene.disable_render = True
        _plot_morphology3D(
            morphology, axes, colors=colors, show_diameters=show_diameter, show_compartments=show_compartments
        )
        axes.scene.disable_render = False
    else:
        axes = _setup_axes_matplotlib(axes)
        _plot_morphology2D(morphology, axes, colors, show_compartments=show_compartments, show_diameter=show_diameter)
        axes.set_xlabel("x (um)")
        axes.set_ylabel("y (um)")
        axes.set_aspect("equal")

    return axes
Ejemplo n.º 11
0
def plot_morphology(morphology, plot_3d=None, show_compartments=False,
                    show_diameter=False, colors=('darkblue', 'darkred'),
                    axes=None):
    '''
    Plot a given `~brian2.spatialneuron.morphology.Morphology` in 2D or 3D.

    Parameters
    ----------
    morphology : `~brian2.spatialneuron.morphology.Morphology`
        The morphology to plot
    plot_3d : bool, optional
        Whether to plot the morphology in 3D or in 2D. If not set (the default)
        a morphology where all z values are 0 is plotted in 2D, otherwise it is
        plot in 3D.
    show_compartments : bool, optional
        Whether to plot a dot at the center of each compartment. Defaults to
        ``False``.
    show_diameter : bool, optional
        Whether to plot the compartments with the diameter given in the
        morphology. Defaults to ``False``.
    colors : sequence of color specifications
        A list of colors that is cycled through for each new section. Can be
        any color specification that matplotlib understands (e.g. a string such
        as ``'darkblue'`` or a tuple such as `(0, 0.7, 0)`.
    axes : `~matplotlib.axes.Axes` or `~mayavi.core.api.Scene`, optional
        A matplotlib `~matplotlib.axes.Axes` (for 2D plots) or mayavi
        `~mayavi.core.api.Scene` ( for 3D plots) instance, where the plot will
        be added.

    Returns
    -------
    axes : `~matplotlib.axes.Axes` or `~mayavi.core.api.Scene`
        The `~matplotlib.axes.Axes` or `~mayavi.core.api.Scene` instance that
        was used for plotting. This object allows to modify the plot further,
        e.g. by setting the plotted range, the axis labels, the plot title, etc.
    '''
    # Avoid circular import issues
    from brian2tools.plotting.base import (_setup_axes_matplotlib,
                                           _setup_axes_mayavi)
    if plot_3d is None:
        # Decide whether to use 2d or 3d plotting based on the coordinates
        flat_morphology = FlatMorphology(morphology)
        plot_3d = any(np.abs(flat_morphology.z) > 1e-12)

    if plot_3d:
        try:
            import mayavi.mlab as mayavi
        except ImportError:
            raise ImportError('3D plotting needs the mayavi library')
        axes = _setup_axes_mayavi(axes)
        axes.scene.disable_render = True
        _plot_morphology3D(morphology, axes, colors=colors,
                           show_diameters=show_diameter,
                           show_compartments=show_compartments)
        axes.scene.disable_render = False
    else:
        axes = _setup_axes_matplotlib(axes)
        _plot_morphology2D(morphology, axes, colors,
                           show_compartments=show_compartments,
                           show_diameter=show_diameter)
        axes.set_xlabel('x (um)')
        axes.set_ylabel('y (um)')
        axes.set_aspect('equal')

    return axes