Exemple #1
0
def spike_density_estimate(population, spkdict, time_bins, arena_id=None, trajectory_id=None, output_file_path=None,
                            progress=False, inferred_rate_attr_name='Inferred Rate Map', **kwargs):
    """
    Calculates spike density function for the given spike trains.
    :param population:
    :param spkdict:
    :param time_bins:
    :param arena_id: str
    :param trajectory_id: str
    :param output_file_path:
    :param progress:
    :param inferred_rate_attr_name: str
    :param kwargs: dict
    :return: dict
    """
    if progress:
        from tqdm import tqdm

    analysis_options = copy.copy(default_baks_analysis_options)
    analysis_options.update(kwargs)

    def make_spktrain(lst, t_start, t_stop):
        spkts = np.asarray(lst, dtype=np.float32)
        return spkts[(spkts >= t_start) & (spkts <= t_stop)]

    
    t_start = time_bins[0]
    t_stop = time_bins[-1]

    spktrains = {ind: make_spktrain(lst, t_start, t_stop) for (ind, lst) in viewitems(spkdict)}
    baks_args = dict()
    baks_args['a'] = analysis_options['BAKS Alpha']
    baks_args['b'] = analysis_options['BAKS Beta']
    
    if progress:
        seq = tqdm(viewitems(spktrains))
    else:
        seq = viewitems(spktrains)
        
    spk_rate_dict = {ind: baks(spkts / 1000., time_bins / 1000., **baks_args)[0].reshape((-1,))
                     if len(spkts) > 1 else np.zeros(time_bins.shape)
                     for ind, spkts in seq}

    if output_file_path is not None:
        if arena_id is None or trajectory_id is None:
            raise RuntimeError('spike_density_estimate: arena_id and trajectory_id required to write Spike Density'
                               'Function namespace')
        namespace = 'Spike Density Function %s %s' % (arena_id, trajectory_id)
        attr_dict = {ind: {inferred_rate_attr_name: np.asarray(spk_rate_dict[ind], dtype='float32')}
                     for ind in spk_rate_dict}
        write_cell_attributes(output_file_path, population, attr_dict, namespace=namespace)

    result = {ind: {'rate': rate, 'time': time_bins} for ind, rate in viewitems(spk_rate_dict)}

        
    result = { ind: { 'rate': rate, 'time': time_bins }
              for ind, rate in viewitems(spk_rate_dict) }
    
    return result
def vout(env, output_path, t_vec, v_dict):
    if not str(env.resultsId):
        namespace_id = "Intracellular Voltage"
    else:
        namespace_id = "Intracellular Voltage %s" % str(env.resultsId)

    for pop_name, gid_v_dict in v_dict.iteritems():
        start = env.celltypes[pop_name]['start']

        attr_dict = {gid - start: {'v': np.array(vs, dtype=np.float32), 't': t_vec}
                     for (gid, vs) in gid_v_dict.iteritems()}

        write_cell_attributes(env.comm, output_path, pop_name, attr_dict, namespace=namespace_id)
def spikeout(env, output_path, t_vec, id_vec):
    binlst = []
    typelst = env.celltypes.keys()
    for k in typelst:
        binlst.append(env.celltypes[k]['start'])

    binvect = np.array(binlst)
    sort_idx = np.argsort(binvect, axis=0)
    bins = binvect[sort_idx][1:]
    types = [typelst[i] for i in sort_idx]
    inds = np.digitize(id_vec, bins)

    if not str(env.resultsId):
        namespace_id = "Spike Events"
    else:
        namespace_id = "Spike Events %s" % str(env.resultsId)

    for i in range(0, len(types)):
        if i > 0:
            start = bins[i - 1]
        else:
            start = 0
        spkdict = {}
        sinds = np.where(inds == i)
        if len(sinds) > 0:
            ids = id_vec[sinds]
            ts = t_vec[sinds]
            for j in range(0, len(ids)):
                id = ids[j] - start
                t = ts[j]
                if spkdict.has_key(id):
                    spkdict[id]['t'].append(t)
                else:
                    spkdict[id] = {'t': [t]}
            for j in spkdict.keys():
                spkdict[j]['t'] = np.array(spkdict[j]['t'])
        pop_name = types[i]
        write_cell_attributes(env.comm, output_path, pop_name, spkdict, namespace=namespace_id)
