Esempio n. 1
0
 def select_sensors(self):
     if len(self.rejected_chans) > 0:
         self.signal_processor.cnt = select_channels(
             self.signal_processor.cnt, self.rejected_chans, invert=True)
     if (self.sensor_names is not None) and (self.sensor_names
                                             is not 'all'):
         self.signal_processor.cnt = select_channels(
             self.signal_processor.cnt, self.sensor_names)
     cleaned_sensor_names = self.signal_processor.cnt.axes[-1]
     self.sensor_names = cleaned_sensor_names
Esempio n. 2
0
def clean_train_test_cnt(train_cnt,
                         test_cnt,
                         train_cleaner,
                         test_cleaner,
                         copy_data=False):
    log.info("Clean Training Set...")
    train_clean_result = train_cleaner.clean(train_cnt)
    log_clean_result(train_clean_result)
    # remove chans rejected by train cleaner from test set
    test_cnt = select_channels(test_cnt,
                               train_clean_result.rejected_chan_names,
                               invert=True)

    log.info("Clean Test Set...")
    test_clean_result = test_cleaner.clean(test_cnt, ignore_chans=True)
    log_clean_result(test_clean_result)
    assert len(test_clean_result.rejected_chan_names) == 0, (
        "There should be no rejected channels on test set, instead got "
        "{:s}".format(test_clean_result.rejected_chan_names))

    log.info("Create Cleaned Cnt Sets...")
    train_markers = list(itertools.chain(*train_cleaner.marker_def.values()))
    clean_train_cnt = restrict_cnt(train_cnt,
                                   train_markers,
                                   train_clean_result.clean_trials,
                                   train_clean_result.rejected_chan_names,
                                   copy_data=copy_data)
    test_markers = list(itertools.chain(*test_cleaner.marker_def.values()))
    clean_test_cnt = restrict_cnt(test_cnt,
                                  test_markers,
                                  test_clean_result.clean_trials,
                                  test_clean_result.rejected_chan_names,
                                  copy_data=copy_data)
    return clean_train_cnt, clean_test_cnt
Esempio n. 3
0
def clean_train_test_cnt(train_cnt, test_cnt, train_cleaner, test_cleaner,
        copy_data=False):
    log.info("Clean Training Set...")
    train_clean_result = train_cleaner.clean(train_cnt)
    log_clean_result(train_clean_result)
    # remove chans rejected by train cleaner from test set
    test_cnt = select_channels(test_cnt,
        train_clean_result.rejected_chan_names, invert=True)
    
    log.info("Clean Test Set...")
    test_clean_result = test_cleaner.clean(test_cnt, ignore_chans=True)
    log_clean_result(test_clean_result)
    assert len(test_clean_result.rejected_chan_names) == 0, (
        "There should be no rejected channels on test set, instead got "
        "{:s}".format(test_clean_result.rejected_chan_names))
    
    
    log.info("Create Cleaned Cnt Sets...")
    clean_train_cnt = restrict_cnt(train_cnt,
        train_cleaner.marker_def.values(),
        train_clean_result.clean_trials,
        train_clean_result.rejected_chan_names,
        copy_data=copy_data)
    clean_test_cnt = restrict_cnt(test_cnt, 
        test_cleaner.marker_def.values(),
        test_clean_result.clean_trials,
        test_clean_result.rejected_chan_names,
        copy_data=copy_data)
    return clean_train_cnt, clean_test_cnt
Esempio n. 4
0
def restrict_cnt(cnt, classes, clean_trials, rejected_chan_names, copy_data=False):
    cleaned_cnt = select_marker_classes(cnt, classes,
                                       copy_data)
    cleaned_cnt = select_marker_epochs(cleaned_cnt, clean_trials,
                                      copy_data)
    cleaned_cnt = select_channels(cleaned_cnt, rejected_chan_names, invert=True)
    return cleaned_cnt
 def test_select_channels(self):
     """Selecting channels with an array of regexes."""
     channels = self.dat.data.copy()
     self.dat = select_channels(self.dat, ['ca.*', 'cc1'])
     np.testing.assert_array_equal(self.dat.axes[-1],
                                   np.array(['ca1', 'ca2', 'cc1']))
     np.testing.assert_array_equal(self.dat.data,
                                   channels[:, np.array([0, 1, -1])])
 def test_select_channels_inverse(self):
     """Removing channels with an array of regexes."""
     channels = self.dat.data.copy()
     self.dat = select_channels(self.dat, ['ca.*', 'cc1'], invert=True)
     np.testing.assert_array_equal(self.dat.axes[-1],
                                   np.array(['cb1', 'cb2']))
     np.testing.assert_array_equal(self.dat.data,
                                   channels[:, np.array([2, 3])])
