Esempio n. 1
0
    def __init__(self, destination_population, soma_coords, soma_distances, extents):
        """
        Warning: This method does not produce an absolute probability. It must be normalized so that the total area
        (volume) under the distribution is 1 before sampling.
        :param destination_population: post-synaptic population name
        :param soma_distances: a dictionary that contains per-population dicts of u, v distances of cell somas
        :param extent: dict: {source: 'width': (tuple of float), 'offset': (tuple of float)}
        """
        self.destination_population = destination_population
        self.soma_coords = soma_coords
        self.soma_distances = soma_distances
        self.p_dist = defaultdict(dict)
        self.width = defaultdict(dict)
        self.offset = defaultdict(dict)
        self.scale_factor = defaultdict(dict)

        for source_population, layer_extents in viewitems(extents):

            for layer, extents in viewitems(layer_extents):

                extent_width = extents['width']
                if 'offset' in extents:
                    extent_offset = extents['offset']
                else:
                    extent_offset = (0., 0.)

                u_extent = (float(extent_width[0]) / 2.0) - float(extent_offset[0])
                v_extent = (float(extent_width[1]) / 2.0) - float(extent_offset[1])
                self.width[source_population][layer] = {'u': u_extent, 'v': v_extent}

                self.scale_factor[source_population][layer] = \
                    {axis: self.width[source_population][layer][axis] / 3. \
                     for axis in self.width[source_population][layer]}

                if extent_offset is None:
                    self.offset[source_population][layer] = {'u': 0., 'v': 0.}
                else:
                    self.offset[source_population][layer] = {'u': float(extent_offset[0]), \
                                                             'v': float(extent_offset[1])}

                self.p_dist[source_population][layer] = \
                    (lambda source_population, layer: \
                         np.vectorize(lambda distance_u, distance_v: \
                                          (norm.pdf(np.abs(distance_u) - self.offset[source_population][layer]['u'], \
                                                    scale=self.scale_factor[source_population][layer]['u']) * \
                                           norm.pdf(np.abs(distance_v) - self.offset[source_population][layer]['v'], \
                                                    scale=self.scale_factor[source_population][layer]['v'])), \
                                      otypes=[float]))(source_population, layer)

                logger.info(f"population {source_population}: layer: {layer}: \n"
                            f"u width: {self.width[source_population][layer]['u']}\n"
                            f"v width: {self.width[source_population][layer]['v']}\n"
                            f"u scale_factor: {self.scale_factor[source_population][layer]['u']}\n"
                            f"v scale_factor: {self.scale_factor[source_population][layer]['v']}\n")
Esempio n. 2
0
def import_celltypes(celltype_path, output_path):

    import csv

    population_dict = {}

    with open(celltype_path, mode='r') as infile:

        reader = csv.DictReader(infile, delimiter="\t")
        for row in reader:
            celltype = row['celltype']
            type_index = int(row['typeIndex'])
            range_start = int(row['rangeStart'])
            range_end = int(row['rangeEnd'])
            count = range_end - range_start + 1
            population_dict[celltype] = (type_index, count)

    populations = []
    for pop_name, pop_info in viewitems(population_dict):
        pop_idx = pop_info[0]
        pop_count = pop_info[1]
        populations.append((pop_name, pop_idx, pop_count))
    populations.sort(key=lambda x: x[1])
    min_pop_idx = populations[0][1]

    # create an HDF5 enumerated type for the population label
    mapping = {name: idx for name, idx, count in populations}
    dt_population_labels = h5py.special_dtype(enum=(np.uint16, mapping))

    with h5py.File(output_path, "x") as h5:

        h5[path_population_labels] = dt_population_labels

        dt_populations = np.dtype([("Start", np.uint64), ("Count", np.uint32),
                                   ("Population",
                                    h5[path_population_labels].dtype)])
        h5[path_population_range] = dt_populations

        # create an HDF5 compound type for population ranges
        dt = h5[path_population_range].dtype

        g = h5_get_group(h5, grp_h5types)

        dset = h5_get_dataset(g,
                              grp_populations,
                              maxshape=(len(populations), ),
                              dtype=dt)
        dset.resize((len(populations), ))
        a = np.zeros(len(populations), dtype=dt)
        start = 0
        for name, idx, count in populations:
            a[idx - min_pop_idx]["Start"] = start
            a[idx - min_pop_idx]["Count"] = count
            a[idx - min_pop_idx]["Population"] = idx
            start += count
        dset[:] = a

    h5.close()
    return populations
Esempio n. 3
0
    def filter_by_distance(self, destination_gid, source_population, source_layer):
        """
        Given the id of a target neuron, returns the distances along u and v
        and the gids of source neurons whose axons potentially contact the target neuron.

        :param destination_gid: int
        :param source_population: string
        :return: tuple of array of int
        """
        destination_coords = self.soma_coords[self.destination_population][destination_gid]
        source_coords = self.soma_coords[source_population]

        destination_distances = self.soma_distances[self.destination_population][destination_gid]

        source_distances = self.soma_distances[source_population]

        destination_u, destination_v, destination_l = destination_coords
        destination_distance_u, destination_distance_v = destination_distances

        distance_u_lst = []
        distance_v_lst = []
        source_u_lst = []
        source_v_lst = []
        source_gid_lst = []

        if source_layer in self.width[source_population]:
            layer_key = source_layer
        elif 'default' in self.width[source_population]:
            layer_key = 'default'
        else:
            raise RuntimeError(f'connection_generator.get_prob: gid {destination_gid}: missing configuration for {source_population} layer {source_layer}')

        source_width = self.width[source_population][layer_key]
        source_offset = self.offset[source_population][layer_key]

        max_distance_u = source_width['u'] + source_offset['u']
        max_distance_v = source_width['v'] + source_offset['v']

        for (source_gid, coords) in viewitems(source_coords):

            source_u, source_v, source_l = coords

            source_distance_u, source_distance_v = source_distances[source_gid]

            distance_u = abs(destination_distance_u - source_distance_u)
            distance_v = abs(destination_distance_v - source_distance_v)

            if ((max_distance_u - distance_u) > 0.0) and ((max_distance_v - distance_v) > 0.0):
                source_u_lst.append(source_u)
                source_v_lst.append(source_v)
                distance_u_lst.append(distance_u)
                distance_v_lst.append(distance_v)
                source_gid_lst.append(source_gid)

        return destination_u, destination_v, np.asarray(source_u_lst), np.asarray(source_v_lst), np.asarray(
            distance_u_lst), np.asarray(distance_v_lst), np.asarray(source_gid_lst, dtype=np.uint32)
Esempio n. 4
0
def spike_bin_counts(spkdict, time_bins):
    bin_dict = {}
    for (ind, lst) in viewitems(spkdict):

        if len(lst) > 0:
            spkts = np.asarray(lst, dtype=np.float32)
            bins, bin_edges = np.histogram(spkts, bins=time_bins)

            bin_dict[ind] = bins

    return bin_dict
