Exemple #1
0
    def load_celltypes(self):
        """

        :return:
        """
        rank = self.comm.Get_rank()
        size = self.comm.Get_size()
        celltypes = self.celltypes
        typenames = sorted(celltypes.keys())

        if rank == 0:
            self.logger.info('env.data_file_path = %s' %
                             str(self.data_file_path))

        (population_ranges,
         _) = read_population_ranges(self.data_file_path, self.comm)
        if rank == 0:
            self.logger.info('population_ranges = %s' % str(population_ranges))

        for k in typenames:
            population_range = population_ranges.get(k, None)
            if population_range is not None:
                celltypes[k]['start'] = population_ranges[k][0]
                celltypes[k]['num'] = population_ranges[k][1]
                if 'mechanism file' in celltypes[k]:
                    celltypes[k]['mech_file_path'] = '%s/%s' % (
                        self.config_prefix, celltypes[k]['mechanism file'])
                    mech_dict = read_from_yaml(celltypes[k]['mech_file_path'])
                    celltypes[k]['mech_dict'] = mech_dict
                if 'synapses' in celltypes[k]:
                    synapses_dict = celltypes[k]['synapses']
                    if 'weights' in synapses_dict:
                        weights_config = synapses_dict['weights']
                        if isinstance(weights_config, list):
                            weights_dicts = weights_config
                        else:
                            weights_dicts = [weights_config]
                        for weights_dict in weights_dicts:
                            if 'expr' in weights_dict:
                                expr = weights_dict['expr']
                                parameter = weights_dict['parameter']
                                const = weights_dict.get('const', {})
                                clos = ExprClosure(parameter, expr, const)
                                weights_dict['closure'] = clos
                        synapses_dict['weights'] = weights_dicts

        population_names = read_population_names(self.data_file_path,
                                                 self.comm)
        if rank == 0:
            self.logger.info('population_names = %s' % str(population_names))
        self.cell_attribute_info = read_cell_attribute_info(
            self.data_file_path, population_names, comm=self.comm)

        if rank == 0:
            self.logger.info('attribute info: %s' %
                             str(self.cell_attribute_info))
Exemple #2
0
def query_cell_attributes(input_file, population_names, namespace_ids=None):

    pop_state_dict = {}

    logger.info('Querying cell attribute data...')

    attr_info_dict = read_cell_attribute_info(input_file,
                                              populations=population_names,
                                              read_cell_index=True)

    namespace_id_lst = []
    for pop_name in attr_info_dict:
        cell_index = None
        pop_state_dict[pop_name] = {}
        if namespace_ids is None:
            namespace_id_lst = attr_info_dict[pop_name].keys()
        else:
            namespace_id_lst = namespace_ids
    return namespace_id_lst, attr_info_dict