Esempio n. 7
0
    def preprocess_set(self):
        # only remove rejected channels now so that clean function can
        # be called multiple times without changing cleaning results
        self.cnt = select_channels(self.cnt, self.rejected_chan_names,
            invert=True)
        if self.sensor_names is not None:
            # Note this does not respect order of sensor names,
            # it selects chans form given sensor names
            # but keeps original order
            self.cnt = select_channels(self.cnt, self.sensor_names)

        if self.set_cz_to_zero is True:
            self.cnt = set_channel_to_zero(self.cnt, 'Cz')
        if self.resample_fs is not None:
            self.cnt = resample_cnt(self.cnt, newfs=self.resample_fs)
        if self.common_average_reference is True:
            self.cnt = common_average_reference_cnt(self.cnt)
        if self.standardize_cnt is True:
            self.cnt = exponential_standardize_cnt(self.cnt)
Esempio n. 8
0
def restrict_cnt(cnt,
                 classes,
                 clean_trials,
                 rejected_chan_names,
                 copy_data=False):
    cleaned_cnt = select_marker_classes(cnt, classes, copy_data)
    cleaned_cnt = select_marker_epochs(cleaned_cnt, clean_trials, copy_data)
    cleaned_cnt = select_channels(cleaned_cnt,
                                  rejected_chan_names,
                                  invert=True)
    return cleaned_cnt
Esempio n. 9
0
    def preprocess_set(self):
        # only remove rejected channels now so that clean function can
        # be called multiple times without changing cleaning results
        self.cnt = select_channels(self.cnt,
                                   self.rejected_chan_names,
                                   invert=True)
        if self.sensor_names is not None:
            # Note this does not respect order of sensor names,
            # it selects chans form given sensor names
            # but keeps original order
            self.cnt = select_channels(self.cnt, self.sensor_names)

        if self.set_cz_to_zero is True:
            self.cnt = set_channel_to_zero(self.cnt, 'Cz')
        if self.resample_fs is not None:
            self.cnt = resample_cnt(self.cnt, newfs=self.resample_fs)
        if self.common_average_reference is True:
            self.cnt = common_average_reference_cnt(self.cnt)
        if self.standardize_cnt is True:
            self.cnt = exponential_standardize_cnt(self.cnt)
Esempio n. 10
0
 def preprocess_test_set(self):
     if self.sensor_names is not None:
         self.sensor_names = sort_topologically(self.sensor_names)
         self.test_cnt = select_channels(self.test_cnt, self.sensor_names)
     if self.set_cz_to_zero is True:
         self.test_cnt = set_channel_to_zero(self.test_cnt, 'Cz')
     if self.resample_fs is not None:
         self.test_cnt = resample_cnt(self.test_cnt, newfs=self.resample_fs)
     if self.common_average_reference is True:
         self.test_cnt = common_average_reference_cnt(self.test_cnt)
     if self.standardize_cnt is True:
         self.test_cnt = exponential_standardize_cnt(self.test_cnt)
Esempio n. 11
0
 def preprocess_test_set(self):
     if self.sensor_names is not None:
         self.sensor_names = sort_topologically(self.sensor_names)
         self.test_cnt = select_channels(self.test_cnt, self.sensor_names)
     if self.set_cz_to_zero is True:
         self.test_cnt = set_channel_to_zero(self.test_cnt, 'Cz')
     if self.resample_fs is not None:
         self.test_cnt = resample_cnt(self.test_cnt, newfs=self.resample_fs)
     if self.common_average_reference is True:
         self.test_cnt = common_average_reference_cnt(self.test_cnt)
     if self.standardize_cnt is True:
         self.test_cnt = exponential_standardize_cnt(self.test_cnt)