Esempio n. 5
0
def choose_synapse_projection(ranstream_syn, syn_layer, swc_type, syn_type, population_dict, projection_synapse_dict,
                              log=False):
    """
    Given a synapse projection, SWC synapse location, and synapse type,
    chooses a projection from the given projection dictionary based on
    1) whether the projection properties match the given synapse
    properties and 2) random choice between all the projections that
    satisfy the given criteria.

    :param ranstream_syn: random state object
    :param syn_layer: synapse layer
    :param swc_type: SWC location for synapse (soma, axon, apical, basal)
    :param syn_type: synapse type (excitatory, inhibitory, neuromodulatory)
    :param population_dict: mapping of population names to population indices
    :param projection_synapse_dict: mapping of projection names to a tuple of the form: <type, layers, swc sections, proportions>

    """
    ivd = {v: k for k, v in viewitems(population_dict)}
    projection_lst = []
    projection_prob_lst = []
    for k, (
    syn_config_type, syn_config_layers, syn_config_sections, syn_config_proportions, syn_config_contacts) in viewitems(
            projection_synapse_dict):
        if (syn_type == syn_config_type) and (swc_type in syn_config_sections):
            ord_indices = list_find_all(lambda x: x == swc_type, syn_config_sections)
            for ord_index in ord_indices:
                if syn_layer == syn_config_layers[ord_index]:
                    projection_lst.append(population_dict[k])
                    projection_prob_lst.append(syn_config_proportions[ord_index])
    if len(projection_lst) > 1:
        candidate_projections = np.asarray(projection_lst)
        candidate_probs = np.asarray(projection_prob_lst)
        if log:
            logger.info(f"candidate_projections: {candidate_projections} candidate_probs: {candidate_probs}")
        projection = ranstream_syn.choice(candidate_projections, 1, p=candidate_probs)[0]
    elif len(projection_lst) > 0:
        projection = projection_lst[0]
    else:
        projection = None

    if projection is None:
        logger.error(f'Projection is none for syn_type {syn_type}, syn_layer {syn_layer} swc_type {swc_type}\n'
                     f'projection synapse dict: {pprint.pformat(projection_synapse_dict)}')

    if projection is not None:
        return ivd[projection]
    else:
        return None
Esempio n. 6
0
def import_morphology_from_hoc(cell, hoc_cell, section_content=None):
    """
    Append sections from an existing instance of a NEURON cell template to a Python cell wrapper.
    :param cell: :class:'BiophysCell'
    :param hoc_cell: :class:'h.hocObject': instance of a NEURON cell template
    
    """
    sec_info_dict = {}
    root_sec = None
    for sec_type, sec_index_list in viewitems(default_hoc_sec_lists):
        hoc_sec_attr_name = sec_type
        if not hasattr(hoc_cell, hoc_sec_attr_name):
            hoc_sec_attr_name = f'{sec_type}_list'
        if hasattr(hoc_cell, hoc_sec_attr_name) and (getattr(
                hoc_cell, hoc_sec_attr_name) is not None):
            sec_list = list(getattr(hoc_cell, hoc_sec_attr_name))
            if hasattr(hoc_cell, sec_index_list):
                sec_indexes = list(getattr(hoc_cell, sec_index_list))
            else:
                raise AttributeError(
                    'import_morphology_from_hoc: %s is not an attribute of the hoc cell'
                    % sec_index_list)
            if sec_type == 'soma':
                root_sec = sec_list[0]
            for sec, index in zip(sec_list, sec_indexes):
                if section_content is not None:
                    sec_info_dict[sec] = {
                        'section_type': sec_type,
                        'section_index': int(index),
                        'section_content': section_content[index]
                    }
                else:
                    sec_info_dict[sec] = {
                        'section_type': sec_type,
                        'section_index': int(index)
                    }
    if root_sec:
        insert_section_tree(cell, [root_sec], sec_info_dict)
    else:
        raise RuntimeError(
            f'import_morphology_from_hoc: unable to locate root section')
Esempio n. 7
0
def write_input_cell_selection(env,
                               input_sources,
                               write_selection_file_path,
                               populations=None,
                               write_kwds={}):
    """
    Writes out predefined spike trains when only a subset of the network is instantiated.

    :param env: an instance of the `Env` class
    :param input_sources: a dictionary of the form { pop_name, gid_sources }
    """

    if 'comm' not in write_kwds:
        write_kwds['comm'] = env.comm
    if 'io_size' not in write_kwds:
        write_kwds['io_size'] = env.io_size

    rank = int(env.comm.Get_rank())
    nhosts = int(env.comm.Get_size())

    dataset_path = env.dataset_path
    input_file_path = env.data_file_path

    if populations is None:
        pop_names = sorted(env.celltypes.keys())
    else:
        pop_names = populations

    for pop_name, gid_range in sorted(viewitems(input_sources)):

        gc.collect()

        if pop_name not in pop_names:
            continue

        spikes_output_dict = {}

        if (env.cell_selection is not None) and (pop_name
                                                 in env.cell_selection):
            local_gid_range = gid_range.difference(
                set(env.cell_selection[pop_name]))
        else:
            local_gid_range = gid_range

        gid_range = env.comm.allreduce(local_gid_range, op=mpi_op_set_union)
        this_gid_range = set([])
        for i, gid in enumerate(gid_range):
            if i % nhosts == rank:
                this_gid_range.add(gid)

        has_spike_train = False
        spike_input_source_loc = []
        if (env.spike_input_attribute_info is not None) and (env.spike_input_ns
                                                             is not None):
            if (pop_name in env.spike_input_attribute_info) and \
                    (env.spike_input_ns in env.spike_input_attribute_info[pop_name]):
                has_spike_train = True
                spike_input_source_loc.append(
                    (env.spike_input_path, env.spike_input_ns))
        if (env.cell_attribute_info is not None) and (env.spike_input_ns
                                                      is not None):
            if (pop_name in env.cell_attribute_info) and \
                    (env.spike_input_ns in env.cell_attribute_info[pop_name]):
                has_spike_train = True
                spike_input_source_loc.append(
                    (input_file_path, env.spike_input_ns))

        if rank == 0:
            logger.info(
                '*** Reading spike trains for population %s: %d cells: has_spike_train = %s'
                % (pop_name, len(this_gid_range), str(has_spike_train)))

        if has_spike_train:

            vecstim_attr_set = set(['t'])
            if env.spike_input_attr is not None:
                vecstim_attr_set.add(env.spike_input_attr)
            if 'spike train' in env.celltypes[pop_name]:
                vecstim_attr_set.add(
                    env.celltypes[pop_name]['spike train']['attribute'])

            cell_spikes_iters = [ scatter_read_cell_attribute_selection(input_path, pop_name, \
                                                                        list(this_gid_range), \
                                                                        namespace=input_ns, \
                                                                        mask=vecstim_attr_set, \
                                                                        comm=env.comm, io_size=env.io_size)
                                  for (input_path, input_ns) in spike_input_source_loc ]

            for cell_spikes_iter in cell_spikes_iters:
                spikes_output_dict.update(dict(list(cell_spikes_iter)))

        if rank == 0:
            logger.info('*** Writing spike trains for population %s: %s' %
                        (pop_name, str(spikes_output_dict)))


        write_cell_attributes(write_selection_file_path, pop_name, spikes_output_dict,  \
                              namespace=env.spike_input_ns, **write_kwds)
