コード例 #1
0
ファイル: plot.py プロジェクト: hendrikheller/wyrm
def plot_tenten(data, highlights=None, hcolors=None, legend=False, scale=True,
                reg_chans=None):
    """Plots channels on a grid system.

    Iterates over every channel in the data structure. If the channelname
    matches a channel in the tenten-system it will be plotted in a grid of
    rectangles. The grid is structured like the tenten-system itself, but in
    a simplified manner. The rows, in which channels appear, are predetermined,
    the channels are ordered automatically within their respective row.
    Areas to highlight can be specified, those areas will be marked with colors
    in every timeinterval plot.

    Parameters
    ----------
    data : wyrm.types.Data
        Data object containing the data to plot.
    highlights : [[int, int)]
        List of tuples containing the start point (included) and end point
        (excluded) of each area to be highlighted (default: None).
    hcolors : [colors], optional
        A list of colors to use for the highlight areas (default: None).
    legend : Boolean, optional
        Flag to switch plotting of the legend on or off (default: True).
    scale : Boolean, optional
        Flag to switch plotting of a scale in the top right corner of the grid
        (default: True)
    reg_chans : [regular expressions]
        A list of regular expressions. The plot will be limited to those
        channels matching the regular expressions.

    Returns
    -------
    [Matplotlib.Axes], Matplotlib.Axes
        Returns the plotted timeinterval axes as a list of Matplotlib.Axes and
        the plotted scale as a single Matplotlib.Axes.

    Examples
    --------
    Plotting of all channels within a Data object
    >>> plot_tenten(data)

    Plotting of all channels with a highlighted area
    >>> plot_tenten(data, highlights=[[200, 400]])

    Plotting of all channels beginning with 'A'
    >>> plot_tenten(data, reg_chans=['A.*'])
    """
    dcopy = data.copy()
    # this dictionary determines which y-position corresponds with which row in the grid
    ordering = {4.0: 0,
                3.5: 0,
                3.0: 1,
                2.5: 2,
                2.0: 3,
                1.5: 4,
                1.0: 5,
                0.5: 6,
                0.0: 7,
                -0.5: 8,
                -1.0: 9,
                -1.5: 10,
                -2.0: 11,
                -2.5: 12,
                -2.6: 12,
                -3.0: 13,
                -3.5: 14,
                -4.0: 15,
                -4.5: 15,
                -5.0: 16}

    # all the channels with their x- and y-position
    system = _get_system()

    # create list with 17 empty lists. one for every potential row of channels.
    channel_lists = []
    for i in range(18):
        channel_lists.append([])

    if reg_chans is not None:
        dcopy = pro.select_channels(dcopy, reg_chans)

    # distribute the channels to the lists by their y-position
    count = 0
    for c in dcopy.axes[1]:
        if c in tts.channels:
            # entries in channel_lists: [<channel_name>, <x-position>, <position in Data>]
            channel_lists[ordering[system[c][1]]].append((c, system[c][0], count))
        count += 1

    # sort the lists of channels by their x-position
    for l in channel_lists:
        l.sort(key=lambda c_list: c_list[1])

    # calculate the needed dimensions of the grid
    columns = map(len, channel_lists)
    columns = [value for value in columns if value != 0]

    # add another axes to the first row for the scale
    columns[0] += 1

    plt.figure()
    grid = calc_centered_grid(columns, hpad=.01, vpad=.01)

    # axis used for sharing axes between channels
    masterax = None
    ax = []

    row = 0
    k = 0
    scale_ax = 0

    for l in channel_lists:
        if len(l) > 0:
            for i in range(len(l)):
                ax.append(_subplot_timeinterval(dcopy, grid[k], epoch=-1, highlights=highlights, hcolors=hcolors, labels=False,
                                                legend=legend, channel=l[i][2], shareaxis=masterax))
                if masterax is None and len(ax) > 0:
                    masterax = ax[0]

                # hide the axeslabeling
                plt.tick_params(axis='both', which='both', labelbottom='off', labeltop='off', labelleft='off',
                                labelright='off', top='off', right='off')

                # at this moment just to show what's what
                plt.gca().annotate(l[i][0], (0.05, 0.05), xycoords='axes fraction')

                k += 1

                if row == 0 and i == len(l)-1:
                    # this is the last axes in the first row
                    scale_ax = k
                    k += 1

            row += 1

    # plot the scale axes
    xtext = dcopy.axes[0][len(dcopy.axes[0])-1]
    sc = _subplot_scale(str(xtext) + ' ms', "$\mu$V", position=grid[scale_ax])

    return ax, sc
