def plot_summed_spike_psth(t, trajectory_id, selectivity_type_name,
                           merged_spike_hist_sum, spike_hist_resolution,
                           fig_options):
    import matplotlib.pyplot as plt
    from dentate.plot import save_figure, clean_axes

    spike_hist_edges = np.linspace(min(t), max(t), spike_hist_resolution + 1)
    for population, this_selectivity_type_name in viewitems(
            merged_spike_hist_sum):
        for this_selectivity_type_name in merged_spike_hist_sum[population]:
            fig_title = '%s %s summed spike PSTH' % (
                population, this_selectivity_type_name)
            fig, axes = plt.subplots()
            axes.plot(spike_hist_edges[1:],
                      merged_spike_hist_sum[population][selectivity_type_name])
            axes.set_xlabel('Time (ms)', fontsize=fig_options.fontSize)
            axes.set_ylabel('Population spike count',
                            fontsize=fig_options.fontSize)
            axes.set_ylim(
                0.,
                np.max(
                    merged_spike_hist_sum[population][selectivity_type_name]) *
                1.1)
            axes.set_title('Summed spike PSTH\n%s %s cells' %
                           (population, selectivity_type_name),
                           fontsize=fig_options.fontSize)
            clean_axes(axes)

            if fig_options.saveFig is not None:
                save_title = 'Summed spike PSTH %s %s %s' % (
                    trajectory_id, population, selectivity_type_name)
                save_fig = '%s %s' % (fig_options.saveFig, save_title)
                save_figure(save_fig, fig=fig, **fig_options())

            if fig_options.showFig:
                fig.show()