Esempio n. 8
0
def write_connection_selection(env,
                               write_selection_file_path,
                               populations=None,
                               write_kwds={}):
    """
    Loads NeuroH5 connectivity file, and writes the corresponding
    synapse and network connection mechanisms for the selected postsynaptic cells.

    :param env: an instance of the `Env` class
    """

    if 'comm' not in write_kwds:
        write_kwds['comm'] = env.comm
    if 'io_size' not in write_kwds:
        write_kwds['io_size'] = env.io_size

    connectivity_file_path = env.connectivity_file_path
    forest_file_path = env.forest_file_path
    rank = int(env.comm.Get_rank())
    nhosts = int(env.comm.Get_size())
    syn_attrs = env.synapse_attributes

    if populations is None:
        pop_names = sorted(env.cell_selection.keys())
    else:
        pop_names = populations

    input_sources = {pop_name: set([]) for pop_name in env.celltypes}

    for (postsyn_name, presyn_names) in sorted(viewitems(env.projection_dict)):

        gc.collect()

        if rank == 0:
            logger.info('*** Writing connection selection of population %s' %
                        (postsyn_name))

        if postsyn_name not in pop_names:
            continue

        gid_range = [
            gid for i, gid in enumerate(env.cell_selection[postsyn_name])
            if i % nhosts == rank
        ]

        synapse_config = env.celltypes[postsyn_name]['synapses']

        weight_dicts = []
        has_weights = False
        if 'weights' in synapse_config:
            has_weights = True
            weight_dicts = synapse_config['weights']

        if rank == 0:
            logger.info('*** Reading synaptic attributes for population %s' %
                        (postsyn_name))

        syn_attributes_iter = scatter_read_cell_attribute_selection(
            forest_file_path,
            postsyn_name,
            selection=gid_range,
            namespace='Synapse Attributes',
            comm=env.comm,
            io_size=env.io_size)

        syn_attributes_output_dict = dict(list(syn_attributes_iter))
        write_cell_attributes(write_selection_file_path,
                              postsyn_name,
                              syn_attributes_output_dict,
                              namespace='Synapse Attributes',
                              **write_kwds)
        del syn_attributes_output_dict
        del syn_attributes_iter

        if has_weights:

            for weight_dict in weight_dicts:

                weights_namespaces = weight_dict['namespace']

                if rank == 0:
                    logger.info(
                        '*** Reading synaptic weights of population %s from namespaces %s'
                        % (postsyn_name, str(weights_namespaces)))

                for weights_namespace in weights_namespaces:
                    syn_weights_iter = scatter_read_cell_attribute_selection(
                        forest_file_path,
                        postsyn_name,
                        namespace=weights_namespace,
                        selection=gid_range,
                        comm=env.comm,
                        io_size=env.io_size)

                    weight_attributes_output_dict = dict(
                        list(syn_weights_iter))
                    write_cell_attributes(write_selection_file_path,
                                          postsyn_name,
                                          weight_attributes_output_dict,
                                          namespace=weights_namespace,
                                          **write_kwds)
                    del weight_attributes_output_dict
                    del syn_weights_iter

        logger.info(
            '*** Rank %i: reading connectivity selection from file %s for postsynaptic population: %s: selection: %s'
            % (rank, connectivity_file_path, postsyn_name, str(gid_range)))

        (graph, attr_info) = scatter_read_graph_selection(connectivity_file_path, selection=gid_range, \
                                                          projections=[ (presyn_name, postsyn_name) for presyn_name in sorted(presyn_names) ], \
                                                          comm=env.comm, io_size=env.io_size, namespaces=['Synapses', 'Connections'])

        for presyn_name in sorted(presyn_names):
            gid_dict = {}
            edge_count = 0
            node_count = 0
            if postsyn_name in graph:

                if postsyn_name in attr_info and presyn_name in attr_info[
                        postsyn_name]:
                    edge_attr_info = attr_info[postsyn_name][presyn_name]
                else:
                    raise RuntimeError('write_connection_selection: missing edge attributes for projection %s -> %s' % \
                                       (presyn_name, postsyn_name))

                if 'Synapses' in edge_attr_info and \
                        'syn_id' in edge_attr_info['Synapses'] and \
                        'Connections' in edge_attr_info and \
                        'distance' in edge_attr_info['Connections']:
                    syn_id_attr_index = edge_attr_info['Synapses']['syn_id']
                    distance_attr_index = edge_attr_info['Connections'][
                        'distance']
                else:
                    raise RuntimeError('write_connection_selection: missing edge attributes for projection %s -> %s' % \
                                           (presyn_name, postsyn_name))

                edge_iter = compose_iter(lambda edgeset: input_sources[presyn_name].update(edgeset[1][0]), \
                                         graph[postsyn_name][presyn_name])
                for (postsyn_gid, edges) in edge_iter:

                    presyn_gids, edge_attrs = edges
                    edge_syn_ids = edge_attrs['Synapses'][syn_id_attr_index]
                    edge_dists = edge_attrs['Connections'][distance_attr_index]

                    gid_dict[postsyn_gid] = (presyn_gids, {
                        'Synapses': {
                            'syn_id': edge_syn_ids
                        },
                        'Connections': {
                            'distance': edge_dists
                        }
                    })
                    edge_count += len(presyn_gids)
                    node_count += 1

            env.comm.barrier()
            logger.info(
                '*** Rank %d: Writing projection %s -> %s selection: %d nodes, %d edges'
                % (rank, presyn_name, postsyn_name, node_count, edge_count))
            write_graph(write_selection_file_path, \
                        src_pop_name=presyn_name, dst_pop_name=postsyn_name, \
                        edges=gid_dict, comm=env.comm, io_size=env.io_size)
            env.comm.barrier()

    return input_sources
Esempio n. 9
0
def make_h5types(env, output_path, gap_junctions=False):
    populations = []
    for pop_name, pop_idx in viewitems(env.Populations):
        layer_counts = env.geometry['Cell Distribution'][pop_name]
        pop_count = 0
        for layer_name, layer_count in viewitems(layer_counts):
            pop_count += layer_count
        populations.append((pop_name, pop_idx, pop_count))
    populations.sort(key=lambda x: x[1])
    min_pop_idx = populations[0][1]

    projections = []
    if gap_junctions:
        for (post, pre), connection_dict in viewitems(env.gapjunctions):
            projections.append((env.Populations[pre], env.Populations[post]))
    else:
        for post, connection_dict in viewitems(env.connection_config):
            for pre, _ in viewitems(connection_dict):
                projections.append(
                    (env.Populations[pre], env.Populations[post]))

    # create an HDF5 enumerated type for the population label
    mapping = {name: idx for name, idx in viewitems(env.Populations)}
    dt_population_labels = h5py.special_dtype(enum=(np.uint16, mapping))

    with h5py.File(output_path, "a") as h5:

        h5[path_population_labels] = dt_population_labels

        dt_populations = np.dtype([("Start", np.uint64), ("Count", np.uint32),
                                   ("Population",
                                    h5[path_population_labels].dtype)])
        h5[path_population_range] = dt_populations

        # create an HDF5 compound type for population ranges
        dt = h5[path_population_range].dtype

        g = h5_get_group(h5, grp_h5types)

        dset = h5_get_dataset(g,
                              grp_populations,
                              maxshape=(len(populations), ),
                              dtype=dt)
        dset.resize((len(populations), ))
        a = np.zeros(len(populations), dtype=dt)
        start = 0
        for name, idx, count in populations:
            a[idx - min_pop_idx]["Start"] = start
            a[idx - min_pop_idx]["Count"] = count
            a[idx - min_pop_idx]["Population"] = idx
            start += count
        dset[:] = a

        dt_projections = np.dtype([
            ("Source", h5[path_population_labels].dtype),
            ("Destination", h5[path_population_labels].dtype)
        ])

        h5[path_population_projections] = dt_projections

        dt = h5[path_population_projections]
        dset = h5_get_dataset(g,
                              grp_valid_population_projections,
                              maxshape=(len(projections), ),
                              dtype=dt)
        dset.resize((len(projections), ))
        a = np.zeros(len(projections), dtype=dt)
        idx = 0
        for i, prj in enumerate(projections):
            src, dst = prj
            a[i]["Source"] = int(src)
            a[i]["Destination"] = int(dst)

        dset[:] = a

    h5.close()
