Ejemplo n.º 1
0
def main(config_file, config_prefix, population, gid, template_paths, dataset_prefix, results_path, results_file_id, results_namespace_id, v_init):

    if results_file_id is None:
        results_file_id = uuid.uuid4()
    if results_namespace_id is None:
        results_namespace_id = 'Cell Clamp Results'
    comm = MPI.COMM_WORLD
    np.seterr(all='raise')
    verbose = True
    params = dict(locals())
    env = Env(**params)
    configure_hoc_env(env)
    io_utils.mkout(env, env.results_file_path)
    env.cell_selection = {}

    attr_dict = {}
    attr_dict[gid] = {}
    attr_dict[gid].update(measure_passive(gid, population, v_init, env))
    attr_dict[gid].update(measure_ap(gid, population, v_init, env))
    attr_dict[gid].update(measure_ap_rate(gid, population, v_init, env))
    attr_dict[gid].update(measure_fi(gid, population, v_init, env))

    pprint.pprint(attr_dict)

    if results_path is not None:
        append_cell_attributes(env.results_file_path, population, attr_dict,
                               namespace=env.results_namespace_id,
                               comm=env.comm, io_size=env.io_size)
Ejemplo n.º 2
0
def init_env():
    """

    """
    context.env = Env(comm=context.comm, results_file_id=context.results_file_id, **context.init_params)
    configure_hoc_env(context.env)
    context.gid = int(context.init_params['gid'])
    context.population = context.init_params['population']
    context.target_val = {}
Ejemplo n.º 3
0
def main(gid, pop_name, config_file, template_paths, hoc_lib_path, dataset_prefix, config_prefix, mech_file,
         load_edges, load_weights, correct_for_spines, verbose):
    """

    :param gid: int
    :param pop_name: str
    :param config_file: str; model configuration file name
    :param template_paths: str; colon-separated list of paths to directories containing hoc cell templates
    :param hoc_lib_path: str; path to directory containing required hoc libraries
    :param dataset_prefix: str; path to directory containing required neuroh5 data files
    :param config_prefix: str; path to directory containing network and cell mechanism config files
    :param mech_file: str; cell mechanism config file name
    :param load_edges: bool; whether to attempt to load connections from a neuroh5 file
    :param load_weights: bool; whether to attempt to load connections from a neuroh5 file
    :param correct_for_spines: bool
    :param verbose: bool
    """
    utils.config_logging(verbose)
    logger = utils.get_script_logger(os.path.basename(__file__))

    comm = MPI.COMM_WORLD
    np.seterr(all='raise')
    env = Env(comm=comm, config_file=config_file, template_paths=template_paths, hoc_lib_path=hoc_lib_path,
              dataset_prefix=dataset_prefix, config_prefix=config_prefix, verbose=verbose)
    configure_hoc_env(env)

    mech_file_path = config_prefix + '/' + mech_file
    template_name = env.celltypes[pop_name]['template']
    if template_name.lower() == 'izhikevich':
        cell = make_izhikevich_cell(env, pop_name=pop_name, gid=gid,
                                    load_synapses=True, load_connections=True,
                                    load_edges=load_edges, load_weights=load_weights,
                                    mech_file_path=mech_file_path)
    elif template_name.lower() == 'pr_nrn':
        cell = make_PR_cell(env, pop_name=pop_name, gid=gid,
                            load_synapses=True, load_connections=True,
                            load_edges=load_edges, load_weights=load_weights,
                            mech_file_path=mech_file_path)
    else:
        cell = make_biophys_cell(env, pop_name=pop_name, gid=gid,
                                 load_synapses=True, load_connections=True,
                                 load_edges=load_edges, load_weights=load_weights,
                                 mech_file_path=mech_file_path)
    context.update(locals())

    init_biophysics(cell, reset_cable=True, correct_cm=correct_for_spines, correct_g_pas=correct_for_spines,
                    env=env, verbose=verbose)

    init_syn_mech_attrs(cell, env)
    config_biophys_cell_syns(env, gid, pop_name, insert=True, insert_netcons=True, insert_vecstims=True,
                             verbose=verbose)

    if verbose:
        for sec in list(cell.hoc_cell.all if hasattr(cell, 'hoc_cell') else cell.all):
            h.psection(sec=sec)
        report_topology(cell, env)
Ejemplo n.º 4
0
def main(gid, pop_name, config_file, template_paths, hoc_lib_path,
         dataset_prefix, config_prefix, mech_file, load_edges, load_weights,
         correct_for_spines, verbose):
    """

    :param gid: int
    :param pop_name: str
    :param config_file: str; model configuration file name
    :param template_paths: str; colon-separated list of paths to directories containing hoc cell templates
    :param hoc_lib_path: str; path to directory containing required hoc libraries
    :param dataset_prefix: str; path to directory containing required neuroh5 data files
    :param config_prefix: str; path to directory containing network and cell mechanism config files
    :param mech_file: str; cell mechanism config file name
    :param load_edges: bool; whether to attempt to load connections from a neuroh5 file
    :param load_weights: bool; whether to attempt to load connections from a neuroh5 file
    :param correct_for_spines: bool
    :param verbose: bool
    """
    comm = MPI.COMM_WORLD
    np.seterr(all='raise')
    env = Env(comm=comm,
              config_file=config_file,
              template_paths=template_paths,
              hoc_lib_path=hoc_lib_path,
              dataset_prefix=dataset_prefix,
              config_prefix=config_prefix,
              verbose=verbose)
    configure_hoc_env(env)

    mech_file_path = config_prefix + '/' + mech_file
    cell = get_biophys_cell(env,
                            pop_name=pop_name,
                            gid=gid,
                            load_edges=load_edges,
                            load_weights=load_weights,
                            mech_file_path=mech_file_path)
    context.update(locals())

    init_biophysics(cell,
                    reset_cable=True,
                    correct_cm=correct_for_spines,
                    correct_g_pas=correct_for_spines,
                    env=env,
                    verbose=verbose)
    init_syn_mech_attrs(cell, env)
    config_biophys_cell_syns(env,
                             gid,
                             pop_name,
                             insert=True,
                             insert_netcons=True,
                             insert_vecstims=True,
                             verbose=verbose)

    if verbose:
        report_topology(cell, env)
Ejemplo n.º 5
0
def main(config, template_path, prototype_gid, prototype_path, forest_path, population, io_size, verbose):
    """

    :param config:
    :param template_path:
    :param prototype_gid:
    :param prototype_path:
    :param forest_path:
    :param population:
    :param io_size:
    """

    utils.config_logging(verbose)
    logger = utils.get_script_logger(script_name)
        
    comm = MPI.COMM_WORLD
    rank = comm.rank
    
    env = Env(comm=MPI.COMM_WORLD, config_file=config, template_paths=template_path)
    configure_hoc_env(env)
    
    if io_size == -1:
        io_size = comm.size
    if rank == 0:
        logger.info('%i ranks have been allocated' % comm.size)
    
    layers = env.layers
    layer_idx_dict = { layers[layer_name]: layer_name 
                       for layer_name in ['GCL', 'IML', 'MML', 'OML', 'Hilus'] }

    (tree_iter, _) = read_tree_selection(prototype_path, population, selection=[prototype_gid])
    (_, prototype_morph_dict) = next(tree_iter)
    prototype_x = prototype_morph_dict['x']
    prototype_y = prototype_morph_dict['y']
    prototype_z = prototype_morph_dict['z']
    prototype_xyz = (prototype_x, prototype_y, prototype_z)

    (pop_ranges, _) = read_population_ranges(forest_path, comm=comm)
    start_time = time.time()

    (population_start, _) = pop_ranges[population]
    template_class = load_cell_template(env, population, bcast_template=True)
    for gid, morph_dict in NeuroH5TreeGen(forest_path, population, io_size=io_size, cache_size=1, comm=comm, topology=True):
#    trees, _ = scatter_read_trees(forest_path, population, io_size=io_size, comm=comm, topology=True)
 #   for gid, morph_dict in trees:
        if gid is not None:
            logger.info('Rank %i gid: %i' % (rank, gid))
            secnodes_dict = morph_dict['section_topology']['nodes']
            vx = morph_dict['x']
            vy = morph_dict['y']
            vz = morph_dict['z']
            if compare_points((vx,vy,vz), prototype_xyz):
                logger.info('Possible match: gid %i' % gid)
    MPI.Finalize()
