示例#1
0
def plot_cluster_h5_group(h5_group, labels_kwargs=None, centroids_kwargs=None):
    """
    Plots the cluster labels and mean response for each cluster

    Parameters
    ----------
    h5_group : h5py.Datagroup object
        H5 group containing the labels and mean response
    labels_kwargs : dict, optional
        keyword arguments for the labels plot. NOT enabled yet.
    centroids_kwargs : dict, optional
        keyword arguments for the centroids plot. NOT enabled yet.

    Returns
    -------
    fig_labels : figure handle
        Figure containing the labels
    fig_centroids : figure handle
        Figure containing the centroids
    """
    if not isinstance(h5_group, h5py.Group):
        raise TypeError('h5_group should be a h5py.Group')
    h5_labels = USIDataset(h5_group['Labels'])
    h5_centroids = USIDataset(h5_group['Mean_Response'])

    labels_mat = np.squeeze(h5_labels.get_n_dim_form())
    if labels_mat.ndim > 3:
        print('Unable to visualize 4 or more dimensional labels!')
    if labels_mat.ndim == 1:
        fig_labs, axis_labs = plt.subplots(figsize=(5.5, 5))
        axis_labs.plot(h5_labels.get_pos_values(h5_labels.pos_dim_labels[0]), labels_mat)
        axis_labs.set_xlabel(h5_labels.pos_dim_descriptors[0])
        axis_labs.set_ylabel('Cluster index')
        axis_labs.set_title(get_attr(h5_group, 'cluster_algorithm') + ' Labels')
    elif labels_mat.ndim == 2:
        fig_labs, axis_labs = plot_cluster_labels(labels_mat, num_clusters=h5_centroids.shape[0],
                                                  x_label=h5_labels.pos_dim_descriptors[0],
                                                  y_label=h5_labels.pos_dim_descriptors[1],
                                                  x_vec=h5_labels.get_pos_values(h5_labels.pos_dim_labels[0]),
                                                  y_vec=h5_labels.get_pos_values(h5_labels.pos_dim_labels[1]),
                                                  title=get_attr(h5_group, 'cluster_algorithm') + ' Labels')

    # TODO: probably not a great idea to load the entire dataset to memory
    centroids_mat = h5_centroids.get_n_dim_form()
    if len(h5_centroids.spec_dim_labels) == 1:
        legend_mode = 2
        if h5_centroids.shape[0] < 6:
            legend_mode = 1
        fig_cent, axis_cent = plot_cluster_centroids(centroids_mat,
                                                     h5_centroids.get_spec_values(h5_centroids.spec_dim_labels[0]),
                                                     legend_mode=legend_mode,
                                                     x_label=h5_centroids.spec_dim_descriptors[0],
                                                     y_label=h5_centroids.data_descriptor,
                                                     overlayed=h5_centroids.shape[0] < 6,
                                                     title=get_attr(h5_group,
                                                                    'cluster_algorithm') + ' Centroid',
                                                     amp_units=get_attr(h5_centroids, 'units'))
    elif len(h5_centroids.spec_dim_labels) == 2:
        # stack of spectrograms
        if h5_centroids.dtype in [np.complex64, np.complex128, np.complex]:
            fig_cent, axis_cent = plot_complex_spectra(centroids_mat, subtitle_prefix='Cluster',
                                                       title=get_attr(h5_group, 'cluster_algorithm') + ' Centroid',
                                                       x_label=h5_centroids.spec_dim_descriptors[0],
                                                       y_label=h5_centroids.spec_dim_descriptors[1],
                                                       amp_units=get_attr(h5_centroids, 'units'))
        else:
            fig_cent, axis_cent = plot_map_stack(centroids_mat, color_bar_mode='each', evenly_spaced=True,
                                                 title='Cluster',
                                                 heading=get_attr(h5_group,
                                                                  'cluster_algorithm') + ' Centroid')
    return fig_labs, fig_cent