Esempio n. 10
0
def import_spikeraster(celltype_path,
                       spikeraster_path,
                       output_path,
                       output_npy=False,
                       namespace="Spike Data",
                       progress=False,
                       comm=None):

    if progress:
        import tqdm

    if comm is None:
        comm = MPI.COMM_WORLD

    populations = import_celltypes(celltype_path, output_path)
    n_pop = len(populations)

    start = 0
    pop_range_bins = []
    for name, idx, count in populations[:-1]:
        pop_range_bins.append(start + count)
        start = start + count

    logger.info(
        f"populations: {populations} total: {start} pop_range_bins: {pop_range_bins}"
    )

    logger.info(f"Reading spike data from file {spikeraster_path}...")

    if spikeraster_path.endswith('.npy'):
        spike_array = np.load(spikeraster_path)
    else:
        spike_array = np.loadtxt(spikeraster_path,
                                 dtype=np.dtype([("time", np.float32),
                                                 ("gid", np.uint32)]))

    if output_npy:
        np.save(f'{spikeraster_path}.npy', spike_array)

    logger.info(f"Done reading spike data from file {spikeraster_path}")

    gid_array = spike_array['gid']
    gid_bins = np.digitize(gid_array, np.asarray(pop_range_bins))

    pop_spk_dict = defaultdict(lambda: defaultdict(list))
    if progress:
        it = tqdm.tqdm(enumerate(zip(gid_array, gid_bins)), unit_scale=True)
    else:
        it = enumerate(zip(gid_array, gid_bins))

    for i, (gid, pop_idx) in it:

        pop_name = populations[pop_idx][0]
        pop_start = populations[pop_idx][0]
        spk_t = spike_array["time"][i]

        pop_spk_dict[pop_name][gid].append(spk_t)

    for pop_name, _, _ in populations:

        this_spk_dict = pop_spk_dict[pop_name]
        logger.info(
            f"Saving spike data for population {pop_name} gid set {sorted(this_spk_dict.keys())}"
        )
        output_dict = {
            gid: {
                't': np.asarray(spk_ts, dtype=np.float32)
            }
            for gid, spk_ts in viewitems(this_spk_dict)
        }

        write_cell_attributes(output_path,
                              pop_name,
                              output_dict,
                              namespace=namespace,
                              comm=comm)
        logger.info(
            f"Saved spike data for population {pop_name} to file {output_path}"
        )

    comm.barrier()