Esempio n. 12
0
def plot_timeinterval(data,
                      r_square=None,
                      highlights=None,
                      hcolors=None,
                      legend=True,
                      reg_chans=None,
                      position=None):
    """Plots a simple time interval.

    Plots all channels of either continuous data or the mean of epoched
    data into a single timeinterval plot.

    Parameters
    ----------
    data : wyrm.types.Data
        Data object containing the data to plot.
    r_square : [values], optional
        List containing r_squared values to be plotted beneath the main
        plot (default: None).
    highlights : [[int, int)]
        List of tuples containing the start point (included) and end
        point (excluded) of each area to be highlighted (default: None).
    hcolors : [colors], optional
        A list of colors to use for the highlights areas (default:
        None).
    legend : Boolean, optional
        Flag to switch plotting of the legend on or off (default: True).
    reg_chans : [regular expression], optional
        A list of regular expressions. The plot will be limited to those
        channels matching the regular expressions. (default: None).
    position : [x, y, width, height], optional
        A Rectangle that limits the plot to its boundaries (default:
        None).

    Returns
    -------
    Matplotlib.Axes or (Matplotlib.Axes, Matplotlib.Axes)
        The Matplotlib.Axes corresponding to the plotted timeinterval
        and, if provided, the Axes corresponding to r_squared values.

    Examples
    --------
    Plots all channels contained in data with a legend.

    >>> plot_timeinterval(data)

    Same as above, but without the legend.

    >>> plot_timeinterval(data, legend=False)

    Adds r-square values to the plot.

    >>> plot_timeinterval(data, r_square=[values])

    Adds a highlighted area to the plot.

    >>> plot_timeinterval(data, highlights=[[200, 400]])

    To specify the colors of the highlighted areas use 'hcolors'.

    >>> plot_timeinterval(data, highlights=[[200, 400]], hcolors=['red'])
    """

    dcopy = data.copy()
    rect_ti_solo = [.07, .07, .9, .9]
    rect_ti_r2 = [.07, .12, .9, .85]
    rect_r2 = [.07, .07, .9, .05]

    if position is None:
        plt.figure()
        if r_square is None:
            pos_ti = rect_ti_solo
        else:
            pos_ti = rect_ti_r2
            pos_r2 = rect_r2
    else:
        if r_square is None:
            pos_ti = _transform_rect(position, rect_ti_solo)
        else:
            pos_ti = _transform_rect(position, rect_ti_r2)
            pos_r2 = _transform_rect(position, rect_r2)

    if reg_chans is not None:
        dcopy = proc.select_channels(dcopy, reg_chans)

    # process epoched data into continuous data using the mean
    if len(data.data.shape) > 2:
        dcopy = Data(np.mean(dcopy.data,
                             axis=0), [dcopy.axes[-2], dcopy.axes[-1]],
                     [dcopy.names[-2], dcopy.names[-1]],
                     [dcopy.units[-2], dcopy.units[-1]])

    ax1 = None
    # plotting of the data
    ax0 = _subplot_timeinterval(dcopy,
                                position=pos_ti,
                                epoch=-1,
                                highlights=highlights,
                                hcolors=hcolors,
                                legend=legend)
    ax0.xaxis.labelpad = 0

    if r_square is not None:
        ax1 = _subplot_r_square(r_square, position=pos_r2)
        ax0.tick_params(axis='x', direction='in', pad=30 * pos_ti[3])

    plt.grid(True)

    if r_square is None:
        return ax0
    else:
        return ax0, ax1
 def test_select_channels_copy(self):
     """Select channels must not change the original parameter."""
     cpy = self.dat.copy()
     select_channels(self.dat, ['ca.*'])
     self.assertEqual(cpy, self.dat)
