def load_celltypes(self): """ :return: """ rank = self.comm.Get_rank() size = self.comm.Get_size() celltypes = self.celltypes typenames = sorted(celltypes.keys()) if rank == 0: self.logger.info('env.data_file_path = %s' % str(self.data_file_path)) (population_ranges, _) = read_population_ranges(self.data_file_path, self.comm) if rank == 0: self.logger.info('population_ranges = %s' % str(population_ranges)) for k in typenames: population_range = population_ranges.get(k, None) if population_range is not None: celltypes[k]['start'] = population_ranges[k][0] celltypes[k]['num'] = population_ranges[k][1] if 'mechanism file' in celltypes[k]: celltypes[k]['mech_file_path'] = '%s/%s' % ( self.config_prefix, celltypes[k]['mechanism file']) mech_dict = read_from_yaml(celltypes[k]['mech_file_path']) celltypes[k]['mech_dict'] = mech_dict if 'synapses' in celltypes[k]: synapses_dict = celltypes[k]['synapses'] if 'weights' in synapses_dict: weights_config = synapses_dict['weights'] if isinstance(weights_config, list): weights_dicts = weights_config else: weights_dicts = [weights_config] for weights_dict in weights_dicts: if 'expr' in weights_dict: expr = weights_dict['expr'] parameter = weights_dict['parameter'] const = weights_dict.get('const', {}) clos = ExprClosure(parameter, expr, const) weights_dict['closure'] = clos synapses_dict['weights'] = weights_dicts population_names = read_population_names(self.data_file_path, self.comm) if rank == 0: self.logger.info('population_names = %s' % str(population_names)) self.cell_attribute_info = read_cell_attribute_info( self.data_file_path, population_names, comm=self.comm) if rank == 0: self.logger.info('attribute info: %s' % str(self.cell_attribute_info))
def query_cell_attributes(input_file, population_names, namespace_ids=None): pop_state_dict = {} logger.info('Querying cell attribute data...') attr_info_dict = read_cell_attribute_info(input_file, populations=population_names, read_cell_index=True) namespace_id_lst = [] for pop_name in attr_info_dict: cell_index = None pop_state_dict[pop_name] = {} if namespace_ids is None: namespace_id_lst = attr_info_dict[pop_name].keys() else: namespace_id_lst = namespace_ids return namespace_id_lst, attr_info_dict
def read_state(input_file, population_names, namespace_id, time_variable='t', state_variable='v', time_range=None, max_units=None, gid=None, comm=None, n_trials=-1): if comm is None: comm = MPI.COMM_WORLD pop_state_dict = {} logger.info( 'Reading state data from populations %s, namespace %s gid = %s...' % (str(population_names), namespace_id, str(gid))) attr_info_dict = read_cell_attribute_info(input_file, populations=population_names, read_cell_index=True) for pop_name in population_names: cell_index = None pop_state_dict[pop_name] = {} for attr_name, attr_cell_index in attr_info_dict[pop_name][ namespace_id]: if state_variable == attr_name: cell_index = attr_cell_index if cell_index is None: raise RuntimeError( 'read_state: Unable to find recordings for state variable %s in population %s namespace %s' % (state_variable, pop_name, str(namespace_id))) cell_set = set(cell_index) # Limit to max_units if gid is None: if (max_units is not None) and (len(cell_set) > max_units): logger.info( ' Reading only randomly sampled %i out of %i units for population %s' % (max_units, len(cell_set), pop_name)) sample_inds = np.random.randint(0, len(cell_set) - 1, size=int(max_units)) cell_set_lst = list(cell_set) gid_set = set([cell_set_lst[i] for i in sample_inds]) else: gid_set = cell_set else: gid_set = set(gid) state_dict = {} if gid is None: valiter = read_cell_attributes(input_file, pop_name, namespace=namespace_id, comm=comm) else: valiter = read_cell_attribute_selection(input_file, pop_name, namespace=namespace_id, selection=list(gid_set), comm=comm) if time_range is None: for cellind, vals in valiter: if cellind is not None: trial_dur = vals.get('trial duration', None) distance = vals.get('distance', [None])[0] section = vals.get('section', [None])[0] loc = vals.get('loc', [None])[0] tvals = np.asarray(vals[time_variable], dtype=np.float32) svals = np.asarray(vals[state_variable], dtype=np.float32) trial_bounds = list( np.where(np.isclose(tvals, tvals[0], atol=1e-4))[0]) n_trial_bounds = len(trial_bounds) trial_bounds.append(len(tvals)) if n_trials == -1: this_n_trials = n_trial_bounds else: this_n_trials = min(n_trial_bounds, n_trials) if this_n_trials > 1: state_dict[cellind] = (np.split( tvals, trial_bounds[1:n_trials]), np.split( svals, trial_bounds[1:n_trials]), distance, section, loc) else: state_dict[cellind] = ([tvals[:trial_bounds[1]] ], [svals[:trial_bounds[1]]], distance, section, loc) else: for cellind, vals in valiter: if cellind is not None: distance = vals.get('distance', [None])[0] section = vals.get('section', [None])[0] loc = vals.get('loc', [None])[0] tinds = np.argwhere((vals[time_variable] <= time_range[1]) & (vals[time_variable] >= time_range[0])) tvals = np.asarray(vals[time_variable][tinds], dtype=np.float32).reshape((-1, )) svals = np.asarray(vals[state_variable][tinds], dtype=np.float32).reshape((-1, )) trial_bounds = list( np.where(np.isclose(tvals, tvals[0], atol=1e-4))[0]) n_trial_bounds = len(trial_bounds) trial_bounds.append(len(tvals)) if n_trials == -1: this_n_trials = n_trial_bounds else: this_n_trials = min(n_trial_bounds, n_trials) if this_n_trials > 1: state_dict[cellind] = (np.split( tvals, trial_bounds[1:n_trials]), np.split( svals, trial_bounds[1:n_trials]), distance, section, loc) else: state_dict[cellind] = ([tvals[:trial_bounds[1]] ], [svals[:trial_bounds[1]]], distance, section, loc) pop_state_dict[pop_name] = state_dict return { 'states': pop_state_dict, 'time_variable': time_variable, 'state_variable': state_variable }
def main(config_file, population, dt, gid, gid_selection_file, arena_id, trajectory_id, generate_weights, t_max, t_min, nprocs_per_worker, n_epochs, n_initial, initial_maxiter, initial_method, optimizer_method, surrogate_method, population_size, num_generations, resample_fraction, mutation_rate, template_paths, dataset_prefix, config_prefix, param_config_name, selectivity_config_name, param_type, recording_profile, results_file, results_path, spike_events_path, spike_events_namespace, spike_events_t, input_features_path, input_features_namespaces, n_trials, trial_regime, problem_regime, target_features_path, target_features_namespace, target_state_variable, target_state_filter, use_coreneuron, cooperative_init, spawn_startup_wait): """ Optimize the input stimulus selectivity of the specified cell in a network clamp configuration. """ init_params = dict(locals()) comm = MPI.COMM_WORLD size = comm.Get_size() rank = comm.Get_rank() results_file_id = None if rank == 0: results_file_id = generate_results_file_id(population, gid) results_file_id = comm.bcast(results_file_id, root=0) comm.barrier() np.seterr(all='raise') verbose = True cache_queries = True config_logging(verbose) cell_index_set = set([]) if gid_selection_file is not None: with open(gid_selection_file, 'r') as f: lines = f.readlines() for line in lines: gid = int(line) cell_index_set.add(gid) elif gid is not None: cell_index_set.add(gid) else: comm.barrier() comm0 = comm.Split(2 if rank == 0 else 1, 0) if rank == 0: env = Env(**init_params, comm=comm0) attr_info_dict = read_cell_attribute_info(env.data_file_path, populations=[population], read_cell_index=True, comm=comm0) cell_index = None attr_name, attr_cell_index = next(iter(attr_info_dict[population]['Trees'])) cell_index_set = set(attr_cell_index) comm.barrier() cell_index_set = comm.bcast(cell_index_set, root=0) comm.barrier() comm0.Free() init_params['cell_index_set'] = cell_index_set del(init_params['gid']) params = dict(locals()) env = Env(**params) if size == 1: configure_hoc_env(env) init(env, population, cell_index_set, arena_id, trajectory_id, n_trials, spike_events_path, spike_events_namespace=spike_events_namespace, spike_train_attr_name=spike_events_t, input_features_path=input_features_path, input_features_namespaces=input_features_namespaces, generate_weights_pops=set(generate_weights), t_min=t_min, t_max=t_max) if (population in env.netclamp_config.optimize_parameters[param_type]): opt_params = env.netclamp_config.optimize_parameters[param_type][population] else: raise RuntimeError(f'optimize_selectivity: population {population} does not have optimization configuration') if target_state_variable is None: target_state_variable = 'v' init_params['target_features_arena'] = arena_id init_params['target_features_trajectory'] = trajectory_id opt_state_baseline = opt_params['Targets']['state'][target_state_variable]['baseline'] init_params['state_baseline'] = opt_state_baseline init_params['state_variable'] = target_state_variable init_params['state_filter'] = target_state_filter init_objfun_name = 'init_selectivity_objfun' best = optimize_run(env, population, param_config_name, selectivity_config_name, init_objfun_name, problem_regime=problem_regime, n_epochs=n_epochs, n_initial=n_initial, initial_maxiter=initial_maxiter, initial_method=initial_method, optimizer_method=optimizer_method, surrogate_method=surrogate_method, population_size=population_size, num_generations=num_generations, resample_fraction=resample_fraction, mutation_rate=mutation_rate, param_type=param_type, init_params=init_params, results_file=results_file, nprocs_per_worker=nprocs_per_worker, cooperative_init=cooperative_init, spawn_startup_wait=spawn_startup_wait, verbose=verbose) opt_param_config = optimization_params(env.netclamp_config.optimize_parameters, [population], param_config_name, param_type) if best is not None: if results_path is not None: run_ts = time.strftime("%Y%m%d_%H%M%S") file_path = f'{results_path}/optimize_selectivity.{run_ts}.yaml' param_names = opt_param_config.param_names param_tuples = opt_param_config.param_tuples if ProblemRegime[problem_regime] == ProblemRegime.every: results_config_dict = {} for gid, prms in viewitems(best): n_res = prms[0][1].shape[0] prms_dict = dict(prms) this_results_config_dict = {} for i in range(n_res): results_param_list = [] for param_pattern, param_tuple in zip(param_names, param_tuples): results_param_list.append((param_tuple.population, param_tuple.source, param_tuple.sec_type, param_tuple.syn_name, param_tuple.param_path, float(prms_dict[param_pattern][i]))) this_results_config_dict[i] = results_param_list results_config_dict[gid] = this_results_config_dict else: prms = best[0] n_res = prms[0][1].shape[0] prms_dict = dict(prms) results_config_dict = {} for i in range(n_res): results_param_list = [] for param_pattern, param_tuple in zip(param_names, param_tuples): results_param_list.append((param_tuple.population, param_tuple.source, param_tuple.sec_type, param_tuple.syn_name, param_tuple.param_path, float(prms_dict[param_pattern][i]))) results_config_dict[i] = results_param_list write_to_yaml(file_path, { population: results_config_dict } ) comm.barrier()
def __init__(self, comm=None, config_file=None, template_paths="templates", hoc_lib_path=None, configure_nrn=True, dataset_prefix=None, config_prefix=None, results_path=None, results_file_id=None, results_namespace_id=None, node_rank_file=None, io_size=0, recording_profile=None, recording_fraction=1.0, coredat=False, tstop=0., v_init=-65, stimulus_onset=0.0, n_trials=1, max_walltime_hours=0.5, checkpoint_interval=500.0, checkpoint_clear_data=True, results_write_time=0, dt=0.025, ldbal=False, lptbal=False, transfer_debug=False, cell_selection_path=None, spike_input_path=None, spike_input_namespace=None, spike_input_attr=None, cleanup=True, cache_queries=False, profile_memory=False, verbose=False, **kwargs): """ :param comm: :class:'MPI.COMM_WORLD' :param config_file: str; model configuration file name :param template_paths: str; colon-separated list of paths to directories containing hoc cell templates :param hoc_lib_path: str; path to directory containing required hoc libraries :param dataset_prefix: str; path to directory containing required neuroh5 data files :param config_prefix: str; path to directory containing network and cell mechanism config files :param results_path: str; path to directory to export output files :param results_file_id: str; label for neuroh5 files to write spike and voltage trace data :param results_namespace_id: str; label for neuroh5 namespaces to write spike and voltage trace data :param node_rank_file: str; name of file specifying assignment of node gids to MPI ranks :param io_size: int; the number of MPI ranks to be used for I/O operations :param recording_profile: str; intracellular recording configuration to use :param coredat: bool; Save CoreNEURON data :param tstop: int; physical time to simulate (ms) :param v_init: float; initialization membrane potential (mV) :param stimulus_onset: float; starting time of stimulus (ms) :param max_walltime_hours: float; maximum wall time (hours) :param results_write_time: float; time to write out results at end of simulation :param dt: float; simulation time step :param ldbal: bool; estimate load balance based on cell complexity :param lptbal: bool; calculate load balance with LPT algorithm :param cleanup: bool; clean up auxiliary cell and synapse structures after network init :param profile: bool; profile memory usage :param cache_queries: bool; whether to use a cache to speed up queries to filter_synapses :param verbose: bool; print verbose diagnostic messages while constructing the network """ self.kwargs = kwargs self.SWC_Types = {} self.SWC_Type_index = {} self.Synapse_Types = {} self.layers = {} self.globals = {} self.gidset = set([]) self.gjlist = [] self.cells = defaultdict(list) self.artificial_cells = defaultdict(dict) self.biophys_cells = defaultdict(dict) self.spike_onset_delay = {} self.recording_sets = {} self.pc = None if comm is None: self.comm = MPI.COMM_WORLD else: self.comm = comm rank = self.comm.Get_rank() if configure_nrn: from dentate.neuron_utils import h, find_template # If true, the biophysical cells and synapses dictionary will be freed # as synapses and connections are instantiated. self.cleanup = cleanup # If true, compute and print memory usage at various points # during simulation initialization self.profile_memory = profile_memory # print verbose diagnostic messages self.verbose = verbose config_logging(verbose) self.logger = get_root_logger() # Directories for cell templates if template_paths is not None: self.template_paths = template_paths.split(':') else: self.template_paths = [] self.template_dict = {} # The location of required hoc libraries self.hoc_lib_path = hoc_lib_path # Checkpoint interval in ms of simulation time self.checkpoint_interval = max(float(checkpoint_interval), 1.0) self.checkpoint_clear_data = checkpoint_clear_data self.last_checkpoint = 0. # The location of all datasets self.dataset_prefix = dataset_prefix # The path where results files should be written self.results_path = results_path # Identifier used to construct results data namespaces self.results_namespace_id = results_namespace_id # Identifier used to construct results data files self.results_file_id = results_file_id # Number of MPI ranks to be used for I/O operations self.io_size = int(io_size) # Initialization voltage self.v_init = float(v_init) # simulation time [ms] self.tstop = float(tstop) # stimulus onset time [ms] self.stimulus_onset = float(stimulus_onset) # number of trials self.n_trials = int(n_trials) # maximum wall time in hours self.max_walltime_hours = float(max_walltime_hours) # time to write out results at end of simulation self.results_write_time = float(results_write_time) # time step self.dt = float(dt) # used to estimate cell complexity self.cxvec = None # measure/perform load balancing self.optldbal = ldbal self.optlptbal = lptbal self.transfer_debug = transfer_debug # Save CoreNEURON data self.coredat = coredat # cache queries to filter_synapses self.cache_queries = cache_queries self.config_prefix = config_prefix if config_file is not None: if config_prefix is not None: config_file_path = self.config_prefix + '/' + config_file else: config_file_path = config_file if not os.path.isfile(config_file_path): raise RuntimeError("configuration file %s was not found" % config_file_path) with open(config_file_path) as fp: self.model_config = yaml.load(fp, IncludeLoader) else: raise RuntimeError("missing configuration file") if 'Definitions' in self.model_config: self.parse_definitions() self.SWC_Type_index = dict([(item[1], item[0]) for item in viewitems(self.SWC_Types)]) if 'Global Parameters' in self.model_config: self.parse_globals() self.geometry = None if 'Geometry' in self.model_config: self.geometry = self.model_config['Geometry'] if 'Origin' in self.geometry['Parametric Surface']: self.parse_origin_coords() self.celltypes = self.model_config['Cell Types'] self.cell_attribute_info = {} # The name of this model if 'Model Name' in self.model_config: self.modelName = self.model_config['Model Name'] # The dataset to use for constructing the network if 'Dataset Name' in self.model_config: self.datasetName = self.model_config['Dataset Name'] if rank == 0: self.logger.info('env.dataset_prefix = %s' % str(self.dataset_prefix)) # Cell selection for simulations of subsets of the network self.cell_selection = None self.cell_selection_path = cell_selection_path if rank == 0: self.logger.info('env.cell_selection_path = %s' % str(self.cell_selection_path)) if cell_selection_path is not None: with open(cell_selection_path) as fp: self.cell_selection = yaml.load(fp, IncludeLoader) # Spike input path self.spike_input_path = spike_input_path self.spike_input_ns = spike_input_namespace self.spike_input_attr = spike_input_attr self.spike_input_attribute_info = None if self.spike_input_path is not None: if rank == 0: self.logger.info('env.spike_input_path = %s' % str(self.spike_input_path)) self.spike_input_attribute_info = \ read_cell_attribute_info(self.spike_input_path, sorted(self.Populations.keys()), comm=self.comm) if rank == 0: self.logger.info('env.spike_input_attribute_info = %s' % str(self.spike_input_attribute_info)) if results_path: if self.results_file_id is None: self.results_file_path = "%s/%s_results.h5" % ( self.results_path, self.modelName) else: self.results_file_path = "%s/%s_results_%s.h5" % ( self.results_path, self.modelName, self.results_file_id) else: if self.results_file_id is None: self.results_file_path = "%s_results.h5" % (self.modelName) else: self.results_file_path = "%s_results_%s.h5" % ( self.modelName, self.results_file_id) if 'Connection Generator' in self.model_config: self.parse_connection_config() self.parse_gapjunction_config() if self.dataset_prefix is not None: self.dataset_path = os.path.join(self.dataset_prefix, self.datasetName) if 'Cell Data' in self.model_config: self.data_file_path = os.path.join( self.dataset_path, self.model_config['Cell Data']) self.forest_file_path = os.path.join( self.dataset_path, self.model_config['Cell Data']) self.load_celltypes() else: self.data_file_path = None self.forest_file_path = None if rank == 0: self.logger.info('env.data_file_path = %s' % self.data_file_path) if 'Connection Data' in self.model_config: self.connectivity_file_path = os.path.join( self.dataset_path, self.model_config['Connection Data']) else: self.connectivity_file_path = None if 'Gap Junction Data' in self.model_config: self.gapjunctions_file_path = os.path.join( self.dataset_path, self.model_config['Gap Junction Data']) else: self.gapjunctions_file_path = None else: self.dataset_path = None self.data_file_path = None self.connectivity_file_path = None self.forest_file_path = None self.gapjunctions_file_path = None self.node_ranks = None if node_rank_file: self.load_node_ranks(node_rank_file) self.netclamp_config = None if 'Network Clamp' in self.model_config: self.parse_netclamp_config() self.stimulus_config = None self.arena_id = None self.trajectory_id = None if 'Stimulus' in self.model_config: self.parse_stimulus_config() self.init_stimulus_config(**kwargs) self.analysis_config = None if 'Analysis' in self.model_config: self.analysis_config = self.model_config['Analysis'] self.projection_dict = defaultdict(list) if self.dataset_prefix is not None: if rank == 0: self.logger.info('env.connectivity_file_path = %s' % str(self.connectivity_file_path)) if self.connectivity_file_path is not None: for (src, dst) in read_projection_names(self.connectivity_file_path, comm=self.comm): self.projection_dict[dst].append(src) if rank == 0: self.logger.info('projection_dict = %s' % str(self.projection_dict)) # Configuration profile for recording intracellular quantities assert ((recording_fraction >= 0.0) and (recording_fraction <= 1.0)) self.recording_fraction = recording_fraction self.recording_profile = None if ('Recording' in self.model_config) and (recording_profile is not None): self.recording_profile = self.model_config['Recording'][ 'Intracellular'][recording_profile] self.recording_profile['label'] = recording_profile for recvar, recdict in viewitems( self.recording_profile.get('synaptic quantity', {})): filters = {} if 'syn types' in recdict: filters['syn_types'] = recdict['syn types'] if 'swc types' in recdict: filters['swc_types'] = recdict['swc types'] if 'layers' in recdict: filters['layers'] = recdict['layers'] if 'sources' in recdict: filters['sources'] = recdict['sources'] syn_filters = get_syn_filter_dict(self, filters, convert=True) recdict['syn_filters'] = syn_filters # Configuration profile for recording local field potentials self.LFP_config = {} if 'Recording' in self.model_config: for label, config in viewitems( self.model_config['Recording']['LFP']): self.LFP_config[label] = { 'position': tuple(config['position']), 'maxEDist': config['maxEDist'], 'fraction': config['fraction'], 'rho': config['rho'], 'dt': config['dt'] } self.t_vec = None self.id_vec = None self.t_rec = None self.recs_dict = {} # Intracellular samples on this host for pop_name, _ in viewitems(self.Populations): self.recs_dict[pop_name] = defaultdict(list) # used to calculate model construction times and run time self.mkcellstime = 0 self.mkstimtime = 0 self.connectcellstime = 0 self.connectgjstime = 0 self.simtime = None self.lfp = {} self.edge_count = defaultdict(dict) self.syns_set = defaultdict(set)
def load_celltypes(self): """ :return: """ rank = self.comm.Get_rank() size = self.comm.Get_size() celltypes = self.celltypes if rank == 0: color = 1 else: color = 0 ## comm0 includes only rank 0 comm0 = self.comm.Split(color, 0) if rank == 0: self.logger.info('env.data_file_path = %s' % str(self.data_file_path)) self.cell_attribute_info = None population_ranges = None population_names = None type_names = None if rank == 0: population_names = read_population_names(self.data_file_path, comm0) (population_ranges, _) = read_population_ranges(self.data_file_path, comm0) type_names = sorted(population_ranges.keys()) self.cell_attribute_info = read_cell_attribute_info( self.data_file_path, population_names, comm=comm0) self.logger.info('population_names = %s' % str(population_names)) self.logger.info('population_ranges = %s' % str(population_ranges)) self.logger.info('attribute info: %s' % str(self.cell_attribute_info)) population_ranges = self.comm.bcast(population_ranges, root=0) population_names = self.comm.bcast(population_names, root=0) type_names = self.comm.bcast(type_names, root=0) self.cell_attribute_info = self.comm.bcast(self.cell_attribute_info, root=0) comm0.Free() for k in type_names: population_range = population_ranges.get(k, None) if population_range is not None: if k not in celltypes: celltypes[k] = {} celltypes[k]['start'] = population_ranges[k][0] celltypes[k]['num'] = population_ranges[k][1] if 'mechanism file' in celltypes[k]: celltypes[k]['mech_file_path'] = '%s/%s' % ( self.config_prefix, celltypes[k]['mechanism file']) mech_dict = None if rank == 0: mech_dict = read_from_yaml( celltypes[k]['mech_file_path']) mech_dict = self.comm.bcast(mech_dict, root=0) celltypes[k]['mech_dict'] = mech_dict if 'synapses' in celltypes[k]: synapses_dict = celltypes[k]['synapses'] if 'weights' in synapses_dict: weights_config = synapses_dict['weights'] if isinstance(weights_config, list): weights_dicts = weights_config else: weights_dicts = [weights_config] for weights_dict in weights_dicts: if 'expr' in weights_dict: expr = weights_dict['expr'] parameter = weights_dict['parameter'] const = weights_dict.get('const', None) clos = ExprClosure(parameters=parameter, expr=expr, consts=const) weights_dict['closure'] = clos synapses_dict['weights'] = weights_dicts