Esempio n. 11
0
def main(config, config_prefix, types_path, geometry_path, output_path,
         output_namespace, populations, resolution, alpha_radius, nodeiter,
         dispersion_delta, snap_delta, io_size, chunk_size, value_chunk_size,
         verbose):

    config_logging(verbose)
    logger = get_script_logger(script_name)

    comm = MPI.COMM_WORLD
    rank = comm.rank
    size = comm.size

    np.seterr(all='raise')

    if io_size == -1:
        io_size = comm.size
    if rank == 0:
        logger.info('%i ranks have been allocated' % comm.size)

    if rank == 0:
        if not os.path.isfile(output_path):
            input_file = h5py.File(types_path, 'r')
            output_file = h5py.File(output_path, 'w')
            input_file.copy('/H5Types', output_file)
            input_file.close()
            output_file.close()
    comm.barrier()

    env = Env(comm=comm, config_file=config, config_prefix=config_prefix)

    random_seed = int(env.model_config['Random Seeds']['Soma Locations'])
    random.seed(random_seed)

    layer_extents = env.geometry['Parametric Surface']['Layer Extents']
    rotate = env.geometry['Parametric Surface']['Rotation']

    (extent_u, extent_v, extent_l) = get_total_extents(layer_extents)
    vol = make_CA1_volume(extent_u,
                          extent_v,
                          extent_l,
                          rotate=rotate,
                          resolution=resolution)
    layer_alpha_shape_path = 'Layer Alpha Shape/%d/%d/%d' % resolution
    if rank == 0:
        logger.info("Constructing alpha shape for volume: extents: %s..." %
                    str((extent_u, extent_v, extent_l)))
        vol_alpha_shape_path = '%s/all' % (layer_alpha_shape_path)
        if geometry_path:
            vol_alpha_shape = load_alpha_shape(geometry_path,
                                               vol_alpha_shape_path)
        else:
            vol_alpha_shape = make_alpha_shape(vol, alpha_radius=alpha_radius)
            if geometry_path:
                save_alpha_shape(geometry_path, vol_alpha_shape_path,
                                 vol_alpha_shape)
        vert = vol_alpha_shape.points
        smp = np.asarray(vol_alpha_shape.bounds, dtype=np.int64)
        vol_domain = (vert, smp)

    layer_alpha_shapes = {}
    layer_extent_vals = {}
    layer_extent_transformed_vals = {}
    if rank == 0:
        for layer, extents in viewitems(layer_extents):
            (extent_u, extent_v,
             extent_l) = get_layer_extents(layer_extents, layer)
            layer_extent_vals[layer] = (extent_u, extent_v, extent_l)
            layer_extent_transformed_vals[layer] = CA1_volume_transform(
                extent_u, extent_v, extent_l)
            has_layer_alpha_shape = False
            if geometry_path:
                this_layer_alpha_shape_path = '%s/%s' % (
                    layer_alpha_shape_path, layer)
                this_layer_alpha_shape = load_alpha_shape(
                    geometry_path, this_layer_alpha_shape_path)
                layer_alpha_shapes[layer] = this_layer_alpha_shape
                if this_layer_alpha_shape is not None:
                    has_layer_alpha_shape = True
            if not has_layer_alpha_shape:
                logger.info(
                    "Constructing alpha shape for layers %s: extents: %s..." %
                    (layer, str(extents)))
                layer_vol = make_CA1_volume(extent_u,
                                            extent_v,
                                            extent_l,
                                            rotate=rotate,
                                            resolution=resolution)
                this_layer_alpha_shape = make_alpha_shape(
                    layer_vol, alpha_radius=alpha_radius)
                layer_alpha_shapes[layer] = this_layer_alpha_shape
                if geometry_path:
                    save_alpha_shape(geometry_path,
                                     this_layer_alpha_shape_path,
                                     this_layer_alpha_shape)

    comm.barrier()
    population_ranges = read_population_ranges(output_path, comm)[0]
    if len(populations) == 0:
        populations = sorted(population_ranges.keys())

    total_count = 0
    for population in populations:
        (population_start, population_count) = population_ranges[population]
        total_count += population_count

    all_xyz_coords1 = None
    generated_coords_count_dict = defaultdict(int)
    if rank == 0:
        all_xyz_coords_lst = []
        for population in populations:
            gc.collect()

            (population_start,
             population_count) = population_ranges[population]

            pop_layers = env.geometry['Cell Distribution'][population]
            pop_constraint = None
            if 'Cell Constraints' in env.geometry:
                if population in env.geometry['Cell Constraints']:
                    pop_constraint = env.geometry['Cell Constraints'][
                        population]
            if rank == 0:
                logger.info("Population %s: layer distribution is %s" %
                            (population, str(pop_layers)))

            pop_layer_count = 0
            for layer, count in viewitems(pop_layers):
                pop_layer_count += count
            assert (population_count == pop_layer_count)

            xyz_coords_lst = []
            for layer, count in viewitems(pop_layers):
                if count <= 0:
                    continue

                alpha = layer_alpha_shapes[layer]

                vert = alpha.points
                smp = np.asarray(alpha.bounds, dtype=np.int64)

                extents_xyz = layer_extent_transformed_vals[layer]
                for (vvi, vv) in enumerate(vert):
                    for (vi, v) in enumerate(vv):
                        if v < extents_xyz[vi][0]:
                            vert[vvi][vi] = extents_xyz[vi][0]
                        elif v > extents_xyz[vi][1]:
                            vert[vvi][vi] = extents_xyz[vi][1]

                N = int(count * 2)  # layer-specific number of nodes
                node_count = 0

                logger.info(
                    "Generating %i nodes in layer %s for population %s..." %
                    (N, layer, population))
                if verbose:
                    rbf_logger = logging.Logger.manager.loggerDict[
                        'rbf.pde.nodes']
                    rbf_logger.setLevel(logging.DEBUG)

                min_energy_constraint = None
                if pop_constraint is not None and layer in pop_constraint:
                    min_energy_constraint = pop_constraint[layer]

                nodes = gen_min_energy_nodes(count, (vert, smp),
                                             min_energy_constraint, nodeiter,
                                             dispersion_delta, snap_delta)
                #nodes = gen_min_energy_nodes(count, (vert, smp),
                #                             pop_constraint[layer] if pop_constraint is not None else None,
                #                             nodeiter, dispersion_delta, snap_delta)

                xyz_coords_lst.append(nodes.reshape(-1, 3))

            for this_xyz_coords in xyz_coords_lst:
                all_xyz_coords_lst.append(this_xyz_coords)
                generated_coords_count_dict[population] += len(this_xyz_coords)

        # Additional dispersion step to ensure no overlapping cell positions
        all_xyz_coords = np.row_stack(all_xyz_coords_lst)
        mask = np.ones((all_xyz_coords.shape[0], ), dtype=np.bool)
        # distance to nearest neighbor
        while True:
            kdt = cKDTree(all_xyz_coords[mask, :])
            nndist, nnindices = kdt.query(all_xyz_coords[mask, :], k=2)
            nndist, nnindices = nndist[:, 1:], nnindices[:, 1:]

            zindices = nnindices[np.argwhere(
                np.isclose(nndist, 0.0, atol=1e-3, rtol=1e-3))]
            if len(zindices) > 0:
                mask[np.argwhere(mask)[zindices]] = False
            else:
                break

        coords_offset = 0
        for population in populations:
            pop_coords_count = generated_coords_count_dict[population]
            pop_mask = mask[coords_offset:coords_offset + pop_coords_count]
            generated_coords_count_dict[population] = np.count_nonzero(
                pop_mask)
            coords_offset += pop_coords_count

        logger.info("Dispersion of %i nodes..." % np.count_nonzero(mask))
        all_xyz_coords1 = disperse(all_xyz_coords[mask, :],
                                   vol_domain,
                                   delta=dispersion_delta)

    if rank == 0:
        logger.info("Computing UVL coordinates of %i nodes..." %
                    len(all_xyz_coords1))

    all_xyz_coords_interp = None
    all_uvl_coords_interp = None

    if rank == 0:
        all_uvl_coords_interp = vol.inverse(all_xyz_coords1)
        all_xyz_coords_interp = vol(all_uvl_coords_interp[:, 0],
                                    all_uvl_coords_interp[:, 1],
                                    all_uvl_coords_interp[:, 2],
                                    mesh=False).reshape(3, -1).T

    if rank == 0:
        logger.info("Broadcasting generated nodes...")

    xyz_coords = comm.bcast(all_xyz_coords1, root=0)
    all_xyz_coords_interp = comm.bcast(all_xyz_coords_interp, root=0)
    all_uvl_coords_interp = comm.bcast(all_uvl_coords_interp, root=0)
    generated_coords_count_dict = comm.bcast(dict(generated_coords_count_dict),
                                             root=0)

    coords_offset = 0
    pop_coords_dict = {}
    for population in populations:
        xyz_error = np.asarray([0.0, 0.0, 0.0])

        pop_layers = env.geometry['Cell Distribution'][population]

        pop_start, pop_count = population_ranges[population]
        coords = []

        gen_coords_count = generated_coords_count_dict[population]

        for i, coord_ind in enumerate(
                range(coords_offset, coords_offset + gen_coords_count)):

            if i % size == rank:

                uvl_coords = all_uvl_coords_interp[coord_ind, :].ravel()
                xyz_coords1 = all_xyz_coords_interp[coord_ind, :].ravel()
                if uvl_in_bounds(all_uvl_coords_interp[coord_ind, :],
                                 layer_extents, pop_layers):
                    xyz_error = np.add(
                        xyz_error,
                        np.abs(
                            np.subtract(xyz_coords[coord_ind, :],
                                        xyz_coords1)))

                    logger.info('Rank %i: %s cell %i: %f %f %f' %
                                (rank, population, i, uvl_coords[0],
                                 uvl_coords[1], uvl_coords[2]))

                    coords.append(
                        (xyz_coords1[0], xyz_coords1[1], xyz_coords1[2],
                         uvl_coords[0], uvl_coords[1], uvl_coords[2]))
                else:
                    logger.debug(
                        'Rank %i: %s cell %i not in bounds: %f %f %f' %
                        (rank, population, i, uvl_coords[0], uvl_coords[1],
                         uvl_coords[2]))
                    uvl_coords = None
                    xyz_coords1 = None

        total_xyz_error = np.zeros((3, ))
        comm.Allreduce(xyz_error, total_xyz_error, op=MPI.SUM)

        coords_count = 0
        coords_count = np.sum(np.asarray(comm.allgather(len(coords))))

        mean_xyz_error = np.asarray([(total_xyz_error[0] / coords_count), \
                                     (total_xyz_error[1] / coords_count), \
                                     (total_xyz_error[2] / coords_count)])

        pop_coords_dict[population] = coords
        coords_offset += gen_coords_count

        if rank == 0:
            logger.info(
                'Total %i coordinates generated for population %s: mean XYZ error: %f %f %f'
                % (coords_count, population, mean_xyz_error[0],
                   mean_xyz_error[1], mean_xyz_error[2]))

    if rank == 0:
        color = 1
    else:
        color = 0

    ## comm0 includes only rank 0
    comm0 = comm.Split(color, 0)

    for population in populations:

        pop_start, pop_count = population_ranges[population]
        pop_layers = env.geometry['Cell Distribution'][population]
        pop_constraint = None
        if 'Cell Constraints' in env.geometry:
            if population in env.geometry['Cell Constraints']:
                pop_constraint = env.geometry['Cell Constraints'][population]

        coords_lst = comm.gather(pop_coords_dict[population], root=0)
        if rank == 0:
            all_coords = []
            for sublist in coords_lst:
                for item in sublist:
                    all_coords.append(item)
            coords_count = len(all_coords)

            if coords_count < pop_count:
                logger.warning(
                    "Generating additional %i coordinates for population %s..."
                    % (pop_count - len(all_coords), population))

                safety = 0.01
                delta = pop_count - len(all_coords)
                for i in range(delta):
                    for layer, count in viewitems(pop_layers):
                        if count > 0:
                            min_extent = layer_extents[layer][0]
                            max_extent = layer_extents[layer][1]
                            coord_u = np.random.uniform(
                                min_extent[0] + safety, max_extent[0] - safety)
                            coord_v = np.random.uniform(
                                min_extent[1] + safety, max_extent[1] - safety)
                            if pop_constraint is None:
                                coord_l = np.random.uniform(
                                    min_extent[2] + safety,
                                    max_extent[2] - safety)
                            else:
                                coord_l = np.random.uniform(
                                    pop_constraint[layer][0] + safety,
                                    pop_constraint[layer][1] - safety)
                            xyz_coords = CA1_volume(coord_u,
                                                    coord_v,
                                                    coord_l,
                                                    rotate=rotate).ravel()
                            all_coords.append(
                                (xyz_coords[0], xyz_coords[1], xyz_coords[2],
                                 coord_u, coord_v, coord_l))

            sampled_coords = random_subset(all_coords, int(pop_count))
            sampled_coords.sort(
                key=lambda coord: coord[3])  ## sort on U coordinate

            coords_dict = {
                pop_start + i: {
                    'X Coordinate': np.asarray([x_coord], dtype=np.float32),
                    'Y Coordinate': np.asarray([y_coord], dtype=np.float32),
                    'Z Coordinate': np.asarray([z_coord], dtype=np.float32),
                    'U Coordinate': np.asarray([u_coord], dtype=np.float32),
                    'V Coordinate': np.asarray([v_coord], dtype=np.float32),
                    'L Coordinate': np.asarray([l_coord], dtype=np.float32)
                }
                for (i, (x_coord, y_coord, z_coord, u_coord, v_coord,
                         l_coord)) in enumerate(sampled_coords)
            }

            append_cell_attributes(output_path,
                                   population,
                                   coords_dict,
                                   namespace=output_namespace,
                                   io_size=io_size,
                                   chunk_size=chunk_size,
                                   value_chunk_size=value_chunk_size,
                                   comm=comm0)

        comm.barrier()

    comm0.Free()