Exemple #3
0
def read_state(input_file,
               population_names,
               namespace_id,
               time_variable='t',
               state_variable='v',
               time_range=None,
               max_units=None,
               gid=None,
               comm=None,
               n_trials=-1):
    if comm is None:
        comm = MPI.COMM_WORLD
    pop_state_dict = {}

    logger.info(
        'Reading state data from populations %s, namespace %s gid = %s...' %
        (str(population_names), namespace_id, str(gid)))

    attr_info_dict = read_cell_attribute_info(input_file,
                                              populations=population_names,
                                              read_cell_index=True)

    for pop_name in population_names:
        cell_index = None
        pop_state_dict[pop_name] = {}
        for attr_name, attr_cell_index in attr_info_dict[pop_name][
                namespace_id]:
            if state_variable == attr_name:
                cell_index = attr_cell_index

        if cell_index is None:
            raise RuntimeError(
                'read_state: Unable to find recordings for state variable %s in population %s namespace %s'
                % (state_variable, pop_name, str(namespace_id)))
        cell_set = set(cell_index)

        # Limit to max_units
        if gid is None:
            if (max_units is not None) and (len(cell_set) > max_units):
                logger.info(
                    '  Reading only randomly sampled %i out of %i units for population %s'
                    % (max_units, len(cell_set), pop_name))
                sample_inds = np.random.randint(0,
                                                len(cell_set) - 1,
                                                size=int(max_units))
                cell_set_lst = list(cell_set)
                gid_set = set([cell_set_lst[i] for i in sample_inds])
            else:
                gid_set = cell_set
        else:
            gid_set = set(gid)

        state_dict = {}
        if gid is None:
            valiter = read_cell_attributes(input_file,
                                           pop_name,
                                           namespace=namespace_id,
                                           comm=comm)
        else:
            valiter = read_cell_attribute_selection(input_file,
                                                    pop_name,
                                                    namespace=namespace_id,
                                                    selection=list(gid_set),
                                                    comm=comm)

        if time_range is None:
            for cellind, vals in valiter:
                if cellind is not None:
                    trial_dur = vals.get('trial duration', None)
                    distance = vals.get('distance', [None])[0]
                    section = vals.get('section', [None])[0]
                    loc = vals.get('loc', [None])[0]
                    tvals = np.asarray(vals[time_variable], dtype=np.float32)
                    svals = np.asarray(vals[state_variable], dtype=np.float32)
                    trial_bounds = list(
                        np.where(np.isclose(tvals, tvals[0], atol=1e-4))[0])
                    n_trial_bounds = len(trial_bounds)
                    trial_bounds.append(len(tvals))
                    if n_trials == -1:
                        this_n_trials = n_trial_bounds
                    else:
                        this_n_trials = min(n_trial_bounds, n_trials)

                    if this_n_trials > 1:
                        state_dict[cellind] = (np.split(
                            tvals, trial_bounds[1:n_trials]),
                                               np.split(
                                                   svals,
                                                   trial_bounds[1:n_trials]),
                                               distance, section, loc)
                    else:
                        state_dict[cellind] = ([tvals[:trial_bounds[1]]
                                                ], [svals[:trial_bounds[1]]],
                                               distance, section, loc)

        else:
            for cellind, vals in valiter:
                if cellind is not None:
                    distance = vals.get('distance', [None])[0]
                    section = vals.get('section', [None])[0]
                    loc = vals.get('loc', [None])[0]
                    tinds = np.argwhere((vals[time_variable] <= time_range[1])
                                        &
                                        (vals[time_variable] >= time_range[0]))
                    tvals = np.asarray(vals[time_variable][tinds],
                                       dtype=np.float32).reshape((-1, ))
                    svals = np.asarray(vals[state_variable][tinds],
                                       dtype=np.float32).reshape((-1, ))
                    trial_bounds = list(
                        np.where(np.isclose(tvals, tvals[0], atol=1e-4))[0])
                    n_trial_bounds = len(trial_bounds)
                    trial_bounds.append(len(tvals))
                    if n_trials == -1:
                        this_n_trials = n_trial_bounds
                    else:
                        this_n_trials = min(n_trial_bounds, n_trials)

                    if this_n_trials > 1:
                        state_dict[cellind] = (np.split(
                            tvals, trial_bounds[1:n_trials]),
                                               np.split(
                                                   svals,
                                                   trial_bounds[1:n_trials]),
                                               distance, section, loc)
                    else:
                        state_dict[cellind] = ([tvals[:trial_bounds[1]]
                                                ], [svals[:trial_bounds[1]]],
                                               distance, section, loc)

        pop_state_dict[pop_name] = state_dict

    return {
        'states': pop_state_dict,
        'time_variable': time_variable,
        'state_variable': state_variable
    }
