def debug_callback(context):
    from dentate.plot import plot_2D_rate_map, close_figure
    fig_title = '%s %s cell %i' % (
        context.population, context.this_selectivity_type_name, context.gid)
    fig_options = copy.copy(context.fig_options)
    if context.save_fig is not None:
        fig_options.saveFig = '%s %s' % (context.save_fig, fig_title)
    fig = plot_2D_rate_map(x=context.arena_x_mesh,
                           y=context.arena_y_mesh,
                           rate_map=context.rate_map,
                           peak_rate=context.env.stimulus_config['Peak Rate']
                           [context.population][context.this_selectivity_type],
                           title='%s\nNormalized cell position: %.3f' %
                           (fig_title, context.norm_u_arc_distance),
                           **fig_options())
    close_figure(fig)
def main(config, config_prefix, coords_path, distances_namespace, bin_distance,
         selectivity_path, selectivity_namespace, subset_seed, arena_id,
         populations, io_size, cache_size, verbose, debug, show_fig, save_fig,
         save_fig_dir, font_size, fig_size, colormap, fig_format):
    """

    :param config: str (.yaml file name)
    :param config_prefix: str (path to dir)
    :param coords_path: str (path to file)
    :param distances_namespace: str
    :param bin_distance: float
    :param selectivity_path: str
    :param subset_seed: int; for reproducible choice of gids to plot individual rate maps
    :param arena_id: str
    :param populations: tuple of str
    :param io_size: int
    :param cache_size: int
    :param verbose: bool
    :param debug: bool
    :param show_fig: bool
    :param save_fig: str (base file name)
    :param save_fig_dir:  str (path to dir)
    :param font_size: float
    :param fig_format: str
    """
    comm = MPI.COMM_WORLD
    rank = comm.rank

    config_logging(verbose)

    env = Env(comm=comm,
              config_file=config,
              config_prefix=config_prefix,
              template_paths=None)
    if io_size == -1:
        io_size = comm.size
    if rank == 0:
        logger.info('%i ranks have been allocated' % comm.size)

    fig_options = copy.copy(default_fig_options)
    fig_options.saveFigDir = save_fig_dir
    fig_options.fontSize = font_size
    fig_options.figFormat = fig_format
    fig_options.showFig = show_fig
    fig_options.figSize = fig_size

    if save_fig is not None:
        save_fig = '%s %s' % (save_fig, arena_id)
    fig_options.saveFig = save_fig

    population_ranges = read_population_ranges(selectivity_path, comm)[0]
    coords_population_ranges = read_population_ranges(coords_path, comm)[0]

    if len(populations) == 0:
        populations = ('MC', 'ConMC', 'LPP', 'GC', 'MPP', 'CA3c')

    valid_selectivity_namespaces = dict()
    if rank == 0:
        for population in populations:
            if population not in population_ranges:
                raise RuntimeError(
                    'plot_input_selectivity_features: specified population: %s not found in '
                    'provided selectivity_path: %s' %
                    (population, selectivity_path))
            if population not in env.stimulus_config[
                    'Selectivity Type Probabilities']:
                raise RuntimeError(
                    'plot_input_selectivity_features: selectivity type not specified for '
                    'population: %s' % population)
            valid_selectivity_namespaces[population] = []
            with h5py.File(selectivity_path, 'r') as selectivity_f:
                for this_namespace in selectivity_f['Populations'][population]:
                    if f'{selectivity_namespace} {arena_id}' in this_namespace:
                        valid_selectivity_namespaces[population].append(
                            this_namespace)
                if len(valid_selectivity_namespaces[population]) == 0:
                    raise RuntimeError(
                        'plot_input_selectivity_features: no selectivity data in arena: %s found '
                        'for specified population: %s in provided selectivity_path: %s'
                        % (arena_id, population, selectivity_path))

    valid_selectivity_namespaces = comm.bcast(valid_selectivity_namespaces,
                                              root=0)
    selectivity_type_names = dict(
        (val, key) for (key, val) in viewitems(env.selectivity_types))

    reference_u_arc_distance_bounds = None
    reference_v_arc_distance_bounds = None
    if rank == 0:
        for population in populations:
            if population not in coords_population_ranges:
                raise RuntimeError(
                    'plot_input_selectivity_features: specified population: %s not found in '
                    'provided coords_path: %s' % (population, coords_path))
            with h5py.File(coords_path, 'r') as coords_f:
                pop_size = population_ranges[population][1]
                unique_gid_count = len(
                    set(coords_f['Populations'][population]
                        [distances_namespace]['U Distance']['Cell Index'][:]))
                if pop_size != unique_gid_count:
                    raise RuntimeError(
                        'plot_input_selectivity_features: only %i/%i unique cell indexes found '
                        'for specified population: %s in provided coords_path: %s'
                        %
                        (unique_gid_count, pop_size, population, coords_path))
                if reference_u_arc_distance_bounds is None:
                    try:
                        reference_u_arc_distance_bounds = \
                            coords_f['Populations'][population][distances_namespace].attrs['Reference U Min'], \
                            coords_f['Populations'][population][distances_namespace].attrs['Reference U Max']
                    except Exception:
                        raise RuntimeError(
                            'plot_input_selectivity_features: problem locating attributes '
                            'containing reference bounds in namespace: %s for population: %s from '
                            'coords_path: %s' %
                            (distances_namespace, population, coords_path))
                if reference_v_arc_distance_bounds is None:
                    try:
                        reference_v_arc_distance_bounds = \
                            coords_f['Populations'][population][distances_namespace].attrs['Reference V Min'], \
                            coords_f['Populations'][population][distances_namespace].attrs['Reference V Max']
                    except Exception:
                        raise RuntimeError(
                            'plot_input_selectivity_features: problem locating attributes '
                            'containing reference bounds in namespace: %s for population: %s from '
                            'coords_path: %s' %
                            (distances_namespace, population, coords_path))
    reference_u_arc_distance_bounds = comm.bcast(
        reference_u_arc_distance_bounds, root=0)
    reference_v_arc_distance_bounds = comm.bcast(
        reference_v_arc_distance_bounds, root=0)

    u_edges = np.arange(reference_u_arc_distance_bounds[0],
                        reference_u_arc_distance_bounds[1] + bin_distance / 2.,
                        bin_distance)
    v_edges = np.arange(reference_v_arc_distance_bounds[0],
                        reference_v_arc_distance_bounds[1] + bin_distance / 2.,
                        bin_distance)

    if arena_id not in env.stimulus_config['Arena']:
        raise RuntimeError(
            'Arena with ID: %s not specified by configuration at file path: %s'
            % (arena_id, config_prefix + '/' + config))

    arena = env.stimulus_config['Arena'][arena_id]
    arena_x_mesh, arena_y_mesh = None, None
    if rank == 0:
        arena_x_mesh, arena_y_mesh = \
            get_2D_arena_spatial_mesh(arena=arena, spatial_resolution=env.stimulus_config['Spatial Resolution'])
    arena_x_mesh = comm.bcast(arena_x_mesh, root=0)
    arena_y_mesh = comm.bcast(arena_y_mesh, root=0)

    for population in populations:

        start_time = time.time()
        u_distances_by_gid = dict()
        v_distances_by_gid = dict()
        distances_attr_gen = \
            bcast_cell_attributes(coords_path, population, root=0, namespace=distances_namespace, comm=comm)
        for gid, distances_attr_dict in distances_attr_gen:
            u_distances_by_gid[gid] = distances_attr_dict['U Distance'][0]
            v_distances_by_gid[gid] = distances_attr_dict['V Distance'][0]

        if rank == 0:
            logger.info(
                'Reading %i cell positions for population %s took %.2f s' %
                (len(u_distances_by_gid), population,
                 time.time() - start_time))

        for this_selectivity_namespace in valid_selectivity_namespaces[
                population]:
            start_time = time.time()
            if rank == 0:
                logger.info('Reading from %s namespace for population %s...' %
                            (this_selectivity_namespace, population))
            gid_count = 0
            gathered_cell_attributes = defaultdict(list)
            gathered_component_attributes = defaultdict(list)
            u_distances_by_cell = list()
            v_distances_by_cell = list()
            u_distances_by_component = list()
            v_distances_by_component = list()
            rate_map_sum_by_module = defaultdict(
                lambda: np.zeros_like(arena_x_mesh))
            start_time = time.time()
            selectivity_attr_gen = NeuroH5CellAttrGen(
                selectivity_path,
                population,
                namespace=this_selectivity_namespace,
                comm=comm,
                io_size=io_size,
                cache_size=cache_size)
            for iter_count, (
                    gid,
                    selectivity_attr_dict) in enumerate(selectivity_attr_gen):
                if gid is not None:
                    gid_count += 1
                    this_selectivity_type = selectivity_attr_dict[
                        'Selectivity Type'][0]
                    this_selectivity_type_name = selectivity_type_names[
                        this_selectivity_type]
                    input_cell_config = \
                        get_input_cell_config(selectivity_type=this_selectivity_type,
                                               selectivity_type_names=selectivity_type_names,
                                               selectivity_attr_dict=selectivity_attr_dict)
                    rate_map = input_cell_config.get_rate_map(x=arena_x_mesh,
                                                              y=arena_y_mesh)
                    u_distances_by_cell.append(u_distances_by_gid[gid])
                    v_distances_by_cell.append(v_distances_by_gid[gid])
                    this_cell_attrs, component_count, this_component_attrs = input_cell_config.gather_attributes(
                    )
                    for attr_name, attr_val in viewitems(this_cell_attrs):
                        gathered_cell_attributes[attr_name].append(attr_val)
                    gathered_cell_attributes['Mean Rate'].append(
                        np.mean(rate_map))
                    if component_count > 0:
                        u_distances_by_component.extend(
                            [u_distances_by_gid[gid]] * component_count)
                        v_distances_by_component.extend(
                            [v_distances_by_gid[gid]] * component_count)
                        for attr_name, attr_val in viewitems(
                                this_component_attrs):
                            gathered_component_attributes[attr_name].extend(
                                attr_val)
                    this_module_id = this_cell_attrs['Module ID']
                    if debug and rank == 0:
                        fig_title = '%s %s cell %i' % (
                            population, this_selectivity_type_name, gid)
                        if save_fig is not None:
                            fig_options.saveFig = '%s %s' % (save_fig,
                                                             fig_title)
                        plot_2D_rate_map(
                            x=arena_x_mesh,
                            y=arena_y_mesh,
                            rate_map=rate_map,
                            peak_rate=env.stimulus_config['Peak Rate']
                            [population][this_selectivity_type],
                            title='%s\nModule: %i' %
                            (fig_title, this_module_id),
                            **fig_options())
                    rate_map_sum_by_module[this_module_id] = np.add(
                        rate_map, rate_map_sum_by_module[this_module_id])
                if debug and iter_count >= 10:
                    break

            cell_count_hist, _, _ = np.histogram2d(u_distances_by_cell,
                                                   v_distances_by_cell,
                                                   bins=[u_edges, v_edges])
            component_count_hist, _, _ = np.histogram2d(
                u_distances_by_component,
                v_distances_by_component,
                bins=[u_edges, v_edges])

            if debug:
                context.update(locals())

            gathered_cell_attr_hist = dict()
            gathered_component_attr_hist = dict()
            for key in gathered_cell_attributes:
                gathered_cell_attr_hist[key], _, _ = \
                    np.histogram2d(u_distances_by_cell, v_distances_by_cell, bins=[u_edges, v_edges],
                                   weights=gathered_cell_attributes[key])
            for key in gathered_component_attributes:
                gathered_component_attr_hist[key], _, _ = \
                    np.histogram2d(u_distances_by_component, v_distances_by_component, bins=[u_edges, v_edges],
                                   weights=gathered_component_attributes[key])
            gid_count = comm.gather(gid_count, root=0)
            cell_count_hist = comm.gather(cell_count_hist, root=0)
            component_count_hist = comm.gather(component_count_hist, root=0)
            gathered_cell_attr_hist = comm.gather(gathered_cell_attr_hist,
                                                  root=0)
            gathered_component_attr_hist = comm.gather(
                gathered_component_attr_hist, root=0)
            rate_map_sum_by_module = dict(rate_map_sum_by_module)
            rate_map_sum_by_module = comm.gather(rate_map_sum_by_module,
                                                 root=0)

            if rank == 0:
                gid_count = sum(gid_count)
                cell_count_hist = np.sum(cell_count_hist, axis=0)
                component_count_hist = np.sum(component_count_hist, axis=0)
                merged_cell_attr_hist = defaultdict(
                    lambda: np.zeros_like(cell_count_hist))
                merged_component_attr_hist = defaultdict(
                    lambda: np.zeros_like(component_count_hist))
                for each_cell_attr_hist in gathered_cell_attr_hist:
                    for key in each_cell_attr_hist:
                        merged_cell_attr_hist[key] = np.add(
                            merged_cell_attr_hist[key],
                            each_cell_attr_hist[key])
                for each_component_attr_hist in gathered_component_attr_hist:
                    for key in each_component_attr_hist:
                        merged_component_attr_hist[key] = np.add(
                            merged_component_attr_hist[key],
                            each_component_attr_hist[key])
                merged_rate_map_sum_by_module = defaultdict(
                    lambda: np.zeros_like(arena_x_mesh))
                for each_rate_map_sum_by_module in rate_map_sum_by_module:
                    for this_module_id in each_rate_map_sum_by_module:
                        merged_rate_map_sum_by_module[this_module_id] = \
                            np.add(merged_rate_map_sum_by_module[this_module_id],
                                   each_rate_map_sum_by_module[this_module_id])

                logger.info('Processing %i %s %s cells took %.2f s' %
                            (gid_count, population, this_selectivity_type_name,
                             time.time() - start_time))

                if debug:
                    context.update(locals())

                for key in merged_cell_attr_hist:
                    fig_title = '%s %s cells %s distribution' % (
                        population, this_selectivity_type_name, key)
                    if save_fig is not None:
                        fig_options.saveFig = '%s %s' % (save_fig, fig_title)
                    if colormap is not None:
                        fig_options.colormap = colormap
                    title = '%s %s cells\n%s distribution' % (
                        population, this_selectivity_type_name, key)
                    fig = plot_2D_histogram(
                        merged_cell_attr_hist[key],
                        x_edges=u_edges,
                        y_edges=v_edges,
                        norm=cell_count_hist,
                        ylabel='Transverse position (um)',
                        xlabel='Septo-temporal position (um)',
                        title=title,
                        cbar_label='Mean value per bin',
                        cbar=True,
                        **fig_options())
                    close_figure(fig)

                for key in merged_component_attr_hist:
                    fig_title = '%s %s cells %s distribution' % (
                        population, this_selectivity_type_name, key)
                    if save_fig is not None:
                        fig_options.saveFig = '%s %s' % (save_fig, fig_title)
                    title = '%s %s cells\n%s distribution' % (
                        population, this_selectivity_type_name, key)
                    fig = plot_2D_histogram(
                        merged_component_attr_hist[key],
                        x_edges=u_edges,
                        y_edges=v_edges,
                        norm=component_count_hist,
                        ylabel='Transverse position (um)',
                        xlabel='Septo-temporal position (um)',
                        title=title,
                        cbar_label='Mean value per bin',
                        cbar=True,
                        **fig_options())
                    close_figure(fig)

                for this_module_id in merged_rate_map_sum_by_module:
                    fig_title = '%s %s Module %i summed rate maps' % \
                                (population, this_selectivity_type_name, this_module_id)
                    if save_fig is not None:
                        fig_options.saveFig = '%s %s' % (save_fig, fig_title)
                    fig = plot_2D_rate_map(
                        x=arena_x_mesh,
                        y=arena_y_mesh,
                        rate_map=merged_rate_map_sum_by_module[this_module_id],
                        title='%s %s summed rate maps\nModule %i' %
                        (population, this_selectivity_type_name,
                         this_module_id),
                        **fig_options())
                    close_figure(fig)

    if is_interactive and rank == 0:
        context.update(locals())
