Beispiel #1
0
def config_controller():
    """

    """
    utils.config_logging(context.verbose)
    context.logger = utils.get_script_logger(os.path.basename(__file__))

    try:
        context.env = Env(comm=context.controller_comm, **context.kwargs)
    except Exception as err:
        context.logger.exception(err)
        raise err

    opt_param_config = optimization_params(
        context.env.netclamp_config.optimize_parameters,
        context.target_populations, context.param_config_name)
    param_bounds = opt_param_config.param_bounds
    param_names = opt_param_config.param_bounds
    param_initial_dict = opt_param_config.param_initial_dict
    param_tuples = opt_param_config.param_tuples
    opt_targets = opt_param_config.opt_targets

    context.param_names = param_names
    context.bounds = [param_bounds[key] for key in param_names]
    context.x0 = param_initial_dict
    context.target_val = opt_targets
    context.target_range = opt_targets
    context.param_tuples = param_tuples
    # These kwargs will be sent from the controller to each worker context
    context.kwargs['param_tuples'] = param_tuples
def main(config_file, population, dt, gid, gid_selection_file, arena_id, trajectory_id, generate_weights,
         t_max, t_min,  nprocs_per_worker, n_epochs, n_initial, initial_maxiter, initial_method, optimizer_method, surrogate_method,
         population_size, num_generations, resample_fraction, mutation_rate,
         template_paths, dataset_prefix, config_prefix,
         param_config_name, selectivity_config_name, param_type, recording_profile, results_file, results_path, spike_events_path,
         spike_events_namespace, spike_events_t, input_features_path, input_features_namespaces, n_trials,
         trial_regime, problem_regime, target_features_path, target_features_namespace, target_state_variable,
         target_state_filter, use_coreneuron, cooperative_init, spawn_startup_wait):
    """
    Optimize the input stimulus selectivity of the specified cell in a network clamp configuration.
    """
    init_params = dict(locals())

    comm = MPI.COMM_WORLD
    size = comm.Get_size()
    rank = comm.Get_rank()

    results_file_id = None
    if rank == 0:
        results_file_id = generate_results_file_id(population, gid)
        
    results_file_id = comm.bcast(results_file_id, root=0)
    comm.barrier()
    
    np.seterr(all='raise')
    verbose = True
    cache_queries = True

    config_logging(verbose)

    cell_index_set = set([])
    if gid_selection_file is not None:
        with open(gid_selection_file, 'r') as f:
            lines = f.readlines()
            for line in lines:
                gid = int(line)
                cell_index_set.add(gid)
    elif gid is not None:
        cell_index_set.add(gid)
    else:
        comm.barrier()
        comm0 = comm.Split(2 if rank == 0 else 1, 0)
        if rank == 0:
            env = Env(**init_params, comm=comm0)
            attr_info_dict = read_cell_attribute_info(env.data_file_path, populations=[population],
                                                      read_cell_index=True, comm=comm0)
            cell_index = None
            attr_name, attr_cell_index = next(iter(attr_info_dict[population]['Trees']))
            cell_index_set = set(attr_cell_index)
        comm.barrier()
        cell_index_set = comm.bcast(cell_index_set, root=0)
        comm.barrier()
        comm0.Free()
    init_params['cell_index_set'] = cell_index_set
    del(init_params['gid'])

    params = dict(locals())
    env = Env(**params)
    if size == 1:
        configure_hoc_env(env)
        init(env, population, cell_index_set, arena_id, trajectory_id, n_trials,
             spike_events_path, spike_events_namespace=spike_events_namespace, 
             spike_train_attr_name=spike_events_t,
             input_features_path=input_features_path,
             input_features_namespaces=input_features_namespaces,
             generate_weights_pops=set(generate_weights), 
             t_min=t_min, t_max=t_max)
        
    if (population in env.netclamp_config.optimize_parameters[param_type]):
        opt_params = env.netclamp_config.optimize_parameters[param_type][population]
    else:
        raise RuntimeError(f'optimize_selectivity: population {population} does not have optimization configuration')

    if target_state_variable is None:
        target_state_variable = 'v'
    
    init_params['target_features_arena'] = arena_id
    init_params['target_features_trajectory'] = trajectory_id
    opt_state_baseline = opt_params['Targets']['state'][target_state_variable]['baseline']
    init_params['state_baseline'] = opt_state_baseline
    init_params['state_variable'] = target_state_variable
    init_params['state_filter'] = target_state_filter
    init_objfun_name = 'init_selectivity_objfun'
        
    best = optimize_run(env, population, param_config_name, selectivity_config_name, init_objfun_name, problem_regime=problem_regime,
                        n_epochs=n_epochs, n_initial=n_initial, initial_maxiter=initial_maxiter, initial_method=initial_method, 
                        optimizer_method=optimizer_method, surrogate_method=surrogate_method, population_size=population_size, 
                        num_generations=num_generations, resample_fraction=resample_fraction, mutation_rate=mutation_rate, 
                        param_type=param_type, init_params=init_params, results_file=results_file, nprocs_per_worker=nprocs_per_worker, 
                        cooperative_init=cooperative_init, spawn_startup_wait=spawn_startup_wait, verbose=verbose)
    
    opt_param_config = optimization_params(env.netclamp_config.optimize_parameters, [population], param_config_name, param_type)
    if best is not None:
        if results_path is not None:
            run_ts = time.strftime("%Y%m%d_%H%M%S")
            file_path = f'{results_path}/optimize_selectivity.{run_ts}.yaml'
            param_names = opt_param_config.param_names
            param_tuples = opt_param_config.param_tuples

            if ProblemRegime[problem_regime] == ProblemRegime.every:
                results_config_dict = {}
                for gid, prms in viewitems(best):
                    n_res = prms[0][1].shape[0]
                    prms_dict = dict(prms)
                    this_results_config_dict = {}
                    for i in range(n_res):
                        results_param_list = []
                        for param_pattern, param_tuple in zip(param_names, param_tuples):
                            results_param_list.append((param_tuple.population,
                                                       param_tuple.source,
                                                       param_tuple.sec_type,
                                                       param_tuple.syn_name,
                                                       param_tuple.param_path,
                                                       float(prms_dict[param_pattern][i])))
                        this_results_config_dict[i] = results_param_list
                    results_config_dict[gid] = this_results_config_dict
                    
            else:
                prms = best[0]
                n_res = prms[0][1].shape[0]
                prms_dict = dict(prms)
                results_config_dict = {}
                for i in range(n_res):
                    results_param_list = []
                    for param_pattern, param_tuple in zip(param_names, param_tuples):
                        results_param_list.append((param_tuple.population,
                                                   param_tuple.source,
                                                   param_tuple.sec_type,
                                                   param_tuple.syn_name,
                                                   param_tuple.param_path,
                                                   float(prms_dict[param_pattern][i])))
                    results_config_dict[i] = results_param_list

            write_to_yaml(file_path, { population: results_config_dict } )

            
    comm.barrier()