def init_network_clamp():
    """

    """
    np.seterr(all='raise')
    if context.env is None:
        context.env = Env(comm=context.comm,
                          results_file_id=context.results_file_id,
                          **context.init_params)
        configure_hoc_env(env)

    context.gid = int(context.init_params['gid'])
    context.population = context.init_params['population']
    context.target_val = {}
    network_clamp.init(
        context.env,
        context.init_params['population'],
        set([context.gid]),
        arena_id=context.init_params['arena_id'],
        trajectory_id=context.init_params['trajectory_id'],
        n_trials=int(context.init_params['n_trials']),
        spike_events_path=context.init_params.get('spike_events_path', None),
        spike_events_namespace=context.init_params.get(
            'spike_events_namespace', None),
        spike_train_attr_name=context.init_params.get('spike_train_attr_name',
                                                      None),
        input_features_path=context.init_params.get('input_features_path',
                                                    None),
        input_features_namespaces=context.init_params.get(
            'input_features_namespaces', None),
        t_min=0.,
        t_max=context.init_params['tstop'])

    context.equilibration_duration = float(
        context.env.stimulus_config['Equilibration Duration'])

    state_variable = 'v'
    context.recording_profile = {
        'label': 'optimize_network_clamp.state.%s' % state_variable,
        'dt': 0.1,
        'section quantity': {
            state_variable: {
                'swc types': ['soma']
            }
        }
    }
    context.state_recs_dict = {}
    context.state_recs_dict[context.gid] = cells.record_cell(
        context.env,
        context.population,
        context.gid,
        recording_profile=context.recording_profile)
Ejemplo n.º 7
0
def main(config_file, population, gid, template_paths, dataset_prefix,
         config_prefix, load_synapses, syn_types, syn_sources,
         syn_source_threshold, font_size, bgcolor, colormap, verbose):

    utils.config_logging(verbose)
    logger = utils.get_script_logger(script_name)

    params = dict(locals())
    env = Env(**params)
    configure_hoc_env(env)

    ## Determine if a mechanism configuration file exists for this cell type
    if 'mech_file_path' in env.celltypes[population]:
        mech_file_path = env.celltypes[population]['mech_file_path']
    else:
        mech_file_path = None

    logger.info('loading cell %i' % gid)

    load_weights = False
    biophys_cell = get_biophys_cell(env,
                                    population,
                                    gid,
                                    load_synapses=load_synapses,
                                    load_weights=load_weights,
                                    load_edges=load_synapses,
                                    mech_file_path=mech_file_path)

    if len(syn_types) == 0:
        syn_types = None
    else:
        syn_types = list(syn_types)
    if len(syn_sources) == 0:
        syn_sources = None
    else:
        syn_sources = list(syn_sources)

    plot.plot_biophys_cell_tree(env,
                                biophys_cell,
                                saveFig=True,
                                syn_source_threshold=syn_source_threshold,
                                synapse_filters={
                                    'syn_types': syn_types,
                                    'sources': syn_sources
                                },
                                bgcolor=bgcolor,
                                colormap=colormap)
Ejemplo n.º 8
0
def main(config_path, template_paths, forest_path, synapses_path):

    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    env = Env(comm=comm,
              config_file=config_path,
              template_paths=template_paths)

    neuron_utils.configure_hoc_env(env)

    h.pc = h.ParallelContext()

    v_init = -65.0
    popName = "MC"
    gid = 1000000

    env.load_cell_template(popName)

    (trees, _) = read_tree_selection(forest_path, popName, [gid], comm=comm)
    if synapses_path is not None:
        synapses_iter = read_cell_attribute_selection (synapses_path, popName, [gid], \
                                                       "Synapse Attributes", comm=comm)
    else:
        synapses_iter = None

    gid, tree = next(trees)
    if synapses_iter is not None:
        (_, synapses) = next(synapses_iter)
    else:
        synapses = None

    if 'mech_file' in env.celltypes[popName]:
        mech_file_path = env.config_prefix + '/' + env.celltypes[popName][
            'mech_file']
    else:
        mech_file_path = None

    template_class = getattr(h, "MossyCell")

    if (synapses is not None):
        synapse_test(template_class, mech_file_path, gid, tree, synapses,
                     v_init, env)
Ejemplo n.º 9
0
def main(config_file, config_prefix, erev, population, presyn_name, gid,
         load_weights, measurements, template_paths, dataset_prefix,
         results_path, results_file_id, results_namespace_id, syn_mech_name,
         syn_weight, syn_count, syn_layer, swc_type, stim_amp, v_init, dt,
         use_cvode, verbose):

    config_logging(verbose)

    if results_file_id is None:
        results_file_id = uuid.uuid4()
    if results_namespace_id is None:
        results_namespace_id = 'Cell Clamp Results'
    comm = MPI.COMM_WORLD
    np.seterr(all='raise')
    params = dict(locals())
    env = Env(**params)
    configure_hoc_env(env)
    io_utils.mkout(env, env.results_file_path)
    env.cell_selection = {}

    if measurements is not None:
        measurements = [x.strip() for x in measurements.split(",")]

    attr_dict = {}
    attr_dict[gid] = {}
    if 'passive' in measurements:
        attr_dict[gid].update(measure_passive(gid, population, v_init, env))
    if 'ap' in measurements:
        attr_dict[gid].update(measure_ap(gid, population, v_init, env))
    if 'ap_rate' in measurements:
        logger.info('ap_rate')
        attr_dict[gid].update(
            measure_ap_rate(gid, population, v_init, env, stim_amp=stim_amp))
    if 'fi' in measurements:
        attr_dict[gid].update(measure_fi(gid, population, v_init, env))
    if 'gap' in measurements:
        measure_gap_junction_coupling(gid, population, v_init, env)
    if 'psp' in measurements:
        assert (presyn_name is not None)
        assert (syn_mech_name is not None)
        assert (erev is not None)
        assert (syn_weight is not None)
        attr_dict[gid].update(
            measure_psp(gid,
                        population,
                        presyn_name,
                        syn_mech_name,
                        swc_type,
                        env,
                        v_init,
                        erev,
                        syn_layer=syn_layer,
                        syn_count=syn_count,
                        weight=syn_weight,
                        load_weights=load_weights))

    if results_path is not None:
        append_cell_attributes(env.results_file_path,
                               population,
                               attr_dict,
                               namespace=env.results_namespace_id,
                               comm=env.comm,
                               io_size=env.io_size)
