Esempio n. 1
0
    def load_celltypes(self):
        """

        :return:
        """
        rank = self.comm.Get_rank()
        size = self.comm.Get_size()
        celltypes = self.celltypes
        typenames = sorted(celltypes.keys())

        if rank == 0:
            self.logger.info('env.data_file_path = %s' %
                             str(self.data_file_path))

        (population_ranges,
         _) = read_population_ranges(self.data_file_path, self.comm)
        if rank == 0:
            self.logger.info('population_ranges = %s' % str(population_ranges))

        for k in typenames:
            population_range = population_ranges.get(k, None)
            if population_range is not None:
                celltypes[k]['start'] = population_ranges[k][0]
                celltypes[k]['num'] = population_ranges[k][1]
                if 'mechanism file' in celltypes[k]:
                    celltypes[k]['mech_file_path'] = '%s/%s' % (
                        self.config_prefix, celltypes[k]['mechanism file'])
                    mech_dict = read_from_yaml(celltypes[k]['mech_file_path'])
                    celltypes[k]['mech_dict'] = mech_dict
                if 'synapses' in celltypes[k]:
                    synapses_dict = celltypes[k]['synapses']
                    if 'weights' in synapses_dict:
                        weights_config = synapses_dict['weights']
                        if isinstance(weights_config, list):
                            weights_dicts = weights_config
                        else:
                            weights_dicts = [weights_config]
                        for weights_dict in weights_dicts:
                            if 'expr' in weights_dict:
                                expr = weights_dict['expr']
                                parameter = weights_dict['parameter']
                                const = weights_dict.get('const', {})
                                clos = ExprClosure(parameter, expr, const)
                                weights_dict['closure'] = clos
                        synapses_dict['weights'] = weights_dicts

        population_names = read_population_names(self.data_file_path,
                                                 self.comm)
        if rank == 0:
            self.logger.info('population_names = %s' % str(population_names))
        self.cell_attribute_info = read_cell_attribute_info(
            self.data_file_path, population_names, comm=self.comm)

        if rank == 0:
            self.logger.info('attribute info: %s' %
                             str(self.cell_attribute_info))
Esempio n. 2
0
def main(config_path, params_id, output_file_name, verbose):

    config_logging(verbose)
    logger = utils.get_script_logger(os.path.basename(__file__))

    eval_config = read_from_yaml(config_path,
                                 include_loader=utils.IncludeLoader)
    network_param_spec_src = eval_config['param_spec']
    network_param_values = eval_config['param_values']
    target_populations = eval_config['target_populations']

    network_param_spec = make_param_spec(target_populations,
                                         network_param_spec_src)

    def from_param_list(x):
        result = []
        for i, (param_name, param_tuple) in enumerate(
                zip(network_param_spec.param_names,
                    network_param_spec.param_tuples)):
            param_range = param_tuple.param_range
            #            assert((x[i] >= param_range[0]) and (x[i] <= param_range[1]))
            result.append((param_tuple, x[i]))
        return result

    x = network_param_values[params_id]
    param_tuple_values = from_param_list(x)

    def rec_dd():
        return defaultdict(rec_dd)

    def dd2dict(d):
        for k, v in d.items():

            if isinstance(v, dict):
                d[k] = dd2dict(v)
        return dict(d)

    param_output_ddict = rec_dd()

    for param_tuple, param_value in param_tuple_values:
        if isinstance(param_tuple.param_path, tuple):
            param_output_ddict[param_tuple.population][param_tuple.source][
                param_tuple.sec_type][param_tuple.syn_name][
                    param_tuple.param_path[0]][
                        param_tuple.param_path[1]] = param_value
        else:
            param_output_ddict[param_tuple.population][param_tuple.source][
                param_tuple.sec_type][param_tuple.syn_name][
                    param_tuple.param_path] = param_value

    param_output_dict = dd2dict(param_output_ddict)
    pprint.pprint(param_output_dict)
    write_to_yaml(output_file_name, param_output_dict)
Esempio n. 3
0
def main(config_path, params_id, output_file_name, verbose):

    config_logging(verbose)
    logger = utils.get_script_logger(os.path.basename(__file__))

    eval_config = read_from_yaml(config_path,
                                 include_loader=utils.IncludeLoader)
    network_param_spec_src = eval_config['param_spec']
    network_param_values = eval_config['param_values']
    target_populations = eval_config['target_populations']

    network_param_spec = make_param_spec(target_populations,
                                         network_param_spec_src)

    def from_param_list(x):
        result = []
        for i, (param_name, param_tuple) in enumerate(
                zip(network_param_spec.param_names,
                    network_param_spec.param_tuples)):
            param_range = param_tuple.param_range
            #            assert((x[i] >= param_range[0]) and (x[i] <= param_range[1]))
            result.append((param_tuple, x[i]))
        return result

    params_id_list = []
    if params_id is None:
        params_id_list = list(network_param_values.keys())
    else:
        params_id_list = [params_id]

    param_output_dict = dict()
    for this_params_id in params_id_list:
        x = network_param_values[this_params_id]
        param_tuple_values = from_param_list(x)
        this_param_list = []
        for param_tuple, param_value in param_tuple_values:
            this_param_list.append((param_tuple.population, param_tuple.source,
                                    param_tuple.sec_type, param_tuple.syn_name,
                                    param_tuple.param_path, param_value))
        param_output_dict[this_params_id] = this_param_list

    pprint.pprint(param_output_dict)
    if output_file_name is not None:
        write_to_yaml(output_file_name, param_output_dict)
