コード例 #1
0
def normalize_signals(signals, normalization=None, axis=None, percent=None):
    # Following pylab demean:

    def matrix_subtract_along_axis(x, y, axis=0):
        "Return x minus y, where y corresponds to some statistic of x along the specified axis"
        if axis == 0 or axis is None or x.ndim <= 1:
            return x - y
        ind = [slice(None)] * x.ndim
        ind[axis] = np.newaxis
        return x - y[ind]

    def matrix_divide_along_axis(x, y, axis=0):
        "Return x divided by y, where y corresponds to some statistic of x along the specified axis"
        if axis == 0 or axis is None or x.ndim <= 1:
            return x / y
        ind = [slice(None)] * x.ndim
        ind[axis] = np.newaxis
        return x / y[ind]

    for norm, ax, prcnd in zip(ensure_list(normalization), cycle(ensure_list(axis)), cycle(ensure_list(percent))):
        if isinstance(norm, string_types):
            if isequal_string(norm, "zscore"):
                signals = zscore(signals, axis=ax)  # / 3.0
            elif isequal_string(norm, "baseline-std"):
                signals = normalize_signals(["baseline", "std"], axis=axis)
            elif norm.find("baseline") == 0 and norm.find("amplitude") >= 0:
                signals = normalize_signals(signals, ["baseline", norm.split("-")[1]], axis=axis, percent=percent)
            elif isequal_string(norm, "minmax"):
                signals = normalize_signals(signals, ["min", "max"], axis=axis)
            elif isequal_string(norm, "mean"):
                signals = demean(signals, axis=ax)
            elif isequal_string(norm, "baseline"):
                if prcnd is None:
                    prcnd = 1
                signals = matrix_subtract_along_axis(signals, np.percentile(signals, prcnd, axis=ax), axis=ax)
            elif isequal_string(norm, "min"):
                signals = matrix_subtract_along_axis(signals, np.min(signals, axis=ax), axis=ax)
            elif isequal_string(norm, "max"):
                signals = matrix_divide_along_axis(signals, np.max(signals, axis=ax), axis=ax)
            elif isequal_string(norm, "std"):
                signals = matrix_divide_along_axis(signals, signals.std(axis=ax), axis=ax)
            elif norm.find("amplitude") >= 0:
                if prcnd is None:
                    prcnd = [1, 99]
                amplitude = np.percentile(signals, prcnd[1], axis=ax) - np.percentile(signals, prcnd[0], axis=ax)
                this_ax = ax
                if isequal_string(norm.split("amplitude")[0], "max"):
                    amplitude = amplitude.max()
                    this_ax = None
                elif isequal_string(norm.split("amplitude")[0], "mean"):
                    amplitude = amplitude.mean()
                    this_ax = None
                signals = matrix_divide_along_axis(signals, amplitude, axis=this_ax)
            else:
                raise_value_error("Ignoring signals' normalization " + normalization +
                                  ",\nwhich is not one of the currently available " + str(NORMALIZATION_METHODS) + "!",
                                  logger)
    return signals