def optimize_run(env, population, param_config_name, selectivity_config_name, init_objfun, problem_regime, nprocs_per_worker=1,
                 n_epochs=10, n_initial=30, initial_maxiter=50, initial_method="slh", optimizer_method="nsga2", surrogate_method='vgp',
                 population_size=200, num_generations=200, resample_fraction=None, mutation_rate=None,
                 param_type='synaptic', init_params={}, results_file=None, cooperative_init=False, 
                 spawn_startup_wait=None, verbose=False):

    opt_param_config = optimization_params(env.netclamp_config.optimize_parameters, [population], param_config_name, param_type)

    opt_targets = opt_param_config.opt_targets
    param_names = opt_param_config.param_names
    param_tuples = opt_param_config.param_tuples
    
    hyperprm_space = { param_pattern: [param_tuple.param_range[0], param_tuple.param_range[1]]
                       for param_pattern, param_tuple in 
                           zip(param_names, param_tuples) }

    if results_file is None:
        if env.results_path is not None:
            file_path = f'{env.results_path}/dmosopt.optimize_selectivity.{env.results_file_id}.h5'
        else:
            file_path = f'dmosopt.optimize_selectivity.{env.results_file_id}.h5'
    else:
        file_path = '%s/%s' % (env.results_path, results_file)
    problem_ids = None
    reduce_fun_name = None
    if ProblemRegime[problem_regime] == ProblemRegime.every:
        reduce_fun_name = "opt_reduce_every"
        problem_ids = init_params.get('cell_index_set', None)
    elif ProblemRegime[problem_regime] == ProblemRegime.mean:
        reduce_fun_name = "opt_reduce_mean"
    elif ProblemRegime[problem_regime] == ProblemRegime.max:
        reduce_fun_name = "opt_reduce_max"
    else:
        raise RuntimeError(f'optimize_run: unknown problem regime {problem_regime}')

    n_trials = init_params.get('n_trials', 1)

    nworkers = env.comm.size-1
    if resample_fraction is None:
        resample_fraction = float(nworkers) / float(population_size)
    if resample_fraction > 1.0:
        resample_fraction = 1.0
    if resample_fraction < 0.1:
        resample_fraction = 0.1

    objective_names = ['residual_infld', 'residual_state']
    feature_names = ['mean_peak_rate', 'mean_trough_rate', 
                     'max_infld_rate', 'min_infld_rate', 'mean_infld_rate', 'mean_outfld_rate', 
                     'mean_peak_state', 'mean_trough_state', 'mean_outfld_state']
    N_objectives = 2
    feature_dtypes = [(feature_name, np.float32) for feature_name in feature_names]
    feature_dtypes.append(('trial_objs', np.float32, (N_objectives, n_trials)))
    feature_dtypes.append(('trial_mean_infld_rate', (np.float32, (1, n_trials))))
    feature_dtypes.append(('trial_mean_outfld_rate', (np.float32, (1, n_trials))))

    constraint_names = ['positive_rate']
    dmosopt_params = {'opt_id': 'dentate.optimize_selectivity',
                      'problem_ids': problem_ids,
                      'obj_fun_init_name': init_objfun, 
                      'obj_fun_init_module': 'dentate.optimize_selectivity',
                      'obj_fun_init_args': init_params,
                      'reduce_fun_name': reduce_fun_name,
                      'reduce_fun_module': 'dentate.optimization',
                      'problem_parameters': {},
                      'space': hyperprm_space,
                      'objective_names': objective_names,
                      'feature_dtypes': feature_dtypes,
                      'constraint_names': constraint_names,
                      'n_initial': n_initial,
                      'n_epochs': n_epochs,
                      'population_size': population_size,
                      'num_generations': num_generations,
                      'resample_fraction': resample_fraction,
                      'mutation_rate': mutation_rate,
                      'initial_maxiter': initial_maxiter,
                      'initial_method': initial_method,
                      'optimizer': optimizer_method,
                      'surrogate_method': surrogate_method,
                      'file_path': file_path,
                      'save': True,
                      'save_eval' : 5,
                      }


    opt_results = dmosopt.run(dmosopt_params, verbose=verbose, collective_mode="sendrecv",
                              spawn_workers=True, nprocs_per_worker=nprocs_per_worker, 
                              spawn_startup_wait=spawn_startup_wait
                              )
    if opt_results is not None:
        if ProblemRegime[problem_regime] == ProblemRegime.every:
            gid_results_config_dict = {}
            for gid, opt_result in viewitems(opt_results):
                params_dict = dict(opt_result[0])
                result_value = opt_result[1]
                results_config_tuples = []
                for param_pattern, param_tuple in zip(param_names, param_tuples):
                    results_config_tuples.append((param_pattern, params_dict[param_pattern]))
                gid_results_config_dict[int(gid)] = results_config_tuples

            logger.info('Optimized parameters and objective function: '
                        f'{pprint.pformat(gid_results_config_dict)} @'
                        f'{result_value}')
            return gid_results_config_dict
        else:
            params_dict = dict(opt_results[0])
            result_value = opt_results[1]
            results_config_tuples = []
            for param_pattern, param_tuple in zip(param_names, param_tuples):
                results_config_tuples.append((param_pattern, params_dict[param_pattern]))
            logger.info('Optimized parameters and objective function: '
                        f'{pprint.pformat(results_config_tuples)} @'
                        f'{result_value}')
            return results_config_tuples
    else:
        return None