def main(config, config_prefix, coords_path, distances_namespace, output_path,
         arena_id, populations, use_noise_gen, io_size, chunk_size,
         value_chunk_size, cache_size, write_size, verbose, gather,
         interactive, debug, debug_count, plot, show_fig, save_fig,
         save_fig_dir, font_size, fig_format, dry_run):
    """

    :param config: str (.yaml file name)
    :param config_prefix: str (path to dir)
    :param coords_path: str (path to file)
    :param distances_namespace: str
    :param output_path: str
    :param arena_id: str
    :param populations: tuple of str
    :param io_size: int
    :param chunk_size: int
    :param value_chunk_size: int
    :param cache_size: int
    :param write_size: int
    :param verbose: bool
    :param gather: bool; whether to gather population attributes to rank 0 for interactive analysis or plotting
    :param interactive: bool
    :param debug: bool
    :param plot: bool
    :param show_fig: bool
    :param save_fig: str (base file name)
    :param save_fig_dir:  str (path to dir)
    :param font_size: float
    :param fig_format: str
    :param dry_run: bool
    """
    comm = MPI.COMM_WORLD
    rank = comm.rank

    config_logging(verbose)

    env = Env(comm=comm,
              config_file=config,
              config_prefix=config_prefix,
              template_paths=None)
    if io_size == -1:
        io_size = comm.size
    if rank == 0:
        logger.info(f'{comm.size} ranks have been allocated')

    if save_fig is not None:
        plot = True

    if plot:
        import matplotlib.pyplot as plt
        from dentate.plot import plot_2D_rate_map, default_fig_options, save_figure, clean_axes, close_figure

        fig_options = copy.copy(default_fig_options)
        fig_options.saveFigDir = save_fig_dir
        fig_options.fontSize = font_size
        fig_options.figFormat = fig_format
        fig_options.showFig = show_fig

    if save_fig is not None:
        save_fig = '%s %s' % (save_fig, arena_id)
        fig_options.saveFig = save_fig

    if not dry_run and rank == 0:
        if output_path is None:
            raise RuntimeError(
                'generate_input_selectivity_features: missing output_path')
        if not os.path.isfile(output_path):
            input_file = h5py.File(coords_path, 'r')
            output_file = h5py.File(output_path, 'w')
            input_file.copy('/H5Types', output_file)
            input_file.close()
            output_file.close()
    comm.barrier()
    population_ranges = read_population_ranges(coords_path, comm)[0]

    if len(populations) == 0:
        populations = sorted(population_ranges.keys())

    reference_u_arc_distance_bounds_dict = {}
    if rank == 0:
        for population in sorted(populations):
            if population not in population_ranges:
                raise RuntimeError(
                    'generate_input_selectivity_features: specified population: %s not found in '
                    'provided coords_path: %s' % (population, coords_path))
            if population not in env.stimulus_config[
                    'Selectivity Type Probabilities']:
                raise RuntimeError(
                    'generate_input_selectivity_features: selectivity type not specified for '
                    'population: %s' % population)
            with h5py.File(coords_path, 'r') as coords_f:
                pop_size = population_ranges[population][1]
                unique_gid_count = len(
                    set(coords_f['Populations'][population]
                        [distances_namespace]['U Distance']['Cell Index'][:]))
                if pop_size != unique_gid_count:
                    raise RuntimeError(
                        'generate_input_selectivity_features: only %i/%i unique cell indexes found '
                        'for specified population: %s in provided coords_path: %s'
                        %
                        (unique_gid_count, pop_size, population, coords_path))
                try:
                    reference_u_arc_distance_bounds_dict[population] = \
                      coords_f['Populations'][population][distances_namespace].attrs['Reference U Min'], \
                      coords_f['Populations'][population][distances_namespace].attrs['Reference U Max']
                except Exception:
                    raise RuntimeError(
                        'generate_input_selectivity_features: problem locating attributes '
                        'containing reference bounds in namespace: %s for population: %s from '
                        'coords_path: %s' %
                        (distances_namespace, population, coords_path))
    comm.barrier()
    reference_u_arc_distance_bounds_dict = comm.bcast(
        reference_u_arc_distance_bounds_dict, root=0)

    selectivity_type_names = dict([
        (val, key) for (key, val) in viewitems(env.selectivity_types)
    ])
    selectivity_type_namespaces = dict()
    for this_selectivity_type in selectivity_type_names:
        this_selectivity_type_name = selectivity_type_names[
            this_selectivity_type]
        chars = list(this_selectivity_type_name)
        chars[0] = chars[0].upper()
        selectivity_type_namespaces[this_selectivity_type_name] = ''.join(
            chars) + ' Selectivity %s' % arena_id

    if arena_id not in env.stimulus_config['Arena']:
        raise RuntimeError(
            f'Arena with ID: {arena_id} not specified by configuration at file path: {config_prefix}/{config}'
        )
    arena = env.stimulus_config['Arena'][arena_id]
    arena_x_mesh, arena_y_mesh = None, None
    if rank == 0:
        arena_x_mesh, arena_y_mesh = \
             get_2D_arena_spatial_mesh(arena=arena, spatial_resolution=env.stimulus_config['Spatial Resolution'])
    arena_x_mesh = comm.bcast(arena_x_mesh, root=0)
    arena_y_mesh = comm.bcast(arena_y_mesh, root=0)

    local_random = np.random.RandomState()
    selectivity_seed_offset = int(
        env.model_config['Random Seeds']['Input Selectivity'])
    local_random.seed(selectivity_seed_offset - 1)

    selectivity_config = InputSelectivityConfig(env.stimulus_config,
                                                local_random)
    if plot and rank == 0:
        selectivity_config.plot_module_probabilities(**fig_options())

    if (debug or interactive) and rank == 0:
        context.update(dict(locals()))

    pop_norm_distances = {}
    rate_map_sum = {}
    x0_dict = {}
    y0_dict = {}
    write_every = max(1, int(math.floor(write_size / comm.size)))
    for population in sorted(populations):
        if rank == 0:
            logger.info(
                f'Generating input selectivity features for population {population}...'
            )

        reference_u_arc_distance_bounds = reference_u_arc_distance_bounds_dict[
            population]

        modular = True
        if population in env.stimulus_config[
                'Non-modular Place Selectivity Populations']:
            modular = False

        noise_gen_dict = None
        if use_noise_gen:
            noise_gen_dict = {}
            if modular:
                for module_id in range(env.stimulus_config['Number Modules']):
                    extent_x, extent_y = get_2D_arena_extents(arena)
                    margin = round(
                        selectivity_config.place_module_field_widths[module_id]
                        / 2.)
                    arena_x_bounds, arena_y_bounds = get_2D_arena_bounds(
                        arena, margin=margin)
                    noise_gen = MPINoiseGenerator(
                        comm=comm,
                        bounds=(arena_x_bounds, arena_y_bounds),
                        tile_rank=comm.rank,
                        bin_size=0.5,
                        mask_fraction=0.99,
                        seed=int(selectivity_seed_offset + module_id * 1e6))
                    noise_gen_dict[module_id] = noise_gen
            else:
                margin = round(
                    np.mean(selectivity_config.place_module_field_widths) / 2.)
                arena_x_bounds, arena_y_bounds = get_2D_arena_bounds(
                    arena, margin=margin)
                noise_gen_dict[-1] = MPINoiseGenerator(
                    comm=comm,
                    bounds=(arena_x_bounds, arena_y_bounds),
                    tile_rank=comm.rank,
                    bin_size=0.5,
                    mask_fraction=0.99,
                    seed=selectivity_seed_offset)

        this_pop_norm_distances = {}
        this_rate_map_sum = defaultdict(lambda: np.zeros_like(arena_x_mesh))
        this_x0_list = []
        this_y0_list = []
        start_time = time.time()
        gid_count = defaultdict(lambda: 0)
        distances_attr_gen = NeuroH5CellAttrGen(coords_path,
                                                population,
                                                namespace=distances_namespace,
                                                comm=comm,
                                                io_size=io_size,
                                                cache_size=cache_size)

        selectivity_attr_dict = dict(
            (key, dict()) for key in env.selectivity_types)
        for iter_count, (gid,
                         distances_attr_dict) in enumerate(distances_attr_gen):
            req = comm.Ibarrier()
            if gid is None:
                if noise_gen_dict is not None:
                    all_module_ids = [-1]
                    if modular:
                        all_module_ids = comm.allreduce(set([]),
                                                        op=mpi_op_set_union)
                    for module_id in all_module_ids:
                        this_noise_gen = noise_gen_dict[module_id]
                        global_num_fields = this_noise_gen.sync(0)
                        for i in range(global_num_fields):
                            this_noise_gen.add(
                                np.empty(shape=(0, 0), dtype=np.float32), None)
            else:
                if rank == 0:
                    logger.info(
                        f'Rank {rank} generating selectivity features for gid {gid}...'
                    )
                u_arc_distance = distances_attr_dict['U Distance'][0]
                v_arc_distance = distances_attr_dict['V Distance'][0]
                norm_u_arc_distance = (
                    (u_arc_distance - reference_u_arc_distance_bounds[0]) /
                    (reference_u_arc_distance_bounds[1] -
                     reference_u_arc_distance_bounds[0]))

                this_pop_norm_distances[gid] = norm_u_arc_distance

                this_selectivity_type_name, this_selectivity_attr_dict = \
                 generate_input_selectivity_features(env, population, arena,
                                                     arena_x_mesh, arena_y_mesh,
                                                     gid, (norm_u_arc_distance, v_arc_distance),
                                                     selectivity_config, selectivity_type_names,
                                                     selectivity_type_namespaces,
                                                     noise_gen_dict=noise_gen_dict,
                                                     rate_map_sum=this_rate_map_sum,
                                                     debug= (debug_callback, context) if debug else False)
                if 'X Offset' in this_selectivity_attr_dict:
                    this_x0_list.append(this_selectivity_attr_dict['X Offset'])
                    this_y0_list.append(this_selectivity_attr_dict['Y Offset'])
                selectivity_attr_dict[this_selectivity_type_name][
                    gid] = this_selectivity_attr_dict
                gid_count[this_selectivity_type_name] += 1
            if noise_gen_dict is not None:
                for m in noise_gen_dict:
                    noise_gen_dict[m].tile_rank = (
                        noise_gen_dict[m].tile_rank + 1) % comm.size
            req.wait()

            if (iter_count > 0 and iter_count % write_every
                    == 0) or (debug and iter_count == debug_count):
                total_gid_count = 0
                gid_count_dict = dict(gid_count.items())
                req = comm.Ibarrier()
                selectivity_gid_count = comm.reduce(gid_count_dict,
                                                    root=0,
                                                    op=mpi_op_merge_count_dict)
                req.wait()
                if rank == 0:
                    for selectivity_type_name in selectivity_gid_count:
                        total_gid_count += selectivity_gid_count[
                            selectivity_type_name]
                    for selectivity_type_name in selectivity_gid_count:
                        logger.info(
                            'generated selectivity features for %i/%i %s %s cells in %.2f s'
                            % (selectivity_gid_count[selectivity_type_name],
                               total_gid_count, population,
                               selectivity_type_name,
                               (time.time() - start_time)))

                if not dry_run:
                    for selectivity_type_name in sorted(
                            selectivity_attr_dict.keys()):
                        req = comm.Ibarrier()
                        if rank == 0:
                            logger.info(
                                f'writing selectivity features for {population} [{selectivity_type_name}]...'
                            )
                        selectivity_type_namespace = selectivity_type_namespaces[
                            selectivity_type_name]
                        append_cell_attributes(
                            output_path,
                            population,
                            selectivity_attr_dict[selectivity_type_name],
                            namespace=selectivity_type_namespace,
                            comm=comm,
                            io_size=io_size,
                            chunk_size=chunk_size,
                            value_chunk_size=value_chunk_size)
                        req.wait()
                    del selectivity_attr_dict
                    selectivity_attr_dict = dict(
                        (key, dict()) for key in env.selectivity_types)
                    gc.collect()

            if debug and iter_count >= debug_count:
                break

        pop_norm_distances[population] = this_pop_norm_distances
        rate_map_sum[population] = dict(this_rate_map_sum)
        if len(this_x0_list) > 0:
            x0_dict[population] = np.concatenate(this_x0_list, axis=None)
            y0_dict[population] = np.concatenate(this_y0_list, axis=None)

        total_gid_count = 0
        gid_count_dict = dict(gid_count.items())
        req = comm.Ibarrier()
        selectivity_gid_count = comm.reduce(gid_count_dict,
                                            root=0,
                                            op=mpi_op_merge_count_dict)
        req.wait()

        if rank == 0:
            for selectivity_type_name in selectivity_gid_count:
                total_gid_count += selectivity_gid_count[selectivity_type_name]
            for selectivity_type_name in selectivity_gid_count:
                logger.info(
                    'generated selectivity features for %i/%i %s %s cells in %.2f s'
                    % (selectivity_gid_count[selectivity_type_name],
                       total_gid_count, population, selectivity_type_name,
                       (time.time() - start_time)))

        if not dry_run:
            for selectivity_type_name in sorted(selectivity_attr_dict.keys()):
                req = comm.Ibarrier()
                if rank == 0:
                    logger.info(
                        f'writing selectivity features for {population} [{selectivity_type_name}]...'
                    )
                selectivity_type_namespace = selectivity_type_namespaces[
                    selectivity_type_name]
                append_cell_attributes(
                    output_path,
                    population,
                    selectivity_attr_dict[selectivity_type_name],
                    namespace=selectivity_type_namespace,
                    comm=comm,
                    io_size=io_size,
                    chunk_size=chunk_size,
                    value_chunk_size=value_chunk_size)
                req.wait()
            del selectivity_attr_dict
            gc.collect()
        req = comm.Ibarrier()
        req.wait()

    if gather:
        merged_pop_norm_distances = {}
        for population in sorted(populations):
            merged_pop_norm_distances[population] = \
              comm.reduce(pop_norm_distances[population], root=0,
                          op=mpi_op_merge_dict)
        merged_rate_map_sum = comm.reduce(rate_map_sum,
                                          root=0,
                                          op=mpi_op_merge_rate_map_dict)
        merged_x0 = comm.reduce(x0_dict,
                                root=0,
                                op=mpi_op_concatenate_ndarray_dict)
        merged_y0 = comm.reduce(y0_dict,
                                root=0,
                                op=mpi_op_concatenate_ndarray_dict)
        if rank == 0:
            if plot:
                for population in merged_pop_norm_distances:
                    norm_distance_values = np.asarray(
                        list(merged_pop_norm_distances[population].values()))
                    hist, edges = np.histogram(norm_distance_values, bins=100)
                    fig, axes = plt.subplots(1)
                    axes.plot(edges[1:], hist)
                    axes.set_title(f'Population: {population}')
                    axes.set_xlabel('Normalized cell position')
                    axes.set_ylabel('Cell count')
                    clean_axes(axes)
                    if save_fig is not None:
                        save_figure(
                            f'{save_fig} {population} normalized distances histogram',
                            fig=fig,
                            **fig_options())
                    if fig_options.showFig:
                        fig.show()
                    close_figure(fig)
                for population in merged_rate_map_sum:
                    for selectivity_type_name in merged_rate_map_sum[
                            population]:
                        fig_title = f'{population} {this_selectivity_type_name} summed rate maps'
                        if save_fig is not None:
                            fig_options.saveFig = f'{save_fig} {fig_title}'
                        plot_2D_rate_map(
                            x=arena_x_mesh,
                            y=arena_y_mesh,
                            rate_map=merged_rate_map_sum[population]
                            [selectivity_type_name],
                            title=
                            f'Summed rate maps\n{population} {selectivity_type_name} cells',
                            **fig_options())
                for population in merged_x0:
                    fig_title = f'{population} field offsets'
                    if save_fig is not None:
                        fig_options.saveFig = f'{save_fig} {fig_title}'
                    x0 = merged_x0[population]
                    y0 = merged_y0[population]
                    fig, axes = plt.subplots(1)
                    axes.scatter(x0, y0)
                    if save_fig is not None:
                        save_figure(f'{save_fig} {fig_title}',
                                    fig=fig,
                                    **fig_options())
                    if fig_options.showFig:
                        fig.show()
                    close_figure(fig)

    if interactive and rank == 0:
        context.update(locals())