コード例 #2
0
ファイル: plot.py プロジェクト: hendrikheller/wyrm
def plot_timeinterval(data, r_square=None, highlights=None, hcolors=None,
                      legend=True, reg_chans=None, position=None):
    """Plots a simple time interval.

    Plots all channels of either continuous data or the mean of epoched data
    into a single timeinterval plot.

    Parameters
    ----------
    data : wyrm.types.Data
        Data object containing the data to plot.
    r_square : [values], optional
        List containing r_squared values to be plotted beneath the main plot
        (default: None).
    highlights : [[int, int)]
        List of tuples containing the start point (included) and end point
        (excluded) of each area to be highlighted (default: None).
    hcolors : [colors], optional
        A list of colors to use for the highlights areas (default: None).
    legend : Boolean, optional
        Flag to switch plotting of the legend on or off (default: True).
    reg_chans : [regular expression], optional
        A list of regular expressions. The plot will be limited to those
        channels matching the regular expressions. (default: None).
    position : [x, y, width, height], optional
        A Rectangle that limits the plot to its boundaries (default: None).

    Returns
    -------
    Matplotlib.Axes or (Matplotlib.Axes, Matplotlib.Axes)
        The Matplotlib.Axes corresponding to the plotted timeinterval and, if
        provided, the Axes corresponding to r_squared values.

    Examples
    --------
    Plots all channels contained in data with a legend.

    >>> plot_timeinterval(data)

    Same as above, but without the legend.

    >>> plot_timeinterval(data, legend=False)

    Adds r-square values to the plot.

    >>> plot_timeinterval(data, r_square=[values])

    Adds a highlighted area to the plot.

    >>> plot_timeinterval(data, highlights=[[200, 400]])

    To specify the colors of the highlighted areas use 'hcolors'.

    >>> plot_timeinterval(data, highlights=[[200, 400]], hcolors=['red'])
    """

    dcopy = data.copy()
    rect_ti_solo = [.07, .07, .9, .9]
    rect_ti_r2 = [.07, .12, .9, .85]
    rect_r2 = [.07, .07, .9, .05]

    if position is None:
        plt.figure()
        if r_square is None:
            pos_ti = rect_ti_solo
        else:
            pos_ti = rect_ti_r2
            pos_r2 = rect_r2
    else:
        if r_square is None:
            pos_ti = _transform_rect(position, rect_ti_solo)
        else:
            pos_ti = _transform_rect(position, rect_ti_r2)
            pos_r2 = _transform_rect(position, rect_r2)

    if reg_chans is not None:
        dcopy = pro.select_channels(dcopy, reg_chans)

    # process epoched data into continuous data using the mean
    if len(data.data.shape) > 2:
        dcopy = Data(np.mean(dcopy.data, axis=0), [dcopy.axes[-2], dcopy.axes[-1]],
                     [dcopy.names[-2], dcopy.names[-1]], [dcopy.units[-2], dcopy.units[-1]])

    ax1 = None
    # plotting of the data
    ax0 = _subplot_timeinterval(dcopy, position=pos_ti, epoch=-1, highlights=highlights,
                                hcolors=hcolors, legend=legend)
    ax0.xaxis.labelpad = 0
    if r_square is not None:
        ax1 = _subplot_r_square(r_square, position=pos_r2)
        ax0.tick_params(axis='x', direction='in', pad=30 * pos_ti[3])

    plt.grid(True)

    if r_square is None:
        return ax0
    else:
        return ax0, ax1