def init_selectivity_objfun(config_file, population, cell_index_set, arena_id, trajectory_id,
                            n_trials, trial_regime, problem_regime,
                            generate_weights, t_max, t_min,
                            template_paths, dataset_prefix, config_prefix, results_path,
                            spike_events_path, spike_events_namespace, spike_events_t,
                            input_features_path, input_features_namespaces,
                            param_type, param_config_name, selectivity_config_name, recording_profile, 
                            state_variable, state_filter, state_baseline,
                            target_features_path, target_features_namespace,
                            target_features_arena, target_features_trajectory,   
                            use_coreneuron, cooperative_init, dt, worker, **kwargs):
    
    params = dict(locals())
    env = Env(**params)
    env.results_file_path = None
    configure_hoc_env(env, bcast_template=True)

    my_cell_index_set = init(env, population, cell_index_set, arena_id, trajectory_id, n_trials,
                             spike_events_path, spike_events_namespace=spike_events_namespace, 
                             spike_train_attr_name=spike_events_t,
                             input_features_path=input_features_path,
                             input_features_namespaces=input_features_namespaces,
                             generate_weights_pops=set(generate_weights), 
                             t_min=t_min, t_max=t_max, cooperative_init=cooperative_init,
                             worker=worker)

    time_step = float(env.stimulus_config['Temporal Resolution'])
    equilibration_duration = float(env.stimulus_config.get('Equilibration Duration', 0.))
    
    target_rate_vector_dict = rate_maps_from_features (env, population,
                                                       cell_index_set=my_cell_index_set, 
                                                       input_features_path=target_features_path,
                                                       input_features_namespace=target_features_namespace, 
                                                       time_range=[0., t_max], 
                                                       arena_id=arena_id)


    logger.info(f'target_rate_vector_dict = {target_rate_vector_dict}')
    for gid, target_rate_vector in viewitems(target_rate_vector_dict):
        target_rate_vector[np.isclose(target_rate_vector, 0., atol=1e-3, rtol=1e-3)] = 0.

    trj_x, trj_y, trj_d, trj_t = stimulus.read_trajectory(input_features_path if input_features_path is not None else spike_events_path, 
                                                          target_features_arena, target_features_trajectory)
    time_range = (0., min(np.max(trj_t), t_max))
    time_bins = np.arange(time_range[0], time_range[1]+time_step, time_step)
    state_time_bins = np.arange(time_range[0], time_range[1], time_step)[:-1]

    def range_inds(rs):
        l = list(rs)
        if len(l) > 0:
            a = np.concatenate(l)
        else:
            a = None
        return a

    def time_ranges(rs):
        if len(rs) > 0:
            a = tuple( ( (time_bins[r[0]], time_bins[r[1]-1]) for r in rs ) )
        else:
            a = None
        return a
        
    
    infld_idxs_dict = { gid: np.where(target_rate_vector > 1e-4)[0] 
                        for gid, target_rate_vector in viewitems(target_rate_vector_dict) }
    peak_pctile_dict = { gid: np.percentile(target_rate_vector_dict[gid][infld_idxs], 80)
                         for gid, infld_idxs in viewitems(infld_idxs_dict) }
    trough_pctile_dict = { gid: np.percentile(target_rate_vector_dict[gid][infld_idxs], 20)
                           for gid, infld_idxs in viewitems(infld_idxs_dict) }
    outfld_idxs_dict = { gid: range_inds(contiguous_ranges(target_rate_vector < 1e-4, return_indices=True))
                        for gid, target_rate_vector in viewitems(target_rate_vector_dict) }

    peak_idxs_dict = { gid: range_inds(contiguous_ranges(target_rate_vector >= peak_pctile_dict[gid], return_indices=True)) 
                       for gid, target_rate_vector in viewitems(target_rate_vector_dict) }
    trough_idxs_dict = { gid: range_inds(contiguous_ranges(np.logical_and(target_rate_vector > 0., target_rate_vector <= trough_pctile_dict[gid]), return_indices=True))
                         for gid, target_rate_vector in viewitems(target_rate_vector_dict) }

    outfld_ranges_dict = { gid: time_ranges(contiguous_ranges(target_rate_vector <= 0.) ) 
                           for gid, target_rate_vector in viewitems(target_rate_vector_dict) }
    infld_ranges_dict = { gid: time_ranges(contiguous_ranges(target_rate_vector > 0) ) 
                          for gid, target_rate_vector in viewitems(target_rate_vector_dict) }

    peak_ranges_dict = { gid: time_ranges(contiguous_ranges(target_rate_vector >= peak_pctile_dict[gid]))
                         for gid, target_rate_vector in viewitems(target_rate_vector_dict) }
    trough_ranges_dict = { gid: time_ranges(contiguous_ranges(np.logical_and(target_rate_vector > 0., target_rate_vector <= trough_pctile_dict[gid])))
                         for gid, target_rate_vector in viewitems(target_rate_vector_dict) }

    large_fld_gids = []
    for gid in my_cell_index_set:

        infld_idxs = infld_idxs_dict[gid]

        target_infld_rate_vector = target_rate_vector[infld_idxs]
        target_peak_rate_vector = target_rate_vector[peak_idxs_dict[gid]]
        target_trough_rate_vector = target_rate_vector[trough_idxs_dict[gid]]

        logger.info(f'selectivity objective: target peak/trough rate of gid {gid}: '
                    f'{peak_pctile_dict[gid]:.02f} {trough_pctile_dict[gid]:.02f}')
        logger.info(f'selectivity objective: mean target peak/trough rate of gid {gid}: '
                    f'{np.mean(target_peak_rate_vector):.02f} {np.mean(target_trough_rate_vector):.02f}')
        
    opt_param_config = optimization_params(env.netclamp_config.optimize_parameters, [population], param_config_name, param_type)
    selectivity_opt_param_config = selectivity_optimization_params(env.netclamp_config.optimize_parameters, [population],
                                                                   selectivity_config_name)

    opt_targets = opt_param_config.opt_targets
    param_names = opt_param_config.param_names
    param_tuples = opt_param_config.param_tuples

    N_objectives = 2
    feature_names = ['mean_peak_rate', 'mean_trough_rate', 
                     'max_infld_rate', 'min_infld_rate', 'mean_infld_rate', 'mean_outfld_rate', 
                     'mean_peak_state', 'mean_trough_state', 'mean_outfld_state']
    feature_dtypes = [(feature_name, np.float32) for feature_name in feature_names]
    feature_dtypes.append(('trial_objs', (np.float32, (N_objectives, n_trials))))
    feature_dtypes.append(('trial_mean_infld_rate', (np.float32, (1, n_trials))))
    feature_dtypes.append(('trial_mean_outfld_rate', (np.float32, (1, n_trials))))

    def from_param_dict(params_dict):
        result = []
        for param_pattern, param_tuple in zip(param_names, param_tuples):
            result.append((param_tuple, params_dict[param_pattern]))

        return result

    def update_run_params(input_param_tuple_vals, update_param_names, update_param_tuples):
        result = []
        updated_set = set([])
        update_param_dict = dict(zip(update_param_names, update_param_tuples))
        for param_pattern, (param_tuple, param_val) in zip(param_names, input_param_tuple_vals):
            if param_pattern in update_param_dict:
                updated_set.add(param_pattern)
                result.append((param_tuple, update_param_dict[param_pattern].param_range))
            else:
                result.append((param_tuple, param_val))
        for update_param_name in update_param_dict:
            if update_param_name not in updated_set:
                result.append((update_param_dict[update_param_name], 
                               update_param_dict[update_param_name].param_range))

        return result
        
    
    def gid_firing_rate_vectors(spkdict, cell_index_set):
        rates_dict = defaultdict(list)
        for i in range(n_trials):
            spkdict1 = {}
            for gid in cell_index_set:
                if gid in spkdict[population]:
                    spkdict1[gid] = spkdict[population][gid][i]
                else:
                    spkdict1[gid] = np.asarray([], dtype=np.float32)
            spike_density_dict = spikedata.spike_density_estimate (population, spkdict1, time_bins)
            for gid in cell_index_set:
                rate_vector = spike_density_dict[gid]['rate']
                rate_vector[np.isclose(rate_vector, 0., atol=1e-3, rtol=1e-3)] = 0.
                rates_dict[gid].append(rate_vector)
                logger.info(f'selectivity objective: trial {i} firing rate min/max of gid {gid}: '
                            f'{np.min(rates_dict[gid]):.02f} / {np.max(rates_dict[gid]):.02f} Hz')

        return rates_dict

    def gid_state_values(spkdict, t_offset, n_trials, t_rec, state_recs_dict):
        t_vec = np.asarray(t_rec.to_python(), dtype=np.float32)
        t_trial_inds = get_trial_time_indices(t_vec, n_trials, t_offset)
        results_dict = {}
        filter_fun = None
        if state_filter == 'lowpass':
            filter_fun = lambda x, t: get_low_pass_filtered_trace(x, t)
        for gid in state_recs_dict:
            state_values = None
            state_recs = state_recs_dict[gid]
            assert(len(state_recs) == 1)
            rec = state_recs[0]
            vec = np.asarray(rec['vec'].to_python(), dtype=np.float32)
            if filter_fun is None:
                data = np.asarray([ vec[t_inds] for t_inds in t_trial_inds ])
            else:
                data = np.asarray([ filter_fun(vec[t_inds], t_vec[t_inds])
                                    for t_inds in t_trial_inds ])

            state_values = []
            max_len = np.max(np.asarray([len(a) for a in data]))
            for state_value_array in data:
                this_len = len(state_value_array)
                if this_len < max_len:
                    a = np.pad(state_value_array, (0, max_len-this_len), 'edge')
                else:
                    a = state_value_array
                state_values.append(a)

            results_dict[gid] = state_values
        return t_vec[t_trial_inds[0]], results_dict


    def trial_snr_residuals(gid, peak_idxs, trough_idxs, infld_idxs, outfld_idxs, 
                            rate_vectors, masked_rate_vectors, target_rate_vector):

        n_trials = len(rate_vectors)
        residual_inflds = []
        trial_inflds = []
        trial_outflds = []

        target_infld = target_rate_vector[infld_idxs]
        target_max_infld = np.max(target_infld)
        target_mean_trough = np.mean(target_rate_vector[trough_idxs])
        logger.info(f'selectivity objective: target max infld/mean trough of gid {gid}: '
                    f'{target_max_infld:.02f} {target_mean_trough:.02f}')
        for trial_i in range(n_trials):

            rate_vector = rate_vectors[trial_i]
            infld_rate_vector = rate_vector[infld_idxs]
            masked_rate_vector = masked_rate_vectors[trial_i]
            if outfld_idxs is None:
                outfld_rate_vector = masked_rate_vectors[trial_i]
            else:
                outfld_rate_vector = rate_vector[outfld_idxs]

            mean_peak = np.mean(rate_vector[peak_idxs])
            mean_trough = np.mean(rate_vector[trough_idxs])
            min_infld = np.min(infld_rate_vector)
            max_infld = np.max(infld_rate_vector)
            mean_infld = np.mean(infld_rate_vector)
            mean_outfld = np.mean(outfld_rate_vector)

            residual_infld = np.abs(np.sum(target_infld - infld_rate_vector))
            logger.info(f'selectivity objective: max infld/mean infld/mean peak/trough/mean outfld/residual_infld of gid {gid} trial {trial_i}: '
                        f'{max_infld:.02f} {mean_infld:.02f} {mean_peak:.02f} {mean_trough:.02f} {mean_outfld:.02f} {residual_infld:.04f}')
            residual_inflds.append(residual_infld)
            trial_inflds.append(mean_infld)
            trial_outflds.append(mean_outfld)

        trial_rate_features = [np.asarray(trial_inflds, dtype=np.float32).reshape((1, n_trials)), 
                               np.asarray(trial_outflds, dtype=np.float32).reshape((1, n_trials))]
        rate_features = [mean_peak, mean_trough, max_infld, min_infld, mean_infld, mean_outfld, ]
        #rate_constr = [ mean_peak if max_infld > 0. else -1. ]
        rate_constr = [ mean_peak - mean_trough if max_infld > 0. else -1. ]
        return (np.asarray(residual_inflds), trial_rate_features, rate_features, rate_constr)

    
    def trial_state_residuals(gid, target_outfld, t_peak_idxs, t_trough_idxs, t_infld_idxs, t_outfld_idxs, state_values, masked_state_values):

        state_value_arrays = np.row_stack(state_values)
        masked_state_value_arrays = None
        if masked_state_values is not None:
            masked_state_value_arrays = np.row_stack(masked_state_values)
        
        residuals_outfld = []
        peak_inflds = []
        trough_inflds = []
        mean_outflds = []
        for i in range(state_value_arrays.shape[0]):
            state_value_array = state_value_arrays[i, :]
            peak_infld = np.mean(state_value_array[t_peak_idxs])
            trough_infld = np.mean(state_value_array[t_trough_idxs])
            mean_infld = np.mean(state_value_array[t_infld_idxs])

            masked_state_value_array = masked_state_value_arrays[i, :]
            mean_masked = np.mean(masked_state_value_array)
            residual_masked = np.mean(masked_state_value_array) - target_outfld

            mean_outfld = mean_masked
            if t_outfld_idxs is not None:
                mean_outfld = np.mean(state_value_array[t_outfld_idxs])
                
            peak_inflds.append(peak_infld)
            trough_inflds.append(trough_infld)
            mean_outflds.append(mean_outfld)
            residuals_outfld.append(residual_masked)
            logger.info(f'selectivity objective: state values of gid {gid}: '
                        f'peak/trough/mean in/mean out/masked: {peak_infld:.02f} / {trough_infld:.02f} / {mean_infld:.02f} / {mean_outfld:.02f} / residual masked: {residual_masked:.04f}')

        state_features = [np.mean(peak_inflds), np.mean(trough_inflds), np.mean(mean_outflds)]
        return (np.asarray(residuals_outfld), state_features)

    
    recording_profile = { 'label': f'optimize_selectivity.{state_variable}',
                          'section quantity': {
                              state_variable: { 'swc types': ['soma'] }
                            }
                        }
    env.recording_profile = recording_profile
    state_recs_dict = {}
    for gid in my_cell_index_set:
        state_recs_dict[gid] = record_cell(env, population, gid, recording_profile=recording_profile)

        
    def eval_problem(cell_param_dict, **kwargs):

        run_params = {population: {gid: from_param_dict(cell_param_dict[gid])
                                   for gid in my_cell_index_set}}
        masked_state_values_dict = {}
        masked_run_params = {population: { gid: update_run_params(run_params[population][gid],
                                                                  selectivity_opt_param_config.mask_param_names,
                                                                  selectivity_opt_param_config.mask_param_tuples)
                                           for gid in my_cell_index_set} }
        spkdict = run_with(env, run_params)
        rates_dict = gid_firing_rate_vectors(spkdict, my_cell_index_set)
        t_s, state_values_dict = gid_state_values(spkdict, equilibration_duration, n_trials, env.t_rec, 
                                                  state_recs_dict)

        masked_spkdict = run_with(env, masked_run_params)
        masked_rates_dict = gid_firing_rate_vectors(masked_spkdict, my_cell_index_set)
        t_s, masked_state_values_dict = gid_state_values(masked_spkdict, equilibration_duration, n_trials, env.t_rec, 
                                                         state_recs_dict)
        
        
        result = {}
        for gid in my_cell_index_set:
            infld_idxs = infld_idxs_dict[gid]
            outfld_idxs = outfld_idxs_dict[gid]
            peak_idxs = peak_idxs_dict[gid]
            trough_idxs = trough_idxs_dict[gid]
            
            target_rate_vector = target_rate_vector_dict[gid]

            peak_ranges = peak_ranges_dict[gid]
            trough_ranges = trough_ranges_dict[gid]
            infld_ranges = infld_ranges_dict[gid]
            outfld_ranges = outfld_ranges_dict[gid]
            
            t_peak_idxs = np.concatenate([ np.where(np.logical_and(t_s >= r[0], t_s < r[1]))[0] for r in peak_ranges ])
            t_trough_idxs = np.concatenate([ np.where(np.logical_and(t_s >= r[0], t_s < r[1]))[0] for r in trough_ranges ])
            t_infld_idxs = np.concatenate([ np.where(np.logical_and(t_s >= r[0], t_s < r[1]))[0] for r in infld_ranges ])
            if outfld_ranges is not None:
                t_outfld_idxs = np.concatenate([ np.where(np.logical_and(t_s >= r[0], t_s < r[1]))[0] for r in outfld_ranges ])
            else:
                t_outfld_idxs = None
            
            masked_state_values = masked_state_values_dict.get(gid, None)
            state_values = state_values_dict[gid]
            rate_vectors = rates_dict[gid]
            masked_rate_vectors = masked_rates_dict[gid]
            
            logger.info(f'selectivity objective: max rates of gid {gid}: '
                        f'{list([np.max(rate_vector) for rate_vector in rate_vectors])}')

            infld_residuals, trial_rate_features, rate_features, rate_constr = \
              trial_snr_residuals(gid, peak_idxs, trough_idxs, infld_idxs, outfld_idxs, 
                                  rate_vectors, masked_rate_vectors, target_rate_vector)
            state_residuals, state_features = trial_state_residuals(gid, state_baseline,
                                                                    t_peak_idxs, t_trough_idxs, t_infld_idxs, t_outfld_idxs,
                                                                    state_values, masked_state_values)
            trial_obj_features = np.row_stack((infld_residuals, state_residuals))
            
            if trial_regime == 'mean':
                mean_infld_residual = np.mean(infld_residuals)
                mean_state_residual = np.mean(state_residuals)
                infld_objective = mean_infld_residual
                state_objective = abs(mean_state_residual)
                logger.info(f'selectivity objective: mean peak/trough/mean infld/mean outfld/mean state residual of gid {gid}: '
                            f'{mean_infld_residual:.04f} {mean_state_residual:.04f}')
            elif trial_regime == 'best':
                min_infld_residual_index = np.argmin(infld_residuals)
                min_infld_residual = infld_residuals[min_infld_index]
                infld_objective = min_infld_residual
                min_state_residual = np.min(np.abs(state_residuals))
                state_objective = min_state_residual
                logger.info(f'selectivity objective: mean peak/trough/max infld/max outfld/min state residual of gid {gid}: '
                            f'{min_infld_residual:.04f} {min_state_residual:.04f}')
            else:
                raise RuntimeError(f'selectivity_rate_objective: unknown trial regime {trial_regime}')

            logger.info(f"rate_features: {rate_features} state_features: {state_features} obj_features: {trial_obj_features}")

            result[gid] = (np.asarray([ infld_objective, state_objective ], 
                                      dtype=np.float32), 
                           np.array([tuple(rate_features+state_features+[trial_obj_features]+trial_rate_features)], 
                                    dtype=np.dtype(feature_dtypes)),
                           np.asarray(rate_constr, dtype=np.float32))
                           
        return result
    
    return opt_eval_fun(problem_regime, my_cell_index_set, eval_problem)