Esempio n. 12
0
def generate_uv_distance_connections(comm, population_dict, connection_config, connection_prob, forest_path,
                                     synapse_seed, connectivity_seed, cluster_seed,
                                     synapse_namespace, connectivity_namespace, connectivity_path,
                                     io_size, chunk_size, value_chunk_size, cache_size, write_size=1,
                                     dry_run=False, debug=False):
    """
    Generates connectivity based on U, V distance-weighted probabilities.

    :param comm: mpi4py MPI communicator
    :param connection_config: connection configuration object (instance of env.ConnectionConfig)
    :param connection_prob: ConnectionProb instance
    :param forest_path: location of file with neuronal trees and synapse information
    :param synapse_seed: random seed for synapse partitioning
    :param connectivity_seed: random seed for connectivity generation
    :param cluster_seed: random seed for determining connectivity clustering for repeated connections from the same source
    :param synapse_namespace: namespace of synapse properties
    :param connectivity_namespace: namespace of connectivity attributes
    :param io_size: number of I/O ranks to use for parallel connectivity append
    :param chunk_size: HDF5 chunk size for connectivity file (pointer and index datasets)
    :param value_chunk_size: HDF5 chunk size for connectivity file (value datasets)
    :param cache_size: how many cells to read ahead
    :param write_size: how many cells to write out at the same time
    """

    rank = comm.rank

    if io_size == -1:
        io_size = comm.size
    if rank == 0:
        logger.info(f'{comm.size} ranks have been allocated')

    start_time = time.time()

    ranstream_syn = np.random.RandomState()
    ranstream_con = np.random.RandomState()

    destination_population = connection_prob.destination_population

    source_populations = sorted(connection_config[destination_population].keys())

    for source_population in source_populations:
        if rank == 0:
            logger.info(f'{source_population} -> {destination_population}: \n'
                        f'{pprint.pformat(connection_config[destination_population][source_population])}')

    projection_config = connection_config[destination_population]
    projection_synapse_dict = {source_population:
                                   (projection_config[source_population].type,
                                    projection_config[source_population].layers,
                                    projection_config[source_population].sections,
                                    projection_config[source_population].proportions,
                                    projection_config[source_population].contacts)
                               for source_population in source_populations}

    
    comm.barrier()

    it_count = 0
    total_count = 0
    gid_count = 0
    connection_dict = defaultdict(lambda: {})
    projection_dict = {}
    for destination_gid, synapse_dict in NeuroH5CellAttrGen(forest_path, \
                                                            destination_population, \
                                                            namespace=synapse_namespace, \
                                                            comm=comm, io_size=io_size, \
                                                            cache_size=cache_size):
        if destination_gid is None:
            logger.info(f'Rank {rank} destination gid is None')
        else:
            logger.info(f'Rank {rank} received attributes for destination: {destination_population}, gid: {destination_gid}')

            ranstream_con.seed(destination_gid + connectivity_seed)
            ranstream_syn.seed(destination_gid + synapse_seed)
            last_gid_time = time.time()

            projection_prob_dict = {}
            for source_population in source_populations:
                source_layers = projection_config[source_population].layers
                projection_prob_dict[source_population] = \
                    connection_prob.get_prob(destination_gid, source_population, source_layers)


                for layer, (probs, source_gids, distances_u, distances_v) in \
                        viewitems(projection_prob_dict[source_population]):
                    if len(distances_u) > 0:
                        max_u_distance = np.max(distances_u)
                        min_u_distance = np.min(distances_u)
                        if rank == 0:
                            logger.info(f'Rank {rank} has {len(source_gids)} possible sources from population {source_population} '
                                        f'for destination: {destination_population}, layer {layer}, gid: {destination_gid}; '
                                        f'max U distance: {max_u_distance:.2f} min U distance: {min_u_distance:.2f}')
                    else:
                        logger.warning(f'Rank {rank} has {len(source_gids)} possible sources from population {source_population} '
                                       f'for destination: {destination_population}, layer {layer}, gid: {destination_gid}')

            count = generate_synaptic_connections(rank,
                                                  destination_gid,
                                                  ranstream_syn,
                                                  ranstream_con,
                                                  cluster_seed + destination_gid,
                                                  destination_gid,
                                                  synapse_dict,
                                                  population_dict,
                                                  projection_synapse_dict,
                                                  projection_prob_dict,
                                                  connection_dict,
                                                  debug_flag=debug)
            total_count += count

            logger.info(f'Rank {rank} took {time.time() - last_gid_time:.2f} s to compute {count} edges for destination: {destination_population}, gid: {destination_gid}')

        if (write_size > 0) and (gid_count % write_size == 0):
            if len(connection_dict) > 0:
                projection_dict = {destination_population: connection_dict}
            else:
                projection_dict = {}
            if not dry_run:
                last_time = time.time()
                append_graph(connectivity_path, projection_dict, io_size=io_size, comm=comm)
                if rank == 0:
                    if connection_dict:
                        logger.info(f'Appending connectivity for {len(connection_dict)} projections took {time.time() - last_time:.2f} s')
            projection_dict.clear()
            connection_dict.clear()
            gc.collect()

        gid_count += 1
        it_count += 1
        if (it_count > 250) and debug:
            break


    gc.collect()
    last_time = time.time()
    if len(connection_dict) > 0:
        projection_dict = {destination_population: connection_dict}
    else:
        projection_dict = {}
    if not dry_run:
        append_graph(connectivity_path, projection_dict, io_size=io_size, comm=comm)
        if rank == 0:
            if connection_dict:
                logger.info(f'Appending connectivity for {len(connection_dict)} projections took {time.time() - last_time:.2f} s')

    global_count = comm.gather(total_count, root=0)
    if rank == 0:
        logger.info(f'{comm.size} ranks took {time.time() - start_time:.2f} s to generate {np.sum(global_count)} edges')