Exemple #4
0
def write_input_cell_selection(env,
                               input_sources,
                               write_selection_file_path,
                               populations=None,
                               write_kwds={}):
    """
    Writes out predefined spike trains when only a subset of the network is instantiated.

    :param env: an instance of the `Env` class
    :param input_sources: a dictionary of the form { pop_name, gid_sources }
    """

    if 'comm' not in write_kwds:
        write_kwds['comm'] = env.comm
    if 'io_size' not in write_kwds:
        write_kwds['io_size'] = env.io_size

    rank = int(env.comm.Get_rank())
    nhosts = int(env.comm.Get_size())

    dataset_path = env.dataset_path
    input_file_path = env.data_file_path

    if populations is None:
        pop_names = sorted(env.celltypes.keys())
    else:
        pop_names = populations

    for pop_name, gid_range in sorted(viewitems(input_sources)):

        gc.collect()

        if pop_name not in pop_names:
            continue

        spikes_output_dict = {}

        if (env.cell_selection is not None) and (pop_name
                                                 in env.cell_selection):
            local_gid_range = gid_range.difference(
                set(env.cell_selection[pop_name]))
        else:
            local_gid_range = gid_range

        gid_range = env.comm.allreduce(local_gid_range, op=mpi_op_set_union)
        this_gid_range = set([])
        for i, gid in enumerate(gid_range):
            if i % nhosts == rank:
                this_gid_range.add(gid)

        has_spike_train = False
        spike_input_source_loc = []
        if (env.spike_input_attribute_info is not None) and (env.spike_input_ns
                                                             is not None):
            if (pop_name in env.spike_input_attribute_info) and \
                    (env.spike_input_ns in env.spike_input_attribute_info[pop_name]):
                has_spike_train = True
                spike_input_source_loc.append(
                    (env.spike_input_path, env.spike_input_ns))
        if (env.cell_attribute_info is not None) and (env.spike_input_ns
                                                      is not None):
            if (pop_name in env.cell_attribute_info) and \
                    (env.spike_input_ns in env.cell_attribute_info[pop_name]):
                has_spike_train = True
                spike_input_source_loc.append(
                    (input_file_path, env.spike_input_ns))

        if rank == 0:
            logger.info(
                '*** Reading spike trains for population %s: %d cells: has_spike_train = %s'
                % (pop_name, len(this_gid_range), str(has_spike_train)))

        if has_spike_train:

            vecstim_attr_set = set(['t'])
            if env.spike_input_attr is not None:
                vecstim_attr_set.add(env.spike_input_attr)
            if 'spike train' in env.celltypes[pop_name]:
                vecstim_attr_set.add(
                    env.celltypes[pop_name]['spike train']['attribute'])

            cell_spikes_iters = [ scatter_read_cell_attribute_selection(input_path, pop_name, \
                                                                        list(this_gid_range), \
                                                                        namespace=input_ns, \
                                                                        mask=vecstim_attr_set, \
                                                                        comm=env.comm, io_size=env.io_size)
                                  for (input_path, input_ns) in spike_input_source_loc ]

            for cell_spikes_iter in cell_spikes_iters:
                spikes_output_dict.update(dict(list(cell_spikes_iter)))

        if rank == 0:
            logger.info('*** Writing spike trains for population %s: %s' %
                        (pop_name, str(spikes_output_dict)))


        write_cell_attributes(write_selection_file_path, pop_name, spikes_output_dict,  \
                              namespace=env.spike_input_ns, **write_kwds)
