def init_network_objfun(operational_config, opt_targets, param_names,
                        param_tuples, worker, **kwargs):

    param_tuples = [
        syn_param_from_dict(param_tuple) for param_tuple in param_tuples
    ]

    objective_names = operational_config['objective_names']
    target_populations = operational_config['target_populations']
    target_features_path = operational_config['target_features_path']
    target_features_namespace = operational_config['target_features_namespace']
    kwargs['results_file_id'] = 'DG_optimize_network_%d_%s' % \
                                (worker.worker_id, operational_config['run_ts'])

    logger = utils.get_script_logger(os.path.basename(__file__))
    env = init_network(comm=MPI.COMM_WORLD, kwargs=kwargs)
    gc.collect()

    t_start = 50.
    t_stop = env.tstop
    time_range = (t_start, t_stop)

    target_trj_rate_map_dict = {}
    target_features_arena = env.arena_id
    target_features_trajectory = env.trajectory_id
    for pop_name in target_populations:
        if ('%s target rate dist residual' % pop_name) not in objective_names:
            continue
        my_cell_index_set = set(env.biophys_cells[pop_name].keys())
        trj_rate_maps = {}
        trj_rate_maps = rate_maps_from_features(
            env,
            pop_name,
            cell_index_set=list(my_cell_index_set),
            input_features_path=target_features_path,
            input_features_namespace=target_features_namespace,
            time_range=time_range)
        target_trj_rate_map_dict[pop_name] = trj_rate_maps

    def from_param_dict(params_dict):
        result = []
        for param_name, param_tuple in zip(param_names, param_tuples):
            result.append((param_tuple, params_dict[param_name]))
        return result

    return partial(network_objfun, env, operational_config, opt_targets,
                   target_trj_rate_map_dict, from_param_dict, t_start, t_stop,
                   target_populations)
