def test_filter(resp_wfm,
                frequency_filters=None,
                noise_threshold=None,
                excit_wfm=None,
                show_plots=True,
                plot_title=None,
                verbose=False):
    """
    Filters the provided response with the provided filters.

    Parameters
    ----------
    resp_wfm : array-like, 1D
        Raw response waveform in the time domain
    frequency_filters : (Optional) FrequencyFilter object or list of
        Frequency filters to apply to signal
    noise_threshold : (Optional) float
        Noise threshold to apply to signal
    excit_wfm : (Optional) 1D array-like
        Excitation waveform in the time domain. This waveform is necessary for plotting loops. If the length of
        resp_wfm matches excit_wfm, a single plot will be returned with the raw and filtered signals plotted against the
        excit_wfm. Else, resp_wfm and the filtered (filt_data) signal will be broken into chunks matching the length of
        excit_wfm and a figure with multiple plots (one for each chunk) with the raw and filtered signal chunks plotted
        against excit_wfm will be returned for fig_loops
    show_plots : (Optional) Boolean
        Whether or not to plot FFTs before and after filtering
    plot_title : str / unicode (Optional)
        Title for the raw vs filtered plots if requested. For example - 'Row 15'
    verbose : (Optional) Boolean
        Prints extra debugging information if True.  Default False

    Returns
    -------
    filt_data : 1D numpy float array
        Filtered signal in the time domain
    fig_fft : matplotlib.pyplot.figure object
        handle to the plotted figure if requested, else None
    fig_loops : matplotlib.pyplot.figure object
        handle to figure with the filtered signal and raw signal plotted against the excitation waveform
    """
    if not isinstance(resp_wfm, (np.ndarray, list)):
        raise TypeError('resp_wfm should be array-like')
    resp_wfm = np.array(resp_wfm)

    show_loops = False
    if excit_wfm is not None and show_plots:
        if len(resp_wfm) % len(excit_wfm) == 0:
            show_loops = True
        else:
            raise ValueError(
                'Length of resp_wfm should be divisibe by length of excit_wfm')
    if show_loops:
        if plot_title is None:
            plot_title = 'FFT Filtering'
        else:
            assert isinstance(plot_title, (str, unicode))

    if frequency_filters is None and noise_threshold is None:
        raise ValueError(
            'Need to specify at least some noise thresholding / frequency filter'
        )

    if noise_threshold is not None:
        if noise_threshold >= 1 or noise_threshold <= 0:
            raise ValueError('Noise threshold must be within (0 1)')

    samp_rate = 1
    composite_filter = 1
    if frequency_filters is not None:
        if not isinstance(frequency_filters, Iterable):
            frequency_filters = [frequency_filters]
        if not are_compatible_filters(frequency_filters):
            raise ValueError(
                'frequency filters must be a single or list of FrequencyFilter objects'
            )
        composite_filter = build_composite_freq_filter(frequency_filters)
        samp_rate = frequency_filters[0].samp_rate

    resp_wfm = np.array(resp_wfm)
    num_pts = resp_wfm.size

    fft_pix_data = np.fft.fftshift(np.fft.fft(resp_wfm))

    if noise_threshold is not None:
        noise_floor = get_noise_floor(fft_pix_data, noise_threshold)[0]
        if verbose:
            print('The noise_floor is', noise_floor)

    fig_fft = None
    if show_plots:
        w_vec = np.linspace(-0.5 * samp_rate, 0.5 * samp_rate, num_pts) * 1E-3

        fig_fft, [ax_raw, ax_filt] = plt.subplots(figsize=(12, 8), nrows=2)
        axes_fft = [ax_raw, ax_filt]
        set_tick_font_size(axes_fft, 14)

        r_ind = num_pts
        if isinstance(composite_filter, np.ndarray):
            r_ind = np.max(np.where(composite_filter > 0)[0])

        x_lims = slice(len(w_vec) // 2, r_ind)
        amp = np.abs(fft_pix_data)
        ax_raw.semilogy(w_vec[x_lims], amp[x_lims], label='Raw')
        if frequency_filters is not None:
            ax_raw.semilogy(w_vec[x_lims],
                            (composite_filter[x_lims] + np.min(amp)) *
                            (np.max(amp) - np.min(amp)),
                            linewidth=3,
                            color='orange',
                            label='Composite Filter')
        if noise_threshold is not None:
            ax_raw.axhline(
                noise_floor,
                # ax_raw.semilogy(w_vec, np.ones(r_ind - l_ind) * noise_floor,
                linewidth=2,
                color='r',
                label='Noise Threshold')
        ax_raw.legend(loc='best', fontsize=14)
        ax_raw.set_title('Raw Signal', fontsize=16)
        ax_raw.set_ylabel('Magnitude (a. u.)', fontsize=14)

    fft_pix_data *= composite_filter

    if noise_threshold is not None:
        fft_pix_data[
            np.abs(fft_pix_data) <
            noise_floor] = 1E-16  # DON'T use 0 here. ipython kernel dies

    if show_plots:
        ax_filt.semilogy(w_vec[x_lims], np.abs(fft_pix_data)[x_lims])
        ax_filt.set_title('Filtered Signal', fontsize=16)
        ax_filt.set_xlabel('Frequency(kHz)', fontsize=14)
        ax_filt.set_ylabel('Magnitude (a. u.)', fontsize=14)
        if noise_threshold is not None:
            ax_filt.set_ylim(
                bottom=noise_floor
            )  # prevents the noise threshold from messing up plots
        fig_fft.tight_layout()

    filt_data = np.real(np.fft.ifft(np.fft.ifftshift(fft_pix_data)))

    if verbose:
        print('The shape of the filtered data is {}'.format(filt_data.shape))
        print('The shape of the excitation waveform is {}'.format(
            excit_wfm.shape))

    fig_loops = None
    if show_loops:
        if len(resp_wfm) == len(excit_wfm):
            # single plot:
            fig_loops, axis = plt.subplots(figsize=(5.5, 5))
            axis.plot(excit_wfm, resp_wfm, 'r', label='Raw')
            axis.plot(excit_wfm, filt_data, 'k', label='Filtered')
            axis.legend(fontsize=14)
            set_tick_font_size(axis, 14)
            axis.set_xlabel('Excitation', fontsize=16)
            axis.set_ylabel('Signal', fontsize=16)
            axis.set_title(plot_title, fontsize=16)
            fig_loops.tight_layout()
        else:
            # N loops:
            raw_pixels = np.reshape(resp_wfm, (-1, len(excit_wfm)))
            filt_pixels = np.reshape(filt_data, (-1, len(excit_wfm)))
            print(raw_pixels.shape, filt_pixels.shape)

            fig_loops, axes_loops = plot_curves(
                excit_wfm, [raw_pixels, filt_pixels],
                line_colors=['r', 'k'],
                dataset_names=['Raw', 'Filtered'],
                x_label='Excitation',
                y_label='Signal',
                subtitle_prefix='Col ',
                num_plots=16,
                title=plot_title)

    return filt_data, fig_fft, fig_loops
Ejemplo n.º 2
0
def plot_svd(h5_main, savefig=False, num_plots=16, **kwargs):
    '''
    Replots the SVD showing the skree, abundance maps, and eigenvectors.
    If h5_main is a Dataset, it will default to the most recent SVD group from that
    Dataset.
    If h5_main is the results group, then it will plot the values for that group.
    
    :param h5_main:
    :type h5_main: USIDataset or h5py Dataset or h5py Group
    
    :param savefig: Saves the figures to disk with some default names
    :type savefig: bool, optional
        
    :param num_plots: Default number of eigenvectors and abundance plots to show
    :type num_plots: int
        
    :param kwargs: keyword arguments for svd filtering
    :type kwarrgs: dict, optional
        
    '''

    if isinstance(h5_main, h5py.Group):

        _U = find_dataset(h5_main, 'U')[-1]
        _V = find_dataset(h5_main, 'V')[-1]
        units = 'arbitrary (a.u.)'
        h5_spec_vals = np.arange(_V.shape[1])
        h5_svd_group = _U.parent

    else:

        h5_svd_group = find_results_groups(h5_main, 'SVD')[-1]
        units = h5_main.attrs['quantity']
        h5_spec_vals = h5_main.get_spec_values('Time')

    h5_U = h5_svd_group['U']
    h5_V = h5_svd_group['V']
    h5_S = h5_svd_group['S']

    _U = USIDataset(h5_U)
    [num_rows, num_cols] = _U.pos_dim_sizes

    abun_maps = np.reshape(h5_U[:, :16], (num_rows, num_cols, -1))
    eigen_vecs = h5_V[:16, :]

    skree_sum = np.zeros(h5_S.shape)
    for i in range(h5_S.shape[0]):
        skree_sum[i] = np.sum(h5_S[:i]) / np.sum(h5_S)

    plt.figure()
    plt.plot(skree_sum, 'bo')
    plt.title('Cumulative Variance')
    plt.xlabel('Total Components')
    plt.ylabel('Total variance ratio (a.u.)')

    if savefig:
        plt.savefig('Cumulative_variance_plot.png')

    fig_skree, axes = plot_utils.plot_scree(h5_S, title='Scree plot')
    fig_skree.tight_layout()

    if savefig:
        plt.savefig('Scree_plot.png')

    fig_abun, axes = plot_utils.plot_map_stack(abun_maps,
                                               num_comps=num_plots,
                                               title='SVD Abundance Maps',
                                               color_bar_mode='single',
                                               cmap='inferno',
                                               reverse_dims=True,
                                               fig_mult=(3.5, 3.5),
                                               facecolor='white',
                                               **kwargs)
    fig_abun.tight_layout()
    if savefig:
        plt.savefig('Abundance_maps.png')

    fig_eigvec, axes = plot_utils.plot_curves(h5_spec_vals * 1e3,
                                              eigen_vecs,
                                              use_rainbow_plots=False,
                                              x_label='Time (ms)',
                                              y_label=units,
                                              num_plots=num_plots,
                                              subtitle_prefix='Component',
                                              title='SVD Eigenvectors',
                                              evenly_spaced=False,
                                              **kwargs)
    fig_eigvec.tight_layout()
    if savefig:
        plt.savefig('Eigenvectors.png')

    return
Ejemplo n.º 3
0
def plot_cluster_centroids(centroids,
                           x_vec,
                           legend_mode=1,
                           x_label=None,
                           y_label=None,
                           title=None,
                           axis=None,
                           overlayed=True,
                           amp_units=None,
                           **kwargs):
    """

    Parameters
    ----------
    centroids : numpy.ndarray
        2D array. Centroids of clusters
    x_vec : numpy.ndarray
        1D array. Vector against which the curves are plotted
    legend_mode : int, optional. default = 1
        Appearance of legend:
            0 - inside the plot
            1 - outside the plot on the right
            else - colorbar instead of legend
    x_label : str, optional, default = None
        Label for x axis
    y_label : str, optional, default = None
        Label for y axis
    title : str, optional, default = None
        Title for the plot
    axis : matplotlib.axes.Axes object, optional.
        Axis to plot this image onto. Will create a new figure by default or will use this axis object to plot into
    overlayed : bool, optional
        If True - all curves will be plotted overlayed on a single plot. Else, curves will be plotted separately
    amp_units : str, optional
        Units for amplitude
    kwargs : dict
        will be passed on to plot_line_family(), plot_complex_spectra, plot_curves

    Returns
    -------
    fig, axes
    """
    if isinstance(centroids, (list, tuple)):
        centroids = np.array(centroids)
    if not isinstance(centroids, np.ndarray):
        raise TypeError('centroids should be a numpy array')
    if centroids.ndim != 2:
        raise ValueError(
            'centroids should be a 2D numpy array - i.e. - 1D spectra')
    if not isinstance(x_vec, (list, tuple)):
        x_vec = np.array(x_vec)
    if not isinstance(x_vec, np.ndarray):
        raise TypeError('x_vec should be a array-like')
    if x_vec.ndim != 1:
        raise ValueError('x_vec should be a 1D array')
    if not isinstance(legend_mode, int):
        raise TypeError('legend_mode should be an integer')
    if axis is not None:
        if not isinstance(axis, mpl.axes.Axes):
            raise TypeError('axis must be a matplotlib.axes.Axes object')
    if not isinstance(overlayed, bool):
        raise TypeError('overlayed should be a boolean value')
    if amp_units is not None:
        if not (isinstance(amp_units, (str, unicode)) or
                (isinstance(amp_units, np.ndarray)
                 and amp_units.dtype.type == np.str_)):
            raise TypeError('amp_units should be a str')
    else:
        amp_units = 'a.u.'

    cmap = kwargs.get('cmap', default_cmap)
    num_clusters = centroids.shape[0]

    def __overlay_curves(axis, curve_stack):

        plot_line_family(axis,
                         x_vec,
                         curve_stack,
                         label_prefix='Cluster',
                         cmap=cmap)

        if legend_mode == 0:
            axis.legend(loc='best', fontsize=14)
        elif legend_mode == 1:
            axis.legend(loc='upper left', bbox_to_anchor=(1, 1), fontsize=14)
        else:
            sm = make_scalar_mappable(0,
                                      num_clusters - 1,
                                      cmap=discrete_cmap(num_clusters, cmap))
            plt.colorbar(sm)

    if overlayed:
        if centroids.dtype in [np.complex64, np.complex128, np.complex]:
            fig, axes = plt.subplots(nrows=2,
                                     figsize=kwargs.pop(
                                         'figsize', (5.5, 2 * 5)))
            for axis, func in zip(axes.flat, [np.abs, np.angle]):
                __overlay_curves(axis, func(centroids))

            for var, var_name, func in zip(
                [y_label, y_label, x_label], ['y_label', 'y_label', 'x_label'],
                [axes[1].set_ylabel, axes[1].set_xlabel]):
                if var is not None:
                    if not isinstance(var, (str, unicode)):
                        raise TypeError(var_name + ' should be a string')
                    func(var)

            if title is not None:
                if not isinstance(title, (str, unicode)):
                    raise TypeError('title should be a string')
                for axis, comp_name, units in zip(axes.flat,
                                                  ['Amplitude', 'Phase'],
                                                  [amp_units, 'rad']):
                    axis.set_title('{} - {} ({})'.format(
                        title, comp_name, units))

        else:

            if axis is None:
                fig, axis = plt.subplots(
                    figsize=kwargs.pop('figsize', (5.5, 5)))
            else:
                fig = None

            __overlay_curves(axis, centroids)

            for var, var_name, func in zip(
                [title, x_label, y_label], ['title', 'x_label', 'y_label'],
                [axis.set_title, axis.set_xlabel, axis.set_ylabel]):
                if var is not None:
                    if not isinstance(var, (str, unicode)):
                        raise TypeError(var_name + ' should be a string')
                    func(var)

        if fig is not None:
            fig.tight_layout()
        return fig, axis

    else:
        if centroids.dtype in [np.complex64, np.complex128, np.complex]:
            return plot_complex_spectra(centroids,
                                        x_vec=x_vec,
                                        title=title,
                                        x_label=x_label,
                                        y_label=y_label,
                                        subtitle_prefix='Cluster',
                                        amp_units=amp_units,
                                        **kwargs)
        else:
            return plot_curves(x_vec,
                               centroids,
                               x_label=x_label,
                               y_label=y_label,
                               title=title,
                               subtitle_prefix='Cluster ',
                               **kwargs)