Exemple #5
0
def write_connection_selection(env,
                               write_selection_file_path,
                               populations=None,
                               write_kwds={}):
    """
    Loads NeuroH5 connectivity file, and writes the corresponding
    synapse and network connection mechanisms for the selected postsynaptic cells.

    :param env: an instance of the `Env` class
    """

    if 'comm' not in write_kwds:
        write_kwds['comm'] = env.comm
    if 'io_size' not in write_kwds:
        write_kwds['io_size'] = env.io_size

    connectivity_file_path = env.connectivity_file_path
    forest_file_path = env.forest_file_path
    rank = int(env.comm.Get_rank())
    nhosts = int(env.comm.Get_size())
    syn_attrs = env.synapse_attributes

    if populations is None:
        pop_names = sorted(env.cell_selection.keys())
    else:
        pop_names = populations

    input_sources = {pop_name: set([]) for pop_name in env.celltypes}

    for (postsyn_name, presyn_names) in sorted(viewitems(env.projection_dict)):

        gc.collect()

        if rank == 0:
            logger.info('*** Writing connection selection of population %s' %
                        (postsyn_name))

        if postsyn_name not in pop_names:
            continue

        gid_range = [
            gid for i, gid in enumerate(env.cell_selection[postsyn_name])
            if i % nhosts == rank
        ]

        synapse_config = env.celltypes[postsyn_name]['synapses']

        weight_dicts = []
        has_weights = False
        if 'weights' in synapse_config:
            has_weights = True
            weight_dicts = synapse_config['weights']

        if rank == 0:
            logger.info('*** Reading synaptic attributes for population %s' %
                        (postsyn_name))

        syn_attributes_iter = scatter_read_cell_attribute_selection(
            forest_file_path,
            postsyn_name,
            selection=gid_range,
            namespace='Synapse Attributes',
            comm=env.comm,
            io_size=env.io_size)

        syn_attributes_output_dict = dict(list(syn_attributes_iter))
        write_cell_attributes(write_selection_file_path,
                              postsyn_name,
                              syn_attributes_output_dict,
                              namespace='Synapse Attributes',
                              **write_kwds)
        del syn_attributes_output_dict
        del syn_attributes_iter

        if has_weights:

            for weight_dict in weight_dicts:

                weights_namespaces = weight_dict['namespace']

                if rank == 0:
                    logger.info(
                        '*** Reading synaptic weights of population %s from namespaces %s'
                        % (postsyn_name, str(weights_namespaces)))

                for weights_namespace in weights_namespaces:
                    syn_weights_iter = scatter_read_cell_attribute_selection(
                        forest_file_path,
                        postsyn_name,
                        namespace=weights_namespace,
                        selection=gid_range,
                        comm=env.comm,
                        io_size=env.io_size)

                    weight_attributes_output_dict = dict(
                        list(syn_weights_iter))
                    write_cell_attributes(write_selection_file_path,
                                          postsyn_name,
                                          weight_attributes_output_dict,
                                          namespace=weights_namespace,
                                          **write_kwds)
                    del weight_attributes_output_dict
                    del syn_weights_iter

        logger.info(
            '*** Rank %i: reading connectivity selection from file %s for postsynaptic population: %s: selection: %s'
            % (rank, connectivity_file_path, postsyn_name, str(gid_range)))

        (graph, attr_info) = scatter_read_graph_selection(connectivity_file_path, selection=gid_range, \
                                                          projections=[ (presyn_name, postsyn_name) for presyn_name in sorted(presyn_names) ], \
                                                          comm=env.comm, io_size=env.io_size, namespaces=['Synapses', 'Connections'])

        for presyn_name in sorted(presyn_names):
            gid_dict = {}
            edge_count = 0
            node_count = 0
            if postsyn_name in graph:

                if postsyn_name in attr_info and presyn_name in attr_info[
                        postsyn_name]:
                    edge_attr_info = attr_info[postsyn_name][presyn_name]
                else:
                    raise RuntimeError('write_connection_selection: missing edge attributes for projection %s -> %s' % \
                                       (presyn_name, postsyn_name))

                if 'Synapses' in edge_attr_info and \
                        'syn_id' in edge_attr_info['Synapses'] and \
                        'Connections' in edge_attr_info and \
                        'distance' in edge_attr_info['Connections']:
                    syn_id_attr_index = edge_attr_info['Synapses']['syn_id']
                    distance_attr_index = edge_attr_info['Connections'][
                        'distance']
                else:
                    raise RuntimeError('write_connection_selection: missing edge attributes for projection %s -> %s' % \
                                           (presyn_name, postsyn_name))

                edge_iter = compose_iter(lambda edgeset: input_sources[presyn_name].update(edgeset[1][0]), \
                                         graph[postsyn_name][presyn_name])
                for (postsyn_gid, edges) in edge_iter:

                    presyn_gids, edge_attrs = edges
                    edge_syn_ids = edge_attrs['Synapses'][syn_id_attr_index]
                    edge_dists = edge_attrs['Connections'][distance_attr_index]

                    gid_dict[postsyn_gid] = (presyn_gids, {
                        'Synapses': {
                            'syn_id': edge_syn_ids
                        },
                        'Connections': {
                            'distance': edge_dists
                        }
                    })
                    edge_count += len(presyn_gids)
                    node_count += 1

            env.comm.barrier()
            logger.info(
                '*** Rank %d: Writing projection %s -> %s selection: %d nodes, %d edges'
                % (rank, presyn_name, postsyn_name, node_count, edge_count))
            write_graph(write_selection_file_path, \
                        src_pop_name=presyn_name, dst_pop_name=postsyn_name, \
                        edges=gid_dict, comm=env.comm, io_size=env.io_size)
            env.comm.barrier()

    return input_sources
