def init_network_objfun(operational_config, opt_targets, param_names, param_tuples, worker, **kwargs): param_tuples = [ syn_param_from_dict(param_tuple) for param_tuple in param_tuples ] objective_names = operational_config['objective_names'] target_populations = operational_config['target_populations'] target_features_path = operational_config['target_features_path'] target_features_namespace = operational_config['target_features_namespace'] kwargs['results_file_id'] = 'DG_optimize_network_%d_%s' % \ (worker.worker_id, operational_config['run_ts']) logger = utils.get_script_logger(os.path.basename(__file__)) env = init_network(comm=MPI.COMM_WORLD, kwargs=kwargs) gc.collect() t_start = 50. t_stop = env.tstop time_range = (t_start, t_stop) target_trj_rate_map_dict = {} target_features_arena = env.arena_id target_features_trajectory = env.trajectory_id for pop_name in target_populations: if ('%s target rate dist residual' % pop_name) not in objective_names: continue my_cell_index_set = set(env.biophys_cells[pop_name].keys()) trj_rate_maps = {} trj_rate_maps = rate_maps_from_features( env, pop_name, cell_index_set=list(my_cell_index_set), input_features_path=target_features_path, input_features_namespace=target_features_namespace, time_range=time_range) target_trj_rate_map_dict[pop_name] = trj_rate_maps def from_param_dict(params_dict): result = [] for param_name, param_tuple in zip(param_names, param_tuples): result.append((param_tuple, params_dict[param_name])) return result return partial(network_objfun, env, operational_config, opt_targets, target_trj_rate_map_dict, from_param_dict, t_start, t_stop, target_populations)
def init_selectivity_objfun(config_file, population, cell_index_set, arena_id, trajectory_id, n_trials, trial_regime, problem_regime, generate_weights, t_max, t_min, template_paths, dataset_prefix, config_prefix, results_path, spike_events_path, spike_events_namespace, spike_events_t, input_features_path, input_features_namespaces, param_type, param_config_name, selectivity_config_name, recording_profile, state_variable, state_filter, state_baseline, target_features_path, target_features_namespace, target_features_arena, target_features_trajectory, use_coreneuron, cooperative_init, dt, worker, **kwargs): params = dict(locals()) env = Env(**params) env.results_file_path = None configure_hoc_env(env, bcast_template=True) my_cell_index_set = init(env, population, cell_index_set, arena_id, trajectory_id, n_trials, spike_events_path, spike_events_namespace=spike_events_namespace, spike_train_attr_name=spike_events_t, input_features_path=input_features_path, input_features_namespaces=input_features_namespaces, generate_weights_pops=set(generate_weights), t_min=t_min, t_max=t_max, cooperative_init=cooperative_init, worker=worker) time_step = float(env.stimulus_config['Temporal Resolution']) equilibration_duration = float(env.stimulus_config.get('Equilibration Duration', 0.)) target_rate_vector_dict = rate_maps_from_features (env, population, cell_index_set=my_cell_index_set, input_features_path=target_features_path, input_features_namespace=target_features_namespace, time_range=[0., t_max], arena_id=arena_id) logger.info(f'target_rate_vector_dict = {target_rate_vector_dict}') for gid, target_rate_vector in viewitems(target_rate_vector_dict): target_rate_vector[np.isclose(target_rate_vector, 0., atol=1e-3, rtol=1e-3)] = 0. trj_x, trj_y, trj_d, trj_t = stimulus.read_trajectory(input_features_path if input_features_path is not None else spike_events_path, target_features_arena, target_features_trajectory) time_range = (0., min(np.max(trj_t), t_max)) time_bins = np.arange(time_range[0], time_range[1]+time_step, time_step) state_time_bins = np.arange(time_range[0], time_range[1], time_step)[:-1] def range_inds(rs): l = list(rs) if len(l) > 0: a = np.concatenate(l) else: a = None return a def time_ranges(rs): if len(rs) > 0: a = tuple( ( (time_bins[r[0]], time_bins[r[1]-1]) for r in rs ) ) else: a = None return a infld_idxs_dict = { gid: np.where(target_rate_vector > 1e-4)[0] for gid, target_rate_vector in viewitems(target_rate_vector_dict) } peak_pctile_dict = { gid: np.percentile(target_rate_vector_dict[gid][infld_idxs], 80) for gid, infld_idxs in viewitems(infld_idxs_dict) } trough_pctile_dict = { gid: np.percentile(target_rate_vector_dict[gid][infld_idxs], 20) for gid, infld_idxs in viewitems(infld_idxs_dict) } outfld_idxs_dict = { gid: range_inds(contiguous_ranges(target_rate_vector < 1e-4, return_indices=True)) for gid, target_rate_vector in viewitems(target_rate_vector_dict) } peak_idxs_dict = { gid: range_inds(contiguous_ranges(target_rate_vector >= peak_pctile_dict[gid], return_indices=True)) for gid, target_rate_vector in viewitems(target_rate_vector_dict) } trough_idxs_dict = { gid: range_inds(contiguous_ranges(np.logical_and(target_rate_vector > 0., target_rate_vector <= trough_pctile_dict[gid]), return_indices=True)) for gid, target_rate_vector in viewitems(target_rate_vector_dict) } outfld_ranges_dict = { gid: time_ranges(contiguous_ranges(target_rate_vector <= 0.) ) for gid, target_rate_vector in viewitems(target_rate_vector_dict) } infld_ranges_dict = { gid: time_ranges(contiguous_ranges(target_rate_vector > 0) ) for gid, target_rate_vector in viewitems(target_rate_vector_dict) } peak_ranges_dict = { gid: time_ranges(contiguous_ranges(target_rate_vector >= peak_pctile_dict[gid])) for gid, target_rate_vector in viewitems(target_rate_vector_dict) } trough_ranges_dict = { gid: time_ranges(contiguous_ranges(np.logical_and(target_rate_vector > 0., target_rate_vector <= trough_pctile_dict[gid]))) for gid, target_rate_vector in viewitems(target_rate_vector_dict) } large_fld_gids = [] for gid in my_cell_index_set: infld_idxs = infld_idxs_dict[gid] target_infld_rate_vector = target_rate_vector[infld_idxs] target_peak_rate_vector = target_rate_vector[peak_idxs_dict[gid]] target_trough_rate_vector = target_rate_vector[trough_idxs_dict[gid]] logger.info(f'selectivity objective: target peak/trough rate of gid {gid}: ' f'{peak_pctile_dict[gid]:.02f} {trough_pctile_dict[gid]:.02f}') logger.info(f'selectivity objective: mean target peak/trough rate of gid {gid}: ' f'{np.mean(target_peak_rate_vector):.02f} {np.mean(target_trough_rate_vector):.02f}') opt_param_config = optimization_params(env.netclamp_config.optimize_parameters, [population], param_config_name, param_type) selectivity_opt_param_config = selectivity_optimization_params(env.netclamp_config.optimize_parameters, [population], selectivity_config_name) opt_targets = opt_param_config.opt_targets param_names = opt_param_config.param_names param_tuples = opt_param_config.param_tuples N_objectives = 2 feature_names = ['mean_peak_rate', 'mean_trough_rate', 'max_infld_rate', 'min_infld_rate', 'mean_infld_rate', 'mean_outfld_rate', 'mean_peak_state', 'mean_trough_state', 'mean_outfld_state'] feature_dtypes = [(feature_name, np.float32) for feature_name in feature_names] feature_dtypes.append(('trial_objs', (np.float32, (N_objectives, n_trials)))) feature_dtypes.append(('trial_mean_infld_rate', (np.float32, (1, n_trials)))) feature_dtypes.append(('trial_mean_outfld_rate', (np.float32, (1, n_trials)))) def from_param_dict(params_dict): result = [] for param_pattern, param_tuple in zip(param_names, param_tuples): result.append((param_tuple, params_dict[param_pattern])) return result def update_run_params(input_param_tuple_vals, update_param_names, update_param_tuples): result = [] updated_set = set([]) update_param_dict = dict(zip(update_param_names, update_param_tuples)) for param_pattern, (param_tuple, param_val) in zip(param_names, input_param_tuple_vals): if param_pattern in update_param_dict: updated_set.add(param_pattern) result.append((param_tuple, update_param_dict[param_pattern].param_range)) else: result.append((param_tuple, param_val)) for update_param_name in update_param_dict: if update_param_name not in updated_set: result.append((update_param_dict[update_param_name], update_param_dict[update_param_name].param_range)) return result def gid_firing_rate_vectors(spkdict, cell_index_set): rates_dict = defaultdict(list) for i in range(n_trials): spkdict1 = {} for gid in cell_index_set: if gid in spkdict[population]: spkdict1[gid] = spkdict[population][gid][i] else: spkdict1[gid] = np.asarray([], dtype=np.float32) spike_density_dict = spikedata.spike_density_estimate (population, spkdict1, time_bins) for gid in cell_index_set: rate_vector = spike_density_dict[gid]['rate'] rate_vector[np.isclose(rate_vector, 0., atol=1e-3, rtol=1e-3)] = 0. rates_dict[gid].append(rate_vector) logger.info(f'selectivity objective: trial {i} firing rate min/max of gid {gid}: ' f'{np.min(rates_dict[gid]):.02f} / {np.max(rates_dict[gid]):.02f} Hz') return rates_dict def gid_state_values(spkdict, t_offset, n_trials, t_rec, state_recs_dict): t_vec = np.asarray(t_rec.to_python(), dtype=np.float32) t_trial_inds = get_trial_time_indices(t_vec, n_trials, t_offset) results_dict = {} filter_fun = None if state_filter == 'lowpass': filter_fun = lambda x, t: get_low_pass_filtered_trace(x, t) for gid in state_recs_dict: state_values = None state_recs = state_recs_dict[gid] assert(len(state_recs) == 1) rec = state_recs[0] vec = np.asarray(rec['vec'].to_python(), dtype=np.float32) if filter_fun is None: data = np.asarray([ vec[t_inds] for t_inds in t_trial_inds ]) else: data = np.asarray([ filter_fun(vec[t_inds], t_vec[t_inds]) for t_inds in t_trial_inds ]) state_values = [] max_len = np.max(np.asarray([len(a) for a in data])) for state_value_array in data: this_len = len(state_value_array) if this_len < max_len: a = np.pad(state_value_array, (0, max_len-this_len), 'edge') else: a = state_value_array state_values.append(a) results_dict[gid] = state_values return t_vec[t_trial_inds[0]], results_dict def trial_snr_residuals(gid, peak_idxs, trough_idxs, infld_idxs, outfld_idxs, rate_vectors, masked_rate_vectors, target_rate_vector): n_trials = len(rate_vectors) residual_inflds = [] trial_inflds = [] trial_outflds = [] target_infld = target_rate_vector[infld_idxs] target_max_infld = np.max(target_infld) target_mean_trough = np.mean(target_rate_vector[trough_idxs]) logger.info(f'selectivity objective: target max infld/mean trough of gid {gid}: ' f'{target_max_infld:.02f} {target_mean_trough:.02f}') for trial_i in range(n_trials): rate_vector = rate_vectors[trial_i] infld_rate_vector = rate_vector[infld_idxs] masked_rate_vector = masked_rate_vectors[trial_i] if outfld_idxs is None: outfld_rate_vector = masked_rate_vectors[trial_i] else: outfld_rate_vector = rate_vector[outfld_idxs] mean_peak = np.mean(rate_vector[peak_idxs]) mean_trough = np.mean(rate_vector[trough_idxs]) min_infld = np.min(infld_rate_vector) max_infld = np.max(infld_rate_vector) mean_infld = np.mean(infld_rate_vector) mean_outfld = np.mean(outfld_rate_vector) residual_infld = np.abs(np.sum(target_infld - infld_rate_vector)) logger.info(f'selectivity objective: max infld/mean infld/mean peak/trough/mean outfld/residual_infld of gid {gid} trial {trial_i}: ' f'{max_infld:.02f} {mean_infld:.02f} {mean_peak:.02f} {mean_trough:.02f} {mean_outfld:.02f} {residual_infld:.04f}') residual_inflds.append(residual_infld) trial_inflds.append(mean_infld) trial_outflds.append(mean_outfld) trial_rate_features = [np.asarray(trial_inflds, dtype=np.float32).reshape((1, n_trials)), np.asarray(trial_outflds, dtype=np.float32).reshape((1, n_trials))] rate_features = [mean_peak, mean_trough, max_infld, min_infld, mean_infld, mean_outfld, ] #rate_constr = [ mean_peak if max_infld > 0. else -1. ] rate_constr = [ mean_peak - mean_trough if max_infld > 0. else -1. ] return (np.asarray(residual_inflds), trial_rate_features, rate_features, rate_constr) def trial_state_residuals(gid, target_outfld, t_peak_idxs, t_trough_idxs, t_infld_idxs, t_outfld_idxs, state_values, masked_state_values): state_value_arrays = np.row_stack(state_values) masked_state_value_arrays = None if masked_state_values is not None: masked_state_value_arrays = np.row_stack(masked_state_values) residuals_outfld = [] peak_inflds = [] trough_inflds = [] mean_outflds = [] for i in range(state_value_arrays.shape[0]): state_value_array = state_value_arrays[i, :] peak_infld = np.mean(state_value_array[t_peak_idxs]) trough_infld = np.mean(state_value_array[t_trough_idxs]) mean_infld = np.mean(state_value_array[t_infld_idxs]) masked_state_value_array = masked_state_value_arrays[i, :] mean_masked = np.mean(masked_state_value_array) residual_masked = np.mean(masked_state_value_array) - target_outfld mean_outfld = mean_masked if t_outfld_idxs is not None: mean_outfld = np.mean(state_value_array[t_outfld_idxs]) peak_inflds.append(peak_infld) trough_inflds.append(trough_infld) mean_outflds.append(mean_outfld) residuals_outfld.append(residual_masked) logger.info(f'selectivity objective: state values of gid {gid}: ' f'peak/trough/mean in/mean out/masked: {peak_infld:.02f} / {trough_infld:.02f} / {mean_infld:.02f} / {mean_outfld:.02f} / residual masked: {residual_masked:.04f}') state_features = [np.mean(peak_inflds), np.mean(trough_inflds), np.mean(mean_outflds)] return (np.asarray(residuals_outfld), state_features) recording_profile = { 'label': f'optimize_selectivity.{state_variable}', 'section quantity': { state_variable: { 'swc types': ['soma'] } } } env.recording_profile = recording_profile state_recs_dict = {} for gid in my_cell_index_set: state_recs_dict[gid] = record_cell(env, population, gid, recording_profile=recording_profile) def eval_problem(cell_param_dict, **kwargs): run_params = {population: {gid: from_param_dict(cell_param_dict[gid]) for gid in my_cell_index_set}} masked_state_values_dict = {} masked_run_params = {population: { gid: update_run_params(run_params[population][gid], selectivity_opt_param_config.mask_param_names, selectivity_opt_param_config.mask_param_tuples) for gid in my_cell_index_set} } spkdict = run_with(env, run_params) rates_dict = gid_firing_rate_vectors(spkdict, my_cell_index_set) t_s, state_values_dict = gid_state_values(spkdict, equilibration_duration, n_trials, env.t_rec, state_recs_dict) masked_spkdict = run_with(env, masked_run_params) masked_rates_dict = gid_firing_rate_vectors(masked_spkdict, my_cell_index_set) t_s, masked_state_values_dict = gid_state_values(masked_spkdict, equilibration_duration, n_trials, env.t_rec, state_recs_dict) result = {} for gid in my_cell_index_set: infld_idxs = infld_idxs_dict[gid] outfld_idxs = outfld_idxs_dict[gid] peak_idxs = peak_idxs_dict[gid] trough_idxs = trough_idxs_dict[gid] target_rate_vector = target_rate_vector_dict[gid] peak_ranges = peak_ranges_dict[gid] trough_ranges = trough_ranges_dict[gid] infld_ranges = infld_ranges_dict[gid] outfld_ranges = outfld_ranges_dict[gid] t_peak_idxs = np.concatenate([ np.where(np.logical_and(t_s >= r[0], t_s < r[1]))[0] for r in peak_ranges ]) t_trough_idxs = np.concatenate([ np.where(np.logical_and(t_s >= r[0], t_s < r[1]))[0] for r in trough_ranges ]) t_infld_idxs = np.concatenate([ np.where(np.logical_and(t_s >= r[0], t_s < r[1]))[0] for r in infld_ranges ]) if outfld_ranges is not None: t_outfld_idxs = np.concatenate([ np.where(np.logical_and(t_s >= r[0], t_s < r[1]))[0] for r in outfld_ranges ]) else: t_outfld_idxs = None masked_state_values = masked_state_values_dict.get(gid, None) state_values = state_values_dict[gid] rate_vectors = rates_dict[gid] masked_rate_vectors = masked_rates_dict[gid] logger.info(f'selectivity objective: max rates of gid {gid}: ' f'{list([np.max(rate_vector) for rate_vector in rate_vectors])}') infld_residuals, trial_rate_features, rate_features, rate_constr = \ trial_snr_residuals(gid, peak_idxs, trough_idxs, infld_idxs, outfld_idxs, rate_vectors, masked_rate_vectors, target_rate_vector) state_residuals, state_features = trial_state_residuals(gid, state_baseline, t_peak_idxs, t_trough_idxs, t_infld_idxs, t_outfld_idxs, state_values, masked_state_values) trial_obj_features = np.row_stack((infld_residuals, state_residuals)) if trial_regime == 'mean': mean_infld_residual = np.mean(infld_residuals) mean_state_residual = np.mean(state_residuals) infld_objective = mean_infld_residual state_objective = abs(mean_state_residual) logger.info(f'selectivity objective: mean peak/trough/mean infld/mean outfld/mean state residual of gid {gid}: ' f'{mean_infld_residual:.04f} {mean_state_residual:.04f}') elif trial_regime == 'best': min_infld_residual_index = np.argmin(infld_residuals) min_infld_residual = infld_residuals[min_infld_index] infld_objective = min_infld_residual min_state_residual = np.min(np.abs(state_residuals)) state_objective = min_state_residual logger.info(f'selectivity objective: mean peak/trough/max infld/max outfld/min state residual of gid {gid}: ' f'{min_infld_residual:.04f} {min_state_residual:.04f}') else: raise RuntimeError(f'selectivity_rate_objective: unknown trial regime {trial_regime}') logger.info(f"rate_features: {rate_features} state_features: {state_features} obj_features: {trial_obj_features}") result[gid] = (np.asarray([ infld_objective, state_objective ], dtype=np.float32), np.array([tuple(rate_features+state_features+[trial_obj_features]+trial_rate_features)], dtype=np.dtype(feature_dtypes)), np.asarray(rate_constr, dtype=np.float32)) return result return opt_eval_fun(problem_regime, my_cell_index_set, eval_problem)
def main(config_path, params_id, n_samples, target_features_path, target_features_namespace, output_file_dir, output_file_name, verbose): config_logging(verbose) logger = utils.get_script_logger(os.path.basename(__file__)) if params_id is None: if n_samples is not None: logger.info("Generating parameter lattice ...") generate_param_lattice(config_path, n_samples, output_file_dir, output_file_name, verbose) return network_args = click.get_current_context().args network_config = {} for arg in network_args: kv = arg.split("=") if len(kv) > 1: k,v = kv network_config[k.replace('--', '').replace('-', '_')] = v else: k = kv[0] network_config[k.replace('--', '').replace('-', '_')] = True run_ts = datetime.datetime.today().strftime('%Y%m%d_%H%M') if output_file_name is None: output_file_name=f"network_features_{run_ts}.h5" output_path = f'{output_file_dir}/{output_file_name}' eval_config = read_from_yaml(config_path, include_loader=utils.IncludeLoader) eval_config['run_ts'] = run_ts if target_features_path is not None: eval_config['target_features_path'] = target_features_path if target_features_namespace is not None: eval_config['target_features_namespace'] = target_features_namespace network_param_spec_src = eval_config['param_spec'] network_param_values = eval_config['param_values'] feature_names = eval_config['feature_names'] target_populations = eval_config['target_populations'] network_config.update(eval_config.get('kwargs', {})) if params_id is None: network_config['results_file_id'] = f"DG_eval_network_{eval_config['run_ts']}" else: network_config['results_file_id'] = f"DG_eval_network_{params_id}_{eval_config['run_ts']}" env = init_network(comm=MPI.COMM_WORLD, kwargs=network_config) gc.collect() t_start = 50. t_stop = env.tstop time_range = (t_start, t_stop) target_trj_rate_map_dict = {} target_features_arena = env.arena_id target_features_trajectory = env.trajectory_id for pop_name in target_populations: if ('%s target rate dist residual' % pop_name) not in feature_names: continue my_cell_index_set = set(env.biophys_cells[pop_name].keys()) trj_rate_maps = {} trj_rate_maps = rate_maps_from_features(env, pop_name, cell_index_set=list(my_cell_index_set), input_features_path=target_features_path, input_features_namespace=target_features_namespace, time_range=time_range) target_trj_rate_map_dict[pop_name] = trj_rate_maps network_param_spec = make_param_spec(target_populations, network_param_spec_src) def from_param_list(x): result = [] for pop_param in x: this_population, source, sec_type, syn_name, param_path, param_val = pop_param param_tuple = SynParam(this_population, source, sec_type, syn_name, param_path, None) result.append((param_tuple, param_val)) return result def from_param_dict(x): result = [] for pop_name, param_specs in viewitems(x): keyfun = lambda kv: str(kv[0]) for source, source_dict in sorted(viewitems(param_specs), key=keyfun): for sec_type, sec_type_dict in sorted(viewitems(source_dict), key=keyfun): for syn_name, syn_mech_dict in sorted(viewitems(sec_type_dict), key=keyfun): for param_fst, param_rst in sorted(viewitems(syn_mech_dict), key=keyfun): if isinstance(param_rst, dict): for const_name, const_value in sorted(viewitems(param_rst)): param_path = (param_fst, const_name) param_tuple = SynParam(pop_name, source, sec_type, syn_name, param_path, const_value) result.append(param_tuple, const_value) else: param_name = param_fst param_value = param_rst param_tuple = SynParam(pop_name, source, sec_type, syn_name, param_name, param_value) result.append(param_tuple, param_value) return result eval_network(env, network_config, from_param_list, from_param_dict, network_param_spec, network_param_values, params_id, target_trj_rate_map_dict, t_start, t_stop, target_populations, output_path)
def config_worker(): """ """ if 'debug' not in context(): context.debug = False if context.debug: if context.comm.rank == 1: print('# of parameters: %i' % len(context.param_names)) print('param_names: ', context.param_names) print('target_val: ', context.target_val) print('target_range: ', context.target_range) print('param_tuples: ', context.param_tuples) sys.stdout.flush() utils.config_logging(context.verbose) context.logger = utils.get_script_logger(os.path.basename(__file__)) # TODO: Do you want this to be identical on all ranks in a subworld? You can use context.comm.bcast if 'results_file_id' not in context(): context.results_file_id = 'DG_optimize_network_subworlds_%s_%s' % \ (context.interface.worker_id, datetime.datetime.today().strftime('%Y%m%d_%H%M')) # 'env' might be in context on controller, but it needs to be re-built when the controller is in a worker subworld try: if context.debug: print( 'debug: config_worker; local_comm.rank: %i/%i; global_comm.rank: %i/%i' % (context.comm.rank, context.comm.size, context.global_comm.rank, context.global_comm.size)) sys.stdout.flush() init_network() except Exception as err: context.logger.exception(err) raise err if 't_start' not in context(): context.t_start = 50. else: context.t_start = float(context.t_start) if 't_stop' not in context(): context.t_stop = context.env.tstop else: context.t_stop = float(context.t_stop) time_range = (context.t_start, context.t_stop) try: if context.debug: if context.global_comm.rank == 0: print('t_start: %.1f, t_stop: %.1f' % (context.t_start, context.t_stop)) except Exception as err: context.logger.exception(err) raise err context.target_trj_rate_map_dict = {} target_rate_map_path = context.target_rate_map_path target_rate_map_namespace = context.target_rate_map_namespace target_rate_map_arena = context.env.arena_id target_rate_map_trajectory = context.env.trajectory_id for pop_name in context.target_populations: my_cell_index_set = set(context.env.biophys_cells[pop_name].keys()) trj_rate_maps = rate_maps_from_features( context.env, pop_name, cell_index_set=list(my_cell_index_set), input_features_path=target_rate_map_path, input_features_namespace=target_rate_map_namespace, time_range=time_range) if len(trj_rate_maps) > 0: context.target_trj_rate_map_dict[pop_name] = trj_rate_maps