def main(config_file, population, dt, gid, gid_selection_file, arena_id, trajectory_id, generate_weights,
         t_max, t_min,  nprocs_per_worker, n_epochs, n_initial, initial_maxiter, initial_method, optimizer_method, surrogate_method,
         population_size, num_generations, resample_fraction, mutation_rate,
         template_paths, dataset_prefix, config_prefix,
         param_config_name, selectivity_config_name, param_type, recording_profile, results_file, results_path, spike_events_path,
         spike_events_namespace, spike_events_t, input_features_path, input_features_namespaces, n_trials,
         trial_regime, problem_regime, target_features_path, target_features_namespace, target_state_variable,
         target_state_filter, use_coreneuron, cooperative_init, spawn_startup_wait):
    """
    Optimize the input stimulus selectivity of the specified cell in a network clamp configuration.
    """
    init_params = dict(locals())

    comm = MPI.COMM_WORLD
    size = comm.Get_size()
    rank = comm.Get_rank()

    results_file_id = None
    if rank == 0:
        results_file_id = generate_results_file_id(population, gid)
        
    results_file_id = comm.bcast(results_file_id, root=0)
    comm.barrier()
    
    np.seterr(all='raise')
    verbose = True
    cache_queries = True

    config_logging(verbose)

    cell_index_set = set([])
    if gid_selection_file is not None:
        with open(gid_selection_file, 'r') as f:
            lines = f.readlines()
            for line in lines:
                gid = int(line)
                cell_index_set.add(gid)
    elif gid is not None:
        cell_index_set.add(gid)
    else:
        comm.barrier()
        comm0 = comm.Split(2 if rank == 0 else 1, 0)
        if rank == 0:
            env = Env(**init_params, comm=comm0)
            attr_info_dict = read_cell_attribute_info(env.data_file_path, populations=[population],
                                                      read_cell_index=True, comm=comm0)
            cell_index = None
            attr_name, attr_cell_index = next(iter(attr_info_dict[population]['Trees']))
            cell_index_set = set(attr_cell_index)
        comm.barrier()
        cell_index_set = comm.bcast(cell_index_set, root=0)
        comm.barrier()
        comm0.Free()
    init_params['cell_index_set'] = cell_index_set
    del(init_params['gid'])

    params = dict(locals())
    env = Env(**params)
    if size == 1:
        configure_hoc_env(env)
        init(env, population, cell_index_set, arena_id, trajectory_id, n_trials,
             spike_events_path, spike_events_namespace=spike_events_namespace, 
             spike_train_attr_name=spike_events_t,
             input_features_path=input_features_path,
             input_features_namespaces=input_features_namespaces,
             generate_weights_pops=set(generate_weights), 
             t_min=t_min, t_max=t_max)
        
    if (population in env.netclamp_config.optimize_parameters[param_type]):
        opt_params = env.netclamp_config.optimize_parameters[param_type][population]
    else:
        raise RuntimeError(f'optimize_selectivity: population {population} does not have optimization configuration')

    if target_state_variable is None:
        target_state_variable = 'v'
    
    init_params['target_features_arena'] = arena_id
    init_params['target_features_trajectory'] = trajectory_id
    opt_state_baseline = opt_params['Targets']['state'][target_state_variable]['baseline']
    init_params['state_baseline'] = opt_state_baseline
    init_params['state_variable'] = target_state_variable
    init_params['state_filter'] = target_state_filter
    init_objfun_name = 'init_selectivity_objfun'
        
    best = optimize_run(env, population, param_config_name, selectivity_config_name, init_objfun_name, problem_regime=problem_regime,
                        n_epochs=n_epochs, n_initial=n_initial, initial_maxiter=initial_maxiter, initial_method=initial_method, 
                        optimizer_method=optimizer_method, surrogate_method=surrogate_method, population_size=population_size, 
                        num_generations=num_generations, resample_fraction=resample_fraction, mutation_rate=mutation_rate, 
                        param_type=param_type, init_params=init_params, results_file=results_file, nprocs_per_worker=nprocs_per_worker, 
                        cooperative_init=cooperative_init, spawn_startup_wait=spawn_startup_wait, verbose=verbose)
    
    opt_param_config = optimization_params(env.netclamp_config.optimize_parameters, [population], param_config_name, param_type)
    if best is not None:
        if results_path is not None:
            run_ts = time.strftime("%Y%m%d_%H%M%S")
            file_path = f'{results_path}/optimize_selectivity.{run_ts}.yaml'
            param_names = opt_param_config.param_names
            param_tuples = opt_param_config.param_tuples

            if ProblemRegime[problem_regime] == ProblemRegime.every:
                results_config_dict = {}
                for gid, prms in viewitems(best):
                    n_res = prms[0][1].shape[0]
                    prms_dict = dict(prms)
                    this_results_config_dict = {}
                    for i in range(n_res):
                        results_param_list = []
                        for param_pattern, param_tuple in zip(param_names, param_tuples):
                            results_param_list.append((param_tuple.population,
                                                       param_tuple.source,
                                                       param_tuple.sec_type,
                                                       param_tuple.syn_name,
                                                       param_tuple.param_path,
                                                       float(prms_dict[param_pattern][i])))
                        this_results_config_dict[i] = results_param_list
                    results_config_dict[gid] = this_results_config_dict
                    
            else:
                prms = best[0]
                n_res = prms[0][1].shape[0]
                prms_dict = dict(prms)
                results_config_dict = {}
                for i in range(n_res):
                    results_param_list = []
                    for param_pattern, param_tuple in zip(param_names, param_tuples):
                        results_param_list.append((param_tuple.population,
                                                   param_tuple.source,
                                                   param_tuple.sec_type,
                                                   param_tuple.syn_name,
                                                   param_tuple.param_path,
                                                   float(prms_dict[param_pattern][i])))
                    results_config_dict[i] = results_param_list

            write_to_yaml(file_path, { population: results_config_dict } )

            
    comm.barrier()