def init_selectivity_objfun(config_file, population, cell_index_set, arena_id, trajectory_id,
                            n_trials, trial_regime, problem_regime,
                            generate_weights, t_max, t_min,
                            template_paths, dataset_prefix, config_prefix, results_path,
                            spike_events_path, spike_events_namespace, spike_events_t,
                            input_features_path, input_features_namespaces,
                            param_type, param_config_name, selectivity_config_name, recording_profile, 
                            state_variable, state_filter, state_baseline,
                            target_features_path, target_features_namespace,
                            target_features_arena, target_features_trajectory,   
                            use_coreneuron, cooperative_init, dt, worker, **kwargs):
    
    params = dict(locals())
    env = Env(**params)
    env.results_file_path = None
    configure_hoc_env(env, bcast_template=True)

    my_cell_index_set = init(env, population, cell_index_set, arena_id, trajectory_id, n_trials,
                             spike_events_path, spike_events_namespace=spike_events_namespace, 
                             spike_train_attr_name=spike_events_t,
                             input_features_path=input_features_path,
                             input_features_namespaces=input_features_namespaces,
                             generate_weights_pops=set(generate_weights), 
                             t_min=t_min, t_max=t_max, cooperative_init=cooperative_init,
                             worker=worker)

    time_step = float(env.stimulus_config['Temporal Resolution'])
    equilibration_duration = float(env.stimulus_config.get('Equilibration Duration', 0.))
    
    target_rate_vector_dict = rate_maps_from_features (env, population,
                                                       cell_index_set=my_cell_index_set, 
                                                       input_features_path=target_features_path,
                                                       input_features_namespace=target_features_namespace, 
                                                       time_range=[0., t_max], 
                                                       arena_id=arena_id)


    logger.info(f'target_rate_vector_dict = {target_rate_vector_dict}')
    for gid, target_rate_vector in viewitems(target_rate_vector_dict):
        target_rate_vector[np.isclose(target_rate_vector, 0., atol=1e-3, rtol=1e-3)] = 0.

    trj_x, trj_y, trj_d, trj_t = stimulus.read_trajectory(input_features_path if input_features_path is not None else spike_events_path, 
                                                          target_features_arena, target_features_trajectory)
    time_range = (0., min(np.max(trj_t), t_max))
    time_bins = np.arange(time_range[0], time_range[1]+time_step, time_step)
    state_time_bins = np.arange(time_range[0], time_range[1], time_step)[:-1]

    def range_inds(rs):
        l = list(rs)
        if len(l) > 0:
            a = np.concatenate(l)
        else:
            a = None
        return a

    def time_ranges(rs):
        if len(rs) > 0:
            a = tuple( ( (time_bins[r[0]], time_bins[r[1]-1]) for r in rs ) )
        else:
            a = None
        return a
        
    
    infld_idxs_dict = { gid: np.where(target_rate_vector > 1e-4)[0] 
                        for gid, target_rate_vector in viewitems(target_rate_vector_dict) }
    peak_pctile_dict = { gid: np.percentile(target_rate_vector_dict[gid][infld_idxs], 80)
                         for gid, infld_idxs in viewitems(infld_idxs_dict) }
    trough_pctile_dict = { gid: np.percentile(target_rate_vector_dict[gid][infld_idxs], 20)
                           for gid, infld_idxs in viewitems(infld_idxs_dict) }
    outfld_idxs_dict = { gid: range_inds(contiguous_ranges(target_rate_vector < 1e-4, return_indices=True))
                        for gid, target_rate_vector in viewitems(target_rate_vector_dict) }

    peak_idxs_dict = { gid: range_inds(contiguous_ranges(target_rate_vector >= peak_pctile_dict[gid], return_indices=True)) 
                       for gid, target_rate_vector in viewitems(target_rate_vector_dict) }
    trough_idxs_dict = { gid: range_inds(contiguous_ranges(np.logical_and(target_rate_vector > 0., target_rate_vector <= trough_pctile_dict[gid]), return_indices=True))
                         for gid, target_rate_vector in viewitems(target_rate_vector_dict) }

    outfld_ranges_dict = { gid: time_ranges(contiguous_ranges(target_rate_vector <= 0.) ) 
                           for gid, target_rate_vector in viewitems(target_rate_vector_dict) }
    infld_ranges_dict = { gid: time_ranges(contiguous_ranges(target_rate_vector > 0) ) 
                          for gid, target_rate_vector in viewitems(target_rate_vector_dict) }

    peak_ranges_dict = { gid: time_ranges(contiguous_ranges(target_rate_vector >= peak_pctile_dict[gid]))
                         for gid, target_rate_vector in viewitems(target_rate_vector_dict) }
    trough_ranges_dict = { gid: time_ranges(contiguous_ranges(np.logical_and(target_rate_vector > 0., target_rate_vector <= trough_pctile_dict[gid])))
                         for gid, target_rate_vector in viewitems(target_rate_vector_dict) }

    large_fld_gids = []
    for gid in my_cell_index_set:

        infld_idxs = infld_idxs_dict[gid]

        target_infld_rate_vector = target_rate_vector[infld_idxs]
        target_peak_rate_vector = target_rate_vector[peak_idxs_dict[gid]]
        target_trough_rate_vector = target_rate_vector[trough_idxs_dict[gid]]

        logger.info(f'selectivity objective: target peak/trough rate of gid {gid}: '
                    f'{peak_pctile_dict[gid]:.02f} {trough_pctile_dict[gid]:.02f}')
        logger.info(f'selectivity objective: mean target peak/trough rate of gid {gid}: '
                    f'{np.mean(target_peak_rate_vector):.02f} {np.mean(target_trough_rate_vector):.02f}')
        
    opt_param_config = optimization_params(env.netclamp_config.optimize_parameters, [population], param_config_name, param_type)
    selectivity_opt_param_config = selectivity_optimization_params(env.netclamp_config.optimize_parameters, [population],
                                                                   selectivity_config_name)

    opt_targets = opt_param_config.opt_targets
    param_names = opt_param_config.param_names
    param_tuples = opt_param_config.param_tuples

    N_objectives = 2
    feature_names = ['mean_peak_rate', 'mean_trough_rate', 
                     'max_infld_rate', 'min_infld_rate', 'mean_infld_rate', 'mean_outfld_rate', 
                     'mean_peak_state', 'mean_trough_state', 'mean_outfld_state']
    feature_dtypes = [(feature_name, np.float32) for feature_name in feature_names]
    feature_dtypes.append(('trial_objs', (np.float32, (N_objectives, n_trials))))
    feature_dtypes.append(('trial_mean_infld_rate', (np.float32, (1, n_trials))))
    feature_dtypes.append(('trial_mean_outfld_rate', (np.float32, (1, n_trials))))

    def from_param_dict(params_dict):
        result = []
        for param_pattern, param_tuple in zip(param_names, param_tuples):
            result.append((param_tuple, params_dict[param_pattern]))

        return result

    def update_run_params(input_param_tuple_vals, update_param_names, update_param_tuples):
        result = []
        updated_set = set([])
        update_param_dict = dict(zip(update_param_names, update_param_tuples))
        for param_pattern, (param_tuple, param_val) in zip(param_names, input_param_tuple_vals):
            if param_pattern in update_param_dict:
                updated_set.add(param_pattern)
                result.append((param_tuple, update_param_dict[param_pattern].param_range))
            else:
                result.append((param_tuple, param_val))
        for update_param_name in update_param_dict:
            if update_param_name not in updated_set:
                result.append((update_param_dict[update_param_name], 
                               update_param_dict[update_param_name].param_range))

        return result
        
    
    def gid_firing_rate_vectors(spkdict, cell_index_set):
        rates_dict = defaultdict(list)
        for i in range(n_trials):
            spkdict1 = {}
            for gid in cell_index_set:
                if gid in spkdict[population]:
                    spkdict1[gid] = spkdict[population][gid][i]
                else:
                    spkdict1[gid] = np.asarray([], dtype=np.float32)
            spike_density_dict = spikedata.spike_density_estimate (population, spkdict1, time_bins)
            for gid in cell_index_set:
                rate_vector = spike_density_dict[gid]['rate']
                rate_vector[np.isclose(rate_vector, 0., atol=1e-3, rtol=1e-3)] = 0.
                rates_dict[gid].append(rate_vector)
                logger.info(f'selectivity objective: trial {i} firing rate min/max of gid {gid}: '
                            f'{np.min(rates_dict[gid]):.02f} / {np.max(rates_dict[gid]):.02f} Hz')

        return rates_dict

    def gid_state_values(spkdict, t_offset, n_trials, t_rec, state_recs_dict):
        t_vec = np.asarray(t_rec.to_python(), dtype=np.float32)
        t_trial_inds = get_trial_time_indices(t_vec, n_trials, t_offset)
        results_dict = {}
        filter_fun = None
        if state_filter == 'lowpass':
            filter_fun = lambda x, t: get_low_pass_filtered_trace(x, t)
        for gid in state_recs_dict:
            state_values = None
            state_recs = state_recs_dict[gid]
            assert(len(state_recs) == 1)
            rec = state_recs[0]
            vec = np.asarray(rec['vec'].to_python(), dtype=np.float32)
            if filter_fun is None:
                data = np.asarray([ vec[t_inds] for t_inds in t_trial_inds ])
            else:
                data = np.asarray([ filter_fun(vec[t_inds], t_vec[t_inds])
                                    for t_inds in t_trial_inds ])

            state_values = []
            max_len = np.max(np.asarray([len(a) for a in data]))
            for state_value_array in data:
                this_len = len(state_value_array)
                if this_len < max_len:
                    a = np.pad(state_value_array, (0, max_len-this_len), 'edge')
                else:
                    a = state_value_array
                state_values.append(a)

            results_dict[gid] = state_values
        return t_vec[t_trial_inds[0]], results_dict


    def trial_snr_residuals(gid, peak_idxs, trough_idxs, infld_idxs, outfld_idxs, 
                            rate_vectors, masked_rate_vectors, target_rate_vector):

        n_trials = len(rate_vectors)
        residual_inflds = []
        trial_inflds = []
        trial_outflds = []

        target_infld = target_rate_vector[infld_idxs]
        target_max_infld = np.max(target_infld)
        target_mean_trough = np.mean(target_rate_vector[trough_idxs])
        logger.info(f'selectivity objective: target max infld/mean trough of gid {gid}: '
                    f'{target_max_infld:.02f} {target_mean_trough:.02f}')
        for trial_i in range(n_trials):

            rate_vector = rate_vectors[trial_i]
            infld_rate_vector = rate_vector[infld_idxs]
            masked_rate_vector = masked_rate_vectors[trial_i]
            if outfld_idxs is None:
                outfld_rate_vector = masked_rate_vectors[trial_i]
            else:
                outfld_rate_vector = rate_vector[outfld_idxs]

            mean_peak = np.mean(rate_vector[peak_idxs])
            mean_trough = np.mean(rate_vector[trough_idxs])
            min_infld = np.min(infld_rate_vector)
            max_infld = np.max(infld_rate_vector)
            mean_infld = np.mean(infld_rate_vector)
            mean_outfld = np.mean(outfld_rate_vector)

            residual_infld = np.abs(np.sum(target_infld - infld_rate_vector))
            logger.info(f'selectivity objective: max infld/mean infld/mean peak/trough/mean outfld/residual_infld of gid {gid} trial {trial_i}: '
                        f'{max_infld:.02f} {mean_infld:.02f} {mean_peak:.02f} {mean_trough:.02f} {mean_outfld:.02f} {residual_infld:.04f}')
            residual_inflds.append(residual_infld)
            trial_inflds.append(mean_infld)
            trial_outflds.append(mean_outfld)

        trial_rate_features = [np.asarray(trial_inflds, dtype=np.float32).reshape((1, n_trials)), 
                               np.asarray(trial_outflds, dtype=np.float32).reshape((1, n_trials))]
        rate_features = [mean_peak, mean_trough, max_infld, min_infld, mean_infld, mean_outfld, ]
        #rate_constr = [ mean_peak if max_infld > 0. else -1. ]
        rate_constr = [ mean_peak - mean_trough if max_infld > 0. else -1. ]
        return (np.asarray(residual_inflds), trial_rate_features, rate_features, rate_constr)

    
    def trial_state_residuals(gid, target_outfld, t_peak_idxs, t_trough_idxs, t_infld_idxs, t_outfld_idxs, state_values, masked_state_values):

        state_value_arrays = np.row_stack(state_values)
        masked_state_value_arrays = None
        if masked_state_values is not None:
            masked_state_value_arrays = np.row_stack(masked_state_values)
        
        residuals_outfld = []
        peak_inflds = []
        trough_inflds = []
        mean_outflds = []
        for i in range(state_value_arrays.shape[0]):
            state_value_array = state_value_arrays[i, :]
            peak_infld = np.mean(state_value_array[t_peak_idxs])
            trough_infld = np.mean(state_value_array[t_trough_idxs])
            mean_infld = np.mean(state_value_array[t_infld_idxs])

            masked_state_value_array = masked_state_value_arrays[i, :]
            mean_masked = np.mean(masked_state_value_array)
            residual_masked = np.mean(masked_state_value_array) - target_outfld

            mean_outfld = mean_masked
            if t_outfld_idxs is not None:
                mean_outfld = np.mean(state_value_array[t_outfld_idxs])
                
            peak_inflds.append(peak_infld)
            trough_inflds.append(trough_infld)
            mean_outflds.append(mean_outfld)
            residuals_outfld.append(residual_masked)
            logger.info(f'selectivity objective: state values of gid {gid}: '
                        f'peak/trough/mean in/mean out/masked: {peak_infld:.02f} / {trough_infld:.02f} / {mean_infld:.02f} / {mean_outfld:.02f} / residual masked: {residual_masked:.04f}')

        state_features = [np.mean(peak_inflds), np.mean(trough_inflds), np.mean(mean_outflds)]
        return (np.asarray(residuals_outfld), state_features)

    
    recording_profile = { 'label': f'optimize_selectivity.{state_variable}',
                          'section quantity': {
                              state_variable: { 'swc types': ['soma'] }
                            }
                        }
    env.recording_profile = recording_profile
    state_recs_dict = {}
    for gid in my_cell_index_set:
        state_recs_dict[gid] = record_cell(env, population, gid, recording_profile=recording_profile)

        
    def eval_problem(cell_param_dict, **kwargs):

        run_params = {population: {gid: from_param_dict(cell_param_dict[gid])
                                   for gid in my_cell_index_set}}
        masked_state_values_dict = {}
        masked_run_params = {population: { gid: update_run_params(run_params[population][gid],
                                                                  selectivity_opt_param_config.mask_param_names,
                                                                  selectivity_opt_param_config.mask_param_tuples)
                                           for gid in my_cell_index_set} }
        spkdict = run_with(env, run_params)
        rates_dict = gid_firing_rate_vectors(spkdict, my_cell_index_set)
        t_s, state_values_dict = gid_state_values(spkdict, equilibration_duration, n_trials, env.t_rec, 
                                                  state_recs_dict)

        masked_spkdict = run_with(env, masked_run_params)
        masked_rates_dict = gid_firing_rate_vectors(masked_spkdict, my_cell_index_set)
        t_s, masked_state_values_dict = gid_state_values(masked_spkdict, equilibration_duration, n_trials, env.t_rec, 
                                                         state_recs_dict)
        
        
        result = {}
        for gid in my_cell_index_set:
            infld_idxs = infld_idxs_dict[gid]
            outfld_idxs = outfld_idxs_dict[gid]
            peak_idxs = peak_idxs_dict[gid]
            trough_idxs = trough_idxs_dict[gid]
            
            target_rate_vector = target_rate_vector_dict[gid]

            peak_ranges = peak_ranges_dict[gid]
            trough_ranges = trough_ranges_dict[gid]
            infld_ranges = infld_ranges_dict[gid]
            outfld_ranges = outfld_ranges_dict[gid]
            
            t_peak_idxs = np.concatenate([ np.where(np.logical_and(t_s >= r[0], t_s < r[1]))[0] for r in peak_ranges ])
            t_trough_idxs = np.concatenate([ np.where(np.logical_and(t_s >= r[0], t_s < r[1]))[0] for r in trough_ranges ])
            t_infld_idxs = np.concatenate([ np.where(np.logical_and(t_s >= r[0], t_s < r[1]))[0] for r in infld_ranges ])
            if outfld_ranges is not None:
                t_outfld_idxs = np.concatenate([ np.where(np.logical_and(t_s >= r[0], t_s < r[1]))[0] for r in outfld_ranges ])
            else:
                t_outfld_idxs = None
            
            masked_state_values = masked_state_values_dict.get(gid, None)
            state_values = state_values_dict[gid]
            rate_vectors = rates_dict[gid]
            masked_rate_vectors = masked_rates_dict[gid]
            
            logger.info(f'selectivity objective: max rates of gid {gid}: '
                        f'{list([np.max(rate_vector) for rate_vector in rate_vectors])}')

            infld_residuals, trial_rate_features, rate_features, rate_constr = \
              trial_snr_residuals(gid, peak_idxs, trough_idxs, infld_idxs, outfld_idxs, 
                                  rate_vectors, masked_rate_vectors, target_rate_vector)
            state_residuals, state_features = trial_state_residuals(gid, state_baseline,
                                                                    t_peak_idxs, t_trough_idxs, t_infld_idxs, t_outfld_idxs,
                                                                    state_values, masked_state_values)
            trial_obj_features = np.row_stack((infld_residuals, state_residuals))
            
            if trial_regime == 'mean':
                mean_infld_residual = np.mean(infld_residuals)
                mean_state_residual = np.mean(state_residuals)
                infld_objective = mean_infld_residual
                state_objective = abs(mean_state_residual)
                logger.info(f'selectivity objective: mean peak/trough/mean infld/mean outfld/mean state residual of gid {gid}: '
                            f'{mean_infld_residual:.04f} {mean_state_residual:.04f}')
            elif trial_regime == 'best':
                min_infld_residual_index = np.argmin(infld_residuals)
                min_infld_residual = infld_residuals[min_infld_index]
                infld_objective = min_infld_residual
                min_state_residual = np.min(np.abs(state_residuals))
                state_objective = min_state_residual
                logger.info(f'selectivity objective: mean peak/trough/max infld/max outfld/min state residual of gid {gid}: '
                            f'{min_infld_residual:.04f} {min_state_residual:.04f}')
            else:
                raise RuntimeError(f'selectivity_rate_objective: unknown trial regime {trial_regime}')

            logger.info(f"rate_features: {rate_features} state_features: {state_features} obj_features: {trial_obj_features}")

            result[gid] = (np.asarray([ infld_objective, state_objective ], 
                                      dtype=np.float32), 
                           np.array([tuple(rate_features+state_features+[trial_obj_features]+trial_rate_features)], 
                                    dtype=np.dtype(feature_dtypes)),
                           np.asarray(rate_constr, dtype=np.float32))
                           
        return result
    
    return opt_eval_fun(problem_regime, my_cell_index_set, eval_problem)
