Beispiel #1
0
def main(coords_path, coords_namespace, io_size, cache_size):

    comm = MPI.COMM_WORLD
    rank = comm.rank
    size = comm.size

    print('Allocated %i ranks' % size)

    population_ranges = read_population_ranges(coords_path)[0]
    print(population_ranges)

    soma_coords = {}
    for population in population_ranges.keys():

        attr_iter = NeuroH5CellAttrGen(coords_path, population, namespace=coords_namespace, \
                                        comm=comm, io_size=io_size, cache_size=cache_size)

        i = 0
        for cell_gid, coords_dict in attr_iter:

            if cell_gid is not None:
                print('coords_dict: ', coords_dict)
                cell_u = coords_dict['U Coordinate']
                cell_v = coords_dict['V Coordinate']

                print('Rank %i: gid = %i u = %f v = %f' %
                      (rank, cell_gid, cell_u, cell_v))
                if i > 10:
                    break
                i = i + 1

    if rank == 0:
        import h5py
        count = 0
        f = h5py.File(coords_path, 'r+')
        if 'test' in f:
            count = f['test'][()]
            del (f['test'])
        f['test'] = count + 1
    comm.barrier()
def assign_cells_to_normalized_position(context):

    rank = context.comm.rank
    population_distances = []
    gid_arc_distance = dict()
    gid_normed_distances = dict()

    for population in ['MPP', 'LPP']:
        #(population_start, population_count) = context.population_ranges[population]
        attr_gen = NeuroH5CellAttrGen(context.coords_path,
                                      population,
                                      namespace=context.distances_namespace,
                                      comm=context.comm,
                                      io_size=context.io_size,
                                      cache_size=context.cache_size)

        for (gid, distances_dict) in attr_gen:
            if gid is None:
                break
            arc_distance_u = distances_dict['U Distance'][0]
            arc_distance_v = distances_dict['V Distance'][0]
            gid_arc_distance[gid] = (arc_distance_u, arc_distance_v)
            population_distances.append((arc_distance_u, arc_distance_v))

    population_distances = np.asarray(population_distances, dtype='float32')

    min_u, max_u = np.min(population_distances[:, 0]), np.max(
        population_distances[:, 0])
    min_v, max_v = np.min(population_distances[:, 1]), np.max(
        population_distances[:, 1])
    for (gid, (arc_distance_u, arc_distance_v)) in viewitems(gid_arc_distance):
        normalized_u = (arc_distance_u - min_u) / (max_u - min_u)
        normalized_v = (arc_distance_v - min_v) / (max_v - min_v)
        gid_normed_distances[gid] = (normalized_u, normalized_v,
                                     arc_distance_u, arc_distance_v)

    return gid_normed_distances
Beispiel #3
0
def generate_uv_distance_connections(comm, population_dict, connection_config, connection_prob, forest_path,
                                     synapse_seed, connectivity_seed, cluster_seed,
                                     synapse_namespace, connectivity_namespace, connectivity_path,
                                     io_size, chunk_size, value_chunk_size, cache_size, write_size=1,
                                     dry_run=False, debug=False):
    """
    Generates connectivity based on U, V distance-weighted probabilities.

    :param comm: mpi4py MPI communicator
    :param connection_config: connection configuration object (instance of env.ConnectionConfig)
    :param connection_prob: ConnectionProb instance
    :param forest_path: location of file with neuronal trees and synapse information
    :param synapse_seed: random seed for synapse partitioning
    :param connectivity_seed: random seed for connectivity generation
    :param cluster_seed: random seed for determining connectivity clustering for repeated connections from the same source
    :param synapse_namespace: namespace of synapse properties
    :param connectivity_namespace: namespace of connectivity attributes
    :param io_size: number of I/O ranks to use for parallel connectivity append
    :param chunk_size: HDF5 chunk size for connectivity file (pointer and index datasets)
    :param value_chunk_size: HDF5 chunk size for connectivity file (value datasets)
    :param cache_size: how many cells to read ahead
    :param write_size: how many cells to write out at the same time
    """

    rank = comm.rank

    if io_size == -1:
        io_size = comm.size
    if rank == 0:
        logger.info(f'{comm.size} ranks have been allocated')

    start_time = time.time()

    ranstream_syn = np.random.RandomState()
    ranstream_con = np.random.RandomState()

    destination_population = connection_prob.destination_population

    source_populations = sorted(connection_config[destination_population].keys())

    for source_population in source_populations:
        if rank == 0:
            logger.info(f'{source_population} -> {destination_population}: \n'
                        f'{pprint.pformat(connection_config[destination_population][source_population])}')

    projection_config = connection_config[destination_population]
    projection_synapse_dict = {source_population:
                                   (projection_config[source_population].type,
                                    projection_config[source_population].layers,
                                    projection_config[source_population].sections,
                                    projection_config[source_population].proportions,
                                    projection_config[source_population].contacts)
                               for source_population in source_populations}

    
    comm.barrier()

    it_count = 0
    total_count = 0
    gid_count = 0
    connection_dict = defaultdict(lambda: {})
    projection_dict = {}
    for destination_gid, synapse_dict in NeuroH5CellAttrGen(forest_path, \
                                                            destination_population, \
                                                            namespace=synapse_namespace, \
                                                            comm=comm, io_size=io_size, \
                                                            cache_size=cache_size):
        if destination_gid is None:
            logger.info(f'Rank {rank} destination gid is None')
        else:
            logger.info(f'Rank {rank} received attributes for destination: {destination_population}, gid: {destination_gid}')

            ranstream_con.seed(destination_gid + connectivity_seed)
            ranstream_syn.seed(destination_gid + synapse_seed)
            last_gid_time = time.time()

            projection_prob_dict = {}
            for source_population in source_populations:
                source_layers = projection_config[source_population].layers
                projection_prob_dict[source_population] = \
                    connection_prob.get_prob(destination_gid, source_population, source_layers)


                for layer, (probs, source_gids, distances_u, distances_v) in \
                        viewitems(projection_prob_dict[source_population]):
                    if len(distances_u) > 0:
                        max_u_distance = np.max(distances_u)
                        min_u_distance = np.min(distances_u)
                        if rank == 0:
                            logger.info(f'Rank {rank} has {len(source_gids)} possible sources from population {source_population} '
                                        f'for destination: {destination_population}, layer {layer}, gid: {destination_gid}; '
                                        f'max U distance: {max_u_distance:.2f} min U distance: {min_u_distance:.2f}')
                    else:
                        logger.warning(f'Rank {rank} has {len(source_gids)} possible sources from population {source_population} '
                                       f'for destination: {destination_population}, layer {layer}, gid: {destination_gid}')

            count = generate_synaptic_connections(rank,
                                                  destination_gid,
                                                  ranstream_syn,
                                                  ranstream_con,
                                                  cluster_seed + destination_gid,
                                                  destination_gid,
                                                  synapse_dict,
                                                  population_dict,
                                                  projection_synapse_dict,
                                                  projection_prob_dict,
                                                  connection_dict,
                                                  debug_flag=debug)
            total_count += count

            logger.info(f'Rank {rank} took {time.time() - last_gid_time:.2f} s to compute {count} edges for destination: {destination_population}, gid: {destination_gid}')

        if (write_size > 0) and (gid_count % write_size == 0):
            if len(connection_dict) > 0:
                projection_dict = {destination_population: connection_dict}
            else:
                projection_dict = {}
            if not dry_run:
                last_time = time.time()
                append_graph(connectivity_path, projection_dict, io_size=io_size, comm=comm)
                if rank == 0:
                    if connection_dict:
                        logger.info(f'Appending connectivity for {len(connection_dict)} projections took {time.time() - last_time:.2f} s')
            projection_dict.clear()
            connection_dict.clear()
            gc.collect()

        gid_count += 1
        it_count += 1
        if (it_count > 250) and debug:
            break


    gc.collect()
    last_time = time.time()
    if len(connection_dict) > 0:
        projection_dict = {destination_population: connection_dict}
    else:
        projection_dict = {}
    if not dry_run:
        append_graph(connectivity_path, projection_dict, io_size=io_size, comm=comm)
        if rank == 0:
            if connection_dict:
                logger.info(f'Appending connectivity for {len(connection_dict)} projections took {time.time() - last_time:.2f} s')

    global_count = comm.gather(total_count, root=0)
    if rank == 0:
        logger.info(f'{comm.size} ranks took {time.time() - start_time:.2f} s to generate {np.sum(global_count)} edges')
Beispiel #4
0
def main(forest_path, cell_attr_path, connections_path, cell_attr_namespace,
         cell_attr, destination, source, io_size, cache_size):
    """

    :param forest_path: str (path)
    :param cell_attr_path: str (path)
    :param connections_path: str (path)
    :param cell_attr_namespace: str
    :param cell_attr: str
    :param destination: str
    :param source: str
    :param io_size: int
    :param cache_size: int
    """
    comm = MPI.COMM_WORLD
    rank = comm.rank

    if io_size == -1:
        io_size = comm.size
    if rank == 0:
        print '%s: %i ranks have been allocated' % (os.path.basename(__file__),
                                                    comm.size)
    sys.stdout.flush()

    pop_ranges, pop_size = read_population_ranges(cell_attr_path, comm=comm)
    destination_gid_offset = pop_ranges[destination][0]
    source_gid_offset = pop_ranges[source][0]
    maxiter = 10
    cell_attr_matched = 0
    cell_attr_processed = 0
    edge_attr_matched = 0
    edge_attr_processed = 0
    tree_attr_matched = 0
    tree_attr_processed = 0

    cell_attr_gen = NeuroH5CellAttrGen(cell_attr_path,
                                       destination,
                                       comm=comm,
                                       io_size=io_size,
                                       cache_size=cache_size,
                                       namespace=cell_attr_namespace)
    index_map = get_cell_attributes_index_map(comm, cell_attr_path,
                                              destination, cell_attr_namespace)
    for itercount, (target_gid, attr_dict) in enumerate(cell_attr_gen):
        print 'Rank: %i receieved target_gid: %s from the cell attribute generator.' % (
            rank, str(target_gid))
        attr_dict2 = select_cell_attributes(
            target_gid,
            comm,
            cell_attr_path,
            index_map,
            destination,
            cell_attr_namespace,
            population_offset=destination_gid_offset)
        if np.all(attr_dict[cell_attr][:] == attr_dict2[cell_attr][:]):
            print 'Rank: %i; cell attributes match!' % rank
            cell_attr_matched += 1
        else:
            print 'Rank: %i; cell attributes do not match.' % rank
        comm.barrier()
        cell_attr_processed += 1
        if itercount > maxiter:
            break
    cell_attr_matched = comm.gather(cell_attr_matched, root=0)
    cell_attr_processed = comm.gather(cell_attr_processed, root=0)

    if connections_path is not None:
        edge_attr_gen = NeuroH5ProjectionGen(connections_path,
                                             source,
                                             destination,
                                             comm=comm,
                                             cache_size=cache_size,
                                             namespaces=['Synapses'])
        index_map = get_edge_attributes_index_map(comm, connections_path,
                                                  source, destination)
        processed = 0
        for itercount, (target_gid, attr_package) in enumerate(edge_attr_gen):
            print 'Rank: %i receieved target_gid: %s from the edge attribute generator.' % (
                rank, str(target_gid))
            source_indexes, attr_dict = attr_package
            syn_ids = attr_dict['Synapses']['syn_id']
            source_indexes2, attr_dict2 = select_edge_attributes(
                target_gid,
                comm,
                connections_path,
                index_map,
                source,
                destination,
                namespaces=['Synapses'],
                source_offset=source_gid_offset,
                destination_offset=destination_gid_offset)
            syn_ids2 = attr_dict2['Synapses']['syn_id']
            if np.all(syn_ids == syn_ids2) and np.all(
                    source_indexes == source_indexes2):
                print 'Rank: %i; edge attributes match!' % rank
                edge_attr_matched += 1
            else:
                print 'Rank: %i; attributes do not match.' % rank
            comm.barrier()
            edge_attr_processed += 1
            if itercount > maxiter:
                break
        edge_attr_matched = comm.gather(edge_attr_matched, root=0)
        edge_attr_processed = comm.gather(edge_attr_processed, root=0)

    if forest_path is not None:
        tree_attr_gen = NeuroH5TreeGen(forest_path,
                                       destination,
                                       comm=comm,
                                       io_size=io_size)
        for itercount, (target_gid, attr_dict) in enumerate(tree_attr_gen):
            print 'Rank: %i receieved target_gid: %s from the tree attribute generator.' % (
                rank, str(target_gid))
            attr_dict2 = select_tree_attributes(target_gid, comm, forest_path,
                                                destination)
            if (attr_dict.keys() == attr_dict2.keys()) and all(
                    attr_dict['layer'] == attr_dict2['layer']):
                print 'Rank: %i; tree attributes match!' % rank
                tree_attr_matched += 1
            else:
                print 'Rank: %i; tree attributes do not match.' % rank
            comm.barrier()
            tree_attr_processed += 1
            if itercount > maxiter:
                break
        tree_attr_matched = comm.gather(tree_attr_matched, root=0)
        tree_attr_processed = comm.gather(tree_attr_processed, root=0)

    if comm.rank == 0:
        print '%i / %i processed gids had matching cell attributes returned by both read methods' % \
              (np.sum(cell_attr_matched), np.sum(cell_attr_processed))
        print '%i / %i processed gids had matching edge attributes returned by both read methods' % \
              (np.sum(edge_attr_matched), np.sum(edge_attr_processed))
        print '%i / %i processed gids had matching tree attributes returned by both read methods' % \
              (np.sum(tree_attr_matched), np.sum(tree_attr_processed))