Ejemplo n.º 10
0
def main(config, template_path, output_path, forest_path, populations,
         distance_bin_size, io_size, chunk_size, value_chunk_size, cache_size,
         verbose):
    """

    :param config:
    :param template_path:
    :param forest_path:
    :param populations:
    :param io_size:
    :param chunk_size:
    :param value_chunk_size:
    :param cache_size:
    """

    utils.config_logging(verbose)
    logger = utils.get_script_logger(script_name)

    comm = MPI.COMM_WORLD
    rank = comm.rank

    env = Env(comm=MPI.COMM_WORLD,
              config_file=config,
              template_paths=template_path)
    configure_hoc_env(env)

    if io_size == -1:
        io_size = comm.size
    if rank == 0:
        logger.info('%i ranks have been allocated' % comm.size)

    if output_path is None:
        output_path = forest_path

    if rank == 0:
        if not os.path.isfile(output_path):
            input_file = h5py.File(forest_path, 'r')
            output_file = h5py.File(output_path, 'w')
            input_file.copy('/H5Types', output_file)
            input_file.close()
            output_file.close()
    comm.barrier()

    layers = env.layers
    layer_idx_dict = {
        layers[layer_name]: layer_name
        for layer_name in ['GCL', 'IML', 'MML', 'OML', 'Hilus']
    }

    (pop_ranges, _) = read_population_ranges(forest_path, comm=comm)
    start_time = time.time()
    for population in populations:
        logger.info('Rank %i population: %s' % (rank, population))
        count = 0
        (population_start, _) = pop_ranges[population]
        template_class = load_cell_template(env,
                                            population,
                                            bcast_template=True)
        measures_dict = {}
        for gid, morph_dict in NeuroH5TreeGen(forest_path,
                                              population,
                                              io_size=io_size,
                                              comm=comm,
                                              topology=True):
            if gid is not None:
                logger.info('Rank %i gid: %i' % (rank, gid))
                cell = cells.make_neurotree_cell(template_class,
                                                 neurotree_dict=morph_dict,
                                                 gid=gid)
                secnodes_dict = morph_dict['section_topology']['nodes']

                apicalidx = set(cell.apicalidx)
                basalidx = set(cell.basalidx)

                dendrite_area_dict = {k: 0.0 for k in layer_idx_dict}
                dendrite_length_dict = {k: 0.0 for k in layer_idx_dict}
                dendrite_distances = []
                dendrite_diams = []
                for (i, sec) in enumerate(cell.sections):
                    if (i in apicalidx) or (i in basalidx):
                        secnodes = secnodes_dict[i]
                        for seg in sec.allseg():
                            L = seg.sec.L
                            nseg = seg.sec.nseg
                            seg_l = L / nseg
                            seg_area = h.area(seg.x)
                            seg_diam = seg.diam
                            seg_distance = get_distance_to_node(
                                cell,
                                list(cell.soma)[0], seg.sec, seg.x)
                            dendrite_diams.append(seg_diam)
                            dendrite_distances.append(seg_distance)
                            layer = synapses.get_node_attribute(
                                'layer', morph_dict, seg.sec, secnodes, seg.x)
                            dendrite_length_dict[layer] += seg_l
                            dendrite_area_dict[layer] += seg_area

                dendrite_distance_array = np.asarray(dendrite_distances)
                dendrite_diam_array = np.asarray(dendrite_diams)
                dendrite_distance_bin_range = int(
                    ((np.max(dendrite_distance_array)) -
                     np.min(dendrite_distance_array)) / distance_bin_size) + 1
                dendrite_distance_counts, dendrite_distance_edges = np.histogram(
                    dendrite_distance_array,
                    bins=dendrite_distance_bin_range,
                    density=False)
                dendrite_diam_sums, _ = np.histogram(
                    dendrite_distance_array,
                    weights=dendrite_diam_array,
                    bins=dendrite_distance_bin_range,
                    density=False)
                dendrite_mean_diam_hist = np.zeros_like(dendrite_diam_sums)
                np.divide(dendrite_diam_sums,
                          dendrite_distance_counts,
                          where=dendrite_distance_counts > 0,
                          out=dendrite_mean_diam_hist)

                dendrite_area_per_layer = np.asarray([
                    dendrite_area_dict[k]
                    for k in sorted(dendrite_area_dict.keys())
                ],
                                                     dtype=np.float32)
                dendrite_length_per_layer = np.asarray([
                    dendrite_length_dict[k]
                    for k in sorted(dendrite_length_dict.keys())
                ],
                                                       dtype=np.float32)

                measures_dict[gid] = {
                    'dendrite_distance_hist_edges':
                    np.asarray(dendrite_distance_edges, dtype=np.float32),
                    'dendrite_distance_counts':
                    np.asarray(dendrite_distance_counts, dtype=np.int32),
                    'dendrite_mean_diam_hist':
                    np.asarray(dendrite_mean_diam_hist, dtype=np.float32),
                    'dendrite_area_per_layer':
                    dendrite_area_per_layer,
                    'dendrite_length_per_layer':
                    dendrite_length_per_layer
                }

                del cell
                count += 1
            else:
                logger.info('Rank %i gid is None' % rank)
        append_cell_attributes(output_path,
                               population,
                               measures_dict,
                               namespace='Tree Measurements',
                               comm=comm,
                               io_size=io_size,
                               chunk_size=chunk_size,
                               value_chunk_size=value_chunk_size,
                               cache_size=cache_size)
    MPI.Finalize()
Ejemplo n.º 11
0
def main(config_file, config_prefix, input_path, population, template_paths,
         dataset_prefix, results_path, results_file_id, results_namespace_id,
         v_init, io_size, chunk_size, value_chunk_size, write_size, verbose):

    utils.config_logging(verbose)
    logger = utils.get_script_logger(os.path.basename(__file__))

    comm = MPI.COMM_WORLD
    rank = comm.rank

    if rank == 0:
        logger.info('%i ranks have been allocated' % comm.size)

    if io_size == -1:
        io_size = comm.size

    if results_file_id is None:
        if rank == 0:
            result_file_id = uuid.uuid4()
        results_file_id = comm.bcast(results_file_id, root=0)
    if results_namespace_id is None:
        results_namespace_id = 'Cell Clamp Results'
    comm = MPI.COMM_WORLD
    np.seterr(all='raise')
    verbose = True
    params = dict(locals())
    env = Env(**params)
    configure_hoc_env(env)
    if rank == 0:
        io_utils.mkout(env, env.results_file_path)
    env.comm.barrier()
    env.cell_selection = {}
    template_class = load_cell_template(env, population)

    if input_path is not None:
        env.data_file_path = input_path
        env.load_celltypes()

    synapse_config = env.celltypes[population]['synapses']

    weights_namespaces = []
    if 'weights' in synapse_config:
        has_weights = synapse_config['weights']
        if has_weights:
            if 'weights namespace' in synapse_config:
                weights_namespaces.append(synapse_config['weights namespace'])
            elif 'weights namespaces' in synapse_config:
                weights_namespaces.extend(synapse_config['weights namespaces'])
            else:
                weights_namespaces.append('Weights')
    else:
        has_weights = False

    start_time = time.time()
    count = 0
    gid_count = 0
    attr_dict = {}
    if input_path is None:
        cell_path = env.data_file_path
        connectivity_path = env.connectivity_file_path
    else:
        cell_path = input_path
        connectivity_path = input_path

    for gid, morph_dict in NeuroH5TreeGen(cell_path,
                                          population,
                                          io_size=io_size,
                                          comm=env.comm,
                                          topology=True):
        local_time = time.time()
        if gid is not None:
            color = 0
            comm0 = comm.Split(color, 0)

            logger.info('Rank %i gid: %i' % (rank, gid))
            cell_dict = {'morph': morph_dict}
            synapses_iter = read_cell_attribute_selection(cell_path,
                                                          population, [gid],
                                                          'Synapse Attributes',
                                                          comm=comm0)
            _, synapse_dict = next(synapses_iter)
            cell_dict['synapse'] = synapse_dict

            if has_weights:
                cell_weights_iters = [
                    read_cell_attribute_selection(cell_path,
                                                  population, [gid],
                                                  weights_namespace,
                                                  comm=comm0)
                    for weights_namespace in weights_namespaces
                ]
                weight_dict = dict(
                    zip_longest(weights_namespaces, cell_weights_iters))
                cell_dict['weight'] = weight_dict

            (graph,
             a) = read_graph_selection(file_name=connectivity_path,
                                       selection=[gid],
                                       namespaces=['Synapses', 'Connections'],
                                       comm=comm0)
            cell_dict['connectivity'] = (graph, a)

            gid_count += 1

            attr_dict[gid] = {}
            attr_dict[gid].update(
                cell_clamp.measure_passive(gid,
                                           population,
                                           v_init,
                                           env,
                                           cell_dict=cell_dict))
            attr_dict[gid].update(
                cell_clamp.measure_ap(gid,
                                      population,
                                      v_init,
                                      env,
                                      cell_dict=cell_dict))
            attr_dict[gid].update(
                cell_clamp.measure_ap_rate(gid,
                                           population,
                                           v_init,
                                           env,
                                           cell_dict=cell_dict))
            attr_dict[gid].update(
                cell_clamp.measure_fi(gid,
                                      population,
                                      v_init,
                                      env,
                                      cell_dict=cell_dict))

        else:
            color = 1
            comm0 = comm.Split(color, 0)
            logger.info('Rank %i gid is None' % (rank))
        comm0.Free()

        count += 1
        if (results_path is not None) and (count % write_size == 0):
            append_cell_attributes(env.results_file_path,
                                   population,
                                   attr_dict,
                                   namespace=env.results_namespace_id,
                                   comm=env.comm,
                                   io_size=env.io_size,
                                   chunk_size=chunk_size,
                                   value_chunk_size=value_chunk_size)
            attr_dict = {}

    env.comm.barrier()
    if results_path is not None:
        append_cell_attributes(env.results_file_path,
                               population,
                               attr_dict,
                               namespace=env.results_namespace_id,
                               comm=env.comm,
                               io_size=env.io_size,
                               chunk_size=chunk_size,
                               value_chunk_size=value_chunk_size)
    global_count = env.comm.gather(gid_count, root=0)

    MPI.Finalize()