Esempio n. 14
0
 def test_select_channels_copy(self):
     """Select channels must not change the original parameter."""
     cpy = self.dat.copy()
     select_channels(self.dat, ["ca.*"])
     self.assertEqual(cpy, self.dat)
Esempio n. 15
0
 def test_select_channels_swapaxis(self):
     """Select channels works with non default chanaxis."""
     dat1 = select_channels(swapaxes(self.dat, 0, 1), ["ca.*"], chanaxis=0)
     dat1 = swapaxes(dat1, 0, 1)
     dat2 = select_channels(self.dat, ["ca.*"])
     self.assertEqual(dat1, dat2)
Esempio n. 16
0
 def test_select_channels(self):
     """Selecting channels with an array of regexes."""
     channels = self.dat.data.copy()
     self.dat = select_channels(self.dat, ["ca.*", "cc1"])
     np.testing.assert_array_equal(self.dat.axes[-1], np.array(["ca1", "ca2", "cc1"]))
     np.testing.assert_array_equal(self.dat.data, channels[:, np.array([0, 1, -1])])
Esempio n. 17
0
 def test_select_channels_inverse(self):
     """Removing channels with an array of regexes."""
     channels = self.dat.data.copy()
     self.dat = select_channels(self.dat, ["ca.*", "cc1"], invert=True)
     np.testing.assert_array_equal(self.dat.axes[-1], np.array(["cb1", "cb2"]))
     np.testing.assert_array_equal(self.dat.data, channels[:, np.array([2, 3])])
Esempio n. 18
0
def plot_tenten(data, highlights=None, hcolors=None, legend=False, scale=True,
                reg_chans=None):
    """Plots channels on a grid system.

    Iterates over every channel in the data structure. If the
    channelname matches a channel in the tenten-system it will be
    plotted in a grid of rectangles. The grid is structured like the
    tenten-system itself, but in a simplified manner. The rows, in which
    channels appear, are predetermined, the channels are ordered
    automatically within their respective row. Areas to highlight can be
    specified, those areas will be marked with colors in every
    timeinterval plot.

    Parameters
    ----------
    data : wyrm.types.Data
        Data object containing the data to plot.
    highlights : [[int, int)]
        List of tuples containing the start point (included) and end
        point (excluded) of each area to be highlighted (default: None).
    hcolors : [colors], optional
        A list of colors to use for the highlight areas (default: None).
    legend : Boolean, optional
        Flag to switch plotting of the legend on or off (default: True).
    scale : Boolean, optional
        Flag to switch plotting of a scale in the top right corner of
        the grid (default: True)
    reg_chans : [regular expressions]
        A list of regular expressions. The plot will be limited to those
        channels matching the regular expressions.

    Returns
    -------
    [Matplotlib.Axes], Matplotlib.Axes
        Returns the plotted timeinterval axes as a list of
        Matplotlib.Axes and the plotted scale as a single
        Matplotlib.Axes.

    Examples
    --------
    Plotting of all channels within a Data object

    >>> plot_tenten(data)

    Plotting of all channels with a highlighted area

    >>> plot_tenten(data, highlights=[[200, 400]])

    Plotting of all channels beginning with 'A'

    >>> plot_tenten(data, reg_chans=['A.*'])
    """
    dcopy = data.copy()
    # this dictionary determines which y-position corresponds with which row in the grid
    ordering = {4.0: 0,
                3.5: 0,
                3.0: 1,
                2.5: 2,
                2.0: 3,
                1.5: 4,
                1.0: 5,
                0.5: 6,
                0.0: 7,
                -0.5: 8,
                -1.0: 9,
                -1.5: 10,
                -2.0: 11,
                -2.5: 12,
                -2.6: 12,
                -3.0: 13,
                -3.5: 14,
                -4.0: 15,
                -4.5: 15,
                -5.0: 16}

    # all the channels with their x- and y-position
    system = dict(CHANNEL_10_20)

    # create list with 17 empty lists. one for every potential row of channels.
    channel_lists = []
    for i in range(18):
        channel_lists.append([])

    if reg_chans is not None:
        dcopy = proc.select_channels(dcopy, reg_chans)

    # distribute the channels to the lists by their y-position
    count = 0
    for c in dcopy.axes[-1]:
        if c in system:
            # entries in channel_lists: [<channel_name>, <x-position>, <position in Data>]
            channel_lists[ordering[system[c][1]]].append((c, system[c][0], count))
        count += 1

    # sort the lists of channels by their x-position
    for l in channel_lists:
        l.sort(key=lambda c_list: c_list[1])

    # calculate the needed dimensions of the grid
    columns = list(map(len, channel_lists))
    columns = [value for value in columns if value != 0]

    # add another axes to the first row for the scale
    columns[0] += 1

    plt.figure()
    grid = calc_centered_grid(columns, hpad=.01, vpad=.01)

    # axis used for sharing axes between channels
    masterax = None
    ax = []

    row = 0
    k = 0
    scale_ax = 0

    for l in channel_lists:
        if len(l) > 0:
            for i in range(len(l)):
                ax.append(_subplot_timeinterval(dcopy, grid[k], epoch=-1, highlights=highlights, hcolors=hcolors, labels=False,
                                                legend=legend, channel=l[i][2], shareaxis=masterax))
                if masterax is None and len(ax) > 0:
                    masterax = ax[0]

                # hide the axeslabeling
                plt.tick_params(axis='both', which='both', labelbottom='off', labeltop='off', labelleft='off',
                                labelright='off', top='off', right='off')

                # at this moment just to show what's what
                plt.gca().annotate(l[i][0], (0.05, 0.05), xycoords='axes fraction')

                k += 1

                if row == 0 and i == len(l)-1:
                    # this is the last axes in the first row
                    scale_ax = k
                    k += 1

            row += 1

    # plot the scale axes
    xtext = dcopy.axes[0][len(dcopy.axes[0])-1]
    sc = _subplot_scale(str(xtext) + ' ms', "$\mu$V", position=grid[scale_ax])

    return ax, sc
