def rate_maps_from_features(env, pop_name, input_features_path, input_features_namespace, cell_index_set, time_range=None, n_trials=1): """Initializes presynaptic spike sources from a file with input selectivity features represented as firing rates.""" if time_range is not None: if time_range[0] is None: time_range[0] = 0.0 spatial_resolution = float(env.stimulus_config['Spatial Resolution']) temporal_resolution = float(env.stimulus_config['Temporal Resolution']) this_input_features_namespace = '%s %s' % (input_features_namespace, env.arena_id) input_features_attr_names = [ 'Selectivity Type', 'Num Fields', 'Field Width', 'Peak Rate', 'Module ID', 'Grid Spacing', 'Grid Orientation', 'Field Width Concentration Factor', 'X Offset', 'Y Offset' ] selectivity_type_names = { i: n for n, i in viewitems(env.selectivity_types) } arena = env.stimulus_config['Arena'][env.arena_id] arena_x, arena_y = stimulus.get_2D_arena_spatial_mesh( arena=arena, spatial_resolution=spatial_resolution) trajectory = arena.trajectories[env.trajectory_id] t, x, y, d = stimulus.generate_linear_trajectory( trajectory, temporal_resolution=temporal_resolution) if time_range is not None: t_range_inds = np.where((t < time_range[1]) & (t >= time_range[0]))[0] t = t[t_range_inds] x = x[t_range_inds] y = y[t_range_inds] d = d[t_range_inds] input_rate_map_dict = {} pop_index = int(env.Populations[pop_name]) input_features_iter = scatter_read_cell_attribute_selection( input_features_path, pop_name, selection=list(cell_index_set), namespace=this_input_features_namespace, mask=set(input_features_attr_names), comm=env.comm, io_size=env.io_size) for gid, selectivity_attr_dict in input_features_iter: this_selectivity_type = selectivity_attr_dict['Selectivity Type'][0] this_selectivity_type_name = selectivity_type_names[ this_selectivity_type] input_cell_config = stimulus.get_input_cell_config( selectivity_type=this_selectivity_type, selectivity_type_names=selectivity_type_names, selectivity_attr_dict=selectivity_attr_dict) if input_cell_config.num_fields > 0: rate_map = input_cell_config.get_rate_map(x=x, y=y) input_rate_map_dict[gid] = rate_map return input_rate_map_dict
def main(config, coordinates, gid, field_width, peak_rate, input_features_path, input_features_namespaces, output_features_namespace, output_weights_path, output_features_path, initial_weights_path, reference_weights_path, h5types_path, synapse_name, initial_weights_namespace, reference_weights_namespace, output_weights_namespace, reference_weights_are_delta, connections_path, optimize_method, destination, sources, arena_id, max_delta_weight, field_width_scale, max_iter, verbose, dry_run, plot): """ :param config: str (path to .yaml file) :param coordinates: tuple of float :param gid: int :param field_width: float :param peak_rate: float :param input_features_path: str (path to .h5 file) :param input_features_namespaces: str :param output_features_namespace: str :param output_weights_path: str (path to .h5 file) :param output_features_path: str (path to .h5 file) :param initial_weights_path: str (path to .h5 file) :param reference_weights_path: str (path to .h5 file) :param h5types_path: str (path to .h5 file) :param synapse_name: str :param initial_weights_namespace: str :param output_weights_namespace: str :param reference_weights_are_delta: bool :param connections_path: str (path to .h5 file) :param destination: str (population name) :param sources: list of str (population name) :param arena_id: str :param max_delta_weight: float :param field_width_scale: float :param max_iter: int :param verbose: bool :param dry_run: bool :param interactive: bool :param plot: bool """ utils.config_logging(verbose) logger = utils.get_script_logger(__file__) env = Env(config_file=config) if not dry_run: if output_weights_path is None: raise RuntimeError( 'Missing required argument: output_weights_path.') if not os.path.isfile(output_weights_path): if initial_weights_path is not None and os.path.isfile( initial_weights_path): input_file_path = initial_weights_path elif h5types_path is not None and os.path.isfile(h5types_path): input_file_path = h5types_path else: raise RuntimeError( 'Missing required source for h5types: either an initial_weights_path or an ' 'h5types_path must be provided.') with h5py.File(output_weights_path, 'a') as output_file: with h5py.File(input_file_path, 'r') as input_file: input_file.copy('/H5Types', output_file) this_input_features_namespaces = [ '%s %s' % (input_features_namespace, arena_id) for input_features_namespace in input_features_namespaces ] features_attr_names = ['Arena Rate Map'] spatial_resolution = env.stimulus_config['Spatial Resolution'] # cm arena = env.stimulus_config['Arena'][arena_id] default_run_vel = arena.properties['default run velocity'] # cm/s arena_x, arena_y = stimulus.get_2D_arena_spatial_mesh( arena, spatial_resolution) dim_x = len(arena_x) dim_y = len(arena_y) if gid is None: target_gids = [] else: target_gids = [gid] dst_input_features = defaultdict(dict) num_fields = len(coordinates) this_field_width = np.array([field_width] * num_fields, dtype=np.float32) this_scaled_field_width = np.array([field_width * field_width_scale] * num_fields, dtype=np.float32) this_peak_rate = np.array([peak_rate] * num_fields, dtype=np.float32) this_x0 = np.array([x for x, y in coordinates], dtype=np.float32) this_y0 = np.array([y for x, y in coordinates], dtype=np.float32) this_rate_map = np.asarray(get_rate_map(this_x0, this_y0, this_field_width, this_peak_rate, arena_x, arena_y), dtype=np.float32) target_map = np.asarray(get_rate_map(this_x0, this_y0, this_scaled_field_width, this_peak_rate, arena_x, arena_y), dtype=np.float32) selectivity_type = env.selectivity_types['place'] dst_input_features[destination][target_gid] = { 'Selectivity Type': np.array([selectivity_type], dtype=np.uint8), 'Num Fields': np.array([num_fields], dtype=np.uint8), 'Field Width': this_field_width, 'Peak Rate': this_peak_rate, 'X Offset': this_x0, 'Y Offset': this_y0, 'Arena Rate Map': this_rate_map.ravel() } initial_weights_by_syn_id_dict = dict() selection = [target_gid] if initial_weights_path is not None: initial_weights_iter = \ read_cell_attribute_selection(initial_weights_path, destination, namespace=initial_weights_namespace, selection=selection) syn_weight_attr_dict = dict(initial_weights_iter) syn_ids = syn_weight_attr_dict[target_gid]['syn_id'] weights = syn_weight_attr_dict[target_gid][synapse_name] for (syn_id, weight) in zip(syn_ids, weights): initial_weights_by_syn_id_dict[int(syn_id)] = float(weight) logger.info( 'destination: %s; gid %i; read initial synaptic weights for %i synapses' % (destination, target_gid, len(initial_weights_by_syn_id_dict))) reference_weights_by_syn_id_dict = None if reference_weights_path is not None: reference_weights_by_syn_id_dict = dict() reference_weights_iter = \ read_cell_attribute_selection(reference_weights_path, destination, namespace=reference_weights_namespace, selection=selection) syn_weight_attr_dict = dict(reference_weights_iter) syn_ids = syn_weight_attr_dict[target_gid]['syn_id'] weights = syn_weight_attr_dict[target_gid][synapse_name] for (syn_id, weight) in zip(syn_ids, weights): reference_weights_by_syn_id_dict[int(syn_id)] = float(weight) logger.info( 'destination: %s; gid %i; read reference synaptic weights for %i synapses' % (destination, target_gid, len(reference_weights_by_syn_id_dict))) source_gid_set_dict = defaultdict(set) syn_ids_by_source_gid_dict = defaultdict(list) initial_weights_by_source_gid_dict = dict() if reference_weights_by_syn_id_dict is None: reference_weights_by_source_gid_dict = None else: reference_weights_by_source_gid_dict = dict() (graph, edge_attr_info) = read_graph_selection(file_name=connections_path, selection=[target_gid], namespaces=['Synapses']) syn_id_attr_index = None for source, edge_iter in viewitems(graph[destination]): if source not in sources: continue this_edge_attr_info = edge_attr_info[destination][source] if 'Synapses' in this_edge_attr_info and \ 'syn_id' in this_edge_attr_info['Synapses']: syn_id_attr_index = this_edge_attr_info['Synapses']['syn_id'] for (destination_gid, edges) in edge_iter: assert destination_gid == target_gid source_gids, edge_attrs = edges syn_ids = edge_attrs['Synapses'][syn_id_attr_index] count = 0 for i in range(len(source_gids)): this_source_gid = int(source_gids[i]) source_gid_set_dict[source].add(this_source_gid) this_syn_id = int(syn_ids[i]) if this_syn_id not in initial_weights_by_syn_id_dict: this_weight = \ env.connection_config[destination][source].mechanisms['default'][synapse_name]['weight'] initial_weights_by_syn_id_dict[this_syn_id] = this_weight syn_ids_by_source_gid_dict[this_source_gid].append(this_syn_id) if this_source_gid not in initial_weights_by_source_gid_dict: initial_weights_by_source_gid_dict[this_source_gid] = \ initial_weights_by_syn_id_dict[this_syn_id] if reference_weights_by_source_gid_dict is not None: reference_weights_by_source_gid_dict[this_source_gid] = \ reference_weights_by_syn_id_dict[this_syn_id] count += 1 logger.info( 'destination: %s; gid %i; set initial synaptic weights for %d inputs from source population ' '%s' % (destination, destination_gid, count, source)) syn_count_by_source_gid_dict = dict() for source_gid in syn_ids_by_source_gid_dict: syn_count_by_source_gid_dict[source_gid] = len( syn_ids_by_source_gid_dict[source_gid]) input_rate_maps_by_source_gid_dict = dict() for source in sources: source_gids = list(source_gid_set_dict[source]) for input_features_namespace in this_input_features_namespaces: input_features_iter = read_cell_attribute_selection( input_features_path, source, namespace=input_features_namespace, mask=set(features_attr_names), selection=source_gids) count = 0 for gid, attr_dict in input_features_iter: input_rate_maps_by_source_gid_dict[gid] = attr_dict[ 'Arena Rate Map'].reshape((dim_x, dim_y)) count += 1 logger.info('Read %s feature data for %i cells in population %s' % (input_features_namespace, count, source)) if is_interactive: context.update(locals()) normalized_delta_weights_dict, arena_LS_map = \ synapses.generate_structured_weights(target_map=target_map, initial_weight_dict=initial_weights_by_source_gid_dict, input_rate_map_dict=input_rate_maps_by_source_gid_dict, syn_count_dict=syn_count_by_source_gid_dict, max_delta_weight=max_delta_weight, arena_x=arena_x, arena_y=arena_y, reference_weight_dict=reference_weights_by_source_gid_dict, reference_weights_are_delta=reference_weights_are_delta, reference_weights_namespace=reference_weights_namespace, optimize_method=optimize_method, verbose=verbose, plot=plot) output_syn_ids = np.empty(len(initial_weights_by_syn_id_dict), dtype='uint32') output_weights = np.empty(len(initial_weights_by_syn_id_dict), dtype='float32') i = 0 for source_gid, this_weight in viewitems(normalized_delta_weights_dict): for syn_id in syn_ids_by_source_gid_dict[source_gid]: output_syn_ids[i] = syn_id output_weights[i] = this_weight i += 1 output_weights_dict = { target_gid: { 'syn_id': output_syn_ids, synapse_name: output_weights } } logger.info('destination: %s; gid %i; generated %s for %i synapses' % (destination, target_gid, output_weights_namespace, len(output_weights))) if not dry_run: this_output_weights_namespace = '%s %s' % (output_weights_namespace, arena_id) logger.info('Destination: %s; appending %s ...' % (destination, this_output_weights_namespace)) append_cell_attributes(output_weights_path, destination, output_weights_dict, namespace=this_output_weights_namespace) logger.info('Destination: %s; appended %s' % (destination, this_output_weights_namespace)) output_weights_dict.clear() if output_features_path is not None: this_output_features_namespace = '%s %s' % ( output_features_namespace, arena_id) cell_attr_dict = dst_input_features[destination] cell_attr_dict[target_gid]['Arena State Map'] = np.asarray( arena_LS_map.ravel(), dtype=np.float32) logger.info('Destination: %s; appending %s ...' % (destination, this_output_features_namespace)) append_cell_attributes(output_features_path, destination, cell_attr_dict, namespace=this_output_features_namespace) if is_interactive: context.update(locals())
def main(config, coordinates, field_width, gid, input_features_path, input_features_namespaces, initial_weights_path, output_features_namespace, output_features_path, output_weights_path, reference_weights_path, h5types_path, synapse_name, initial_weights_namespace, output_weights_namespace, reference_weights_namespace, connections_path, destination, sources, non_structured_sources, non_structured_weights_namespace, non_structured_weights_path, arena_id, field_width_scale, max_opt_iter, max_weight_decay_fraction, optimize_tol, peak_rate, reference_weights_are_delta, arena_margin, target_amplitude, io_size, chunk_size, value_chunk_size, cache_size, write_size, verbose, dry_run, plot, show_fig, save_fig, debug): """ :param config: str (path to .yaml file) :param input_features_path: str (path to .h5 file) :param initial_weights_path: str (path to .h5 file) :param initial_weights_namespace: str :param output_weights_namespace: str :param connections_path: str (path to .h5 file) :param destination: str :param sources: list of str :param io_size: :param chunk_size: :param value_chunk_size: :param write_size: :param verbose: :param dry_run: :return: """ utils.config_logging(verbose) script_name = __file__ logger = utils.get_script_logger(script_name) comm = MPI.COMM_WORLD rank = comm.rank nranks = comm.size if io_size == -1: io_size = comm.size if rank == 0: logger.info(f'{comm.size} ranks have been allocated') env = Env(comm=comm, config_file=config, io_size=io_size) env.comm.barrier() if plot and (not save_fig) and (not show_fig): show_fig = True if (not dry_run) and (rank == 0): if not os.path.isfile(output_weights_path): if initial_weights_path is not None: input_file = h5py.File(initial_weights_path, 'r') elif h5types_path is not None: input_file = h5py.File(h5types_path, 'r') else: raise RuntimeError( 'h5types input path must be specified when weights path is not specified.' ) output_file = h5py.File(output_weights_path, 'w') input_file.copy('/H5Types', output_file) input_file.close() output_file.close() env.comm.barrier() LTD_output_weights_namespace = f'LTD {output_weights_namespace} {arena_id}' LTP_output_weights_namespace = f'LTP {output_weights_namespace} {arena_id}' this_input_features_namespaces = [ f'{input_features_namespace} {arena_id}' for input_features_namespace in input_features_namespaces ] selectivity_type_index = { i: n for n, i in viewitems(env.selectivity_types) } target_selectivity_type_name = 'place' target_selectivity_type = env.selectivity_types[ target_selectivity_type_name] features_attrs = defaultdict(dict) source_features_attr_names = [ 'Selectivity Type', 'Num Fields', 'Field Width', 'Peak Rate', 'Module ID', 'Grid Spacing', 'Grid Orientation', 'Field Width Concentration Factor', 'X Offset', 'Y Offset' ] target_features_attr_names = [ 'Selectivity Type', 'Num Fields', 'Field Width', 'Peak Rate', 'X Offset', 'Y Offset' ] seed_offset = int( env.model_config['Random Seeds']['GC Structured Weights']) spatial_resolution = env.stimulus_config['Spatial Resolution'] # cm arena = env.stimulus_config['Arena'][arena_id] default_run_vel = arena.properties['default run velocity'] # cm/s gid_count = 0 start_time = time.time() target_gid_set = None if len(gid) > 0: target_gid_set = set(gid) projections = [(source, destination) for source in sources] graph_info = read_graph_info(connections_path, namespaces=['Connections', 'Synapses'], read_node_index=True) for projection in projections: if projection not in graph_info: raise RuntimeError( f'Projection {projection[0]} -> {projection[1]} is not present in connections file.' ) if target_gid_set is None: target_gid_set = set(graph_info[projection][1]) all_sources = sources + non_structured_sources src_input_features_attr_dict = {source: {} for source in all_sources} for source in sorted(all_sources): this_src_input_features_attr_dict = {} for this_input_features_namespace in this_input_features_namespaces: logger.info( f'Rank {rank}: Reading {this_input_features_namespace} feature data for cells in population {source}' ) input_features_dict = scatter_read_cell_attributes( input_features_path, source, namespaces=[this_input_features_namespace], mask=set(source_features_attr_names), comm=env.comm, io_size=env.io_size) for gid, attr_dict in input_features_dict[ this_input_features_namespace]: this_src_input_features_attr_dict[gid] = attr_dict src_input_features_attr_dict[ source] = this_src_input_features_attr_dict source_gid_count = env.comm.reduce( len(this_src_input_features_attr_dict), op=MPI.SUM, root=0) if rank == 0: logger.info( f'Rank {rank}: Read feature data for {source_gid_count} cells in population {source}' ) dst_gids = [] if target_gid_set is not None: for i, gid in enumerate(target_gid_set): if i % nranks == rank: dst_gids.append(gid) dst_input_features_attr_dict = {} for this_input_features_namespace in this_input_features_namespaces: feature_count = 0 gid_count = 0 logger.info( f'Rank {rank}: reading {this_input_features_namespace} feature data for {len(dst_gids)} cells in population {destination}' ) input_features_iter = scatter_read_cell_attribute_selection( input_features_path, destination, namespace=this_input_features_namespace, mask=set(target_features_attr_names), selection=dst_gids, io_size=env.io_size, comm=env.comm) for gid, attr_dict in input_features_iter: gid_count += 1 if (len(coordinates) > 0) or (attr_dict['Num Fields'][0] > 0): dst_input_features_attr_dict[gid] = attr_dict feature_count += 1 logger.info( f'Rank {rank}: read {this_input_features_namespace} feature data for ' f'{gid_count} / {feature_count} cells in population {destination}') feature_count = env.comm.reduce(feature_count, op=MPI.SUM, root=0) env.comm.barrier() if rank == 0: logger.info( f'Read {this_input_features_namespace} feature data for {feature_count} cells in population {destination}' ) feature_dst_gids = list(dst_input_features_attr_dict.keys()) all_feature_gids_per_rank = comm.allgather(feature_dst_gids) all_feature_gids = sorted( [item for sublist in all_feature_gids_per_rank for item in sublist]) request_dst_gids = [] for i, gid in enumerate(all_feature_gids): if i % nranks == rank: request_dst_gids.append(gid) dst_input_features_attr_dict = exchange_input_features( env.comm, request_dst_gids, dst_input_features_attr_dict) dst_gids = list(dst_input_features_attr_dict.keys()) if rank == 0: logger.info( f"Rank {rank} feature dict is {dst_input_features_attr_dict}") dst_count = env.comm.reduce(len(dst_gids), op=MPI.SUM, root=0) logger.info(f"Rank {rank} has {len(dst_gids)} feature gids") if rank == 0: logger.info(f'Total {dst_count} feature gids') max_dst_count = env.comm.allreduce(len(dst_gids), op=MPI.MAX) env.comm.barrier() max_iter_count = max_dst_count output_features_dict = {} LTP_output_weights_dict = {} LTD_output_weights_dict = {} non_structured_output_weights_dict = {} for iter_count in range(max_iter_count): gc.collect() local_time = time.time() selection = [] if len(dst_gids) > 0: dst_gid = dst_gids.pop() selection.append(dst_gid) logger.info(f'Rank {rank} received gid {dst_gid}') env.comm.barrier() arena_margin_size = 0. arena_margin = max(arena_margin, 0.) target_selectivity_features_dict = {} target_selectivity_config_dict = {} target_field_width_dict = {} for destination_gid in selection: arena_margin_size = init_selectivity_config( destination_gid, spatial_resolution, arena, arena_margin, arena_margin_size, coordinates, field_width, field_width_scale, peak_rate, target_selectivity_type, selectivity_type_index, dst_input_features_attr_dict, target_selectivity_features_dict, target_selectivity_config_dict, target_field_width_dict, logger=logger) arena_x, arena_y = stimulus.get_2D_arena_spatial_mesh( arena, spatial_resolution, margin=arena_margin_size) selection = list(target_selectivity_features_dict.keys()) initial_weights_by_source_gid_dict = defaultdict(lambda: dict()) initial_weights_by_syn_id_dict = \ read_weights(initial_weights_path, initial_weights_namespace, synapse_name, destination, selection, env.comm, env.io_size, defaultdict(lambda: dict()), logger=logger if rank == 0 else None) non_structured_weights_by_source_gid_dict = defaultdict(lambda: dict()) non_structured_weights_by_syn_id_dict = None if len(non_structured_sources) > 0: non_structured_weights_by_syn_id_dict = \ read_weights(non_structured_weights_path, non_structured_weights_namespace, synapse_name, destination, selection, env.comm, env.io_size, defaultdict(lambda: dict()), logger=logger if rank == 0 else None) reference_weights_by_syn_id_dict = None reference_weights_by_source_gid_dict = defaultdict(lambda: dict()) if reference_weights_path is not None: reference_weights_by_syn_id_dict = \ read_weights(reference_weights_path, reference_weights_namespace, synapse_name, destination, selection, env.comm, env.io_size, defaultdict(lambda: dict()), logger=logger if rank == 0 else None) source_gid_set_dict = defaultdict(set) syn_count_by_source_gid_dict = defaultdict(lambda: defaultdict(int)) syn_ids_by_source_gid_dict = defaultdict(lambda: defaultdict(list)) structured_syn_id_count = defaultdict(int) non_structured_syn_id_count = defaultdict(int) projections = [(source, destination) for source in all_sources] edge_iter_dict, edge_attr_info = scatter_read_graph_selection( connections_path, selection=selection, namespaces=['Synapses'], projections=projections, comm=env.comm, io_size=env.io_size) syn_counts_by_source = init_syn_weight_dicts( destination, non_structured_sources, edge_iter_dict, edge_attr_info, initial_weights_by_syn_id_dict, initial_weights_by_source_gid_dict, non_structured_weights_by_syn_id_dict, non_structured_weights_by_source_gid_dict, reference_weights_by_syn_id_dict, reference_weights_by_source_gid_dict, source_gid_set_dict, syn_count_by_source_gid_dict, syn_ids_by_source_gid_dict, structured_syn_id_count, non_structured_syn_id_count) for source in syn_counts_by_source: for this_gid in syn_counts_by_source[source]: count = syn_counts_by_source[source][this_gid] logger.info( f'Rank {rank}: destination: {destination}; gid {this_gid}; ' f'{count} edges from source population {source}') input_rate_maps_by_source_gid_dict = {} if len(non_structured_sources) > 0: non_structured_input_rate_maps_by_source_gid_dict = {} else: non_structured_input_rate_maps_by_source_gid_dict = None for source in all_sources: source_gids = list(source_gid_set_dict[source]) if rank == 0: logger.info( f'Rank {rank}: getting feature data for {len(source_gids)} cells in population {source}' ) this_src_input_features = exchange_input_features( env.comm, source_gids, src_input_features_attr_dict[source]) count = 0 for this_gid in source_gids: attr_dict = this_src_input_features[this_gid] this_selectivity_type = attr_dict['Selectivity Type'][0] this_selectivity_type_name = selectivity_type_index[ this_selectivity_type] input_cell_config = stimulus.get_input_cell_config( this_selectivity_type, selectivity_type_index, selectivity_attr_dict=attr_dict) this_arena_rate_map = np.asarray( input_cell_config.get_rate_map(arena_x, arena_y), dtype=np.float32) if source in non_structured_sources: non_structured_input_rate_maps_by_source_gid_dict[ this_gid] = this_arena_rate_map else: input_rate_maps_by_source_gid_dict[ this_gid] = this_arena_rate_map count += 1 for destination_gid in selection: if is_interactive: context.update(locals()) save_fig_path = None if save_fig is not None: save_fig_path = f'{save_fig}/Structured Weights {destination} {destination_gid}.png' reference_weight_dict = None if reference_weights_path is not None: reference_weight_dict = reference_weights_by_source_gid_dict[ destination_gid] LTP_delta_weights_dict, LTD_delta_weights_dict, arena_structured_map = \ synapses.generate_structured_weights(destination_gid, target_map=target_selectivity_features_dict[destination_gid]['Arena Rate Map'], initial_weight_dict=initial_weights_by_source_gid_dict[destination_gid], #reference_weight_dict=reference_weight_dict, #reference_weights_are_delta=reference_weights_are_delta, #reference_weights_namespace=reference_weights_namespace, input_rate_map_dict=input_rate_maps_by_source_gid_dict, non_structured_input_rate_map_dict=non_structured_input_rate_maps_by_source_gid_dict, non_structured_weights_dict=non_structured_weights_by_source_gid_dict[destination_gid], syn_count_dict=syn_count_by_source_gid_dict[destination_gid], max_opt_iter=max_opt_iter, max_weight_decay_fraction=max_weight_decay_fraction, target_amplitude=target_amplitude, arena_x=arena_x, arena_y=arena_y, optimize_tol=optimize_tol, verbose=verbose if rank == 0 else False, plot=plot, show_fig=show_fig, save_fig=save_fig_path, fig_kwargs={'gid': destination_gid, 'field_width': target_field_width_dict[destination_gid]}) input_rate_maps_by_source_gid_dict.clear() target_map_flat = target_selectivity_features_dict[ destination_gid]['Arena Rate Map'].flat arena_map_residual_mae = np.mean( np.abs(arena_structured_map - target_map_flat)) output_features_dict[destination_gid] = \ { fld: target_selectivity_features_dict[destination_gid][fld] for fld in ['Selectivity Type', 'Num Fields', 'Field Width', 'Peak Rate', 'X Offset', 'Y Offset',]} output_features_dict[destination_gid][ 'Rate Map Residual Mean Error'] = np.asarray( [arena_map_residual_mae], dtype=np.float32) this_structured_syn_id_count = structured_syn_id_count[ destination_gid] output_syn_ids = np.empty(this_structured_syn_id_count, dtype='uint32') LTD_output_weights = np.empty(this_structured_syn_id_count, dtype='float32') LTP_output_weights = np.empty(this_structured_syn_id_count, dtype='float32') i = 0 for source_gid in LTP_delta_weights_dict: for syn_id in syn_ids_by_source_gid_dict[destination_gid][ source_gid]: output_syn_ids[i] = syn_id LTP_output_weights[i] = LTP_delta_weights_dict[source_gid] LTD_output_weights[i] = LTD_delta_weights_dict[source_gid] i += 1 LTP_output_weights_dict[destination_gid] = { 'syn_id': output_syn_ids, synapse_name: LTP_output_weights } LTD_output_weights_dict[destination_gid] = { 'syn_id': output_syn_ids, synapse_name: LTD_output_weights } this_non_structured_syn_id_count = non_structured_syn_id_count[ destination_gid] i = 0 logger.info( f'Rank {rank}; destination: {destination}; gid {destination_gid}; ' f'generated structured weights for {len(output_syn_ids)} inputs in {time.time() - local_time:.2f} s; ' f'residual error is {arena_map_residual_mae:.2f}') gid_count += 1 gc.collect() env.comm.barrier() if (write_size > 0) and (iter_count % write_size == 0): if not dry_run: append_cell_attributes(output_weights_path, destination, LTD_output_weights_dict, namespace=LTD_output_weights_namespace, comm=env.comm, io_size=env.io_size, chunk_size=chunk_size, value_chunk_size=value_chunk_size) append_cell_attributes(output_weights_path, destination, LTP_output_weights_dict, namespace=LTP_output_weights_namespace, comm=env.comm, io_size=env.io_size, chunk_size=chunk_size, value_chunk_size=value_chunk_size) count = env.comm.reduce(len(LTP_output_weights_dict), op=MPI.SUM, root=0) env.comm.barrier() if rank == 0: logger.info( f'Destination: {destination}; appended weights for {count} cells' ) if output_features_path is not None: if output_features_namespace is None: output_features_namespace = f'{target_selectivity_type_name.title()} Selectivity' this_output_features_namespace = f'{output_features_namespace} {arena_id}' append_cell_attributes( output_features_path, destination, output_features_dict, namespace=this_output_features_namespace) count = env.comm.reduce(len(output_features_dict), op=MPI.SUM, root=0) env.comm.barrier() if rank == 0: logger.info( f'Destination: {destination}; appended selectivity features for {count} cells' ) LTP_output_weights_dict.clear() LTD_output_weights_dict.clear() output_features_dict.clear() gc.collect() env.comm.barrier() if (iter_count >= 10) and debug: break env.comm.barrier() if not dry_run: append_cell_attributes(output_weights_path, destination, LTD_output_weights_dict, namespace=LTD_output_weights_namespace, comm=env.comm, io_size=env.io_size, chunk_size=chunk_size, value_chunk_size=value_chunk_size) append_cell_attributes(output_weights_path, destination, LTP_output_weights_dict, namespace=LTP_output_weights_namespace, comm=env.comm, io_size=env.io_size, chunk_size=chunk_size, value_chunk_size=value_chunk_size) count = comm.reduce(len(LTP_output_weights_dict), op=MPI.SUM, root=0) env.comm.barrier() if rank == 0: logger.info( f'Destination: {destination}; appended weights for {count} cells' ) if output_features_path is not None: if output_features_namespace is None: output_features_namespace = 'Selectivity Features' this_output_features_namespace = f'{output_features_namespace} {arena_id}' append_cell_attributes(output_features_path, destination, output_features_dict, namespace=this_output_features_namespace) count = env.comm.reduce(len(output_features_dict), op=MPI.SUM, root=0) env.comm.barrier() if rank == 0: logger.info( f'Destination: {destination}; appended selectivity features for {count} cells' ) env.comm.barrier() global_count = env.comm.gather(gid_count, root=0) env.comm.barrier() if rank == 0: total_count = np.sum(global_count) total_time = time.time() - start_time logger.info( f'Destination: {destination}; ' f'{env.comm.size} ranks assigned structured weights to {total_count} cells in {total_time:.2f} s' )
def init_selectivity_config(destination_gid, spatial_resolution, arena, arena_margin, arena_margin_size, coordinates, field_width, field_width_scale, peak_rate, target_selectivity_type, selectivity_type_index, input_features_attr_dict, target_selectivity_features_dict, target_selectivity_config_dict, target_field_width_dict, logger=None): assert (destination_gid in input_features_attr_dict) this_target_selectivity_features_dict = input_features_attr_dict[ destination_gid] this_target_selectivity_features_dict['Selectivity Type'] = np.asarray( [target_selectivity_type], dtype=np.uint8) if len(coordinates) > 0: num_fields = len(coordinates) this_target_selectivity_features_dict['X Offset'] = np.asarray( [x[0] for x in coordinates], dtype=np.float32) this_target_selectivity_features_dict['Y Offset'] = np.asarray( [x[1] for x in coordinates], dtype=np.float32) this_target_selectivity_features_dict['Num Fields'] = np.asarray( [num_fields], dtype=np.uint8) elif 'Num Fields' in this_target_selectivity_features_dict: num_fields = this_target_selectivity_features_dict['Num Fields'][0] else: num_fields = 0 if field_width is not None: this_target_selectivity_features_dict['Field Width'] = np.asarray( [field_width] * num_fields, dtype=np.float32) elif 'Field Width' in this_target_selectivity_features_dict: this_field_width = this_target_selectivity_features_dict['Field Width'] this_target_selectivity_features_dict[ 'Field Width'] = this_field_width[:num_fields] else: this_field_width = np.asarray([], dtype=np.float32) if peak_rate is not None: this_target_selectivity_features_dict['Peak Rate'] = np.asarray( [peak_rate] * num_fields, dtype=np.float32) if num_fields > 0: input_cell_config = stimulus.get_input_cell_config( target_selectivity_type, selectivity_type_index, selectivity_attr_dict=this_target_selectivity_features_dict) arena_margin_size = max( arena_margin_size, np.max(input_cell_config.field_width) * arena_margin) arena_x, arena_y = stimulus.get_2D_arena_spatial_mesh( arena, spatial_resolution, margin=arena_margin_size) target_map = np.asarray(input_cell_config.get_rate_map( arena_x, arena_y, scale=field_width_scale), dtype=np.float32).flatten() this_target_selectivity_features_dict['Arena Rate Map'] = target_map target_selectivity_features_dict[ destination_gid] = this_target_selectivity_features_dict target_field_width_dict[ destination_gid] = input_cell_config.field_width target_selectivity_config_dict[destination_gid] = input_cell_config return arena_margin_size
def main(config, coordinates, field_width, gid, input_features_path, input_features_namespaces, initial_weights_path, output_features_namespace, output_features_path, output_weights_path, reference_weights_path, h5types_path, synapse_name, initial_weights_namespace, output_weights_namespace, reference_weights_namespace, connections_path, destination, sources, non_structured_sources, non_structured_weights_namespace, non_structured_weights_path, arena_id, field_width_scale, max_delta_weight, max_opt_iter, max_weight_decay_fraction, optimize_method, optimize_tol, optimize_grad, peak_rate, reference_weights_are_delta, arena_margin, target_amplitude, io_size, chunk_size, value_chunk_size, cache_size, write_size, verbose, dry_run, plot, show_fig, save_fig): """ :param config: str (path to .yaml file) :param input_features_path: str (path to .h5 file) :param initial_weights_path: str (path to .h5 file) :param initial_weights_namespace: str :param output_weights_namespace: str :param connections_path: str (path to .h5 file) :param destination: str :param sources: list of str :param io_size: :param chunk_size: :param value_chunk_size: :param cache_size: :param write_size: :param verbose: :param dry_run: :return: """ utils.config_logging(verbose) logger = utils.get_script_logger(__file__) comm = MPI.COMM_WORLD rank = comm.rank nranks = comm.size if io_size == -1: io_size = comm.size if rank == 0: logger.info('%s: %i ranks have been allocated' % (__file__, comm.size)) env = Env(comm=comm, config_file=config, io_size=io_size) if plot and (not save_fig) and (not show_fig): show_fig = True if (not dry_run) and (rank == 0): if not os.path.isfile(output_weights_path): if initial_weights_path is not None: input_file = h5py.File(initial_weights_path, 'r') elif h5types_path is not None: input_file = h5py.File(h5types_path, 'r') else: raise RuntimeError( 'h5types input path must be specified when weights path is not specified.' ) output_file = h5py.File(output_weights_path, 'w') input_file.copy('/H5Types', output_file) input_file.close() output_file.close() env.comm.barrier() LTD_output_weights_namespace = 'LTD %s %s' % (output_weights_namespace, arena_id) LTP_output_weights_namespace = 'LTP %s %s' % (output_weights_namespace, arena_id) this_input_features_namespaces = [ '%s %s' % (input_features_namespace, arena_id) for input_features_namespace in input_features_namespaces ] selectivity_type_index = { i: n for n, i in viewitems(env.selectivity_types) } target_selectivity_type_name = 'place' target_selectivity_type = env.selectivity_types[ target_selectivity_type_name] features_attrs = defaultdict(dict) source_features_attr_names = [ 'Selectivity Type', 'Num Fields', 'Field Width', 'Peak Rate', 'Module ID', 'Grid Spacing', 'Grid Orientation', 'Field Width Concentration Factor', 'X Offset', 'Y Offset' ] target_features_attr_names = [ 'Selectivity Type', 'Num Fields', 'Field Width', 'Peak Rate', 'X Offset', 'Y Offset' ] local_random = np.random.RandomState() seed_offset = int( env.model_config['Random Seeds']['GC Structured Weights']) spatial_resolution = env.stimulus_config['Spatial Resolution'] # cm arena = env.stimulus_config['Arena'][arena_id] default_run_vel = arena.properties['default run velocity'] # cm/s gid_count = 0 start_time = time.time() target_gid_set = None if len(gid) > 0: target_gid_set = set(gid) all_sources = sources + non_structured_sources connection_gen_list = [ NeuroH5ProjectionGen(connections_path, source, destination, namespaces=['Synapses'], comm=comm) \ for source in all_sources ] output_features_dict = {} LTP_output_weights_dict = {} LTD_output_weights_dict = {} for iter_count, attr_gen_package in enumerate( zip_longest(*connection_gen_list)): local_time = time.time() this_gid = attr_gen_package[0][0] if not all([ attr_gen_items[0] == this_gid for attr_gen_items in attr_gen_package ]): raise Exception( 'Rank: %i; destination: %s; this_gid not matched across multiple attribute ' 'generators: %s' % (rank, destination, [attr_gen_items[0] for attr_gen_items in attr_gen_package])) if (target_gid_set is not None) and (this_gid not in target_gid_set): continue if this_gid is None: selection = [] logger.info('Rank: %i received None' % rank) else: selection = [this_gid] local_random.seed(int(this_gid + seed_offset)) has_structured_weights = False dst_input_features_attr_dict = {} for input_features_namespace in this_input_features_namespaces: input_features_iter = read_cell_attribute_selection( input_features_path, destination, namespace=input_features_namespace, mask=set(target_features_attr_names), comm=env.comm, selection=selection) count = 0 for gid, attr_dict in input_features_iter: dst_input_features_attr_dict[gid] = attr_dict count += 1 if rank == 0: logger.info( 'Read %s feature data for %i cells in population %s' % (input_features_namespace, count, destination)) arena_margin_size = 0. arena_margin = max(arena_margin, 0.) target_selectivity_features_dict = {} target_selectivity_config_dict = {} target_field_width_dict = {} for gid in selection: target_selectivity_features_dict[ gid] = dst_input_features_attr_dict.get(gid, {}) target_selectivity_features_dict[gid][ 'Selectivity Type'] = np.asarray([target_selectivity_type], dtype=np.uint8) num_fields = target_selectivity_features_dict[gid]['Num Fields'][0] if coordinates[0] is not None: num_fields = 1 target_selectivity_features_dict[gid]['X Offset'] = np.asarray( [coordinates[0]], dtype=np.float32) target_selectivity_features_dict[gid]['Y Offset'] = np.asarray( [coordinates[1]], dtype=np.float32) target_selectivity_features_dict[gid][ 'Num Fields'] = np.asarray([num_fields], dtype=np.uint8) if field_width is not None: target_selectivity_features_dict[gid][ 'Field Width'] = np.asarray([field_width] * num_fields, dtype=np.float32) else: this_field_width = target_selectivity_features_dict[gid][ 'Field Width'] target_selectivity_features_dict[gid][ 'Field Width'] = this_field_width[:num_fields] if peak_rate is not None: target_selectivity_features_dict[gid][ 'Peak Rate'] = np.asarray([peak_rate] * num_fields, dtype=np.float32) input_cell_config = stimulus.get_input_cell_config( target_selectivity_type, selectivity_type_index, selectivity_attr_dict=target_selectivity_features_dict[gid]) if input_cell_config.num_fields > 0: arena_margin_size = max( arena_margin_size, np.max(input_cell_config.field_width) * arena_margin) target_field_width_dict[gid] = input_cell_config.field_width target_selectivity_config_dict[gid] = input_cell_config has_structured_weights = True arena_x, arena_y = stimulus.get_2D_arena_spatial_mesh( arena, spatial_resolution, margin=arena_margin_size) for gid, input_cell_config in viewitems( target_selectivity_config_dict): target_map = np.asarray(input_cell_config.get_rate_map( arena_x, arena_y, scale=field_width_scale), dtype=np.float32) target_selectivity_features_dict[gid][ 'Arena Rate Map'] = target_map if not has_structured_weights: selection = [] initial_weights_by_syn_id_dict = defaultdict(lambda: dict()) initial_weights_by_source_gid_dict = defaultdict(lambda: dict()) if initial_weights_path is not None: initial_weights_iter = \ read_cell_attribute_selection(initial_weights_path, destination, namespace=initial_weights_namespace, selection=selection) initial_weights_gid_count = 0 initial_weights_syn_count = 0 for this_gid, syn_weight_attr_dict in initial_weights_iter: syn_ids = syn_weight_attr_dict['syn_id'] weights = syn_weight_attr_dict[synapse_name] for (syn_id, weight) in zip(syn_ids, weights): initial_weights_by_syn_id_dict[this_gid][int( syn_id)] = float(weight) initial_weights_gid_count += 1 initial_weights_syn_count += len(syn_ids) logger.info( 'destination: %s; read initial synaptic weights for %i gids and %i syns' % (destination, initial_weights_gid_count, initial_weights_syn_count)) if len(non_structured_sources) > 0: non_structured_weights_by_syn_id_dict = defaultdict(lambda: dict()) non_structured_weights_by_source_gid_dict = defaultdict( lambda: dict()) else: non_structured_weights_by_syn_id_dict = None if non_structured_weights_path is not None: non_structured_weights_iter = \ read_cell_attribute_selection(initial_weights_path, destination, namespace=non_structured_weights_namespace, selection=selection) non_structured_weights_gid_count = 0 non_structured_weights_syn_count = 0 for this_gid, syn_weight_attr_dict in non_structured_weights_iter: syn_ids = syn_weight_attr_dict['syn_id'] weights = syn_weight_attr_dict[synapse_name] for (syn_id, weight) in zip(syn_ids, weights): non_structured_weights_by_syn_id_dict[this_gid][int( syn_id)] = float(weight) non_structured_weights_gid_count += 1 non_structured_weights_syn_count += len(syn_ids) logger.info( 'destination: %s; read non-structured synaptic weights for %i gids and %i syns' % ( destination, non_structured_weights_gid_count, non_structured_weights_syn_count, )) reference_weights_by_syn_id_dict = None reference_weights_by_source_gid_dict = defaultdict(lambda: dict()) if reference_weights_path is not None: reference_weights_by_syn_id_dict = defaultdict(lambda: dict()) reference_weights_iter = \ read_cell_attribute_selection(reference_weights_path, destination, namespace=reference_weights_namespace, selection=selection) reference_weights_gid_count = 0 for this_gid, syn_weight_attr_dict in reference_weights_iter: syn_ids = syn_weight_attr_dict['syn_id'] weights = syn_weight_attr_dict[synapse_name] for (syn_id, weight) in zip(syn_ids, weights): reference_weights_by_syn_id_dict[this_gid][int( syn_id)] = float(weight) logger.info( 'destination: %s; read reference synaptic weights for %i gids' % (destination, reference_weights_gid_count)) syn_count_by_source_gid_dict = defaultdict(int) source_gid_set_dict = defaultdict(set) syn_ids_by_source_gid_dict = defaultdict(list) structured_syn_id_count = 0 if has_structured_weights: for source, (destination_gid, (source_gid_array, conn_attr_dict)) in zip_longest( all_sources, attr_gen_package): syn_ids = conn_attr_dict['Synapses']['syn_id'] count = 0 this_initial_weights_by_syn_id_dict = None this_initial_weights_by_source_gid_dict = None this_reference_weights_by_syn_id_dict = None this_reference_weights_by_source_gid_dict = None this_non_structured_weights_by_syn_id_dict = None this_non_structured_weights_by_source_gid_dict = None if destination_gid is not None: this_initial_weights_by_syn_id_dict = initial_weights_by_syn_id_dict[ destination_gid] this_initial_weights_by_source_gid_dict = initial_weights_by_source_gid_dict[ destination_gid] if reference_weights_by_syn_id_dict is not None: this_reference_weights_by_syn_id_dict = reference_weights_by_syn_id_dict[ destination_gid] this_reference_weights_by_source_gid_dict = reference_weights_by_source_gid_dict[ destination_gid] this_non_structured_weights_by_syn_id_dict = non_structured_weights_by_syn_id_dict[ destination_gid] this_non_structured_weights_by_source_gid_dict = non_structured_weights_by_source_gid_dict[ destination_gid] for i in range(len(source_gid_array)): this_source_gid = source_gid_array[i] this_syn_id = syn_ids[i] if this_syn_id in this_initial_weights_by_syn_id_dict: this_syn_wgt = this_initial_weights_by_syn_id_dict[ this_syn_id] if this_source_gid not in this_initial_weights_by_source_gid_dict: this_initial_weights_by_source_gid_dict[ this_source_gid] = this_syn_wgt if this_reference_weights_by_syn_id_dict is not None: this_reference_weights_by_source_gid_dict[this_source_gid] = \ this_reference_weights_by_syn_id_dict[this_syn_id] elif this_syn_id in this_non_structured_weights_by_syn_id_dict: this_syn_wgt = this_non_structured_weights_by_syn_id_dict[ this_syn_id] if this_source_gid not in this_non_structured_weights_by_source_gid_dict: this_non_structured_weights_by_source_gid_dict[ this_source_gid] = this_syn_wgt source_gid_set_dict[source].add(this_source_gid) syn_ids_by_source_gid_dict[this_source_gid].append( this_syn_id) syn_count_by_source_gid_dict[this_source_gid] += 1 count += 1 if source not in non_structured_sources: structured_syn_id_count += len(syn_ids) logger.info( 'Rank %i; destination: %s; gid %i; %d edges from source population %s' % (rank, destination, this_gid, count, source)) input_rate_maps_by_source_gid_dict = {} if len(non_structured_sources) > 0: non_structured_input_rate_maps_by_source_gid_dict = {} else: non_structured_input_rate_maps_by_source_gid_dict = None for source in all_sources: if has_structured_weights: source_gids = list(source_gid_set_dict[source]) else: source_gids = [] if rank == 0: logger.info( 'Reading %s feature data for %i cells in population %s...' % (input_features_namespace, len(source_gids), source)) for input_features_namespace in this_input_features_namespaces: input_features_iter = read_cell_attribute_selection( input_features_path, source, namespace=input_features_namespace, mask=set(source_features_attr_names), comm=env.comm, selection=source_gids) count = 0 for gid, attr_dict in input_features_iter: this_selectivity_type = attr_dict['Selectivity Type'][0] this_selectivity_type_name = selectivity_type_index[ this_selectivity_type] input_cell_config = stimulus.get_input_cell_config( this_selectivity_type, selectivity_type_index, selectivity_attr_dict=attr_dict) this_arena_rate_map = np.asarray( input_cell_config.get_rate_map(arena_x, arena_y), dtype=np.float32) if source in non_structured_sources: non_structured_input_rate_maps_by_source_gid_dict[ gid] = this_arena_rate_map else: input_rate_maps_by_source_gid_dict[ gid] = this_arena_rate_map count += 1 if rank == 0: logger.info( 'Read %s feature data for %i cells in population %s' % (input_features_namespace, count, source)) if has_structured_weights: if is_interactive: context.update(locals()) save_fig_path = None if save_fig is not None: save_fig_path = '%s/Structured Weights %s %d.png' % ( save_fig, destination, this_gid) normalized_LTP_delta_weights_dict, LTD_delta_weights_dict, arena_LS_map = \ synapses.generate_structured_weights(target_map=target_selectivity_features_dict[this_gid]['Arena Rate Map'], initial_weight_dict=this_initial_weights_by_source_gid_dict, reference_weight_dict=this_reference_weights_by_source_gid_dict, reference_weights_are_delta=reference_weights_are_delta, reference_weights_namespace=reference_weights_namespace, input_rate_map_dict=input_rate_maps_by_source_gid_dict, non_structured_input_rate_map_dict=non_structured_input_rate_maps_by_source_gid_dict, non_structured_weights_dict=this_non_structured_weights_by_source_gid_dict, syn_count_dict=syn_count_by_source_gid_dict, max_delta_weight=max_delta_weight, max_opt_iter=max_opt_iter, max_weight_decay_fraction=max_weight_decay_fraction, target_amplitude=target_amplitude, arena_x=arena_x, arena_y=arena_y, optimize_method=optimize_method, optimize_tol=optimize_tol, optimize_grad=optimize_grad, verbose=verbose, plot=plot, show_fig=show_fig, save_fig=save_fig_path, fig_kwargs={'gid': this_gid, 'field_width': target_field_width_dict[this_gid]}) gc.collect() this_selectivity_dict = target_selectivity_features_dict[this_gid] output_features_dict[this_gid] = { fld: this_selectivity_dict[fld] for fld in [ 'Selectivity Type', 'Num Fields', 'Field Width', 'Peak Rate', 'X Offset', 'Y Offset' ] } output_features_dict[this_gid]['Arena State Map'] = np.asarray( arena_LS_map.ravel(), dtype=np.float32) output_syn_ids = np.empty(structured_syn_id_count, dtype='uint32') LTD_output_weights = np.empty(structured_syn_id_count, dtype='float32') LTP_output_weights = np.empty(structured_syn_id_count, dtype='float32') i = 0 for source_gid in normalized_LTP_delta_weights_dict: for syn_id in syn_ids_by_source_gid_dict[source_gid]: output_syn_ids[i] = syn_id LTP_output_weights[i] = normalized_LTP_delta_weights_dict[ source_gid] LTD_output_weights[i] = LTD_delta_weights_dict[source_gid] i += 1 LTP_output_weights_dict[this_gid] = { 'syn_id': output_syn_ids, synapse_name: LTP_output_weights } LTD_output_weights_dict[this_gid] = { 'syn_id': output_syn_ids, synapse_name: LTD_output_weights } logger.info( 'Rank %i; destination: %s; gid %i; generated structured weights for %i inputs in %.2f ' 's' % (rank, destination, this_gid, len(output_syn_ids), time.time() - local_time)) gid_count += 1 if iter_count % write_size == 0: if not dry_run: append_cell_attributes(output_weights_path, destination, LTD_output_weights_dict, namespace=LTD_output_weights_namespace, comm=env.comm, io_size=env.io_size, chunk_size=chunk_size, value_chunk_size=value_chunk_size) append_cell_attributes(output_weights_path, destination, LTP_output_weights_dict, namespace=LTP_output_weights_namespace, comm=env.comm, io_size=env.io_size, chunk_size=chunk_size, value_chunk_size=value_chunk_size) count = comm.reduce(len(LTP_output_weights_dict), op=MPI.SUM, root=0) if rank == 0: logger.info( 'Destination: %s; appended weights for %i cells' % (destination, count)) if output_features_path is not None: if output_features_namespace is None: output_features_namespace = '%s Selectivity' % target_selectivity_type_name.title( ) this_output_features_namespace = '%s %s' % ( output_features_namespace, arena_id) logger.info(str(output_features_dict)) append_cell_attributes( output_features_path, destination, output_features_dict, namespace=this_output_features_namespace) count = comm.reduce(len(output_features_dict), op=MPI.SUM, root=0) if rank == 0: logger.info( 'Destination: %s; appended selectivity features for %i cells' % (destination, count)) LTP_output_weights_dict.clear() LTD_output_weights_dict.clear() output_features_dict.clear() gc.collect() env.comm.barrier() if not dry_run: append_cell_attributes(output_weights_path, destination, LTD_output_weights_dict, namespace=LTD_output_weights_namespace, comm=env.comm, io_size=env.io_size, chunk_size=chunk_size, value_chunk_size=value_chunk_size) append_cell_attributes(output_weights_path, destination, LTP_output_weights_dict, namespace=LTP_output_weights_namespace, comm=env.comm, io_size=env.io_size, chunk_size=chunk_size, value_chunk_size=value_chunk_size) count = comm.reduce(len(LTP_output_weights_dict), op=MPI.SUM, root=0) if rank == 0: logger.info('Destination: %s; appended weights for %i cells' % (destination, count)) if output_features_path is not None: if output_features_namespace is None: output_features_namespace = 'Selectivity Features' this_output_features_namespace = '%s %s' % ( output_features_namespace, arena_id) append_cell_attributes(output_features_path, destination, output_features_dict, namespace=this_output_features_namespace) count = comm.reduce(len(output_features_dict), op=MPI.SUM, root=0) if rank == 0: logger.info( 'Destination: %s; appended selectivity features for %i cells' % (destination, count)) env.comm.barrier() global_count = comm.gather(gid_count, root=0) if rank == 0: logger.info( 'destination: %s; %i ranks assigned structured weights to %i cells in %.2f s' % (destination, comm.size, np.sum(global_count), time.time() - start_time))
def main(config, config_prefix, coords_path, distances_namespace, bin_distance, selectivity_path, selectivity_namespace, subset_seed, arena_id, populations, io_size, cache_size, verbose, debug, show_fig, save_fig, save_fig_dir, font_size, fig_size, colormap, fig_format): """ :param config: str (.yaml file name) :param config_prefix: str (path to dir) :param coords_path: str (path to file) :param distances_namespace: str :param bin_distance: float :param selectivity_path: str :param subset_seed: int; for reproducible choice of gids to plot individual rate maps :param arena_id: str :param populations: tuple of str :param io_size: int :param cache_size: int :param verbose: bool :param debug: bool :param show_fig: bool :param save_fig: str (base file name) :param save_fig_dir: str (path to dir) :param font_size: float :param fig_format: str """ comm = MPI.COMM_WORLD rank = comm.rank config_logging(verbose) env = Env(comm=comm, config_file=config, config_prefix=config_prefix, template_paths=None) if io_size == -1: io_size = comm.size if rank == 0: logger.info('%i ranks have been allocated' % comm.size) fig_options = copy.copy(default_fig_options) fig_options.saveFigDir = save_fig_dir fig_options.fontSize = font_size fig_options.figFormat = fig_format fig_options.showFig = show_fig fig_options.figSize = fig_size if save_fig is not None: save_fig = '%s %s' % (save_fig, arena_id) fig_options.saveFig = save_fig population_ranges = read_population_ranges(selectivity_path, comm)[0] coords_population_ranges = read_population_ranges(coords_path, comm)[0] if len(populations) == 0: populations = ('MC', 'ConMC', 'LPP', 'GC', 'MPP', 'CA3c') valid_selectivity_namespaces = dict() if rank == 0: for population in populations: if population not in population_ranges: raise RuntimeError( 'plot_input_selectivity_features: specified population: %s not found in ' 'provided selectivity_path: %s' % (population, selectivity_path)) if population not in env.stimulus_config[ 'Selectivity Type Probabilities']: raise RuntimeError( 'plot_input_selectivity_features: selectivity type not specified for ' 'population: %s' % population) valid_selectivity_namespaces[population] = [] with h5py.File(selectivity_path, 'r') as selectivity_f: for this_namespace in selectivity_f['Populations'][population]: if f'{selectivity_namespace} {arena_id}' in this_namespace: valid_selectivity_namespaces[population].append( this_namespace) if len(valid_selectivity_namespaces[population]) == 0: raise RuntimeError( 'plot_input_selectivity_features: no selectivity data in arena: %s found ' 'for specified population: %s in provided selectivity_path: %s' % (arena_id, population, selectivity_path)) valid_selectivity_namespaces = comm.bcast(valid_selectivity_namespaces, root=0) selectivity_type_names = dict( (val, key) for (key, val) in viewitems(env.selectivity_types)) reference_u_arc_distance_bounds = None reference_v_arc_distance_bounds = None if rank == 0: for population in populations: if population not in coords_population_ranges: raise RuntimeError( 'plot_input_selectivity_features: specified population: %s not found in ' 'provided coords_path: %s' % (population, coords_path)) with h5py.File(coords_path, 'r') as coords_f: pop_size = population_ranges[population][1] unique_gid_count = len( set(coords_f['Populations'][population] [distances_namespace]['U Distance']['Cell Index'][:])) if pop_size != unique_gid_count: raise RuntimeError( 'plot_input_selectivity_features: only %i/%i unique cell indexes found ' 'for specified population: %s in provided coords_path: %s' % (unique_gid_count, pop_size, population, coords_path)) if reference_u_arc_distance_bounds is None: try: reference_u_arc_distance_bounds = \ coords_f['Populations'][population][distances_namespace].attrs['Reference U Min'], \ coords_f['Populations'][population][distances_namespace].attrs['Reference U Max'] except Exception: raise RuntimeError( 'plot_input_selectivity_features: problem locating attributes ' 'containing reference bounds in namespace: %s for population: %s from ' 'coords_path: %s' % (distances_namespace, population, coords_path)) if reference_v_arc_distance_bounds is None: try: reference_v_arc_distance_bounds = \ coords_f['Populations'][population][distances_namespace].attrs['Reference V Min'], \ coords_f['Populations'][population][distances_namespace].attrs['Reference V Max'] except Exception: raise RuntimeError( 'plot_input_selectivity_features: problem locating attributes ' 'containing reference bounds in namespace: %s for population: %s from ' 'coords_path: %s' % (distances_namespace, population, coords_path)) reference_u_arc_distance_bounds = comm.bcast( reference_u_arc_distance_bounds, root=0) reference_v_arc_distance_bounds = comm.bcast( reference_v_arc_distance_bounds, root=0) u_edges = np.arange(reference_u_arc_distance_bounds[0], reference_u_arc_distance_bounds[1] + bin_distance / 2., bin_distance) v_edges = np.arange(reference_v_arc_distance_bounds[0], reference_v_arc_distance_bounds[1] + bin_distance / 2., bin_distance) if arena_id not in env.stimulus_config['Arena']: raise RuntimeError( 'Arena with ID: %s not specified by configuration at file path: %s' % (arena_id, config_prefix + '/' + config)) arena = env.stimulus_config['Arena'][arena_id] arena_x_mesh, arena_y_mesh = None, None if rank == 0: arena_x_mesh, arena_y_mesh = \ get_2D_arena_spatial_mesh(arena=arena, spatial_resolution=env.stimulus_config['Spatial Resolution']) arena_x_mesh = comm.bcast(arena_x_mesh, root=0) arena_y_mesh = comm.bcast(arena_y_mesh, root=0) for population in populations: start_time = time.time() u_distances_by_gid = dict() v_distances_by_gid = dict() distances_attr_gen = \ bcast_cell_attributes(coords_path, population, root=0, namespace=distances_namespace, comm=comm) for gid, distances_attr_dict in distances_attr_gen: u_distances_by_gid[gid] = distances_attr_dict['U Distance'][0] v_distances_by_gid[gid] = distances_attr_dict['V Distance'][0] if rank == 0: logger.info( 'Reading %i cell positions for population %s took %.2f s' % (len(u_distances_by_gid), population, time.time() - start_time)) for this_selectivity_namespace in valid_selectivity_namespaces[ population]: start_time = time.time() if rank == 0: logger.info('Reading from %s namespace for population %s...' % (this_selectivity_namespace, population)) gid_count = 0 gathered_cell_attributes = defaultdict(list) gathered_component_attributes = defaultdict(list) u_distances_by_cell = list() v_distances_by_cell = list() u_distances_by_component = list() v_distances_by_component = list() rate_map_sum_by_module = defaultdict( lambda: np.zeros_like(arena_x_mesh)) start_time = time.time() selectivity_attr_gen = NeuroH5CellAttrGen( selectivity_path, population, namespace=this_selectivity_namespace, comm=comm, io_size=io_size, cache_size=cache_size) for iter_count, ( gid, selectivity_attr_dict) in enumerate(selectivity_attr_gen): if gid is not None: gid_count += 1 this_selectivity_type = selectivity_attr_dict[ 'Selectivity Type'][0] this_selectivity_type_name = selectivity_type_names[ this_selectivity_type] input_cell_config = \ get_input_cell_config(selectivity_type=this_selectivity_type, selectivity_type_names=selectivity_type_names, selectivity_attr_dict=selectivity_attr_dict) rate_map = input_cell_config.get_rate_map(x=arena_x_mesh, y=arena_y_mesh) u_distances_by_cell.append(u_distances_by_gid[gid]) v_distances_by_cell.append(v_distances_by_gid[gid]) this_cell_attrs, component_count, this_component_attrs = input_cell_config.gather_attributes( ) for attr_name, attr_val in viewitems(this_cell_attrs): gathered_cell_attributes[attr_name].append(attr_val) gathered_cell_attributes['Mean Rate'].append( np.mean(rate_map)) if component_count > 0: u_distances_by_component.extend( [u_distances_by_gid[gid]] * component_count) v_distances_by_component.extend( [v_distances_by_gid[gid]] * component_count) for attr_name, attr_val in viewitems( this_component_attrs): gathered_component_attributes[attr_name].extend( attr_val) this_module_id = this_cell_attrs['Module ID'] if debug and rank == 0: fig_title = '%s %s cell %i' % ( population, this_selectivity_type_name, gid) if save_fig is not None: fig_options.saveFig = '%s %s' % (save_fig, fig_title) plot_2D_rate_map( x=arena_x_mesh, y=arena_y_mesh, rate_map=rate_map, peak_rate=env.stimulus_config['Peak Rate'] [population][this_selectivity_type], title='%s\nModule: %i' % (fig_title, this_module_id), **fig_options()) rate_map_sum_by_module[this_module_id] = np.add( rate_map, rate_map_sum_by_module[this_module_id]) if debug and iter_count >= 10: break cell_count_hist, _, _ = np.histogram2d(u_distances_by_cell, v_distances_by_cell, bins=[u_edges, v_edges]) component_count_hist, _, _ = np.histogram2d( u_distances_by_component, v_distances_by_component, bins=[u_edges, v_edges]) if debug: context.update(locals()) gathered_cell_attr_hist = dict() gathered_component_attr_hist = dict() for key in gathered_cell_attributes: gathered_cell_attr_hist[key], _, _ = \ np.histogram2d(u_distances_by_cell, v_distances_by_cell, bins=[u_edges, v_edges], weights=gathered_cell_attributes[key]) for key in gathered_component_attributes: gathered_component_attr_hist[key], _, _ = \ np.histogram2d(u_distances_by_component, v_distances_by_component, bins=[u_edges, v_edges], weights=gathered_component_attributes[key]) gid_count = comm.gather(gid_count, root=0) cell_count_hist = comm.gather(cell_count_hist, root=0) component_count_hist = comm.gather(component_count_hist, root=0) gathered_cell_attr_hist = comm.gather(gathered_cell_attr_hist, root=0) gathered_component_attr_hist = comm.gather( gathered_component_attr_hist, root=0) rate_map_sum_by_module = dict(rate_map_sum_by_module) rate_map_sum_by_module = comm.gather(rate_map_sum_by_module, root=0) if rank == 0: gid_count = sum(gid_count) cell_count_hist = np.sum(cell_count_hist, axis=0) component_count_hist = np.sum(component_count_hist, axis=0) merged_cell_attr_hist = defaultdict( lambda: np.zeros_like(cell_count_hist)) merged_component_attr_hist = defaultdict( lambda: np.zeros_like(component_count_hist)) for each_cell_attr_hist in gathered_cell_attr_hist: for key in each_cell_attr_hist: merged_cell_attr_hist[key] = np.add( merged_cell_attr_hist[key], each_cell_attr_hist[key]) for each_component_attr_hist in gathered_component_attr_hist: for key in each_component_attr_hist: merged_component_attr_hist[key] = np.add( merged_component_attr_hist[key], each_component_attr_hist[key]) merged_rate_map_sum_by_module = defaultdict( lambda: np.zeros_like(arena_x_mesh)) for each_rate_map_sum_by_module in rate_map_sum_by_module: for this_module_id in each_rate_map_sum_by_module: merged_rate_map_sum_by_module[this_module_id] = \ np.add(merged_rate_map_sum_by_module[this_module_id], each_rate_map_sum_by_module[this_module_id]) logger.info('Processing %i %s %s cells took %.2f s' % (gid_count, population, this_selectivity_type_name, time.time() - start_time)) if debug: context.update(locals()) for key in merged_cell_attr_hist: fig_title = '%s %s cells %s distribution' % ( population, this_selectivity_type_name, key) if save_fig is not None: fig_options.saveFig = '%s %s' % (save_fig, fig_title) if colormap is not None: fig_options.colormap = colormap title = '%s %s cells\n%s distribution' % ( population, this_selectivity_type_name, key) fig = plot_2D_histogram( merged_cell_attr_hist[key], x_edges=u_edges, y_edges=v_edges, norm=cell_count_hist, ylabel='Transverse position (um)', xlabel='Septo-temporal position (um)', title=title, cbar_label='Mean value per bin', cbar=True, **fig_options()) close_figure(fig) for key in merged_component_attr_hist: fig_title = '%s %s cells %s distribution' % ( population, this_selectivity_type_name, key) if save_fig is not None: fig_options.saveFig = '%s %s' % (save_fig, fig_title) title = '%s %s cells\n%s distribution' % ( population, this_selectivity_type_name, key) fig = plot_2D_histogram( merged_component_attr_hist[key], x_edges=u_edges, y_edges=v_edges, norm=component_count_hist, ylabel='Transverse position (um)', xlabel='Septo-temporal position (um)', title=title, cbar_label='Mean value per bin', cbar=True, **fig_options()) close_figure(fig) for this_module_id in merged_rate_map_sum_by_module: fig_title = '%s %s Module %i summed rate maps' % \ (population, this_selectivity_type_name, this_module_id) if save_fig is not None: fig_options.saveFig = '%s %s' % (save_fig, fig_title) fig = plot_2D_rate_map( x=arena_x_mesh, y=arena_y_mesh, rate_map=merged_rate_map_sum_by_module[this_module_id], title='%s %s summed rate maps\nModule %i' % (population, this_selectivity_type_name, this_module_id), **fig_options()) close_figure(fig) if is_interactive and rank == 0: context.update(locals())
def main(config, config_prefix, coords_path, distances_namespace, bin_distance, selectivity_path, selectivity_namespace, spatial_resolution, arena_id, populations, io_size, cache_size, verbose, debug, show_fig, save_fig, save_fig_dir, font_size, fig_size, colormap, fig_format): """ :param config: str (.yaml file name) :param config_prefix: str (path to dir) :param coords_path: str (path to file) :param distances_namespace: str :param bin_distance: float :param selectivity_path: str :param arena_id: str :param populations: tuple of str :param io_size: int :param cache_size: int :param verbose: bool :param debug: bool :param show_fig: bool :param save_fig: str (base file name) :param save_fig_dir: str (path to dir) :param font_size: float :param fig_format: str """ comm = MPI.COMM_WORLD rank = comm.rank config_logging(verbose) env = Env(comm=comm, config_file=config, config_prefix=config_prefix, template_paths=None) if io_size == -1: io_size = comm.size if rank == 0: logger.info(f'{comm.size} ranks have been allocated') fig_options = copy.copy(default_fig_options) fig_options.saveFigDir = save_fig_dir fig_options.fontSize = font_size fig_options.figFormat = fig_format fig_options.showFig = show_fig fig_options.figSize = fig_size if save_fig is not None: save_fig = f'{save_fig} {arena_id}' fig_options.saveFig = save_fig population_ranges = read_population_ranges(selectivity_path, comm)[0] coords_population_ranges = read_population_ranges(coords_path, comm)[0] if len(populations) == 0: populations = ('MC', 'ConMC', 'LPP', 'GC', 'MPP', 'CA3c') valid_selectivity_namespaces = dict() if rank == 0: for population in populations: if population not in population_ranges: raise RuntimeError( f'plot_input_selectivity_features: specified population: {population} not found in ' f'provided selectivity_path: {selectivity_path}') if population not in env.stimulus_config[ 'Selectivity Type Probabilities']: raise RuntimeError( 'plot_input_selectivity_features: selectivity type not specified for ' f'population: {population}') valid_selectivity_namespaces[population] = [] with h5py.File(selectivity_path, 'r') as selectivity_f: for this_namespace in selectivity_f['Populations'][population]: if f'{selectivity_namespace} {arena_id}' in this_namespace: valid_selectivity_namespaces[population].append( this_namespace) if len(valid_selectivity_namespaces[population]) == 0: raise RuntimeError( f'plot_input_selectivity_features: no selectivity data in arena: {arena_id} found ' f'for specified population: {population} in provided selectivity_path: {selectivity_path}' ) valid_selectivity_namespaces = comm.bcast(valid_selectivity_namespaces, root=0) selectivity_type_names = dict( (val, key) for (key, val) in viewitems(env.selectivity_types)) reference_u_arc_distance_bounds = None reference_v_arc_distance_bounds = None if rank == 0: for population in populations: if population not in coords_population_ranges: raise RuntimeError( f'plot_input_selectivity_features: specified population: {population} not found in ' f'provided coords_path: {coords_path}') with h5py.File(coords_path, 'r') as coords_f: pop_size = population_ranges[population][1] unique_gid_count = len( set(coords_f['Populations'][population] [distances_namespace]['U Distance']['Cell Index'][:])) if pop_size != unique_gid_count: raise RuntimeError( f'plot_input_selectivity_features: only {unique_gid_count}/{pop_size} unique cell indexes found ' f'for specified population: {population} in provided coords_path: {coords_path}' ) if reference_u_arc_distance_bounds is None: try: reference_u_arc_distance_bounds = \ coords_f['Populations'][population][distances_namespace].attrs['Reference U Min'], \ coords_f['Populations'][population][distances_namespace].attrs['Reference U Max'] except Exception: raise RuntimeError( 'plot_input_selectivity_features: problem locating attributes ' f'containing reference bounds in namespace: {distances_namespace} ' f'for population: {population} from coords_path: {coords_path}' ) if reference_v_arc_distance_bounds is None: try: reference_v_arc_distance_bounds = \ coords_f['Populations'][population][distances_namespace].attrs['Reference V Min'], \ coords_f['Populations'][population][distances_namespace].attrs['Reference V Max'] except Exception: raise RuntimeError( 'plot_input_selectivity_features: problem locating attributes ' f'containing reference bounds in namespace: {distances_namespace} ' f'for population: {population} from coords_path: {coords_path}' ) reference_u_arc_distance_bounds = comm.bcast( reference_u_arc_distance_bounds, root=0) reference_v_arc_distance_bounds = comm.bcast( reference_v_arc_distance_bounds, root=0) u_edges = np.arange(reference_u_arc_distance_bounds[0], reference_u_arc_distance_bounds[1] + bin_distance / 2., bin_distance) v_edges = np.arange(reference_v_arc_distance_bounds[0], reference_v_arc_distance_bounds[1] + bin_distance / 2., bin_distance) if arena_id not in env.stimulus_config['Arena']: raise RuntimeError( f'Arena with ID: {arena_id} not specified by configuration at file path: {config_prefix}/{config}' ) if spatial_resolution is None: spatial_resolution = env.stimulus_config['Spatial Resolution'] arena = env.stimulus_config['Arena'][arena_id] arena_x_mesh, arena_y_mesh = None, None if rank == 0: arena_x_mesh, arena_y_mesh = \ get_2D_arena_spatial_mesh(arena=arena, spatial_resolution=spatial_resolution) arena_x_mesh = comm.bcast(arena_x_mesh, root=0) arena_y_mesh = comm.bcast(arena_y_mesh, root=0) x0_dict = {} y0_dict = {} for population in populations: start_time = time.time() u_distances_by_gid = dict() v_distances_by_gid = dict() distances_attr_gen = \ bcast_cell_attributes(coords_path, population, root=0, namespace=distances_namespace, comm=comm) for gid, distances_attr_dict in distances_attr_gen: u_distances_by_gid[gid] = distances_attr_dict['U Distance'][0] v_distances_by_gid[gid] = distances_attr_dict['V Distance'][0] if rank == 0: logger.info( f'Reading {len(u_distances_by_gid)} cell positions for population {population} took ' f'{time.time() - start_time:.2f} s') for this_selectivity_namespace in valid_selectivity_namespaces[ population]: start_time = time.time() if rank == 0: logger.info( f'Reading from {this_selectivity_namespace} namespace for population {population}...' ) gid_count = 0 gathered_cell_attributes = defaultdict(list) gathered_component_attributes = defaultdict(list) u_distances_by_cell = list() v_distances_by_cell = list() u_distances_by_component = list() v_distances_by_component = list() rate_map_sum_by_module = defaultdict( lambda: np.zeros_like(arena_x_mesh)) count_by_module = defaultdict(int) start_time = time.time() x0_list_by_module = defaultdict(list) y0_list_by_module = defaultdict(list) selectivity_attr_gen = NeuroH5CellAttrGen( selectivity_path, population, namespace=this_selectivity_namespace, comm=comm, io_size=io_size, cache_size=cache_size) for iter_count, ( gid, selectivity_attr_dict) in enumerate(selectivity_attr_gen): if gid is not None: gid_count += 1 this_selectivity_type = selectivity_attr_dict[ 'Selectivity Type'][0] this_selectivity_type_name = selectivity_type_names[ this_selectivity_type] input_cell_config = \ get_input_cell_config(selectivity_type=this_selectivity_type, selectivity_type_names=selectivity_type_names, selectivity_attr_dict=selectivity_attr_dict) rate_map = input_cell_config.get_rate_map(x=arena_x_mesh, y=arena_y_mesh) u_distances_by_cell.append(u_distances_by_gid[gid]) v_distances_by_cell.append(v_distances_by_gid[gid]) this_cell_attrs, component_count, this_component_attrs = input_cell_config.gather_attributes( ) for attr_name, attr_val in viewitems(this_cell_attrs): gathered_cell_attributes[attr_name].append(attr_val) gathered_cell_attributes['Mean Rate'].append( np.mean(rate_map)) if component_count > 0: u_distances_by_component.extend( [u_distances_by_gid[gid]] * component_count) v_distances_by_component.extend( [v_distances_by_gid[gid]] * component_count) for attr_name, attr_val in viewitems( this_component_attrs): gathered_component_attributes[attr_name].extend( attr_val) this_module_id = this_cell_attrs['Module ID'] if debug and rank == 0: fig_title = f'{population} {this_selectivity_type_name} cell {gid}' if save_fig is not None: fig_options.saveFig = f'{save_fig} {fig_title}' plot_2D_rate_map( x=arena_x_mesh, y=arena_y_mesh, rate_map=rate_map, peak_rate=env.stimulus_config['Peak Rate'] [population][this_selectivity_type], title=f'{fig_title}\nModule: {this_module_id}', **fig_options()) x0_list_by_module[this_module_id].append( selectivity_attr_dict['X Offset']) y0_list_by_module[this_module_id].append( selectivity_attr_dict['Y Offset']) rate_map_sum_by_module[this_module_id] = np.add( rate_map, rate_map_sum_by_module[this_module_id]) count_by_module[this_module_id] += 1 if debug and iter_count >= 10: break if rank == 0: logger.info( f'Done reading from {this_selectivity_namespace} namespace for population {population}...' ) cell_count_hist, _, _ = np.histogram2d(u_distances_by_cell, v_distances_by_cell, bins=[u_edges, v_edges]) component_count_hist, _, _ = np.histogram2d( u_distances_by_component, v_distances_by_component, bins=[u_edges, v_edges]) if debug: context.update(locals()) gathered_cell_attr_hist = dict() gathered_component_attr_hist = dict() for key in gathered_cell_attributes: gathered_cell_attr_hist[key], _, _ = \ np.histogram2d(u_distances_by_cell, v_distances_by_cell, bins=[u_edges, v_edges], weights=gathered_cell_attributes[key]) for key in gathered_component_attributes: gathered_component_attr_hist[key], _, _ = \ np.histogram2d(u_distances_by_component, v_distances_by_component, bins=[u_edges, v_edges], weights=gathered_component_attributes[key]) gid_count = comm.gather(gid_count, root=0) cell_count_hist = comm.gather(cell_count_hist, root=0) component_count_hist = comm.gather(component_count_hist, root=0) gathered_cell_attr_hist = comm.gather(gathered_cell_attr_hist, root=0) gathered_component_attr_hist = comm.gather( gathered_component_attr_hist, root=0) x0_list_by_module = dict(x0_list_by_module) y0_list_by_module = dict(y0_list_by_module) x0_list_by_module = comm.reduce(x0_list_by_module, op=mpi_op_merge_list_dict, root=0) y0_list_by_module = comm.reduce(y0_list_by_module, op=mpi_op_merge_list_dict, root=0) rate_map_sum_by_module = dict(rate_map_sum_by_module) rate_map_sum_by_module = comm.gather(rate_map_sum_by_module, root=0) count_by_module = dict(count_by_module) count_by_module = comm.reduce(count_by_module, op=mpi_op_merge_count_dict, root=0) if rank == 0: gid_count = sum(gid_count) cell_count_hist = np.sum(cell_count_hist, axis=0) component_count_hist = np.sum(component_count_hist, axis=0) merged_cell_attr_hist = defaultdict( lambda: np.zeros_like(cell_count_hist)) merged_component_attr_hist = defaultdict( lambda: np.zeros_like(component_count_hist)) for each_cell_attr_hist in gathered_cell_attr_hist: for key in each_cell_attr_hist: merged_cell_attr_hist[key] = np.add( merged_cell_attr_hist[key], each_cell_attr_hist[key]) for each_component_attr_hist in gathered_component_attr_hist: for key in each_component_attr_hist: merged_component_attr_hist[key] = np.add( merged_component_attr_hist[key], each_component_attr_hist[key]) merged_rate_map_sum_by_module = defaultdict( lambda: np.zeros_like(arena_x_mesh)) for each_rate_map_sum_by_module in rate_map_sum_by_module: for this_module_id in each_rate_map_sum_by_module: merged_rate_map_sum_by_module[this_module_id] = \ np.add(merged_rate_map_sum_by_module[this_module_id], each_rate_map_sum_by_module[this_module_id]) logger.info( f'Processing {gid_count} {population} {this_selectivity_type_name} cells ' f'took {time.time() - start_time:.2f} s') if debug: context.update(locals()) fig_title = f'{population} {this_selectivity_type_name} field offsets' if save_fig is not None: fig_options.saveFig = f'{save_fig} {fig_title}' for key in merged_cell_attr_hist: fig_title = f'{population} {this_selectivity_type_name} cells {key} distribution' if save_fig is not None: fig_options.saveFig = f'{save_fig} {fig_title}' if colormap is not None: fig_options.colormap = colormap title = f'{population} {this_selectivity_type_name} cells\n{key} distribution' fig = plot_2D_histogram( merged_cell_attr_hist[key], x_edges=u_edges, y_edges=v_edges, norm=cell_count_hist, ylabel='Transverse position (um)', xlabel='Septo-temporal position (um)', title=title, cbar_label='Mean value per bin', cbar=True, **fig_options()) close_figure(fig) for key in merged_component_attr_hist: fig_title = f'{population} {this_selectivity_type_name} cells {key} distribution' if save_fig is not None: fig_options.saveFig = f'{save_fig} {fig_title}' title = f'{population} {this_selectivity_type_name} cells\n{key} distribution' fig = plot_2D_histogram( merged_component_attr_hist[key], x_edges=u_edges, y_edges=v_edges, norm=component_count_hist, ylabel='Transverse position (um)', xlabel='Septo-temporal position (um)', title=title, cbar_label='Mean value per bin', cbar=True, **fig_options()) close_figure(fig) for this_module_id in merged_rate_map_sum_by_module: num_cells = count_by_module[this_module_id] x0 = np.concatenate(x0_list_by_module[this_module_id]) y0 = np.concatenate(y0_list_by_module[this_module_id]) fig_title = f'{population} {this_selectivity_type_name} Module {this_module_id} rate map' if save_fig is not None: fig_options.saveFig = f'{save_fig} {fig_title}' fig = plot_2D_rate_map( x=arena_x_mesh, y=arena_y_mesh, x0=x0, y0=y0, rate_map=merged_rate_map_sum_by_module[this_module_id], title= (f'{population} {this_selectivity_type_name} rate map\n' f'Module {this_module_id} ({num_cells} cells)'), **fig_options()) close_figure(fig) if is_interactive and rank == 0: context.update(locals())
def main(config, config_prefix, coords_path, distances_namespace, output_path, arena_id, populations, use_noise_gen, io_size, chunk_size, value_chunk_size, cache_size, write_size, verbose, gather, interactive, debug, debug_count, plot, show_fig, save_fig, save_fig_dir, font_size, fig_format, dry_run): """ :param config: str (.yaml file name) :param config_prefix: str (path to dir) :param coords_path: str (path to file) :param distances_namespace: str :param output_path: str :param arena_id: str :param populations: tuple of str :param io_size: int :param chunk_size: int :param value_chunk_size: int :param cache_size: int :param write_size: int :param verbose: bool :param gather: bool; whether to gather population attributes to rank 0 for interactive analysis or plotting :param interactive: bool :param debug: bool :param plot: bool :param show_fig: bool :param save_fig: str (base file name) :param save_fig_dir: str (path to dir) :param font_size: float :param fig_format: str :param dry_run: bool """ comm = MPI.COMM_WORLD rank = comm.rank config_logging(verbose) env = Env(comm=comm, config_file=config, config_prefix=config_prefix, template_paths=None) if io_size == -1: io_size = comm.size if rank == 0: logger.info(f'{comm.size} ranks have been allocated') if save_fig is not None: plot = True if plot: import matplotlib.pyplot as plt from dentate.plot import plot_2D_rate_map, default_fig_options, save_figure, clean_axes, close_figure fig_options = copy.copy(default_fig_options) fig_options.saveFigDir = save_fig_dir fig_options.fontSize = font_size fig_options.figFormat = fig_format fig_options.showFig = show_fig if save_fig is not None: save_fig = '%s %s' % (save_fig, arena_id) fig_options.saveFig = save_fig if not dry_run and rank == 0: if output_path is None: raise RuntimeError( 'generate_input_selectivity_features: missing output_path') if not os.path.isfile(output_path): input_file = h5py.File(coords_path, 'r') output_file = h5py.File(output_path, 'w') input_file.copy('/H5Types', output_file) input_file.close() output_file.close() comm.barrier() population_ranges = read_population_ranges(coords_path, comm)[0] if len(populations) == 0: populations = sorted(population_ranges.keys()) reference_u_arc_distance_bounds_dict = {} if rank == 0: for population in sorted(populations): if population not in population_ranges: raise RuntimeError( 'generate_input_selectivity_features: specified population: %s not found in ' 'provided coords_path: %s' % (population, coords_path)) if population not in env.stimulus_config[ 'Selectivity Type Probabilities']: raise RuntimeError( 'generate_input_selectivity_features: selectivity type not specified for ' 'population: %s' % population) with h5py.File(coords_path, 'r') as coords_f: pop_size = population_ranges[population][1] unique_gid_count = len( set(coords_f['Populations'][population] [distances_namespace]['U Distance']['Cell Index'][:])) if pop_size != unique_gid_count: raise RuntimeError( 'generate_input_selectivity_features: only %i/%i unique cell indexes found ' 'for specified population: %s in provided coords_path: %s' % (unique_gid_count, pop_size, population, coords_path)) try: reference_u_arc_distance_bounds_dict[population] = \ coords_f['Populations'][population][distances_namespace].attrs['Reference U Min'], \ coords_f['Populations'][population][distances_namespace].attrs['Reference U Max'] except Exception: raise RuntimeError( 'generate_input_selectivity_features: problem locating attributes ' 'containing reference bounds in namespace: %s for population: %s from ' 'coords_path: %s' % (distances_namespace, population, coords_path)) comm.barrier() reference_u_arc_distance_bounds_dict = comm.bcast( reference_u_arc_distance_bounds_dict, root=0) selectivity_type_names = dict([ (val, key) for (key, val) in viewitems(env.selectivity_types) ]) selectivity_type_namespaces = dict() for this_selectivity_type in selectivity_type_names: this_selectivity_type_name = selectivity_type_names[ this_selectivity_type] chars = list(this_selectivity_type_name) chars[0] = chars[0].upper() selectivity_type_namespaces[this_selectivity_type_name] = ''.join( chars) + ' Selectivity %s' % arena_id if arena_id not in env.stimulus_config['Arena']: raise RuntimeError( f'Arena with ID: {arena_id} not specified by configuration at file path: {config_prefix}/{config}' ) arena = env.stimulus_config['Arena'][arena_id] arena_x_mesh, arena_y_mesh = None, None if rank == 0: arena_x_mesh, arena_y_mesh = \ get_2D_arena_spatial_mesh(arena=arena, spatial_resolution=env.stimulus_config['Spatial Resolution']) arena_x_mesh = comm.bcast(arena_x_mesh, root=0) arena_y_mesh = comm.bcast(arena_y_mesh, root=0) local_random = np.random.RandomState() selectivity_seed_offset = int( env.model_config['Random Seeds']['Input Selectivity']) local_random.seed(selectivity_seed_offset - 1) selectivity_config = InputSelectivityConfig(env.stimulus_config, local_random) if plot and rank == 0: selectivity_config.plot_module_probabilities(**fig_options()) if (debug or interactive) and rank == 0: context.update(dict(locals())) pop_norm_distances = {} rate_map_sum = {} x0_dict = {} y0_dict = {} write_every = max(1, int(math.floor(write_size / comm.size))) for population in sorted(populations): if rank == 0: logger.info( f'Generating input selectivity features for population {population}...' ) reference_u_arc_distance_bounds = reference_u_arc_distance_bounds_dict[ population] modular = True if population in env.stimulus_config[ 'Non-modular Place Selectivity Populations']: modular = False noise_gen_dict = None if use_noise_gen: noise_gen_dict = {} if modular: for module_id in range(env.stimulus_config['Number Modules']): extent_x, extent_y = get_2D_arena_extents(arena) margin = round( selectivity_config.place_module_field_widths[module_id] / 2.) arena_x_bounds, arena_y_bounds = get_2D_arena_bounds( arena, margin=margin) noise_gen = MPINoiseGenerator( comm=comm, bounds=(arena_x_bounds, arena_y_bounds), tile_rank=comm.rank, bin_size=0.5, mask_fraction=0.99, seed=int(selectivity_seed_offset + module_id * 1e6)) noise_gen_dict[module_id] = noise_gen else: margin = round( np.mean(selectivity_config.place_module_field_widths) / 2.) arena_x_bounds, arena_y_bounds = get_2D_arena_bounds( arena, margin=margin) noise_gen_dict[-1] = MPINoiseGenerator( comm=comm, bounds=(arena_x_bounds, arena_y_bounds), tile_rank=comm.rank, bin_size=0.5, mask_fraction=0.99, seed=selectivity_seed_offset) this_pop_norm_distances = {} this_rate_map_sum = defaultdict(lambda: np.zeros_like(arena_x_mesh)) this_x0_list = [] this_y0_list = [] start_time = time.time() gid_count = defaultdict(lambda: 0) distances_attr_gen = NeuroH5CellAttrGen(coords_path, population, namespace=distances_namespace, comm=comm, io_size=io_size, cache_size=cache_size) selectivity_attr_dict = dict( (key, dict()) for key in env.selectivity_types) for iter_count, (gid, distances_attr_dict) in enumerate(distances_attr_gen): req = comm.Ibarrier() if gid is None: if noise_gen_dict is not None: all_module_ids = [-1] if modular: all_module_ids = comm.allreduce(set([]), op=mpi_op_set_union) for module_id in all_module_ids: this_noise_gen = noise_gen_dict[module_id] global_num_fields = this_noise_gen.sync(0) for i in range(global_num_fields): this_noise_gen.add( np.empty(shape=(0, 0), dtype=np.float32), None) else: if rank == 0: logger.info( f'Rank {rank} generating selectivity features for gid {gid}...' ) u_arc_distance = distances_attr_dict['U Distance'][0] v_arc_distance = distances_attr_dict['V Distance'][0] norm_u_arc_distance = ( (u_arc_distance - reference_u_arc_distance_bounds[0]) / (reference_u_arc_distance_bounds[1] - reference_u_arc_distance_bounds[0])) this_pop_norm_distances[gid] = norm_u_arc_distance this_selectivity_type_name, this_selectivity_attr_dict = \ generate_input_selectivity_features(env, population, arena, arena_x_mesh, arena_y_mesh, gid, (norm_u_arc_distance, v_arc_distance), selectivity_config, selectivity_type_names, selectivity_type_namespaces, noise_gen_dict=noise_gen_dict, rate_map_sum=this_rate_map_sum, debug= (debug_callback, context) if debug else False) if 'X Offset' in this_selectivity_attr_dict: this_x0_list.append(this_selectivity_attr_dict['X Offset']) this_y0_list.append(this_selectivity_attr_dict['Y Offset']) selectivity_attr_dict[this_selectivity_type_name][ gid] = this_selectivity_attr_dict gid_count[this_selectivity_type_name] += 1 if noise_gen_dict is not None: for m in noise_gen_dict: noise_gen_dict[m].tile_rank = ( noise_gen_dict[m].tile_rank + 1) % comm.size req.wait() if (iter_count > 0 and iter_count % write_every == 0) or (debug and iter_count == debug_count): total_gid_count = 0 gid_count_dict = dict(gid_count.items()) req = comm.Ibarrier() selectivity_gid_count = comm.reduce(gid_count_dict, root=0, op=mpi_op_merge_count_dict) req.wait() if rank == 0: for selectivity_type_name in selectivity_gid_count: total_gid_count += selectivity_gid_count[ selectivity_type_name] for selectivity_type_name in selectivity_gid_count: logger.info( 'generated selectivity features for %i/%i %s %s cells in %.2f s' % (selectivity_gid_count[selectivity_type_name], total_gid_count, population, selectivity_type_name, (time.time() - start_time))) if not dry_run: for selectivity_type_name in sorted( selectivity_attr_dict.keys()): req = comm.Ibarrier() if rank == 0: logger.info( f'writing selectivity features for {population} [{selectivity_type_name}]...' ) selectivity_type_namespace = selectivity_type_namespaces[ selectivity_type_name] append_cell_attributes( output_path, population, selectivity_attr_dict[selectivity_type_name], namespace=selectivity_type_namespace, comm=comm, io_size=io_size, chunk_size=chunk_size, value_chunk_size=value_chunk_size) req.wait() del selectivity_attr_dict selectivity_attr_dict = dict( (key, dict()) for key in env.selectivity_types) gc.collect() if debug and iter_count >= debug_count: break pop_norm_distances[population] = this_pop_norm_distances rate_map_sum[population] = dict(this_rate_map_sum) if len(this_x0_list) > 0: x0_dict[population] = np.concatenate(this_x0_list, axis=None) y0_dict[population] = np.concatenate(this_y0_list, axis=None) total_gid_count = 0 gid_count_dict = dict(gid_count.items()) req = comm.Ibarrier() selectivity_gid_count = comm.reduce(gid_count_dict, root=0, op=mpi_op_merge_count_dict) req.wait() if rank == 0: for selectivity_type_name in selectivity_gid_count: total_gid_count += selectivity_gid_count[selectivity_type_name] for selectivity_type_name in selectivity_gid_count: logger.info( 'generated selectivity features for %i/%i %s %s cells in %.2f s' % (selectivity_gid_count[selectivity_type_name], total_gid_count, population, selectivity_type_name, (time.time() - start_time))) if not dry_run: for selectivity_type_name in sorted(selectivity_attr_dict.keys()): req = comm.Ibarrier() if rank == 0: logger.info( f'writing selectivity features for {population} [{selectivity_type_name}]...' ) selectivity_type_namespace = selectivity_type_namespaces[ selectivity_type_name] append_cell_attributes( output_path, population, selectivity_attr_dict[selectivity_type_name], namespace=selectivity_type_namespace, comm=comm, io_size=io_size, chunk_size=chunk_size, value_chunk_size=value_chunk_size) req.wait() del selectivity_attr_dict gc.collect() req = comm.Ibarrier() req.wait() if gather: merged_pop_norm_distances = {} for population in sorted(populations): merged_pop_norm_distances[population] = \ comm.reduce(pop_norm_distances[population], root=0, op=mpi_op_merge_dict) merged_rate_map_sum = comm.reduce(rate_map_sum, root=0, op=mpi_op_merge_rate_map_dict) merged_x0 = comm.reduce(x0_dict, root=0, op=mpi_op_concatenate_ndarray_dict) merged_y0 = comm.reduce(y0_dict, root=0, op=mpi_op_concatenate_ndarray_dict) if rank == 0: if plot: for population in merged_pop_norm_distances: norm_distance_values = np.asarray( list(merged_pop_norm_distances[population].values())) hist, edges = np.histogram(norm_distance_values, bins=100) fig, axes = plt.subplots(1) axes.plot(edges[1:], hist) axes.set_title(f'Population: {population}') axes.set_xlabel('Normalized cell position') axes.set_ylabel('Cell count') clean_axes(axes) if save_fig is not None: save_figure( f'{save_fig} {population} normalized distances histogram', fig=fig, **fig_options()) if fig_options.showFig: fig.show() close_figure(fig) for population in merged_rate_map_sum: for selectivity_type_name in merged_rate_map_sum[ population]: fig_title = f'{population} {this_selectivity_type_name} summed rate maps' if save_fig is not None: fig_options.saveFig = f'{save_fig} {fig_title}' plot_2D_rate_map( x=arena_x_mesh, y=arena_y_mesh, rate_map=merged_rate_map_sum[population] [selectivity_type_name], title= f'Summed rate maps\n{population} {selectivity_type_name} cells', **fig_options()) for population in merged_x0: fig_title = f'{population} field offsets' if save_fig is not None: fig_options.saveFig = f'{save_fig} {fig_title}' x0 = merged_x0[population] y0 = merged_y0[population] fig, axes = plt.subplots(1) axes.scatter(x0, y0) if save_fig is not None: save_figure(f'{save_fig} {fig_title}', fig=fig, **fig_options()) if fig_options.showFig: fig.show() close_figure(fig) if interactive and rank == 0: context.update(locals())
def main(config, config_prefix, coords_path, distances_namespace, output_path, arena_id, populations, io_size, chunk_size, value_chunk_size, cache_size, write_size, verbose, gather, interactive, debug, plot, show_fig, save_fig, save_fig_dir, font_size, fig_format, dry_run): """ :param config: str (.yaml file name) :param config_prefix: str (path to dir) :param coords_path: str (path to file) :param distances_namespace: str :param output_path: str :param arena_id: str :param populations: tuple of str :param io_size: int :param chunk_size: int :param value_chunk_size: int :param cache_size: int :param write_size: int :param verbose: bool :param gather: bool; whether to gather population attributes to rank 0 for interactive analysis or plotting :param interactive: bool :param debug: bool :param plot: bool :param show_fig: bool :param save_fig: str (base file name) :param save_fig_dir: str (path to dir) :param font_size: float :param fig_format: str :param dry_run: bool """ comm = MPI.COMM_WORLD rank = comm.rank config_logging(verbose) env = Env(comm=comm, config_file=config, config_prefix=config_prefix, template_paths=None) if io_size == -1: io_size = comm.size if rank == 0: logger.info('%i ranks have been allocated' % comm.size) if save_fig is not None: plot = True if plot: import matplotlib.pyplot as plt from dentate.plot import plot_2D_rate_map, default_fig_options, save_figure, clean_axes fig_options = copy.copy(default_fig_options) fig_options.saveFigDir = save_fig_dir fig_options.fontSize = font_size fig_options.figFormat = fig_format fig_options.showFig = show_fig if save_fig is not None: save_fig = '%s %s' % (save_fig, arena_id) fig_options.saveFig = save_fig if not dry_run and rank == 0: if output_path is None: raise RuntimeError( 'generate_input_selectivity_features: missing output_path') if not os.path.isfile(output_path): input_file = h5py.File(coords_path, 'r') output_file = h5py.File(output_path, 'w') input_file.copy('/H5Types', output_file) input_file.close() output_file.close() comm.barrier() population_ranges = read_population_ranges(coords_path, comm)[0] if len(populations) == 0: populations = sorted(population_ranges.keys()) reference_u_arc_distance_bounds_dict = {} if rank == 0: for population in sorted(populations): if population not in population_ranges: raise RuntimeError( 'generate_input_selectivity_features: specified population: %s not found in ' 'provided coords_path: %s' % (population, coords_path)) if population not in env.stimulus_config[ 'Selectivity Type Probabilities']: raise RuntimeError( 'generate_input_selectivity_features: selectivity type not specified for ' 'population: %s' % population) with h5py.File(coords_path, 'r') as coords_f: pop_size = population_ranges[population][1] unique_gid_count = len( set(coords_f['Populations'][population] [distances_namespace]['U Distance']['Cell Index'][:])) if pop_size != unique_gid_count: raise RuntimeError( 'generate_input_selectivity_features: only %i/%i unique cell indexes found ' 'for specified population: %s in provided coords_path: %s' % (unique_gid_count, pop_size, population, coords_path)) try: reference_u_arc_distance_bounds_dict[population] = \ coords_f['Populations'][population][distances_namespace].attrs['Reference U Min'], \ coords_f['Populations'][population][distances_namespace].attrs['Reference U Max'] except Exception: raise RuntimeError( 'generate_input_selectivity_features: problem locating attributes ' 'containing reference bounds in namespace: %s for population: %s from ' 'coords_path: %s' % (distances_namespace, population, coords_path)) comm.barrier() reference_u_arc_distance_bounds_dict = comm.bcast( reference_u_arc_distance_bounds_dict, root=0) selectivity_type_names = dict([ (val, key) for (key, val) in viewitems(env.selectivity_types) ]) selectivity_type_namespaces = dict() for this_selectivity_type in selectivity_type_names: this_selectivity_type_name = selectivity_type_names[ this_selectivity_type] chars = list(this_selectivity_type_name) chars[0] = chars[0].upper() selectivity_type_namespaces[this_selectivity_type_name] = ''.join( chars) + ' Selectivity %s' % arena_id if arena_id not in env.stimulus_config['Arena']: raise RuntimeError( 'Arena with ID: %s not specified by configuration at file path: %s' % (arena_id, config_prefix + '/' + config)) arena = env.stimulus_config['Arena'][arena_id] arena_x_mesh, arena_y_mesh = None, None if rank == 0: arena_x_mesh, arena_y_mesh = \ get_2D_arena_spatial_mesh(arena=arena, spatial_resolution=env.stimulus_config['Spatial Resolution']) arena_x_mesh = comm.bcast(arena_x_mesh, root=0) arena_y_mesh = comm.bcast(arena_y_mesh, root=0) local_random = np.random.RandomState() selectivity_seed_offset = int( env.model_config['Random Seeds']['Input Selectivity']) local_random.seed(selectivity_seed_offset - 1) selectivity_config = InputSelectivityConfig(env.stimulus_config, local_random) if plot and rank == 0: selectivity_config.plot_module_probabilities(**fig_options()) if (debug or interactive) and rank == 0: context.update(dict(locals())) pop_norm_distances = {} rate_map_sum = {} write_every = max(1, int(math.floor(write_size / comm.size))) for population in sorted(populations): if rank == 0: logger.info( 'Generating input selectivity features for population %s...' % population) reference_u_arc_distance_bounds = reference_u_arc_distance_bounds_dict[ population] this_pop_norm_distances = {} this_rate_map_sum = defaultdict(lambda: np.zeros_like(arena_x_mesh)) start_time = time.time() gid_count = defaultdict(lambda: 0) distances_attr_gen = NeuroH5CellAttrGen(coords_path, population, namespace=distances_namespace, comm=comm, io_size=io_size, cache_size=cache_size) selectivity_attr_dict = dict( (key, dict()) for key in env.selectivity_types) for iter_count, (gid, distances_attr_dict) in enumerate(distances_attr_gen): if gid is not None: u_arc_distance = distances_attr_dict['U Distance'][0] v_arc_distance = distances_attr_dict['V Distance'][0] norm_u_arc_distance = ( (u_arc_distance - reference_u_arc_distance_bounds[0]) / (reference_u_arc_distance_bounds[1] - reference_u_arc_distance_bounds[0])) this_pop_norm_distances[gid] = norm_u_arc_distance this_selectivity_type_name, this_selectivity_attr_dict = \ generate_input_selectivity_features(env, population, arena, arena_x_mesh, arena_y_mesh, gid, (norm_u_arc_distance, v_arc_distance), selectivity_config, selectivity_type_names, selectivity_type_namespaces, rate_map_sum=this_rate_map_sum, debug= (debug_callback, context) if debug else False) selectivity_attr_dict[this_selectivity_type_name][ gid] = this_selectivity_attr_dict gid_count[this_selectivity_type_name] += 1 if (iter_count > 0 and iter_count % write_every == 0) or (debug and iter_count == 10): total_gid_count = 0 gid_count_dict = dict(gid_count.items()) selectivity_gid_count = comm.reduce(gid_count_dict, root=0, op=mpi_op_merge_count_dict) if rank == 0: for selectivity_type_name in selectivity_gid_count: total_gid_count += selectivity_gid_count[ selectivity_type_name] for selectivity_type_name in selectivity_gid_count: logger.info( 'generated selectivity features for %i/%i %s %s cells in %.2f s' % (selectivity_gid_count[selectivity_type_name], total_gid_count, population, selectivity_type_name, (time.time() - start_time))) if not dry_run: for selectivity_type_name in sorted( selectivity_attr_dict.keys()): if rank == 0: logger.info( 'writing selectivity features for %s [%s]...' % (population, selectivity_type_name)) selectivity_type_namespace = selectivity_type_namespaces[ selectivity_type_name] append_cell_attributes( output_path, population, selectivity_attr_dict[selectivity_type_name], namespace=selectivity_type_namespace, comm=comm, io_size=io_size, chunk_size=chunk_size, value_chunk_size=value_chunk_size) del selectivity_attr_dict selectivity_attr_dict = dict( (key, dict()) for key in env.selectivity_types) if debug and iter_count >= 10: break pop_norm_distances[population] = this_pop_norm_distances rate_map_sum[population] = this_rate_map_sum total_gid_count = 0 gid_count_dict = dict(gid_count.items()) selectivity_gid_count = comm.reduce(gid_count_dict, root=0, op=mpi_op_merge_count_dict) if rank == 0: for selectivity_type_name in selectivity_gid_count: total_gid_count += selectivity_gid_count[selectivity_type_name] for selectivity_type_name in selectivity_gid_count: logger.info( 'generated selectivity features for %i/%i %s %s cells in %.2f s' % (selectivity_gid_count[selectivity_type_name], total_gid_count, population, selectivity_type_name, (time.time() - start_time))) if not dry_run: for selectivity_type_name in sorted(selectivity_attr_dict.keys()): if rank == 0: logger.info('writing selectivity features for %s [%s]...' % (population, selectivity_type_name)) selectivity_type_namespace = selectivity_type_namespaces[ selectivity_type_name] append_cell_attributes( output_path, population, selectivity_attr_dict[selectivity_type_name], namespace=selectivity_type_namespace, comm=comm, io_size=io_size, chunk_size=chunk_size, value_chunk_size=value_chunk_size) del selectivity_attr_dict comm.barrier() if gather: merged_pop_norm_distances = {} for population in sorted(populations): merged_pop_norm_distances[population] = \ comm.reduce(pop_norm_distances[population], root=0, op=mpi_op_merge_dict) rate_map_sum = dict([(key, dict(val.items())) for key, val in viewitems(rate_map_sum)]) rate_map_sum = comm.gather(rate_map_sum, root=0) if rank == 0: merged_rate_map_sum = defaultdict( lambda: defaultdict(lambda: np.zeros_like(arena_x_mesh))) for each_rate_map_sum in rate_map_sum: for population in each_rate_map_sum: for selectivity_type_name in each_rate_map_sum[population]: merged_rate_map_sum[population][selectivity_type_name] = \ np.add(merged_rate_map_sum[population][selectivity_type_name], each_rate_map_sum[population][selectivity_type_name]) if plot: for population in merged_pop_norm_distances: norm_distance_values = np.asarray( list(merged_pop_norm_distances[population].values())) hist, edges = np.histogram(norm_distance_values, bins=100) fig, axes = plt.subplots(1) axes.plot(edges[1:], hist) axes.set_title('Population: %s' % population) axes.set_xlabel('Normalized cell position') axes.set_ylabel('Cell count') clean_axes(axes) if save_fig is not None: save_figure('%s %s normalized distances histogram' % (save_fig, population), fig=fig, **fig_options()) if fig_options.showFig: fig.show() for population in merged_rate_map_sum: for selectivity_type_name in merged_rate_map_sum[ population]: fig_title = '%s %s summed rate maps' % ( population, this_selectivity_type_name) if save_fig is not None: fig_options.saveFig = '%s %s' % (save_fig, fig_title) plot_2D_rate_map( x=arena_x_mesh, y=arena_y_mesh, rate_map=merged_rate_map_sum[population] [selectivity_type_name], title='Summed rate maps\n%s %s cells' % (population, selectivity_type_name), **fig_options()) if interactive and rank == 0: context.update(locals())
def main(config, coordinates, gid, field_width, peak_rate, input_features_path, input_features_namespaces, output_weights_path, output_features_path, weights_path, h5types_path, synapse_name, initial_weights_namespace, structured_weights_namespace, connections_path, destination, sources, arena_id, baseline_weight, field_width_scale, max_iter, verbose, dry_run, interactive): """ :param config: str (path to .yaml file) :param weights_path: str (path to .h5 file) :param initial_weights_namespace: str :param structured_weights_namespace: str :param connections_path: str (path to .h5 file) :param destination: str :param sources: list of str :param verbose: :param dry_run: :return: """ utils.config_logging(verbose) logger = utils.get_script_logger(__file__) env = Env(config_file=config) if output_weights_path is None: if weights_path is None: raise RuntimeError('Output weights path must be specified when weights path is not specified.') output_weights_path = weights_path if (not dry_run): if not os.path.isfile(output_weights_path): if weights_path is not None: input_file = h5py.File(weights_path,'r') elif h5types_path is not None: input_file = h5py.File(h5types_path,'r') else: raise RuntimeError('h5types input path must be specified when weights path is not specified.') output_file = h5py.File(output_weights_path,'w') input_file.copy('/H5Types',output_file) input_file.close() output_file.close() this_input_features_namespaces = ['%s %s' % (input_features_namespace, arena_id) for input_features_namespace in input_features_namespaces] initial_weights_dict = None if weights_path is not None: logger.info('Reading initial weights data from %s...' % weights_path) cell_attributes_dict = read_cell_attribute_selection(weights_path, destination, namespaces=[initial_weights_namespace], selection=[gid]) if initial_weights_namespace in cell_attributes_dict: initial_weights_iter = cell_attributes_dict[initial_weights_namespace] initial_weights_dict = { gid: attr_dict for gid, attr_dict in initial_weights_iter } else: raise RuntimeError('Initial weights namespace %s was not found in file %s' % (initial_weights_namespace, weights_path)) logger.info('Rank %i; destination: %s; read synaptic weights for %i cells' % (env.comm.rank, destination, len(initial_weights_dict))) features_attr_names = ['Num Fields', 'Field Width', 'Peak Rate', 'X Offset', 'Y Offset', 'Arena Rate Map'] local_random = np.random.RandomState() seed_offset = int(env.model_config['Random Seeds']['GC Structured Weights']) local_random.seed(int(gid + seed_offset)) spatial_resolution = env.stimulus_config['Spatial Resolution'] # cm arena = env.stimulus_config['Arena'][arena_id] default_run_vel = arena.properties['default run velocity'] # cm/s x, y = stimulus.get_2D_arena_spatial_mesh(arena, spatial_resolution) plasticity_kernel = lambda x, y, x_loc, y_loc, sx, sy: gauss2d(x-x_loc, y-y_loc, sx=sx, sy=sy) plasticity_kernel = np.vectorize(plasticity_kernel, excluded=[2,3,4,5]) dst_input_features = defaultdict(dict) num_fields = len(coordinates) this_field_width = np.array([field_width]*num_fields, dtype=np.float32) this_peak_rate = np.array([peak_rate]*num_fields, dtype=np.float32) this_x0 = np.array([x for x, y in coordinates], dtype=np.float32) this_y0 = np.array([y for x, y in coordinates], dtype=np.float32) this_rate_map = np.asarray(get_rate_map(this_x0, this_y0, this_field_width, this_peak_rate, x, y), dtype=np.float32) selectivity_type = env.selectivity_types['place'] dst_input_features[destination][gid] = { 'Selectivity Type': np.array([selectivity_type], dtype=np.uint8), 'Num Fields': np.array([num_fields], dtype=np.uint8), 'Field Width': this_field_width, 'Peak Rate': this_peak_rate, 'X Offset': this_x0, 'Y Offset': this_y0, 'Arena Rate Map': this_rate_map.ravel() } selection=[gid] structured_weights_dict = {} source_syn_dict = defaultdict(lambda: defaultdict(list)) syn_weight_dict = {} if weights_path is not None: initial_weights_iter = read_cell_attribute_selection(weights_path, destination, namespace=initial_weights_namespace, selection=selection) syn_weight_attr_dict = dict(initial_weights_iter) syn_ids = syn_weight_attr_dict[gid]['syn_id'] weights = syn_weight_attr_dict[gid][synapse_name] for (syn_id, weight) in zip(syn_ids, weights): syn_weight_dict[int(syn_id)] = float(weight) logger.info('destination: %s; gid %i; received synaptic weights for %i synapses' % (destination, gid, len(syn_weight_dict))) (graph, edge_attr_info) = read_graph_selection(file_name=connections_path, selection=[gid], namespaces=['Synapses']) syn_id_attr_index = None for source, edge_iter in viewitems(graph[destination]): this_edge_attr_info = edge_attr_info[destination][source] if 'Synapses' in this_edge_attr_info and \ 'syn_id' in this_edge_attr_info['Synapses']: syn_id_attr_index = this_edge_attr_info['Synapses']['syn_id'] for (destination_gid, edges) in edge_iter: assert destination_gid == gid source_gids, edge_attrs = edges syn_ids = edge_attrs['Synapses'][syn_id_attr_index] this_source_syn_dict = source_syn_dict[source] count = 0 for i in range(len(source_gids)): this_source_gid = source_gids[i] this_syn_id = syn_ids[i] this_syn_wgt = syn_weight_dict.get(this_syn_id, 0.0) this_source_syn_dict[this_source_gid].append((this_syn_id, this_syn_wgt)) count += 1 logger.info('destination: %s; gid %i; %d synaptic weights from source population %s' % (destination, gid, count, source)) src_input_features = defaultdict(dict) for source in sources: source_gids = list(source_syn_dict[source].keys()) for input_features_namespace in this_input_features_namespaces: input_features_iter = read_cell_attribute_selection(input_features_path, source, namespace=input_features_namespace, mask=set(features_attr_names), selection=source_gids) this_src_input_features = src_input_features[source] count = 0 for gid, attr_dict in input_features_iter: this_src_input_features[gid] = attr_dict count += 1 logger.info('Read %s feature data for %i cells in population %s' % (input_features_namespace, count, source)) this_syn_weights = \ synapses.generate_structured_weights(destination_gid, destination, synapse_name, sources, dst_input_features, src_input_features, source_syn_dict, spatial_mesh=(x,y), plasticity_kernel=plasticity_kernel, field_width_scale=field_width_scale, baseline_weight=baseline_weight, local_random=local_random, interactive=interactive) assert this_syn_weights is not None structured_weights_dict[destination_gid] = this_syn_weights logger.info('destination: %s; gid %i; generated structured weights for %i inputs' % (destination, destination_gid, len(this_syn_weights['syn_id']))) gc.collect() if not dry_run: logger.info('Destination: %s; appending structured weights...' % (destination)) this_structured_weights_namespace = '%s %s' % (structured_weights_namespace, arena_id) append_cell_attributes(output_weights_path, destination, structured_weights_dict, namespace=this_structured_weights_namespace) logger.info('Destination: %s; appended structured weights' % (destination)) structured_weights_dict.clear() if output_features_path is not None: output_features_namespace = 'Place Selectivity %s' % arena_id cell_attr_dict = dst_input_features[destination] logger.info('Destination: %s; appending features...' % (destination)) append_cell_attributes(output_features_path, destination, cell_attr_dict, namespace=output_features_namespace) gc.collect() del(syn_weight_dict) del(src_input_features) del(dst_input_features)
def main(config, config_prefix, arena_id, populations, module_ids, target_fraction_active, normalize_scale, verbose, interactive, debug, plot, show_fig, save_fig, save_fig_dir, font_size, fig_format): """ :param config: str (.yaml file name) :param config_prefix: str (path to dir) :param arena_id: str :param populations: tuple of str :param module_ids: tuple of int :param target_fraction_active: float :param normalize_scale: bool; whether to interpret the scale of the num_place_field_probabilities distribution as normalized to the scale of the mean place field width :param verbose: bool :param interactive: bool :param debug: bool :param plot: bool :param show_fig: bool :param save_fig: str (base file name) :param save_fig_dir: str (path to dir) :param font_size: float :param fig_format: str """ comm = MPI.COMM_WORLD rank = comm.rank config_logging(verbose) env = Env(comm=comm, config_file=config, config_prefix=config_prefix, template_paths=None) if plot: import matplotlib.pyplot as plt from dentate.plot import clean_axes fig_options = copy.copy(default_fig_options) fig_options.saveFigDir = save_fig_dir fig_options.fontSize = font_size fig_options.figFormat = fig_format fig_options.showFig = show_fig if save_fig is not None: save_fig = '%s %s' % (save_fig, arena_id) fig_options.saveFig = save_fig if len(populations) == 0: populations = ('MC', 'ConMC', 'LPP', 'GC', 'MPP', 'CA3c') if arena_id not in env.stimulus_config['Arena']: raise RuntimeError( 'Arena with ID: %s not specified by configuration at file path: %s' % (arena_id, config_prefix + '/' + config)) selectivity_type_names = dict( (val, key) for (key, val) in viewitems(env.selectivity_types)) arena = env.stimulus_config['Arena'][arena_id] arena_x_mesh, arena_y_mesh = \ get_2D_arena_spatial_mesh(arena=arena, spatial_resolution=env.stimulus_config['Spatial Resolution']) local_random = np.random.RandomState() selectivity_seed_offset = int( env.model_config['Random Seeds']['Input Selectivity']) local_random.seed(selectivity_seed_offset - 1) selectivity_config = InputSelectivityConfig(env.stimulus_config, local_random) this_selectivity_type_name = 'place' this_selectivity_type = env.selectivity_types['place'] if interactive: context.update(locals()) if len(module_ids) == 0: module_ids = selectivity_config.module_ids elif not all([ module_id in selectivity_config.module_ids for module_id in module_ids ]): raise RuntimeError( 'calibrate_DG_num_place_field_probabilities: invalid module_ids provided: %s' % str(module_ids)) for population in populations: if population not in env.stimulus_config[ 'Num Place Field Probabilities']: raise RuntimeError( 'calibrate_DG_num_place_field_probabilities: probabilities for number of place fields ' 'not specified for population: %s' % population) num_place_field_probabilities = env.stimulus_config[ 'Num Place Field Probabilities'][population] if population not in env.stimulus_config['Peak Rate'] or \ this_selectivity_type not in env.stimulus_config['Peak Rate'][population]: raise RuntimeError( 'calibrate_DG_num_place_field_probabilities: peak rate not specified for population: ' '%s, selectivity type: %s' % (population, this_selectivity_type_name)) peak_rate = env.stimulus_config['Peak Rate'][population][ this_selectivity_type] start_time = time.time() for module_id in module_ids: field_width = selectivity_config.place_module_field_widths[ module_id] logger.info( 'Calibrating distribution of num_place_field_probabilities for population: %s, module: %i, ' 'field width: %.2f' % (population, module_id, field_width)) modified_num_place_field_probabilities = \ calibrate_num_place_field_probabilities(num_place_field_probabilities, field_width, peak_rate=peak_rate, selectivity_type=this_selectivity_type, arena=arena, normalize_scale=normalize_scale, selectivity_config=selectivity_config, target_fraction_active=target_fraction_active, random_seed=selectivity_seed_offset + module_id, plot=plot and show_fig) logger.info( 'Modified num_place_field_probabilities for population: %s, module: %i, field width: %.2f' % (population, module_id, field_width)) print_param_dict_like_yaml(modified_num_place_field_probabilities) sys.stdout.flush() if debug: context.update(locals()) return if interactive: context.update(locals())