Ejemplo n.º 12
0
def main(config_file, population, dt, gid, gid_selection_file, arena_id, trajectory_id, generate_weights,
         t_max, t_min,  nprocs_per_worker, n_epochs, n_initial, initial_maxiter, initial_method, optimizer_method, surrogate_method,
         population_size, num_generations, resample_fraction, mutation_rate,
         template_paths, dataset_prefix, config_prefix,
         param_config_name, selectivity_config_name, param_type, recording_profile, results_file, results_path, spike_events_path,
         spike_events_namespace, spike_events_t, input_features_path, input_features_namespaces, n_trials,
         trial_regime, problem_regime, target_features_path, target_features_namespace, target_state_variable,
         target_state_filter, use_coreneuron, cooperative_init, spawn_startup_wait):
    """
    Optimize the input stimulus selectivity of the specified cell in a network clamp configuration.
    """
    init_params = dict(locals())

    comm = MPI.COMM_WORLD
    size = comm.Get_size()
    rank = comm.Get_rank()

    results_file_id = None
    if rank == 0:
        results_file_id = generate_results_file_id(population, gid)
        
    results_file_id = comm.bcast(results_file_id, root=0)
    comm.barrier()
    
    np.seterr(all='raise')
    verbose = True
    cache_queries = True

    config_logging(verbose)

    cell_index_set = set([])
    if gid_selection_file is not None:
        with open(gid_selection_file, 'r') as f:
            lines = f.readlines()
            for line in lines:
                gid = int(line)
                cell_index_set.add(gid)
    elif gid is not None:
        cell_index_set.add(gid)
    else:
        comm.barrier()
        comm0 = comm.Split(2 if rank == 0 else 1, 0)
        if rank == 0:
            env = Env(**init_params, comm=comm0)
            attr_info_dict = read_cell_attribute_info(env.data_file_path, populations=[population],
                                                      read_cell_index=True, comm=comm0)
            cell_index = None
            attr_name, attr_cell_index = next(iter(attr_info_dict[population]['Trees']))
            cell_index_set = set(attr_cell_index)
        comm.barrier()
        cell_index_set = comm.bcast(cell_index_set, root=0)
        comm.barrier()
        comm0.Free()
    init_params['cell_index_set'] = cell_index_set
    del(init_params['gid'])

    params = dict(locals())
    env = Env(**params)
    if size == 1:
        configure_hoc_env(env)
        init(env, population, cell_index_set, arena_id, trajectory_id, n_trials,
             spike_events_path, spike_events_namespace=spike_events_namespace, 
             spike_train_attr_name=spike_events_t,
             input_features_path=input_features_path,
             input_features_namespaces=input_features_namespaces,
             generate_weights_pops=set(generate_weights), 
             t_min=t_min, t_max=t_max)
        
    if (population in env.netclamp_config.optimize_parameters[param_type]):
        opt_params = env.netclamp_config.optimize_parameters[param_type][population]
    else:
        raise RuntimeError(f'optimize_selectivity: population {population} does not have optimization configuration')

    if target_state_variable is None:
        target_state_variable = 'v'
    
    init_params['target_features_arena'] = arena_id
    init_params['target_features_trajectory'] = trajectory_id
    opt_state_baseline = opt_params['Targets']['state'][target_state_variable]['baseline']
    init_params['state_baseline'] = opt_state_baseline
    init_params['state_variable'] = target_state_variable
    init_params['state_filter'] = target_state_filter
    init_objfun_name = 'init_selectivity_objfun'
        
    best = optimize_run(env, population, param_config_name, selectivity_config_name, init_objfun_name, problem_regime=problem_regime,
                        n_epochs=n_epochs, n_initial=n_initial, initial_maxiter=initial_maxiter, initial_method=initial_method, 
                        optimizer_method=optimizer_method, surrogate_method=surrogate_method, population_size=population_size, 
                        num_generations=num_generations, resample_fraction=resample_fraction, mutation_rate=mutation_rate, 
                        param_type=param_type, init_params=init_params, results_file=results_file, nprocs_per_worker=nprocs_per_worker, 
                        cooperative_init=cooperative_init, spawn_startup_wait=spawn_startup_wait, verbose=verbose)
    
    opt_param_config = optimization_params(env.netclamp_config.optimize_parameters, [population], param_config_name, param_type)
    if best is not None:
        if results_path is not None:
            run_ts = time.strftime("%Y%m%d_%H%M%S")
            file_path = f'{results_path}/optimize_selectivity.{run_ts}.yaml'
            param_names = opt_param_config.param_names
            param_tuples = opt_param_config.param_tuples

            if ProblemRegime[problem_regime] == ProblemRegime.every:
                results_config_dict = {}
                for gid, prms in viewitems(best):
                    n_res = prms[0][1].shape[0]
                    prms_dict = dict(prms)
                    this_results_config_dict = {}
                    for i in range(n_res):
                        results_param_list = []
                        for param_pattern, param_tuple in zip(param_names, param_tuples):
                            results_param_list.append((param_tuple.population,
                                                       param_tuple.source,
                                                       param_tuple.sec_type,
                                                       param_tuple.syn_name,
                                                       param_tuple.param_path,
                                                       float(prms_dict[param_pattern][i])))
                        this_results_config_dict[i] = results_param_list
                    results_config_dict[gid] = this_results_config_dict
                    
            else:
                prms = best[0]
                n_res = prms[0][1].shape[0]
                prms_dict = dict(prms)
                results_config_dict = {}
                for i in range(n_res):
                    results_param_list = []
                    for param_pattern, param_tuple in zip(param_names, param_tuples):
                        results_param_list.append((param_tuple.population,
                                                   param_tuple.source,
                                                   param_tuple.sec_type,
                                                   param_tuple.syn_name,
                                                   param_tuple.param_path,
                                                   float(prms_dict[param_pattern][i])))
                    results_config_dict[i] = results_param_list

            write_to_yaml(file_path, { population: results_config_dict } )

            
    comm.barrier()