Esempio n. 13
0
def generate_synaptic_connections(rank,
                                  gid,
                                  ranstream_syn,
                                  ranstream_con,
                                  cluster_seed,
                                  destination_gid,
                                  synapse_dict,
                                  population_dict,
                                  projection_synapse_dict,
                                  projection_prob_dict,
                                  connection_dict,
                                  random_choice=random_choice_w_replacement,
                                  debug_flag=False):
    """
    Given a set of synapses for a particular gid, projection
    configuration, projection and connection probability dictionaries,
    generates a set of possible connections for each synapse. The
    procedure first assigns each synapse to a projection, using the
    given proportions of each synapse type, and then chooses source
    gids for each synapse using the given projection probability
    dictionary.

    :param ranstream_syn: random stream for the synapse partitioning step
    :param ranstream_con: random stream for the choosing source gids step
    :param destination_gid: destination gid
    :param synapse_dict: synapse configurations, a dictionary with fields: 1) syn_ids (synapse ids) 2) syn_types (excitatory, inhibitory, etc).,
                        3) swc_types (SWC types(s) of synapse location in the neuronal morphological structure 3) syn_layers (synapse layer placement)
    :param population_dict: mapping of population names to population indices
    :param projection_synapse_dict: mapping of projection names to a tuple of the form: <syn_layer, swc_type, syn_type, syn_proportion>
    :param projection_prob_dict: mapping of presynaptic population names to sets of source probabilities and source gids
    :param connection_dict: output connection dictionary
    :param random_choice: random choice procedure (default uses np.ranstream.multinomial)

    """
    num_projections = len(projection_synapse_dict)
    source_populations = sorted(projection_synapse_dict)
    prj_pop_index = {population: i for (i, population) in enumerate(source_populations)}
    synapse_prj_counts = np.zeros((num_projections,))
    synapse_prj_partition = defaultdict(lambda: defaultdict(list))
    maxit = 10
    it = 0
    syn_cdist_dict = {}
    ## assign each synapse to a projection
    while (np.count_nonzero(synapse_prj_counts) < num_projections) and (it < maxit):
        log_flag = it > 1
        if log_flag or debug_flag:
            logger.info(f"generate_synaptic_connections: gid {gid}: iteration {it}: "
                        f"source_populations = {source_populations} "
                        f"synapse_prj_counts = {synapse_prj_counts}")
        if debug_flag:
            logger.info(f'synapse_dict = {synapse_dict}')
        synapse_prj_counts.fill(0)
        synapse_prj_partition.clear()
        for (syn_id, syn_cdist, syn_type, swc_type, syn_layer) in zip(synapse_dict['syn_ids'],
                                                                      synapse_dict['syn_cdists'],
                                                                      synapse_dict['syn_types'],
                                                                      synapse_dict['swc_types'],
                                                                      synapse_dict['syn_layers']):
            syn_cdist_dict[syn_id] = syn_cdist
            projection = choose_synapse_projection(ranstream_syn, syn_layer, swc_type, syn_type, \
                                                   population_dict, projection_synapse_dict, log=log_flag)
            if log_flag or debug_flag:
                logger.info(f'generate_synaptic_connections: gid {gid}: '
                            f'syn_id = {syn_id} syn_type = {syn_type} swc_type = {swc_type} '
                            f'syn_layer = {syn_layer} source = {projection}')
            log_flag = False
            assert (projection is not None)
            synapse_prj_counts[prj_pop_index[projection]] += 1
            synapse_prj_partition[projection][syn_layer].append(syn_id)
        it += 1

    empty_projections = []

    for projection in projection_synapse_dict:
        logger.debug(f'Rank {rank}: gid {destination_gid}: source {projection} has {len(synapse_prj_partition[projection])} synapses')
        if not (len(synapse_prj_partition[projection]) > 0):
            empty_projections.append(projection)

    if len(empty_projections) > 0:
        logger.warning(f"Rank {rank}: gid {destination_gid}: projections {empty_projections} have an empty synapse list; "
                       f"swc types are {set(synapse_dict['swc_types'].flat)} layers are {set(synapse_dict['syn_layers'].flat)}")
    assert (len(empty_projections) == 0)

    ## Choose source connections based on distance-weighted probability
    count = 0
    for projection, prj_layer_dict in viewitems(synapse_prj_partition):
        (syn_config_type, syn_config_layers, syn_config_sections, syn_config_proportions, syn_config_contacts) = \
            projection_synapse_dict[projection]
        gid_dict = connection_dict[projection]
        prj_source_vertices = []
        prj_syn_ids = []
        prj_distances = []
        for prj_layer, syn_ids in viewitems(prj_layer_dict):
            source_probs, source_gids, distances_u, distances_v = \
                projection_prob_dict[projection][prj_layer]
            distance_dict = {source_gid: distance_u + distance_v \
                             for (source_gid, distance_u, distance_v) in \
                             zip(source_gids, distances_u, distances_v)}
            if len(source_gids) > 0:
                ordered_syn_ids = sorted(syn_ids, key=lambda x: syn_cdist_dict[x])
                n_syn_groups = int(math.ceil(float(len(syn_ids)) / float(syn_config_contacts)))
                source_gid_counts = random_choice(ranstream_con, n_syn_groups, source_probs)
                total_count = 0
                if syn_config_contacts > 1:
                    ncontacts = int(math.ceil(syn_config_contacts))
                    for i in range(0, len(source_gid_counts)):
                        if source_gid_counts[i] > 0:
                            source_gid_counts[i] *= ncontacts
                if len(source_gid_counts) == 0:
                    logger.warning(f'Rank {rank}: source vertices list is empty for gid: {destination_gid} ' 
                                   f'source: {projection} layer: {layer} '
                                   f'source probs: {source_probs} distances_u: {distances_u} distances_v: {distances_v}')

                source_vertices = np.asarray(random_clustered_shuffle(len(source_gids), \
                                                                      source_gid_counts, \
                                                                      center_ids=source_gids, \
                                                                      cluster_std=2.0, \
                                                                      random_seed=cluster_seed), \
                                             dtype=np.uint32)[0:len(syn_ids)]
                assert (len(source_vertices) == len(syn_ids))
                distances = np.asarray([distance_dict[gid] for gid in source_vertices], \
                                       dtype=np.float32).reshape(-1, )
                prj_source_vertices.append(source_vertices)
                prj_syn_ids.append(ordered_syn_ids)
                prj_distances.append(distances)
                gid_dict[destination_gid] = (np.asarray([], dtype=np.uint32),
                                             {'Synapses': {'syn_id': np.asarray([], dtype=np.uint32)},
                                              'Connections': {'distance': np.asarray([], dtype=np.float32)}
                                              })
                cluster_seed += 1
        if len(prj_source_vertices) > 0:
            prj_source_vertices_array = np.concatenate(prj_source_vertices)
        else:
            prj_source_vertices_array = np.asarray([], dtype=np.uint32)
        del (prj_source_vertices)
        if len(prj_syn_ids) > 0:
            prj_syn_ids_array = np.concatenate(prj_syn_ids)
        else:
            prj_syn_ids_array = np.asarray([], dtype=np.uint32)
        del (prj_syn_ids)
        if len(prj_distances) > 0:
            prj_distances_array = np.concatenate(prj_distances)
        else:
            prj_distances_array = np.asarray([], dtype=np.float32)
        del (prj_distances)
        if len(prj_source_vertices_array) == 0:
            logger.warning(f'Rank {rank}: source gid list is empty for gid: {destination_gid} source: {projection}')
        count += len(prj_source_vertices_array)
        gid_dict[destination_gid] = (prj_source_vertices_array,
                                     {'Synapses': {'syn_id': np.asarray(prj_syn_ids_array, \
                                                                        dtype=np.uint32)},
                                      'Connections': {'distance': prj_distances_array}
                                      })

    return count