Exemple #6
0
def write_cell_selection(env,
                         write_selection_file_path,
                         populations=None,
                         write_kwds={}):
    """
    Writes out the data necessary to instantiate the selected cells.

    :param env: an instance of the `Env` class
    """

    if 'comm' not in write_kwds:
        write_kwds['comm'] = env.comm
    if 'io_size' not in write_kwds:
        write_kwds['io_size'] = env.io_size

    rank = int(env.comm.Get_rank())
    nhosts = int(env.comm.Get_size())

    dataset_path = env.dataset_path
    data_file_path = env.data_file_path

    if populations is None:
        pop_names = sorted(env.cell_selection.keys())
    else:
        pop_names = populations

    for pop_name in pop_names:

        gid_range = [
            gid for i, gid in enumerate(env.cell_selection[pop_name])
            if i % nhosts == rank
        ]

        trees_output_dict = {}
        coords_output_dict = {}
        num_cells = 0
        if (pop_name in env.cell_attribute_info) and (
                'Trees' in env.cell_attribute_info[pop_name]):
            if rank == 0:
                logger.info("*** Reading trees for population %s" % pop_name)

            cell_tree_iter, _ = scatter_read_tree_selection(data_file_path, pop_name, selection=gid_range, \
                                                            topology=False, comm=env.comm, io_size=env.io_size)
            if rank == 0:
                logger.info("*** Done reading trees for population %s" %
                            pop_name)

            for i, (gid, tree) in enumerate(cell_tree_iter):
                trees_output_dict[gid] = tree
                num_cells += 1

            assert (len(trees_output_dict) == len(gid_range))

        elif (pop_name in env.cell_attribute_info) and (
                'Coordinates' in env.cell_attribute_info[pop_name]):
            if rank == 0:
                logger.info("*** Reading coordinates for population %s" %
                            pop_name)

            cell_attributes_iter = scatter_read_cell_attribute_selection(data_file_path, pop_name, selection=gid_range, \
                                                                         namespace='Coordinates', comm=env.comm, io_size=env.io_size)

            if rank == 0:
                logger.info("*** Done reading coordinates for population %s" %
                            pop_name)

            for i, (gid, coords) in enumerate(cell_attributes_iter):
                coords_output_dict[gid] = coords
                num_cells += 1

        if rank == 0:
            logger.info(
                "*** Writing cell selection for population %s to file %s" %
                (pop_name, write_selection_file_path))
        append_cell_trees(write_selection_file_path, pop_name,
                          trees_output_dict, **write_kwds)
        write_cell_attributes(write_selection_file_path,
                              pop_name,
                              coords_output_dict,
                              namespace='Coordinates',
                              **write_kwds)
        env.comm.barrier()