Ejemplo n.º 13
0
def init_selectivity_objfun(config_file, population, cell_index_set, arena_id, trajectory_id,
                            n_trials, trial_regime, problem_regime,
                            generate_weights, t_max, t_min,
                            template_paths, dataset_prefix, config_prefix, results_path,
                            spike_events_path, spike_events_namespace, spike_events_t,
                            input_features_path, input_features_namespaces,
                            param_type, param_config_name, selectivity_config_name, recording_profile, 
                            state_variable, state_filter, state_baseline,
                            target_features_path, target_features_namespace,
                            target_features_arena, target_features_trajectory,   
                            use_coreneuron, cooperative_init, dt, worker, **kwargs):
    
    params = dict(locals())
    env = Env(**params)
    env.results_file_path = None
    configure_hoc_env(env, bcast_template=True)

    my_cell_index_set = init(env, population, cell_index_set, arena_id, trajectory_id, n_trials,
                             spike_events_path, spike_events_namespace=spike_events_namespace, 
                             spike_train_attr_name=spike_events_t,
                             input_features_path=input_features_path,
                             input_features_namespaces=input_features_namespaces,
                             generate_weights_pops=set(generate_weights), 
                             t_min=t_min, t_max=t_max, cooperative_init=cooperative_init,
                             worker=worker)

    time_step = float(env.stimulus_config['Temporal Resolution'])
    equilibration_duration = float(env.stimulus_config.get('Equilibration Duration', 0.))
    
    target_rate_vector_dict = rate_maps_from_features (env, population,
                                                       cell_index_set=my_cell_index_set, 
                                                       input_features_path=target_features_path,
                                                       input_features_namespace=target_features_namespace, 
                                                       time_range=[0., t_max], 
                                                       arena_id=arena_id)


    logger.info(f'target_rate_vector_dict = {target_rate_vector_dict}')
    for gid, target_rate_vector in viewitems(target_rate_vector_dict):
        target_rate_vector[np.isclose(target_rate_vector, 0., atol=1e-3, rtol=1e-3)] = 0.

    trj_x, trj_y, trj_d, trj_t = stimulus.read_trajectory(input_features_path if input_features_path is not None else spike_events_path, 
                                                          target_features_arena, target_features_trajectory)
    time_range = (0., min(np.max(trj_t), t_max))
    time_bins = np.arange(time_range[0], time_range[1]+time_step, time_step)
    state_time_bins = np.arange(time_range[0], time_range[1], time_step)[:-1]

    def range_inds(rs):
        l = list(rs)
        if len(l) > 0:
            a = np.concatenate(l)
        else:
            a = None
        return a

    def time_ranges(rs):
        if len(rs) > 0:
            a = tuple( ( (time_bins[r[0]], time_bins[r[1]-1]) for r in rs ) )
        else:
            a = None
        return a
        
    
    infld_idxs_dict = { gid: np.where(target_rate_vector > 1e-4)[0] 
                        for gid, target_rate_vector in viewitems(target_rate_vector_dict) }
    peak_pctile_dict = { gid: np.percentile(target_rate_vector_dict[gid][infld_idxs], 80)
                         for gid, infld_idxs in viewitems(infld_idxs_dict) }
    trough_pctile_dict = { gid: np.percentile(target_rate_vector_dict[gid][infld_idxs], 20)
                           for gid, infld_idxs in viewitems(infld_idxs_dict) }
    outfld_idxs_dict = { gid: range_inds(contiguous_ranges(target_rate_vector < 1e-4, return_indices=True))
                        for gid, target_rate_vector in viewitems(target_rate_vector_dict) }

    peak_idxs_dict = { gid: range_inds(contiguous_ranges(target_rate_vector >= peak_pctile_dict[gid], return_indices=True)) 
                       for gid, target_rate_vector in viewitems(target_rate_vector_dict) }
    trough_idxs_dict = { gid: range_inds(contiguous_ranges(np.logical_and(target_rate_vector > 0., target_rate_vector <= trough_pctile_dict[gid]), return_indices=True))
                         for gid, target_rate_vector in viewitems(target_rate_vector_dict) }

    outfld_ranges_dict = { gid: time_ranges(contiguous_ranges(target_rate_vector <= 0.) ) 
                           for gid, target_rate_vector in viewitems(target_rate_vector_dict) }
    infld_ranges_dict = { gid: time_ranges(contiguous_ranges(target_rate_vector > 0) ) 
                          for gid, target_rate_vector in viewitems(target_rate_vector_dict) }

    peak_ranges_dict = { gid: time_ranges(contiguous_ranges(target_rate_vector >= peak_pctile_dict[gid]))
                         for gid, target_rate_vector in viewitems(target_rate_vector_dict) }
    trough_ranges_dict = { gid: time_ranges(contiguous_ranges(np.logical_and(target_rate_vector > 0., target_rate_vector <= trough_pctile_dict[gid])))
                         for gid, target_rate_vector in viewitems(target_rate_vector_dict) }

    large_fld_gids = []
    for gid in my_cell_index_set:

        infld_idxs = infld_idxs_dict[gid]

        target_infld_rate_vector = target_rate_vector[infld_idxs]
        target_peak_rate_vector = target_rate_vector[peak_idxs_dict[gid]]
        target_trough_rate_vector = target_rate_vector[trough_idxs_dict[gid]]

        logger.info(f'selectivity objective: target peak/trough rate of gid {gid}: '
                    f'{peak_pctile_dict[gid]:.02f} {trough_pctile_dict[gid]:.02f}')
        logger.info(f'selectivity objective: mean target peak/trough rate of gid {gid}: '
                    f'{np.mean(target_peak_rate_vector):.02f} {np.mean(target_trough_rate_vector):.02f}')
        
    opt_param_config = optimization_params(env.netclamp_config.optimize_parameters, [population], param_config_name, param_type)
    selectivity_opt_param_config = selectivity_optimization_params(env.netclamp_config.optimize_parameters, [population],
                                                                   selectivity_config_name)

    opt_targets = opt_param_config.opt_targets
    param_names = opt_param_config.param_names
    param_tuples = opt_param_config.param_tuples

    N_objectives = 2
    feature_names = ['mean_peak_rate', 'mean_trough_rate', 
                     'max_infld_rate', 'min_infld_rate', 'mean_infld_rate', 'mean_outfld_rate', 
                     'mean_peak_state', 'mean_trough_state', 'mean_outfld_state']
    feature_dtypes = [(feature_name, np.float32) for feature_name in feature_names]
    feature_dtypes.append(('trial_objs', (np.float32, (N_objectives, n_trials))))
    feature_dtypes.append(('trial_mean_infld_rate', (np.float32, (1, n_trials))))
    feature_dtypes.append(('trial_mean_outfld_rate', (np.float32, (1, n_trials))))

    def from_param_dict(params_dict):
        result = []
        for param_pattern, param_tuple in zip(param_names, param_tuples):
            result.append((param_tuple, params_dict[param_pattern]))

        return result

    def update_run_params(input_param_tuple_vals, update_param_names, update_param_tuples):
        result = []
        updated_set = set([])
        update_param_dict = dict(zip(update_param_names, update_param_tuples))
        for param_pattern, (param_tuple, param_val) in zip(param_names, input_param_tuple_vals):
            if param_pattern in update_param_dict:
                updated_set.add(param_pattern)
                result.append((param_tuple, update_param_dict[param_pattern].param_range))
            else:
                result.append((param_tuple, param_val))
        for update_param_name in update_param_dict:
            if update_param_name not in updated_set:
                result.append((update_param_dict[update_param_name], 
                               update_param_dict[update_param_name].param_range))

        return result
        
    
    def gid_firing_rate_vectors(spkdict, cell_index_set):
        rates_dict = defaultdict(list)
        for i in range(n_trials):
            spkdict1 = {}
            for gid in cell_index_set:
                if gid in spkdict[population]:
                    spkdict1[gid] = spkdict[population][gid][i]
                else:
                    spkdict1[gid] = np.asarray([], dtype=np.float32)
            spike_density_dict = spikedata.spike_density_estimate (population, spkdict1, time_bins)
            for gid in cell_index_set:
                rate_vector = spike_density_dict[gid]['rate']
                rate_vector[np.isclose(rate_vector, 0., atol=1e-3, rtol=1e-3)] = 0.
                rates_dict[gid].append(rate_vector)
                logger.info(f'selectivity objective: trial {i} firing rate min/max of gid {gid}: '
                            f'{np.min(rates_dict[gid]):.02f} / {np.max(rates_dict[gid]):.02f} Hz')

        return rates_dict

    def gid_state_values(spkdict, t_offset, n_trials, t_rec, state_recs_dict):
        t_vec = np.asarray(t_rec.to_python(), dtype=np.float32)
        t_trial_inds = get_trial_time_indices(t_vec, n_trials, t_offset)
        results_dict = {}
        filter_fun = None
        if state_filter == 'lowpass':
            filter_fun = lambda x, t: get_low_pass_filtered_trace(x, t)
        for gid in state_recs_dict:
            state_values = None
            state_recs = state_recs_dict[gid]
            assert(len(state_recs) == 1)
            rec = state_recs[0]
            vec = np.asarray(rec['vec'].to_python(), dtype=np.float32)
            if filter_fun is None:
                data = np.asarray([ vec[t_inds] for t_inds in t_trial_inds ])
            else:
                data = np.asarray([ filter_fun(vec[t_inds], t_vec[t_inds])
                                    for t_inds in t_trial_inds ])

            state_values = []
            max_len = np.max(np.asarray([len(a) for a in data]))
            for state_value_array in data:
                this_len = len(state_value_array)
                if this_len < max_len:
                    a = np.pad(state_value_array, (0, max_len-this_len), 'edge')
                else:
                    a = state_value_array
                state_values.append(a)

            results_dict[gid] = state_values
        return t_vec[t_trial_inds[0]], results_dict


    def trial_snr_residuals(gid, peak_idxs, trough_idxs, infld_idxs, outfld_idxs, 
                            rate_vectors, masked_rate_vectors, target_rate_vector):

        n_trials = len(rate_vectors)
        residual_inflds = []
        trial_inflds = []
        trial_outflds = []

        target_infld = target_rate_vector[infld_idxs]
        target_max_infld = np.max(target_infld)
        target_mean_trough = np.mean(target_rate_vector[trough_idxs])
        logger.info(f'selectivity objective: target max infld/mean trough of gid {gid}: '
                    f'{target_max_infld:.02f} {target_mean_trough:.02f}')
        for trial_i in range(n_trials):

            rate_vector = rate_vectors[trial_i]
            infld_rate_vector = rate_vector[infld_idxs]
            masked_rate_vector = masked_rate_vectors[trial_i]
            if outfld_idxs is None:
                outfld_rate_vector = masked_rate_vectors[trial_i]
            else:
                outfld_rate_vector = rate_vector[outfld_idxs]

            mean_peak = np.mean(rate_vector[peak_idxs])
            mean_trough = np.mean(rate_vector[trough_idxs])
            min_infld = np.min(infld_rate_vector)
            max_infld = np.max(infld_rate_vector)
            mean_infld = np.mean(infld_rate_vector)
            mean_outfld = np.mean(outfld_rate_vector)

            residual_infld = np.abs(np.sum(target_infld - infld_rate_vector))
            logger.info(f'selectivity objective: max infld/mean infld/mean peak/trough/mean outfld/residual_infld of gid {gid} trial {trial_i}: '
                        f'{max_infld:.02f} {mean_infld:.02f} {mean_peak:.02f} {mean_trough:.02f} {mean_outfld:.02f} {residual_infld:.04f}')
            residual_inflds.append(residual_infld)
            trial_inflds.append(mean_infld)
            trial_outflds.append(mean_outfld)

        trial_rate_features = [np.asarray(trial_inflds, dtype=np.float32).reshape((1, n_trials)), 
                               np.asarray(trial_outflds, dtype=np.float32).reshape((1, n_trials))]
        rate_features = [mean_peak, mean_trough, max_infld, min_infld, mean_infld, mean_outfld, ]
        #rate_constr = [ mean_peak if max_infld > 0. else -1. ]
        rate_constr = [ mean_peak - mean_trough if max_infld > 0. else -1. ]
        return (np.asarray(residual_inflds), trial_rate_features, rate_features, rate_constr)

    
    def trial_state_residuals(gid, target_outfld, t_peak_idxs, t_trough_idxs, t_infld_idxs, t_outfld_idxs, state_values, masked_state_values):

        state_value_arrays = np.row_stack(state_values)
        masked_state_value_arrays = None
        if masked_state_values is not None:
            masked_state_value_arrays = np.row_stack(masked_state_values)
        
        residuals_outfld = []
        peak_inflds = []
        trough_inflds = []
        mean_outflds = []
        for i in range(state_value_arrays.shape[0]):
            state_value_array = state_value_arrays[i, :]
            peak_infld = np.mean(state_value_array[t_peak_idxs])
            trough_infld = np.mean(state_value_array[t_trough_idxs])
            mean_infld = np.mean(state_value_array[t_infld_idxs])

            masked_state_value_array = masked_state_value_arrays[i, :]
            mean_masked = np.mean(masked_state_value_array)
            residual_masked = np.mean(masked_state_value_array) - target_outfld

            mean_outfld = mean_masked
            if t_outfld_idxs is not None:
                mean_outfld = np.mean(state_value_array[t_outfld_idxs])
                
            peak_inflds.append(peak_infld)
            trough_inflds.append(trough_infld)
            mean_outflds.append(mean_outfld)
            residuals_outfld.append(residual_masked)
            logger.info(f'selectivity objective: state values of gid {gid}: '
                        f'peak/trough/mean in/mean out/masked: {peak_infld:.02f} / {trough_infld:.02f} / {mean_infld:.02f} / {mean_outfld:.02f} / residual masked: {residual_masked:.04f}')

        state_features = [np.mean(peak_inflds), np.mean(trough_inflds), np.mean(mean_outflds)]
        return (np.asarray(residuals_outfld), state_features)

    
    recording_profile = { 'label': f'optimize_selectivity.{state_variable}',
                          'section quantity': {
                              state_variable: { 'swc types': ['soma'] }
                            }
                        }
    env.recording_profile = recording_profile
    state_recs_dict = {}
    for gid in my_cell_index_set:
        state_recs_dict[gid] = record_cell(env, population, gid, recording_profile=recording_profile)

        
    def eval_problem(cell_param_dict, **kwargs):

        run_params = {population: {gid: from_param_dict(cell_param_dict[gid])
                                   for gid in my_cell_index_set}}
        masked_state_values_dict = {}
        masked_run_params = {population: { gid: update_run_params(run_params[population][gid],
                                                                  selectivity_opt_param_config.mask_param_names,
                                                                  selectivity_opt_param_config.mask_param_tuples)
                                           for gid in my_cell_index_set} }
        spkdict = run_with(env, run_params)
        rates_dict = gid_firing_rate_vectors(spkdict, my_cell_index_set)
        t_s, state_values_dict = gid_state_values(spkdict, equilibration_duration, n_trials, env.t_rec, 
                                                  state_recs_dict)

        masked_spkdict = run_with(env, masked_run_params)
        masked_rates_dict = gid_firing_rate_vectors(masked_spkdict, my_cell_index_set)
        t_s, masked_state_values_dict = gid_state_values(masked_spkdict, equilibration_duration, n_trials, env.t_rec, 
                                                         state_recs_dict)
        
        
        result = {}
        for gid in my_cell_index_set:
            infld_idxs = infld_idxs_dict[gid]
            outfld_idxs = outfld_idxs_dict[gid]
            peak_idxs = peak_idxs_dict[gid]
            trough_idxs = trough_idxs_dict[gid]
            
            target_rate_vector = target_rate_vector_dict[gid]

            peak_ranges = peak_ranges_dict[gid]
            trough_ranges = trough_ranges_dict[gid]
            infld_ranges = infld_ranges_dict[gid]
            outfld_ranges = outfld_ranges_dict[gid]
            
            t_peak_idxs = np.concatenate([ np.where(np.logical_and(t_s >= r[0], t_s < r[1]))[0] for r in peak_ranges ])
            t_trough_idxs = np.concatenate([ np.where(np.logical_and(t_s >= r[0], t_s < r[1]))[0] for r in trough_ranges ])
            t_infld_idxs = np.concatenate([ np.where(np.logical_and(t_s >= r[0], t_s < r[1]))[0] for r in infld_ranges ])
            if outfld_ranges is not None:
                t_outfld_idxs = np.concatenate([ np.where(np.logical_and(t_s >= r[0], t_s < r[1]))[0] for r in outfld_ranges ])
            else:
                t_outfld_idxs = None
            
            masked_state_values = masked_state_values_dict.get(gid, None)
            state_values = state_values_dict[gid]
            rate_vectors = rates_dict[gid]
            masked_rate_vectors = masked_rates_dict[gid]
            
            logger.info(f'selectivity objective: max rates of gid {gid}: '
                        f'{list([np.max(rate_vector) for rate_vector in rate_vectors])}')

            infld_residuals, trial_rate_features, rate_features, rate_constr = \
              trial_snr_residuals(gid, peak_idxs, trough_idxs, infld_idxs, outfld_idxs, 
                                  rate_vectors, masked_rate_vectors, target_rate_vector)
            state_residuals, state_features = trial_state_residuals(gid, state_baseline,
                                                                    t_peak_idxs, t_trough_idxs, t_infld_idxs, t_outfld_idxs,
                                                                    state_values, masked_state_values)
            trial_obj_features = np.row_stack((infld_residuals, state_residuals))
            
            if trial_regime == 'mean':
                mean_infld_residual = np.mean(infld_residuals)
                mean_state_residual = np.mean(state_residuals)
                infld_objective = mean_infld_residual
                state_objective = abs(mean_state_residual)
                logger.info(f'selectivity objective: mean peak/trough/mean infld/mean outfld/mean state residual of gid {gid}: '
                            f'{mean_infld_residual:.04f} {mean_state_residual:.04f}')
            elif trial_regime == 'best':
                min_infld_residual_index = np.argmin(infld_residuals)
                min_infld_residual = infld_residuals[min_infld_index]
                infld_objective = min_infld_residual
                min_state_residual = np.min(np.abs(state_residuals))
                state_objective = min_state_residual
                logger.info(f'selectivity objective: mean peak/trough/max infld/max outfld/min state residual of gid {gid}: '
                            f'{min_infld_residual:.04f} {min_state_residual:.04f}')
            else:
                raise RuntimeError(f'selectivity_rate_objective: unknown trial regime {trial_regime}')

            logger.info(f"rate_features: {rate_features} state_features: {state_features} obj_features: {trial_obj_features}")

            result[gid] = (np.asarray([ infld_objective, state_objective ], 
                                      dtype=np.float32), 
                           np.array([tuple(rate_features+state_features+[trial_obj_features]+trial_rate_features)], 
                                    dtype=np.dtype(feature_dtypes)),
                           np.asarray(rate_constr, dtype=np.float32))
                           
        return result
    
    return opt_eval_fun(problem_regime, my_cell_index_set, eval_problem)