Esempio n. 14
0
def spike_density_estimate(population,
                           spkdict,
                           time_bins,
                           arena_id=None,
                           trajectory_id=None,
                           output_file_path=None,
                           progress=False,
                           inferred_rate_attr_name='Inferred Rate Map',
                           **kwargs):
    """
    Calculates spike density function for the given spike trains.
    :param population:
    :param spkdict:
    :param time_bins:
    :param arena_id: str
    :param trajectory_id: str
    :param output_file_path:
    :param progress:
    :param inferred_rate_attr_name: str
    :param kwargs: dict
    :return: dict
    """
    if progress:
        from tqdm import tqdm

    analysis_options = copy.copy(default_baks_analysis_options)
    analysis_options.update(kwargs)

    def make_spktrain(lst, t_start, t_stop):
        spkts = np.asarray(lst, dtype=np.float32)
        return spkts[(spkts >= t_start) & (spkts <= t_stop)]

    t_start = time_bins[0]
    t_stop = time_bins[-1]

    spktrains = {
        ind: make_spktrain(lst, t_start, t_stop)
        for (ind, lst) in viewitems(spkdict)
    }
    baks_args = dict()
    baks_args['a'] = analysis_options['BAKS Alpha']
    baks_args['b'] = analysis_options['BAKS Beta']

    if progress:
        seq = tqdm(viewitems(spktrains))
    else:
        seq = viewitems(spktrains)

    spk_rate_dict = {
        ind: baks(spkts / 1000., time_bins / 1000., **baks_args)[0].reshape(
            (-1, )) if len(spkts) > 1 else np.zeros(time_bins.shape)
        for ind, spkts in seq
    }

    if output_file_path is not None:
        if arena_id is None or trajectory_id is None:
            raise RuntimeError(
                'spike_density_estimate: arena_id and trajectory_id required to write Spike Density'
                'Function namespace')
        namespace = 'Spike Density Function %s %s' % (arena_id, trajectory_id)
        attr_dict = {
            ind: {
                inferred_rate_attr_name:
                np.asarray(spk_rate_dict[ind], dtype='float32')
            }
            for ind in spk_rate_dict
        }
        write_cell_attributes(output_file_path,
                              population,
                              attr_dict,
                              namespace=namespace)

    result = {
        ind: {
            'rate': rate,
            'time': time_bins
        }
        for ind, rate in viewitems(spk_rate_dict)
    }

    result = {
        ind: {
            'rate': rate,
            'time': time_bins
        }
        for ind, rate in viewitems(spk_rate_dict)
    }

    return result
Esempio n. 15
0
def main(config, coords_path, coords_namespace, geometry_path, populations, interp_chunk_size, resolution, alpha_radius, nsample, io_size, chunk_size, value_chunk_size, cache_size, verbose):

    utils.config_logging(verbose)
    logger = utils.get_script_logger(__file__)
    
    comm = MPI.COMM_WORLD
    rank = comm.rank

    env = Env(comm=comm, config_file=config)
    output_path = coords_path

    soma_coords = {}

    if rank == 0:
        logger.info('Reading population coordinates...')
        
    for population in sorted(populations):
        coords = bcast_cell_attributes(coords_path, population, 0, \
                                       namespace=coords_namespace, comm=comm)

        soma_coords[population] = { k: (v['U Coordinate'][0], v['V Coordinate'][0], v['L Coordinate'][0]) 
                                    for (k,v) in coords }
        del coords
        gc.collect()

    
    has_ip_dist=False
    origin_ranges=None
    ip_dist_u=None
    ip_dist_v=None
    ip_dist_path = 'Distance Interpolant/%d/%d/%d' % resolution
    if rank == 0:
        if geometry_path is not None:
            f = h5py.File(geometry_path,"a")
            pkl_path = f'{ip_dist_path}/ip_dist.pkl'
            if pkl_path in f:
                has_ip_dist = True
                ip_dist_dset = f[pkl_path]
                origin_ranges, ip_dist_u, ip_dist_v = pickle.loads(base64.b64decode(ip_dist_dset[()]))
            f.close()
    has_ip_dist = env.comm.bcast(has_ip_dist, root=0)
    
    if not has_ip_dist:
        if rank == 0:
            logger.info('Creating distance interpolant...')
        (origin_ranges, ip_dist_u, ip_dist_v) = make_distance_interpolant(env.comm, geometry_config=env.geometry,
                                                                          make_volume=make_CA1_volume,
                                                                          resolution=resolution, nsample=nsample)
        if rank == 0:
            if geometry_path is not None:
                f = h5py.File(geometry_path, 'a')
                pkl_path = f'{ip_dist_path}/ip_dist.pkl'
                pkl = pickle.dumps((origin_ranges, ip_dist_u, ip_dist_v))
                pklstr = base64.b64encode(pkl)
                f[pkl_path] = pklstr
                f.close()
                
    ip_dist = (origin_ranges, ip_dist_u, ip_dist_v)
    if rank == 0:
        logger.info('Measuring soma distances...')

    soma_distances = measure_distances(env.comm, env.geometry, soma_coords, ip_dist, resolution=resolution)
                                       
    for population in list(sorted(soma_distances.keys())):

        if rank == 0:
            logger.info(f'Writing distances for population {population}...')

        dist_dict = soma_distances[population]
        attr_dict = {}
        for k, v in viewitems(dist_dict):
            attr_dict[k] = { 'U Distance': np.asarray([v[0]],dtype=np.float32), \
                             'V Distance': np.asarray([v[1]],dtype=np.float32) }
        append_cell_attributes(output_path, population, attr_dict,
                               namespace='Arc Distances', comm=comm,
                               io_size=io_size, chunk_size=chunk_size,
                               value_chunk_size=value_chunk_size, cache_size=cache_size)
        if rank == 0:
            f = h5py.File(output_path, 'a')
            f['Populations'][population]['Arc Distances'].attrs['Reference U Min'] = origin_ranges[0][0]
            f['Populations'][population]['Arc Distances'].attrs['Reference U Max'] = origin_ranges[0][1]
            f['Populations'][population]['Arc Distances'].attrs['Reference V Min'] = origin_ranges[1][0]
            f['Populations'][population]['Arc Distances'].attrs['Reference V Max'] = origin_ranges[1][1]
            f.close()

    comm.Barrier()
Esempio n. 16
0
def make_morph_graph(biophys_cell, node_filters={}):
    """
    Creates a graph of 3d points that follows the morphological organization of the given neuron.
    :param neurotree_dict:
    :return: NetworkX.DiGraph
    """
    import networkx as nx

    nodes = filter_nodes(biophys_cell, **node_filters)
    tree = biophys_cell.tree

    sec_layers = {}
    src_sec = []
    dst_sec = []
    connection_locs = []
    pt_xs = []
    pt_ys = []
    pt_zs = []
    pt_locs = []
    pt_idxs = []
    pt_layers = []
    pt_idx = 0
    sec_pts = collections.defaultdict(list)

    for node in nodes:
        sec = node.sec
        nn = sec.n3d()
        L = sec.L
        for i in range(nn):
            pt_xs.append(sec.x3d(i))
            pt_ys.append(sec.y3d(i))
            pt_zs.append(sec.z3d(i))
            loc = sec.arc3d(i) / L
            pt_locs.append(loc)
            pt_layers.append(node.get_layer(loc))
            pt_idxs.append(pt_idx)
            sec_pts[node.index].append(pt_idx)
            pt_idx += 1

        for child in tree.successors(node):
            src_sec.append(node.index)
            dst_sec.append(child.index)
            connection_locs.append(h.parent_connection(sec=child.sec))

    sec_pt_idxs = {}
    edges = []
    for sec, pts in viewitems(sec_pts):
        sec_pt_idxs[pts[0]] = sec
        for i in range(1, len(pts)):
            sec_pt_idxs[pts[i]] = sec
            src_pt = pts[i - 1]
            dst_pt = pts[i]
            edges.append((src_pt, dst_pt))

    for (s, d, parent_loc) in zip(src_sec, dst_sec, connection_locs):
        for src_pt in sec_pts[s]:
            if pt_locs[src_pt] >= parent_loc:
                break
        dst_pt = sec_pts[d][0]
        edges.append((src_pt, dst_pt))

    morph_graph = nx.Graph()
    morph_graph.add_nodes_from([(i, {
        'x': x,
        'y': y,
        'z': z,
        'sec': sec_pt_idxs[i],
        'loc': loc,
        'layer': layer
    }) for (i, x, y, z, loc, layer) in zip(range(len(pt_idxs)), pt_xs, pt_ys,
                                           pt_zs, pt_locs, pt_layers)])
    for i, j in edges:
        morph_graph.add_edge(i, j)

    return morph_graph