def config_controller(): """ """ utils.config_logging(context.verbose) context.logger = utils.get_script_logger(os.path.basename(__file__)) try: context.env = Env(comm=context.controller_comm, **context.kwargs) except Exception as err: context.logger.exception(err) raise err opt_param_config = optimization_params( context.env.netclamp_config.optimize_parameters, context.target_populations, context.param_config_name) param_bounds = opt_param_config.param_bounds param_names = opt_param_config.param_bounds param_initial_dict = opt_param_config.param_initial_dict param_tuples = opt_param_config.param_tuples opt_targets = opt_param_config.opt_targets context.param_names = param_names context.bounds = [param_bounds[key] for key in param_names] context.x0 = param_initial_dict context.target_val = opt_targets context.target_range = opt_targets context.param_tuples = param_tuples # These kwargs will be sent from the controller to each worker context context.kwargs['param_tuples'] = param_tuples
def main(config_file, population, dt, gid, gid_selection_file, arena_id, trajectory_id, generate_weights, t_max, t_min, nprocs_per_worker, n_epochs, n_initial, initial_maxiter, initial_method, optimizer_method, surrogate_method, population_size, num_generations, resample_fraction, mutation_rate, template_paths, dataset_prefix, config_prefix, param_config_name, selectivity_config_name, param_type, recording_profile, results_file, results_path, spike_events_path, spike_events_namespace, spike_events_t, input_features_path, input_features_namespaces, n_trials, trial_regime, problem_regime, target_features_path, target_features_namespace, target_state_variable, target_state_filter, use_coreneuron, cooperative_init, spawn_startup_wait): """ Optimize the input stimulus selectivity of the specified cell in a network clamp configuration. """ init_params = dict(locals()) comm = MPI.COMM_WORLD size = comm.Get_size() rank = comm.Get_rank() results_file_id = None if rank == 0: results_file_id = generate_results_file_id(population, gid) results_file_id = comm.bcast(results_file_id, root=0) comm.barrier() np.seterr(all='raise') verbose = True cache_queries = True config_logging(verbose) cell_index_set = set([]) if gid_selection_file is not None: with open(gid_selection_file, 'r') as f: lines = f.readlines() for line in lines: gid = int(line) cell_index_set.add(gid) elif gid is not None: cell_index_set.add(gid) else: comm.barrier() comm0 = comm.Split(2 if rank == 0 else 1, 0) if rank == 0: env = Env(**init_params, comm=comm0) attr_info_dict = read_cell_attribute_info(env.data_file_path, populations=[population], read_cell_index=True, comm=comm0) cell_index = None attr_name, attr_cell_index = next(iter(attr_info_dict[population]['Trees'])) cell_index_set = set(attr_cell_index) comm.barrier() cell_index_set = comm.bcast(cell_index_set, root=0) comm.barrier() comm0.Free() init_params['cell_index_set'] = cell_index_set del(init_params['gid']) params = dict(locals()) env = Env(**params) if size == 1: configure_hoc_env(env) init(env, population, cell_index_set, arena_id, trajectory_id, n_trials, spike_events_path, spike_events_namespace=spike_events_namespace, spike_train_attr_name=spike_events_t, input_features_path=input_features_path, input_features_namespaces=input_features_namespaces, generate_weights_pops=set(generate_weights), t_min=t_min, t_max=t_max) if (population in env.netclamp_config.optimize_parameters[param_type]): opt_params = env.netclamp_config.optimize_parameters[param_type][population] else: raise RuntimeError(f'optimize_selectivity: population {population} does not have optimization configuration') if target_state_variable is None: target_state_variable = 'v' init_params['target_features_arena'] = arena_id init_params['target_features_trajectory'] = trajectory_id opt_state_baseline = opt_params['Targets']['state'][target_state_variable]['baseline'] init_params['state_baseline'] = opt_state_baseline init_params['state_variable'] = target_state_variable init_params['state_filter'] = target_state_filter init_objfun_name = 'init_selectivity_objfun' best = optimize_run(env, population, param_config_name, selectivity_config_name, init_objfun_name, problem_regime=problem_regime, n_epochs=n_epochs, n_initial=n_initial, initial_maxiter=initial_maxiter, initial_method=initial_method, optimizer_method=optimizer_method, surrogate_method=surrogate_method, population_size=population_size, num_generations=num_generations, resample_fraction=resample_fraction, mutation_rate=mutation_rate, param_type=param_type, init_params=init_params, results_file=results_file, nprocs_per_worker=nprocs_per_worker, cooperative_init=cooperative_init, spawn_startup_wait=spawn_startup_wait, verbose=verbose) opt_param_config = optimization_params(env.netclamp_config.optimize_parameters, [population], param_config_name, param_type) if best is not None: if results_path is not None: run_ts = time.strftime("%Y%m%d_%H%M%S") file_path = f'{results_path}/optimize_selectivity.{run_ts}.yaml' param_names = opt_param_config.param_names param_tuples = opt_param_config.param_tuples if ProblemRegime[problem_regime] == ProblemRegime.every: results_config_dict = {} for gid, prms in viewitems(best): n_res = prms[0][1].shape[0] prms_dict = dict(prms) this_results_config_dict = {} for i in range(n_res): results_param_list = [] for param_pattern, param_tuple in zip(param_names, param_tuples): results_param_list.append((param_tuple.population, param_tuple.source, param_tuple.sec_type, param_tuple.syn_name, param_tuple.param_path, float(prms_dict[param_pattern][i]))) this_results_config_dict[i] = results_param_list results_config_dict[gid] = this_results_config_dict else: prms = best[0] n_res = prms[0][1].shape[0] prms_dict = dict(prms) results_config_dict = {} for i in range(n_res): results_param_list = [] for param_pattern, param_tuple in zip(param_names, param_tuples): results_param_list.append((param_tuple.population, param_tuple.source, param_tuple.sec_type, param_tuple.syn_name, param_tuple.param_path, float(prms_dict[param_pattern][i]))) results_config_dict[i] = results_param_list write_to_yaml(file_path, { population: results_config_dict } ) comm.barrier()
def optimize_run(env, population, param_config_name, selectivity_config_name, init_objfun, problem_regime, nprocs_per_worker=1, n_epochs=10, n_initial=30, initial_maxiter=50, initial_method="slh", optimizer_method="nsga2", surrogate_method='vgp', population_size=200, num_generations=200, resample_fraction=None, mutation_rate=None, param_type='synaptic', init_params={}, results_file=None, cooperative_init=False, spawn_startup_wait=None, verbose=False): opt_param_config = optimization_params(env.netclamp_config.optimize_parameters, [population], param_config_name, param_type) opt_targets = opt_param_config.opt_targets param_names = opt_param_config.param_names param_tuples = opt_param_config.param_tuples hyperprm_space = { param_pattern: [param_tuple.param_range[0], param_tuple.param_range[1]] for param_pattern, param_tuple in zip(param_names, param_tuples) } if results_file is None: if env.results_path is not None: file_path = f'{env.results_path}/dmosopt.optimize_selectivity.{env.results_file_id}.h5' else: file_path = f'dmosopt.optimize_selectivity.{env.results_file_id}.h5' else: file_path = '%s/%s' % (env.results_path, results_file) problem_ids = None reduce_fun_name = None if ProblemRegime[problem_regime] == ProblemRegime.every: reduce_fun_name = "opt_reduce_every" problem_ids = init_params.get('cell_index_set', None) elif ProblemRegime[problem_regime] == ProblemRegime.mean: reduce_fun_name = "opt_reduce_mean" elif ProblemRegime[problem_regime] == ProblemRegime.max: reduce_fun_name = "opt_reduce_max" else: raise RuntimeError(f'optimize_run: unknown problem regime {problem_regime}') n_trials = init_params.get('n_trials', 1) nworkers = env.comm.size-1 if resample_fraction is None: resample_fraction = float(nworkers) / float(population_size) if resample_fraction > 1.0: resample_fraction = 1.0 if resample_fraction < 0.1: resample_fraction = 0.1 objective_names = ['residual_infld', 'residual_state'] feature_names = ['mean_peak_rate', 'mean_trough_rate', 'max_infld_rate', 'min_infld_rate', 'mean_infld_rate', 'mean_outfld_rate', 'mean_peak_state', 'mean_trough_state', 'mean_outfld_state'] N_objectives = 2 feature_dtypes = [(feature_name, np.float32) for feature_name in feature_names] feature_dtypes.append(('trial_objs', np.float32, (N_objectives, n_trials))) feature_dtypes.append(('trial_mean_infld_rate', (np.float32, (1, n_trials)))) feature_dtypes.append(('trial_mean_outfld_rate', (np.float32, (1, n_trials)))) constraint_names = ['positive_rate'] dmosopt_params = {'opt_id': 'dentate.optimize_selectivity', 'problem_ids': problem_ids, 'obj_fun_init_name': init_objfun, 'obj_fun_init_module': 'dentate.optimize_selectivity', 'obj_fun_init_args': init_params, 'reduce_fun_name': reduce_fun_name, 'reduce_fun_module': 'dentate.optimization', 'problem_parameters': {}, 'space': hyperprm_space, 'objective_names': objective_names, 'feature_dtypes': feature_dtypes, 'constraint_names': constraint_names, 'n_initial': n_initial, 'n_epochs': n_epochs, 'population_size': population_size, 'num_generations': num_generations, 'resample_fraction': resample_fraction, 'mutation_rate': mutation_rate, 'initial_maxiter': initial_maxiter, 'initial_method': initial_method, 'optimizer': optimizer_method, 'surrogate_method': surrogate_method, 'file_path': file_path, 'save': True, 'save_eval' : 5, } opt_results = dmosopt.run(dmosopt_params, verbose=verbose, collective_mode="sendrecv", spawn_workers=True, nprocs_per_worker=nprocs_per_worker, spawn_startup_wait=spawn_startup_wait ) if opt_results is not None: if ProblemRegime[problem_regime] == ProblemRegime.every: gid_results_config_dict = {} for gid, opt_result in viewitems(opt_results): params_dict = dict(opt_result[0]) result_value = opt_result[1] results_config_tuples = [] for param_pattern, param_tuple in zip(param_names, param_tuples): results_config_tuples.append((param_pattern, params_dict[param_pattern])) gid_results_config_dict[int(gid)] = results_config_tuples logger.info('Optimized parameters and objective function: ' f'{pprint.pformat(gid_results_config_dict)} @' f'{result_value}') return gid_results_config_dict else: params_dict = dict(opt_results[0]) result_value = opt_results[1] results_config_tuples = [] for param_pattern, param_tuple in zip(param_names, param_tuples): results_config_tuples.append((param_pattern, params_dict[param_pattern])) logger.info('Optimized parameters and objective function: ' f'{pprint.pformat(results_config_tuples)} @' f'{result_value}') return results_config_tuples else: return None
def init_selectivity_objfun(config_file, population, cell_index_set, arena_id, trajectory_id, n_trials, trial_regime, problem_regime, generate_weights, t_max, t_min, template_paths, dataset_prefix, config_prefix, results_path, spike_events_path, spike_events_namespace, spike_events_t, input_features_path, input_features_namespaces, param_type, param_config_name, selectivity_config_name, recording_profile, state_variable, state_filter, state_baseline, target_features_path, target_features_namespace, target_features_arena, target_features_trajectory, use_coreneuron, cooperative_init, dt, worker, **kwargs): params = dict(locals()) env = Env(**params) env.results_file_path = None configure_hoc_env(env, bcast_template=True) my_cell_index_set = init(env, population, cell_index_set, arena_id, trajectory_id, n_trials, spike_events_path, spike_events_namespace=spike_events_namespace, spike_train_attr_name=spike_events_t, input_features_path=input_features_path, input_features_namespaces=input_features_namespaces, generate_weights_pops=set(generate_weights), t_min=t_min, t_max=t_max, cooperative_init=cooperative_init, worker=worker) time_step = float(env.stimulus_config['Temporal Resolution']) equilibration_duration = float(env.stimulus_config.get('Equilibration Duration', 0.)) target_rate_vector_dict = rate_maps_from_features (env, population, cell_index_set=my_cell_index_set, input_features_path=target_features_path, input_features_namespace=target_features_namespace, time_range=[0., t_max], arena_id=arena_id) logger.info(f'target_rate_vector_dict = {target_rate_vector_dict}') for gid, target_rate_vector in viewitems(target_rate_vector_dict): target_rate_vector[np.isclose(target_rate_vector, 0., atol=1e-3, rtol=1e-3)] = 0. trj_x, trj_y, trj_d, trj_t = stimulus.read_trajectory(input_features_path if input_features_path is not None else spike_events_path, target_features_arena, target_features_trajectory) time_range = (0., min(np.max(trj_t), t_max)) time_bins = np.arange(time_range[0], time_range[1]+time_step, time_step) state_time_bins = np.arange(time_range[0], time_range[1], time_step)[:-1] def range_inds(rs): l = list(rs) if len(l) > 0: a = np.concatenate(l) else: a = None return a def time_ranges(rs): if len(rs) > 0: a = tuple( ( (time_bins[r[0]], time_bins[r[1]-1]) for r in rs ) ) else: a = None return a infld_idxs_dict = { gid: np.where(target_rate_vector > 1e-4)[0] for gid, target_rate_vector in viewitems(target_rate_vector_dict) } peak_pctile_dict = { gid: np.percentile(target_rate_vector_dict[gid][infld_idxs], 80) for gid, infld_idxs in viewitems(infld_idxs_dict) } trough_pctile_dict = { gid: np.percentile(target_rate_vector_dict[gid][infld_idxs], 20) for gid, infld_idxs in viewitems(infld_idxs_dict) } outfld_idxs_dict = { gid: range_inds(contiguous_ranges(target_rate_vector < 1e-4, return_indices=True)) for gid, target_rate_vector in viewitems(target_rate_vector_dict) } peak_idxs_dict = { gid: range_inds(contiguous_ranges(target_rate_vector >= peak_pctile_dict[gid], return_indices=True)) for gid, target_rate_vector in viewitems(target_rate_vector_dict) } trough_idxs_dict = { gid: range_inds(contiguous_ranges(np.logical_and(target_rate_vector > 0., target_rate_vector <= trough_pctile_dict[gid]), return_indices=True)) for gid, target_rate_vector in viewitems(target_rate_vector_dict) } outfld_ranges_dict = { gid: time_ranges(contiguous_ranges(target_rate_vector <= 0.) ) for gid, target_rate_vector in viewitems(target_rate_vector_dict) } infld_ranges_dict = { gid: time_ranges(contiguous_ranges(target_rate_vector > 0) ) for gid, target_rate_vector in viewitems(target_rate_vector_dict) } peak_ranges_dict = { gid: time_ranges(contiguous_ranges(target_rate_vector >= peak_pctile_dict[gid])) for gid, target_rate_vector in viewitems(target_rate_vector_dict) } trough_ranges_dict = { gid: time_ranges(contiguous_ranges(np.logical_and(target_rate_vector > 0., target_rate_vector <= trough_pctile_dict[gid]))) for gid, target_rate_vector in viewitems(target_rate_vector_dict) } large_fld_gids = [] for gid in my_cell_index_set: infld_idxs = infld_idxs_dict[gid] target_infld_rate_vector = target_rate_vector[infld_idxs] target_peak_rate_vector = target_rate_vector[peak_idxs_dict[gid]] target_trough_rate_vector = target_rate_vector[trough_idxs_dict[gid]] logger.info(f'selectivity objective: target peak/trough rate of gid {gid}: ' f'{peak_pctile_dict[gid]:.02f} {trough_pctile_dict[gid]:.02f}') logger.info(f'selectivity objective: mean target peak/trough rate of gid {gid}: ' f'{np.mean(target_peak_rate_vector):.02f} {np.mean(target_trough_rate_vector):.02f}') opt_param_config = optimization_params(env.netclamp_config.optimize_parameters, [population], param_config_name, param_type) selectivity_opt_param_config = selectivity_optimization_params(env.netclamp_config.optimize_parameters, [population], selectivity_config_name) opt_targets = opt_param_config.opt_targets param_names = opt_param_config.param_names param_tuples = opt_param_config.param_tuples N_objectives = 2 feature_names = ['mean_peak_rate', 'mean_trough_rate', 'max_infld_rate', 'min_infld_rate', 'mean_infld_rate', 'mean_outfld_rate', 'mean_peak_state', 'mean_trough_state', 'mean_outfld_state'] feature_dtypes = [(feature_name, np.float32) for feature_name in feature_names] feature_dtypes.append(('trial_objs', (np.float32, (N_objectives, n_trials)))) feature_dtypes.append(('trial_mean_infld_rate', (np.float32, (1, n_trials)))) feature_dtypes.append(('trial_mean_outfld_rate', (np.float32, (1, n_trials)))) def from_param_dict(params_dict): result = [] for param_pattern, param_tuple in zip(param_names, param_tuples): result.append((param_tuple, params_dict[param_pattern])) return result def update_run_params(input_param_tuple_vals, update_param_names, update_param_tuples): result = [] updated_set = set([]) update_param_dict = dict(zip(update_param_names, update_param_tuples)) for param_pattern, (param_tuple, param_val) in zip(param_names, input_param_tuple_vals): if param_pattern in update_param_dict: updated_set.add(param_pattern) result.append((param_tuple, update_param_dict[param_pattern].param_range)) else: result.append((param_tuple, param_val)) for update_param_name in update_param_dict: if update_param_name not in updated_set: result.append((update_param_dict[update_param_name], update_param_dict[update_param_name].param_range)) return result def gid_firing_rate_vectors(spkdict, cell_index_set): rates_dict = defaultdict(list) for i in range(n_trials): spkdict1 = {} for gid in cell_index_set: if gid in spkdict[population]: spkdict1[gid] = spkdict[population][gid][i] else: spkdict1[gid] = np.asarray([], dtype=np.float32) spike_density_dict = spikedata.spike_density_estimate (population, spkdict1, time_bins) for gid in cell_index_set: rate_vector = spike_density_dict[gid]['rate'] rate_vector[np.isclose(rate_vector, 0., atol=1e-3, rtol=1e-3)] = 0. rates_dict[gid].append(rate_vector) logger.info(f'selectivity objective: trial {i} firing rate min/max of gid {gid}: ' f'{np.min(rates_dict[gid]):.02f} / {np.max(rates_dict[gid]):.02f} Hz') return rates_dict def gid_state_values(spkdict, t_offset, n_trials, t_rec, state_recs_dict): t_vec = np.asarray(t_rec.to_python(), dtype=np.float32) t_trial_inds = get_trial_time_indices(t_vec, n_trials, t_offset) results_dict = {} filter_fun = None if state_filter == 'lowpass': filter_fun = lambda x, t: get_low_pass_filtered_trace(x, t) for gid in state_recs_dict: state_values = None state_recs = state_recs_dict[gid] assert(len(state_recs) == 1) rec = state_recs[0] vec = np.asarray(rec['vec'].to_python(), dtype=np.float32) if filter_fun is None: data = np.asarray([ vec[t_inds] for t_inds in t_trial_inds ]) else: data = np.asarray([ filter_fun(vec[t_inds], t_vec[t_inds]) for t_inds in t_trial_inds ]) state_values = [] max_len = np.max(np.asarray([len(a) for a in data])) for state_value_array in data: this_len = len(state_value_array) if this_len < max_len: a = np.pad(state_value_array, (0, max_len-this_len), 'edge') else: a = state_value_array state_values.append(a) results_dict[gid] = state_values return t_vec[t_trial_inds[0]], results_dict def trial_snr_residuals(gid, peak_idxs, trough_idxs, infld_idxs, outfld_idxs, rate_vectors, masked_rate_vectors, target_rate_vector): n_trials = len(rate_vectors) residual_inflds = [] trial_inflds = [] trial_outflds = [] target_infld = target_rate_vector[infld_idxs] target_max_infld = np.max(target_infld) target_mean_trough = np.mean(target_rate_vector[trough_idxs]) logger.info(f'selectivity objective: target max infld/mean trough of gid {gid}: ' f'{target_max_infld:.02f} {target_mean_trough:.02f}') for trial_i in range(n_trials): rate_vector = rate_vectors[trial_i] infld_rate_vector = rate_vector[infld_idxs] masked_rate_vector = masked_rate_vectors[trial_i] if outfld_idxs is None: outfld_rate_vector = masked_rate_vectors[trial_i] else: outfld_rate_vector = rate_vector[outfld_idxs] mean_peak = np.mean(rate_vector[peak_idxs]) mean_trough = np.mean(rate_vector[trough_idxs]) min_infld = np.min(infld_rate_vector) max_infld = np.max(infld_rate_vector) mean_infld = np.mean(infld_rate_vector) mean_outfld = np.mean(outfld_rate_vector) residual_infld = np.abs(np.sum(target_infld - infld_rate_vector)) logger.info(f'selectivity objective: max infld/mean infld/mean peak/trough/mean outfld/residual_infld of gid {gid} trial {trial_i}: ' f'{max_infld:.02f} {mean_infld:.02f} {mean_peak:.02f} {mean_trough:.02f} {mean_outfld:.02f} {residual_infld:.04f}') residual_inflds.append(residual_infld) trial_inflds.append(mean_infld) trial_outflds.append(mean_outfld) trial_rate_features = [np.asarray(trial_inflds, dtype=np.float32).reshape((1, n_trials)), np.asarray(trial_outflds, dtype=np.float32).reshape((1, n_trials))] rate_features = [mean_peak, mean_trough, max_infld, min_infld, mean_infld, mean_outfld, ] #rate_constr = [ mean_peak if max_infld > 0. else -1. ] rate_constr = [ mean_peak - mean_trough if max_infld > 0. else -1. ] return (np.asarray(residual_inflds), trial_rate_features, rate_features, rate_constr) def trial_state_residuals(gid, target_outfld, t_peak_idxs, t_trough_idxs, t_infld_idxs, t_outfld_idxs, state_values, masked_state_values): state_value_arrays = np.row_stack(state_values) masked_state_value_arrays = None if masked_state_values is not None: masked_state_value_arrays = np.row_stack(masked_state_values) residuals_outfld = [] peak_inflds = [] trough_inflds = [] mean_outflds = [] for i in range(state_value_arrays.shape[0]): state_value_array = state_value_arrays[i, :] peak_infld = np.mean(state_value_array[t_peak_idxs]) trough_infld = np.mean(state_value_array[t_trough_idxs]) mean_infld = np.mean(state_value_array[t_infld_idxs]) masked_state_value_array = masked_state_value_arrays[i, :] mean_masked = np.mean(masked_state_value_array) residual_masked = np.mean(masked_state_value_array) - target_outfld mean_outfld = mean_masked if t_outfld_idxs is not None: mean_outfld = np.mean(state_value_array[t_outfld_idxs]) peak_inflds.append(peak_infld) trough_inflds.append(trough_infld) mean_outflds.append(mean_outfld) residuals_outfld.append(residual_masked) logger.info(f'selectivity objective: state values of gid {gid}: ' f'peak/trough/mean in/mean out/masked: {peak_infld:.02f} / {trough_infld:.02f} / {mean_infld:.02f} / {mean_outfld:.02f} / residual masked: {residual_masked:.04f}') state_features = [np.mean(peak_inflds), np.mean(trough_inflds), np.mean(mean_outflds)] return (np.asarray(residuals_outfld), state_features) recording_profile = { 'label': f'optimize_selectivity.{state_variable}', 'section quantity': { state_variable: { 'swc types': ['soma'] } } } env.recording_profile = recording_profile state_recs_dict = {} for gid in my_cell_index_set: state_recs_dict[gid] = record_cell(env, population, gid, recording_profile=recording_profile) def eval_problem(cell_param_dict, **kwargs): run_params = {population: {gid: from_param_dict(cell_param_dict[gid]) for gid in my_cell_index_set}} masked_state_values_dict = {} masked_run_params = {population: { gid: update_run_params(run_params[population][gid], selectivity_opt_param_config.mask_param_names, selectivity_opt_param_config.mask_param_tuples) for gid in my_cell_index_set} } spkdict = run_with(env, run_params) rates_dict = gid_firing_rate_vectors(spkdict, my_cell_index_set) t_s, state_values_dict = gid_state_values(spkdict, equilibration_duration, n_trials, env.t_rec, state_recs_dict) masked_spkdict = run_with(env, masked_run_params) masked_rates_dict = gid_firing_rate_vectors(masked_spkdict, my_cell_index_set) t_s, masked_state_values_dict = gid_state_values(masked_spkdict, equilibration_duration, n_trials, env.t_rec, state_recs_dict) result = {} for gid in my_cell_index_set: infld_idxs = infld_idxs_dict[gid] outfld_idxs = outfld_idxs_dict[gid] peak_idxs = peak_idxs_dict[gid] trough_idxs = trough_idxs_dict[gid] target_rate_vector = target_rate_vector_dict[gid] peak_ranges = peak_ranges_dict[gid] trough_ranges = trough_ranges_dict[gid] infld_ranges = infld_ranges_dict[gid] outfld_ranges = outfld_ranges_dict[gid] t_peak_idxs = np.concatenate([ np.where(np.logical_and(t_s >= r[0], t_s < r[1]))[0] for r in peak_ranges ]) t_trough_idxs = np.concatenate([ np.where(np.logical_and(t_s >= r[0], t_s < r[1]))[0] for r in trough_ranges ]) t_infld_idxs = np.concatenate([ np.where(np.logical_and(t_s >= r[0], t_s < r[1]))[0] for r in infld_ranges ]) if outfld_ranges is not None: t_outfld_idxs = np.concatenate([ np.where(np.logical_and(t_s >= r[0], t_s < r[1]))[0] for r in outfld_ranges ]) else: t_outfld_idxs = None masked_state_values = masked_state_values_dict.get(gid, None) state_values = state_values_dict[gid] rate_vectors = rates_dict[gid] masked_rate_vectors = masked_rates_dict[gid] logger.info(f'selectivity objective: max rates of gid {gid}: ' f'{list([np.max(rate_vector) for rate_vector in rate_vectors])}') infld_residuals, trial_rate_features, rate_features, rate_constr = \ trial_snr_residuals(gid, peak_idxs, trough_idxs, infld_idxs, outfld_idxs, rate_vectors, masked_rate_vectors, target_rate_vector) state_residuals, state_features = trial_state_residuals(gid, state_baseline, t_peak_idxs, t_trough_idxs, t_infld_idxs, t_outfld_idxs, state_values, masked_state_values) trial_obj_features = np.row_stack((infld_residuals, state_residuals)) if trial_regime == 'mean': mean_infld_residual = np.mean(infld_residuals) mean_state_residual = np.mean(state_residuals) infld_objective = mean_infld_residual state_objective = abs(mean_state_residual) logger.info(f'selectivity objective: mean peak/trough/mean infld/mean outfld/mean state residual of gid {gid}: ' f'{mean_infld_residual:.04f} {mean_state_residual:.04f}') elif trial_regime == 'best': min_infld_residual_index = np.argmin(infld_residuals) min_infld_residual = infld_residuals[min_infld_index] infld_objective = min_infld_residual min_state_residual = np.min(np.abs(state_residuals)) state_objective = min_state_residual logger.info(f'selectivity objective: mean peak/trough/max infld/max outfld/min state residual of gid {gid}: ' f'{min_infld_residual:.04f} {min_state_residual:.04f}') else: raise RuntimeError(f'selectivity_rate_objective: unknown trial regime {trial_regime}') logger.info(f"rate_features: {rate_features} state_features: {state_features} obj_features: {trial_obj_features}") result[gid] = (np.asarray([ infld_objective, state_objective ], dtype=np.float32), np.array([tuple(rate_features+state_features+[trial_obj_features]+trial_rate_features)], dtype=np.dtype(feature_dtypes)), np.asarray(rate_constr, dtype=np.float32)) return result return opt_eval_fun(problem_regime, my_cell_index_set, eval_problem)
def main(config_path, target_features_path, target_features_namespace, optimize_file_dir, optimize_file_name, nprocs_per_worker, n_epochs, n_initial, initial_maxiter, initial_method, optimizer_method, population_size, num_generations, resample_fraction, mutation_rate, collective_mode, spawn_startup_wait, verbose): network_args = click.get_current_context().args network_config = {} for arg in network_args: kv = arg.split("=") if len(kv) > 1: k,v = kv network_config[k.replace('--', '').replace('-', '_')] = v else: k = kv[0] network_config[k.replace('--', '').replace('-', '_')] = True run_ts = datetime.datetime.today().strftime('%Y%m%d_%H%M') if optimize_file_name is None: optimize_file_name=f"dmosopt.optimize_network_{run_ts}.h5" operational_config = read_from_yaml(config_path) operational_config['run_ts'] = run_ts if target_features_path is not None: operational_config['target_features_path'] = target_features_path if target_features_namespace is not None: operational_config['target_features_namespace'] = target_features_namespace network_config.update(operational_config.get('kwargs', {})) env = Env(**network_config) objective_names = operational_config['objective_names'] param_config_name = operational_config['param_config_name'] target_populations = operational_config['target_populations'] opt_param_config = optimization_params(env.netclamp_config.optimize_parameters, target_populations, param_config_name) opt_targets = opt_param_config.opt_targets param_names = opt_param_config.param_names param_tuples = opt_param_config.param_tuples hyperprm_space = { param_pattern: [param_tuple.param_range[0], param_tuple.param_range[1]] for param_pattern, param_tuple in zip(param_names, param_tuples) } init_objfun = 'init_network_objfun' init_params = { 'operational_config': operational_config, 'opt_targets': opt_targets, 'param_tuples': [ param_tuple._asdict() for param_tuple in param_tuples ], 'param_names': param_names } init_params.update(network_config.items()) nworkers = env.comm.size-1 if resample_fraction is None: resample_fraction = float(nworkers) / float(population_size) if resample_fraction > 1.0: resample_fraction = 1.0 if resample_fraction < 0.1: resample_fraction = 0.1 # Create an optimizer feature_dtypes = [(feature_name, np.float32) for feature_name in objective_names] constraint_names = [f'{target_pop_name} positive rate' for target_pop_name in target_populations ] dmosopt_params = {'opt_id': 'dentate.optimize_network', 'obj_fun_init_name': init_objfun, 'obj_fun_init_module': 'dentate.optimize_network', 'obj_fun_init_args': init_params, 'reduce_fun_name': 'compute_objectives', 'reduce_fun_module': 'dentate.optimize_network', 'reduce_fun_args': (operational_config, opt_targets), 'problem_parameters': {}, 'space': hyperprm_space, 'objective_names': objective_names, 'feature_dtypes': feature_dtypes, 'constraint_names': constraint_names, 'n_initial': n_initial, 'initial_maxiter': initial_maxiter, 'initial_method': initial_method, 'optimizer': optimizer_method, 'n_epochs': n_epochs, 'population_size': population_size, 'num_generations': num_generations, 'resample_fraction': resample_fraction, 'mutation_rate': mutation_rate, 'file_path': f'{optimize_file_dir}/{optimize_file_name}', 'termination_conditions': True, 'save_surrogate_eval': True, 'save': True, 'save_eval': 5 } #dmosopt_params['broker_fun_name'] = 'dmosopt_broker_init' #dmosopt_params['broker_module_name'] = 'dentate.optimize_network' best = dmosopt.run(dmosopt_params, spawn_workers=True, sequential_spawn=False, spawn_startup_wait=spawn_startup_wait, nprocs_per_worker=nprocs_per_worker, collective_mode=collective_mode, verbose=True, worker_debug=True) if best is not None: if optimize_file_dir is not None: results_file_id = 'DG_optimize_network_%s' % run_ts yaml_file_path = '%s/optimize_network.%s.yaml' % (optimize_file_dir, str(results_file_id)) prms = best[0] prms_dict = dict(prms) n_res = prms[0][1].shape[0] results_config_dict = {} for i in range(n_res): result_param_list = [] for param_pattern, param_tuple in zip(param_names, param_tuples): result_param_list.append([param_pattern, float(prms_dict[param_pattern][i])]) results_config_dict[i] = result_param_list write_to_yaml(yaml_file_path, results_config_dict)