def main(config, config_prefix, include, forest_path, connectivity_path,
         connectivity_namespace, coords_path, coords_namespace,
         synapses_namespace, distances_namespace, resolution,
         interp_chunk_size, io_size, chunk_size, value_chunk_size, cache_size,
         write_size, verbose, dry_run, debug):

    utils.config_logging(verbose)
    logger = utils.get_script_logger(os.path.basename(__file__))

    comm = MPI.COMM_WORLD
    rank = comm.rank

    env = Env(comm=comm, config_file=config, config_prefix=config_prefix)
    configure_hoc_env(env)

    connection_config = env.connection_config
    extent = {}

    if (not dry_run) and (rank == 0):
        if not os.path.isfile(connectivity_path):
            input_file = h5py.File(coords_path, 'r')
            output_file = h5py.File(connectivity_path, 'w')
            input_file.copy('/H5Types', output_file)
            input_file.close()
            output_file.close()
    comm.barrier()

    population_ranges = read_population_ranges(coords_path)[0]
    populations = sorted(list(population_ranges.keys()))

    color = 0
    if rank == 0:
        color = 1
    comm0 = comm.Split(color, 0)

    soma_distances = {}
    soma_coords = {}
    for population in populations:
        if rank == 0:
            logger.info(f'Reading {population} coordinates...')
            coords_iter = read_cell_attributes(
                coords_path,
                population,
                comm=comm0,
                mask=set(['U Coordinate', 'V Coordinate', 'L Coordinate']),
                namespace=coords_namespace)
            distances_iter = read_cell_attributes(
                coords_path,
                population,
                comm=comm0,
                mask=set(['U Distance', 'V Distance']),
                namespace=distances_namespace)

            soma_coords[population] = {
                k: (float(v['U Coordinate'][0]), float(v['V Coordinate'][0]),
                    float(v['L Coordinate'][0]))
                for (k, v) in coords_iter
            }

            distances = {
                k: (float(v['U Distance'][0]), float(v['V Distance'][0]))
                for (k, v) in distances_iter
            }

            if len(distances) > 0:
                soma_distances[population] = distances

            gc.collect()

    comm.barrier()
    comm0.Free()

    soma_distances = comm.bcast(soma_distances, root=0)
    soma_coords = comm.bcast(soma_coords, root=0)

    forest_populations = sorted(read_population_names(forest_path))
    if (include is None) or (len(include) == 0):
        destination_populations = forest_populations
    else:
        destination_populations = []
        for p in include:
            if p in forest_populations:
                destination_populations.append(p)
    if rank == 0:
        logger.info(
            f'Generating connectivity for populations {destination_populations}...'
        )

    if len(soma_distances) == 0:
        (origin_ranges, ip_dist_u,
         ip_dist_v) = make_distance_interpolant(env,
                                                resolution=resolution,
                                                nsample=nsample)
        ip_dist = (origin_ranges, ip_dist_u, ip_dist_v)
        soma_distances = measure_distances(env,
                                           soma_coords,
                                           ip_dist,
                                           resolution=resolution)

    for destination_population in destination_populations:

        if rank == 0:
            logger.info(
                f'Generating connection probabilities for population {destination_population}...'
            )

        connection_prob = ConnectionProb(destination_population, soma_coords, soma_distances, \
                                         env.connection_extents)

        synapse_seed = int(
            env.model_config['Random Seeds']['Synapse Projection Partitions'])

        connectivity_seed = int(env.model_config['Random Seeds']
                                ['Distance-Dependent Connectivity'])
        cluster_seed = int(
            env.model_config['Random Seeds']['Connectivity Clustering'])

        if rank == 0:
            logger.info(
                f'Generating connections for population {destination_population}...'
            )

        populations_dict = env.model_config['Definitions']['Populations']
        generate_uv_distance_connections(comm,
                                         populations_dict,
                                         connection_config,
                                         connection_prob,
                                         forest_path,
                                         synapse_seed,
                                         connectivity_seed,
                                         cluster_seed,
                                         synapses_namespace,
                                         connectivity_namespace,
                                         connectivity_path,
                                         io_size,
                                         chunk_size,
                                         value_chunk_size,
                                         cache_size,
                                         write_size,
                                         dry_run=dry_run,
                                         debug=debug)
    MPI.Finalize()