示例#2
0
def plot_cluster_h5_group(h5_group, labels_kwargs=None, centroids_kwargs=None):
    """
    Plots the cluster labels and mean response for each cluster

    Parameters
    ----------
    h5_group : h5py.Datagroup object
        H5 group containing the labels and mean response
    labels_kwargs : dict, optional
        keyword arguments for the labels plot. NOT enabled yet.
    centroids_kwargs : dict, optional
        keyword arguments for the centroids plot. NOT enabled yet.

    Returns
    -------
    fig_labels : figure handle
        Figure containing the labels
    fig_centroids : figure handle
        Figure containing the centroids
    """
    if not isinstance(h5_group, h5py.Group):
        raise TypeError('h5_group should be a h5py.Group')
    h5_labels = USIDataset(h5_group['Labels'])
    h5_centroids = USIDataset(h5_group['Mean_Response'])

    labels_mat = np.squeeze(h5_labels.get_n_dim_form())
    if labels_mat.ndim > 3:
        print('Unable to visualize 4 or more dimensional labels!')
    if labels_mat.ndim == 1:
        fig_labs, axis_labs = plt.subplots(figsize=(5.5, 5))
        axis_labs.plot(h5_labels.get_pos_values(h5_labels.pos_dim_labels[0]),
                       labels_mat)
        axis_labs.set_xlabel(h5_labels.pos_dim_descriptors[0])
        axis_labs.set_ylabel('Cluster index')
        axis_labs.set_title(
            get_attr(h5_group, 'cluster_algorithm') + ' Labels')
    elif labels_mat.ndim == 2:
        fig_labs, axis_labs = plot_cluster_labels(
            labels_mat,
            num_clusters=h5_centroids.shape[0],
            x_label=h5_labels.pos_dim_descriptors[0],
            y_label=h5_labels.pos_dim_descriptors[1],
            x_vec=h5_labels.get_pos_values(h5_labels.pos_dim_labels[0]),
            y_vec=h5_labels.get_pos_values(h5_labels.pos_dim_labels[1]),
            title=get_attr(h5_group, 'cluster_algorithm') + ' Labels')

    # TODO: probably not a great idea to load the entire dataset to memory
    centroids_mat = h5_centroids.get_n_dim_form()
    if len(h5_centroids.spec_dim_labels) == 1:
        legend_mode = 2
        if h5_centroids.shape[0] < 6:
            legend_mode = 1
        fig_cent, axis_cent = plot_cluster_centroids(
            centroids_mat,
            h5_centroids.get_spec_values(h5_centroids.spec_dim_labels[0]),
            legend_mode=legend_mode,
            x_label=h5_centroids.spec_dim_descriptors[0],
            y_label=h5_centroids.data_descriptor,
            overlayed=h5_centroids.shape[0] < 6,
            title=get_attr(h5_group, 'cluster_algorithm') + ' Centroid',
            amp_units=get_attr(h5_centroids, 'units'))
    elif len(h5_centroids.spec_dim_labels) == 2:
        # stack of spectrograms
        if h5_centroids.dtype in [np.complex64, np.complex128, np.complex]:
            fig_cent, axis_cent = plot_complex_spectra(
                centroids_mat,
                subtitle_prefix='Cluster',
                title=get_attr(h5_group, 'cluster_algorithm') + ' Centroid',
                x_label=h5_centroids.spec_dim_descriptors[0],
                y_label=h5_centroids.spec_dim_descriptors[1],
                amp_units=get_attr(h5_centroids, 'units'))
        else:
            fig_cent, axis_cent = plot_map_stack(
                centroids_mat,
                color_bar_mode='each',
                evenly_spaced=True,
                title='Cluster',
                heading=get_attr(h5_group, 'cluster_algorithm') + ' Centroid')
    return fig_labs, fig_cent