Exemplo n.º 3
0
def main(config, config_prefix, coords_path, distances_namespace, bin_distance,
         selectivity_path, selectivity_namespace, spatial_resolution, arena_id,
         populations, io_size, cache_size, verbose, debug, show_fig, save_fig,
         save_fig_dir, font_size, fig_size, colormap, fig_format):
    """

    :param config: str (.yaml file name)
    :param config_prefix: str (path to dir)
    :param coords_path: str (path to file)
    :param distances_namespace: str
    :param bin_distance: float
    :param selectivity_path: str
    :param arena_id: str
    :param populations: tuple of str
    :param io_size: int
    :param cache_size: int
    :param verbose: bool
    :param debug: bool
    :param show_fig: bool
    :param save_fig: str (base file name)
    :param save_fig_dir:  str (path to dir)
    :param font_size: float
    :param fig_format: str
    """
    comm = MPI.COMM_WORLD
    rank = comm.rank

    config_logging(verbose)

    env = Env(comm=comm,
              config_file=config,
              config_prefix=config_prefix,
              template_paths=None)
    if io_size == -1:
        io_size = comm.size
    if rank == 0:
        logger.info(f'{comm.size} ranks have been allocated')

    fig_options = copy.copy(default_fig_options)
    fig_options.saveFigDir = save_fig_dir
    fig_options.fontSize = font_size
    fig_options.figFormat = fig_format
    fig_options.showFig = show_fig
    fig_options.figSize = fig_size

    if save_fig is not None:
        save_fig = f'{save_fig} {arena_id}'
    fig_options.saveFig = save_fig

    population_ranges = read_population_ranges(selectivity_path, comm)[0]
    coords_population_ranges = read_population_ranges(coords_path, comm)[0]

    if len(populations) == 0:
        populations = ('MC', 'ConMC', 'LPP', 'GC', 'MPP', 'CA3c')

    valid_selectivity_namespaces = dict()
    if rank == 0:
        for population in populations:
            if population not in population_ranges:
                raise RuntimeError(
                    f'plot_input_selectivity_features: specified population: {population} not found in '
                    f'provided selectivity_path: {selectivity_path}')
            if population not in env.stimulus_config[
                    'Selectivity Type Probabilities']:
                raise RuntimeError(
                    'plot_input_selectivity_features: selectivity type not specified for '
                    f'population: {population}')
            valid_selectivity_namespaces[population] = []
            with h5py.File(selectivity_path, 'r') as selectivity_f:
                for this_namespace in selectivity_f['Populations'][population]:
                    if f'{selectivity_namespace} {arena_id}' in this_namespace:
                        valid_selectivity_namespaces[population].append(
                            this_namespace)
                if len(valid_selectivity_namespaces[population]) == 0:
                    raise RuntimeError(
                        f'plot_input_selectivity_features: no selectivity data in arena: {arena_id} found '
                        f'for specified population: {population} in provided selectivity_path: {selectivity_path}'
                    )

    valid_selectivity_namespaces = comm.bcast(valid_selectivity_namespaces,
                                              root=0)
    selectivity_type_names = dict(
        (val, key) for (key, val) in viewitems(env.selectivity_types))

    reference_u_arc_distance_bounds = None
    reference_v_arc_distance_bounds = None
    if rank == 0:
        for population in populations:
            if population not in coords_population_ranges:
                raise RuntimeError(
                    f'plot_input_selectivity_features: specified population: {population} not found in '
                    f'provided coords_path: {coords_path}')
            with h5py.File(coords_path, 'r') as coords_f:
                pop_size = population_ranges[population][1]
                unique_gid_count = len(
                    set(coords_f['Populations'][population]
                        [distances_namespace]['U Distance']['Cell Index'][:]))
                if pop_size != unique_gid_count:
                    raise RuntimeError(
                        f'plot_input_selectivity_features: only {unique_gid_count}/{pop_size} unique cell indexes found '
                        f'for specified population: {population} in provided coords_path: {coords_path}'
                    )
                if reference_u_arc_distance_bounds is None:
                    try:
                        reference_u_arc_distance_bounds = \
                            coords_f['Populations'][population][distances_namespace].attrs['Reference U Min'], \
                            coords_f['Populations'][population][distances_namespace].attrs['Reference U Max']
                    except Exception:
                        raise RuntimeError(
                            'plot_input_selectivity_features: problem locating attributes '
                            f'containing reference bounds in namespace: {distances_namespace} '
                            f'for population: {population} from coords_path: {coords_path}'
                        )
                if reference_v_arc_distance_bounds is None:
                    try:
                        reference_v_arc_distance_bounds = \
                            coords_f['Populations'][population][distances_namespace].attrs['Reference V Min'], \
                            coords_f['Populations'][population][distances_namespace].attrs['Reference V Max']
                    except Exception:
                        raise RuntimeError(
                            'plot_input_selectivity_features: problem locating attributes '
                            f'containing reference bounds in namespace: {distances_namespace} '
                            f'for population: {population} from coords_path: {coords_path}'
                        )
    reference_u_arc_distance_bounds = comm.bcast(
        reference_u_arc_distance_bounds, root=0)
    reference_v_arc_distance_bounds = comm.bcast(
        reference_v_arc_distance_bounds, root=0)

    u_edges = np.arange(reference_u_arc_distance_bounds[0],
                        reference_u_arc_distance_bounds[1] + bin_distance / 2.,
                        bin_distance)
    v_edges = np.arange(reference_v_arc_distance_bounds[0],
                        reference_v_arc_distance_bounds[1] + bin_distance / 2.,
                        bin_distance)

    if arena_id not in env.stimulus_config['Arena']:
        raise RuntimeError(
            f'Arena with ID: {arena_id} not specified by configuration at file path: {config_prefix}/{config}'
        )

    if spatial_resolution is None:
        spatial_resolution = env.stimulus_config['Spatial Resolution']
    arena = env.stimulus_config['Arena'][arena_id]
    arena_x_mesh, arena_y_mesh = None, None
    if rank == 0:
        arena_x_mesh, arena_y_mesh = \
            get_2D_arena_spatial_mesh(arena=arena, spatial_resolution=spatial_resolution)
    arena_x_mesh = comm.bcast(arena_x_mesh, root=0)
    arena_y_mesh = comm.bcast(arena_y_mesh, root=0)
    x0_dict = {}
    y0_dict = {}

    for population in populations:

        start_time = time.time()
        u_distances_by_gid = dict()
        v_distances_by_gid = dict()
        distances_attr_gen = \
            bcast_cell_attributes(coords_path, population, root=0, namespace=distances_namespace, comm=comm)
        for gid, distances_attr_dict in distances_attr_gen:
            u_distances_by_gid[gid] = distances_attr_dict['U Distance'][0]
            v_distances_by_gid[gid] = distances_attr_dict['V Distance'][0]

        if rank == 0:
            logger.info(
                f'Reading {len(u_distances_by_gid)} cell positions for population {population} took '
                f'{time.time() - start_time:.2f} s')

        for this_selectivity_namespace in valid_selectivity_namespaces[
                population]:
            start_time = time.time()
            if rank == 0:
                logger.info(
                    f'Reading from {this_selectivity_namespace} namespace for population {population}...'
                )
            gid_count = 0
            gathered_cell_attributes = defaultdict(list)
            gathered_component_attributes = defaultdict(list)
            u_distances_by_cell = list()
            v_distances_by_cell = list()
            u_distances_by_component = list()
            v_distances_by_component = list()
            rate_map_sum_by_module = defaultdict(
                lambda: np.zeros_like(arena_x_mesh))
            count_by_module = defaultdict(int)
            start_time = time.time()
            x0_list_by_module = defaultdict(list)
            y0_list_by_module = defaultdict(list)
            selectivity_attr_gen = NeuroH5CellAttrGen(
                selectivity_path,
                population,
                namespace=this_selectivity_namespace,
                comm=comm,
                io_size=io_size,
                cache_size=cache_size)
            for iter_count, (
                    gid,
                    selectivity_attr_dict) in enumerate(selectivity_attr_gen):
                if gid is not None:
                    gid_count += 1
                    this_selectivity_type = selectivity_attr_dict[
                        'Selectivity Type'][0]
                    this_selectivity_type_name = selectivity_type_names[
                        this_selectivity_type]
                    input_cell_config = \
                        get_input_cell_config(selectivity_type=this_selectivity_type,
                                               selectivity_type_names=selectivity_type_names,
                                               selectivity_attr_dict=selectivity_attr_dict)
                    rate_map = input_cell_config.get_rate_map(x=arena_x_mesh,
                                                              y=arena_y_mesh)
                    u_distances_by_cell.append(u_distances_by_gid[gid])
                    v_distances_by_cell.append(v_distances_by_gid[gid])
                    this_cell_attrs, component_count, this_component_attrs = input_cell_config.gather_attributes(
                    )
                    for attr_name, attr_val in viewitems(this_cell_attrs):
                        gathered_cell_attributes[attr_name].append(attr_val)
                    gathered_cell_attributes['Mean Rate'].append(
                        np.mean(rate_map))
                    if component_count > 0:
                        u_distances_by_component.extend(
                            [u_distances_by_gid[gid]] * component_count)
                        v_distances_by_component.extend(
                            [v_distances_by_gid[gid]] * component_count)
                        for attr_name, attr_val in viewitems(
                                this_component_attrs):
                            gathered_component_attributes[attr_name].extend(
                                attr_val)
                    this_module_id = this_cell_attrs['Module ID']
                    if debug and rank == 0:
                        fig_title = f'{population} {this_selectivity_type_name} cell {gid}'
                        if save_fig is not None:
                            fig_options.saveFig = f'{save_fig} {fig_title}'
                        plot_2D_rate_map(
                            x=arena_x_mesh,
                            y=arena_y_mesh,
                            rate_map=rate_map,
                            peak_rate=env.stimulus_config['Peak Rate']
                            [population][this_selectivity_type],
                            title=f'{fig_title}\nModule: {this_module_id}',
                            **fig_options())
                    x0_list_by_module[this_module_id].append(
                        selectivity_attr_dict['X Offset'])
                    y0_list_by_module[this_module_id].append(
                        selectivity_attr_dict['Y Offset'])
                    rate_map_sum_by_module[this_module_id] = np.add(
                        rate_map, rate_map_sum_by_module[this_module_id])
                    count_by_module[this_module_id] += 1
                if debug and iter_count >= 10:
                    break

            if rank == 0:
                logger.info(
                    f'Done reading from {this_selectivity_namespace} namespace for population {population}...'
                )

            cell_count_hist, _, _ = np.histogram2d(u_distances_by_cell,
                                                   v_distances_by_cell,
                                                   bins=[u_edges, v_edges])
            component_count_hist, _, _ = np.histogram2d(
                u_distances_by_component,
                v_distances_by_component,
                bins=[u_edges, v_edges])

            if debug:
                context.update(locals())

            gathered_cell_attr_hist = dict()
            gathered_component_attr_hist = dict()
            for key in gathered_cell_attributes:
                gathered_cell_attr_hist[key], _, _ = \
                    np.histogram2d(u_distances_by_cell, v_distances_by_cell, bins=[u_edges, v_edges],
                                   weights=gathered_cell_attributes[key])
            for key in gathered_component_attributes:
                gathered_component_attr_hist[key], _, _ = \
                    np.histogram2d(u_distances_by_component, v_distances_by_component, bins=[u_edges, v_edges],
                                   weights=gathered_component_attributes[key])
            gid_count = comm.gather(gid_count, root=0)
            cell_count_hist = comm.gather(cell_count_hist, root=0)
            component_count_hist = comm.gather(component_count_hist, root=0)
            gathered_cell_attr_hist = comm.gather(gathered_cell_attr_hist,
                                                  root=0)
            gathered_component_attr_hist = comm.gather(
                gathered_component_attr_hist, root=0)
            x0_list_by_module = dict(x0_list_by_module)
            y0_list_by_module = dict(y0_list_by_module)
            x0_list_by_module = comm.reduce(x0_list_by_module,
                                            op=mpi_op_merge_list_dict,
                                            root=0)
            y0_list_by_module = comm.reduce(y0_list_by_module,
                                            op=mpi_op_merge_list_dict,
                                            root=0)
            rate_map_sum_by_module = dict(rate_map_sum_by_module)
            rate_map_sum_by_module = comm.gather(rate_map_sum_by_module,
                                                 root=0)
            count_by_module = dict(count_by_module)
            count_by_module = comm.reduce(count_by_module,
                                          op=mpi_op_merge_count_dict,
                                          root=0)

            if rank == 0:
                gid_count = sum(gid_count)
                cell_count_hist = np.sum(cell_count_hist, axis=0)
                component_count_hist = np.sum(component_count_hist, axis=0)
                merged_cell_attr_hist = defaultdict(
                    lambda: np.zeros_like(cell_count_hist))
                merged_component_attr_hist = defaultdict(
                    lambda: np.zeros_like(component_count_hist))
                for each_cell_attr_hist in gathered_cell_attr_hist:
                    for key in each_cell_attr_hist:
                        merged_cell_attr_hist[key] = np.add(
                            merged_cell_attr_hist[key],
                            each_cell_attr_hist[key])
                for each_component_attr_hist in gathered_component_attr_hist:
                    for key in each_component_attr_hist:
                        merged_component_attr_hist[key] = np.add(
                            merged_component_attr_hist[key],
                            each_component_attr_hist[key])
                merged_rate_map_sum_by_module = defaultdict(
                    lambda: np.zeros_like(arena_x_mesh))
                for each_rate_map_sum_by_module in rate_map_sum_by_module:
                    for this_module_id in each_rate_map_sum_by_module:
                        merged_rate_map_sum_by_module[this_module_id] = \
                            np.add(merged_rate_map_sum_by_module[this_module_id],
                                   each_rate_map_sum_by_module[this_module_id])

                logger.info(
                    f'Processing {gid_count} {population} {this_selectivity_type_name} cells '
                    f'took {time.time() - start_time:.2f} s')

                if debug:
                    context.update(locals())

                fig_title = f'{population} {this_selectivity_type_name} field offsets'
                if save_fig is not None:
                    fig_options.saveFig = f'{save_fig} {fig_title}'

                for key in merged_cell_attr_hist:
                    fig_title = f'{population} {this_selectivity_type_name} cells {key} distribution'
                    if save_fig is not None:
                        fig_options.saveFig = f'{save_fig} {fig_title}'
                    if colormap is not None:
                        fig_options.colormap = colormap
                    title = f'{population} {this_selectivity_type_name} cells\n{key} distribution'
                    fig = plot_2D_histogram(
                        merged_cell_attr_hist[key],
                        x_edges=u_edges,
                        y_edges=v_edges,
                        norm=cell_count_hist,
                        ylabel='Transverse position (um)',
                        xlabel='Septo-temporal position (um)',
                        title=title,
                        cbar_label='Mean value per bin',
                        cbar=True,
                        **fig_options())
                    close_figure(fig)

                for key in merged_component_attr_hist:
                    fig_title = f'{population} {this_selectivity_type_name} cells {key} distribution'
                    if save_fig is not None:
                        fig_options.saveFig = f'{save_fig} {fig_title}'
                    title = f'{population} {this_selectivity_type_name} cells\n{key} distribution'
                    fig = plot_2D_histogram(
                        merged_component_attr_hist[key],
                        x_edges=u_edges,
                        y_edges=v_edges,
                        norm=component_count_hist,
                        ylabel='Transverse position (um)',
                        xlabel='Septo-temporal position (um)',
                        title=title,
                        cbar_label='Mean value per bin',
                        cbar=True,
                        **fig_options())
                    close_figure(fig)

                for this_module_id in merged_rate_map_sum_by_module:
                    num_cells = count_by_module[this_module_id]
                    x0 = np.concatenate(x0_list_by_module[this_module_id])
                    y0 = np.concatenate(y0_list_by_module[this_module_id])
                    fig_title = f'{population} {this_selectivity_type_name} Module {this_module_id} rate map'
                    if save_fig is not None:
                        fig_options.saveFig = f'{save_fig} {fig_title}'
                    fig = plot_2D_rate_map(
                        x=arena_x_mesh,
                        y=arena_y_mesh,
                        x0=x0,
                        y0=y0,
                        rate_map=merged_rate_map_sum_by_module[this_module_id],
                        title=
                        (f'{population} {this_selectivity_type_name} rate map\n'
                         f'Module {this_module_id} ({num_cells} cells)'),
                        **fig_options())
                    close_figure(fig)

    if is_interactive and rank == 0:
        context.update(locals())