Ejemplo n.º 15
0
def main(config, template_path, types_path, forest_path, connectivity_path,
         connectivity_namespace, coords_path, coords_namespace, io_size,
         chunk_size, value_chunk_size, cache_size, write_size, verbose,
         dry_run):

    utils.config_logging(verbose)
    logger = utils.get_script_logger(os.path.basename(__file__))

    comm = MPI.COMM_WORLD
    rank = comm.rank

    env = Env(comm=comm, config_file=config, template_paths=template_path)
    configure_hoc_env(env)

    gj_config = env.gapjunctions
    gj_seed = int(env.model_config['Random Seeds']['Gap Junctions'])

    soma_coords = {}

    if (not dry_run) and (rank == 0):
        if not os.path.isfile(connectivity_path):
            input_file = h5py.File(types_path, 'r')
            output_file = h5py.File(connectivity_path, 'w')
            input_file.copy('/H5Types', output_file)
            input_file.close()
            output_file.close()
    comm.barrier()

    population_ranges = read_population_ranges(coords_path)[0]
    populations = sorted(population_ranges.keys())

    if rank == 0:
        logger.info('Reading population coordinates...')

    soma_distances = {}
    for population in populations:
        coords_iter = bcast_cell_attributes(coords_path,
                                            population,
                                            0,
                                            namespace=coords_namespace)

        soma_coords[population] = {
            k:
            (v['X Coordinate'][0], v['Y Coordinate'][0], v['Z Coordinate'][0])
            for (k, v) in coords_iter
        }

        gc.collect()

    generate_gj_connections(env,
                            forest_path,
                            soma_coords,
                            gj_config,
                            gj_seed,
                            connectivity_namespace,
                            connectivity_path,
                            io_size,
                            chunk_size,
                            value_chunk_size,
                            cache_size,
                            dry_run=dry_run)

    MPI.Finalize()