def main(features_path, connectivity_path, connectivity_namespace, io_size,
         chunk_size, value_chunk_size, cache_size, trajectory_id, debug):
    """

    :param features_path:
    :param connectivity_path:
    :param connectivity_namespace:
    :param io_size:
    :param chunk_size:
    :param value_chunk_size:
    :param cache_size:
    :param trajectory_id:
    :param debug:
    """
    comm = MPI.COMM_WORLD
    rank = comm.rank

    if io_size == -1:
        io_size = comm.size
    if rank == 0:
        print('%i ranks have been allocated' % comm.size)
    sys.stdout.flush()

    population_range_dict = read_population_ranges(comm, features_path)

    features_dict = {}
    for population in ['MPP', 'LPP']:
        features_dict[population] = bcast_cell_attributes(
            comm,
            0,
            features_path,
            population,
            namespace='Feature Selectivity')

    arena_dimension = 100.  # minimum distance from origin to boundary (cm)

    run_vel = 30.  # cm/s
    spatial_resolution = 1.  # cm
    x = np.arange(-arena_dimension, arena_dimension, spatial_resolution)
    y = np.arange(-arena_dimension, arena_dimension, spatial_resolution)
    distance = np.insert(
        np.cumsum(np.sqrt(np.sum(
            [np.diff(x)**2., np.diff(y)**2.], axis=0))), 0, 0.)
    interp_distance = np.arange(distance[0], distance[-1], spatial_resolution)
    t = old_div(interp_distance, run_vel * 1000.)  # ms
    interp_x = np.interp(interp_distance, distance, x)
    interp_y = np.interp(interp_distance, distance, y)

    with h5py.File(features_path, 'a', driver='mpio', comm=comm) as f:
        if 'Trajectories' not in f:
            f.create_group('Trajectories')
        if str(trajectory_id) not in f['Trajectories']:
            f['Trajectories'].create_group(str(trajectory_id))
            f['Trajectories'][str(trajectory_id)].create_dataset(
                'x', dtype='float32', data=interp_x)
            f['Trajectories'][str(trajectory_id)].create_dataset(
                'y', dtype='float32', data=interp_y)
            f['Trajectories'][str(trajectory_id)].create_dataset(
                'd', dtype='float32', data=interp_distance)
            f['Trajectories'][str(trajectory_id)].create_dataset(
                't', dtype='float32', data=t)
        x = f['Trajectories'][str(trajectory_id)]['x'][:]
        y = f['Trajectories'][str(trajectory_id)]['y'][:]
        d = f['Trajectories'][str(trajectory_id)]['d'][:]

    prediction_namespace = 'Response Prediction ' + str(trajectory_id)

    target_population = 'GC'
    count = 0
    start_time = time.time()
    attr_gen = NeuroH5CellAttrGen(comm,
                                  connectivity_path,
                                  target_population,
                                  io_size=io_size,
                                  cache_size=cache_size,
                                  namespace=connectivity_namespace)
    if debug:
        attr_gen_wrapper = (next(attr_gen) for i in range(2))
    else:
        attr_gen_wrapper = attr_gen
    for gid, connectivity_dict in attr_gen_wrapper:
        local_time = time.time()
        source_gid_counts = {}
        response_dict = {}
        response = np.zeros_like(d, dtype='float32')
        if gid is not None:
            for population in ['MPP', 'LPP']:
                indexes = np.where(
                    (connectivity_dict[connectivity_namespace]['source_gid'] >=
                     population_range_dict[population][0])
                    & (connectivity_dict[connectivity_namespace]['source_gid']
                       < population_range_dict[population][0] +
                       population_range_dict[population][1]))[0]
                source_gid_counts[population] = \
                    Counter(connectivity_dict[connectivity_namespace]['source_gid'][indexes])
            for population in ['MPP', 'LPP']:
                for source_gid in (
                        source_gid
                        for source_gid in source_gid_counts[population]
                        if source_gid in features_dict[population]):
                    this_feature_dict = features_dict[population][source_gid]
                    selectivity_type = this_feature_dict['Selectivity Type'][0]
                    contact_count = source_gid_counts[population][source_gid]
                    if selectivity_type == selectivity_grid:
                        ori_offset = this_feature_dict['Grid Orientation'][0]
                        grid_spacing = this_feature_dict['Grid Spacing'][0]
                        x_offset = this_feature_dict['X Offset'][0]
                        y_offset = this_feature_dict['Y Offset'][0]
                        rate = np.vectorize(
                            grid_rate(grid_spacing, ori_offset, x_offset,
                                      y_offset))
                    elif selectivity_type == selectivity_place_field:
                        field_width = this_feature_dict['Field Width'][0]
                        x_offset = this_feature_dict['X Offset'][0]
                        y_offset = this_feature_dict['Y Offset'][0]
                        rate = np.vectorize(
                            place_rate(field_width, x_offset, y_offset))
                    response = np.add(response,
                                      contact_count * rate(x, y),
                                      dtype='float32')
            response_dict[gid] = {'waveform': response}
            baseline = np.mean(response[np.where(
                response <= np.percentile(response, 10.))[0]])
            peak = np.mean(response[np.where(
                response >= np.percentile(response, 90.))[0]])
            modulation = 0. if peak <= 0.1 else old_div(
                (peak - baseline), peak)
            peak_index = np.where(response == np.max(response))[0][0]
            response_dict[gid]['modulation'] = np.array([modulation],
                                                        dtype='float32')
            response_dict[gid]['peak_index'] = np.array([peak_index],
                                                        dtype='uint32')
            print('Rank %i: took %.2f s to compute predicted response for %s gid %i' % \
                  (rank, time.time() - local_time, target_population, gid))
            count += 1
        if not debug:
            append_cell_attributes(comm,
                                   features_path,
                                   target_population,
                                   response_dict,
                                   namespace=prediction_namespace,
                                   io_size=io_size,
                                   chunk_size=chunk_size,
                                   value_chunk_size=value_chunk_size)
        sys.stdout.flush()
        del response
        del response_dict
        del source_gid_counts
        gc.collect()

    global_count = comm.gather(count, root=0)
    if rank == 0:
        print('%i ranks took %.2f s to compute selectivity parameters for %i %s cells' % \
              (comm.size, time.time() - start_time, np.sum(global_count), target_population))
def main(config, config_prefix, coords_path, distances_namespace, bin_distance,
         selectivity_path, selectivity_namespace, subset_seed, arena_id,
         populations, io_size, cache_size, verbose, debug, show_fig, save_fig,
         save_fig_dir, font_size, fig_size, colormap, fig_format):
    """

    :param config: str (.yaml file name)
    :param config_prefix: str (path to dir)
    :param coords_path: str (path to file)
    :param distances_namespace: str
    :param bin_distance: float
    :param selectivity_path: str
    :param subset_seed: int; for reproducible choice of gids to plot individual rate maps
    :param arena_id: str
    :param populations: tuple of str
    :param io_size: int
    :param cache_size: int
    :param verbose: bool
    :param debug: bool
    :param show_fig: bool
    :param save_fig: str (base file name)
    :param save_fig_dir:  str (path to dir)
    :param font_size: float
    :param fig_format: str
    """
    comm = MPI.COMM_WORLD
    rank = comm.rank

    config_logging(verbose)

    env = Env(comm=comm,
              config_file=config,
              config_prefix=config_prefix,
              template_paths=None)
    if io_size == -1:
        io_size = comm.size
    if rank == 0:
        logger.info('%i ranks have been allocated' % comm.size)

    fig_options = copy.copy(default_fig_options)
    fig_options.saveFigDir = save_fig_dir
    fig_options.fontSize = font_size
    fig_options.figFormat = fig_format
    fig_options.showFig = show_fig
    fig_options.figSize = fig_size

    if save_fig is not None:
        save_fig = '%s %s' % (save_fig, arena_id)
    fig_options.saveFig = save_fig

    population_ranges = read_population_ranges(selectivity_path, comm)[0]
    coords_population_ranges = read_population_ranges(coords_path, comm)[0]

    if len(populations) == 0:
        populations = ('MC', 'ConMC', 'LPP', 'GC', 'MPP', 'CA3c')

    valid_selectivity_namespaces = dict()
    if rank == 0:
        for population in populations:
            if population not in population_ranges:
                raise RuntimeError(
                    'plot_input_selectivity_features: specified population: %s not found in '
                    'provided selectivity_path: %s' %
                    (population, selectivity_path))
            if population not in env.stimulus_config[
                    'Selectivity Type Probabilities']:
                raise RuntimeError(
                    'plot_input_selectivity_features: selectivity type not specified for '
                    'population: %s' % population)
            valid_selectivity_namespaces[population] = []
            with h5py.File(selectivity_path, 'r') as selectivity_f:
                for this_namespace in selectivity_f['Populations'][population]:
                    if f'{selectivity_namespace} {arena_id}' in this_namespace:
                        valid_selectivity_namespaces[population].append(
                            this_namespace)
                if len(valid_selectivity_namespaces[population]) == 0:
                    raise RuntimeError(
                        'plot_input_selectivity_features: no selectivity data in arena: %s found '
                        'for specified population: %s in provided selectivity_path: %s'
                        % (arena_id, population, selectivity_path))

    valid_selectivity_namespaces = comm.bcast(valid_selectivity_namespaces,
                                              root=0)
    selectivity_type_names = dict(
        (val, key) for (key, val) in viewitems(env.selectivity_types))

    reference_u_arc_distance_bounds = None
    reference_v_arc_distance_bounds = None
    if rank == 0:
        for population in populations:
            if population not in coords_population_ranges:
                raise RuntimeError(
                    'plot_input_selectivity_features: specified population: %s not found in '
                    'provided coords_path: %s' % (population, coords_path))
            with h5py.File(coords_path, 'r') as coords_f:
                pop_size = population_ranges[population][1]
                unique_gid_count = len(
                    set(coords_f['Populations'][population]
                        [distances_namespace]['U Distance']['Cell Index'][:]))
                if pop_size != unique_gid_count:
                    raise RuntimeError(
                        'plot_input_selectivity_features: only %i/%i unique cell indexes found '
                        'for specified population: %s in provided coords_path: %s'
                        %
                        (unique_gid_count, pop_size, population, coords_path))
                if reference_u_arc_distance_bounds is None:
                    try:
                        reference_u_arc_distance_bounds = \
                            coords_f['Populations'][population][distances_namespace].attrs['Reference U Min'], \
                            coords_f['Populations'][population][distances_namespace].attrs['Reference U Max']
                    except Exception:
                        raise RuntimeError(
                            'plot_input_selectivity_features: problem locating attributes '
                            'containing reference bounds in namespace: %s for population: %s from '
                            'coords_path: %s' %
                            (distances_namespace, population, coords_path))
                if reference_v_arc_distance_bounds is None:
                    try:
                        reference_v_arc_distance_bounds = \
                            coords_f['Populations'][population][distances_namespace].attrs['Reference V Min'], \
                            coords_f['Populations'][population][distances_namespace].attrs['Reference V Max']
                    except Exception:
                        raise RuntimeError(
                            'plot_input_selectivity_features: problem locating attributes '
                            'containing reference bounds in namespace: %s for population: %s from '
                            'coords_path: %s' %
                            (distances_namespace, population, coords_path))
    reference_u_arc_distance_bounds = comm.bcast(
        reference_u_arc_distance_bounds, root=0)
    reference_v_arc_distance_bounds = comm.bcast(
        reference_v_arc_distance_bounds, root=0)

    u_edges = np.arange(reference_u_arc_distance_bounds[0],
                        reference_u_arc_distance_bounds[1] + bin_distance / 2.,
                        bin_distance)
    v_edges = np.arange(reference_v_arc_distance_bounds[0],
                        reference_v_arc_distance_bounds[1] + bin_distance / 2.,
                        bin_distance)

    if arena_id not in env.stimulus_config['Arena']:
        raise RuntimeError(
            'Arena with ID: %s not specified by configuration at file path: %s'
            % (arena_id, config_prefix + '/' + config))

    arena = env.stimulus_config['Arena'][arena_id]
    arena_x_mesh, arena_y_mesh = None, None
    if rank == 0:
        arena_x_mesh, arena_y_mesh = \
            get_2D_arena_spatial_mesh(arena=arena, spatial_resolution=env.stimulus_config['Spatial Resolution'])
    arena_x_mesh = comm.bcast(arena_x_mesh, root=0)
    arena_y_mesh = comm.bcast(arena_y_mesh, root=0)

    for population in populations:

        start_time = time.time()
        u_distances_by_gid = dict()
        v_distances_by_gid = dict()
        distances_attr_gen = \
            bcast_cell_attributes(coords_path, population, root=0, namespace=distances_namespace, comm=comm)
        for gid, distances_attr_dict in distances_attr_gen:
            u_distances_by_gid[gid] = distances_attr_dict['U Distance'][0]
            v_distances_by_gid[gid] = distances_attr_dict['V Distance'][0]

        if rank == 0:
            logger.info(
                'Reading %i cell positions for population %s took %.2f s' %
                (len(u_distances_by_gid), population,
                 time.time() - start_time))

        for this_selectivity_namespace in valid_selectivity_namespaces[
                population]:
            start_time = time.time()
            if rank == 0:
                logger.info('Reading from %s namespace for population %s...' %
                            (this_selectivity_namespace, population))
            gid_count = 0
            gathered_cell_attributes = defaultdict(list)
            gathered_component_attributes = defaultdict(list)
            u_distances_by_cell = list()
            v_distances_by_cell = list()
            u_distances_by_component = list()
            v_distances_by_component = list()
            rate_map_sum_by_module = defaultdict(
                lambda: np.zeros_like(arena_x_mesh))
            start_time = time.time()
            selectivity_attr_gen = NeuroH5CellAttrGen(
                selectivity_path,
                population,
                namespace=this_selectivity_namespace,
                comm=comm,
                io_size=io_size,
                cache_size=cache_size)
            for iter_count, (
                    gid,
                    selectivity_attr_dict) in enumerate(selectivity_attr_gen):
                if gid is not None:
                    gid_count += 1
                    this_selectivity_type = selectivity_attr_dict[
                        'Selectivity Type'][0]
                    this_selectivity_type_name = selectivity_type_names[
                        this_selectivity_type]
                    input_cell_config = \
                        get_input_cell_config(selectivity_type=this_selectivity_type,
                                               selectivity_type_names=selectivity_type_names,
                                               selectivity_attr_dict=selectivity_attr_dict)
                    rate_map = input_cell_config.get_rate_map(x=arena_x_mesh,
                                                              y=arena_y_mesh)
                    u_distances_by_cell.append(u_distances_by_gid[gid])
                    v_distances_by_cell.append(v_distances_by_gid[gid])
                    this_cell_attrs, component_count, this_component_attrs = input_cell_config.gather_attributes(
                    )
                    for attr_name, attr_val in viewitems(this_cell_attrs):
                        gathered_cell_attributes[attr_name].append(attr_val)
                    gathered_cell_attributes['Mean Rate'].append(
                        np.mean(rate_map))
                    if component_count > 0:
                        u_distances_by_component.extend(
                            [u_distances_by_gid[gid]] * component_count)
                        v_distances_by_component.extend(
                            [v_distances_by_gid[gid]] * component_count)
                        for attr_name, attr_val in viewitems(
                                this_component_attrs):
                            gathered_component_attributes[attr_name].extend(
                                attr_val)
                    this_module_id = this_cell_attrs['Module ID']
                    if debug and rank == 0:
                        fig_title = '%s %s cell %i' % (
                            population, this_selectivity_type_name, gid)
                        if save_fig is not None:
                            fig_options.saveFig = '%s %s' % (save_fig,
                                                             fig_title)
                        plot_2D_rate_map(
                            x=arena_x_mesh,
                            y=arena_y_mesh,
                            rate_map=rate_map,
                            peak_rate=env.stimulus_config['Peak Rate']
                            [population][this_selectivity_type],
                            title='%s\nModule: %i' %
                            (fig_title, this_module_id),
                            **fig_options())
                    rate_map_sum_by_module[this_module_id] = np.add(
                        rate_map, rate_map_sum_by_module[this_module_id])
                if debug and iter_count >= 10:
                    break

            cell_count_hist, _, _ = np.histogram2d(u_distances_by_cell,
                                                   v_distances_by_cell,
                                                   bins=[u_edges, v_edges])
            component_count_hist, _, _ = np.histogram2d(
                u_distances_by_component,
                v_distances_by_component,
                bins=[u_edges, v_edges])

            if debug:
                context.update(locals())

            gathered_cell_attr_hist = dict()
            gathered_component_attr_hist = dict()
            for key in gathered_cell_attributes:
                gathered_cell_attr_hist[key], _, _ = \
                    np.histogram2d(u_distances_by_cell, v_distances_by_cell, bins=[u_edges, v_edges],
                                   weights=gathered_cell_attributes[key])
            for key in gathered_component_attributes:
                gathered_component_attr_hist[key], _, _ = \
                    np.histogram2d(u_distances_by_component, v_distances_by_component, bins=[u_edges, v_edges],
                                   weights=gathered_component_attributes[key])
            gid_count = comm.gather(gid_count, root=0)
            cell_count_hist = comm.gather(cell_count_hist, root=0)
            component_count_hist = comm.gather(component_count_hist, root=0)
            gathered_cell_attr_hist = comm.gather(gathered_cell_attr_hist,
                                                  root=0)
            gathered_component_attr_hist = comm.gather(
                gathered_component_attr_hist, root=0)
            rate_map_sum_by_module = dict(rate_map_sum_by_module)
            rate_map_sum_by_module = comm.gather(rate_map_sum_by_module,
                                                 root=0)

            if rank == 0:
                gid_count = sum(gid_count)
                cell_count_hist = np.sum(cell_count_hist, axis=0)
                component_count_hist = np.sum(component_count_hist, axis=0)
                merged_cell_attr_hist = defaultdict(
                    lambda: np.zeros_like(cell_count_hist))
                merged_component_attr_hist = defaultdict(
                    lambda: np.zeros_like(component_count_hist))
                for each_cell_attr_hist in gathered_cell_attr_hist:
                    for key in each_cell_attr_hist:
                        merged_cell_attr_hist[key] = np.add(
                            merged_cell_attr_hist[key],
                            each_cell_attr_hist[key])
                for each_component_attr_hist in gathered_component_attr_hist:
                    for key in each_component_attr_hist:
                        merged_component_attr_hist[key] = np.add(
                            merged_component_attr_hist[key],
                            each_component_attr_hist[key])
                merged_rate_map_sum_by_module = defaultdict(
                    lambda: np.zeros_like(arena_x_mesh))
                for each_rate_map_sum_by_module in rate_map_sum_by_module:
                    for this_module_id in each_rate_map_sum_by_module:
                        merged_rate_map_sum_by_module[this_module_id] = \
                            np.add(merged_rate_map_sum_by_module[this_module_id],
                                   each_rate_map_sum_by_module[this_module_id])

                logger.info('Processing %i %s %s cells took %.2f s' %
                            (gid_count, population, this_selectivity_type_name,
                             time.time() - start_time))

                if debug:
                    context.update(locals())

                for key in merged_cell_attr_hist:
                    fig_title = '%s %s cells %s distribution' % (
                        population, this_selectivity_type_name, key)
                    if save_fig is not None:
                        fig_options.saveFig = '%s %s' % (save_fig, fig_title)
                    if colormap is not None:
                        fig_options.colormap = colormap
                    title = '%s %s cells\n%s distribution' % (
                        population, this_selectivity_type_name, key)
                    fig = plot_2D_histogram(
                        merged_cell_attr_hist[key],
                        x_edges=u_edges,
                        y_edges=v_edges,
                        norm=cell_count_hist,
                        ylabel='Transverse position (um)',
                        xlabel='Septo-temporal position (um)',
                        title=title,
                        cbar_label='Mean value per bin',
                        cbar=True,
                        **fig_options())
                    close_figure(fig)

                for key in merged_component_attr_hist:
                    fig_title = '%s %s cells %s distribution' % (
                        population, this_selectivity_type_name, key)
                    if save_fig is not None:
                        fig_options.saveFig = '%s %s' % (save_fig, fig_title)
                    title = '%s %s cells\n%s distribution' % (
                        population, this_selectivity_type_name, key)
                    fig = plot_2D_histogram(
                        merged_component_attr_hist[key],
                        x_edges=u_edges,
                        y_edges=v_edges,
                        norm=component_count_hist,
                        ylabel='Transverse position (um)',
                        xlabel='Septo-temporal position (um)',
                        title=title,
                        cbar_label='Mean value per bin',
                        cbar=True,
                        **fig_options())
                    close_figure(fig)

                for this_module_id in merged_rate_map_sum_by_module:
                    fig_title = '%s %s Module %i summed rate maps' % \
                                (population, this_selectivity_type_name, this_module_id)
                    if save_fig is not None:
                        fig_options.saveFig = '%s %s' % (save_fig, fig_title)
                    fig = plot_2D_rate_map(
                        x=arena_x_mesh,
                        y=arena_y_mesh,
                        rate_map=merged_rate_map_sum_by_module[this_module_id],
                        title='%s %s summed rate maps\nModule %i' %
                        (population, this_selectivity_type_name,
                         this_module_id),
                        **fig_options())
                    close_figure(fig)

    if is_interactive and rank == 0:
        context.update(locals())