示例#3
0
def plot_cluster_centroids(centroids, x_vec, legend_mode=1, x_label=None, y_label=None, title=None, axis=None,
                           overlayed=True, amp_units=None, **kwargs):
    """

    Parameters
    ----------
    centroids : numpy.ndarray
        2D array. Centroids of clusters
    x_vec : numpy.ndarray
        1D array. Vector against which the curves are plotted
    legend_mode : int, optional. default = 1
        Appearance of legend:
            0 - inside the plot
            1 - outside the plot on the right
            else - colorbar instead of legend
    x_label : str, optional, default = None
        Label for x axis
    y_label : str, optional, default = None
        Label for y axis
    title : str, optional, default = None
        Title for the plot
    axis : matplotlib.axes.Axes object, optional.
        Axis to plot this image onto. Will create a new figure by default or will use this axis object to plot into
    overlayed : bool, optional
        If True - all curves will be plotted overlayed on a single plot. Else, curves will be plotted separately
    amp_units : str, optional
        Units for amplitude
    kwargs : dict
        will be passed on to plot_line_family(), plot_complex_spectra, plot_curves

    Returns
    -------
    fig, axes
    """
    if isinstance(centroids, (list, tuple)):
        centroids = np.array(centroids)
    if not isinstance(centroids, np.ndarray):
        raise TypeError('centroids should be a numpy array')
    if centroids.ndim != 2:
        raise ValueError('centroids should be a 2D numpy array - i.e. - 1D spectra')
    if not isinstance(x_vec, (list, tuple)):
        x_vec = np.array(x_vec)
    if not isinstance(x_vec, np.ndarray):
        raise TypeError('x_vec should be a array-like')
    if x_vec.ndim != 1:
        raise ValueError('x_vec should be a 1D array')
    if not isinstance(legend_mode, int):
        raise TypeError('legend_mode should be an integer')
    if axis is not None:
        if not isinstance(axis, mpl.axes.Axes):
            raise TypeError('axis must be a matplotlib.axes.Axes object')
    if not isinstance(overlayed, bool):
        raise TypeError('overlayed should be a boolean value')
    if amp_units is not None:
        if not (isinstance(amp_units, (str, unicode)) or
                (isinstance(amp_units, np.ndarray) and amp_units.dtype.type == np.str_)):
            raise TypeError('amp_units should be a str')
    else:
        amp_units = 'a.u.'

    cmap = kwargs.get('cmap', default_cmap)
    num_clusters = centroids.shape[0]

    def __overlay_curves(axis, curve_stack):

        plot_line_family(axis, x_vec, curve_stack, label_prefix='Cluster', cmap=cmap)

        if legend_mode == 0:
            axis.legend(loc='best', fontsize=14)
        elif legend_mode == 1:
            axis.legend(loc='upper left', bbox_to_anchor=(1, 1), fontsize=14)
        else:
            sm = make_scalar_mappable(0, num_clusters - 1, cmap=discrete_cmap(num_clusters, cmap))
            plt.colorbar(sm)

    if overlayed:
        if centroids.dtype in [np.complex64, np.complex128, np.complex]:
            fig, axes = plt.subplots(nrows=2, figsize=kwargs.pop('figsize', (5.5, 2 * 5)))
            for axis, func in zip(axes.flat, [np.abs, np.angle]):
                __overlay_curves(axis, func(centroids))

            for var, var_name, func in zip([y_label, y_label, x_label], ['y_label', 'y_label', 'x_label'],
                                           [axes[1].set_ylabel, axes[1].set_xlabel]):
                if var is not None:
                    if not isinstance(var, (str, unicode)):
                        raise TypeError(var_name + ' should be a string')
                    func(var)

            if title is not None:
                if not isinstance(title, (str, unicode)):
                    raise TypeError('title should be a string')
                for axis, comp_name, units in zip(axes.flat, ['Amplitude', 'Phase'], [amp_units, 'rad']):
                    axis.set_title('{} - {} ({})'.format(title, comp_name, units))

        else:

            if axis is None:
                fig, axis = plt.subplots(figsize=kwargs.pop('figsize', (5.5, 5)))
            else:
                fig = None

            __overlay_curves(axis, centroids)

            for var, var_name, func in zip([title, x_label, y_label], ['title', 'x_label', 'y_label'],
                                           [axis.set_title, axis.set_xlabel, axis.set_ylabel]):
                if var is not None:
                    if not isinstance(var, (str, unicode)):
                        raise TypeError(var_name + ' should be a string')
                    func(var)

        if fig is not None:
            fig.tight_layout()
        return fig, axis

    else:
        if centroids.dtype in [np.complex64, np.complex128, np.complex]:
            return plot_complex_spectra(centroids, x_vec=x_vec, title=title, x_label=x_label, y_label=y_label,
                                        subtitle_prefix='Cluster', amp_units=amp_units, **kwargs)
        else:
            return plot_curves(x_vec, centroids, x_label=x_label, y_label=y_label, title=title,
                               subtitle_prefix='Cluster ', **kwargs)
