def main(args): parser = argparse.ArgumentParser('Detect spikes in .dat files') parser.add_argument('-v', '--verbose', action='store_true', help="Verbose (debug) output") parser.add_argument('target', default='.', help="""Directory with tetrode files.""") parser.add_argument('-o', '--out_path', help='Output file path Defaults to current working directory') parser.add_argument('--sampling-rate', type=float, help='Sampling rate. Default 30000 Hz', default=3e4) parser.add_argument('--noise_percentile', type=int, help='Noise percentile. Default: 5', default=5) parser.add_argument('--threshold', type=float, help='Threshold. Default: 4.5', default=4.5) parser.add_argument('-t', '--tetrodes', nargs='*', help='0-index list of tetrodes to look at. Default: all.') parser.add_argument('-f', '--force', action='store_true', help='Force overwrite of existing files.') parser.add_argument('-a', '--align', help='Alignment method, default: min', default='min') parser.add_argument('--start', type=float, help='Segment start in seconds', default=0) parser.add_argument('--end', type=float, help='Segment end in seconds') cli_args = parser.parse_args(args) logger.debug('Arguments: {}'.format(cli_args)) stddev_factor = cli_args.threshold logger.debug('Threshold factor : {}'.format(stddev_factor)) fs = cli_args.sampling_rate logger.debug('Sampling rate : {}'.format(fs)) noise_percentile = cli_args.noise_percentile logger.debug('Noise percentile : {}'.format(noise_percentile)) alignment_method = cli_args.align logger.debug('Alignment method : {}'.format(alignment_method)) target = Path(cli_args.target) if target.is_file() and target.exists(): tetrode_files = [target] target = target.parent logger.debug('Using single file mode with {}'.format(target)) else: tetrode_files = sorted(target.glob('tetrode*.dat')) start = int(cli_args.start * fs) if cli_args.start is not None else 0 end = int(cli_args.end * fs) if cli_args.end is not None else -1 now = dt.today().strftime('%Y%m%d_%H%M%S') report_path = target / f'dataman_detect_report_{now}.html' for tt, tetrode_file in tqdm(enumerate(list(tetrode_files)), desc='Progress', total=len(tetrode_files), unit='TT'): report_string = '' # General report on shape, lengths etc. tqdm.write(f'-> Starting spike detection for {tetrode_file.name}') if not tetrode_file.exists(): logger.info(f"{tetrode_file} not found. Skipping.") continue matpath = tetrode_file.with_suffix('.mat') if matpath.exists() and not cli_args.force: logger.error(f'{matpath} already exists. Delete or use --force to overwrite.') exit(1) elif matpath.exists() and cli_args.force: logger.warning(f'{matpath} already exists, deleting it.') os.remove(matpath) raw_memmap = np.memmap(tetrode_file, dtype='int16') logger.debug(f'loading {start}:{end} from memmap {raw_memmap}') wb = raw_memmap.reshape((-1, 4))[start:end] del raw_memmap logger.debug('Creating waveform figure...') report_string += '<h1>Recording</h1>\n' report_string += 'Length: {:.2f} MSamples, {:.2f} minutes'.format(wb.shape[0] / 1e6, wb.shape[0] / fs / 60) report_string += f'<h1>{tetrode_file.name}</h1>' report_string += str(tetrode_file) + '<br>' fig = report.plot_raw(wb) report_string += report.fig2html(fig) + '<br>' plt.close(fig) del fig logger.debug('Creating noise estimation figure...') # Noise estimation for threshold calculation noise = estimate_noise(wb) # Calculate threshold based on all segments with a minimum amount of noise # to not incorporate zeroed out segments ne_nz = noise.sum(axis=1) > MINIMUM_NOISE_THRESHOLD non_zero_ne = noise[ne_nz, :] noise_perc = np.percentile(non_zero_ne, noise_percentile, axis=0) ne_min = np.min(noise, axis=0) ne_max = np.max(noise, axis=0) ne_std = np.std(noise, axis=0) # Report noise amplitudes report_string += '<h2>Noise estimation</h2>' fig = report.plot_noise(noise, thresholds=noise_perc, tetrode=tetrode_file.name) report_string += report.fig2html(fig) + '<br>' plt.close(fig) del fig thr = noise_perc * stddev_factor for ch in range(4): info_line = f'<b>Channel {ch}:</b> Thr {thr[ch]:.1f} = {noise_perc[ch]:.1f} uV * {stddev_factor:.1f} nSD' \ f' ({noise_percentile:}th NE percentile, min: {ne_min[ch]:.1f}, max: {ne_max[ch]:.1f},' \ f'std: {ne_std[ch]:.1f})</br>' report_string += info_line # Spike Timestamp Detection ################################################## report_string += '<h2>Spike Detection</h2>' timestamps = detect_spikes(wb, thr, align=cli_args.align, fs=fs) sps = len(timestamps) / (wb.shape[0] / fs) report_string += '<b>{} spikes</b> ({:.1f} sps) </br>'.format(len(timestamps), sps) logger.info(f'{tetrode_file.name}: {len(timestamps)} spikes, {sps:.1f} sps') # Spike Waveform Extraction ################################################## waveforms = extract_waveforms(timestamps, wb, outpath=matpath, s_pre=10, s_post=22, fs=fs) # Create waveform plots logger.debug('Creating waveform plots') density_agg = 'log' images = dataman.lib.report.ds_shade_waveforms(waveforms, how=density_agg) fig = dataman.lib.report.ds_plot_waveforms(images, density_agg) report_string += report.fig2html(fig) + '</br>' plt.close(fig) del fig # Tetrode Done! report_string += '</hr>' with open(report_path, 'a') as rf: rf.write(report_string) logger.info('Done!')
def main(args): parser = argparse.ArgumentParser('Clustering with KlustaKwik') parser.add_argument( 'target', help= 'Target path, either path containing tetrode files, or single tetrodeXX.fet.0' ) parser.add_argument('--KK', help='Path to KlustaKwik executable') parser.add_argument('--features', nargs='*', help='list of features to use for clustering') parser.add_argument('--config', help='Path to configuration file') parser.add_argument('--cluster', action='store_true', help='Directly run ') parser.add_argument('--force', help='Overwrite existing files.', action='store_true') parser.add_argument('--skip', help='Skip if clu file exists already', action='store_true') parser.add_argument('--no_spread', help='Shade report plots without static spread', action='store_true') cli_args = parser.parse_args(args) # Load default configuration yaml file default_cfg_path = Path( pkg_resources.resource_filename( __name__, '../resources/cluster_defaults.yml')).resolve() if not default_cfg_path.exists(): logging.error('Could not find default config file.') raise FileNotFoundError logger.debug('Loading default configuration') cfg = load_yaml(default_cfg_path) # Load local config file if it exists local_cfg_path = Path( pkg_resources.resource_filename( __name__, '../resources/cluster_defaults_local.yml')).resolve() if local_cfg_path.exists(): logger.debug('Loading and updating with local configuration') local_cfg = load_yaml(local_cfg_path) cfg.update(local_cfg) # Load custom config path custom_cfg_path = Path( cli_args.config).resolve() if cli_args.config else None if custom_cfg_path: if custom_cfg_path.exists(): logger.debug('Loading and updating with custom configuration') cfg.update(load_yaml(custom_cfg_path)) else: raise FileNotFoundError( f"Could not load configuration file {custom_cfg_path}") # Load parameters from command line logger.debug('Parsing and updating configuration with CLI arguments') cfg.update(vars(cli_args)) logger.debug(cfg) # try to find Klustakwik executable if necessary... if cli_args.KK is None: cli_args.KK = shutil.which('KlustaKwik') or shutil.which( 'klustakwik') or shutil.which('Klustakwik') if cli_args.KK is None and cli_args.cluster: raise FileNotFoundError( 'Could not find the KlustaKwik executable on the path, and none given.' ) # # Building KlustaKwik Command # # 1) Find KlustaKwik executable # mclust_path = Path('C:/Users/reichler/src/MClustPipeline/MClust/KlustaKwik') # pf_system = platform.system() # logger.debug(f'Platform: {pf_system}') # if pf_system == 'Linux': # kk_executable = mclust_path / cfg['KLUSTAKWIK_PATH_LINUX'] # elif pf_system == 'Windows': # kk_executable = mclust_path / cfg['KLUSTAKWIK_PATH_WINDOWS'] # else: # raise NotImplemented(f'No KlustaKwik executable defined for platform {pf_system}') # logger.debug(kk_executable) kk_executable = cli_args.KK # 2) Find target file stem working_dir = Path(cli_args.target).resolve() if working_dir.is_file() and working_dir.exists(): tetrode_files = [working_dir.name] working_dir = working_dir.parent logger.debug(f'Using single file mode with {str(tetrode_files[0])}') else: tetrode_files = sorted( [tf.name for tf in working_dir.glob(cfg['TARGET_FILE_GLOB'])]) logger.debug(f'Working dir: {working_dir}') # No parallel/serial execution supported right now if len(tetrode_files) > 1: raise NotImplemented( 'Currently only one target file per call supported!') logger.debug(f'Target found: {tetrode_files}') tetrode_file_stem = str(tetrode_files[0]).split(".")[0] tetrode_file_elecno = tetrode_files[0].split(".")[-1] # 3) Check if output file already exists clu_file = (working_dir / tetrode_file_stem).with_suffix(f'.clu.{tetrode_file_elecno}') if clu_file.exists() and not (cli_args.force or cli_args.skip): raise FileExistsError( 'Clu file already exists. Use --force to overwrite.') # 4) combine executable and arguments kk_cmd = f'{kk_executable} {tetrode_file_stem} -ElecNo {tetrode_file_elecno}' kk_cmd_list = kk_cmd.split(' ') logger.debug(f'KK COMMAND: {kk_cmd}') logger.debug(f'KK COMMAND LIST: {kk_cmd_list}') # Call KlustaKwik and gather output # TODO: Use communicate to interact with KK, i.e. write to log and monitor progress # see https://stackoverflow.com/questions/21953835/run-subprocess-and-print-output-to-logging logger.info('Starting KlustaKwik process') if cfg['PRINT_KK_OUTPUT']: stdout = subprocess.STDOUT else: stdout = subprocess.PIPE # EXECUTE KLUSTAKWIK if not clu_file.exists() or cli_args.force: kk_call = subprocess.run(kk_cmd_list, stderr=subprocess.PIPE, stdout=stdout) kk_error = kk_call.returncode logger.debug('Writing KlustaKwik log file') with open(clu_file.with_suffix('.log'), 'w') as log_file: log_file.write(kk_call.stderr.decode('ascii')) # Check call return code and output if kk_error: logging.error(f'KlustaKwik error code: {kk_error}') exit(kk_error) else: logging.debug(f'KlustaKwik successful: {kk_error}') # Load clu file logger.debug(f'Loading {clu_file}') clu_df = pd.read_csv(clu_file, dtype='category', names=['cluster_id'], skiprows=1) cluster_labels = clu_df['cluster_id'].cat.categories num_clusters = len(cluster_labels) logger.info(f'{len(clu_df)} spikes in {num_clusters} clusters') # Find all feature .fd files feature_files = list(working_dir.glob(tetrode_file_stem + '_*.fd')) ff_sizes = [ff.stat().st_mtime for ff in feature_files] feature_files = [f for t, f in sorted(zip(ff_sizes, feature_files))] if not len(feature_files): raise FileNotFoundError(f'No Feature Files found in {working_dir}') # TODO: Stupid, the feature names are in the .fd file already feature_names = [ str(ff.name).split(tetrode_file_stem + '_')[1].split('.')[0] for ff in feature_files ] logger.info(f'Loading features: {feature_names}') color_keys = cfg['CLUSTER_COLORS'] with open(clu_file.with_suffix('.html'), 'w') as crf: crf.write('<head></head><body><h1>{}</h1>'.format(clu_file.name)) for fd_file, fet_name in zip(feature_files, feature_names): crf.write('<h3>Feature: {}</h3>\n'.format(fet_name)) logger.info(f'Generating images for feature {fet_name}') if not fd_file.exists(): continue logger.debug(f'Loading {fd_file}') mat_fet = h5s.loadmat(str(fd_file), appendmat=False) fd_df = pd.DataFrame(mat_fet['FeatureData']) fd_df.rename(columns={c: str(c) for c in fd_df.columns}, inplace=True) if not len(clu_df) == len(fd_df): raise ValueError( f'Number of cluster labels ({num_clusters}) does not match number of spikes' f'in {fd_file} ({len(fd_df)})') fd_df['clu_id'] = clu_df.cluster_id.astype('category') logger.debug( f'Feature {fet_name} loaded with {len(fd_df)} spikes, {fd_df.shape[1] - 1} dimensions ' ) images = [] titles = [] for cc in combinations(map(str, range(len(fd_df.columns) - 1)), r=2): fet_title = f'{fet_name}:{cc[1]} vs {fet_name}:{cc[0]}' x_range = (np.percentile(fd_df[cc[0]], 0.01), np.percentile(fd_df[cc[0]], 99.9)) y_range = (np.percentile(fd_df[cc[1]], 0.01), np.percentile(fd_df[cc[1]], 99.9)) logger.debug( f'shading {len(fd_df)} points in {fd_df.shape[1] - 1} dimensions' ) canvas = ds.Canvas(plot_width=300, plot_height=300, x_range=x_range, y_range=y_range) agg = canvas.points(fd_df, x=cc[0], y=cc[1], agg=ds.count_cat('clu_id')) with np.errstate(invalid='ignore'): img = ds_tf.shade(agg, how='log', color_key=color_keys) img = img if cli_args.no_spread else ds_tf.spread(img, px=1) images.append(img) titles.append(fet_title) logger.debug(f'Creating plot for {fet_name}') fet_fig = ds_plot_features(images, how='log', fet_titles=titles) crf.write(fig2html(fet_fig) + '</br>\n') plt.close(fet_fig)
def main(args): parser = argparse.ArgumentParser( 'Generate .fet and .fd files for features from spike waveforms') parser.add_argument('-v', '--verbose', action='store_true', help="Verbose (debug) output") parser.add_argument('target', default='.', help="""Directory with waveform .mat files.""") parser.add_argument( '-o', '--out_path', help='Output file path Defaults to current working directory') parser.add_argument('--sampling-rate', type=float, help='Sampling rate. Default 30000 Hz', default=3e4) parser.add_argument('-f', '--force', action='store_true', help='Force overwrite of existing files.') parser.add_argument('-a', '--align', help='Alignment method, default: min', default='min') parser.add_argument('-F', '--features', nargs='*', help='Features to use. Default: energy', default=['energy']) parser.add_argument('--to_fet', nargs='*', help='Features to include in fet file, default: all', default='all') parser.add_argument( '--ignore-prb', action='store_true', help='Do not load channel validity from dead channels in .prb files') parser.add_argument('--no-report', action='store_true', help='Do not generate report file (saves time)') cli_args = parser.parse_args(args) matpath = Path(cli_args.target).resolve() if matpath.is_file(): matfiles = [matpath] else: matfiles = sorted( list(map(Path.resolve, matpath.glob('tetrode??.mat')))) logger.debug(f'Target files: {[mf.name for mf in matfiles]}') logger.info('Found {} waveform files'.format(len(matfiles))) logger.debug(f'Requested to fet: {cli_args.to_fet}') # TODO: # per feature arguments sigma = 0.8 for nt, matfile in tqdm(enumerate(matfiles), total=len(matfiles)): outpath = matfile.parent / 'FD' if not outpath.exists(): outpath.mkdir() # Load prb file if it exists and set channel validity based on dead channels prb_path = matfile.with_suffix('.prb') if prb_path.exists(): prb = run_prb(prb_path) else: logger.warning( f'No probe file found for {matfile.name} and no channel validity given.' ) prb = None if prb is None or 'dead_channels' not in prb: channel_validity = [1, 1, 1, 1] else: channel_validity = [ int(ch not in prb['dead_channels']) for ch in prb['channel_groups'][0]['channels'] ] logger.debug('Channel validity: {}'.format(channel_validity) + ('' if all(channel_validity) else f', {4 - sum(channel_validity)} dead channel(s)')) hf = h5py.File(matfile, 'r') waveforms = np.array(hf['spikes'], dtype=PRECISION).reshape( [N_SAMPLES, N_CHANNELS, -1]) timestamps = np.array(hf['index'], dtype='double') gauss = gaussian_filter(waveforms, sigma) # indices = timestamps * sampling_rate / 1e4 features = {} # Allow to calculate all available features if len(cli_args.features) == 1 and cli_args.features[0].lower( ) == 'all': cli_args.features = AVAILABLE_FEATURES for fet_name in map(str.lower, cli_args.features): if fet_name == 'energy': logger.debug(f'Calculating {fet_name} feature') features['energy'] = scale_feature(feature_energy(waveforms)) elif fet_name == 'energy24': logger.debug(f'Calculating {fet_name} feature') features['energy24'] = scale_feature( feature_energy24(waveforms)) elif fet_name == 'peak': logger.debug(f'Calculating {fet_name} feature') features['peak'] = feature_peak(waveforms) elif fet_name == 'cpca': logging.debug(f'Calculating {fet_name} feature') cpca = scale_feature(feature_cPCA(waveforms)) logger.debug('cPCA shape {}'.format(cpca.shape)) features['cPCA'] = cpca elif fet_name == 'cpca24': logging.debug(f'Calculating {fet_name} feature') cpca24 = scale_feature(feature_cPCA24(waveforms)) logger.debug('cPCA24 shape {}'.format(cpca24.shape)) features['cPCA24'] = cpca24 elif fet_name == 'chwpca': logging.debug(f'Calculating {fet_name} feature') chwpca = scale_feature(feature_chwPCA(waveforms)) logger.debug('chwPCA shape {}'.format(chwpca.shape)) features['chwPCA'] = chwpca else: raise NotImplementedError( "Unknonw feature: {}".format(fet_name)) # TODO: # fet_cpca_4 = fet_cpca[:, :4] # # Position feature # n_bytes = [250154314, 101099824, 237970294] # fet_pos = feature_position(matpath / 'XY_data.mat', dat_offsets=n_bytes, timestamps=timestamps, # indices=indices) # Generate .fet file used for clustering # TODO: Best move this out into the cluster module? if 'none' in map(str.lower, cli_args.to_fet): logger.warning('Skipping fet file generation') else: fet_file_path = outpath / matfile.with_suffix('.fet.0').name if len(cli_args.to_fet) == 1 and cli_args.to_fet[0].lower( ) == 'all': logger.debug('Writing all features to fet file.') included_features = list(map(str.lower, features.keys())) else: included_features = [ fn for fn in map(str.lower, features.keys()) if fn in list(map(str.lower, cli_args.to_fet)) ] logger.info(f'Writing features {list(included_features)} to .fet') fet_data = [ fd for fn, fd in features.items() if fn.lower() in included_features ] logger.debug(f'Writing .fet file {fet_file_path}') write_features_fet(feature_data=fet_data, outpath=fet_file_path) # Write .fd file for each feature for fet_name, fet_data in features.items(): logger.debug(f'Writing feature {fet_name}.fd file') write_feature_fd(feature_names=fet_name, feature_data=fet_data, timestamps=timestamps, outpath=outpath, tetrode_path=matfile, channel_validity=channel_validity) logger.debug('Generating waveform graphic') with open(matfile.with_suffix('.html'), 'w') as frf: frf.write('<head></head><body><h1>{}</h1>'.format(matfile.name)) frf.write('<h2>Waveforms (n={})</h2>'.format(waveforms.shape[2])) density_agg = 'log' with np.errstate(invalid='ignore' ): # ignore some matplotlib colormap usage errors images = ds_shade_waveforms(waveforms, how=density_agg) fig = ds_plot_waveforms(images, density_agg) frf.write(fig2html(fig) + '</br>') del fig for fet_name, fet_data in features.items(): frf.write('<h3>Feature: {}</h3>\n'.format(fet_name)) df_fet = pd.DataFrame(fet_data) # numerical column names are an issue with datashader, stringify 'em df_fet.rename(columns={k: str(k) for k in df_fet.columns}, inplace=True) df_fet['time'] = timestamps fet_columns = df_fet.columns[:-1] # Features vs. features images = [] titles = [] for cc in list(combinations(fet_columns, 2)): fet_title = f'{fet_name}:{cc[1]} vs {fet_name}:{cc[0]}' logger.debug(f'plotting feature {fet_title}') # Calculate display limits, try to exclude outliers # TODO: correct axis labeling perc_lower = 0.05 perc_upper = 99.9 x_range = (np.percentile(df_fet[cc[0]], perc_lower), np.percentile(df_fet[cc[0]], perc_upper)) y_range = (np.percentile(df_fet[cc[1]], perc_lower), np.percentile(df_fet[cc[1]], perc_upper)) with np.errstate(invalid='ignore'): shade = ds_shade_feature(df_fet[[cc[0], cc[1]]], x_range=x_range, y_range=y_range, color_map='inferno') images.append(shade) titles.append(fet_title) fet_fig = ds_plot_features(images, how='log', fet_titles=titles) frf.write(fig2html(fet_fig) + '</br>\n') del fet_fig # Features over time t_images = [] t_titles = [] x_range = (0, df_fet['time'].max()) # Calculate display limits, try to exclude outliers # TODO: correct axis labeling perc_lower = 0.1 perc_upper = 99.9 y_range = (np.percentile(df_fet[cc[1]], perc_lower), np.percentile(df_fet[cc[1]], perc_upper)) for cc in fet_columns: t_title = f'{fet_name}:{cc} vs. time' logger.debug(f'plotting {t_title}') with np.errstate(invalid='ignore'): shade = ds_shade_feature(df_fet[['time', cc]], x_range=x_range, y_range=y_range, color_map='viridis') t_images.append(shade) t_titles.append(t_title) t_fig = ds_plot_features(t_images, how='log', fet_titles=t_titles) frf.write(fig2html(t_fig) + '</br>\n') del t_fig frf.write('</hr>\n')
def run_kk(params, run_kk=True): cfg,maxc, target_path = params tt_fname = target_path.name tetrode_file_stem = tt_fname.split(".")[0] tetrode_file_elecno = tt_fname.split(".")[-1] working_dir = target_path.parent logging.debug(f'Tetrode name: {tt_fname}, stem: {tetrode_file_stem}, ElecNo: {tetrode_file_elecno}') clu_file = working_dir / (tetrode_file_stem + f'.clu.{tetrode_file_elecno}') if clu_file.exists() and cfg['skip']: logging.error(f'Clu file {clu_file} exists. Skipping.') run_kk = False # Read in feature validity validity_path = target_path.with_suffix('.validity') if not validity_path.exists(): logger.warning('No explicit feature validity given, falling back to default = all used.') with open(validity_path) as vfp: validity_string = vfp.readline() logger.debug(f'Channel validity: {validity_string}') # Combine executable and arguments kk_executable = cfg["kk_executable"] kk_cmd = f'{kk_executable} {tetrode_file_stem} -ElecNo {tetrode_file_elecno} -UseFeatures {validity_string} -MaxPossibleClusters {maxc}' if cfg['KKv3']: kk_cmd += ' -UseDistributional 0' kk_cmd_list = kk_cmd.split(' ') logger.debug(f'KK COMMAND: {kk_cmd}') logger.debug(f'KK COMMAND LIST: {kk_cmd_list}') # Call KlustaKwik and gather output # TODO: Use communicate to interact with KK, i.e. write to log and monitor progress # see https://stackoverflow.com/questions/21953835/run-subprocess-and-print-output-to-logging logger.info('Starting KlustaKwik process') if cfg['PRINT_KK_OUTPUT']: stdout = None else: stdout = subprocess.PIPE if run_kk: kk_call = subprocess.run(kk_cmd_list, stderr=subprocess.STDOUT, stdout=stdout) kk_error = kk_call.returncode logger.debug('Writing KlustaKwik log file') logger.debug('Clu File: ' + str(clu_file)) if kk_call.stdout is not None: with open(clu_file.with_suffix('.log'), 'w') as log_file: log_file.write(kk_call.stdout.decode('ascii')) else: logging.warning('Missing stdout, not writing log file!') # Check call return code and output if kk_error: logging.error(f'KlustaKwik error code: {kk_error}') exit(kk_error) else: logging.debug(f'KlustaKwik successful: {kk_error}') # Load clu file logger.debug(f'Loading {clu_file}') clu_df = pd.read_csv(clu_file, dtype='category', names=['cluster_id'], skiprows=1) cluster_labels = clu_df['cluster_id'].cat.categories num_clusters = len(cluster_labels) logger.info(f'{len(clu_df)} spikes in {num_clusters} clusters') # Find all feature .fd files feature_files = list(working_dir.glob(tetrode_file_stem + '_*.fd')) ff_sizes = [ff.stat().st_mtime for ff in feature_files] feature_files = [f for t, f in sorted(zip(ff_sizes, feature_files))] if not len(feature_files): raise FileNotFoundError(f'No Feature Files found in {working_dir}') # TODO: Stupid, the feature names are in the .fd file already feature_names = [str(ff.name).split(tetrode_file_stem + '_')[1].split('.')[0] for ff in feature_files] logger.info(f'Loading features: {feature_names}') color_keys = cfg['CLUSTER_COLORS'] with open(clu_file.with_suffix('.html'), 'w') as crf: crf.write('<head></head><body><h1>{}</h1>'.format(clu_file.name)) for fd_file, fet_name in zip(feature_files, feature_names): crf.write('<h3>Feature: {}</h3>\n'.format(fet_name)) logger.info(f'Generating images for feature {fet_name}') if not fd_file.exists(): continue logger.debug(f'Loading {fd_file}') mat_fet = h5s.loadmat(str(fd_file), appendmat=False) fd_df = pd.DataFrame(mat_fet['FeatureData']) fd_df.rename(columns={c: str(c) for c in fd_df.columns}, inplace=True) if not len(clu_df) == len(fd_df): raise ValueError(f'Number of cluster labels ({num_clusters}) does not match number of spikes' f'in {fd_file} ({len(fd_df)})') fd_df['clu_id'] = clu_df.cluster_id.astype('category') logger.debug(f'Feature {fet_name} loaded with {len(fd_df)} spikes, {fd_df.shape[1] - 1} dimensions ') images = [] titles = [] for cc in combinations(map(str, range(len(fd_df.columns) - 1)), r=2): fet_title = f'x: {fet_name}:{cc[0]} vs y: {fet_name}:{cc[1]}' x_range = (np.percentile(fd_df[cc[0]], 0.01), np.percentile(fd_df[cc[0]], 99.9)) y_range = (np.percentile(fd_df[cc[1]], 0.01), np.percentile(fd_df[cc[1]], 99.9)) logger.debug(f'shading {len(fd_df)} points in {fd_df.shape[1] - 1} dimensions') canvas = ds.Canvas(plot_width=300, plot_height=300, x_range=x_range, y_range=y_range) try: agg = canvas.points(fd_df, x=cc[0], y=cc[1], agg=ds.count_cat('clu_id')) with np.errstate(invalid='ignore'): img = ds_tf.shade(agg, how='log', color_key=color_keys) img = img if cfg['no_spread'] else ds_tf.spread(img, px=1) except ZeroDivisionError: img = None images.append(img) titles.append(fet_title) logger.debug(f'Creating plot for {fet_name}') fet_fig = ds_plot_features(images, how='log', fet_titles=titles) crf.write(fig2html(fet_fig) + '</br>\n') plt.close(fet_fig)