Esempio n. 19
0
def plot_timeinterval(data, r_square=None, highlights=None, hcolors=None,
                      legend=True, reg_chans=None, position=None):
    """Plots a simple time interval.

    Plots all channels of either continuous data or the mean of epoched
    data into a single timeinterval plot.

    Parameters
    ----------
    data : wyrm.types.Data
        Data object containing the data to plot.
    r_square : [values], optional
        List containing r_squared values to be plotted beneath the main
        plot (default: None).
    highlights : [[int, int)]
        List of tuples containing the start point (included) and end
        point (excluded) of each area to be highlighted (default: None).
    hcolors : [colors], optional
        A list of colors to use for the highlights areas (default:
        None).
    legend : Boolean, optional
        Flag to switch plotting of the legend on or off (default: True).
    reg_chans : [regular expression], optional
        A list of regular expressions. The plot will be limited to those
        channels matching the regular expressions. (default: None).
    position : [x, y, width, height], optional
        A Rectangle that limits the plot to its boundaries (default:
        None).

    Returns
    -------
    Matplotlib.Axes or (Matplotlib.Axes, Matplotlib.Axes)
        The Matplotlib.Axes corresponding to the plotted timeinterval
        and, if provided, the Axes corresponding to r_squared values.

    Examples
    --------
    Plots all channels contained in data with a legend.

    >>> plot_timeinterval(data)

    Same as above, but without the legend.

    >>> plot_timeinterval(data, legend=False)

    Adds r-square values to the plot.

    >>> plot_timeinterval(data, r_square=[values])

    Adds a highlighted area to the plot.

    >>> plot_timeinterval(data, highlights=[[200, 400]])

    To specify the colors of the highlighted areas use 'hcolors'.

    >>> plot_timeinterval(data, highlights=[[200, 400]], hcolors=['red'])
    """

    dcopy = data.copy()
    rect_ti_solo = [.07, .07, .9, .9]
    rect_ti_r2 = [.07, .12, .9, .85]
    rect_r2 = [.07, .07, .9, .05]

    if position is None:
        plt.figure()
        if r_square is None:
            pos_ti = rect_ti_solo
        else:
            pos_ti = rect_ti_r2
            pos_r2 = rect_r2
    else:
        if r_square is None:
            pos_ti = _transform_rect(position, rect_ti_solo)
        else:
            pos_ti = _transform_rect(position, rect_ti_r2)
            pos_r2 = _transform_rect(position, rect_r2)

    if reg_chans is not None:
        dcopy = proc.select_channels(dcopy, reg_chans)

    # process epoched data into continuous data using the mean
    if len(data.data.shape) > 2:
        dcopy = Data(np.mean(dcopy.data, axis=0), [dcopy.axes[-2], dcopy.axes[-1]],
                     [dcopy.names[-2], dcopy.names[-1]], [dcopy.units[-2], dcopy.units[-1]])

    ax1 = None
    # plotting of the data
    ax0 = _subplot_timeinterval(dcopy, position=pos_ti, epoch=-1, highlights=highlights,
                                hcolors=hcolors, legend=legend)
    ax0.xaxis.labelpad = 0
    if r_square is not None:
        ax1 = _subplot_r_square(r_square, position=pos_r2)
        ax0.tick_params(axis='x', direction='in', pad=30 * pos_ti[3])

    plt.grid(True)

    if r_square is None:
        return ax0
    else:
        return ax0, ax1
 def test_select_channels_swapaxis(self):
     """Select channels works with non default chanaxis."""
     dat1 = select_channels(swapaxes(self.dat, 0, 1), ['ca.*'], chanaxis=0)
     dat1 = swapaxes(dat1, 0, 1)
     dat2 = select_channels(self.dat, ['ca.*'])
     self.assertEqual(dat1, dat2)
