def test_fit_bycycle(ndim, test_data): # Get signals from fixture if ndim == 1: sig = test_data['sig_1d'] elif ndim == 2: sig = test_data['sig_2d'] elif ndim == 3: sig = test_data['sig_3d'] elif ndim == 4: sig = test_data['sig_4d'] # Fit fs = test_data['fs'] f_range = test_data['f_range'] df = fit_bycycle(sig, fs, f_range) if ndim == 1: assert isinstance(df, pd.DataFrame) elif ndim == 2: assert isinstance(df[0], pd.DataFrame) elif ndim == 3: assert isinstance(df[0][0], pd.DataFrame)
def test_save_bycycle(ndim, test_data): # Load data from fixture if ndim == 1: sig = test_data['sig_1d'] elif ndim == 2: sig = test_data['sig_2d'] elif ndim == 3: sig = test_data['sig_3d'] # Fit data fs = test_data['fs'] f_range = test_data['f_range'] model = fit_bycycle(sig, fs, f_range) # Save test_dir = TemporaryDirectory() output_dir = test_dir.name save_bycycle(model, output_dir) for f in [ os.path.join(dp, f) for dp, dn, fn in os.walk(output_dir) for f in fn ]: assert f.split('/')[-1] in ['report.html', 'results.csv', '.data.js'] test_dir.cleanup()
def bycycle_outs(test_data): # Load data sig_1d = test_data['sig_1d'] sig_2d = test_data['sig_2d'] sig_3d = test_data['sig_3d'] fs = test_data['fs'] f_range = test_data['f_range'] # Fit threshold_kwargs = dict(amp_fraction_threshold=0.5, amp_consistency_threshold=.5, monotonicity_threshold=0.8, period_consistency_threshold=.5, min_n_cycles=3) bm = fit_bycycle(sig_1d, fs, f_range, threshold_kwargs=threshold_kwargs) bg = fit_bycycle(sig_2d, fs, f_range, threshold_kwargs=threshold_kwargs) bgs = fit_bycycle(sig_3d, fs, f_range, threshold_kwargs=threshold_kwargs) # Plot bm_graph = plot_bm(bm, sig_1d, fs, threshold_kwargs, 0) yield dict(bm=bm, bm_graph=bm_graph, bg=bg, bgs=bgs, threshold_kwargs=threshold_kwargs)
def test_extract(sim_sig, fooof_outs): sig = sim_sig['sig'] fs = sim_sig['fs'] fm = fooof_outs['fm'] # Model input motifs, cycles = extract(fm, sig, fs) _check_results(motifs, cycles, 'valid', 1) # Tuple input params = [(fm.get_params('peak_params', 'CF'), fm.get_params('peak_params', 'BW'))] motifs, cycles = extract(params, sig, fs) _check_results(motifs, cycles, 'valid', 1) # Dataframe input f_range = (params[0][0]-params[0][1], params[0][0]+params[0][1]) df_features = fit_bycycle(sig, fs, f_range, 'peak') motifs, cycles = extract(params, sig, fs, df_features=df_features) _check_results(motifs, cycles, 'valid', 1) # Force multi-motif motifs, cycles = extract(fm, sig, fs, min_clusters=2, min_clust_score=0, var_thresh=0, min_n_cycles=2) _check_results(motifs, cycles, 'valid', 1) # Force multi-motif and subthresh variance motifs, cycles = extract(fm, sig, fs, min_clusters=2, min_clust_score=0, var_thresh=1, min_n_cycles=2) _check_results(motifs, cycles, 'invalid', 1) # Force single motif motifs, cycles = extract(params, sig, fs, min_clust_score=1.1) _check_results(motifs, cycles, 'valid', 1) # Minimum cycles - no cycles will survive motifs, cycles = extract(params, sig, fs, min_n_cycles=np.inf) _check_results(motifs, cycles, 'invalid', 1) # Sub Variance threshold - no cycles will survive motifs, cycles = extract(params, sig, fs, min_n_cycles=0, var_thresh=np.inf) _check_results(motifs, cycles, 'invalid', 1) # Requires lower bound step fm.peak_params_[0][0] = 2 fm.peak_params_[0][2] = 1.9 motifs, cycles = extract(fm, sig, fs) _check_results(motifs, cycles, 'invalid', 1)
def extract(fm, sig, fs, df_features=None, scaling=1, use_thresh=True, center='peak', min_clust_score=1, var_thresh=0.05, min_clusters=2, max_clusters=10, min_n_cycles=10, index=None, random_state=None): """Get the average cycle from a bycycle dataframe for all fooof peaks. Parameters ---------- fm : fooof.FOOOF or list of tuple A fooof model that has been fit, or a list of (center_freq, bandwidth). sig : 1d or 2d array Time series. If 2d, each timeseries is expected to correspond to each peak, in ascending frequency. fs : float Sampling rate, in Hz. df_features : pandas.DataFrame, optional, default: None A dataframe containing bycycle features. scaling : float, optional, default: 1 The scaling of the bandwidth from the center frequencies to limit cycles to. use_thresh : bool, optional, default: True Limits the dataframe to super variance and correlation thresholds. center : {'peak', 'trough'}, optional The center definition of cycles. min_clust_score : float, optional, default: 1 The minimum silhouette score to accept k clusters. The default skips clustering. var_thresh : float, optional, default: 0.05 Height threshold in variance. min_clusters : int, optional, default: 2 The minimum number of clusters to evaluate. max_clusters : int, optional, default: 10 The maximum number of clusters to evaluate. min_n_cycles : int, optional, default: 10 The minimum number of cycles required to be considered at motif. index : int, optional, default: None Sub-selects a single frequency range to extract. random_state : int, optional, default: None Determines random number generation for centroid initialization. Use an int to make the randomness deterministic for reproducible results. Returns ------- motifs : list of list of 1d arrays Motifs for each center frequency in ascending order. Inner list contains multiple arrays if multiple motifs are found at one frequency. cycles : dict The timeseries, dataframes, frequency ranges, and predicted labels for each cycle. Valid keys include: 'sigs', 'dfs_osc', 'labels', 'f_ranges'. Only returned when ``return_cycles`` is True. """ # Extract center freqs and bandwidths from fooof fit if not isinstance(fm, (list, np.ndarray)): cfs = fm.get_params('peak_params', 'CF') bws = fm.get_params('peak_params', 'BW') cfs = cfs if isinstance(cfs, (list, np.ndarray)) else [cfs] bws = bws if isinstance(bws, (list, np.ndarray)) else [bws] elif isinstance(fm, list): cfs = np.array(fm)[:, 0] bws = np.array(fm)[:, 1] elif isinstance(fm, np.ndarray) and fm.ndim == 2: cfs = fm[:, 0] bws = fm[:, 1] elif isinstance(fm, np.ndarray) and fm.ndim == 1: cfs = [fm[0]] bws = [fm[1]] f_ranges = [(round(cf - (scaling * bws[idx]), 1), round(cf + (scaling * bws[idx]), 1)) for idx, cf in enumerate(cfs)] # Sub-select index if requested if index is not None: f_ranges = [f_ranges[index]] cfs = [cfs[index]] # Get cycles within freq ranges motifs = [] cycles = {'sigs': [], 'dfs_features': [], 'labels': [], 'f_ranges': []} # Vertically stack if sig.ndim == 1: sig = sig.reshape(1, len(sig)) sig = np.repeat(sig, len(f_ranges), axis=0) for ind, (f_range, cf) in enumerate(zip(f_ranges, cfs)): # Floor lower frequency bound at one f_range = (1, f_range[1]) if f_range[0] < 1 else f_range if df_features is None: df_features = fit_bycycle(sig[ind], fs, f_range, center) if df_features is None: motifs, cycles = _nan_append(motifs, cycles) continue # Restrict dataframe to frequency range df_osc = limit_df(df_features, fs, f_range, only_bursts=use_thresh) # No cycles found in frequency range if not isinstance(df_osc, pd.DataFrame) or len(df_osc) < min_n_cycles: motifs, cycles = _nan_append(motifs, cycles) continue # Split signal into 2d array of cycles sig_cyc = split_signal(df_osc, sig[ind], True, center, int(fs / cf)) # Cluster cycles labels = cluster_cycles(sig_cyc, min_clust_score=min_clust_score, min_clusters=min_clusters, max_clusters=max_clusters, random_state=random_state) # Single clusters found if not isinstance(labels, np.ndarray): motif = np.mean(sig_cyc, axis=0) # The variance of the motif is too small (i.e. flat line) if use_thresh and np.var(motif) < var_thresh: motifs, cycles = _nan_append(motifs, cycles) continue motifs.append([motif]) # Multiple motifs found at the current frequency range else: multi_motifs = [] for idx in range(max(labels) + 1): label_idxs = np.where(labels == idx)[0] motif = np.mean(sig_cyc[label_idxs], axis=0) if use_thresh and np.var(motif) < var_thresh: multi_motifs.append(np.nan) else: multi_motifs.append(motif) # Variance too small if len(multi_motifs) == 0: motifs, cycles = _nan_append(motifs, cycles) continue motifs.append(multi_motifs) # Collect cycles cycles['sigs'].append(sig_cyc) cycles['dfs_features'].append(df_osc) cycles['labels'].append(labels) cycles['f_ranges'].append(f_range) return motifs, cycles
def fit(self, fm, sig, fs, ttype='affine'): """Robust motif extraction. Parameters ---------- fm : fooof.FOOOF or list of tuple A fooof model that has been fit, or a list of (center_freq, bandwidth). sig : 1d or 2d array Time series. If 2d, each timeseries is expected to correspond to each peak, in ascending frequency. fs : float Sampling rate, in Hz. ttype : {'euclidean', 'similarity', 'affine', 'projective', 'polynomial'} Transformation type. """ from ndspflow.motif import extract if isinstance(fm, FOOOFGroup): raise ValueError( 'Use motif.fit.MotifGroup for FOOOFGroup objects.') self.fm = fm self.sig = sig self.fs = fs self.results = [] # First pass motif extraction _motifs, _cycles = extract(self.fm, self.sig, self.fs, use_thresh=False, center=self.center, min_clusters=self.min_clusters, max_clusters=self.max_clusters) # Vertically stack f_ranges = _cycles['f_ranges'] if sig.ndim == 1: sig = sig.reshape(1, len(sig)) sig = np.repeat(sig, len(f_ranges), axis=0) for ind, (motif, f_range) in enumerate(zip(_motifs, f_ranges)): # Skip null motifs (np.nan) if isinstance(f_range, float): self.results.append(MotifResult(f_range)) continue # Floor the lower frequency range to one if f_range[0] < 1: f_range = (1, f_range[1]) # Motif correlation burst detection bm = fit_bycycle(sig[ind], fs, f_range) is_burst = motif_burst_detection(motif, bm, sig[ind], corr_thresh=self.corr_thresh, var_thresh=self.var_thresh, ttype=ttype) bm['is_burst'] = is_burst # Re-extract motifs from bursts extract_kwargs = dict(center=self.center, use_thresh=True, var_thresh=self.var_thresh, min_clust_score=self.min_clust_score, min_clusters=self.min_clusters, max_clusters=self.max_clusters, min_n_cycles=self.min_n_cycles, random_state=self.random_state) motifs_burst, cycles_burst = extract(fm, sig[ind], fs, df_features=bm, index=ind, **extract_kwargs) # Match re-extraction results to frequency range of interest motif_idx = [idx for idx, cyc_range in enumerate(cycles_burst['f_ranges']) \ if not isinstance(cyc_range, float) and \ round(cyc_range[0] - f_range[0]) == 0 and \ round(cyc_range[1] - f_range[1]) == 0] # No cycles found in the given frequency range if len(motif_idx) != 1: self.results.append(MotifResult(f_range)) continue motif_idx = motif_idx[0] # Collect results result = MotifResult(f_range, motifs_burst[motif_idx], cycles_burst['sigs'][motif_idx], cycles_burst['dfs_features'][motif_idx], cycles_burst['labels'][motif_idx]) self.results.append(result)
def _run_interface(self, runtime): sig = np.load( os.path.join(os.getcwd(), self.inputs.input_dir, self.inputs.sig)) # Infer axis type from string (traits doesn't support multi-type) axis = None if 'None' in self.inputs.axis else self.inputs.axis axis = (0, 1) if self.inputs.axis.replace(' ', '') == '(0,1)' else axis axis_error = ValueError("Axis must be 0, 1, (0, 1), or None.") if axis is not None and axis != (0, 1): try: axis = int(self.inputs.axis) except: raise ValueError from axis_error if axis not in [0, 1, (0, 1), None]: raise ValueError from axis_error # Get thresholds if self.inputs.burst_method == 'cycles': threshold_kwargs = dict( amp_fraction_threshold=self.inputs.amp_fraction_threshold, amp_consistency_threshold=self.inputs. amp_consistency_threshold, period_consistency_threshold=self.inputs. period_consistency_threshold, monotonicity_threshold=self.inputs.monotonicity_threshold, min_n_cycles=self.inputs.min_n_cycles) else: threshold_kwargs = dict( burst_fraction_threshold=self.inputs.burst_fraction_threshold, min_n_cycles=self.inputs.min_n_cycles) # Organize all kwargs fit_kwargs = dict(center_extrema=self.inputs.center_extrema, burst_method=self.inputs.burst_method, threshold_kwargs=threshold_kwargs, axis=axis, n_jobs=self.inputs.n_jobs) # Fit df_features = fit_bycycle(sig, self.inputs.fs, self.inputs.f_range_bycycle, **fit_kwargs) # Save dataframes save_bycycle(df_features, self.inputs.output_dir) fit_args = dict(sig=sig, fs=self.inputs.fs, f_range=self.inputs.f_range_bycycle, **fit_kwargs) self._results["df_features"] = df_features self._results["bm_results"] = os.path.join(self.inputs.output_dir, 'bycycle') self._results["_fit_args"] = fit_args return runtime