Exemplo n.º 3
0
def main(config, config_prefix, coords_path, distances_namespace, output_path,
         arena_id, populations, io_size, chunk_size, value_chunk_size,
         cache_size, write_size, verbose, gather, interactive, debug, plot,
         show_fig, save_fig, save_fig_dir, font_size, fig_format, dry_run):
    """

    :param config: str (.yaml file name)
    :param config_prefix: str (path to dir)
    :param coords_path: str (path to file)
    :param distances_namespace: str
    :param output_path: str
    :param arena_id: str
    :param populations: tuple of str
    :param io_size: int
    :param chunk_size: int
    :param value_chunk_size: int
    :param cache_size: int
    :param write_size: int
    :param verbose: bool
    :param gather: bool; whether to gather population attributes to rank 0 for interactive analysis or plotting
    :param interactive: bool
    :param debug: bool
    :param plot: bool
    :param show_fig: bool
    :param save_fig: str (base file name)
    :param save_fig_dir:  str (path to dir)
    :param font_size: float
    :param fig_format: str
    :param dry_run: bool
    """
    comm = MPI.COMM_WORLD
    rank = comm.rank

    config_logging(verbose)

    env = Env(comm=comm,
              config_file=config,
              config_prefix=config_prefix,
              template_paths=None)
    if io_size == -1:
        io_size = comm.size
    if rank == 0:
        logger.info('%i ranks have been allocated' % comm.size)

    if save_fig is not None:
        plot = True

    if plot:
        import matplotlib.pyplot as plt
        from dentate.plot import plot_2D_rate_map, default_fig_options, save_figure, clean_axes

        fig_options = copy.copy(default_fig_options)
        fig_options.saveFigDir = save_fig_dir
        fig_options.fontSize = font_size
        fig_options.figFormat = fig_format
        fig_options.showFig = show_fig

    if save_fig is not None:
        save_fig = '%s %s' % (save_fig, arena_id)
        fig_options.saveFig = save_fig

    if not dry_run and rank == 0:
        if output_path is None:
            raise RuntimeError(
                'generate_input_selectivity_features: missing output_path')
        if not os.path.isfile(output_path):
            input_file = h5py.File(coords_path, 'r')
            output_file = h5py.File(output_path, 'w')
            input_file.copy('/H5Types', output_file)
            input_file.close()
            output_file.close()
    comm.barrier()
    population_ranges = read_population_ranges(coords_path, comm)[0]

    if len(populations) == 0:
        populations = sorted(population_ranges.keys())

    reference_u_arc_distance_bounds_dict = {}
    if rank == 0:
        for population in sorted(populations):
            if population not in population_ranges:
                raise RuntimeError(
                    'generate_input_selectivity_features: specified population: %s not found in '
                    'provided coords_path: %s' % (population, coords_path))
            if population not in env.stimulus_config[
                    'Selectivity Type Probabilities']:
                raise RuntimeError(
                    'generate_input_selectivity_features: selectivity type not specified for '
                    'population: %s' % population)
            with h5py.File(coords_path, 'r') as coords_f:
                pop_size = population_ranges[population][1]
                unique_gid_count = len(
                    set(coords_f['Populations'][population]
                        [distances_namespace]['U Distance']['Cell Index'][:]))
                if pop_size != unique_gid_count:
                    raise RuntimeError(
                        'generate_input_selectivity_features: only %i/%i unique cell indexes found '
                        'for specified population: %s in provided coords_path: %s'
                        %
                        (unique_gid_count, pop_size, population, coords_path))
                try:
                    reference_u_arc_distance_bounds_dict[population] = \
                      coords_f['Populations'][population][distances_namespace].attrs['Reference U Min'], \
                      coords_f['Populations'][population][distances_namespace].attrs['Reference U Max']
                except Exception:
                    raise RuntimeError(
                        'generate_input_selectivity_features: problem locating attributes '
                        'containing reference bounds in namespace: %s for population: %s from '
                        'coords_path: %s' %
                        (distances_namespace, population, coords_path))
    comm.barrier()
    reference_u_arc_distance_bounds_dict = comm.bcast(
        reference_u_arc_distance_bounds_dict, root=0)

    selectivity_type_names = dict([
        (val, key) for (key, val) in viewitems(env.selectivity_types)
    ])
    selectivity_type_namespaces = dict()
    for this_selectivity_type in selectivity_type_names:
        this_selectivity_type_name = selectivity_type_names[
            this_selectivity_type]
        chars = list(this_selectivity_type_name)
        chars[0] = chars[0].upper()
        selectivity_type_namespaces[this_selectivity_type_name] = ''.join(
            chars) + ' Selectivity %s' % arena_id

    if arena_id not in env.stimulus_config['Arena']:
        raise RuntimeError(
            'Arena with ID: %s not specified by configuration at file path: %s'
            % (arena_id, config_prefix + '/' + config))
    arena = env.stimulus_config['Arena'][arena_id]
    arena_x_mesh, arena_y_mesh = None, None
    if rank == 0:
        arena_x_mesh, arena_y_mesh = \
             get_2D_arena_spatial_mesh(arena=arena, spatial_resolution=env.stimulus_config['Spatial Resolution'])
    arena_x_mesh = comm.bcast(arena_x_mesh, root=0)
    arena_y_mesh = comm.bcast(arena_y_mesh, root=0)

    local_random = np.random.RandomState()
    selectivity_seed_offset = int(
        env.model_config['Random Seeds']['Input Selectivity'])
    local_random.seed(selectivity_seed_offset - 1)

    selectivity_config = InputSelectivityConfig(env.stimulus_config,
                                                local_random)
    if plot and rank == 0:
        selectivity_config.plot_module_probabilities(**fig_options())

    if (debug or interactive) and rank == 0:
        context.update(dict(locals()))

    pop_norm_distances = {}
    rate_map_sum = {}
    write_every = max(1, int(math.floor(write_size / comm.size)))
    for population in sorted(populations):
        if rank == 0:
            logger.info(
                'Generating input selectivity features for population %s...' %
                population)

        reference_u_arc_distance_bounds = reference_u_arc_distance_bounds_dict[
            population]

        this_pop_norm_distances = {}
        this_rate_map_sum = defaultdict(lambda: np.zeros_like(arena_x_mesh))
        start_time = time.time()
        gid_count = defaultdict(lambda: 0)
        distances_attr_gen = NeuroH5CellAttrGen(coords_path,
                                                population,
                                                namespace=distances_namespace,
                                                comm=comm,
                                                io_size=io_size,
                                                cache_size=cache_size)

        selectivity_attr_dict = dict(
            (key, dict()) for key in env.selectivity_types)
        for iter_count, (gid,
                         distances_attr_dict) in enumerate(distances_attr_gen):
            if gid is not None:
                u_arc_distance = distances_attr_dict['U Distance'][0]
                v_arc_distance = distances_attr_dict['V Distance'][0]
                norm_u_arc_distance = (
                    (u_arc_distance - reference_u_arc_distance_bounds[0]) /
                    (reference_u_arc_distance_bounds[1] -
                     reference_u_arc_distance_bounds[0]))

                this_pop_norm_distances[gid] = norm_u_arc_distance

                this_selectivity_type_name, this_selectivity_attr_dict = \
                 generate_input_selectivity_features(env, population, arena,
                                                     arena_x_mesh, arena_y_mesh,
                                                     gid, (norm_u_arc_distance, v_arc_distance),
                                                     selectivity_config, selectivity_type_names,
                                                     selectivity_type_namespaces,
                                                     rate_map_sum=this_rate_map_sum,
                                                     debug= (debug_callback, context) if debug else False)
                selectivity_attr_dict[this_selectivity_type_name][
                    gid] = this_selectivity_attr_dict
                gid_count[this_selectivity_type_name] += 1

            if (iter_count > 0 and iter_count % write_every
                    == 0) or (debug and iter_count == 10):
                total_gid_count = 0
                gid_count_dict = dict(gid_count.items())
                selectivity_gid_count = comm.reduce(gid_count_dict,
                                                    root=0,
                                                    op=mpi_op_merge_count_dict)
                if rank == 0:
                    for selectivity_type_name in selectivity_gid_count:
                        total_gid_count += selectivity_gid_count[
                            selectivity_type_name]
                    for selectivity_type_name in selectivity_gid_count:
                        logger.info(
                            'generated selectivity features for %i/%i %s %s cells in %.2f s'
                            % (selectivity_gid_count[selectivity_type_name],
                               total_gid_count, population,
                               selectivity_type_name,
                               (time.time() - start_time)))

                if not dry_run:
                    for selectivity_type_name in sorted(
                            selectivity_attr_dict.keys()):
                        if rank == 0:
                            logger.info(
                                'writing selectivity features for %s [%s]...' %
                                (population, selectivity_type_name))
                        selectivity_type_namespace = selectivity_type_namespaces[
                            selectivity_type_name]
                        append_cell_attributes(
                            output_path,
                            population,
                            selectivity_attr_dict[selectivity_type_name],
                            namespace=selectivity_type_namespace,
                            comm=comm,
                            io_size=io_size,
                            chunk_size=chunk_size,
                            value_chunk_size=value_chunk_size)
                del selectivity_attr_dict
                selectivity_attr_dict = dict(
                    (key, dict()) for key in env.selectivity_types)

            if debug and iter_count >= 10:
                break

        pop_norm_distances[population] = this_pop_norm_distances
        rate_map_sum[population] = this_rate_map_sum

        total_gid_count = 0
        gid_count_dict = dict(gid_count.items())
        selectivity_gid_count = comm.reduce(gid_count_dict,
                                            root=0,
                                            op=mpi_op_merge_count_dict)

        if rank == 0:
            for selectivity_type_name in selectivity_gid_count:
                total_gid_count += selectivity_gid_count[selectivity_type_name]
            for selectivity_type_name in selectivity_gid_count:
                logger.info(
                    'generated selectivity features for %i/%i %s %s cells in %.2f s'
                    % (selectivity_gid_count[selectivity_type_name],
                       total_gid_count, population, selectivity_type_name,
                       (time.time() - start_time)))

        if not dry_run:
            for selectivity_type_name in sorted(selectivity_attr_dict.keys()):
                if rank == 0:
                    logger.info('writing selectivity features for %s [%s]...' %
                                (population, selectivity_type_name))
                selectivity_type_namespace = selectivity_type_namespaces[
                    selectivity_type_name]
                append_cell_attributes(
                    output_path,
                    population,
                    selectivity_attr_dict[selectivity_type_name],
                    namespace=selectivity_type_namespace,
                    comm=comm,
                    io_size=io_size,
                    chunk_size=chunk_size,
                    value_chunk_size=value_chunk_size)
        del selectivity_attr_dict
        comm.barrier()

    if gather:
        merged_pop_norm_distances = {}
        for population in sorted(populations):
            merged_pop_norm_distances[population] = \
              comm.reduce(pop_norm_distances[population], root=0,
                          op=mpi_op_merge_dict)
        rate_map_sum = dict([(key, dict(val.items()))
                             for key, val in viewitems(rate_map_sum)])
        rate_map_sum = comm.gather(rate_map_sum, root=0)
        if rank == 0:
            merged_rate_map_sum = defaultdict(
                lambda: defaultdict(lambda: np.zeros_like(arena_x_mesh)))
            for each_rate_map_sum in rate_map_sum:
                for population in each_rate_map_sum:
                    for selectivity_type_name in each_rate_map_sum[population]:
                        merged_rate_map_sum[population][selectivity_type_name] = \
                            np.add(merged_rate_map_sum[population][selectivity_type_name],
                                   each_rate_map_sum[population][selectivity_type_name])
            if plot:
                for population in merged_pop_norm_distances:
                    norm_distance_values = np.asarray(
                        list(merged_pop_norm_distances[population].values()))
                    hist, edges = np.histogram(norm_distance_values, bins=100)
                    fig, axes = plt.subplots(1)
                    axes.plot(edges[1:], hist)
                    axes.set_title('Population: %s' % population)
                    axes.set_xlabel('Normalized cell position')
                    axes.set_ylabel('Cell count')
                    clean_axes(axes)
                    if save_fig is not None:
                        save_figure('%s %s normalized distances histogram' %
                                    (save_fig, population),
                                    fig=fig,
                                    **fig_options())
                    if fig_options.showFig:
                        fig.show()
                for population in merged_rate_map_sum:
                    for selectivity_type_name in merged_rate_map_sum[
                            population]:
                        fig_title = '%s %s summed rate maps' % (
                            population, this_selectivity_type_name)
                        if save_fig is not None:
                            fig_options.saveFig = '%s %s' % (save_fig,
                                                             fig_title)
                        plot_2D_rate_map(
                            x=arena_x_mesh,
                            y=arena_y_mesh,
                            rate_map=merged_rate_map_sum[population]
                            [selectivity_type_name],
                            title='Summed rate maps\n%s %s cells' %
                            (population, selectivity_type_name),
                            **fig_options())

    if interactive and rank == 0:
        context.update(locals())