Exemple #3
0
def main(config_path, params_id, n_samples, target_features_path, target_features_namespace, output_file_dir, output_file_name, verbose):

    config_logging(verbose)
    logger = utils.get_script_logger(os.path.basename(__file__))

    if params_id is None:
        if n_samples is not None:
            logger.info("Generating parameter lattice ...")
            generate_param_lattice(config_path, n_samples, output_file_dir, output_file_name, verbose)
            return

    network_args = click.get_current_context().args
    network_config = {}
    for arg in network_args:
        kv = arg.split("=")
        if len(kv) > 1:
            k,v = kv
            network_config[k.replace('--', '').replace('-', '_')] = v
        else:
            k = kv[0]
            network_config[k.replace('--', '').replace('-', '_')] = True

    run_ts = datetime.datetime.today().strftime('%Y%m%d_%H%M')

    if output_file_name is None:
        output_file_name=f"network_features_{run_ts}.h5"
    output_path = f'{output_file_dir}/{output_file_name}'

    eval_config = read_from_yaml(config_path, include_loader=utils.IncludeLoader)
    eval_config['run_ts'] = run_ts
    if target_features_path is not None:
        eval_config['target_features_path'] = target_features_path
    if target_features_namespace is not None:
        eval_config['target_features_namespace'] = target_features_namespace

    network_param_spec_src = eval_config['param_spec']
    network_param_values = eval_config['param_values']

    feature_names = eval_config['feature_names']
    target_populations = eval_config['target_populations']

    network_config.update(eval_config.get('kwargs', {}))

    if params_id is None:
        network_config['results_file_id'] = f"DG_eval_network_{eval_config['run_ts']}"
    else:
        network_config['results_file_id'] = f"DG_eval_network_{params_id}_{eval_config['run_ts']}"
        

    env = init_network(comm=MPI.COMM_WORLD, kwargs=network_config)
    gc.collect()

    t_start = 50.
    t_stop = env.tstop
    time_range = (t_start, t_stop)

    target_trj_rate_map_dict = {}
    target_features_arena = env.arena_id
    target_features_trajectory = env.trajectory_id
    for pop_name in target_populations:
        if ('%s target rate dist residual' % pop_name) not in feature_names:
            continue
        my_cell_index_set = set(env.biophys_cells[pop_name].keys())
        trj_rate_maps = {}
        trj_rate_maps = rate_maps_from_features(env, pop_name,
                                                cell_index_set=list(my_cell_index_set),
                                                input_features_path=target_features_path, 
                                                input_features_namespace=target_features_namespace,
                                                time_range=time_range)
        target_trj_rate_map_dict[pop_name] = trj_rate_maps


    network_param_spec = make_param_spec(target_populations, network_param_spec_src)

    def from_param_list(x):
        result = []
        for pop_param in x:
            this_population, source, sec_type, syn_name, param_path, param_val = pop_param
            param_tuple = SynParam(this_population, source, sec_type, syn_name, param_path, None)
            result.append((param_tuple, param_val))

        return result

    def from_param_dict(x):
        result = []
        for pop_name, param_specs in viewitems(x):
            keyfun = lambda kv: str(kv[0])
            for source, source_dict in sorted(viewitems(param_specs), key=keyfun):
                for sec_type, sec_type_dict in sorted(viewitems(source_dict), key=keyfun):
                    for syn_name, syn_mech_dict in sorted(viewitems(sec_type_dict), key=keyfun):
                        for param_fst, param_rst in sorted(viewitems(syn_mech_dict), key=keyfun):
                            if isinstance(param_rst, dict):
                                for const_name, const_value in sorted(viewitems(param_rst)):
                                    param_path = (param_fst, const_name)
                                    param_tuple = SynParam(pop_name, source, sec_type, syn_name, param_path, const_value)
                                    result.append(param_tuple, const_value)
                            else:
                                param_name = param_fst
                                param_value = param_rst
                                param_tuple = SynParam(pop_name, source, sec_type, syn_name, param_name, param_value)
                                result.append(param_tuple, param_value)
        return result

    eval_network(env, network_config,
                 from_param_list, from_param_dict,
                 network_param_spec, network_param_values, params_id,
                 target_trj_rate_map_dict, t_start, t_stop, 
                 target_populations, output_path)