Exemple #7
0
def import_spikeraster(celltype_path,
                       spikeraster_path,
                       output_path,
                       output_npy=False,
                       namespace="Spike Data",
                       progress=False,
                       comm=None):

    if progress:
        import tqdm

    if comm is None:
        comm = MPI.COMM_WORLD

    populations = import_celltypes(celltype_path, output_path)
    n_pop = len(populations)

    start = 0
    pop_range_bins = []
    for name, idx, count in populations[:-1]:
        pop_range_bins.append(start + count)
        start = start + count

    logger.info(
        f"populations: {populations} total: {start} pop_range_bins: {pop_range_bins}"
    )

    logger.info(f"Reading spike data from file {spikeraster_path}...")

    if spikeraster_path.endswith('.npy'):
        spike_array = np.load(spikeraster_path)
    else:
        spike_array = np.loadtxt(spikeraster_path,
                                 dtype=np.dtype([("time", np.float32),
                                                 ("gid", np.uint32)]))

    if output_npy:
        np.save(f'{spikeraster_path}.npy', spike_array)

    logger.info(f"Done reading spike data from file {spikeraster_path}")

    gid_array = spike_array['gid']
    gid_bins = np.digitize(gid_array, np.asarray(pop_range_bins))

    pop_spk_dict = defaultdict(lambda: defaultdict(list))
    if progress:
        it = tqdm.tqdm(enumerate(zip(gid_array, gid_bins)), unit_scale=True)
    else:
        it = enumerate(zip(gid_array, gid_bins))

    for i, (gid, pop_idx) in it:

        pop_name = populations[pop_idx][0]
        pop_start = populations[pop_idx][0]
        spk_t = spike_array["time"][i]

        pop_spk_dict[pop_name][gid].append(spk_t)

    for pop_name, _, _ in populations:

        this_spk_dict = pop_spk_dict[pop_name]
        logger.info(
            f"Saving spike data for population {pop_name} gid set {sorted(this_spk_dict.keys())}"
        )
        output_dict = {
            gid: {
                't': np.asarray(spk_ts, dtype=np.float32)
            }
            for gid, spk_ts in viewitems(this_spk_dict)
        }

        write_cell_attributes(output_path,
                              pop_name,
                              output_dict,
                              namespace=namespace,
                              comm=comm)
        logger.info(
            f"Saved spike data for population {pop_name} to file {output_path}"
        )

    comm.barrier()