Exemple #5
0
    def __init__(self,
                 comm=None,
                 config_file=None,
                 template_paths="templates",
                 hoc_lib_path=None,
                 configure_nrn=True,
                 dataset_prefix=None,
                 config_prefix=None,
                 results_path=None,
                 results_file_id=None,
                 results_namespace_id=None,
                 node_rank_file=None,
                 io_size=0,
                 recording_profile=None,
                 recording_fraction=1.0,
                 coredat=False,
                 tstop=0.,
                 v_init=-65,
                 stimulus_onset=0.0,
                 n_trials=1,
                 max_walltime_hours=0.5,
                 checkpoint_interval=500.0,
                 checkpoint_clear_data=True,
                 results_write_time=0,
                 dt=0.025,
                 ldbal=False,
                 lptbal=False,
                 transfer_debug=False,
                 cell_selection_path=None,
                 spike_input_path=None,
                 spike_input_namespace=None,
                 spike_input_attr=None,
                 cleanup=True,
                 cache_queries=False,
                 profile_memory=False,
                 verbose=False,
                 **kwargs):
        """
        :param comm: :class:'MPI.COMM_WORLD'
        :param config_file: str; model configuration file name
        :param template_paths: str; colon-separated list of paths to directories containing hoc cell templates
        :param hoc_lib_path: str; path to directory containing required hoc libraries
        :param dataset_prefix: str; path to directory containing required neuroh5 data files
        :param config_prefix: str; path to directory containing network and cell mechanism config files
        :param results_path: str; path to directory to export output files
        :param results_file_id: str; label for neuroh5 files to write spike and voltage trace data
        :param results_namespace_id: str; label for neuroh5 namespaces to write spike and voltage trace data
        :param node_rank_file: str; name of file specifying assignment of node gids to MPI ranks
        :param io_size: int; the number of MPI ranks to be used for I/O operations
        :param recording_profile: str; intracellular recording configuration to use
        :param coredat: bool; Save CoreNEURON data
        :param tstop: int; physical time to simulate (ms)
        :param v_init: float; initialization membrane potential (mV)
        :param stimulus_onset: float; starting time of stimulus (ms)
        :param max_walltime_hours: float; maximum wall time (hours)
        :param results_write_time: float; time to write out results at end of simulation
        :param dt: float; simulation time step
        :param ldbal: bool; estimate load balance based on cell complexity
        :param lptbal: bool; calculate load balance with LPT algorithm
        :param cleanup: bool; clean up auxiliary cell and synapse structures after network init
        :param profile: bool; profile memory usage
        :param cache_queries: bool; whether to use a cache to speed up queries to filter_synapses
        :param verbose: bool; print verbose diagnostic messages while constructing the network
        """
        self.kwargs = kwargs

        self.SWC_Types = {}
        self.SWC_Type_index = {}
        self.Synapse_Types = {}
        self.layers = {}
        self.globals = {}

        self.gidset = set([])
        self.gjlist = []
        self.cells = defaultdict(list)
        self.artificial_cells = defaultdict(dict)
        self.biophys_cells = defaultdict(dict)
        self.spike_onset_delay = {}
        self.recording_sets = {}

        self.pc = None
        if comm is None:
            self.comm = MPI.COMM_WORLD
        else:
            self.comm = comm
        rank = self.comm.Get_rank()

        if configure_nrn:
            from dentate.neuron_utils import h, find_template

        # If true, the biophysical cells and synapses dictionary will be freed
        # as synapses and connections are instantiated.
        self.cleanup = cleanup

        # If true, compute and print memory usage at various points
        # during simulation initialization
        self.profile_memory = profile_memory

        # print verbose diagnostic messages
        self.verbose = verbose
        config_logging(verbose)
        self.logger = get_root_logger()

        # Directories for cell templates
        if template_paths is not None:
            self.template_paths = template_paths.split(':')
        else:
            self.template_paths = []
        self.template_dict = {}

        # The location of required hoc libraries
        self.hoc_lib_path = hoc_lib_path

        # Checkpoint interval in ms of simulation time
        self.checkpoint_interval = max(float(checkpoint_interval), 1.0)
        self.checkpoint_clear_data = checkpoint_clear_data
        self.last_checkpoint = 0.

        # The location of all datasets
        self.dataset_prefix = dataset_prefix

        # The path where results files should be written
        self.results_path = results_path

        # Identifier used to construct results data namespaces
        self.results_namespace_id = results_namespace_id
        # Identifier used to construct results data files
        self.results_file_id = results_file_id

        # Number of MPI ranks to be used for I/O operations
        self.io_size = int(io_size)

        # Initialization voltage
        self.v_init = float(v_init)

        # simulation time [ms]
        self.tstop = float(tstop)

        # stimulus onset time [ms]
        self.stimulus_onset = float(stimulus_onset)

        # number of trials
        self.n_trials = int(n_trials)

        # maximum wall time in hours
        self.max_walltime_hours = float(max_walltime_hours)

        # time to write out results at end of simulation
        self.results_write_time = float(results_write_time)

        # time step
        self.dt = float(dt)

        # used to estimate cell complexity
        self.cxvec = None

        # measure/perform load balancing
        self.optldbal = ldbal
        self.optlptbal = lptbal

        self.transfer_debug = transfer_debug

        # Save CoreNEURON data
        self.coredat = coredat

        # cache queries to filter_synapses
        self.cache_queries = cache_queries

        self.config_prefix = config_prefix
        if config_file is not None:
            if config_prefix is not None:
                config_file_path = self.config_prefix + '/' + config_file
            else:
                config_file_path = config_file
            if not os.path.isfile(config_file_path):
                raise RuntimeError("configuration file %s was not found" %
                                   config_file_path)
            with open(config_file_path) as fp:
                self.model_config = yaml.load(fp, IncludeLoader)
        else:
            raise RuntimeError("missing configuration file")

        if 'Definitions' in self.model_config:
            self.parse_definitions()
            self.SWC_Type_index = dict([(item[1], item[0])
                                        for item in viewitems(self.SWC_Types)])

        if 'Global Parameters' in self.model_config:
            self.parse_globals()

        self.geometry = None
        if 'Geometry' in self.model_config:
            self.geometry = self.model_config['Geometry']

        if 'Origin' in self.geometry['Parametric Surface']:
            self.parse_origin_coords()

        self.celltypes = self.model_config['Cell Types']
        self.cell_attribute_info = {}

        # The name of this model
        if 'Model Name' in self.model_config:
            self.modelName = self.model_config['Model Name']
        # The dataset to use for constructing the network
        if 'Dataset Name' in self.model_config:
            self.datasetName = self.model_config['Dataset Name']

        if rank == 0:
            self.logger.info('env.dataset_prefix = %s' %
                             str(self.dataset_prefix))

        # Cell selection for simulations of subsets of the network
        self.cell_selection = None
        self.cell_selection_path = cell_selection_path
        if rank == 0:
            self.logger.info('env.cell_selection_path = %s' %
                             str(self.cell_selection_path))
        if cell_selection_path is not None:
            with open(cell_selection_path) as fp:
                self.cell_selection = yaml.load(fp, IncludeLoader)

        # Spike input path
        self.spike_input_path = spike_input_path
        self.spike_input_ns = spike_input_namespace
        self.spike_input_attr = spike_input_attr
        self.spike_input_attribute_info = None
        if self.spike_input_path is not None:
            if rank == 0:
                self.logger.info('env.spike_input_path = %s' %
                                 str(self.spike_input_path))
            self.spike_input_attribute_info = \
              read_cell_attribute_info(self.spike_input_path, sorted(self.Populations.keys()), comm=self.comm)
            if rank == 0:
                self.logger.info('env.spike_input_attribute_info = %s' %
                                 str(self.spike_input_attribute_info))
        if results_path:
            if self.results_file_id is None:
                self.results_file_path = "%s/%s_results.h5" % (
                    self.results_path, self.modelName)
            else:
                self.results_file_path = "%s/%s_results_%s.h5" % (
                    self.results_path, self.modelName, self.results_file_id)
        else:
            if self.results_file_id is None:
                self.results_file_path = "%s_results.h5" % (self.modelName)
            else:
                self.results_file_path = "%s_results_%s.h5" % (
                    self.modelName, self.results_file_id)

        if 'Connection Generator' in self.model_config:
            self.parse_connection_config()
            self.parse_gapjunction_config()

        if self.dataset_prefix is not None:
            self.dataset_path = os.path.join(self.dataset_prefix,
                                             self.datasetName)
            if 'Cell Data' in self.model_config:
                self.data_file_path = os.path.join(
                    self.dataset_path, self.model_config['Cell Data'])
                self.forest_file_path = os.path.join(
                    self.dataset_path, self.model_config['Cell Data'])
                self.load_celltypes()
            else:
                self.data_file_path = None
                self.forest_file_path = None
            if rank == 0:
                self.logger.info('env.data_file_path = %s' %
                                 self.data_file_path)
            if 'Connection Data' in self.model_config:
                self.connectivity_file_path = os.path.join(
                    self.dataset_path, self.model_config['Connection Data'])
            else:
                self.connectivity_file_path = None
            if 'Gap Junction Data' in self.model_config:
                self.gapjunctions_file_path = os.path.join(
                    self.dataset_path, self.model_config['Gap Junction Data'])
            else:
                self.gapjunctions_file_path = None
        else:
            self.dataset_path = None
            self.data_file_path = None
            self.connectivity_file_path = None
            self.forest_file_path = None
            self.gapjunctions_file_path = None

        self.node_ranks = None
        if node_rank_file:
            self.load_node_ranks(node_rank_file)

        self.netclamp_config = None
        if 'Network Clamp' in self.model_config:
            self.parse_netclamp_config()

        self.stimulus_config = None
        self.arena_id = None
        self.trajectory_id = None
        if 'Stimulus' in self.model_config:
            self.parse_stimulus_config()
            self.init_stimulus_config(**kwargs)

        self.analysis_config = None
        if 'Analysis' in self.model_config:
            self.analysis_config = self.model_config['Analysis']

        self.projection_dict = defaultdict(list)
        if self.dataset_prefix is not None:
            if rank == 0:
                self.logger.info('env.connectivity_file_path = %s' %
                                 str(self.connectivity_file_path))
            if self.connectivity_file_path is not None:
                for (src,
                     dst) in read_projection_names(self.connectivity_file_path,
                                                   comm=self.comm):
                    self.projection_dict[dst].append(src)
                if rank == 0:
                    self.logger.info('projection_dict = %s' %
                                     str(self.projection_dict))

        # Configuration profile for recording intracellular quantities
        assert ((recording_fraction >= 0.0) and (recording_fraction <= 1.0))
        self.recording_fraction = recording_fraction
        self.recording_profile = None
        if ('Recording' in self.model_config) and (recording_profile
                                                   is not None):
            self.recording_profile = self.model_config['Recording'][
                'Intracellular'][recording_profile]
            self.recording_profile['label'] = recording_profile
            for recvar, recdict in viewitems(
                    self.recording_profile.get('synaptic quantity', {})):
                filters = {}
                if 'syn types' in recdict:
                    filters['syn_types'] = recdict['syn types']
                if 'swc types' in recdict:
                    filters['swc_types'] = recdict['swc types']
                if 'layers' in recdict:
                    filters['layers'] = recdict['layers']
                if 'sources' in recdict:
                    filters['sources'] = recdict['sources']
                syn_filters = get_syn_filter_dict(self, filters, convert=True)
                recdict['syn_filters'] = syn_filters

        # Configuration profile for recording local field potentials
        self.LFP_config = {}
        if 'Recording' in self.model_config:
            for label, config in viewitems(
                    self.model_config['Recording']['LFP']):
                self.LFP_config[label] = {
                    'position': tuple(config['position']),
                    'maxEDist': config['maxEDist'],
                    'fraction': config['fraction'],
                    'rho': config['rho'],
                    'dt': config['dt']
                }

        self.t_vec = None
        self.id_vec = None
        self.t_rec = None
        self.recs_dict = {}  # Intracellular samples on this host
        for pop_name, _ in viewitems(self.Populations):
            self.recs_dict[pop_name] = defaultdict(list)

        # used to calculate model construction times and run time
        self.mkcellstime = 0
        self.mkstimtime = 0
        self.connectcellstime = 0
        self.connectgjstime = 0

        self.simtime = None
        self.lfp = {}

        self.edge_count = defaultdict(dict)
        self.syns_set = defaultdict(set)