示例#4
0
def plot_cluster_centroids(centroids,
                           x_vec,
                           legend_mode=1,
                           x_label=None,
                           y_label=None,
                           title=None,
                           axis=None,
                           overlayed=True,
                           amp_units=None,
                           **kwargs):
    """

    Parameters
    ----------
    centroids : numpy.ndarray
        2D array. Centroids of clusters
    x_vec : numpy.ndarray
        1D array. Vector against which the curves are plotted
    legend_mode : int, optional. default = 1
        Appearance of legend:
            0 - inside the plot
            1 - outside the plot on the right
            else - colorbar instead of legend
    x_label : str, optional, default = None
        Label for x axis
    y_label : str, optional, default = None
        Label for y axis
    title : str, optional, default = None
        Title for the plot
    axis : matplotlib.axes.Axes object, optional.
        Axis to plot this image onto. Will create a new figure by default or will use this axis object to plot into
    overlayed : bool, optional
        If True - all curves will be plotted overlayed on a single plot. Else, curves will be plotted separately
    amp_units : str, optional
        Units for amplitude
    kwargs : dict
        will be passed on to plot_line_family(), plot_complex_spectra, plot_curves

    Returns
    -------
    fig, axes
    """
    if isinstance(centroids, (list, tuple)):
        centroids = np.array(centroids)
    if not isinstance(centroids, np.ndarray):
        raise TypeError('centroids should be a numpy array')
    if centroids.ndim != 2:
        raise ValueError(
            'centroids should be a 2D numpy array - i.e. - 1D spectra')
    if not isinstance(x_vec, (list, tuple)):
        x_vec = np.array(x_vec)
    if not isinstance(x_vec, np.ndarray):
        raise TypeError('x_vec should be a array-like')
    if x_vec.ndim != 1:
        raise ValueError('x_vec should be a 1D array')
    if not isinstance(legend_mode, int):
        raise TypeError('legend_mode should be an integer')
    if axis is not None:
        if not isinstance(axis, mpl.axes.Axes):
            raise TypeError('axis must be a matplotlib.axes.Axes object')
    if not isinstance(overlayed, bool):
        raise TypeError('overlayed should be a boolean value')
    if amp_units is not None:
        if not (isinstance(amp_units, (str, unicode)) or
                (isinstance(amp_units, np.ndarray)
                 and amp_units.dtype.type == np.str_)):
            raise TypeError('amp_units should be a str')
    else:
        amp_units = 'a.u.'

    cmap = kwargs.get('cmap', default_cmap)
    num_clusters = centroids.shape[0]

    def __overlay_curves(axis, curve_stack):

        plot_line_family(axis,
                         x_vec,
                         curve_stack,
                         label_prefix='Cluster',
                         cmap=cmap)

        if legend_mode == 0:
            axis.legend(loc='best', fontsize=14)
        elif legend_mode == 1:
            axis.legend(loc='upper left', bbox_to_anchor=(1, 1), fontsize=14)
        else:
            sm = make_scalar_mappable(0,
                                      num_clusters - 1,
                                      cmap=discrete_cmap(num_clusters, cmap))
            plt.colorbar(sm)

    if overlayed:
        if centroids.dtype in [np.complex64, np.complex128, np.complex]:
            fig, axes = plt.subplots(nrows=2,
                                     figsize=kwargs.pop(
                                         'figsize', (5.5, 2 * 5)))
            for axis, func in zip(axes.flat, [np.abs, np.angle]):
                __overlay_curves(axis, func(centroids))

            for var, var_name, func in zip(
                [y_label, y_label, x_label], ['y_label', 'y_label', 'x_label'],
                [axes[1].set_ylabel, axes[1].set_xlabel]):
                if var is not None:
                    if not isinstance(var, (str, unicode)):
                        raise TypeError(var_name + ' should be a string')
                    func(var)

            if title is not None:
                if not isinstance(title, (str, unicode)):
                    raise TypeError('title should be a string')
                for axis, comp_name, units in zip(axes.flat,
                                                  ['Amplitude', 'Phase'],
                                                  [amp_units, 'rad']):
                    axis.set_title('{} - {} ({})'.format(
                        title, comp_name, units))

        else:

            if axis is None:
                fig, axis = plt.subplots(
                    figsize=kwargs.pop('figsize', (5.5, 5)))
            else:
                fig = None

            __overlay_curves(axis, centroids)

            for var, var_name, func in zip(
                [title, x_label, y_label], ['title', 'x_label', 'y_label'],
                [axis.set_title, axis.set_xlabel, axis.set_ylabel]):
                if var is not None:
                    if not isinstance(var, (str, unicode)):
                        raise TypeError(var_name + ' should be a string')
                    func(var)

        if fig is not None:
            fig.tight_layout()
        return fig, axis

    else:
        if centroids.dtype in [np.complex64, np.complex128, np.complex]:
            return plot_complex_spectra(centroids,
                                        x_vec=x_vec,
                                        title=title,
                                        x_label=x_label,
                                        y_label=y_label,
                                        subtitle_prefix='Cluster',
                                        amp_units=amp_units,
                                        **kwargs)
        else:
            return plot_curves(x_vec,
                               centroids,
                               x_label=x_label,
                               y_label=y_label,
                               title=title,
                               subtitle_prefix='Cluster ',
                               **kwargs)
示例#5
0
 def test_not_stdevs(self):
     frequencies = 2**np.arange(4)
     image_stack = [self.get_complex_2d_image(freq) for freq in frequencies]
     with self.assertRaises(TypeError):
         plot_utils.plot_complex_spectra(np.array(image_stack), stdevs=-1)
示例#6
0
 def test_is_x_vec(self):
     frequencies = 2**np.arange(4)
     image_stack = [self.get_complex_2d_image(freq) for freq in frequencies]
     ran_arr = np.zeros_like(image_stack)
     with self.assertRaises(ValueError):
         plot_utils.plot_complex_spectra(np.array(image_stack), ran_arr)
示例#7
0
 def test_is_not_dim_x_vec(self):
     frequencies = 2**np.arange(4)
     image_stack = [self.get_complex_2d_image(freq) for freq in frequencies]
     with self.assertRaises(ValueError):
         plot_utils.plot_complex_spectra(np.array(image_stack), [1])
示例#8
0
 def test_not_map_stack(self):
     with self.assertRaises(TypeError):
         plot_utils.plot_complex_spectra('wrongthing')