Esempio n. 4
0
def generate_param_lattice(config_path, n_samples, output_file_dir, output_file_name, maxiter=5, verbose=False):
    from dmosopt import sampling

    logger = utils.get_script_logger(os.path.basename(__file__))

    output_path = None
    if output_file_name is not None:
        output_path = f'{output_file_dir}/{output_file_name}'
    eval_config = read_from_yaml(config_path, include_loader=utils.IncludeLoader)
    network_param_spec_src = eval_config['param_spec']

    target_populations = eval_config['target_populations']
    network_param_spec = make_param_spec(target_populations, network_param_spec_src)
    param_tuples = network_param_spec.param_tuples
    param_names = network_param_spec.param_names
    n_params = len(param_tuples)

    n_init = n_params * n_samples
    Xinit = sampling.glp(n_init, n_params, maxiter=maxiter)

    ub = []
    lb = []
    for param_name, param_tuple in zip(param_names, param_tuples):
        param_range = param_tuple.param_range
        ub.append(param_range[1])
        lb.append(param_range[0])

    ub = np.asarray(ub)
    lb = np.asarray(lb)

    for i in range(n_init):
        Xinit[i,:] = Xinit[i,:] * (ub - lb) + lb

    output_dict = {}
    for i in range(Xinit.shape[0]):
        output_dict[i] = list([float(x) for x in Xinit[i, :]])

    if output_path is not None:    
        write_to_yaml(output_path, output_dict)
    else:
        pprint.pprint(output_dict)
Esempio n. 5
0
def main(config_path, target_features_path, target_features_namespace, optimize_file_dir, optimize_file_name, nprocs_per_worker, n_epochs, n_initial, initial_maxiter, initial_method, optimizer_method, population_size, num_generations, resample_fraction, mutation_rate, collective_mode, spawn_startup_wait, verbose):

    network_args = click.get_current_context().args
    network_config = {}
    for arg in network_args:
        kv = arg.split("=")
        if len(kv) > 1:
            k,v = kv
            network_config[k.replace('--', '').replace('-', '_')] = v
        else:
            k = kv[0]
            network_config[k.replace('--', '').replace('-', '_')] = True

    run_ts = datetime.datetime.today().strftime('%Y%m%d_%H%M')

    if optimize_file_name is None:
        optimize_file_name=f"dmosopt.optimize_network_{run_ts}.h5"
    operational_config = read_from_yaml(config_path)
    operational_config['run_ts'] = run_ts
    if target_features_path is not None:
        operational_config['target_features_path'] = target_features_path
    if target_features_namespace is not None:
        operational_config['target_features_namespace'] = target_features_namespace

    network_config.update(operational_config.get('kwargs', {}))
    env = Env(**network_config)

    objective_names = operational_config['objective_names']
    param_config_name = operational_config['param_config_name']
    target_populations = operational_config['target_populations']
    opt_param_config = optimization_params(env.netclamp_config.optimize_parameters, target_populations, param_config_name)

    opt_targets = opt_param_config.opt_targets
    param_names = opt_param_config.param_names
    param_tuples = opt_param_config.param_tuples
    hyperprm_space = { param_pattern: [param_tuple.param_range[0], param_tuple.param_range[1]]
                       for param_pattern, param_tuple in 
                           zip(param_names, param_tuples) }

    init_objfun = 'init_network_objfun'
    init_params = { 'operational_config': operational_config,
                    'opt_targets': opt_targets,
                    'param_tuples': [ param_tuple._asdict() for param_tuple in param_tuples ],
                    'param_names': param_names
                    }
    init_params.update(network_config.items())
    
    nworkers = env.comm.size-1
    if resample_fraction is None:
        resample_fraction = float(nworkers) / float(population_size)
    if resample_fraction > 1.0:
        resample_fraction = 1.0
    if resample_fraction < 0.1:
        resample_fraction = 0.1
    
    # Create an optimizer
    feature_dtypes = [(feature_name, np.float32) for feature_name in objective_names]
    constraint_names = [f'{target_pop_name} positive rate' for target_pop_name in target_populations ]
    dmosopt_params = {'opt_id': 'dentate.optimize_network',
                      'obj_fun_init_name': init_objfun, 
                      'obj_fun_init_module': 'dentate.optimize_network',
                      'obj_fun_init_args': init_params,
                      'reduce_fun_name': 'compute_objectives',
                      'reduce_fun_module': 'dentate.optimize_network',
                      'reduce_fun_args': (operational_config, opt_targets),
                      'problem_parameters': {},
                      'space': hyperprm_space,
                      'objective_names': objective_names,
                      'feature_dtypes': feature_dtypes,
                      'constraint_names': constraint_names,
                      'n_initial': n_initial,
                      'initial_maxiter': initial_maxiter,
                      'initial_method': initial_method,
                      'optimizer': optimizer_method,
                      'n_epochs': n_epochs,
                      'population_size': population_size,
                      'num_generations': num_generations,
                      'resample_fraction': resample_fraction,
                      'mutation_rate': mutation_rate,
                      'file_path': f'{optimize_file_dir}/{optimize_file_name}',
                      'termination_conditions': True,
                      'save_surrogate_eval': True,
                      'save': True,
                      'save_eval': 5
                      }
    
    #dmosopt_params['broker_fun_name'] = 'dmosopt_broker_init'
    #dmosopt_params['broker_module_name'] = 'dentate.optimize_network'

    best = dmosopt.run(dmosopt_params, spawn_workers=True, sequential_spawn=False,
                       spawn_startup_wait=spawn_startup_wait,
                       nprocs_per_worker=nprocs_per_worker,
                       collective_mode=collective_mode,
                       verbose=True, worker_debug=True)
    
    if best is not None:
        if optimize_file_dir is not None:
            results_file_id = 'DG_optimize_network_%s' % run_ts
            yaml_file_path = '%s/optimize_network.%s.yaml' % (optimize_file_dir, str(results_file_id))
            prms = best[0]
            prms_dict = dict(prms)
            n_res = prms[0][1].shape[0]
            results_config_dict = {}
            for i in range(n_res):
                result_param_list = []
                for param_pattern, param_tuple in zip(param_names, param_tuples):
                    result_param_list.append([param_pattern, float(prms_dict[param_pattern][i])])
                results_config_dict[i] = result_param_list
            write_to_yaml(yaml_file_path, results_config_dict)