Exemple #4
0
def config_worker():
    """

    """
    if 'debug' not in context():
        context.debug = False

    if context.debug:
        if context.comm.rank == 1:
            print('# of parameters: %i' % len(context.param_names))
            print('param_names: ', context.param_names)
            print('target_val: ', context.target_val)
            print('target_range: ', context.target_range)
            print('param_tuples: ', context.param_tuples)
            sys.stdout.flush()

    utils.config_logging(context.verbose)
    context.logger = utils.get_script_logger(os.path.basename(__file__))
    # TODO: Do you want this to be identical on all ranks in a subworld? You can use context.comm.bcast
    if 'results_file_id' not in context():
        context.results_file_id = 'DG_optimize_network_subworlds_%s_%s' % \
                             (context.interface.worker_id, datetime.datetime.today().strftime('%Y%m%d_%H%M'))

    # 'env' might be in context on controller, but it needs to be re-built when the controller is in a worker subworld
    try:
        if context.debug:
            print(
                'debug: config_worker; local_comm.rank: %i/%i; global_comm.rank: %i/%i'
                % (context.comm.rank, context.comm.size,
                   context.global_comm.rank, context.global_comm.size))
            sys.stdout.flush()
        init_network()
    except Exception as err:
        context.logger.exception(err)
        raise err

    if 't_start' not in context():
        context.t_start = 50.
    else:
        context.t_start = float(context.t_start)
    if 't_stop' not in context():
        context.t_stop = context.env.tstop
    else:
        context.t_stop = float(context.t_stop)
    time_range = (context.t_start, context.t_stop)

    try:
        if context.debug:
            if context.global_comm.rank == 0:
                print('t_start: %.1f, t_stop: %.1f' %
                      (context.t_start, context.t_stop))
    except Exception as err:
        context.logger.exception(err)
        raise err

    context.target_trj_rate_map_dict = {}
    target_rate_map_path = context.target_rate_map_path
    target_rate_map_namespace = context.target_rate_map_namespace
    target_rate_map_arena = context.env.arena_id
    target_rate_map_trajectory = context.env.trajectory_id
    for pop_name in context.target_populations:
        my_cell_index_set = set(context.env.biophys_cells[pop_name].keys())
        trj_rate_maps = rate_maps_from_features(
            context.env,
            pop_name,
            cell_index_set=list(my_cell_index_set),
            input_features_path=target_rate_map_path,
            input_features_namespace=target_rate_map_namespace,
            time_range=time_range)
        if len(trj_rate_maps) > 0:
            context.target_trj_rate_map_dict[pop_name] = trj_rate_maps