Beispiel #5
0
def main(config_path, target_features_path, target_features_namespace, optimize_file_dir, optimize_file_name, nprocs_per_worker, n_epochs, n_initial, initial_maxiter, initial_method, optimizer_method, population_size, num_generations, resample_fraction, mutation_rate, collective_mode, spawn_startup_wait, verbose):

    network_args = click.get_current_context().args
    network_config = {}
    for arg in network_args:
        kv = arg.split("=")
        if len(kv) > 1:
            k,v = kv
            network_config[k.replace('--', '').replace('-', '_')] = v
        else:
            k = kv[0]
            network_config[k.replace('--', '').replace('-', '_')] = True

    run_ts = datetime.datetime.today().strftime('%Y%m%d_%H%M')

    if optimize_file_name is None:
        optimize_file_name=f"dmosopt.optimize_network_{run_ts}.h5"
    operational_config = read_from_yaml(config_path)
    operational_config['run_ts'] = run_ts
    if target_features_path is not None:
        operational_config['target_features_path'] = target_features_path
    if target_features_namespace is not None:
        operational_config['target_features_namespace'] = target_features_namespace

    network_config.update(operational_config.get('kwargs', {}))
    env = Env(**network_config)

    objective_names = operational_config['objective_names']
    param_config_name = operational_config['param_config_name']
    target_populations = operational_config['target_populations']
    opt_param_config = optimization_params(env.netclamp_config.optimize_parameters, target_populations, param_config_name)

    opt_targets = opt_param_config.opt_targets
    param_names = opt_param_config.param_names
    param_tuples = opt_param_config.param_tuples
    hyperprm_space = { param_pattern: [param_tuple.param_range[0], param_tuple.param_range[1]]
                       for param_pattern, param_tuple in 
                           zip(param_names, param_tuples) }

    init_objfun = 'init_network_objfun'
    init_params = { 'operational_config': operational_config,
                    'opt_targets': opt_targets,
                    'param_tuples': [ param_tuple._asdict() for param_tuple in param_tuples ],
                    'param_names': param_names
                    }
    init_params.update(network_config.items())
    
    nworkers = env.comm.size-1
    if resample_fraction is None:
        resample_fraction = float(nworkers) / float(population_size)
    if resample_fraction > 1.0:
        resample_fraction = 1.0
    if resample_fraction < 0.1:
        resample_fraction = 0.1
    
    # Create an optimizer
    feature_dtypes = [(feature_name, np.float32) for feature_name in objective_names]
    constraint_names = [f'{target_pop_name} positive rate' for target_pop_name in target_populations ]
    dmosopt_params = {'opt_id': 'dentate.optimize_network',
                      'obj_fun_init_name': init_objfun, 
                      'obj_fun_init_module': 'dentate.optimize_network',
                      'obj_fun_init_args': init_params,
                      'reduce_fun_name': 'compute_objectives',
                      'reduce_fun_module': 'dentate.optimize_network',
                      'reduce_fun_args': (operational_config, opt_targets),
                      'problem_parameters': {},
                      'space': hyperprm_space,
                      'objective_names': objective_names,
                      'feature_dtypes': feature_dtypes,
                      'constraint_names': constraint_names,
                      'n_initial': n_initial,
                      'initial_maxiter': initial_maxiter,
                      'initial_method': initial_method,
                      'optimizer': optimizer_method,
                      'n_epochs': n_epochs,
                      'population_size': population_size,
                      'num_generations': num_generations,
                      'resample_fraction': resample_fraction,
                      'mutation_rate': mutation_rate,
                      'file_path': f'{optimize_file_dir}/{optimize_file_name}',
                      'termination_conditions': True,
                      'save_surrogate_eval': True,
                      'save': True,
                      'save_eval': 5
                      }
    
    #dmosopt_params['broker_fun_name'] = 'dmosopt_broker_init'
    #dmosopt_params['broker_module_name'] = 'dentate.optimize_network'

    best = dmosopt.run(dmosopt_params, spawn_workers=True, sequential_spawn=False,
                       spawn_startup_wait=spawn_startup_wait,
                       nprocs_per_worker=nprocs_per_worker,
                       collective_mode=collective_mode,
                       verbose=True, worker_debug=True)
    
    if best is not None:
        if optimize_file_dir is not None:
            results_file_id = 'DG_optimize_network_%s' % run_ts
            yaml_file_path = '%s/optimize_network.%s.yaml' % (optimize_file_dir, str(results_file_id))
            prms = best[0]
            prms_dict = dict(prms)
            n_res = prms[0][1].shape[0]
            results_config_dict = {}
            for i in range(n_res):
                result_param_list = []
                for param_pattern, param_tuple in zip(param_names, param_tuples):
                    result_param_list.append([param_pattern, float(prms_dict[param_pattern][i])])
                results_config_dict[i] = result_param_list
            write_to_yaml(yaml_file_path, results_config_dict)