Esempio n. 6
0
    def load_celltypes(self):
        """

        :return:
        """
        rank = self.comm.Get_rank()
        size = self.comm.Get_size()
        celltypes = self.celltypes

        if rank == 0:
            color = 1
        else:
            color = 0
        ## comm0 includes only rank 0
        comm0 = self.comm.Split(color, 0)

        if rank == 0:
            self.logger.info('env.data_file_path = %s' %
                             str(self.data_file_path))

        self.cell_attribute_info = None
        population_ranges = None
        population_names = None
        type_names = None
        if rank == 0:
            population_names = read_population_names(self.data_file_path,
                                                     comm0)
            (population_ranges,
             _) = read_population_ranges(self.data_file_path, comm0)
            type_names = sorted(population_ranges.keys())
            self.cell_attribute_info = read_cell_attribute_info(
                self.data_file_path, population_names, comm=comm0)
            self.logger.info('population_names = %s' % str(population_names))
            self.logger.info('population_ranges = %s' % str(population_ranges))
            self.logger.info('attribute info: %s' %
                             str(self.cell_attribute_info))
        population_ranges = self.comm.bcast(population_ranges, root=0)
        population_names = self.comm.bcast(population_names, root=0)
        type_names = self.comm.bcast(type_names, root=0)
        self.cell_attribute_info = self.comm.bcast(self.cell_attribute_info,
                                                   root=0)
        comm0.Free()

        for k in type_names:
            population_range = population_ranges.get(k, None)
            if population_range is not None:
                if k not in celltypes:
                    celltypes[k] = {}
                celltypes[k]['start'] = population_ranges[k][0]
                celltypes[k]['num'] = population_ranges[k][1]
                if 'mechanism file' in celltypes[k]:
                    celltypes[k]['mech_file_path'] = '%s/%s' % (
                        self.config_prefix, celltypes[k]['mechanism file'])
                    mech_dict = None
                    if rank == 0:
                        mech_dict = read_from_yaml(
                            celltypes[k]['mech_file_path'])
                    mech_dict = self.comm.bcast(mech_dict, root=0)
                    celltypes[k]['mech_dict'] = mech_dict
                if 'synapses' in celltypes[k]:
                    synapses_dict = celltypes[k]['synapses']
                    if 'weights' in synapses_dict:
                        weights_config = synapses_dict['weights']
                        if isinstance(weights_config, list):
                            weights_dicts = weights_config
                        else:
                            weights_dicts = [weights_config]
                        for weights_dict in weights_dicts:
                            if 'expr' in weights_dict:
                                expr = weights_dict['expr']
                                parameter = weights_dict['parameter']
                                const = weights_dict.get('const', None)
                                clos = ExprClosure(parameters=parameter,
                                                   expr=expr,
                                                   consts=const)
                                weights_dict['closure'] = clos
                        synapses_dict['weights'] = weights_dicts