コード例 #2
0
    def plot_ts(self,
                data,
                time=None,
                var_labels=[],
                mode="ts",
                subplots=None,
                special_idx=[],
                subtitles=[],
                labels=[],
                offset=0.5,
                time_unit="ms",
                title='Time series',
                figure_name=None,
                figsize=None):
        if not isinstance(figsize, (list, tuple)):
            figsize = self.config.LARGE_SIZE
        if isinstance(data, dict):
            var_labels = data.keys()
            data = data.values()
        elif isinstance(data, numpy.ndarray):
            if len(data.shape) < 3:
                if len(data.shape) < 2:
                    data = numpy.expand_dims(data, 1)
                data = numpy.expand_dims(data, 2)
                data = [data]
            else:
                # Assuming a structure of Time X Space X Variables X Samples
                data = [
                    data[:, :, iv].squeeze() for iv in range(data.shape[2])
                ]
        elif isinstance(data, (list, tuple)):
            data = ensure_list(data)
        else:
            LOG.error("Input timeseries: %s \n"
                      "is not on of one of the following types: "
                      "[numpy.ndarray, dict, list, tuple]" % str(data))
        n_vars = len(data)
        data_lims = []
        for id, d in enumerate(data):
            if isequal_string(mode, "raster"):
                data[id] = (d - d.mean(axis=0))
                drange = numpy.max(data[id].max(axis=0) - data[id].min(axis=0))
                data[id] = data[id] / drange  # zscore(d, axis=None)
            data_lims.append([d.min(), d.max()])
        data_shape = data[0].shape
        if len(data_shape) == 1:
            n_times = data_shape[0]
            nTS = 1
            for iV in range(n_vars):
                data[iV] = data[iV][:, numpy.newaxis]
        else:
            n_times, nTS = data_shape[:2]
        if len(data_shape) > 2:
            nSamples = data_shape[2]
        else:
            nSamples = 1
        if special_idx is None:
            special_idx = []
        n_special_idx = len(special_idx)
        if len(subtitles) == 0:
            subtitles = var_labels
        if isinstance(labels, list) and len(labels) == n_vars:
            labels = [
                generate_region_labels(nTS, label, ". ", self.print_ts_indices)
                for label in labels
            ]
        else:
            labels = [
                generate_region_labels(nTS, labels, ". ",
                                       self.print_ts_indices)
                for _ in range(n_vars)
            ]
        if isequal_string(mode, "traj"):
            data_fun, plot_lines, projection, n_rows, n_cols, def_alpha, loopfun, \
            subtitle, subtitle_col, axlabels, axlimits = \
                self._trajectories_plot(n_vars, nTS, nSamples, subplots)
        else:
            if isequal_string(mode, "raster"):
                data_fun, time, plot_lines, projection, n_rows, n_cols, def_alpha, loopfun, \
                subtitle, subtitle_col, axlabels, axlimits, axYticks = \
                    self._ts_plot(time, n_vars, nTS, n_times, time_unit, 0, offset, data_lims)

            else:
                data_fun, time, plot_lines, projection, n_rows, n_cols, def_alpha, loopfun, \
                subtitle, subtitle_col, axlabels, axlimits, axYticks = \
                    self._ts_plot(time, n_vars, nTS, n_times, time_unit, ensure_list(subplots)[0])
        alpha_ratio = 1.0 / nSamples
        alphas = numpy.maximum(
            numpy.array([def_alpha] * nTS) * alpha_ratio, 0.1)
        alphas[special_idx] = numpy.maximum(alpha_ratio, 0.1)
        if isequal_string(mode, "traj") and (n_cols * n_rows > 1):
            colors = numpy.zeros((nTS, 4))
            colors[special_idx] = \
                numpy.array([numpy.array([1.0, 0, 0, 1.0]) for _ in range(n_special_idx)]).reshape((n_special_idx, 4))
        else:
            cmap = matplotlib.cm.get_cmap('jet')
            colors = numpy.array([cmap(0.5 * iTS / nTS) for iTS in range(nTS)])
            colors[special_idx] = \
                numpy.array([cmap(1.0 - 0.25 * iTS / nTS) for iTS in range(n_special_idx)]).reshape((n_special_idx, 4))
        colors[:, 3] = alphas
        lines = []
        pyplot.figure(title, figsize=figsize)
        axes = []
        for icol in range(n_cols):
            if n_rows == 1:
                # If there are no more rows, create axis, and set its limits, labels and possible subtitle
                axes += ensure_list(
                    pyplot.subplot(n_rows,
                                   n_cols,
                                   icol + 1,
                                   projection=projection))
                axlimits(data_lims, time, n_vars, icol)
                axlabels(labels[icol % n_vars], var_labels, n_vars, n_rows, 1,
                         0)
                pyplot.gca().set_title(subtitles[icol])
            for iTS in loopfun(nTS, n_rows, icol):
                if n_rows > 1:
                    # If there are more rows, create axes, and set their limits, labels and possible subtitles
                    axes += ensure_list(
                        pyplot.subplot(n_rows,
                                       n_cols,
                                       iTS + 1,
                                       projection=projection))
                    axlimits(data_lims, time, n_vars, icol)
                    subtitle(labels[icol % n_vars], iTS)
                    axlabels(labels[icol % n_vars], var_labels, n_vars, n_rows,
                             (iTS % n_rows) + 1, iTS)
                lines += ensure_list(
                    plot_lines(data_fun(data, time, icol), iTS, colors,
                               labels[icol % n_vars]))
            if isequal_string(
                    mode,
                    "raster"):  # set yticks as labels if this is a raster plot
                axYticks(labels[icol % n_vars], nTS)
                yticklabels = pyplot.gca().yaxis.get_ticklabels()
                self.tick_font_size = numpy.minimum(
                    self.tick_font_size,
                    int(numpy.round(self.tick_font_size * 100.0 / nTS)))
                for iTS in range(nTS):
                    yticklabels[iTS].set_fontsize(self.tick_font_size)
                    if iTS in special_idx:
                        yticklabels[iTS].set_color(colors[iTS, :3].tolist() +
                                                   [1])
                pyplot.gca().yaxis.set_ticklabels(yticklabels)
                pyplot.gca().invert_yaxis()

        if self.config.MOUSE_HOOVER:
            for line in lines:
                self.HighlightingDataCursor(line,
                                            formatter='{label}'.format,
                                            bbox=dict(fc='white'),
                                            arrowprops=dict(
                                                arrowstyle='simple',
                                                fc='white',
                                                alpha=0.5))

        self._save_figure(pyplot.gcf(), figure_name)
        self._check_show()
        return pyplot.gcf(), axes, lines