Exemple #8
0
def place_fields(population,
                 bin_size,
                 rate_dict,
                 trajectory,
                 arena_id=None,
                 trajectory_id=None,
                 nstdev=1.5,
                 binsteps=5,
                 baseline_fraction=None,
                 output_file_path=None,
                 progress=False,
                 **kwargs):
    """
    Estimates place fields from the given instantaneous spike rate dictionary.
    :param population: str
    :param bin_size: float
    :param rate_dict: dict
    :param trajectory: tuple of array
    :param arena_id: str
    :param trajectory_id: str
    :param nstdev: float
    :param binsteps: float
    :param baseline_fraction: float
    :param min_pf_width: float
    :param output_file_path: str (path to file)
    :param verbose: bool
    :return: dict
    """

    if progress:
        from tqdm import tqdm

    analysis_options = copy.copy(default_pf_analysis_options)
    analysis_options.update(kwargs)

    min_pf_width = analysis_options['Minimum Width']
    min_pf_rate = analysis_options['Minimum Rate']

    (trj_x, trj_y, trj_d, trj_t) = trajectory

    pf_dict = {}
    pf_total_count = 0
    pf_cell_count = 0
    cell_count = 0
    pf_min = sys.maxsize
    pf_max = 0
    ncells = len(rate_dict)

    if progress:
        it = tqdm(viewitems(rate_dict))
    else:
        it = viewitems(rate_dict)

    for ind, valdict in it:
        t = valdict['time']
        rate = valdict['rate']
        m = np.mean(rate)
        rate1 = np.subtract(rate, m)
        if baseline_fraction is None:
            s = np.std(rate1)
        else:
            k = rate1.shape[0] / baseline_fraction
            s = np.std(rate1[np.argpartition(rate1, k)[:k]])
        tmin = t[0]
        tmax = t[-1]
        bins = np.arange(tmin, tmax, bin_size)
        bin_rates = []
        bin_norm_rates = []
        pf_ibins = []
        for ibin in range(1, len(bins)):
            binx = np.linspace(bins[ibin - 1], bins[ibin], binsteps)
            interp_rate1 = np.interp(binx, t,
                                     np.asarray(rate1, dtype=np.float64))
            interp_rate = np.interp(binx, t, np.asarray(rate,
                                                        dtype=np.float64))
            r_n = np.mean(interp_rate1)
            r = np.mean(interp_rate)
            bin_rates.append(r)
            bin_norm_rates.append(r_n)
            if r_n > nstdev * s:
                pf_ibins.append(ibin - 1)

        bin_rates = np.asarray(bin_rates)
        bin_norm_rates = np.asarray(bin_norm_rates)

        if len(pf_ibins) > 0:
            pf_consecutive_ibins = []
            pf_consecutive_bins = []
            pf_widths = []
            pf_rates = []
            for pf_ibin_array in consecutive(pf_ibins):
                pf_ibin_range = np.asarray(
                    [np.min(pf_ibin_array),
                     np.max(pf_ibin_array)])
                pf_bin_range = np.asarray(
                    [bins[pf_ibin_range[0]], bins[pf_ibin_range[1]]])
                pf_bin_rates = [bin_rates[ibin] for ibin in pf_ibin_array]
                pf_width = np.diff(np.interp(pf_bin_range, trj_t, trj_d))[0]
                pf_consecutive_ibins.append(pf_ibin_range)
                pf_consecutive_bins.append(pf_bin_range)
                pf_widths.append(pf_width)
                pf_rates.append(np.mean(pf_bin_rates))

            if min_pf_rate is None:
                pf_filtered_ibins = [
                    pf_consecutive_ibins[i]
                    for i, pf_width in enumerate(pf_widths)
                    if pf_width >= min_pf_width
                ]
            else:
                pf_filtered_ibins = [
                    pf_consecutive_ibins[i]
                    for i, (pf_width,
                            pf_rate) in enumerate(zip(pf_widths, pf_rates))
                    if (pf_width >= min_pf_width) and (pf_rate >= min_pf_rate)
                ]

            pf_count = len(pf_filtered_ibins)
            pf_ibins = [
                list(range(pf_ibin[0], pf_ibin[1] + 1))
                for pf_ibin in pf_filtered_ibins
            ]
            pf_mean_width = []
            pf_mean_rate = []
            pf_peak_rate = []
            pf_mean_norm_rate = []
            pf_x_locs = []
            pf_y_locs = []
            for pf_ibin_iter in pf_ibins:
                pf_ibin_array = list(pf_ibin_iter)
                pf_ibin_range = np.asarray(
                    [np.min(pf_ibin_array),
                     np.max(pf_ibin_array)])
                pf_bin_range = np.asarray(
                    [bins[pf_ibin_range[0]], bins[pf_ibin_range[1]]])
                pf_mean_width.append(
                    np.mean(
                        np.asarray([
                            pf_width for pf_width in pf_widths
                            if pf_width >= min_pf_width
                        ])))
                pf_mean_rate.append(
                    np.mean(np.asarray(bin_rates[pf_ibin_array])))
                pf_peak_rate.append(
                    np.max(np.asarray(bin_rates[pf_ibin_array])))
                pf_mean_norm_rate.append(
                    np.mean(np.asarray(bin_norm_rates[pf_ibin_array])))
                pf_x_range = np.interp(pf_bin_range, trj_t, trj_x)
                pf_y_range = np.interp(pf_bin_range, trj_t, trj_y)
                pf_x_locs.append(np.mean(pf_x_range))
                pf_y_locs.append(np.mean(pf_y_range))

            pf_min = min(pf_count, pf_min)
            pf_max = max(pf_count, pf_max)
            pf_cell_count += 1
            pf_total_count += pf_count
        else:
            pf_count = 0
            pf_mean_width = []
            pf_mean_rate = []
            pf_peak_rate = []
            pf_mean_norm_rate = []
            pf_x_locs = []
            pf_y_locs = []

        cell_count += 1
        pf_dict[ind] = {
            'pf_count': np.asarray([pf_count], dtype=np.uint32),
            'pf_mean_width': np.asarray(pf_mean_width, dtype=np.float32),
            'pf_mean_rate': np.asarray(pf_mean_rate, dtype=np.float32),
            'pf_peak_rate': np.asarray(pf_peak_rate, dtype=np.float32),
            'pf_mean_norm_rate': np.asarray(pf_mean_norm_rate,
                                            dtype=np.float32),
            'pf_x_locs': np.asarray(pf_x_locs),
            'pf_y_locs': np.asarray(pf_y_locs)
        }

    logger.info('%s place fields: %i cells min %i max %i mean %f\n' %
                (population, cell_count, pf_min, pf_max,
                 float(pf_total_count) / float(cell_count)))
    if output_file_path is not None:
        if arena_id is None or trajectory_id is None:
            raise RuntimeError(
                'spikedata.place_fields: arena_id and trajectory_id required to write %s namespace'
                % 'Place Fields')
        namespace = 'Place Fields %s %s' % (arena_id, trajectory_id)
        write_cell_attributes(output_file_path,
                              population,
                              pf_dict,
                              namespace=namespace)

    return pf_dict