Esempio n. 21
0
def plot_tenten(data,
                highlights=None,
                hcolors=None,
                legend=False,
                scale=True,
                reg_chans=None):
    """Plots channels on a grid system.

    Iterates over every channel in the data structure. If the
    channelname matches a channel in the tenten-system it will be
    plotted in a grid of rectangles. The grid is structured like the
    tenten-system itself, but in a simplified manner. The rows, in which
    channels appear, are predetermined, the channels are ordered
    automatically within their respective row. Areas to highlight can be
    specified, those areas will be marked with colors in every
    timeinterval plot.

    Parameters
    ----------
    data : wyrm.types.Data
        Data object containing the data to plot.
    highlights : [[int, int)]
        List of tuples containing the start point (included) and end
        point (excluded) of each area to be highlighted (default: None).
    hcolors : [colors], optional
        A list of colors to use for the highlight areas (default: None).
    legend : Boolean, optional
        Flag to switch plotting of the legend on or off (default: True).
    scale : Boolean, optional
        Flag to switch plotting of a scale in the top right corner of
        the grid (default: True)
    reg_chans : [regular expressions]
        A list of regular expressions. The plot will be limited to those
        channels matching the regular expressions.

    Returns
    -------
    [Matplotlib.Axes], Matplotlib.Axes
        Returns the plotted timeinterval axes as a list of
        Matplotlib.Axes and the plotted scale as a single
        Matplotlib.Axes.

    Examples
    --------
    Plotting of all channels within a Data object

    >>> plot_tenten(data)

    Plotting of all channels with a highlighted area

    >>> plot_tenten(data, highlights=[[200, 400]])

    Plotting of all channels beginning with 'A'

    >>> plot_tenten(data, reg_chans=['A.*'])
    """
    dcopy = data.copy()
    # this dictionary determines which y-position corresponds with which row in the grid
    ordering = {
        4.0: 0,
        3.5: 0,
        3.0: 1,
        2.5: 2,
        2.0: 3,
        1.5: 4,
        1.0: 5,
        0.5: 6,
        0.0: 7,
        -0.5: 8,
        -1.0: 9,
        -1.5: 10,
        -2.0: 11,
        -2.5: 12,
        -2.6: 12,
        -3.0: 13,
        -3.5: 14,
        -4.0: 15,
        -4.5: 15,
        -5.0: 16
    }

    # all the channels with their x- and y-position
    system = dict(CHANNEL_10_20)

    # create list with 17 empty lists. one for every potential row of channels.
    channel_lists = []
    for i in range(18):
        channel_lists.append([])

    if reg_chans is not None:
        dcopy = proc.select_channels(dcopy, reg_chans)

    # distribute the channels to the lists by their y-position
    count = 0
    for c in dcopy.axes[-1]:
        if c in system:
            # entries in channel_lists: [<channel_name>, <x-position>, <position in Data>]
            channel_lists[ordering[system[c][1]]].append(
                (c, system[c][0], count))
        count += 1

    # sort the lists of channels by their x-position
    for l in channel_lists:
        l.sort(key=lambda c_list: c_list[1])

    # calculate the needed dimensions of the grid
    columns = list(map(len, channel_lists))
    columns = [value for value in columns if value != 0]

    # add another axes to the first row for the scale
    columns[0] += 1

    fig = plt.figure()
    grid = calc_centered_grid(columns, hpad=.01, vpad=.01)

    # axis used for sharing axes between channels
    masterax = None
    ax = []

    row = 0
    k = 0
    scale_ax = 0

    for l in channel_lists:
        if len(l) > 0:
            for i in range(len(l)):
                ax.append(
                    _subplot_timeinterval(dcopy,
                                          grid[k],
                                          epoch=-1,
                                          highlights=highlights,
                                          hcolors=hcolors,
                                          labels=False,
                                          legend=legend,
                                          channel=l[i][2],
                                          shareaxis=masterax))
                if masterax is None and len(ax) > 0:
                    masterax = ax[0]

                # hide the axeslabeling
                plt.tick_params(axis='both',
                                which='both',
                                labelbottom='off',
                                labeltop='off',
                                labelleft='off',
                                labelright='off',
                                top='off',
                                right='off')

                # at this moment just to show what's what
                plt.gca().annotate(l[i][0], (0.05, 0.05),
                                   xycoords='axes fraction')

                k += 1

                if row == 0 and i == len(l) - 1:
                    # this is the last axes in the first row
                    scale_ax = k
                    k += 1

            row += 1

    # plot the scale axes
    xtext = dcopy.axes[0][len(dcopy.axes[0]) - 1]
    sc = _subplot_scale(fig,
                        str(xtext) + ' ms',
                        "$\mu$V",
                        position=grid[scale_ax])

    return ax, sc