Ejemplo n.º 16
0
def main(config_file, population, gid, template_paths, dataset_prefix,
         config_prefix, data_file, load_synapses, syn_types, syn_sources,
         syn_source_threshold, font_size, bgcolor, colormap, verbose):

    utils.config_logging(verbose)
    logger = utils.get_script_logger(script_name)

    if dataset_prefix is None and data_file is None:
        raise RuntimeError(
            'Either --dataset-prefix or --data-file must be provided.')

    params = dict(locals())
    env = Env(**params)
    configure_hoc_env(env)

    if env.data_file_path is None:
        env.data_file_path = data_file
        env.load_celltypes()

    ## Determine if a mechanism configuration file exists for this cell type
    if 'mech_file_path' in env.celltypes[population]:
        mech_file_path = env.celltypes[population]['mech_file_path']
    else:
        mech_file_path = None

    ## Determine if correct_for_spines flag has been specified for this cell type
    synapse_config = env.celltypes[population]['synapses']
    if 'correct_for_spines' in synapse_config:
        correct_for_spines_flag = synapse_config['correct_for_spines']
    else:
        correct_for_spines_flag = False

    logger.info('loading cell %i' % gid)

    load_weights = False
    biophys_cell = make_biophys_cell(env,
                                     population,
                                     gid,
                                     load_synapses=load_synapses,
                                     load_weights=load_weights,
                                     load_edges=load_synapses,
                                     mech_file_path=mech_file_path)

    init_biophysics(biophys_cell,
                    reset_cable=True,
                    correct_cm=correct_for_spines_flag,
                    correct_g_pas=correct_for_spines_flag,
                    env=env)
    report_topology(biophys_cell, env)

    if len(syn_types) == 0:
        syn_types = None
    else:
        syn_types = list(syn_types)
    if len(syn_sources) == 0:
        syn_sources = None
    else:
        syn_sources = list(syn_sources)

    plot.plot_biophys_cell_tree(env,
                                biophys_cell,
                                saveFig=True,
                                syn_source_threshold=syn_source_threshold,
                                synapse_filters={
                                    'syn_types': syn_types,
                                    'sources': syn_sources
                                },
                                bgcolor=bgcolor,
                                colormap=colormap)
Ejemplo n.º 17
0
def main(config, config_prefix, template_path, output_path, forest_path,
         populations, distribution, io_size, chunk_size, value_chunk_size,
         cache_size, write_size, verbose, dry_run):
    """

    :param config:
    :param config_prefix:
    :param template_path:
    :param forest_path:
    :param populations:
    :param distribution:
    :param io_size:
    :param chunk_size:
    :param value_chunk_size:
    :param cache_size:
    """

    utils.config_logging(verbose)
    logger = utils.get_script_logger(os.path.basename(__file__))

    comm = MPI.COMM_WORLD
    rank = comm.rank

    if rank == 0:
        logger.info('%i ranks have been allocated' % comm.size)

    env = Env(comm=MPI.COMM_WORLD,
              config_file=config,
              config_prefix=config_prefix,
              template_paths=template_path)

    configure_hoc_env(env)

    if io_size == -1:
        io_size = comm.size

    if output_path is None:
        output_path = forest_path

    if not dry_run:
        if rank == 0:
            if not os.path.isfile(output_path):
                input_file = h5py.File(forest_path, 'r')
                output_file = h5py.File(output_path, 'w')
                input_file.copy('/H5Types', output_file)
                input_file.close()
                output_file.close()
        comm.barrier()

    (pop_ranges, _) = read_population_ranges(forest_path, comm=comm)
    start_time = time.time()
    syn_stats = {}
    for population in populations:
        logger.info('Rank %i population: %s' % (rank, population))
        (population_start, _) = pop_ranges[population]
        template_class = load_cell_template(env, population)

        density_dict = env.celltypes[population]['synapses']['density']
        layer_set_dict = defaultdict(set)
        swc_set_dict = defaultdict(set)
        for sec_name, sec_dict in viewitems(density_dict):
            for syn_type, syn_dict in viewitems(sec_dict):
                swc_set_dict[syn_type].add(env.SWC_Types[sec_name])
                for layer_name in syn_dict:
                    if layer_name != 'default':
                        layer = env.layers[layer_name]
                        layer_set_dict[syn_type].add(layer)

        syn_stats_dict = { 'section': defaultdict(lambda: { 'excitatory': 0, 'inhibitory': 0 }), \
                           'layer': defaultdict(lambda: { 'excitatory': 0, 'inhibitory': 0 }), \
                           'swc_type': defaultdict(lambda: { 'excitatory': 0, 'inhibitory': 0 }), \
                           'total': { 'excitatory': 0, 'inhibitory': 0 } }

        count = 0
        gid_count = 0
        synapse_dict = {}
        for gid, morph_dict in NeuroH5TreeGen(forest_path,
                                              population,
                                              io_size=io_size,
                                              comm=comm,
                                              topology=True):
            local_time = time.time()
            if gid is not None:
                logger.info('Rank %i gid: %i' % (rank, gid))
                cell = cells.make_neurotree_cell(template_class,
                                                 neurotree_dict=morph_dict,
                                                 gid=gid)
                cell_sec_dict = {
                    'apical': (cell.apical, None),
                    'basal': (cell.basal, None),
                    'soma': (cell.soma, None),
                    'ais': (cell.ais, None),
                    'hillock': (cell.hillock, None)
                }
                cell_secidx_dict = {
                    'apical': cell.apicalidx,
                    'basal': cell.basalidx,
                    'soma': cell.somaidx,
                    'ais': cell.aisidx,
                    'hillock': cell.hilidx
                }

                random_seed = env.model_config['Random Seeds'][
                    'Synapse Locations'] + gid
                if distribution == 'uniform':
                    syn_dict, seg_density_per_sec = synapses.distribute_uniform_synapses(
                        random_seed, env.Synapse_Types, env.SWC_Types,
                        env.layers, density_dict, morph_dict, cell_sec_dict,
                        cell_secidx_dict)

                elif distribution == 'poisson':
                    syn_dict, seg_density_per_sec = synapses.distribute_poisson_synapses(
                        random_seed, env.Synapse_Types, env.SWC_Types,
                        env.layers, density_dict, morph_dict, cell_sec_dict,
                        cell_secidx_dict)
                else:
                    raise Exception('Unknown distribution type: %s' %
                                    distribution)

                synapse_dict[gid] = syn_dict
                this_syn_stats = update_syn_stats(env, syn_stats_dict,
                                                  syn_dict)
                check_syns(gid, morph_dict, this_syn_stats,
                           seg_density_per_sec, layer_set_dict, swc_set_dict,
                           env, logger)

                del cell
                num_syns = len(synapse_dict[gid]['syn_ids'])
                logger.info(
                    'Rank %i took %i s to compute %d synapse locations for %s gid: %i'
                    % (rank, time.time() - local_time, num_syns, population,
                       gid))
                logger.info(
                    '%s gid %i synapses: %s' %
                    (population, gid, local_syn_summary(this_syn_stats)))
                gid_count += 1
            else:
                logger.info('Rank %i gid is None' % rank)
            if (not dry_run) and (gid_count % write_size == 0):
                append_cell_attributes(output_path,
                                       population,
                                       synapse_dict,
                                       namespace='Synapse Attributes',
                                       comm=comm,
                                       io_size=io_size,
                                       chunk_size=chunk_size,
                                       value_chunk_size=value_chunk_size,
                                       cache_size=cache_size)
                synapse_dict = {}
                gc.collect()
            syn_stats[population] = syn_stats_dict
            count += 1

        if not dry_run:
            append_cell_attributes(output_path,
                                   population,
                                   synapse_dict,
                                   namespace='Synapse Attributes',
                                   comm=comm,
                                   io_size=io_size,
                                   chunk_size=chunk_size,
                                   value_chunk_size=value_chunk_size,
                                   cache_size=cache_size)

        global_count = comm.gather(gid_count, root=0)

        if gid_count > 0:
            color = 1
        else:
            color = 0

        comm0 = comm.Split(color, 0)
        if color == 1:
            summary = global_syn_summary(comm0,
                                         syn_stats,
                                         np.sum(global_count),
                                         root=0)
            if rank == 0:
                logger.info(
                    'target: %s, %i ranks took %i s to compute synapse locations for %i cells'
                    % (population, comm.size, time.time() - start_time,
                       np.sum(global_count)))
                logger.info(summary)
        comm0.Free()
        comm.barrier()

    MPI.Finalize()