Beispiel #7
0
def main(config, config_prefix, coords_path, distances_namespace, output_path,
         arena_id, populations, io_size, chunk_size, value_chunk_size,
         cache_size, write_size, verbose, gather, interactive, debug, plot,
         show_fig, save_fig, save_fig_dir, font_size, fig_format, dry_run):
    """

    :param config: str (.yaml file name)
    :param config_prefix: str (path to dir)
    :param coords_path: str (path to file)
    :param distances_namespace: str
    :param output_path: str
    :param arena_id: str
    :param populations: tuple of str
    :param io_size: int
    :param chunk_size: int
    :param value_chunk_size: int
    :param cache_size: int
    :param write_size: int
    :param verbose: bool
    :param gather: bool; whether to gather population attributes to rank 0 for interactive analysis or plotting
    :param interactive: bool
    :param debug: bool
    :param plot: bool
    :param show_fig: bool
    :param save_fig: str (base file name)
    :param save_fig_dir:  str (path to dir)
    :param font_size: float
    :param fig_format: str
    :param dry_run: bool
    """
    comm = MPI.COMM_WORLD
    rank = comm.rank

    config_logging(verbose)

    env = Env(comm=comm,
              config_file=config,
              config_prefix=config_prefix,
              template_paths=None)
    if io_size == -1:
        io_size = comm.size
    if rank == 0:
        logger.info('%i ranks have been allocated' % comm.size)

    if save_fig is not None:
        plot = True

    if plot:
        import matplotlib.pyplot as plt
        from dentate.plot import plot_2D_rate_map, default_fig_options, save_figure, clean_axes

        fig_options = copy.copy(default_fig_options)
        fig_options.saveFigDir = save_fig_dir
        fig_options.fontSize = font_size
        fig_options.figFormat = fig_format
        fig_options.showFig = show_fig

    if save_fig is not None:
        save_fig = '%s %s' % (save_fig, arena_id)
        fig_options.saveFig = save_fig

    if not dry_run and rank == 0:
        if output_path is None:
            raise RuntimeError(
                'generate_input_selectivity_features: missing output_path')
        if not os.path.isfile(output_path):
            input_file = h5py.File(coords_path, 'r')
            output_file = h5py.File(output_path, 'w')
            input_file.copy('/H5Types', output_file)
            input_file.close()
            output_file.close()
    comm.barrier()
    population_ranges = read_population_ranges(coords_path, comm)[0]

    if len(populations) == 0:
        populations = sorted(population_ranges.keys())

    reference_u_arc_distance_bounds_dict = {}
    if rank == 0:
        for population in sorted(populations):
            if population not in population_ranges:
                raise RuntimeError(
                    'generate_input_selectivity_features: specified population: %s not found in '
                    'provided coords_path: %s' % (population, coords_path))
            if population not in env.stimulus_config[
                    'Selectivity Type Probabilities']:
                raise RuntimeError(
                    'generate_input_selectivity_features: selectivity type not specified for '
                    'population: %s' % population)
            with h5py.File(coords_path, 'r') as coords_f:
                pop_size = population_ranges[population][1]
                unique_gid_count = len(
                    set(coords_f['Populations'][population]
                        [distances_namespace]['U Distance']['Cell Index'][:]))
                if pop_size != unique_gid_count:
                    raise RuntimeError(
                        'generate_input_selectivity_features: only %i/%i unique cell indexes found '
                        'for specified population: %s in provided coords_path: %s'
                        %
                        (unique_gid_count, pop_size, population, coords_path))
                try:
                    reference_u_arc_distance_bounds_dict[population] = \
                      coords_f['Populations'][population][distances_namespace].attrs['Reference U Min'], \
                      coords_f['Populations'][population][distances_namespace].attrs['Reference U Max']
                except Exception:
                    raise RuntimeError(
                        'generate_input_selectivity_features: problem locating attributes '
                        'containing reference bounds in namespace: %s for population: %s from '
                        'coords_path: %s' %
                        (distances_namespace, population, coords_path))
    comm.barrier()
    reference_u_arc_distance_bounds_dict = comm.bcast(
        reference_u_arc_distance_bounds_dict, root=0)

    selectivity_type_names = dict([
        (val, key) for (key, val) in viewitems(env.selectivity_types)
    ])
    selectivity_type_namespaces = dict()
    for this_selectivity_type in selectivity_type_names:
        this_selectivity_type_name = selectivity_type_names[
            this_selectivity_type]
        chars = list(this_selectivity_type_name)
        chars[0] = chars[0].upper()
        selectivity_type_namespaces[this_selectivity_type_name] = ''.join(
            chars) + ' Selectivity %s' % arena_id

    if arena_id not in env.stimulus_config['Arena']:
        raise RuntimeError(
            'Arena with ID: %s not specified by configuration at file path: %s'
            % (arena_id, config_prefix + '/' + config))
    arena = env.stimulus_config['Arena'][arena_id]
    arena_x_mesh, arena_y_mesh = None, None
    if rank == 0:
        arena_x_mesh, arena_y_mesh = \
             get_2D_arena_spatial_mesh(arena=arena, spatial_resolution=env.stimulus_config['Spatial Resolution'])
    arena_x_mesh = comm.bcast(arena_x_mesh, root=0)
    arena_y_mesh = comm.bcast(arena_y_mesh, root=0)

    local_random = np.random.RandomState()
    selectivity_seed_offset = int(
        env.model_config['Random Seeds']['Input Selectivity'])
    local_random.seed(selectivity_seed_offset - 1)

    selectivity_config = InputSelectivityConfig(env.stimulus_config,
                                                local_random)
    if plot and rank == 0:
        selectivity_config.plot_module_probabilities(**fig_options())

    if (debug or interactive) and rank == 0:
        context.update(dict(locals()))

    pop_norm_distances = {}
    rate_map_sum = {}
    write_every = max(1, int(math.floor(write_size / comm.size)))
    for population in sorted(populations):
        if rank == 0:
            logger.info(
                'Generating input selectivity features for population %s...' %
                population)

        reference_u_arc_distance_bounds = reference_u_arc_distance_bounds_dict[
            population]

        this_pop_norm_distances = {}
        this_rate_map_sum = defaultdict(lambda: np.zeros_like(arena_x_mesh))
        start_time = time.time()
        gid_count = defaultdict(lambda: 0)
        distances_attr_gen = NeuroH5CellAttrGen(coords_path,
                                                population,
                                                namespace=distances_namespace,
                                                comm=comm,
                                                io_size=io_size,
                                                cache_size=cache_size)

        selectivity_attr_dict = dict(
            (key, dict()) for key in env.selectivity_types)
        for iter_count, (gid,
                         distances_attr_dict) in enumerate(distances_attr_gen):
            if gid is not None:
                u_arc_distance = distances_attr_dict['U Distance'][0]
                v_arc_distance = distances_attr_dict['V Distance'][0]
                norm_u_arc_distance = (
                    (u_arc_distance - reference_u_arc_distance_bounds[0]) /
                    (reference_u_arc_distance_bounds[1] -
                     reference_u_arc_distance_bounds[0]))

                this_pop_norm_distances[gid] = norm_u_arc_distance

                this_selectivity_type_name, this_selectivity_attr_dict = \
                 generate_input_selectivity_features(env, population, arena,
                                                     arena_x_mesh, arena_y_mesh,
                                                     gid, (norm_u_arc_distance, v_arc_distance),
                                                     selectivity_config, selectivity_type_names,
                                                     selectivity_type_namespaces,
                                                     rate_map_sum=this_rate_map_sum,
                                                     debug= (debug_callback, context) if debug else False)
                selectivity_attr_dict[this_selectivity_type_name][
                    gid] = this_selectivity_attr_dict
                gid_count[this_selectivity_type_name] += 1

            if (iter_count > 0 and iter_count % write_every
                    == 0) or (debug and iter_count == 10):
                total_gid_count = 0
                gid_count_dict = dict(gid_count.items())
                selectivity_gid_count = comm.reduce(gid_count_dict,
                                                    root=0,
                                                    op=mpi_op_merge_count_dict)
                if rank == 0:
                    for selectivity_type_name in selectivity_gid_count:
                        total_gid_count += selectivity_gid_count[
                            selectivity_type_name]
                    for selectivity_type_name in selectivity_gid_count:
                        logger.info(
                            'generated selectivity features for %i/%i %s %s cells in %.2f s'
                            % (selectivity_gid_count[selectivity_type_name],
                               total_gid_count, population,
                               selectivity_type_name,
                               (time.time() - start_time)))

                if not dry_run:
                    for selectivity_type_name in sorted(
                            selectivity_attr_dict.keys()):
                        if rank == 0:
                            logger.info(
                                'writing selectivity features for %s [%s]...' %
                                (population, selectivity_type_name))
                        selectivity_type_namespace = selectivity_type_namespaces[
                            selectivity_type_name]
                        append_cell_attributes(
                            output_path,
                            population,
                            selectivity_attr_dict[selectivity_type_name],
                            namespace=selectivity_type_namespace,
                            comm=comm,
                            io_size=io_size,
                            chunk_size=chunk_size,
                            value_chunk_size=value_chunk_size)
                del selectivity_attr_dict
                selectivity_attr_dict = dict(
                    (key, dict()) for key in env.selectivity_types)

            if debug and iter_count >= 10:
                break

        pop_norm_distances[population] = this_pop_norm_distances
        rate_map_sum[population] = this_rate_map_sum

        total_gid_count = 0
        gid_count_dict = dict(gid_count.items())
        selectivity_gid_count = comm.reduce(gid_count_dict,
                                            root=0,
                                            op=mpi_op_merge_count_dict)

        if rank == 0:
            for selectivity_type_name in selectivity_gid_count:
                total_gid_count += selectivity_gid_count[selectivity_type_name]
            for selectivity_type_name in selectivity_gid_count:
                logger.info(
                    'generated selectivity features for %i/%i %s %s cells in %.2f s'
                    % (selectivity_gid_count[selectivity_type_name],
                       total_gid_count, population, selectivity_type_name,
                       (time.time() - start_time)))

        if not dry_run:
            for selectivity_type_name in sorted(selectivity_attr_dict.keys()):
                if rank == 0:
                    logger.info('writing selectivity features for %s [%s]...' %
                                (population, selectivity_type_name))
                selectivity_type_namespace = selectivity_type_namespaces[
                    selectivity_type_name]
                append_cell_attributes(
                    output_path,
                    population,
                    selectivity_attr_dict[selectivity_type_name],
                    namespace=selectivity_type_namespace,
                    comm=comm,
                    io_size=io_size,
                    chunk_size=chunk_size,
                    value_chunk_size=value_chunk_size)
        del selectivity_attr_dict
        comm.barrier()

    if gather:
        merged_pop_norm_distances = {}
        for population in sorted(populations):
            merged_pop_norm_distances[population] = \
              comm.reduce(pop_norm_distances[population], root=0,
                          op=mpi_op_merge_dict)
        rate_map_sum = dict([(key, dict(val.items()))
                             for key, val in viewitems(rate_map_sum)])
        rate_map_sum = comm.gather(rate_map_sum, root=0)
        if rank == 0:
            merged_rate_map_sum = defaultdict(
                lambda: defaultdict(lambda: np.zeros_like(arena_x_mesh)))
            for each_rate_map_sum in rate_map_sum:
                for population in each_rate_map_sum:
                    for selectivity_type_name in each_rate_map_sum[population]:
                        merged_rate_map_sum[population][selectivity_type_name] = \
                            np.add(merged_rate_map_sum[population][selectivity_type_name],
                                   each_rate_map_sum[population][selectivity_type_name])
            if plot:
                for population in merged_pop_norm_distances:
                    norm_distance_values = np.asarray(
                        list(merged_pop_norm_distances[population].values()))
                    hist, edges = np.histogram(norm_distance_values, bins=100)
                    fig, axes = plt.subplots(1)
                    axes.plot(edges[1:], hist)
                    axes.set_title('Population: %s' % population)
                    axes.set_xlabel('Normalized cell position')
                    axes.set_ylabel('Cell count')
                    clean_axes(axes)
                    if save_fig is not None:
                        save_figure('%s %s normalized distances histogram' %
                                    (save_fig, population),
                                    fig=fig,
                                    **fig_options())
                    if fig_options.showFig:
                        fig.show()
                for population in merged_rate_map_sum:
                    for selectivity_type_name in merged_rate_map_sum[
                            population]:
                        fig_title = '%s %s summed rate maps' % (
                            population, this_selectivity_type_name)
                        if save_fig is not None:
                            fig_options.saveFig = '%s %s' % (save_fig,
                                                             fig_title)
                        plot_2D_rate_map(
                            x=arena_x_mesh,
                            y=arena_y_mesh,
                            rate_map=merged_rate_map_sum[population]
                            [selectivity_type_name],
                            title='Summed rate maps\n%s %s cells' %
                            (population, selectivity_type_name),
                            **fig_options())

    if interactive and rank == 0:
        context.update(locals())
