def debug_callback(context): from dentate.plot import plot_2D_rate_map, close_figure fig_title = '%s %s cell %i' % ( context.population, context.this_selectivity_type_name, context.gid) fig_options = copy.copy(context.fig_options) if context.save_fig is not None: fig_options.saveFig = '%s %s' % (context.save_fig, fig_title) fig = plot_2D_rate_map(x=context.arena_x_mesh, y=context.arena_y_mesh, rate_map=context.rate_map, peak_rate=context.env.stimulus_config['Peak Rate'] [context.population][context.this_selectivity_type], title='%s\nNormalized cell position: %.3f' % (fig_title, context.norm_u_arc_distance), **fig_options()) close_figure(fig)
def main(config, config_prefix, coords_path, distances_namespace, bin_distance, selectivity_path, selectivity_namespace, subset_seed, arena_id, populations, io_size, cache_size, verbose, debug, show_fig, save_fig, save_fig_dir, font_size, fig_size, colormap, fig_format): """ :param config: str (.yaml file name) :param config_prefix: str (path to dir) :param coords_path: str (path to file) :param distances_namespace: str :param bin_distance: float :param selectivity_path: str :param subset_seed: int; for reproducible choice of gids to plot individual rate maps :param arena_id: str :param populations: tuple of str :param io_size: int :param cache_size: int :param verbose: bool :param debug: bool :param show_fig: bool :param save_fig: str (base file name) :param save_fig_dir: str (path to dir) :param font_size: float :param fig_format: str """ comm = MPI.COMM_WORLD rank = comm.rank config_logging(verbose) env = Env(comm=comm, config_file=config, config_prefix=config_prefix, template_paths=None) if io_size == -1: io_size = comm.size if rank == 0: logger.info('%i ranks have been allocated' % comm.size) fig_options = copy.copy(default_fig_options) fig_options.saveFigDir = save_fig_dir fig_options.fontSize = font_size fig_options.figFormat = fig_format fig_options.showFig = show_fig fig_options.figSize = fig_size if save_fig is not None: save_fig = '%s %s' % (save_fig, arena_id) fig_options.saveFig = save_fig population_ranges = read_population_ranges(selectivity_path, comm)[0] coords_population_ranges = read_population_ranges(coords_path, comm)[0] if len(populations) == 0: populations = ('MC', 'ConMC', 'LPP', 'GC', 'MPP', 'CA3c') valid_selectivity_namespaces = dict() if rank == 0: for population in populations: if population not in population_ranges: raise RuntimeError( 'plot_input_selectivity_features: specified population: %s not found in ' 'provided selectivity_path: %s' % (population, selectivity_path)) if population not in env.stimulus_config[ 'Selectivity Type Probabilities']: raise RuntimeError( 'plot_input_selectivity_features: selectivity type not specified for ' 'population: %s' % population) valid_selectivity_namespaces[population] = [] with h5py.File(selectivity_path, 'r') as selectivity_f: for this_namespace in selectivity_f['Populations'][population]: if f'{selectivity_namespace} {arena_id}' in this_namespace: valid_selectivity_namespaces[population].append( this_namespace) if len(valid_selectivity_namespaces[population]) == 0: raise RuntimeError( 'plot_input_selectivity_features: no selectivity data in arena: %s found ' 'for specified population: %s in provided selectivity_path: %s' % (arena_id, population, selectivity_path)) valid_selectivity_namespaces = comm.bcast(valid_selectivity_namespaces, root=0) selectivity_type_names = dict( (val, key) for (key, val) in viewitems(env.selectivity_types)) reference_u_arc_distance_bounds = None reference_v_arc_distance_bounds = None if rank == 0: for population in populations: if population not in coords_population_ranges: raise RuntimeError( 'plot_input_selectivity_features: specified population: %s not found in ' 'provided coords_path: %s' % (population, coords_path)) with h5py.File(coords_path, 'r') as coords_f: pop_size = population_ranges[population][1] unique_gid_count = len( set(coords_f['Populations'][population] [distances_namespace]['U Distance']['Cell Index'][:])) if pop_size != unique_gid_count: raise RuntimeError( 'plot_input_selectivity_features: only %i/%i unique cell indexes found ' 'for specified population: %s in provided coords_path: %s' % (unique_gid_count, pop_size, population, coords_path)) if reference_u_arc_distance_bounds is None: try: reference_u_arc_distance_bounds = \ coords_f['Populations'][population][distances_namespace].attrs['Reference U Min'], \ coords_f['Populations'][population][distances_namespace].attrs['Reference U Max'] except Exception: raise RuntimeError( 'plot_input_selectivity_features: problem locating attributes ' 'containing reference bounds in namespace: %s for population: %s from ' 'coords_path: %s' % (distances_namespace, population, coords_path)) if reference_v_arc_distance_bounds is None: try: reference_v_arc_distance_bounds = \ coords_f['Populations'][population][distances_namespace].attrs['Reference V Min'], \ coords_f['Populations'][population][distances_namespace].attrs['Reference V Max'] except Exception: raise RuntimeError( 'plot_input_selectivity_features: problem locating attributes ' 'containing reference bounds in namespace: %s for population: %s from ' 'coords_path: %s' % (distances_namespace, population, coords_path)) reference_u_arc_distance_bounds = comm.bcast( reference_u_arc_distance_bounds, root=0) reference_v_arc_distance_bounds = comm.bcast( reference_v_arc_distance_bounds, root=0) u_edges = np.arange(reference_u_arc_distance_bounds[0], reference_u_arc_distance_bounds[1] + bin_distance / 2., bin_distance) v_edges = np.arange(reference_v_arc_distance_bounds[0], reference_v_arc_distance_bounds[1] + bin_distance / 2., bin_distance) if arena_id not in env.stimulus_config['Arena']: raise RuntimeError( 'Arena with ID: %s not specified by configuration at file path: %s' % (arena_id, config_prefix + '/' + config)) arena = env.stimulus_config['Arena'][arena_id] arena_x_mesh, arena_y_mesh = None, None if rank == 0: arena_x_mesh, arena_y_mesh = \ get_2D_arena_spatial_mesh(arena=arena, spatial_resolution=env.stimulus_config['Spatial Resolution']) arena_x_mesh = comm.bcast(arena_x_mesh, root=0) arena_y_mesh = comm.bcast(arena_y_mesh, root=0) for population in populations: start_time = time.time() u_distances_by_gid = dict() v_distances_by_gid = dict() distances_attr_gen = \ bcast_cell_attributes(coords_path, population, root=0, namespace=distances_namespace, comm=comm) for gid, distances_attr_dict in distances_attr_gen: u_distances_by_gid[gid] = distances_attr_dict['U Distance'][0] v_distances_by_gid[gid] = distances_attr_dict['V Distance'][0] if rank == 0: logger.info( 'Reading %i cell positions for population %s took %.2f s' % (len(u_distances_by_gid), population, time.time() - start_time)) for this_selectivity_namespace in valid_selectivity_namespaces[ population]: start_time = time.time() if rank == 0: logger.info('Reading from %s namespace for population %s...' % (this_selectivity_namespace, population)) gid_count = 0 gathered_cell_attributes = defaultdict(list) gathered_component_attributes = defaultdict(list) u_distances_by_cell = list() v_distances_by_cell = list() u_distances_by_component = list() v_distances_by_component = list() rate_map_sum_by_module = defaultdict( lambda: np.zeros_like(arena_x_mesh)) start_time = time.time() selectivity_attr_gen = NeuroH5CellAttrGen( selectivity_path, population, namespace=this_selectivity_namespace, comm=comm, io_size=io_size, cache_size=cache_size) for iter_count, ( gid, selectivity_attr_dict) in enumerate(selectivity_attr_gen): if gid is not None: gid_count += 1 this_selectivity_type = selectivity_attr_dict[ 'Selectivity Type'][0] this_selectivity_type_name = selectivity_type_names[ this_selectivity_type] input_cell_config = \ get_input_cell_config(selectivity_type=this_selectivity_type, selectivity_type_names=selectivity_type_names, selectivity_attr_dict=selectivity_attr_dict) rate_map = input_cell_config.get_rate_map(x=arena_x_mesh, y=arena_y_mesh) u_distances_by_cell.append(u_distances_by_gid[gid]) v_distances_by_cell.append(v_distances_by_gid[gid]) this_cell_attrs, component_count, this_component_attrs = input_cell_config.gather_attributes( ) for attr_name, attr_val in viewitems(this_cell_attrs): gathered_cell_attributes[attr_name].append(attr_val) gathered_cell_attributes['Mean Rate'].append( np.mean(rate_map)) if component_count > 0: u_distances_by_component.extend( [u_distances_by_gid[gid]] * component_count) v_distances_by_component.extend( [v_distances_by_gid[gid]] * component_count) for attr_name, attr_val in viewitems( this_component_attrs): gathered_component_attributes[attr_name].extend( attr_val) this_module_id = this_cell_attrs['Module ID'] if debug and rank == 0: fig_title = '%s %s cell %i' % ( population, this_selectivity_type_name, gid) if save_fig is not None: fig_options.saveFig = '%s %s' % (save_fig, fig_title) plot_2D_rate_map( x=arena_x_mesh, y=arena_y_mesh, rate_map=rate_map, peak_rate=env.stimulus_config['Peak Rate'] [population][this_selectivity_type], title='%s\nModule: %i' % (fig_title, this_module_id), **fig_options()) rate_map_sum_by_module[this_module_id] = np.add( rate_map, rate_map_sum_by_module[this_module_id]) if debug and iter_count >= 10: break cell_count_hist, _, _ = np.histogram2d(u_distances_by_cell, v_distances_by_cell, bins=[u_edges, v_edges]) component_count_hist, _, _ = np.histogram2d( u_distances_by_component, v_distances_by_component, bins=[u_edges, v_edges]) if debug: context.update(locals()) gathered_cell_attr_hist = dict() gathered_component_attr_hist = dict() for key in gathered_cell_attributes: gathered_cell_attr_hist[key], _, _ = \ np.histogram2d(u_distances_by_cell, v_distances_by_cell, bins=[u_edges, v_edges], weights=gathered_cell_attributes[key]) for key in gathered_component_attributes: gathered_component_attr_hist[key], _, _ = \ np.histogram2d(u_distances_by_component, v_distances_by_component, bins=[u_edges, v_edges], weights=gathered_component_attributes[key]) gid_count = comm.gather(gid_count, root=0) cell_count_hist = comm.gather(cell_count_hist, root=0) component_count_hist = comm.gather(component_count_hist, root=0) gathered_cell_attr_hist = comm.gather(gathered_cell_attr_hist, root=0) gathered_component_attr_hist = comm.gather( gathered_component_attr_hist, root=0) rate_map_sum_by_module = dict(rate_map_sum_by_module) rate_map_sum_by_module = comm.gather(rate_map_sum_by_module, root=0) if rank == 0: gid_count = sum(gid_count) cell_count_hist = np.sum(cell_count_hist, axis=0) component_count_hist = np.sum(component_count_hist, axis=0) merged_cell_attr_hist = defaultdict( lambda: np.zeros_like(cell_count_hist)) merged_component_attr_hist = defaultdict( lambda: np.zeros_like(component_count_hist)) for each_cell_attr_hist in gathered_cell_attr_hist: for key in each_cell_attr_hist: merged_cell_attr_hist[key] = np.add( merged_cell_attr_hist[key], each_cell_attr_hist[key]) for each_component_attr_hist in gathered_component_attr_hist: for key in each_component_attr_hist: merged_component_attr_hist[key] = np.add( merged_component_attr_hist[key], each_component_attr_hist[key]) merged_rate_map_sum_by_module = defaultdict( lambda: np.zeros_like(arena_x_mesh)) for each_rate_map_sum_by_module in rate_map_sum_by_module: for this_module_id in each_rate_map_sum_by_module: merged_rate_map_sum_by_module[this_module_id] = \ np.add(merged_rate_map_sum_by_module[this_module_id], each_rate_map_sum_by_module[this_module_id]) logger.info('Processing %i %s %s cells took %.2f s' % (gid_count, population, this_selectivity_type_name, time.time() - start_time)) if debug: context.update(locals()) for key in merged_cell_attr_hist: fig_title = '%s %s cells %s distribution' % ( population, this_selectivity_type_name, key) if save_fig is not None: fig_options.saveFig = '%s %s' % (save_fig, fig_title) if colormap is not None: fig_options.colormap = colormap title = '%s %s cells\n%s distribution' % ( population, this_selectivity_type_name, key) fig = plot_2D_histogram( merged_cell_attr_hist[key], x_edges=u_edges, y_edges=v_edges, norm=cell_count_hist, ylabel='Transverse position (um)', xlabel='Septo-temporal position (um)', title=title, cbar_label='Mean value per bin', cbar=True, **fig_options()) close_figure(fig) for key in merged_component_attr_hist: fig_title = '%s %s cells %s distribution' % ( population, this_selectivity_type_name, key) if save_fig is not None: fig_options.saveFig = '%s %s' % (save_fig, fig_title) title = '%s %s cells\n%s distribution' % ( population, this_selectivity_type_name, key) fig = plot_2D_histogram( merged_component_attr_hist[key], x_edges=u_edges, y_edges=v_edges, norm=component_count_hist, ylabel='Transverse position (um)', xlabel='Septo-temporal position (um)', title=title, cbar_label='Mean value per bin', cbar=True, **fig_options()) close_figure(fig) for this_module_id in merged_rate_map_sum_by_module: fig_title = '%s %s Module %i summed rate maps' % \ (population, this_selectivity_type_name, this_module_id) if save_fig is not None: fig_options.saveFig = '%s %s' % (save_fig, fig_title) fig = plot_2D_rate_map( x=arena_x_mesh, y=arena_y_mesh, rate_map=merged_rate_map_sum_by_module[this_module_id], title='%s %s summed rate maps\nModule %i' % (population, this_selectivity_type_name, this_module_id), **fig_options()) close_figure(fig) if is_interactive and rank == 0: context.update(locals())
def main(config, config_prefix, coords_path, distances_namespace, bin_distance, selectivity_path, selectivity_namespace, spatial_resolution, arena_id, populations, io_size, cache_size, verbose, debug, show_fig, save_fig, save_fig_dir, font_size, fig_size, colormap, fig_format): """ :param config: str (.yaml file name) :param config_prefix: str (path to dir) :param coords_path: str (path to file) :param distances_namespace: str :param bin_distance: float :param selectivity_path: str :param arena_id: str :param populations: tuple of str :param io_size: int :param cache_size: int :param verbose: bool :param debug: bool :param show_fig: bool :param save_fig: str (base file name) :param save_fig_dir: str (path to dir) :param font_size: float :param fig_format: str """ comm = MPI.COMM_WORLD rank = comm.rank config_logging(verbose) env = Env(comm=comm, config_file=config, config_prefix=config_prefix, template_paths=None) if io_size == -1: io_size = comm.size if rank == 0: logger.info(f'{comm.size} ranks have been allocated') fig_options = copy.copy(default_fig_options) fig_options.saveFigDir = save_fig_dir fig_options.fontSize = font_size fig_options.figFormat = fig_format fig_options.showFig = show_fig fig_options.figSize = fig_size if save_fig is not None: save_fig = f'{save_fig} {arena_id}' fig_options.saveFig = save_fig population_ranges = read_population_ranges(selectivity_path, comm)[0] coords_population_ranges = read_population_ranges(coords_path, comm)[0] if len(populations) == 0: populations = ('MC', 'ConMC', 'LPP', 'GC', 'MPP', 'CA3c') valid_selectivity_namespaces = dict() if rank == 0: for population in populations: if population not in population_ranges: raise RuntimeError( f'plot_input_selectivity_features: specified population: {population} not found in ' f'provided selectivity_path: {selectivity_path}') if population not in env.stimulus_config[ 'Selectivity Type Probabilities']: raise RuntimeError( 'plot_input_selectivity_features: selectivity type not specified for ' f'population: {population}') valid_selectivity_namespaces[population] = [] with h5py.File(selectivity_path, 'r') as selectivity_f: for this_namespace in selectivity_f['Populations'][population]: if f'{selectivity_namespace} {arena_id}' in this_namespace: valid_selectivity_namespaces[population].append( this_namespace) if len(valid_selectivity_namespaces[population]) == 0: raise RuntimeError( f'plot_input_selectivity_features: no selectivity data in arena: {arena_id} found ' f'for specified population: {population} in provided selectivity_path: {selectivity_path}' ) valid_selectivity_namespaces = comm.bcast(valid_selectivity_namespaces, root=0) selectivity_type_names = dict( (val, key) for (key, val) in viewitems(env.selectivity_types)) reference_u_arc_distance_bounds = None reference_v_arc_distance_bounds = None if rank == 0: for population in populations: if population not in coords_population_ranges: raise RuntimeError( f'plot_input_selectivity_features: specified population: {population} not found in ' f'provided coords_path: {coords_path}') with h5py.File(coords_path, 'r') as coords_f: pop_size = population_ranges[population][1] unique_gid_count = len( set(coords_f['Populations'][population] [distances_namespace]['U Distance']['Cell Index'][:])) if pop_size != unique_gid_count: raise RuntimeError( f'plot_input_selectivity_features: only {unique_gid_count}/{pop_size} unique cell indexes found ' f'for specified population: {population} in provided coords_path: {coords_path}' ) if reference_u_arc_distance_bounds is None: try: reference_u_arc_distance_bounds = \ coords_f['Populations'][population][distances_namespace].attrs['Reference U Min'], \ coords_f['Populations'][population][distances_namespace].attrs['Reference U Max'] except Exception: raise RuntimeError( 'plot_input_selectivity_features: problem locating attributes ' f'containing reference bounds in namespace: {distances_namespace} ' f'for population: {population} from coords_path: {coords_path}' ) if reference_v_arc_distance_bounds is None: try: reference_v_arc_distance_bounds = \ coords_f['Populations'][population][distances_namespace].attrs['Reference V Min'], \ coords_f['Populations'][population][distances_namespace].attrs['Reference V Max'] except Exception: raise RuntimeError( 'plot_input_selectivity_features: problem locating attributes ' f'containing reference bounds in namespace: {distances_namespace} ' f'for population: {population} from coords_path: {coords_path}' ) reference_u_arc_distance_bounds = comm.bcast( reference_u_arc_distance_bounds, root=0) reference_v_arc_distance_bounds = comm.bcast( reference_v_arc_distance_bounds, root=0) u_edges = np.arange(reference_u_arc_distance_bounds[0], reference_u_arc_distance_bounds[1] + bin_distance / 2., bin_distance) v_edges = np.arange(reference_v_arc_distance_bounds[0], reference_v_arc_distance_bounds[1] + bin_distance / 2., bin_distance) if arena_id not in env.stimulus_config['Arena']: raise RuntimeError( f'Arena with ID: {arena_id} not specified by configuration at file path: {config_prefix}/{config}' ) if spatial_resolution is None: spatial_resolution = env.stimulus_config['Spatial Resolution'] arena = env.stimulus_config['Arena'][arena_id] arena_x_mesh, arena_y_mesh = None, None if rank == 0: arena_x_mesh, arena_y_mesh = \ get_2D_arena_spatial_mesh(arena=arena, spatial_resolution=spatial_resolution) arena_x_mesh = comm.bcast(arena_x_mesh, root=0) arena_y_mesh = comm.bcast(arena_y_mesh, root=0) x0_dict = {} y0_dict = {} for population in populations: start_time = time.time() u_distances_by_gid = dict() v_distances_by_gid = dict() distances_attr_gen = \ bcast_cell_attributes(coords_path, population, root=0, namespace=distances_namespace, comm=comm) for gid, distances_attr_dict in distances_attr_gen: u_distances_by_gid[gid] = distances_attr_dict['U Distance'][0] v_distances_by_gid[gid] = distances_attr_dict['V Distance'][0] if rank == 0: logger.info( f'Reading {len(u_distances_by_gid)} cell positions for population {population} took ' f'{time.time() - start_time:.2f} s') for this_selectivity_namespace in valid_selectivity_namespaces[ population]: start_time = time.time() if rank == 0: logger.info( f'Reading from {this_selectivity_namespace} namespace for population {population}...' ) gid_count = 0 gathered_cell_attributes = defaultdict(list) gathered_component_attributes = defaultdict(list) u_distances_by_cell = list() v_distances_by_cell = list() u_distances_by_component = list() v_distances_by_component = list() rate_map_sum_by_module = defaultdict( lambda: np.zeros_like(arena_x_mesh)) count_by_module = defaultdict(int) start_time = time.time() x0_list_by_module = defaultdict(list) y0_list_by_module = defaultdict(list) selectivity_attr_gen = NeuroH5CellAttrGen( selectivity_path, population, namespace=this_selectivity_namespace, comm=comm, io_size=io_size, cache_size=cache_size) for iter_count, ( gid, selectivity_attr_dict) in enumerate(selectivity_attr_gen): if gid is not None: gid_count += 1 this_selectivity_type = selectivity_attr_dict[ 'Selectivity Type'][0] this_selectivity_type_name = selectivity_type_names[ this_selectivity_type] input_cell_config = \ get_input_cell_config(selectivity_type=this_selectivity_type, selectivity_type_names=selectivity_type_names, selectivity_attr_dict=selectivity_attr_dict) rate_map = input_cell_config.get_rate_map(x=arena_x_mesh, y=arena_y_mesh) u_distances_by_cell.append(u_distances_by_gid[gid]) v_distances_by_cell.append(v_distances_by_gid[gid]) this_cell_attrs, component_count, this_component_attrs = input_cell_config.gather_attributes( ) for attr_name, attr_val in viewitems(this_cell_attrs): gathered_cell_attributes[attr_name].append(attr_val) gathered_cell_attributes['Mean Rate'].append( np.mean(rate_map)) if component_count > 0: u_distances_by_component.extend( [u_distances_by_gid[gid]] * component_count) v_distances_by_component.extend( [v_distances_by_gid[gid]] * component_count) for attr_name, attr_val in viewitems( this_component_attrs): gathered_component_attributes[attr_name].extend( attr_val) this_module_id = this_cell_attrs['Module ID'] if debug and rank == 0: fig_title = f'{population} {this_selectivity_type_name} cell {gid}' if save_fig is not None: fig_options.saveFig = f'{save_fig} {fig_title}' plot_2D_rate_map( x=arena_x_mesh, y=arena_y_mesh, rate_map=rate_map, peak_rate=env.stimulus_config['Peak Rate'] [population][this_selectivity_type], title=f'{fig_title}\nModule: {this_module_id}', **fig_options()) x0_list_by_module[this_module_id].append( selectivity_attr_dict['X Offset']) y0_list_by_module[this_module_id].append( selectivity_attr_dict['Y Offset']) rate_map_sum_by_module[this_module_id] = np.add( rate_map, rate_map_sum_by_module[this_module_id]) count_by_module[this_module_id] += 1 if debug and iter_count >= 10: break if rank == 0: logger.info( f'Done reading from {this_selectivity_namespace} namespace for population {population}...' ) cell_count_hist, _, _ = np.histogram2d(u_distances_by_cell, v_distances_by_cell, bins=[u_edges, v_edges]) component_count_hist, _, _ = np.histogram2d( u_distances_by_component, v_distances_by_component, bins=[u_edges, v_edges]) if debug: context.update(locals()) gathered_cell_attr_hist = dict() gathered_component_attr_hist = dict() for key in gathered_cell_attributes: gathered_cell_attr_hist[key], _, _ = \ np.histogram2d(u_distances_by_cell, v_distances_by_cell, bins=[u_edges, v_edges], weights=gathered_cell_attributes[key]) for key in gathered_component_attributes: gathered_component_attr_hist[key], _, _ = \ np.histogram2d(u_distances_by_component, v_distances_by_component, bins=[u_edges, v_edges], weights=gathered_component_attributes[key]) gid_count = comm.gather(gid_count, root=0) cell_count_hist = comm.gather(cell_count_hist, root=0) component_count_hist = comm.gather(component_count_hist, root=0) gathered_cell_attr_hist = comm.gather(gathered_cell_attr_hist, root=0) gathered_component_attr_hist = comm.gather( gathered_component_attr_hist, root=0) x0_list_by_module = dict(x0_list_by_module) y0_list_by_module = dict(y0_list_by_module) x0_list_by_module = comm.reduce(x0_list_by_module, op=mpi_op_merge_list_dict, root=0) y0_list_by_module = comm.reduce(y0_list_by_module, op=mpi_op_merge_list_dict, root=0) rate_map_sum_by_module = dict(rate_map_sum_by_module) rate_map_sum_by_module = comm.gather(rate_map_sum_by_module, root=0) count_by_module = dict(count_by_module) count_by_module = comm.reduce(count_by_module, op=mpi_op_merge_count_dict, root=0) if rank == 0: gid_count = sum(gid_count) cell_count_hist = np.sum(cell_count_hist, axis=0) component_count_hist = np.sum(component_count_hist, axis=0) merged_cell_attr_hist = defaultdict( lambda: np.zeros_like(cell_count_hist)) merged_component_attr_hist = defaultdict( lambda: np.zeros_like(component_count_hist)) for each_cell_attr_hist in gathered_cell_attr_hist: for key in each_cell_attr_hist: merged_cell_attr_hist[key] = np.add( merged_cell_attr_hist[key], each_cell_attr_hist[key]) for each_component_attr_hist in gathered_component_attr_hist: for key in each_component_attr_hist: merged_component_attr_hist[key] = np.add( merged_component_attr_hist[key], each_component_attr_hist[key]) merged_rate_map_sum_by_module = defaultdict( lambda: np.zeros_like(arena_x_mesh)) for each_rate_map_sum_by_module in rate_map_sum_by_module: for this_module_id in each_rate_map_sum_by_module: merged_rate_map_sum_by_module[this_module_id] = \ np.add(merged_rate_map_sum_by_module[this_module_id], each_rate_map_sum_by_module[this_module_id]) logger.info( f'Processing {gid_count} {population} {this_selectivity_type_name} cells ' f'took {time.time() - start_time:.2f} s') if debug: context.update(locals()) fig_title = f'{population} {this_selectivity_type_name} field offsets' if save_fig is not None: fig_options.saveFig = f'{save_fig} {fig_title}' for key in merged_cell_attr_hist: fig_title = f'{population} {this_selectivity_type_name} cells {key} distribution' if save_fig is not None: fig_options.saveFig = f'{save_fig} {fig_title}' if colormap is not None: fig_options.colormap = colormap title = f'{population} {this_selectivity_type_name} cells\n{key} distribution' fig = plot_2D_histogram( merged_cell_attr_hist[key], x_edges=u_edges, y_edges=v_edges, norm=cell_count_hist, ylabel='Transverse position (um)', xlabel='Septo-temporal position (um)', title=title, cbar_label='Mean value per bin', cbar=True, **fig_options()) close_figure(fig) for key in merged_component_attr_hist: fig_title = f'{population} {this_selectivity_type_name} cells {key} distribution' if save_fig is not None: fig_options.saveFig = f'{save_fig} {fig_title}' title = f'{population} {this_selectivity_type_name} cells\n{key} distribution' fig = plot_2D_histogram( merged_component_attr_hist[key], x_edges=u_edges, y_edges=v_edges, norm=component_count_hist, ylabel='Transverse position (um)', xlabel='Septo-temporal position (um)', title=title, cbar_label='Mean value per bin', cbar=True, **fig_options()) close_figure(fig) for this_module_id in merged_rate_map_sum_by_module: num_cells = count_by_module[this_module_id] x0 = np.concatenate(x0_list_by_module[this_module_id]) y0 = np.concatenate(y0_list_by_module[this_module_id]) fig_title = f'{population} {this_selectivity_type_name} Module {this_module_id} rate map' if save_fig is not None: fig_options.saveFig = f'{save_fig} {fig_title}' fig = plot_2D_rate_map( x=arena_x_mesh, y=arena_y_mesh, x0=x0, y0=y0, rate_map=merged_rate_map_sum_by_module[this_module_id], title= (f'{population} {this_selectivity_type_name} rate map\n' f'Module {this_module_id} ({num_cells} cells)'), **fig_options()) close_figure(fig) if is_interactive and rank == 0: context.update(locals())
def main(config, config_prefix, coords_path, distances_namespace, output_path, arena_id, populations, use_noise_gen, io_size, chunk_size, value_chunk_size, cache_size, write_size, verbose, gather, interactive, debug, debug_count, plot, show_fig, save_fig, save_fig_dir, font_size, fig_format, dry_run): """ :param config: str (.yaml file name) :param config_prefix: str (path to dir) :param coords_path: str (path to file) :param distances_namespace: str :param output_path: str :param arena_id: str :param populations: tuple of str :param io_size: int :param chunk_size: int :param value_chunk_size: int :param cache_size: int :param write_size: int :param verbose: bool :param gather: bool; whether to gather population attributes to rank 0 for interactive analysis or plotting :param interactive: bool :param debug: bool :param plot: bool :param show_fig: bool :param save_fig: str (base file name) :param save_fig_dir: str (path to dir) :param font_size: float :param fig_format: str :param dry_run: bool """ comm = MPI.COMM_WORLD rank = comm.rank config_logging(verbose) env = Env(comm=comm, config_file=config, config_prefix=config_prefix, template_paths=None) if io_size == -1: io_size = comm.size if rank == 0: logger.info(f'{comm.size} ranks have been allocated') if save_fig is not None: plot = True if plot: import matplotlib.pyplot as plt from dentate.plot import plot_2D_rate_map, default_fig_options, save_figure, clean_axes, close_figure fig_options = copy.copy(default_fig_options) fig_options.saveFigDir = save_fig_dir fig_options.fontSize = font_size fig_options.figFormat = fig_format fig_options.showFig = show_fig if save_fig is not None: save_fig = '%s %s' % (save_fig, arena_id) fig_options.saveFig = save_fig if not dry_run and rank == 0: if output_path is None: raise RuntimeError( 'generate_input_selectivity_features: missing output_path') if not os.path.isfile(output_path): input_file = h5py.File(coords_path, 'r') output_file = h5py.File(output_path, 'w') input_file.copy('/H5Types', output_file) input_file.close() output_file.close() comm.barrier() population_ranges = read_population_ranges(coords_path, comm)[0] if len(populations) == 0: populations = sorted(population_ranges.keys()) reference_u_arc_distance_bounds_dict = {} if rank == 0: for population in sorted(populations): if population not in population_ranges: raise RuntimeError( 'generate_input_selectivity_features: specified population: %s not found in ' 'provided coords_path: %s' % (population, coords_path)) if population not in env.stimulus_config[ 'Selectivity Type Probabilities']: raise RuntimeError( 'generate_input_selectivity_features: selectivity type not specified for ' 'population: %s' % population) with h5py.File(coords_path, 'r') as coords_f: pop_size = population_ranges[population][1] unique_gid_count = len( set(coords_f['Populations'][population] [distances_namespace]['U Distance']['Cell Index'][:])) if pop_size != unique_gid_count: raise RuntimeError( 'generate_input_selectivity_features: only %i/%i unique cell indexes found ' 'for specified population: %s in provided coords_path: %s' % (unique_gid_count, pop_size, population, coords_path)) try: reference_u_arc_distance_bounds_dict[population] = \ coords_f['Populations'][population][distances_namespace].attrs['Reference U Min'], \ coords_f['Populations'][population][distances_namespace].attrs['Reference U Max'] except Exception: raise RuntimeError( 'generate_input_selectivity_features: problem locating attributes ' 'containing reference bounds in namespace: %s for population: %s from ' 'coords_path: %s' % (distances_namespace, population, coords_path)) comm.barrier() reference_u_arc_distance_bounds_dict = comm.bcast( reference_u_arc_distance_bounds_dict, root=0) selectivity_type_names = dict([ (val, key) for (key, val) in viewitems(env.selectivity_types) ]) selectivity_type_namespaces = dict() for this_selectivity_type in selectivity_type_names: this_selectivity_type_name = selectivity_type_names[ this_selectivity_type] chars = list(this_selectivity_type_name) chars[0] = chars[0].upper() selectivity_type_namespaces[this_selectivity_type_name] = ''.join( chars) + ' Selectivity %s' % arena_id if arena_id not in env.stimulus_config['Arena']: raise RuntimeError( f'Arena with ID: {arena_id} not specified by configuration at file path: {config_prefix}/{config}' ) arena = env.stimulus_config['Arena'][arena_id] arena_x_mesh, arena_y_mesh = None, None if rank == 0: arena_x_mesh, arena_y_mesh = \ get_2D_arena_spatial_mesh(arena=arena, spatial_resolution=env.stimulus_config['Spatial Resolution']) arena_x_mesh = comm.bcast(arena_x_mesh, root=0) arena_y_mesh = comm.bcast(arena_y_mesh, root=0) local_random = np.random.RandomState() selectivity_seed_offset = int( env.model_config['Random Seeds']['Input Selectivity']) local_random.seed(selectivity_seed_offset - 1) selectivity_config = InputSelectivityConfig(env.stimulus_config, local_random) if plot and rank == 0: selectivity_config.plot_module_probabilities(**fig_options()) if (debug or interactive) and rank == 0: context.update(dict(locals())) pop_norm_distances = {} rate_map_sum = {} x0_dict = {} y0_dict = {} write_every = max(1, int(math.floor(write_size / comm.size))) for population in sorted(populations): if rank == 0: logger.info( f'Generating input selectivity features for population {population}...' ) reference_u_arc_distance_bounds = reference_u_arc_distance_bounds_dict[ population] modular = True if population in env.stimulus_config[ 'Non-modular Place Selectivity Populations']: modular = False noise_gen_dict = None if use_noise_gen: noise_gen_dict = {} if modular: for module_id in range(env.stimulus_config['Number Modules']): extent_x, extent_y = get_2D_arena_extents(arena) margin = round( selectivity_config.place_module_field_widths[module_id] / 2.) arena_x_bounds, arena_y_bounds = get_2D_arena_bounds( arena, margin=margin) noise_gen = MPINoiseGenerator( comm=comm, bounds=(arena_x_bounds, arena_y_bounds), tile_rank=comm.rank, bin_size=0.5, mask_fraction=0.99, seed=int(selectivity_seed_offset + module_id * 1e6)) noise_gen_dict[module_id] = noise_gen else: margin = round( np.mean(selectivity_config.place_module_field_widths) / 2.) arena_x_bounds, arena_y_bounds = get_2D_arena_bounds( arena, margin=margin) noise_gen_dict[-1] = MPINoiseGenerator( comm=comm, bounds=(arena_x_bounds, arena_y_bounds), tile_rank=comm.rank, bin_size=0.5, mask_fraction=0.99, seed=selectivity_seed_offset) this_pop_norm_distances = {} this_rate_map_sum = defaultdict(lambda: np.zeros_like(arena_x_mesh)) this_x0_list = [] this_y0_list = [] start_time = time.time() gid_count = defaultdict(lambda: 0) distances_attr_gen = NeuroH5CellAttrGen(coords_path, population, namespace=distances_namespace, comm=comm, io_size=io_size, cache_size=cache_size) selectivity_attr_dict = dict( (key, dict()) for key in env.selectivity_types) for iter_count, (gid, distances_attr_dict) in enumerate(distances_attr_gen): req = comm.Ibarrier() if gid is None: if noise_gen_dict is not None: all_module_ids = [-1] if modular: all_module_ids = comm.allreduce(set([]), op=mpi_op_set_union) for module_id in all_module_ids: this_noise_gen = noise_gen_dict[module_id] global_num_fields = this_noise_gen.sync(0) for i in range(global_num_fields): this_noise_gen.add( np.empty(shape=(0, 0), dtype=np.float32), None) else: if rank == 0: logger.info( f'Rank {rank} generating selectivity features for gid {gid}...' ) u_arc_distance = distances_attr_dict['U Distance'][0] v_arc_distance = distances_attr_dict['V Distance'][0] norm_u_arc_distance = ( (u_arc_distance - reference_u_arc_distance_bounds[0]) / (reference_u_arc_distance_bounds[1] - reference_u_arc_distance_bounds[0])) this_pop_norm_distances[gid] = norm_u_arc_distance this_selectivity_type_name, this_selectivity_attr_dict = \ generate_input_selectivity_features(env, population, arena, arena_x_mesh, arena_y_mesh, gid, (norm_u_arc_distance, v_arc_distance), selectivity_config, selectivity_type_names, selectivity_type_namespaces, noise_gen_dict=noise_gen_dict, rate_map_sum=this_rate_map_sum, debug= (debug_callback, context) if debug else False) if 'X Offset' in this_selectivity_attr_dict: this_x0_list.append(this_selectivity_attr_dict['X Offset']) this_y0_list.append(this_selectivity_attr_dict['Y Offset']) selectivity_attr_dict[this_selectivity_type_name][ gid] = this_selectivity_attr_dict gid_count[this_selectivity_type_name] += 1 if noise_gen_dict is not None: for m in noise_gen_dict: noise_gen_dict[m].tile_rank = ( noise_gen_dict[m].tile_rank + 1) % comm.size req.wait() if (iter_count > 0 and iter_count % write_every == 0) or (debug and iter_count == debug_count): total_gid_count = 0 gid_count_dict = dict(gid_count.items()) req = comm.Ibarrier() selectivity_gid_count = comm.reduce(gid_count_dict, root=0, op=mpi_op_merge_count_dict) req.wait() if rank == 0: for selectivity_type_name in selectivity_gid_count: total_gid_count += selectivity_gid_count[ selectivity_type_name] for selectivity_type_name in selectivity_gid_count: logger.info( 'generated selectivity features for %i/%i %s %s cells in %.2f s' % (selectivity_gid_count[selectivity_type_name], total_gid_count, population, selectivity_type_name, (time.time() - start_time))) if not dry_run: for selectivity_type_name in sorted( selectivity_attr_dict.keys()): req = comm.Ibarrier() if rank == 0: logger.info( f'writing selectivity features for {population} [{selectivity_type_name}]...' ) selectivity_type_namespace = selectivity_type_namespaces[ selectivity_type_name] append_cell_attributes( output_path, population, selectivity_attr_dict[selectivity_type_name], namespace=selectivity_type_namespace, comm=comm, io_size=io_size, chunk_size=chunk_size, value_chunk_size=value_chunk_size) req.wait() del selectivity_attr_dict selectivity_attr_dict = dict( (key, dict()) for key in env.selectivity_types) gc.collect() if debug and iter_count >= debug_count: break pop_norm_distances[population] = this_pop_norm_distances rate_map_sum[population] = dict(this_rate_map_sum) if len(this_x0_list) > 0: x0_dict[population] = np.concatenate(this_x0_list, axis=None) y0_dict[population] = np.concatenate(this_y0_list, axis=None) total_gid_count = 0 gid_count_dict = dict(gid_count.items()) req = comm.Ibarrier() selectivity_gid_count = comm.reduce(gid_count_dict, root=0, op=mpi_op_merge_count_dict) req.wait() if rank == 0: for selectivity_type_name in selectivity_gid_count: total_gid_count += selectivity_gid_count[selectivity_type_name] for selectivity_type_name in selectivity_gid_count: logger.info( 'generated selectivity features for %i/%i %s %s cells in %.2f s' % (selectivity_gid_count[selectivity_type_name], total_gid_count, population, selectivity_type_name, (time.time() - start_time))) if not dry_run: for selectivity_type_name in sorted(selectivity_attr_dict.keys()): req = comm.Ibarrier() if rank == 0: logger.info( f'writing selectivity features for {population} [{selectivity_type_name}]...' ) selectivity_type_namespace = selectivity_type_namespaces[ selectivity_type_name] append_cell_attributes( output_path, population, selectivity_attr_dict[selectivity_type_name], namespace=selectivity_type_namespace, comm=comm, io_size=io_size, chunk_size=chunk_size, value_chunk_size=value_chunk_size) req.wait() del selectivity_attr_dict gc.collect() req = comm.Ibarrier() req.wait() if gather: merged_pop_norm_distances = {} for population in sorted(populations): merged_pop_norm_distances[population] = \ comm.reduce(pop_norm_distances[population], root=0, op=mpi_op_merge_dict) merged_rate_map_sum = comm.reduce(rate_map_sum, root=0, op=mpi_op_merge_rate_map_dict) merged_x0 = comm.reduce(x0_dict, root=0, op=mpi_op_concatenate_ndarray_dict) merged_y0 = comm.reduce(y0_dict, root=0, op=mpi_op_concatenate_ndarray_dict) if rank == 0: if plot: for population in merged_pop_norm_distances: norm_distance_values = np.asarray( list(merged_pop_norm_distances[population].values())) hist, edges = np.histogram(norm_distance_values, bins=100) fig, axes = plt.subplots(1) axes.plot(edges[1:], hist) axes.set_title(f'Population: {population}') axes.set_xlabel('Normalized cell position') axes.set_ylabel('Cell count') clean_axes(axes) if save_fig is not None: save_figure( f'{save_fig} {population} normalized distances histogram', fig=fig, **fig_options()) if fig_options.showFig: fig.show() close_figure(fig) for population in merged_rate_map_sum: for selectivity_type_name in merged_rate_map_sum[ population]: fig_title = f'{population} {this_selectivity_type_name} summed rate maps' if save_fig is not None: fig_options.saveFig = f'{save_fig} {fig_title}' plot_2D_rate_map( x=arena_x_mesh, y=arena_y_mesh, rate_map=merged_rate_map_sum[population] [selectivity_type_name], title= f'Summed rate maps\n{population} {selectivity_type_name} cells', **fig_options()) for population in merged_x0: fig_title = f'{population} field offsets' if save_fig is not None: fig_options.saveFig = f'{save_fig} {fig_title}' x0 = merged_x0[population] y0 = merged_y0[population] fig, axes = plt.subplots(1) axes.scatter(x0, y0) if save_fig is not None: save_figure(f'{save_fig} {fig_title}', fig=fig, **fig_options()) if fig_options.showFig: fig.show() close_figure(fig) if interactive and rank == 0: context.update(locals())