Esempio n. 22
0
 def select_sensors(self):
     if (self.sensor_names is not None) and (self.sensor_names is not "all"):
         self.signal_processor.cnt = select_channels(self.signal_processor.cnt, self.sensor_names)
     self.sensor_names = self.signal_processor.cnt.axes[-1]
Esempio n. 23
0
def run(ex, subject_id, with_breaks, min_freq, only_return_exp):
    start_time = time.time()
    ex.info['finished'] = False

    window_len = 2000
    window_stride = 500
    marker_def = {
        '1- Right Hand': [1],
        '2 - Feet': [4],
        '3 - Rotation': [8],
        '4 - Words': [10]
    }
    segment_ival = [0, window_len]
    n_selected_features = 20  # 20

    all_start_marker_vals = [1, 4, 8, 10]
    all_end_marker_vals, train_folders, test_folders = get_subject_config(
        subject_id)

    if with_breaks:
        min_break_length_ms = 6000
        max_break_length_ms = 8000
        break_start_offset_ms = 1000
        break_stop_offset_ms = -500
        break_start_marker = 300
        break_end_marker = 301
        all_start_marker_vals.append(break_start_marker)
        all_end_marker_vals.append(break_end_marker)
        marker_def['5 - Break'] = [break_start_marker]

    train_files_list = [
        sorted(glob(os.path.join(folder, '*.BBCI.mat')))
        for folder in train_folders
    ]

    train_files = list(itertools.chain(*train_files_list))
    test_files_list = [
        sorted(glob(os.path.join(folder, '*.BBCI.mat')))
        for folder in test_folders
    ]
    test_files = list(itertools.chain(*test_files_list))
    train_set = MultipleBBCIDataset(train_files)
    test_set = MultipleBBCIDataset(test_files)

    csp_exp = TwoFileCSPExperiment(train_set,
                                   test_set,
                                   NoCleaner(marker_def=marker_def,
                                             segment_ival=segment_ival),
                                   NoCleaner(marker_def=marker_def,
                                             segment_ival=segment_ival),
                                   resample_fs=250,
                                   standardize_cnt=False,
                                   min_freq=min_freq,
                                   max_freq=34,
                                   last_low_freq=10,
                                   low_width=6,
                                   low_overlap=3,
                                   high_overlap=4,
                                   high_width=8,
                                   filt_order=3,
                                   standardize_filt_cnt=False,
                                   segment_ival=[0, 2000],
                                   standardize_epo=False,
                                   n_folds=None,
                                   n_top_bottom_csp_filters=5,
                                   n_selected_filterbands=None,
                                   forward_steps=2,
                                   backward_steps=1,
                                   stop_when_no_improvement=False,
                                   n_selected_features=n_selected_features,
                                   only_last_fold=True,
                                   restricted_n_trials=None,
                                   common_average_reference=False,
                                   ival_optimizer=None,
                                   shuffle=False,
                                   marker_def=marker_def,
                                   set_cz_to_zero=False,
                                   low_bound=0.)
    if only_return_exp:
        return csp_exp
    log.info("Loading train set...")
    csp_exp.load_bbci_set()
    log.info("Loading test set...")
    csp_exp.load_bbci_test_set()
    csp_exp.cnt = select_channels(csp_exp.cnt, ['Cz'], invert=True)
    assert len(csp_exp.cnt.axes[1]) == 63
    csp_exp.test_cnt = select_channels(csp_exp.test_cnt, ['Cz'], invert=True)
    assert len(csp_exp.test_cnt.axes[1]) == 63
    if with_breaks:
        add_break_start_stop_markers(csp_exp.cnt, all_start_marker_vals,
                                     all_end_marker_vals, min_break_length_ms,
                                     max_break_length_ms,
                                     break_start_offset_ms,
                                     break_stop_offset_ms, break_start_marker,
                                     break_end_marker)
    set_windowed_markers(
        csp_exp.cnt,
        all_start_marker_vals,
        all_end_marker_vals,
        window_len,
        window_stride,
    )

    if with_breaks:
        add_break_start_stop_markers(csp_exp.test_cnt, all_start_marker_vals,
                                     all_end_marker_vals, min_break_length_ms,
                                     max_break_length_ms,
                                     break_start_offset_ms,
                                     break_stop_offset_ms, break_start_marker,
                                     break_end_marker)
    set_windowed_markers(
        csp_exp.test_cnt,
        all_start_marker_vals,
        all_end_marker_vals,
        window_len,
        window_stride,
    )

    log.info("Cleaning both sets...")
    csp_exp.clean_both_sets()
    log.info("Preprocessing train set...")
    csp_exp.preprocess_set()
    log.info("Preprocessing test set...")
    csp_exp.preprocess_test_set()
    csp_exp.remember_sensor_names()
    csp_exp.init_training_vars()
    log.info("Running Training...")
    csp_exp.run_training()
    end_time = time.time()
    run_time = end_time - start_time

    ex.info['finished'] = True
    result = CSPResult(csp_trainer=csp_exp,
                       parameters={},
                       training_time=run_time)
    assert len(csp_exp.multi_class.test_accuracy) == 1
    assert len(csp_exp.multi_class.train_accuracy) == 1
    ex.info['train_misclass'] = 1 - csp_exp.multi_class.train_accuracy[0]
    ex.info['test_misclass'] = 1 - csp_exp.multi_class.test_accuracy[0]
    ex.info['runtime'] = run_time
    save_pkl_artifact(ex, result, 'csp_result.pkl')
Esempio n. 24
0
 def select_sensors(self):
     if (self.sensor_names is not None) and (self.sensor_names is not 'all'):
         self.signal_processor.cnt = select_channels(
             self.signal_processor.cnt, 
             self.sensor_names)
     self.sensor_names = self.signal_processor.cnt.axes[-1]