Beispiel #8
0
def main(stimulus_path, input_stimulus_namespace, output_stimulus_namespace, io_size, chunk_size, value_chunk_size,
         cache_size, seed_offset, trajectory_id, debug):
    """
    :param stimulus_path: str
    :param input_stimulus_namespace: str
    :param output_stimulus_namespace: str
    :param io_size: int
    :param chunk_size: int
    :param value_chunk_size: int
    :param cache_size: int
    :param seed_offset: int
    :param trajectory_id: int
    :param debug: bool
    """
    comm = MPI.COMM_WORLD
    rank = comm.rank

    if io_size == -1:
        io_size = comm.size
    if rank == 0:
        print('%i ranks have been allocated' % comm.size)
    sys.stdout.flush()

    seed_offset *= 2e6
    np.random.seed(int(seed_offset))

    population_ranges = read_population_ranges(comm, stimulus_path)[0]

    input_stimulus_namespace += ' ' + str(trajectory_id)
    output_stimulus_namespace += ' ' + str(trajectory_id)

    for population in ['LPP']:
        population_start = population_ranges[population][0]
        population_count = population_ranges[population][1]

        if rank == 0:
            random_gids = np.arange(0, population_count)
            np.random.shuffle(random_gids)
        else:
            random_gids = None
        random_gids = comm.bcast(random_gids, root=0)

        count = 0
        start_time = time.time()
        attr_gen = NeuroH5CellAttrGen(comm, stimulus_path, population, io_size=io_size,
                                      cache_size=cache_size, namespace=input_stimulus_namespace)
        if debug:
            attr_gen_wrapper = (next(attr_gen) for i in range(2))
        else:
            attr_gen_wrapper = attr_gen
        for gid, stimulus_dict in attr_gen_wrapper:
            local_time = time.time()
            new_response_dict = {}
            if gid is not None:

                random_gid = random_gids[gid-population_start]
                new_response_dict[random_gid] = {'rate': stimulus_dict['rate'],
                                                 'spiketrain': np.asarray(stimulus_dict['spiketrain'],
                                                                          dtype=np.float32),
                                                 'modulation': stimulus_dict['modulation'],
                                                 'peak_index': stimulus_dict['peak_index'] }

                print('Rank %i; source: %s; assigned spike trains for gid %i to gid %i in %.2f s' % \
                      (rank, population, gid, random_gid+population_start, time.time() - local_time))
                count += 1
            if not debug:
                append_cell_attributes(comm, stimulus_path, population, new_response_dict,
                                       namespace=output_stimulus_namespace,
                                       io_size=io_size, chunk_size=chunk_size,
                                       value_chunk_size=value_chunk_size)
            sys.stdout.flush()
            del new_response_dict
            gc.collect()

        global_count = comm.gather(count, root=0)
        if rank == 0:
            print('%i ranks randomized spike trains for %i cells in %.2f s' % (comm.size, np.sum(global_count),
                                                                               time.time() - start_time))
Beispiel #9
0
def main(features_path, prediction_namespace, io_size, chunk_size, value_chunk_size, cache_size,
         trajectory_id, debug):
    """

    :param features_path:
    :param prediction_namespace:
    :param io_size:
    :param chunk_size:
    :param value_chunk_size:
    :param cache_size:
    :param trajectory_id:
    :param debug:
    """
    comm = MPI.COMM_WORLD
    rank = comm.rank

    if io_size == -1:
        io_size = comm.size
    if rank == 0:
        print('%i ranks have been allocated' % comm.size)
    sys.stdout.flush()

    prediction_namespace = prediction_namespace+' '+str(trajectory_id)

    target_population = 'GC'
    count = 0
    start_time = time.time()
    attr_gen = NeuroH5CellAttrGen(comm, features_path, target_population, io_size=io_size,
                                cache_size=cache_size, namespace=prediction_namespace)
    if debug:
        attr_gen_wrapper = (next(attr_gen) for i in range(2))
    else:
        attr_gen_wrapper = attr_gen
    for gid, response_dict in attr_gen_wrapper:
        local_time = time.time()
        response_attr_dict = {}
        response = None
        if gid is not None:
            response_attr_dict[gid] = {}
            response = response_dict[prediction_namespace]['waveform']
            baseline = np.mean(response[np.where(response <= np.percentile(response, 10.))[0]])
            peak = np.mean(response[np.where(response >= np.percentile(response, 90.))[0]])
            modulation = old_div(peak, baseline) - 1.
            peak_index = np.where(response == np.max(response))[0][0]
            response_attr_dict[gid]['modulation'] = np.array([modulation], dtype='float32')
            response_attr_dict[gid]['peak_index'] = np.array([peak_index], dtype='uint32')
            print('Rank %i: took %.2f s to append compute prediction attributes for %s gid %i' % \
                  (rank, time.time() - local_time, target_population, gid))
            count += 1
        if not debug:
            append_cell_attributes(comm, features_path, target_population, response_attr_dict,
                                   namespace=prediction_namespace, io_size=io_size, chunk_size=chunk_size,
                                   value_chunk_size=value_chunk_size)
        sys.stdout.flush()
        del response
        del response_attr_dict
        gc.collect()

    global_count = comm.gather(count, root=0)
    if rank == 0:
        print('%i ranks took %.2f s to compute selectivity parameters for %i %s cells' % \
              (comm.size, time.time() - start_time, np.sum(global_count), target_population))
Beispiel #10
0
def main(file_path, namespace, attribute, population, io_size, cache_size,
         trajectory_id):
    """

    :param file_path: str (path)
    :param namespace: str
    :param attribute: str
    :param population: str
    :param io_size: int
    :param cache_size: int
    :param trajectory_id: int
    """
    comm = MPI.COMM_WORLD
    rank = comm.rank

    if io_size == -1:
        io_size = comm.size
    if rank == 0:
        print('%s: %i ranks have been allocated' %
              (os.path.basename(__file__).split('.py')[0], comm.size))
    sys.stdout.flush()

    trajectory_namespace = 'Trajectory %s' % str(trajectory_id)

    arena_dimension = 100.  # minimum distance from origin to boundary (cm)
    default_run_vel = 30.  # cm/s
    spatial_resolution = 1.  # cm

    with h5py.File(file_path, 'a', driver='mpio', comm=comm) as f:
        if trajectory_namespace not in f:
            print('Rank: %i; Creating %s datasets' %
                  (rank, trajectory_namespace))
            group = f.create_group(trajectory_namespace)
            t, x, y, d = stimulus.generate_linear_trajectory(
                arena_dimension=arena_dimension,
                velocity=default_run_vel,
                spatial_resolution=spatial_resolution)
            for key, value in zip(['x', 'y', 'd', 't'], [x, y, d, t]):
                dataset = group.create_dataset(key, (value.shape[0], ),
                                               dtype='float32')
                with dataset.collective:
                    dataset[:] = value.astype('float32', copy=False)
        else:
            print('Rank: %i; Reading %s datasets' %
                  (rank, trajectory_namespace))
            group = f[trajectory_namespace]
            dataset = group['x']
            with dataset.collective:
                x = dataset[:]
            dataset = group['y']
            with dataset.collective:
                y = dataset[:]
            dataset = group['d']
            with dataset.collective:
                d = dataset[:]
            dataset = group['t']
            with dataset.collective:
                t = dataset[:]

    target = population

    pop_ranges, pop_size = read_population_ranges(file_path, comm=comm)
    target_gid_offset = pop_ranges[target][0]

    attr_gen = NeuroH5CellAttrGen(file_path,
                                  target,
                                  comm=comm,
                                  io_size=io_size,
                                  cache_size=cache_size,
                                  namespace=namespace)
    index_map = get_cell_attributes_index_map(comm, file_path, target,
                                              namespace)

    maxiter = 10
    matched = 0
    processed = 0
    for itercount, (target_gid, attr_dict) in enumerate(attr_gen):
        print(
            'Rank: %i receieved target_gid: %s from the attribute generator.' %
            (rank, str(target_gid)))
        attr_dict2 = select_cell_attributes(
            target_gid,
            comm,
            file_path,
            index_map,
            target,
            namespace,
            population_offset=target_gid_offset)
        if np.all(attr_dict[attribute][:] == attr_dict2[attribute][:]):
            print('Rank: %i; cell attributes match!' % rank)
            matched += 1
        else:
            print('Rank: %i; cell attributes do not match.' % rank)
        comm.barrier()
        processed += 1
        if itercount > maxiter:
            break
    matched = comm.gather(matched, root=0)
    processed = comm.gather(processed, root=0)
    if comm.rank == 0:
        print('%i / %i processed gids had matching cell attributes returned by both read methods' % \
              (np.sum(matched), np.sum(processed)))
Beispiel #11
0
def init_context():
    if 'plot' not in context():
        context.plot = False
    pop_names = ['MPP', 'LPP']
    nwb_spikes_file = NWBHDF5IO(context.nwb_spikes_file_path, 'r')
    nwb_spikes = nwb_spikes_file.read()

    t, pos = get_position(nwb_spikes)
    d = compute_distance_travelled(pos)
    if context.plot:
        fig, axes = plt.subplots()
        axes.plot(t, d)
        axes.set_xlabel('Time (s)')
        axes.set_ylabel('Distance (m)')
        axes.set_title('Spatial trajectory')
        clean_axes(axes)
        fig.show()

    start_time = time.time()
    spike_trains = defaultdict(dict)
    nwb_gids = nwb_spikes.units.id.data[:]
    nwb_cell_types = nwb_spikes.units['cell_type'][:]
    nwb_spike_trains = nwb_spikes.units['spike_times'][:]
    for i in xrange(len(nwb_gids)):
        gid = nwb_gids[i]
        pop_name = nwb_cell_types[i]
        spike_trains[pop_name][gid] = nwb_spike_trains[i]
    del nwb_spikes, nwb_gids, nwb_cell_types, nwb_spike_trains
    nwb_spikes_file.close()

    count = sum([len(spike_trains[pop_name]) for pop_name in spike_trains])
    if context.verbose > 1:
        print 'optimize_baks: pid: %i; loading spikes for %i gids from cell populations: %s took %.1f s' % \
              (os.getpid(), count, ', '.join(str(pop_name) for pop_name in spike_trains), time.time() - start_time)
        sys.stdout.flush()

    if 'block_size' not in context():
        context.block_size = context.num_workers
    gid_block_size = int(math.ceil(float(count) / context.block_size))

    imposed_rates = defaultdict(dict)
    start_time = time.time()
    for pop_name in pop_names:
        count = 0
        cell_attr_gen = NeuroH5CellAttrGen(
            context.neuroh5_rates_file_path,
            pop_name,
            comm=context.comm,
            namespace=context.neuroh5_rates_namespace)
        for gid, attr_dict in cell_attr_gen:
            if gid is not None:
                imposed_rates[pop_name][gid] = attr_dict['rate']
                # spike_trains[pop_name][gid] = attr_dict['spiketrain'] / 1000.  # convert to s
                count += 1
        if context.verbose > 1:
            print 'optimize_baks: pid: %i; loading imposed rates for %i gids from cell population: %s took %.1f s' % \
                  (os.getpid(), count, pop_name, time.time() - start_time)
            sys.stdout.flush()

    if context.plot:
        t_bins = np.linspace(0., max(t), 100)
        d_bins = np.linspace(0., max(d), 100)
        fig, axes = plt.subplots(2, 2)
        for i, pop_name in enumerate(['MPP', 'LPP']):
            axes[0, i].set_title('Imposed rates: %s' % pop_name)
            axes[0, i].set_xlabel('Distance (m)')
            axes[0, i].set_ylabel('Firing rate (Hz)')
            axes[1, i].set_title('Binned spike counts: %s' % pop_name)
            axes[1, i].set_xlabel('Distance (m)')
            axes[1, i].set_ylabel('Count')
            count = 0
            for gid in imposed_rates[pop_name]:
                rate = imposed_rates[pop_name][gid]
                if np.max(rate) > 5.:
                    hist, edges = np.histogram(spike_trains[pop_name][gid],
                                               bins=t_bins)
                    axes[0, i].plot(d, rate)
                    axes[1, i].plot(d_bins[1:], hist)
                    count += 1
                if count > 20:
                    break
        clean_axes(axes)
        fig.tight_layout()
        fig.show()

    plotted = {pop_name: False for pop_name in pop_names}
    min_field_len = int(context.min_field_width / max(d) * len(d))
    context.update(locals())