def main(config, config_prefix, coords_path, distances_namespace, output_path,
         arena_id, populations, use_noise_gen, io_size, chunk_size,
         value_chunk_size, cache_size, write_size, verbose, gather,
         interactive, debug, debug_count, plot, show_fig, save_fig,
         save_fig_dir, font_size, fig_format, dry_run):
    """

    :param config: str (.yaml file name)
    :param config_prefix: str (path to dir)
    :param coords_path: str (path to file)
    :param distances_namespace: str
    :param output_path: str
    :param arena_id: str
    :param populations: tuple of str
    :param io_size: int
    :param chunk_size: int
    :param value_chunk_size: int
    :param cache_size: int
    :param write_size: int
    :param verbose: bool
    :param gather: bool; whether to gather population attributes to rank 0 for interactive analysis or plotting
    :param interactive: bool
    :param debug: bool
    :param plot: bool
    :param show_fig: bool
    :param save_fig: str (base file name)
    :param save_fig_dir:  str (path to dir)
    :param font_size: float
    :param fig_format: str
    :param dry_run: bool
    """
    comm = MPI.COMM_WORLD
    rank = comm.rank

    config_logging(verbose)

    env = Env(comm=comm,
              config_file=config,
              config_prefix=config_prefix,
              template_paths=None)
    if io_size == -1:
        io_size = comm.size
    if rank == 0:
        logger.info(f'{comm.size} ranks have been allocated')

    if save_fig is not None:
        plot = True

    if plot:
        import matplotlib.pyplot as plt
        from dentate.plot import plot_2D_rate_map, default_fig_options, save_figure, clean_axes, close_figure

        fig_options = copy.copy(default_fig_options)
        fig_options.saveFigDir = save_fig_dir
        fig_options.fontSize = font_size
        fig_options.figFormat = fig_format
        fig_options.showFig = show_fig

    if save_fig is not None:
        save_fig = '%s %s' % (save_fig, arena_id)
        fig_options.saveFig = save_fig

    if not dry_run and rank == 0:
        if output_path is None:
            raise RuntimeError(
                'generate_input_selectivity_features: missing output_path')
        if not os.path.isfile(output_path):
            input_file = h5py.File(coords_path, 'r')
            output_file = h5py.File(output_path, 'w')
            input_file.copy('/H5Types', output_file)
            input_file.close()
            output_file.close()
    comm.barrier()
    population_ranges = read_population_ranges(coords_path, comm)[0]

    if len(populations) == 0:
        populations = sorted(population_ranges.keys())

    reference_u_arc_distance_bounds_dict = {}
    if rank == 0:
        for population in sorted(populations):
            if population not in population_ranges:
                raise RuntimeError(
                    'generate_input_selectivity_features: specified population: %s not found in '
                    'provided coords_path: %s' % (population, coords_path))
            if population not in env.stimulus_config[
                    'Selectivity Type Probabilities']:
                raise RuntimeError(
                    'generate_input_selectivity_features: selectivity type not specified for '
                    'population: %s' % population)
            with h5py.File(coords_path, 'r') as coords_f:
                pop_size = population_ranges[population][1]
                unique_gid_count = len(
                    set(coords_f['Populations'][population]
                        [distances_namespace]['U Distance']['Cell Index'][:]))
                if pop_size != unique_gid_count:
                    raise RuntimeError(
                        'generate_input_selectivity_features: only %i/%i unique cell indexes found '
                        'for specified population: %s in provided coords_path: %s'
                        %
                        (unique_gid_count, pop_size, population, coords_path))
                try:
                    reference_u_arc_distance_bounds_dict[population] = \
                      coords_f['Populations'][population][distances_namespace].attrs['Reference U Min'], \
                      coords_f['Populations'][population][distances_namespace].attrs['Reference U Max']
                except Exception:
                    raise RuntimeError(
                        'generate_input_selectivity_features: problem locating attributes '
                        'containing reference bounds in namespace: %s for population: %s from '
                        'coords_path: %s' %
                        (distances_namespace, population, coords_path))
    comm.barrier()
    reference_u_arc_distance_bounds_dict = comm.bcast(
        reference_u_arc_distance_bounds_dict, root=0)

    selectivity_type_names = dict([
        (val, key) for (key, val) in viewitems(env.selectivity_types)
    ])
    selectivity_type_namespaces = dict()
    for this_selectivity_type in selectivity_type_names:
        this_selectivity_type_name = selectivity_type_names[
            this_selectivity_type]
        chars = list(this_selectivity_type_name)
        chars[0] = chars[0].upper()
        selectivity_type_namespaces[this_selectivity_type_name] = ''.join(
            chars) + ' Selectivity %s' % arena_id

    if arena_id not in env.stimulus_config['Arena']:
        raise RuntimeError(
            f'Arena with ID: {arena_id} not specified by configuration at file path: {config_prefix}/{config}'
        )
    arena = env.stimulus_config['Arena'][arena_id]
    arena_x_mesh, arena_y_mesh = None, None
    if rank == 0:
        arena_x_mesh, arena_y_mesh = \
             get_2D_arena_spatial_mesh(arena=arena, spatial_resolution=env.stimulus_config['Spatial Resolution'])
    arena_x_mesh = comm.bcast(arena_x_mesh, root=0)
    arena_y_mesh = comm.bcast(arena_y_mesh, root=0)

    local_random = np.random.RandomState()
    selectivity_seed_offset = int(
        env.model_config['Random Seeds']['Input Selectivity'])
    local_random.seed(selectivity_seed_offset - 1)

    selectivity_config = InputSelectivityConfig(env.stimulus_config,
                                                local_random)
    if plot and rank == 0:
        selectivity_config.plot_module_probabilities(**fig_options())

    if (debug or interactive) and rank == 0:
        context.update(dict(locals()))

    pop_norm_distances = {}
    rate_map_sum = {}
    x0_dict = {}
    y0_dict = {}
    write_every = max(1, int(math.floor(write_size / comm.size)))
    for population in sorted(populations):
        if rank == 0:
            logger.info(
                f'Generating input selectivity features for population {population}...'
            )

        reference_u_arc_distance_bounds = reference_u_arc_distance_bounds_dict[
            population]

        modular = True
        if population in env.stimulus_config[
                'Non-modular Place Selectivity Populations']:
            modular = False

        noise_gen_dict = None
        if use_noise_gen:
            noise_gen_dict = {}
            if modular:
                for module_id in range(env.stimulus_config['Number Modules']):
                    extent_x, extent_y = get_2D_arena_extents(arena)
                    margin = round(
                        selectivity_config.place_module_field_widths[module_id]
                        / 2.)
                    arena_x_bounds, arena_y_bounds = get_2D_arena_bounds(
                        arena, margin=margin)
                    noise_gen = MPINoiseGenerator(
                        comm=comm,
                        bounds=(arena_x_bounds, arena_y_bounds),
                        tile_rank=comm.rank,
                        bin_size=0.5,
                        mask_fraction=0.99,
                        seed=int(selectivity_seed_offset + module_id * 1e6))
                    noise_gen_dict[module_id] = noise_gen
            else:
                margin = round(
                    np.mean(selectivity_config.place_module_field_widths) / 2.)
                arena_x_bounds, arena_y_bounds = get_2D_arena_bounds(
                    arena, margin=margin)
                noise_gen_dict[-1] = MPINoiseGenerator(
                    comm=comm,
                    bounds=(arena_x_bounds, arena_y_bounds),
                    tile_rank=comm.rank,
                    bin_size=0.5,
                    mask_fraction=0.99,
                    seed=selectivity_seed_offset)

        this_pop_norm_distances = {}
        this_rate_map_sum = defaultdict(lambda: np.zeros_like(arena_x_mesh))
        this_x0_list = []
        this_y0_list = []
        start_time = time.time()
        gid_count = defaultdict(lambda: 0)
        distances_attr_gen = NeuroH5CellAttrGen(coords_path,
                                                population,
                                                namespace=distances_namespace,
                                                comm=comm,
                                                io_size=io_size,
                                                cache_size=cache_size)

        selectivity_attr_dict = dict(
            (key, dict()) for key in env.selectivity_types)
        for iter_count, (gid,
                         distances_attr_dict) in enumerate(distances_attr_gen):
            req = comm.Ibarrier()
            if gid is None:
                if noise_gen_dict is not None:
                    all_module_ids = [-1]
                    if modular:
                        all_module_ids = comm.allreduce(set([]),
                                                        op=mpi_op_set_union)
                    for module_id in all_module_ids:
                        this_noise_gen = noise_gen_dict[module_id]
                        global_num_fields = this_noise_gen.sync(0)
                        for i in range(global_num_fields):
                            this_noise_gen.add(
                                np.empty(shape=(0, 0), dtype=np.float32), None)
            else:
                if rank == 0:
                    logger.info(
                        f'Rank {rank} generating selectivity features for gid {gid}...'
                    )
                u_arc_distance = distances_attr_dict['U Distance'][0]
                v_arc_distance = distances_attr_dict['V Distance'][0]
                norm_u_arc_distance = (
                    (u_arc_distance - reference_u_arc_distance_bounds[0]) /
                    (reference_u_arc_distance_bounds[1] -
                     reference_u_arc_distance_bounds[0]))

                this_pop_norm_distances[gid] = norm_u_arc_distance

                this_selectivity_type_name, this_selectivity_attr_dict = \
                 generate_input_selectivity_features(env, population, arena,
                                                     arena_x_mesh, arena_y_mesh,
                                                     gid, (norm_u_arc_distance, v_arc_distance),
                                                     selectivity_config, selectivity_type_names,
                                                     selectivity_type_namespaces,
                                                     noise_gen_dict=noise_gen_dict,
                                                     rate_map_sum=this_rate_map_sum,
                                                     debug= (debug_callback, context) if debug else False)
                if 'X Offset' in this_selectivity_attr_dict:
                    this_x0_list.append(this_selectivity_attr_dict['X Offset'])
                    this_y0_list.append(this_selectivity_attr_dict['Y Offset'])
                selectivity_attr_dict[this_selectivity_type_name][
                    gid] = this_selectivity_attr_dict
                gid_count[this_selectivity_type_name] += 1
            if noise_gen_dict is not None:
                for m in noise_gen_dict:
                    noise_gen_dict[m].tile_rank = (
                        noise_gen_dict[m].tile_rank + 1) % comm.size
            req.wait()

            if (iter_count > 0 and iter_count % write_every
                    == 0) or (debug and iter_count == debug_count):
                total_gid_count = 0
                gid_count_dict = dict(gid_count.items())
                req = comm.Ibarrier()
                selectivity_gid_count = comm.reduce(gid_count_dict,
                                                    root=0,
                                                    op=mpi_op_merge_count_dict)
                req.wait()
                if rank == 0:
                    for selectivity_type_name in selectivity_gid_count:
                        total_gid_count += selectivity_gid_count[
                            selectivity_type_name]
                    for selectivity_type_name in selectivity_gid_count:
                        logger.info(
                            'generated selectivity features for %i/%i %s %s cells in %.2f s'
                            % (selectivity_gid_count[selectivity_type_name],
                               total_gid_count, population,
                               selectivity_type_name,
                               (time.time() - start_time)))

                if not dry_run:
                    for selectivity_type_name in sorted(
                            selectivity_attr_dict.keys()):
                        req = comm.Ibarrier()
                        if rank == 0:
                            logger.info(
                                f'writing selectivity features for {population} [{selectivity_type_name}]...'
                            )
                        selectivity_type_namespace = selectivity_type_namespaces[
                            selectivity_type_name]
                        append_cell_attributes(
                            output_path,
                            population,
                            selectivity_attr_dict[selectivity_type_name],
                            namespace=selectivity_type_namespace,
                            comm=comm,
                            io_size=io_size,
                            chunk_size=chunk_size,
                            value_chunk_size=value_chunk_size)
                        req.wait()
                    del selectivity_attr_dict
                    selectivity_attr_dict = dict(
                        (key, dict()) for key in env.selectivity_types)
                    gc.collect()

            if debug and iter_count >= debug_count:
                break

        pop_norm_distances[population] = this_pop_norm_distances
        rate_map_sum[population] = dict(this_rate_map_sum)
        if len(this_x0_list) > 0:
            x0_dict[population] = np.concatenate(this_x0_list, axis=None)
            y0_dict[population] = np.concatenate(this_y0_list, axis=None)

        total_gid_count = 0
        gid_count_dict = dict(gid_count.items())
        req = comm.Ibarrier()
        selectivity_gid_count = comm.reduce(gid_count_dict,
                                            root=0,
                                            op=mpi_op_merge_count_dict)
        req.wait()

        if rank == 0:
            for selectivity_type_name in selectivity_gid_count:
                total_gid_count += selectivity_gid_count[selectivity_type_name]
            for selectivity_type_name in selectivity_gid_count:
                logger.info(
                    'generated selectivity features for %i/%i %s %s cells in %.2f s'
                    % (selectivity_gid_count[selectivity_type_name],
                       total_gid_count, population, selectivity_type_name,
                       (time.time() - start_time)))

        if not dry_run:
            for selectivity_type_name in sorted(selectivity_attr_dict.keys()):
                req = comm.Ibarrier()
                if rank == 0:
                    logger.info(
                        f'writing selectivity features for {population} [{selectivity_type_name}]...'
                    )
                selectivity_type_namespace = selectivity_type_namespaces[
                    selectivity_type_name]
                append_cell_attributes(
                    output_path,
                    population,
                    selectivity_attr_dict[selectivity_type_name],
                    namespace=selectivity_type_namespace,
                    comm=comm,
                    io_size=io_size,
                    chunk_size=chunk_size,
                    value_chunk_size=value_chunk_size)
                req.wait()
            del selectivity_attr_dict
            gc.collect()
        req = comm.Ibarrier()
        req.wait()

    if gather:
        merged_pop_norm_distances = {}
        for population in sorted(populations):
            merged_pop_norm_distances[population] = \
              comm.reduce(pop_norm_distances[population], root=0,
                          op=mpi_op_merge_dict)
        merged_rate_map_sum = comm.reduce(rate_map_sum,
                                          root=0,
                                          op=mpi_op_merge_rate_map_dict)
        merged_x0 = comm.reduce(x0_dict,
                                root=0,
                                op=mpi_op_concatenate_ndarray_dict)
        merged_y0 = comm.reduce(y0_dict,
                                root=0,
                                op=mpi_op_concatenate_ndarray_dict)
        if rank == 0:
            if plot:
                for population in merged_pop_norm_distances:
                    norm_distance_values = np.asarray(
                        list(merged_pop_norm_distances[population].values()))
                    hist, edges = np.histogram(norm_distance_values, bins=100)
                    fig, axes = plt.subplots(1)
                    axes.plot(edges[1:], hist)
                    axes.set_title(f'Population: {population}')
                    axes.set_xlabel('Normalized cell position')
                    axes.set_ylabel('Cell count')
                    clean_axes(axes)
                    if save_fig is not None:
                        save_figure(
                            f'{save_fig} {population} normalized distances histogram',
                            fig=fig,
                            **fig_options())
                    if fig_options.showFig:
                        fig.show()
                    close_figure(fig)
                for population in merged_rate_map_sum:
                    for selectivity_type_name in merged_rate_map_sum[
                            population]:
                        fig_title = f'{population} {this_selectivity_type_name} summed rate maps'
                        if save_fig is not None:
                            fig_options.saveFig = f'{save_fig} {fig_title}'
                        plot_2D_rate_map(
                            x=arena_x_mesh,
                            y=arena_y_mesh,
                            rate_map=merged_rate_map_sum[population]
                            [selectivity_type_name],
                            title=
                            f'Summed rate maps\n{population} {selectivity_type_name} cells',
                            **fig_options())
                for population in merged_x0:
                    fig_title = f'{population} field offsets'
                    if save_fig is not None:
                        fig_options.saveFig = f'{save_fig} {fig_title}'
                    x0 = merged_x0[population]
                    y0 = merged_y0[population]
                    fig, axes = plt.subplots(1)
                    axes.scatter(x0, y0)
                    if save_fig is not None:
                        save_figure(f'{save_fig} {fig_title}',
                                    fig=fig,
                                    **fig_options())
                    if fig_options.showFig:
                        fig.show()
                    close_figure(fig)

    if interactive and rank == 0:
        context.update(locals())