Exemple #6
0
    def load_celltypes(self):
        """

        :return:
        """
        rank = self.comm.Get_rank()
        size = self.comm.Get_size()
        celltypes = self.celltypes

        if rank == 0:
            color = 1
        else:
            color = 0
        ## comm0 includes only rank 0
        comm0 = self.comm.Split(color, 0)

        if rank == 0:
            self.logger.info('env.data_file_path = %s' %
                             str(self.data_file_path))

        self.cell_attribute_info = None
        population_ranges = None
        population_names = None
        type_names = None
        if rank == 0:
            population_names = read_population_names(self.data_file_path,
                                                     comm0)
            (population_ranges,
             _) = read_population_ranges(self.data_file_path, comm0)
            type_names = sorted(population_ranges.keys())
            self.cell_attribute_info = read_cell_attribute_info(
                self.data_file_path, population_names, comm=comm0)
            self.logger.info('population_names = %s' % str(population_names))
            self.logger.info('population_ranges = %s' % str(population_ranges))
            self.logger.info('attribute info: %s' %
                             str(self.cell_attribute_info))
        population_ranges = self.comm.bcast(population_ranges, root=0)
        population_names = self.comm.bcast(population_names, root=0)
        type_names = self.comm.bcast(type_names, root=0)
        self.cell_attribute_info = self.comm.bcast(self.cell_attribute_info,
                                                   root=0)
        comm0.Free()

        for k in type_names:
            population_range = population_ranges.get(k, None)
            if population_range is not None:
                if k not in celltypes:
                    celltypes[k] = {}
                celltypes[k]['start'] = population_ranges[k][0]
                celltypes[k]['num'] = population_ranges[k][1]
                if 'mechanism file' in celltypes[k]:
                    celltypes[k]['mech_file_path'] = '%s/%s' % (
                        self.config_prefix, celltypes[k]['mechanism file'])
                    mech_dict = None
                    if rank == 0:
                        mech_dict = read_from_yaml(
                            celltypes[k]['mech_file_path'])
                    mech_dict = self.comm.bcast(mech_dict, root=0)
                    celltypes[k]['mech_dict'] = mech_dict
                if 'synapses' in celltypes[k]:
                    synapses_dict = celltypes[k]['synapses']
                    if 'weights' in synapses_dict:
                        weights_config = synapses_dict['weights']
                        if isinstance(weights_config, list):
                            weights_dicts = weights_config
                        else:
                            weights_dicts = [weights_config]
                        for weights_dict in weights_dicts:
                            if 'expr' in weights_dict:
                                expr = weights_dict['expr']
                                parameter = weights_dict['parameter']
                                const = weights_dict.get('const', None)
                                clos = ExprClosure(parameters=parameter,
                                                   expr=expr,
                                                   consts=const)
                                weights_dict['closure'] = clos
                        synapses_dict['weights'] = weights_dicts