Beispiel #12
0
def generate_uv_distance_connections(comm,
                                     population_dict,
                                     connection_config,
                                     connection_prob,
                                     forest_path,
                                     synapse_seed,
                                     connectivity_seed,
                                     cluster_seed,
                                     synapse_namespace,
                                     connectivity_namespace,
                                     connectivity_path,
                                     io_size,
                                     chunk_size,
                                     value_chunk_size,
                                     cache_size,
                                     write_size=1,
                                     dry_run=False):
    """
    Generates connectivity based on U, V distance-weighted probabilities.

    :param comm: mpi4py MPI communicator
    :param connection_config: connection configuration object (instance of env.ConnectionConfig)
    :param connection_prob: ConnectionProb instance
    :param forest_path: location of file with neuronal trees and synapse information
    :param synapse_seed: random seed for synapse partitioning
    :param connectivity_seed: random seed for connectivity generation
    :param cluster_seed: random seed for determining connectivity clustering for repeated connections from the same source
    :param synapse_namespace: namespace of synapse properties
    :param connectivity_namespace: namespace of connectivity attributes
    :param io_size: number of I/O ranks to use for parallel connectivity append
    :param chunk_size: HDF5 chunk size for connectivity file (pointer and index datasets)
    :param value_chunk_size: HDF5 chunk size for connectivity file (value datasets)
    :param cache_size: how many cells to read ahead
    :param write_size: how many cells to write out at the same time
    """

    rank = comm.rank

    if io_size == -1:
        io_size = comm.size
    if rank == 0:
        logger.info('%i ranks have been allocated' % comm.size)

    start_time = time.time()

    ranstream_syn = np.random.RandomState()
    ranstream_con = np.random.RandomState()

    destination_population = connection_prob.destination_population

    source_populations = sorted(
        connection_config[destination_population].keys())

    for source_population in source_populations:
        if rank == 0:
            logger.info('%s -> %s:' %
                        (source_population, destination_population))
            logger.info(
                str(connection_config[destination_population]
                    [source_population]))

    projection_config = connection_config[destination_population]
    projection_synapse_dict = {
        source_population: (projection_config[source_population].type,
                            projection_config[source_population].layers,
                            projection_config[source_population].sections,
                            projection_config[source_population].proportions,
                            projection_config[source_population].contacts)
        for source_population in source_populations
    }

    comm.barrier()

    total_count = 0
    gid_count = 0
    connection_dict = defaultdict(lambda: {})
    projection_dict = {}
    for destination_gid, synapse_dict in NeuroH5CellAttrGen(forest_path, \
                                                            destination_population, \
                                                            namespace=synapse_namespace, \
                                                            comm=comm, io_size=io_size, \
                                                            cache_size=cache_size):
        last_time = time.time()
        if destination_gid is None:
            logger.info('Rank %i destination gid is None' % rank)
        else:
            logger.info(
                'Rank %i received attributes for destination: %s, gid: %i' %
                (rank, destination_population, destination_gid))
            ranstream_con.seed(destination_gid + connectivity_seed)
            ranstream_syn.seed(destination_gid + synapse_seed)

            projection_prob_dict = {}
            for source_population in source_populations:
                source_layers = projection_config[source_population].layers
                projection_prob_dict[source_population] = \
                    connection_prob.get_prob(destination_gid, source_population, source_layers)


                for layer, (probs, source_gids, distances_u, distances_v) in \
                        viewitems(projection_prob_dict[source_population]):
                    if len(distances_u) > 0:
                        max_u_distance = np.max(distances_u)
                        min_u_distance = np.min(distances_u)
                        if rank == 0:
                            logger.info(
                                'Rank %i has %d possible sources from population %s for destination: %s, layer %s, gid: %i; max U distance: %f min U distance: %f'
                                % (rank, len(source_gids),
                                   source_population, destination_population,
                                   str(layer), destination_gid, max_u_distance,
                                   min_u_distance))
                    else:
                        logger.warning(
                            'Rank %i has %d possible sources from population %s for destination: %s, layer %s, gid: %i'
                            % (rank, len(source_gids),
                               source_population, destination_population,
                               str(layer), destination_gid))
            count = generate_synaptic_connections(
                rank, destination_gid, ranstream_syn, ranstream_con,
                cluster_seed + destination_gid, destination_gid, synapse_dict,
                population_dict, projection_synapse_dict, projection_prob_dict,
                connection_dict)
            total_count += count

            logger.info(
                'Rank %i took %i s to compute %d edges for destination: %s, gid: %i'
                % (rank, time.time() - last_time, count,
                   destination_population, destination_gid))

        if gid_count % write_size == 0:
            last_time = time.time()
            if len(connection_dict) > 0:
                projection_dict = {destination_population: connection_dict}
            else:
                projection_dict = {}
            if not dry_run:
                append_graph(connectivity_path,
                             projection_dict,
                             io_size=io_size,
                             comm=comm)
            if rank == 0:
                if connection_dict:
                    for (prj, prj_dict) in viewitems(connection_dict):
                        logger.info("%s: %s" %
                                    (prj, str(list(prj_dict.keys()))))
                    logger.info(
                        'Appending connectivity for %i projections took %i s' %
                        (len(connection_dict), time.time() - last_time))
            projection_dict.clear()
            connection_dict.clear()
            gc.collect()

        gid_count += 1

    last_time = time.time()
    if len(connection_dict) > 0:
        projection_dict = {destination_population: connection_dict}
    else:
        projection_dict = {}
    if not dry_run:
        append_graph(connectivity_path,
                     projection_dict,
                     io_size=io_size,
                     comm=comm)
    if rank == 0:
        if connection_dict:
            for (prj, prj_dict) in viewitems(connection_dict):
                logger.info("%s: %s" % (prj, str(list(prj_dict.keys()))))
                logger.info(
                    'Appending connectivity for %i projections took %i s' %
                    (len(connection_dict), time.time() - last_time))

    global_count = comm.gather(total_count, root=0)
    if rank == 0:
        logger.info(
            '%i ranks took %i s to generate %i edges' %
            (comm.size, time.time() - start_time, np.sum(global_count)))
Beispiel #13
0
def main(config, config_prefix, coords_path, distances_namespace, bin_distance,
         selectivity_path, selectivity_namespace, spatial_resolution, arena_id,
         populations, io_size, cache_size, verbose, debug, show_fig, save_fig,
         save_fig_dir, font_size, fig_size, colormap, fig_format):
    """

    :param config: str (.yaml file name)
    :param config_prefix: str (path to dir)
    :param coords_path: str (path to file)
    :param distances_namespace: str
    :param bin_distance: float
    :param selectivity_path: str
    :param arena_id: str
    :param populations: tuple of str
    :param io_size: int
    :param cache_size: int
    :param verbose: bool
    :param debug: bool
    :param show_fig: bool
    :param save_fig: str (base file name)
    :param save_fig_dir:  str (path to dir)
    :param font_size: float
    :param fig_format: str
    """
    comm = MPI.COMM_WORLD
    rank = comm.rank

    config_logging(verbose)

    env = Env(comm=comm,
              config_file=config,
              config_prefix=config_prefix,
              template_paths=None)
    if io_size == -1:
        io_size = comm.size
    if rank == 0:
        logger.info(f'{comm.size} ranks have been allocated')

    fig_options = copy.copy(default_fig_options)
    fig_options.saveFigDir = save_fig_dir
    fig_options.fontSize = font_size
    fig_options.figFormat = fig_format
    fig_options.showFig = show_fig
    fig_options.figSize = fig_size

    if save_fig is not None:
        save_fig = f'{save_fig} {arena_id}'
    fig_options.saveFig = save_fig

    population_ranges = read_population_ranges(selectivity_path, comm)[0]
    coords_population_ranges = read_population_ranges(coords_path, comm)[0]

    if len(populations) == 0:
        populations = ('MC', 'ConMC', 'LPP', 'GC', 'MPP', 'CA3c')

    valid_selectivity_namespaces = dict()
    if rank == 0:
        for population in populations:
            if population not in population_ranges:
                raise RuntimeError(
                    f'plot_input_selectivity_features: specified population: {population} not found in '
                    f'provided selectivity_path: {selectivity_path}')
            if population not in env.stimulus_config[
                    'Selectivity Type Probabilities']:
                raise RuntimeError(
                    'plot_input_selectivity_features: selectivity type not specified for '
                    f'population: {population}')
            valid_selectivity_namespaces[population] = []
            with h5py.File(selectivity_path, 'r') as selectivity_f:
                for this_namespace in selectivity_f['Populations'][population]:
                    if f'{selectivity_namespace} {arena_id}' in this_namespace:
                        valid_selectivity_namespaces[population].append(
                            this_namespace)
                if len(valid_selectivity_namespaces[population]) == 0:
                    raise RuntimeError(
                        f'plot_input_selectivity_features: no selectivity data in arena: {arena_id} found '
                        f'for specified population: {population} in provided selectivity_path: {selectivity_path}'
                    )

    valid_selectivity_namespaces = comm.bcast(valid_selectivity_namespaces,
                                              root=0)
    selectivity_type_names = dict(
        (val, key) for (key, val) in viewitems(env.selectivity_types))

    reference_u_arc_distance_bounds = None
    reference_v_arc_distance_bounds = None
    if rank == 0:
        for population in populations:
            if population not in coords_population_ranges:
                raise RuntimeError(
                    f'plot_input_selectivity_features: specified population: {population} not found in '
                    f'provided coords_path: {coords_path}')
            with h5py.File(coords_path, 'r') as coords_f:
                pop_size = population_ranges[population][1]
                unique_gid_count = len(
                    set(coords_f['Populations'][population]
                        [distances_namespace]['U Distance']['Cell Index'][:]))
                if pop_size != unique_gid_count:
                    raise RuntimeError(
                        f'plot_input_selectivity_features: only {unique_gid_count}/{pop_size} unique cell indexes found '
                        f'for specified population: {population} in provided coords_path: {coords_path}'
                    )
                if reference_u_arc_distance_bounds is None:
                    try:
                        reference_u_arc_distance_bounds = \
                            coords_f['Populations'][population][distances_namespace].attrs['Reference U Min'], \
                            coords_f['Populations'][population][distances_namespace].attrs['Reference U Max']
                    except Exception:
                        raise RuntimeError(
                            'plot_input_selectivity_features: problem locating attributes '
                            f'containing reference bounds in namespace: {distances_namespace} '
                            f'for population: {population} from coords_path: {coords_path}'
                        )
                if reference_v_arc_distance_bounds is None:
                    try:
                        reference_v_arc_distance_bounds = \
                            coords_f['Populations'][population][distances_namespace].attrs['Reference V Min'], \
                            coords_f['Populations'][population][distances_namespace].attrs['Reference V Max']
                    except Exception:
                        raise RuntimeError(
                            'plot_input_selectivity_features: problem locating attributes '
                            f'containing reference bounds in namespace: {distances_namespace} '
                            f'for population: {population} from coords_path: {coords_path}'
                        )
    reference_u_arc_distance_bounds = comm.bcast(
        reference_u_arc_distance_bounds, root=0)
    reference_v_arc_distance_bounds = comm.bcast(
        reference_v_arc_distance_bounds, root=0)

    u_edges = np.arange(reference_u_arc_distance_bounds[0],
                        reference_u_arc_distance_bounds[1] + bin_distance / 2.,
                        bin_distance)
    v_edges = np.arange(reference_v_arc_distance_bounds[0],
                        reference_v_arc_distance_bounds[1] + bin_distance / 2.,
                        bin_distance)

    if arena_id not in env.stimulus_config['Arena']:
        raise RuntimeError(
            f'Arena with ID: {arena_id} not specified by configuration at file path: {config_prefix}/{config}'
        )

    if spatial_resolution is None:
        spatial_resolution = env.stimulus_config['Spatial Resolution']
    arena = env.stimulus_config['Arena'][arena_id]
    arena_x_mesh, arena_y_mesh = None, None
    if rank == 0:
        arena_x_mesh, arena_y_mesh = \
            get_2D_arena_spatial_mesh(arena=arena, spatial_resolution=spatial_resolution)
    arena_x_mesh = comm.bcast(arena_x_mesh, root=0)
    arena_y_mesh = comm.bcast(arena_y_mesh, root=0)
    x0_dict = {}
    y0_dict = {}

    for population in populations:

        start_time = time.time()
        u_distances_by_gid = dict()
        v_distances_by_gid = dict()
        distances_attr_gen = \
            bcast_cell_attributes(coords_path, population, root=0, namespace=distances_namespace, comm=comm)
        for gid, distances_attr_dict in distances_attr_gen:
            u_distances_by_gid[gid] = distances_attr_dict['U Distance'][0]
            v_distances_by_gid[gid] = distances_attr_dict['V Distance'][0]

        if rank == 0:
            logger.info(
                f'Reading {len(u_distances_by_gid)} cell positions for population {population} took '
                f'{time.time() - start_time:.2f} s')

        for this_selectivity_namespace in valid_selectivity_namespaces[
                population]:
            start_time = time.time()
            if rank == 0:
                logger.info(
                    f'Reading from {this_selectivity_namespace} namespace for population {population}...'
                )
            gid_count = 0
            gathered_cell_attributes = defaultdict(list)
            gathered_component_attributes = defaultdict(list)
            u_distances_by_cell = list()
            v_distances_by_cell = list()
            u_distances_by_component = list()
            v_distances_by_component = list()
            rate_map_sum_by_module = defaultdict(
                lambda: np.zeros_like(arena_x_mesh))
            count_by_module = defaultdict(int)
            start_time = time.time()
            x0_list_by_module = defaultdict(list)
            y0_list_by_module = defaultdict(list)
            selectivity_attr_gen = NeuroH5CellAttrGen(
                selectivity_path,
                population,
                namespace=this_selectivity_namespace,
                comm=comm,
                io_size=io_size,
                cache_size=cache_size)
            for iter_count, (
                    gid,
                    selectivity_attr_dict) in enumerate(selectivity_attr_gen):
                if gid is not None:
                    gid_count += 1
                    this_selectivity_type = selectivity_attr_dict[
                        'Selectivity Type'][0]
                    this_selectivity_type_name = selectivity_type_names[
                        this_selectivity_type]
                    input_cell_config = \
                        get_input_cell_config(selectivity_type=this_selectivity_type,
                                               selectivity_type_names=selectivity_type_names,
                                               selectivity_attr_dict=selectivity_attr_dict)
                    rate_map = input_cell_config.get_rate_map(x=arena_x_mesh,
                                                              y=arena_y_mesh)
                    u_distances_by_cell.append(u_distances_by_gid[gid])
                    v_distances_by_cell.append(v_distances_by_gid[gid])
                    this_cell_attrs, component_count, this_component_attrs = input_cell_config.gather_attributes(
                    )
                    for attr_name, attr_val in viewitems(this_cell_attrs):
                        gathered_cell_attributes[attr_name].append(attr_val)
                    gathered_cell_attributes['Mean Rate'].append(
                        np.mean(rate_map))
                    if component_count > 0:
                        u_distances_by_component.extend(
                            [u_distances_by_gid[gid]] * component_count)
                        v_distances_by_component.extend(
                            [v_distances_by_gid[gid]] * component_count)
                        for attr_name, attr_val in viewitems(
                                this_component_attrs):
                            gathered_component_attributes[attr_name].extend(
                                attr_val)
                    this_module_id = this_cell_attrs['Module ID']
                    if debug and rank == 0:
                        fig_title = f'{population} {this_selectivity_type_name} cell {gid}'
                        if save_fig is not None:
                            fig_options.saveFig = f'{save_fig} {fig_title}'
                        plot_2D_rate_map(
                            x=arena_x_mesh,
                            y=arena_y_mesh,
                            rate_map=rate_map,
                            peak_rate=env.stimulus_config['Peak Rate']
                            [population][this_selectivity_type],
                            title=f'{fig_title}\nModule: {this_module_id}',
                            **fig_options())
                    x0_list_by_module[this_module_id].append(
                        selectivity_attr_dict['X Offset'])
                    y0_list_by_module[this_module_id].append(
                        selectivity_attr_dict['Y Offset'])
                    rate_map_sum_by_module[this_module_id] = np.add(
                        rate_map, rate_map_sum_by_module[this_module_id])
                    count_by_module[this_module_id] += 1
                if debug and iter_count >= 10:
                    break

            if rank == 0:
                logger.info(
                    f'Done reading from {this_selectivity_namespace} namespace for population {population}...'
                )

            cell_count_hist, _, _ = np.histogram2d(u_distances_by_cell,
                                                   v_distances_by_cell,
                                                   bins=[u_edges, v_edges])
            component_count_hist, _, _ = np.histogram2d(
                u_distances_by_component,
                v_distances_by_component,
                bins=[u_edges, v_edges])

            if debug:
                context.update(locals())

            gathered_cell_attr_hist = dict()
            gathered_component_attr_hist = dict()
            for key in gathered_cell_attributes:
                gathered_cell_attr_hist[key], _, _ = \
                    np.histogram2d(u_distances_by_cell, v_distances_by_cell, bins=[u_edges, v_edges],
                                   weights=gathered_cell_attributes[key])
            for key in gathered_component_attributes:
                gathered_component_attr_hist[key], _, _ = \
                    np.histogram2d(u_distances_by_component, v_distances_by_component, bins=[u_edges, v_edges],
                                   weights=gathered_component_attributes[key])
            gid_count = comm.gather(gid_count, root=0)
            cell_count_hist = comm.gather(cell_count_hist, root=0)
            component_count_hist = comm.gather(component_count_hist, root=0)
            gathered_cell_attr_hist = comm.gather(gathered_cell_attr_hist,
                                                  root=0)
            gathered_component_attr_hist = comm.gather(
                gathered_component_attr_hist, root=0)
            x0_list_by_module = dict(x0_list_by_module)
            y0_list_by_module = dict(y0_list_by_module)
            x0_list_by_module = comm.reduce(x0_list_by_module,
                                            op=mpi_op_merge_list_dict,
                                            root=0)
            y0_list_by_module = comm.reduce(y0_list_by_module,
                                            op=mpi_op_merge_list_dict,
                                            root=0)
            rate_map_sum_by_module = dict(rate_map_sum_by_module)
            rate_map_sum_by_module = comm.gather(rate_map_sum_by_module,
                                                 root=0)
            count_by_module = dict(count_by_module)
            count_by_module = comm.reduce(count_by_module,
                                          op=mpi_op_merge_count_dict,
                                          root=0)

            if rank == 0:
                gid_count = sum(gid_count)
                cell_count_hist = np.sum(cell_count_hist, axis=0)
                component_count_hist = np.sum(component_count_hist, axis=0)
                merged_cell_attr_hist = defaultdict(
                    lambda: np.zeros_like(cell_count_hist))
                merged_component_attr_hist = defaultdict(
                    lambda: np.zeros_like(component_count_hist))
                for each_cell_attr_hist in gathered_cell_attr_hist:
                    for key in each_cell_attr_hist:
                        merged_cell_attr_hist[key] = np.add(
                            merged_cell_attr_hist[key],
                            each_cell_attr_hist[key])
                for each_component_attr_hist in gathered_component_attr_hist:
                    for key in each_component_attr_hist:
                        merged_component_attr_hist[key] = np.add(
                            merged_component_attr_hist[key],
                            each_component_attr_hist[key])
                merged_rate_map_sum_by_module = defaultdict(
                    lambda: np.zeros_like(arena_x_mesh))
                for each_rate_map_sum_by_module in rate_map_sum_by_module:
                    for this_module_id in each_rate_map_sum_by_module:
                        merged_rate_map_sum_by_module[this_module_id] = \
                            np.add(merged_rate_map_sum_by_module[this_module_id],
                                   each_rate_map_sum_by_module[this_module_id])

                logger.info(
                    f'Processing {gid_count} {population} {this_selectivity_type_name} cells '
                    f'took {time.time() - start_time:.2f} s')

                if debug:
                    context.update(locals())

                fig_title = f'{population} {this_selectivity_type_name} field offsets'
                if save_fig is not None:
                    fig_options.saveFig = f'{save_fig} {fig_title}'

                for key in merged_cell_attr_hist:
                    fig_title = f'{population} {this_selectivity_type_name} cells {key} distribution'
                    if save_fig is not None:
                        fig_options.saveFig = f'{save_fig} {fig_title}'
                    if colormap is not None:
                        fig_options.colormap = colormap
                    title = f'{population} {this_selectivity_type_name} cells\n{key} distribution'
                    fig = plot_2D_histogram(
                        merged_cell_attr_hist[key],
                        x_edges=u_edges,
                        y_edges=v_edges,
                        norm=cell_count_hist,
                        ylabel='Transverse position (um)',
                        xlabel='Septo-temporal position (um)',
                        title=title,
                        cbar_label='Mean value per bin',
                        cbar=True,
                        **fig_options())
                    close_figure(fig)

                for key in merged_component_attr_hist:
                    fig_title = f'{population} {this_selectivity_type_name} cells {key} distribution'
                    if save_fig is not None:
                        fig_options.saveFig = f'{save_fig} {fig_title}'
                    title = f'{population} {this_selectivity_type_name} cells\n{key} distribution'
                    fig = plot_2D_histogram(
                        merged_component_attr_hist[key],
                        x_edges=u_edges,
                        y_edges=v_edges,
                        norm=component_count_hist,
                        ylabel='Transverse position (um)',
                        xlabel='Septo-temporal position (um)',
                        title=title,
                        cbar_label='Mean value per bin',
                        cbar=True,
                        **fig_options())
                    close_figure(fig)

                for this_module_id in merged_rate_map_sum_by_module:
                    num_cells = count_by_module[this_module_id]
                    x0 = np.concatenate(x0_list_by_module[this_module_id])
                    y0 = np.concatenate(y0_list_by_module[this_module_id])
                    fig_title = f'{population} {this_selectivity_type_name} Module {this_module_id} rate map'
                    if save_fig is not None:
                        fig_options.saveFig = f'{save_fig} {fig_title}'
                    fig = plot_2D_rate_map(
                        x=arena_x_mesh,
                        y=arena_y_mesh,
                        x0=x0,
                        y0=y0,
                        rate_map=merged_rate_map_sum_by_module[this_module_id],
                        title=
                        (f'{population} {this_selectivity_type_name} rate map\n'
                         f'Module {this_module_id} ({num_cells} cells)'),
                        **fig_options())
                    close_figure(fig)

    if is_interactive and rank == 0:
        context.update(locals())