Exemple #9
0
def spatial_information(population,
                        trajectory,
                        spkdict,
                        time_range,
                        position_bin_size,
                        arena_id=None,
                        trajectory_id=None,
                        output_file_path=None,
                        information_attr_name='Mutual Information',
                        progress=False,
                        **kwargs):
    """
    Calculates mutual information for the given spatial trajectory and spike trains.
    :param population:
    :param trajectory:
    :param spkdict:
    :param time_range:
    :param position_bin_size:
    :param arena_id: str
    :param trajectory_id: str
    :param output_file_path: str (path to file)
    :param information_attr_name: str
    :return: dict
    """
    tmin = time_range[0]
    tmax = time_range[1]

    x, y, d, t = trajectory

    t_inds = np.where((t >= tmin) & (t <= tmax))
    t = t[t_inds]
    d = d[t_inds]

    d_extent = np.max(d) - np.min(d)
    position_bins = np.arange(np.min(d), np.max(d), position_bin_size)
    d_bin_inds = np.digitize(d, bins=position_bins)
    t_bin_ind_lst = [0]
    for ibin in range(1, len(position_bins) + 1):
        bin_inds = np.where(d_bin_inds == ibin)
        t_bin_ind_lst.append(np.max(bin_inds))
    t_bin_inds = np.asarray(t_bin_ind_lst)
    time_bins = t[t_bin_inds]

    d_bin_probs = {}
    prev_bin = np.min(d)
    for ibin in range(1, len(position_bins) + 1):
        d_bin = d[d_bin_inds == ibin]
        if d_bin.size > 0:
            bin_max = np.max(d_bin)
            d_prob = (bin_max - prev_bin) / d_extent
            d_bin_probs[ibin] = d_prob
            prev_bin = bin_max
        else:
            d_bin_probs[ibin] = 0.

    rate_bin_dict = spike_density_estimate(population,
                                           spkdict,
                                           time_bins,
                                           arena_id=arena_id,
                                           trajectory_id=trajectory_id,
                                           output_file_path=output_file_path,
                                           progress=progress,
                                           **kwargs)

    MI_dict = {}
    for ind, valdict in viewitems(rate_bin_dict):
        MI = 0.
        x = valdict['time']
        rates = valdict['rate']
        R = np.mean(rates)

        if R > 0.:
            for ibin in range(1, len(position_bins) + 1):
                p_i = d_bin_probs[ibin]
                R_i = rates[ibin - 1]
                if R_i > 0.:
                    MI += p_i * (R_i / R) * math.log((R_i / R), 2)

        MI_dict[ind] = MI

    if output_file_path is not None:
        if arena_id is None or trajectory_id is None:
            raise RuntimeError(
                'spikedata.spatial_information: arena_id and trajectory_id required to write Spatial '
                'Mutual Information namespace')
        namespace = 'Spatial Mutual Information %s %s' % (arena_id,
                                                          trajectory_id)
        attr_dict = {
            ind: {
                information_attr_name: np.array(MI_dict[ind], dtype='float32')
            }
            for ind in MI_dict
        }
        write_cell_attributes(output_file_path,
                              population,
                              attr_dict,
                              namespace=namespace)

    return MI_dict
    mapping = {name: idx for name, start, count, idx in defs}
    dt = h5py.special_dtype(enum=(np.uint16, mapping))
    h5[path_population_labels] = dt

    dt = np.dtype([("Start", np.uint64), ("Count", np.uint32),
                   ("Population", h5[path_population_labels].dtype)])
    h5[path_population_range] = dt

    # create an HDF5 compound type for population ranges
    dt = h5[path_population_range].dtype

    g = h5_get_group(h5, grp_h5types)

    dset = h5_get_dataset(g, grp_populations, maxshape=(n_pop, ), dtype=dt)
    dset.resize((n_pop, ))
    a = np.zeros(n_pop, dtype=dt)
    idx = 0
    for name, start, count, idx in defs:
        a[idx]["Start"] = start
        a[idx]["Count"] = count
        a[idx]["Population"] = idx
        idx += 1

    dset[:] = a

write_cell_attributes(output_path,
                      pop_name,
                      attr_dict,
                      namespace='Test Attributes')
print(list(read_cell_attributes(output_path, pop_name, 'Test Attributes')))