Esempio n. 7
0
def main(config_path, params_id, n_samples, target_features_path, target_features_namespace, output_file_dir, output_file_name, verbose):

    config_logging(verbose)
    logger = utils.get_script_logger(os.path.basename(__file__))

    if params_id is None:
        if n_samples is not None:
            logger.info("Generating parameter lattice ...")
            generate_param_lattice(config_path, n_samples, output_file_dir, output_file_name, verbose)
            return

    network_args = click.get_current_context().args
    network_config = {}
    for arg in network_args:
        kv = arg.split("=")
        if len(kv) > 1:
            k,v = kv
            network_config[k.replace('--', '').replace('-', '_')] = v
        else:
            k = kv[0]
            network_config[k.replace('--', '').replace('-', '_')] = True

    run_ts = datetime.datetime.today().strftime('%Y%m%d_%H%M')

    if output_file_name is None:
        output_file_name=f"network_features_{run_ts}.h5"
    output_path = f'{output_file_dir}/{output_file_name}'

    eval_config = read_from_yaml(config_path, include_loader=utils.IncludeLoader)
    eval_config['run_ts'] = run_ts
    if target_features_path is not None:
        eval_config['target_features_path'] = target_features_path
    if target_features_namespace is not None:
        eval_config['target_features_namespace'] = target_features_namespace

    network_param_spec_src = eval_config['param_spec']
    network_param_values = eval_config['param_values']

    feature_names = eval_config['feature_names']
    target_populations = eval_config['target_populations']

    network_config.update(eval_config.get('kwargs', {}))

    if params_id is None:
        network_config['results_file_id'] = f"DG_eval_network_{eval_config['run_ts']}"
    else:
        network_config['results_file_id'] = f"DG_eval_network_{params_id}_{eval_config['run_ts']}"
        

    env = init_network(comm=MPI.COMM_WORLD, kwargs=network_config)
    gc.collect()

    t_start = 50.
    t_stop = env.tstop
    time_range = (t_start, t_stop)

    target_trj_rate_map_dict = {}
    target_features_arena = env.arena_id
    target_features_trajectory = env.trajectory_id
    for pop_name in target_populations:
        if ('%s target rate dist residual' % pop_name) not in feature_names:
            continue
        my_cell_index_set = set(env.biophys_cells[pop_name].keys())
        trj_rate_maps = {}
        trj_rate_maps = rate_maps_from_features(env, pop_name,
                                                cell_index_set=list(my_cell_index_set),
                                                input_features_path=target_features_path, 
                                                input_features_namespace=target_features_namespace,
                                                time_range=time_range)
        target_trj_rate_map_dict[pop_name] = trj_rate_maps


    network_param_spec = make_param_spec(target_populations, network_param_spec_src)

    def from_param_list(x):
        result = []
        for pop_param in x:
            this_population, source, sec_type, syn_name, param_path, param_val = pop_param
            param_tuple = SynParam(this_population, source, sec_type, syn_name, param_path, None)
            result.append((param_tuple, param_val))

        return result

    def from_param_dict(x):
        result = []
        for pop_name, param_specs in viewitems(x):
            keyfun = lambda kv: str(kv[0])
            for source, source_dict in sorted(viewitems(param_specs), key=keyfun):
                for sec_type, sec_type_dict in sorted(viewitems(source_dict), key=keyfun):
                    for syn_name, syn_mech_dict in sorted(viewitems(sec_type_dict), key=keyfun):
                        for param_fst, param_rst in sorted(viewitems(syn_mech_dict), key=keyfun):
                            if isinstance(param_rst, dict):
                                for const_name, const_value in sorted(viewitems(param_rst)):
                                    param_path = (param_fst, const_name)
                                    param_tuple = SynParam(pop_name, source, sec_type, syn_name, param_path, const_value)
                                    result.append(param_tuple, const_value)
                            else:
                                param_name = param_fst
                                param_value = param_rst
                                param_tuple = SynParam(pop_name, source, sec_type, syn_name, param_name, param_value)
                                result.append(param_tuple, param_value)
        return result

    eval_network(env, network_config,
                 from_param_list, from_param_dict,
                 network_param_spec, network_param_values, params_id,
                 target_trj_rate_map_dict, t_start, t_stop, 
                 target_populations, output_path)