comm = MPI.COMM_WORLD
rank = comm.rank  # The process ID (integer 0-3 for 4-process run)

if rank == 0:
    print('%i ranks have been allocated' % comm.size)
sys.stdout.flush()

#neurotrees_dir = os.environ['SCRATCH']+'/dentate/Full_Scale_Control/'
#forest_file = 'DGC_forest_test_syns_20171019.h5'
#forest_file = 'DGC_forest_syns_compressed_20180306.h5'

neurotrees_dir = "./tests/"
forest_file = "MC_BC_trees_20180817.h5"

g = NeuroH5CellAttrGen(neurotrees_dir + forest_file,
                       'MC',
                       comm=comm,
                       io_size=2,
                       namespace='Synapse Attributes')
global_count = 0
count = 0
for destination_gid, synapse_dict in g:
    if destination_gid is None:
        print('Rank %i destination gid is None' % rank)
    else:
        print('Rank: %i, gid: %i, count: %i' % (rank, destination_gid, count))
        count += 1
global_count = comm.gather(count, root=0)
if rank == 0:
    print('Total: %i' % np.sum(global_count))
def main(config, config_prefix, coords_path, distances_namespace, output_path,
         arena_id, populations, use_noise_gen, io_size, chunk_size,
         value_chunk_size, cache_size, write_size, verbose, gather,
         interactive, debug, debug_count, plot, show_fig, save_fig,
         save_fig_dir, font_size, fig_format, dry_run):
    """

    :param config: str (.yaml file name)
    :param config_prefix: str (path to dir)
    :param coords_path: str (path to file)
    :param distances_namespace: str
    :param output_path: str
    :param arena_id: str
    :param populations: tuple of str
    :param io_size: int
    :param chunk_size: int
    :param value_chunk_size: int
    :param cache_size: int
    :param write_size: int
    :param verbose: bool
    :param gather: bool; whether to gather population attributes to rank 0 for interactive analysis or plotting
    :param interactive: bool
    :param debug: bool
    :param plot: bool
    :param show_fig: bool
    :param save_fig: str (base file name)
    :param save_fig_dir:  str (path to dir)
    :param font_size: float
    :param fig_format: str
    :param dry_run: bool
    """
    comm = MPI.COMM_WORLD
    rank = comm.rank

    config_logging(verbose)

    env = Env(comm=comm,
              config_file=config,
              config_prefix=config_prefix,
              template_paths=None)
    if io_size == -1:
        io_size = comm.size
    if rank == 0:
        logger.info(f'{comm.size} ranks have been allocated')

    if save_fig is not None:
        plot = True

    if plot:
        import matplotlib.pyplot as plt
        from dentate.plot import plot_2D_rate_map, default_fig_options, save_figure, clean_axes, close_figure

        fig_options = copy.copy(default_fig_options)
        fig_options.saveFigDir = save_fig_dir
        fig_options.fontSize = font_size
        fig_options.figFormat = fig_format
        fig_options.showFig = show_fig

    if save_fig is not None:
        save_fig = '%s %s' % (save_fig, arena_id)
        fig_options.saveFig = save_fig

    if not dry_run and rank == 0:
        if output_path is None:
            raise RuntimeError(
                'generate_input_selectivity_features: missing output_path')
        if not os.path.isfile(output_path):
            input_file = h5py.File(coords_path, 'r')
            output_file = h5py.File(output_path, 'w')
            input_file.copy('/H5Types', output_file)
            input_file.close()
            output_file.close()
    comm.barrier()
    population_ranges = read_population_ranges(coords_path, comm)[0]

    if len(populations) == 0:
        populations = sorted(population_ranges.keys())

    reference_u_arc_distance_bounds_dict = {}
    if rank == 0:
        for population in sorted(populations):
            if population not in population_ranges:
                raise RuntimeError(
                    'generate_input_selectivity_features: specified population: %s not found in '
                    'provided coords_path: %s' % (population, coords_path))
            if population not in env.stimulus_config[
                    'Selectivity Type Probabilities']:
                raise RuntimeError(
                    'generate_input_selectivity_features: selectivity type not specified for '
                    'population: %s' % population)
            with h5py.File(coords_path, 'r') as coords_f:
                pop_size = population_ranges[population][1]
                unique_gid_count = len(
                    set(coords_f['Populations'][population]
                        [distances_namespace]['U Distance']['Cell Index'][:]))
                if pop_size != unique_gid_count:
                    raise RuntimeError(
                        'generate_input_selectivity_features: only %i/%i unique cell indexes found '
                        'for specified population: %s in provided coords_path: %s'
                        %
                        (unique_gid_count, pop_size, population, coords_path))
                try:
                    reference_u_arc_distance_bounds_dict[population] = \
                      coords_f['Populations'][population][distances_namespace].attrs['Reference U Min'], \
                      coords_f['Populations'][population][distances_namespace].attrs['Reference U Max']
                except Exception:
                    raise RuntimeError(
                        'generate_input_selectivity_features: problem locating attributes '
                        'containing reference bounds in namespace: %s for population: %s from '
                        'coords_path: %s' %
                        (distances_namespace, population, coords_path))
    comm.barrier()
    reference_u_arc_distance_bounds_dict = comm.bcast(
        reference_u_arc_distance_bounds_dict, root=0)

    selectivity_type_names = dict([
        (val, key) for (key, val) in viewitems(env.selectivity_types)
    ])
    selectivity_type_namespaces = dict()
    for this_selectivity_type in selectivity_type_names:
        this_selectivity_type_name = selectivity_type_names[
            this_selectivity_type]
        chars = list(this_selectivity_type_name)
        chars[0] = chars[0].upper()
        selectivity_type_namespaces[this_selectivity_type_name] = ''.join(
            chars) + ' Selectivity %s' % arena_id

    if arena_id not in env.stimulus_config['Arena']:
        raise RuntimeError(
            f'Arena with ID: {arena_id} not specified by configuration at file path: {config_prefix}/{config}'
        )
    arena = env.stimulus_config['Arena'][arena_id]
    arena_x_mesh, arena_y_mesh = None, None
    if rank == 0:
        arena_x_mesh, arena_y_mesh = \
             get_2D_arena_spatial_mesh(arena=arena, spatial_resolution=env.stimulus_config['Spatial Resolution'])
    arena_x_mesh = comm.bcast(arena_x_mesh, root=0)
    arena_y_mesh = comm.bcast(arena_y_mesh, root=0)

    local_random = np.random.RandomState()
    selectivity_seed_offset = int(
        env.model_config['Random Seeds']['Input Selectivity'])
    local_random.seed(selectivity_seed_offset - 1)

    selectivity_config = InputSelectivityConfig(env.stimulus_config,
                                                local_random)
    if plot and rank == 0:
        selectivity_config.plot_module_probabilities(**fig_options())

    if (debug or interactive) and rank == 0:
        context.update(dict(locals()))

    pop_norm_distances = {}
    rate_map_sum = {}
    x0_dict = {}
    y0_dict = {}
    write_every = max(1, int(math.floor(write_size / comm.size)))
    for population in sorted(populations):
        if rank == 0:
            logger.info(
                f'Generating input selectivity features for population {population}...'
            )

        reference_u_arc_distance_bounds = reference_u_arc_distance_bounds_dict[
            population]

        modular = True
        if population in env.stimulus_config[
                'Non-modular Place Selectivity Populations']:
            modular = False

        noise_gen_dict = None
        if use_noise_gen:
            noise_gen_dict = {}
            if modular:
                for module_id in range(env.stimulus_config['Number Modules']):
                    extent_x, extent_y = get_2D_arena_extents(arena)
                    margin = round(
                        selectivity_config.place_module_field_widths[module_id]
                        / 2.)
                    arena_x_bounds, arena_y_bounds = get_2D_arena_bounds(
                        arena, margin=margin)
                    noise_gen = MPINoiseGenerator(
                        comm=comm,
                        bounds=(arena_x_bounds, arena_y_bounds),
                        tile_rank=comm.rank,
                        bin_size=0.5,
                        mask_fraction=0.99,
                        seed=int(selectivity_seed_offset + module_id * 1e6))
                    noise_gen_dict[module_id] = noise_gen
            else:
                margin = round(
                    np.mean(selectivity_config.place_module_field_widths) / 2.)
                arena_x_bounds, arena_y_bounds = get_2D_arena_bounds(
                    arena, margin=margin)
                noise_gen_dict[-1] = MPINoiseGenerator(
                    comm=comm,
                    bounds=(arena_x_bounds, arena_y_bounds),
                    tile_rank=comm.rank,
                    bin_size=0.5,
                    mask_fraction=0.99,
                    seed=selectivity_seed_offset)

        this_pop_norm_distances = {}
        this_rate_map_sum = defaultdict(lambda: np.zeros_like(arena_x_mesh))
        this_x0_list = []
        this_y0_list = []
        start_time = time.time()
        gid_count = defaultdict(lambda: 0)
        distances_attr_gen = NeuroH5CellAttrGen(coords_path,
                                                population,
                                                namespace=distances_namespace,
                                                comm=comm,
                                                io_size=io_size,
                                                cache_size=cache_size)

        selectivity_attr_dict = dict(
            (key, dict()) for key in env.selectivity_types)
        for iter_count, (gid,
                         distances_attr_dict) in enumerate(distances_attr_gen):
            req = comm.Ibarrier()
            if gid is None:
                if noise_gen_dict is not None:
                    all_module_ids = [-1]
                    if modular:
                        all_module_ids = comm.allreduce(set([]),
                                                        op=mpi_op_set_union)
                    for module_id in all_module_ids:
                        this_noise_gen = noise_gen_dict[module_id]
                        global_num_fields = this_noise_gen.sync(0)
                        for i in range(global_num_fields):
                            this_noise_gen.add(
                                np.empty(shape=(0, 0), dtype=np.float32), None)
            else:
                if rank == 0:
                    logger.info(
                        f'Rank {rank} generating selectivity features for gid {gid}...'
                    )
                u_arc_distance = distances_attr_dict['U Distance'][0]
                v_arc_distance = distances_attr_dict['V Distance'][0]
                norm_u_arc_distance = (
                    (u_arc_distance - reference_u_arc_distance_bounds[0]) /
                    (reference_u_arc_distance_bounds[1] -
                     reference_u_arc_distance_bounds[0]))

                this_pop_norm_distances[gid] = norm_u_arc_distance

                this_selectivity_type_name, this_selectivity_attr_dict = \
                 generate_input_selectivity_features(env, population, arena,
                                                     arena_x_mesh, arena_y_mesh,
                                                     gid, (norm_u_arc_distance, v_arc_distance),
                                                     selectivity_config, selectivity_type_names,
                                                     selectivity_type_namespaces,
                                                     noise_gen_dict=noise_gen_dict,
                                                     rate_map_sum=this_rate_map_sum,
                                                     debug= (debug_callback, context) if debug else False)
                if 'X Offset' in this_selectivity_attr_dict:
                    this_x0_list.append(this_selectivity_attr_dict['X Offset'])
                    this_y0_list.append(this_selectivity_attr_dict['Y Offset'])
                selectivity_attr_dict[this_selectivity_type_name][
                    gid] = this_selectivity_attr_dict
                gid_count[this_selectivity_type_name] += 1
            if noise_gen_dict is not None:
                for m in noise_gen_dict:
                    noise_gen_dict[m].tile_rank = (
                        noise_gen_dict[m].tile_rank + 1) % comm.size
            req.wait()

            if (iter_count > 0 and iter_count % write_every
                    == 0) or (debug and iter_count == debug_count):
                total_gid_count = 0
                gid_count_dict = dict(gid_count.items())
                req = comm.Ibarrier()
                selectivity_gid_count = comm.reduce(gid_count_dict,
                                                    root=0,
                                                    op=mpi_op_merge_count_dict)
                req.wait()
                if rank == 0:
                    for selectivity_type_name in selectivity_gid_count:
                        total_gid_count += selectivity_gid_count[
                            selectivity_type_name]
                    for selectivity_type_name in selectivity_gid_count:
                        logger.info(
                            'generated selectivity features for %i/%i %s %s cells in %.2f s'
                            % (selectivity_gid_count[selectivity_type_name],
                               total_gid_count, population,
                               selectivity_type_name,
                               (time.time() - start_time)))

                if not dry_run:
                    for selectivity_type_name in sorted(
                            selectivity_attr_dict.keys()):
                        req = comm.Ibarrier()
                        if rank == 0:
                            logger.info(
                                f'writing selectivity features for {population} [{selectivity_type_name}]...'
                            )
                        selectivity_type_namespace = selectivity_type_namespaces[
                            selectivity_type_name]
                        append_cell_attributes(
                            output_path,
                            population,
                            selectivity_attr_dict[selectivity_type_name],
                            namespace=selectivity_type_namespace,
                            comm=comm,
                            io_size=io_size,
                            chunk_size=chunk_size,
                            value_chunk_size=value_chunk_size)
                        req.wait()
                    del selectivity_attr_dict
                    selectivity_attr_dict = dict(
                        (key, dict()) for key in env.selectivity_types)
                    gc.collect()

            if debug and iter_count >= debug_count:
                break

        pop_norm_distances[population] = this_pop_norm_distances
        rate_map_sum[population] = dict(this_rate_map_sum)
        if len(this_x0_list) > 0:
            x0_dict[population] = np.concatenate(this_x0_list, axis=None)
            y0_dict[population] = np.concatenate(this_y0_list, axis=None)

        total_gid_count = 0
        gid_count_dict = dict(gid_count.items())
        req = comm.Ibarrier()
        selectivity_gid_count = comm.reduce(gid_count_dict,
                                            root=0,
                                            op=mpi_op_merge_count_dict)
        req.wait()

        if rank == 0:
            for selectivity_type_name in selectivity_gid_count:
                total_gid_count += selectivity_gid_count[selectivity_type_name]
            for selectivity_type_name in selectivity_gid_count:
                logger.info(
                    'generated selectivity features for %i/%i %s %s cells in %.2f s'
                    % (selectivity_gid_count[selectivity_type_name],
                       total_gid_count, population, selectivity_type_name,
                       (time.time() - start_time)))

        if not dry_run:
            for selectivity_type_name in sorted(selectivity_attr_dict.keys()):
                req = comm.Ibarrier()
                if rank == 0:
                    logger.info(
                        f'writing selectivity features for {population} [{selectivity_type_name}]...'
                    )
                selectivity_type_namespace = selectivity_type_namespaces[
                    selectivity_type_name]
                append_cell_attributes(
                    output_path,
                    population,
                    selectivity_attr_dict[selectivity_type_name],
                    namespace=selectivity_type_namespace,
                    comm=comm,
                    io_size=io_size,
                    chunk_size=chunk_size,
                    value_chunk_size=value_chunk_size)
                req.wait()
            del selectivity_attr_dict
            gc.collect()
        req = comm.Ibarrier()
        req.wait()

    if gather:
        merged_pop_norm_distances = {}
        for population in sorted(populations):
            merged_pop_norm_distances[population] = \
              comm.reduce(pop_norm_distances[population], root=0,
                          op=mpi_op_merge_dict)
        merged_rate_map_sum = comm.reduce(rate_map_sum,
                                          root=0,
                                          op=mpi_op_merge_rate_map_dict)
        merged_x0 = comm.reduce(x0_dict,
                                root=0,
                                op=mpi_op_concatenate_ndarray_dict)
        merged_y0 = comm.reduce(y0_dict,
                                root=0,
                                op=mpi_op_concatenate_ndarray_dict)
        if rank == 0:
            if plot:
                for population in merged_pop_norm_distances:
                    norm_distance_values = np.asarray(
                        list(merged_pop_norm_distances[population].values()))
                    hist, edges = np.histogram(norm_distance_values, bins=100)
                    fig, axes = plt.subplots(1)
                    axes.plot(edges[1:], hist)
                    axes.set_title(f'Population: {population}')
                    axes.set_xlabel('Normalized cell position')
                    axes.set_ylabel('Cell count')
                    clean_axes(axes)
                    if save_fig is not None:
                        save_figure(
                            f'{save_fig} {population} normalized distances histogram',
                            fig=fig,
                            **fig_options())
                    if fig_options.showFig:
                        fig.show()
                    close_figure(fig)
                for population in merged_rate_map_sum:
                    for selectivity_type_name in merged_rate_map_sum[
                            population]:
                        fig_title = f'{population} {this_selectivity_type_name} summed rate maps'
                        if save_fig is not None:
                            fig_options.saveFig = f'{save_fig} {fig_title}'
                        plot_2D_rate_map(
                            x=arena_x_mesh,
                            y=arena_y_mesh,
                            rate_map=merged_rate_map_sum[population]
                            [selectivity_type_name],
                            title=
                            f'Summed rate maps\n{population} {selectivity_type_name} cells',
                            **fig_options())
                for population in merged_x0:
                    fig_title = f'{population} field offsets'
                    if save_fig is not None:
                        fig_options.saveFig = f'{save_fig} {fig_title}'
                    x0 = merged_x0[population]
                    y0 = merged_y0[population]
                    fig, axes = plt.subplots(1)
                    axes.scatter(x0, y0)
                    if save_fig is not None:
                        save_figure(f'{save_fig} {fig_title}',
                                    fig=fig,
                                    **fig_options())
                    if fig_options.showFig:
                        fig.show()
                    close_figure(fig)

    if interactive and rank == 0:
        context.update(locals())
def main(forest_path, connectivity_namespace, coords_path, coords_namespace, io_size, chunk_size, value_chunk_size,
         cache_size):
    """

    :param forest_path:
    :param connectivity_namespace:
    :param coords_path:
    :param coords_namespace:
    :param io_size:
    :param chunk_size:
    :param value_chunk_size:
    :param cache_size:
    """
    comm = MPI.COMM_WORLD
    rank = comm.rank  # The process ID (integer 0-3 for 4-process run)

    if io_size == -1:
        io_size = comm.size
    if rank == 0:
        print('%i ranks have been allocated' % comm.size)
    sys.stdout.flush()

    start_time = time.time()

    soma_coords = {}
    source_populations = list(read_population_ranges(MPI._addressof(comm), coords_path).keys())
    for population in source_populations:
        soma_coords[population] = bcast_cell_attributes(MPI._addressof(comm), 0, coords_path, population,
                                                            namespace=coords_namespace)

    for population in soma_coords:
        for cell in viewvalues(soma_coords[population]):
            cell['u_index'] = get_array_index(u, cell['U Coordinate'][0])
            cell['v_index'] = get_array_index(v, cell['V Coordinate'][0])

    target = 'GC'

    layer_set, swc_type_set, syn_type_set = set(), set(), set()
    for source in layers[target]:
        layer_set.update(layers[target][source])
        swc_type_set.update(swc_types[target][source])
        syn_type_set.update(syn_types[target][source])

    count = 0
    for target_gid, attributes_dict in NeuroH5CellAttrGen(MPI._addressof(comm), forest_path, target, io_size=io_size,
                                                        cache_size=cache_size, namespace='Synapse_Attributes'):
        last_time = time.time()
        connection_dict = {}
        p_dict = {}
        source_gid_dict = {}
        if target_gid is None:
            print('Rank %i target gid is None' % rank)
        else:
            print('Rank %i received attributes for target: %s, gid: %i' % (rank, target, target_gid))
            synapse_dict = attributes_dict['Synapse_Attributes']
            connection_dict[target_gid] = {}
            local_np_random.seed(target_gid + connectivity_seed_offset)
            connection_dict[target_gid]['source_gid'] = np.array([], dtype='uint32')
            connection_dict[target_gid]['syn_id'] = np.array([], dtype='uint32')

            for layer in layer_set:
                for swc_type in swc_type_set:
                    for syn_type in syn_type_set:
                        sources, this_proportions = filter_sources(target, layer, swc_type, syn_type)
                        if sources:
                            if rank == 0 and count == 0:
                                source_list_str = '[' + ', '.join(['%s' % xi for xi in sources]) + ']'
                                print('Connections to target: %s in layer: %i ' \
                                    '(swc_type: %i, syn_type: %i): %s' % \
                                    (target, layer, swc_type, syn_type, source_list_str))
                            p, source_gid = np.array([]), np.array([])
                            for source, this_proportion in zip(sources, this_proportions):
                                if source not in source_gid_dict:
                                    this_p, this_source_gid = p_connect.get_p(target, source, target_gid, soma_coords,
                                                                              distance_U, distance_V)
                                    source_gid_dict[source] = this_source_gid
                                    p_dict[source] = this_p
                                else:
                                    this_source_gid = source_gid_dict[source]
                                    this_p = p_dict[source]
                                p = np.append(p, this_p * this_proportion)
                                source_gid = np.append(source_gid, this_source_gid)
                            syn_indexes = filter_synapses(synapse_dict, layer, swc_type, syn_type)
                            connection_dict[target_gid]['syn_id'] = \
                                np.append(connection_dict[target_gid]['syn_id'],
                                          synapse_dict['syn_id'][syn_indexes]).astype('uint32', copy=False)
                            this_source_gid = local_np_random.choice(source_gid, len(syn_indexes), p=p)
                            connection_dict[target_gid]['source_gid'] = \
                                np.append(connection_dict[target_gid]['source_gid'],
                                          this_source_gid).astype('uint32', copy=False)
            count += 1
            print('Rank %i took %i s to compute connectivity for target: %s, gid: %i' % (rank, time.time() - last_time,
                                                                                         target, target_gid))
            sys.stdout.flush()
        last_time = time.time()
        append_cell_attributes(MPI._addressof(comm), forest_path, target, connection_dict,
                               namespace=connectivity_namespace, io_size=io_size, chunk_size=chunk_size,
                               value_chunk_size=value_chunk_size)
        if rank == 0:
            print('Appending connectivity attributes for target: %s took %i s' % (target, time.time() - last_time))
        sys.stdout.flush()
        del connection_dict
        del p_dict
        del source_gid_dict
        gc.collect()

    global_count = comm.gather(count, root=0)
    if rank == 0:
        print('%i ranks took took %i s to compute connectivity for %i cells' % (comm.size, time.time() - start_time,
                                                                                  np.sum(global_count)))
Beispiel #17
0
def main(weights_path, weights_namespace, structured_weights_namespace, io_size, chunk_size, value_chunk_size,
         cache_size, debug):
    """

    :param weights_path:
    :param weights_namespace:
    :param structured_weights_namespace:
    :param io_size:
    :param chunk_size:
    :param value_chunk_size:
    :param cache_size:
    :param debug:
    """
    comm = MPI.COMM_WORLD
    rank = comm.rank

    if io_size == -1:
        io_size = comm.size
    if rank == 0:
        print('%i ranks have been allocated' % comm.size)
    sys.stdout.flush()

    population = 'GC'
    count = 0
    structured_count = 0
    start_time = time.time()
    weights_gen = NeuroH5CellAttrGen(MPI._addressof(comm), weights_path, population, io_size=io_size,
                                        cache_size=cache_size, namespace=weights_namespace)
    structured_weights_gen = NeuroH5CellAttrGen(MPI._addressof(comm), weights_path, population, io_size=io_size,
                                   cache_size=cache_size, namespace=structured_weights_namespace)
    if debug:
        attr_gen = ((next(weights_gen), next(structured_weights_gen)) for i in range(10))
    else:
        attr_gen = list(zip(weights_gen, structured_weights_gen))
    for (gid, weights_dict), (structured_weights_gid, structured_weights_dict) in attr_gen:
        local_time = time.time()
        modified_dict = {}
        sorted_indexes = None
        sorted_weights = None
        sorted_structured_indexes = None
        sorted_structured_weights = None
        if gid is not None:
            if gid != structured_weights_gid:
                raise Exception('gid %i from weights_gen does not match gid %i from structured_weights_gen') % \
                      (gid, structured_weights_gid)
            sorted_indexes = weights_dict[weights_namespace]['syn_id'].argsort()
            sorted_weights = weights_dict[weights_namespace]['weight'][sorted_indexes]
            sorted_structured_indexes = structured_weights_dict[structured_weights_namespace]['syn_id'].argsort()
            sorted_structured_weights = \
                structured_weights_dict[structured_weights_namespace]['weight'][sorted_structured_indexes]
            if not np.all(weights_dict[weights_namespace]['syn_id'][sorted_indexes] ==
                          structured_weights_dict[structured_weights_namespace]['syn_id'][sorted_structured_indexes]):
                raise Exception('gid %i: sorted syn_ids from weights_namespace do not match '
                                'structured_weights_namespace') % gid
            modify_weights = not np.all(sorted_weights == sorted_structured_weights)
            modified_dict[gid] = {'structured': np.array([int(modify_weights)], dtype='uint32')}
            print('Rank %i: %s gid %i took %.2f s to check for structured weights: %s' % \
                  (rank, population, gid, time.time() - local_time, str(modify_weights)))
            if modify_weights:
                structured_count += 1
            count += 1
        if not debug:
            append_cell_attributes(MPI._addressof(comm), weights_path, population, modified_dict,
                                   namespace=structured_weights_namespace, io_size=io_size, chunk_size=chunk_size,
                                   value_chunk_size=value_chunk_size)
        else:
            comm.barrier()
        del sorted_indexes
        del sorted_weights
        del sorted_structured_indexes
        del sorted_structured_weights
        del modified_dict
        gc.collect()
        sys.stdout.flush()

    global_count = comm.gather(count, root=0)
    global_structured_count = comm.gather(structured_count, root=0)
    if rank == 0:
        print('%i ranks processed %i %s cells (%i assigned structured weights) in %.2f s' % \
              (comm.size, np.sum(global_count), population, np.sum(global_structured_count),
               time.time() - start_time))
def main(config, coords_path, coords_namespace, distance_namespace, layers,
         npoints, spatial_resolution, io_size, verbose):

    comm = MPI.COMM_WORLD
    rank = comm.rank

    env = Env(comm=comm, config_file=config)

    max_extents = env.geometry['Parametric Surface']['Minimum Extent']
    min_extents = env.geometry['Parametric Surface']['Maximum Extent']

    layer_mids = []
    for ((layer_name, max_extent),
         (_, min_extent)) in itertools.izip(max_extents.iteritems(),
                                            min_extents.iteritems()):
        if layer_name in layers:
            mid = (max_extent[2] - min_extent[2]) / 2.
            layer_mids.append(mid)

    population_ranges = read_population_ranges(comm, coords_path)[0]

    ip_surfaces = []
    for layer in layer_mids:
        ip_surfaces.append(
            make_surface(l=layer, spatial_resolution=spatial_resolution))

    for population in population_ranges:
        (population_start, _) = population_ranges[population]

        for (layer_index, (layer_name, layer_mid, ip_surface)) in enumerate(
                itertools.izip(layers, layer_mids, ip_surfaces)):

            origin_u = np.min(ip_surface.su[0])
            origin_v = np.min(ip_surface.sv[0])

            for cell_gid, cell_coords_dict in NeuroH5CellAttrGen(
                    comm,
                    coords_path,
                    population,
                    io_size=io_size,
                    namespace=coords_namespace):
                arc_distance_dict = {}
                if cell_gid is None:
                    print 'Rank %i cell gid is None' % rank
                else:
                    cell_u = cell_coords_dict['U Coordinate']
                    cell_v = cell_coords_dict['V Coordinate']

                    U = np.linspace(origin_u, cell_u, npoints)
                    V = np.linspace(origin_v, cell_v, npoints)

                    arc_distance_u = ip_surface.point_distance(
                        U, cell_v, normalize_uv=True)
                    arc_distance_v = ip_surface.point_distance(
                        cell_u, V, normalize_uv=True)

                    arc_distance_dict[cell_gid - population_start] = {
                        'U Distance': np.asarray([arc_distance_u],
                                                 dtype='float32'),
                        'V Distance': np.asarray([arc_distance_v],
                                                 dtype='float32')
                    }

                    if verbose:
                        print 'Rank %i: gid = %i u = %f v = %f dist u = %f dist v = %f' % (
                            rank, cell_gid, cell_u, cell_v, arc_distance_u,
                            arc_distance_v)

                append_cell_attributes(comm,
                                       coords_path,
                                       population,
                                       arc_distance_dict,
                                       namespace='%s Layer %s' %
                                       (distance_namespace, layer_name),
                                       io_size=io_size)
def main(config, config_prefix, selectivity_path, selectivity_namespace,
         arena_id, populations, n_trials, io_size, chunk_size,
         value_chunk_size, cache_size, write_size, output_path,
         spikes_namespace, spike_train_attr_name, gather, debug, plot,
         show_fig, save_fig, save_fig_dir, font_size, fig_format, verbose,
         dry_run):
    """

    :param config: str (.yaml file name)
    :param config_prefix: str (path to dir)
    :param selectivity_path: str (path to file)
    :param selectivity_namespace: str
    :param arena_id: str
    :param populations: str
    :param n_trials: int
    :param io_size: int
    :param chunk_size: int
    :param value_chunk_size: int
    :param cache_size: int
    :param write_size: int
    :param output_path: str (path to file)
    :param spikes_namespace: str
    :param spike_train_attr_name: str
    :param gather: bool
    :param debug: bool
    :param plot: bool
    :param show_fig: bool
    :param save_fig: str (base file name)
    :param save_fig_dir:  str (path to dir)
    :param font_size: float
    :param fig_format: str
    :param verbose: bool
    :param dry_run: bool
    """
    comm = MPI.COMM_WORLD
    rank = comm.rank

    config_logging(verbose)

    env = Env(comm=comm,
              config_file=config,
              config_prefix=config_prefix,
              template_paths=None)
    if io_size == -1:
        io_size = comm.size
    if rank == 0:
        logger.info('%i ranks have been allocated' % comm.size)

    if save_fig is not None:
        plot = True

    if plot:
        from dentate.plot import default_fig_options

        fig_options = copy.copy(default_fig_options)
        fig_options.saveFigDir = save_fig_dir
        fig_options.fontSize = font_size
        fig_options.figFormat = fig_format
        fig_options.showFig = show_fig

    population_ranges = read_population_ranges(selectivity_path, comm)[0]

    if len(populations) == 0:
        populations = sorted(population_ranges.keys())

    if arena_id not in env.stimulus_config['Arena']:
        raise RuntimeError(
            'Arena with ID: %s not specified by configuration at file path: %s'
            % (arena_id, config_prefix + '/' + config))
    arena = env.stimulus_config['Arena'][arena_id]

    valid_selectivity_namespaces = dict()
    if rank == 0:
        for population in populations:
            if population not in population_ranges:
                raise RuntimeError(
                    'generate_input_spike_trains: specified population: %s not found in '
                    'provided selectivity_path: %s' %
                    (population, selectivity_path))
            if population not in env.stimulus_config[
                    'Selectivity Type Probabilities']:
                raise RuntimeError(
                    'generate_input_spike_trains: selectivity type not specified for '
                    'population: %s' % population)
            valid_selectivity_namespaces[population] = []
            with h5py.File(selectivity_path, 'r') as selectivity_f:
                for this_namespace in selectivity_f['Populations'][population]:
                    if 'Selectivity %s' % arena_id in this_namespace:
                        valid_selectivity_namespaces[population].append(
                            this_namespace)
                if len(valid_selectivity_namespaces[population]) == 0:
                    raise RuntimeError(
                        'generate_input_spike_trains: no selectivity data in arena: %s found '
                        'for specified population: %s in provided selectivity_path: %s'
                        % (arena_id, population, selectivity_path))
    comm.barrier()

    valid_selectivity_namespaces = comm.bcast(valid_selectivity_namespaces,
                                              root=0)
    selectivity_type_names = dict(
        (val, key) for (key, val) in viewitems(env.selectivity_types))

    equilibrate = get_equilibration(env)

    for trajectory_id in sorted(arena.trajectories.keys()):
        trajectory = arena.trajectories[trajectory_id]
        t, x, y, d = None, None, None, None
        if rank == 0:
            t, x, y, d = generate_linear_trajectory(
                trajectory,
                temporal_resolution=env.stimulus_config['Temporal Resolution'],
                equilibration_duration=env.
                stimulus_config['Equilibration Duration'])

        t = comm.bcast(t, root=0)
        x = comm.bcast(x, root=0)
        y = comm.bcast(y, root=0)
        d = comm.bcast(d, root=0)

        trajectory = t, x, y, d
        trajectory_namespace = 'Trajectory %s %s' % (arena_id, trajectory_id)
        output_namespace = '%s %s %s' % (spikes_namespace, arena_id,
                                         trajectory_id)

        if not dry_run and rank == 0:
            if output_path is None:
                raise RuntimeError(
                    'generate_input_spike_trains: missing output_path')
            if not os.path.isfile(output_path):
                with h5py.File(output_path, 'w') as output_file:
                    input_file = h5py.File(selectivity_path, 'r')
                    input_file.copy('/H5Types', output_file)
                    input_file.close()
            with h5py.File(output_path, 'a') as f:
                if trajectory_namespace not in f:
                    logger.info('Appending %s datasets to file at path: %s' %
                                (trajectory_namespace, output_path))
                group = f.create_group(trajectory_namespace)
                for key, value in zip(['t', 'x', 'y', 'd'], [t, x, y, d]):
                    dataset = group.create_dataset(key,
                                                   data=value,
                                                   dtype='float32')
                else:
                    loaded_t = f[trajectory_namespace]['t'][:]
                    if len(t) != len(loaded_t):
                        raise RuntimeError(
                            'generate_input_spike_trains: file at path: %s already contains the '
                            'namespace: %s, but the dataset sizes are inconsistent with the provided input'
                            'configuration' %
                            (output_path, trajectory_namespace))
        comm.barrier()

        if rank == 0:
            context.update(locals())

        spike_hist_sum_dict = {}
        spike_hist_resolution = 1000

        write_every = max(1, int(math.floor(write_size / comm.size)))
        for population in populations:

            this_spike_hist_sum = defaultdict(
                lambda: np.zeros(spike_hist_resolution))

            process_time = dict()
            for this_selectivity_namespace in sorted(
                    valid_selectivity_namespaces[population]):

                if rank == 0:
                    logger.info(
                        'Generating input source spike trains for population %s [%s]...'
                        % (population, this_selectivity_namespace))

                start_time = time.time()
                selectivity_attr_gen = NeuroH5CellAttrGen(
                    selectivity_path,
                    population,
                    namespace=this_selectivity_namespace,
                    comm=comm,
                    io_size=io_size,
                    cache_size=cache_size)
                spikes_attr_dict = dict()
                gid_count = 0
                for iter_count, (gid, selectivity_attr_dict
                                 ) in enumerate(selectivity_attr_gen):
                    if gid is not None:
                        context.update(locals())
                        spikes_attr_dict[gid] = \
                            generate_input_spike_trains(env, selectivity_type_names, trajectory,
                                                        gid, selectivity_attr_dict, n_trials=n_trials,
                                                        spike_train_attr_name=spike_train_attr_name,
                                                        spike_hist_resolution=spike_hist_resolution,
                                                        equilibrate=equilibrate,
                                                        spike_hist_sum=this_spike_hist_sum,
                                                        debug= (debug_callback, context) if debug else False)
                        gid_count += 1

                    if (iter_count > 0 and iter_count % write_every
                            == 0) or (debug and iter_count == 10):
                        total_gid_count = comm.reduce(gid_count,
                                                      root=0,
                                                      op=MPI.SUM)
                        if rank == 0:
                            logger.info(
                                'generated spike trains for %i %s cells' %
                                (total_gid_count, population))

                        if not dry_run:
                            append_cell_attributes(
                                output_path,
                                population,
                                spikes_attr_dict,
                                namespace=output_namespace,
                                comm=comm,
                                io_size=io_size,
                                chunk_size=chunk_size,
                                value_chunk_size=value_chunk_size)
                        del spikes_attr_dict
                        spikes_attr_dict = dict()

                        if debug and iter_count == 10:
                            break

            if not dry_run:
                append_cell_attributes(output_path,
                                       population,
                                       spikes_attr_dict,
                                       namespace=output_namespace,
                                       comm=comm,
                                       io_size=io_size,
                                       chunk_size=chunk_size,
                                       value_chunk_size=value_chunk_size)
                del spikes_attr_dict
                spikes_attr_dict = dict()
            process_time = time.time() - start_time

            total_gid_count = comm.reduce(gid_count, root=0, op=MPI.SUM)
            if rank == 0:
                logger.info(
                    'generated spike trains for %i %s cells in %.2f s' %
                    (total_gid_count, population, process_time))

            if gather:
                spike_hist_sum_dict[population] = this_spike_hist_sum

        if gather:
            this_spike_hist_sum = dict([
                (key, dict(val.items()))
                for key, val in viewitems(spike_hist_sum_dict)
            ])
            spike_hist_sum = comm.gather(this_spike_hist_sum, root=0)

            if rank == 0:
                merged_spike_hist_sum = defaultdict(lambda: defaultdict(
                    lambda: np.zeros(spike_hist_resolution)))
                for each_spike_hist_sum in spike_hist_sum:
                    for population in each_spike_hist_sum:
                        for selectivity_type_name in each_spike_hist_sum[
                                population]:
                            merged_spike_hist_sum[population][selectivity_type_name] = \
                                np.add(merged_spike_hist_sum[population][selectivity_type_name],
                                       each_spike_hist_sum[population][selectivity_type_name])

                if plot:

                    if save_fig is not None:
                        fig_options.saveFig = save_fig

                        plot_summed_spike_psth(t, trajectory_id,
                                               selectivity_type_name,
                                               merged_spike_hist_sum,
                                               spike_hist_resolution,
                                               fig_options)

        comm.barrier()

    if is_interactive and rank == 0:
        context.update(locals())
Beispiel #20
0
from mpi4py import MPI
from neuroh5.io import read_population_ranges, NeuroH5CellAttrGen

# import mkl
import sys, os, gc
import numpy as np
comm = MPI.COMM_WORLD
rank = comm.rank

input_path = "/scratch1/03320/iraikov/striped/dentate/Full_Scale_Control/DGC_forest_syns_20201217_compressed.h5"
synapse_namespace = 'Synapse Attributes'
io_size = 20
cache_size = 1

if rank == 0:
    print("%d ranks allocated" % comm.size)
    sys.stdout.flush()
it = NeuroH5CellAttrGen(input_path, 'GC', namespace=synapse_namespace, \
                        comm=comm, io_size=io_size, cache_size=cache_size)

for (gid, synapse_dict) in it:
    if gid is not None:
        print('rank %i: gid = %i' % (rank, gid))
        sys.stdout.flush()