예제 #1
0
def rate_maps_from_features(env,
                            pop_name,
                            input_features_path,
                            input_features_namespace,
                            cell_index_set,
                            time_range=None,
                            n_trials=1):
    """Initializes presynaptic spike sources from a file with input selectivity features represented as firing rates."""

    if time_range is not None:
        if time_range[0] is None:
            time_range[0] = 0.0

    spatial_resolution = float(env.stimulus_config['Spatial Resolution'])
    temporal_resolution = float(env.stimulus_config['Temporal Resolution'])

    this_input_features_namespace = '%s %s' % (input_features_namespace,
                                               env.arena_id)

    input_features_attr_names = [
        'Selectivity Type', 'Num Fields', 'Field Width', 'Peak Rate',
        'Module ID', 'Grid Spacing', 'Grid Orientation',
        'Field Width Concentration Factor', 'X Offset', 'Y Offset'
    ]

    selectivity_type_names = {
        i: n
        for n, i in viewitems(env.selectivity_types)
    }

    arena = env.stimulus_config['Arena'][env.arena_id]
    arena_x, arena_y = stimulus.get_2D_arena_spatial_mesh(
        arena=arena, spatial_resolution=spatial_resolution)

    trajectory = arena.trajectories[env.trajectory_id]
    t, x, y, d = stimulus.generate_linear_trajectory(
        trajectory, temporal_resolution=temporal_resolution)
    if time_range is not None:
        t_range_inds = np.where((t < time_range[1]) & (t >= time_range[0]))[0]
        t = t[t_range_inds]
        x = x[t_range_inds]
        y = y[t_range_inds]
        d = d[t_range_inds]

    input_rate_map_dict = {}
    pop_index = int(env.Populations[pop_name])
    input_features_iter = scatter_read_cell_attribute_selection(
        input_features_path,
        pop_name,
        selection=list(cell_index_set),
        namespace=this_input_features_namespace,
        mask=set(input_features_attr_names),
        comm=env.comm,
        io_size=env.io_size)
    for gid, selectivity_attr_dict in input_features_iter:

        this_selectivity_type = selectivity_attr_dict['Selectivity Type'][0]
        this_selectivity_type_name = selectivity_type_names[
            this_selectivity_type]
        input_cell_config = stimulus.get_input_cell_config(
            selectivity_type=this_selectivity_type,
            selectivity_type_names=selectivity_type_names,
            selectivity_attr_dict=selectivity_attr_dict)
        if input_cell_config.num_fields > 0:
            rate_map = input_cell_config.get_rate_map(x=x, y=y)
            input_rate_map_dict[gid] = rate_map

    return input_rate_map_dict
def main(config, coordinates, gid, field_width, peak_rate, input_features_path,
         input_features_namespaces, output_features_namespace,
         output_weights_path, output_features_path, initial_weights_path,
         reference_weights_path, h5types_path, synapse_name,
         initial_weights_namespace, reference_weights_namespace,
         output_weights_namespace, reference_weights_are_delta,
         connections_path, optimize_method, destination, sources, arena_id,
         max_delta_weight, field_width_scale, max_iter, verbose, dry_run,
         plot):
    """
    :param config: str (path to .yaml file)
    :param coordinates: tuple of float
    :param gid: int
    :param field_width: float
    :param peak_rate: float
    :param input_features_path: str (path to .h5 file)
    :param input_features_namespaces: str
    :param output_features_namespace: str
    :param output_weights_path: str (path to .h5 file)
    :param output_features_path: str (path to .h5 file)
    :param initial_weights_path: str (path to .h5 file)
    :param reference_weights_path: str (path to .h5 file)
    :param h5types_path: str (path to .h5 file)
    :param synapse_name: str
    :param initial_weights_namespace: str
    :param output_weights_namespace: str
    :param reference_weights_are_delta: bool
    :param connections_path: str (path to .h5 file)
    :param destination: str (population name)
    :param sources: list of str (population name)
    :param arena_id: str
    :param max_delta_weight: float
    :param field_width_scale: float
    :param max_iter: int
    :param verbose: bool
    :param dry_run: bool
    :param interactive: bool
    :param plot: bool
    """
    utils.config_logging(verbose)
    logger = utils.get_script_logger(__file__)

    env = Env(config_file=config)

    if not dry_run:
        if output_weights_path is None:
            raise RuntimeError(
                'Missing required argument: output_weights_path.')
        if not os.path.isfile(output_weights_path):
            if initial_weights_path is not None and os.path.isfile(
                    initial_weights_path):
                input_file_path = initial_weights_path
            elif h5types_path is not None and os.path.isfile(h5types_path):
                input_file_path = h5types_path
            else:
                raise RuntimeError(
                    'Missing required source for h5types: either an initial_weights_path or an '
                    'h5types_path must be provided.')
            with h5py.File(output_weights_path, 'a') as output_file:
                with h5py.File(input_file_path, 'r') as input_file:
                    input_file.copy('/H5Types', output_file)

    this_input_features_namespaces = [
        '%s %s' % (input_features_namespace, arena_id)
        for input_features_namespace in input_features_namespaces
    ]
    features_attr_names = ['Arena Rate Map']
    spatial_resolution = env.stimulus_config['Spatial Resolution']  # cm
    arena = env.stimulus_config['Arena'][arena_id]
    default_run_vel = arena.properties['default run velocity']  # cm/s
    arena_x, arena_y = stimulus.get_2D_arena_spatial_mesh(
        arena, spatial_resolution)
    dim_x = len(arena_x)
    dim_y = len(arena_y)

    if gid is None:
        target_gids = []
    else:
        target_gids = [gid]

    dst_input_features = defaultdict(dict)
    num_fields = len(coordinates)
    this_field_width = np.array([field_width] * num_fields, dtype=np.float32)
    this_scaled_field_width = np.array([field_width * field_width_scale] *
                                       num_fields,
                                       dtype=np.float32)
    this_peak_rate = np.array([peak_rate] * num_fields, dtype=np.float32)
    this_x0 = np.array([x for x, y in coordinates], dtype=np.float32)
    this_y0 = np.array([y for x, y in coordinates], dtype=np.float32)
    this_rate_map = np.asarray(get_rate_map(this_x0, this_y0, this_field_width,
                                            this_peak_rate, arena_x, arena_y),
                               dtype=np.float32)
    target_map = np.asarray(get_rate_map(this_x0, this_y0,
                                         this_scaled_field_width,
                                         this_peak_rate, arena_x, arena_y),
                            dtype=np.float32)
    selectivity_type = env.selectivity_types['place']
    dst_input_features[destination][target_gid] = {
        'Selectivity Type': np.array([selectivity_type], dtype=np.uint8),
        'Num Fields': np.array([num_fields], dtype=np.uint8),
        'Field Width': this_field_width,
        'Peak Rate': this_peak_rate,
        'X Offset': this_x0,
        'Y Offset': this_y0,
        'Arena Rate Map': this_rate_map.ravel()
    }

    initial_weights_by_syn_id_dict = dict()
    selection = [target_gid]
    if initial_weights_path is not None:
        initial_weights_iter = \
            read_cell_attribute_selection(initial_weights_path, destination, namespace=initial_weights_namespace,
                                          selection=selection)
        syn_weight_attr_dict = dict(initial_weights_iter)

        syn_ids = syn_weight_attr_dict[target_gid]['syn_id']
        weights = syn_weight_attr_dict[target_gid][synapse_name]

        for (syn_id, weight) in zip(syn_ids, weights):
            initial_weights_by_syn_id_dict[int(syn_id)] = float(weight)

        logger.info(
            'destination: %s; gid %i; read initial synaptic weights for %i synapses'
            % (destination, target_gid, len(initial_weights_by_syn_id_dict)))

    reference_weights_by_syn_id_dict = None
    if reference_weights_path is not None:
        reference_weights_by_syn_id_dict = dict()
        reference_weights_iter = \
            read_cell_attribute_selection(reference_weights_path, destination, namespace=reference_weights_namespace,
                                          selection=selection)
        syn_weight_attr_dict = dict(reference_weights_iter)

        syn_ids = syn_weight_attr_dict[target_gid]['syn_id']
        weights = syn_weight_attr_dict[target_gid][synapse_name]

        for (syn_id, weight) in zip(syn_ids, weights):
            reference_weights_by_syn_id_dict[int(syn_id)] = float(weight)

        logger.info(
            'destination: %s; gid %i; read reference synaptic weights for %i synapses'
            % (destination, target_gid, len(reference_weights_by_syn_id_dict)))

    source_gid_set_dict = defaultdict(set)
    syn_ids_by_source_gid_dict = defaultdict(list)
    initial_weights_by_source_gid_dict = dict()
    if reference_weights_by_syn_id_dict is None:
        reference_weights_by_source_gid_dict = None
    else:
        reference_weights_by_source_gid_dict = dict()
    (graph, edge_attr_info) = read_graph_selection(file_name=connections_path,
                                                   selection=[target_gid],
                                                   namespaces=['Synapses'])
    syn_id_attr_index = None
    for source, edge_iter in viewitems(graph[destination]):
        if source not in sources:
            continue
        this_edge_attr_info = edge_attr_info[destination][source]
        if 'Synapses' in this_edge_attr_info and \
           'syn_id' in this_edge_attr_info['Synapses']:
            syn_id_attr_index = this_edge_attr_info['Synapses']['syn_id']
        for (destination_gid, edges) in edge_iter:
            assert destination_gid == target_gid
            source_gids, edge_attrs = edges
            syn_ids = edge_attrs['Synapses'][syn_id_attr_index]
            count = 0
            for i in range(len(source_gids)):
                this_source_gid = int(source_gids[i])
                source_gid_set_dict[source].add(this_source_gid)
                this_syn_id = int(syn_ids[i])
                if this_syn_id not in initial_weights_by_syn_id_dict:
                    this_weight = \
                        env.connection_config[destination][source].mechanisms['default'][synapse_name]['weight']
                    initial_weights_by_syn_id_dict[this_syn_id] = this_weight
                syn_ids_by_source_gid_dict[this_source_gid].append(this_syn_id)
                if this_source_gid not in initial_weights_by_source_gid_dict:
                    initial_weights_by_source_gid_dict[this_source_gid] = \
                        initial_weights_by_syn_id_dict[this_syn_id]
                    if reference_weights_by_source_gid_dict is not None:
                        reference_weights_by_source_gid_dict[this_source_gid] = \
                            reference_weights_by_syn_id_dict[this_syn_id]
                count += 1
            logger.info(
                'destination: %s; gid %i; set initial synaptic weights for %d inputs from source population '
                '%s' % (destination, destination_gid, count, source))

    syn_count_by_source_gid_dict = dict()
    for source_gid in syn_ids_by_source_gid_dict:
        syn_count_by_source_gid_dict[source_gid] = len(
            syn_ids_by_source_gid_dict[source_gid])

    input_rate_maps_by_source_gid_dict = dict()
    for source in sources:
        source_gids = list(source_gid_set_dict[source])
        for input_features_namespace in this_input_features_namespaces:
            input_features_iter = read_cell_attribute_selection(
                input_features_path,
                source,
                namespace=input_features_namespace,
                mask=set(features_attr_names),
                selection=source_gids)
            count = 0
            for gid, attr_dict in input_features_iter:
                input_rate_maps_by_source_gid_dict[gid] = attr_dict[
                    'Arena Rate Map'].reshape((dim_x, dim_y))
                count += 1
            logger.info('Read %s feature data for %i cells in population %s' %
                        (input_features_namespace, count, source))

    if is_interactive:
        context.update(locals())

    normalized_delta_weights_dict, arena_LS_map = \
        synapses.generate_structured_weights(target_map=target_map,
                                             initial_weight_dict=initial_weights_by_source_gid_dict,
                                             input_rate_map_dict=input_rate_maps_by_source_gid_dict,
                                             syn_count_dict=syn_count_by_source_gid_dict,
                                             max_delta_weight=max_delta_weight, arena_x=arena_x, arena_y=arena_y,
                                             reference_weight_dict=reference_weights_by_source_gid_dict,
                                             reference_weights_are_delta=reference_weights_are_delta,
                                             reference_weights_namespace=reference_weights_namespace,
                                             optimize_method=optimize_method, verbose=verbose, plot=plot)

    output_syn_ids = np.empty(len(initial_weights_by_syn_id_dict),
                              dtype='uint32')
    output_weights = np.empty(len(initial_weights_by_syn_id_dict),
                              dtype='float32')
    i = 0
    for source_gid, this_weight in viewitems(normalized_delta_weights_dict):
        for syn_id in syn_ids_by_source_gid_dict[source_gid]:
            output_syn_ids[i] = syn_id
            output_weights[i] = this_weight
            i += 1
    output_weights_dict = {
        target_gid: {
            'syn_id': output_syn_ids,
            synapse_name: output_weights
        }
    }

    logger.info('destination: %s; gid %i; generated %s for %i synapses' %
                (destination, target_gid, output_weights_namespace,
                 len(output_weights)))

    if not dry_run:
        this_output_weights_namespace = '%s %s' % (output_weights_namespace,
                                                   arena_id)
        logger.info('Destination: %s; appending %s ...' %
                    (destination, this_output_weights_namespace))
        append_cell_attributes(output_weights_path,
                               destination,
                               output_weights_dict,
                               namespace=this_output_weights_namespace)
        logger.info('Destination: %s; appended %s' %
                    (destination, this_output_weights_namespace))
        output_weights_dict.clear()
        if output_features_path is not None:
            this_output_features_namespace = '%s %s' % (
                output_features_namespace, arena_id)
            cell_attr_dict = dst_input_features[destination]
            cell_attr_dict[target_gid]['Arena State Map'] = np.asarray(
                arena_LS_map.ravel(), dtype=np.float32)
            logger.info('Destination: %s; appending %s ...' %
                        (destination, this_output_features_namespace))
            append_cell_attributes(output_features_path,
                                   destination,
                                   cell_attr_dict,
                                   namespace=this_output_features_namespace)

    if is_interactive:
        context.update(locals())
def main(config, coordinates, field_width, gid, input_features_path,
         input_features_namespaces, initial_weights_path,
         output_features_namespace, output_features_path, output_weights_path,
         reference_weights_path, h5types_path, synapse_name,
         initial_weights_namespace, output_weights_namespace,
         reference_weights_namespace, connections_path, destination, sources,
         non_structured_sources, non_structured_weights_namespace,
         non_structured_weights_path, arena_id, field_width_scale,
         max_opt_iter, max_weight_decay_fraction, optimize_tol, peak_rate,
         reference_weights_are_delta, arena_margin, target_amplitude, io_size,
         chunk_size, value_chunk_size, cache_size, write_size, verbose,
         dry_run, plot, show_fig, save_fig, debug):
    """

    :param config: str (path to .yaml file)
    :param input_features_path: str (path to .h5 file)
    :param initial_weights_path: str (path to .h5 file)
    :param initial_weights_namespace: str
    :param output_weights_namespace: str
    :param connections_path: str (path to .h5 file)
    :param destination: str
    :param sources: list of str
    :param io_size:
    :param chunk_size:
    :param value_chunk_size:
    :param write_size:
    :param verbose:
    :param dry_run:
    :return:
    """

    utils.config_logging(verbose)
    script_name = __file__
    logger = utils.get_script_logger(script_name)

    comm = MPI.COMM_WORLD
    rank = comm.rank
    nranks = comm.size

    if io_size == -1:
        io_size = comm.size
    if rank == 0:
        logger.info(f'{comm.size} ranks have been allocated')

    env = Env(comm=comm, config_file=config, io_size=io_size)
    env.comm.barrier()

    if plot and (not save_fig) and (not show_fig):
        show_fig = True

    if (not dry_run) and (rank == 0):
        if not os.path.isfile(output_weights_path):
            if initial_weights_path is not None:
                input_file = h5py.File(initial_weights_path, 'r')
            elif h5types_path is not None:
                input_file = h5py.File(h5types_path, 'r')
            else:
                raise RuntimeError(
                    'h5types input path must be specified when weights path is not specified.'
                )
            output_file = h5py.File(output_weights_path, 'w')
            input_file.copy('/H5Types', output_file)
            input_file.close()
            output_file.close()
    env.comm.barrier()

    LTD_output_weights_namespace = f'LTD {output_weights_namespace} {arena_id}'
    LTP_output_weights_namespace = f'LTP {output_weights_namespace} {arena_id}'
    this_input_features_namespaces = [
        f'{input_features_namespace} {arena_id}'
        for input_features_namespace in input_features_namespaces
    ]

    selectivity_type_index = {
        i: n
        for n, i in viewitems(env.selectivity_types)
    }
    target_selectivity_type_name = 'place'
    target_selectivity_type = env.selectivity_types[
        target_selectivity_type_name]
    features_attrs = defaultdict(dict)
    source_features_attr_names = [
        'Selectivity Type', 'Num Fields', 'Field Width', 'Peak Rate',
        'Module ID', 'Grid Spacing', 'Grid Orientation',
        'Field Width Concentration Factor', 'X Offset', 'Y Offset'
    ]
    target_features_attr_names = [
        'Selectivity Type', 'Num Fields', 'Field Width', 'Peak Rate',
        'X Offset', 'Y Offset'
    ]

    seed_offset = int(
        env.model_config['Random Seeds']['GC Structured Weights'])
    spatial_resolution = env.stimulus_config['Spatial Resolution']  # cm

    arena = env.stimulus_config['Arena'][arena_id]
    default_run_vel = arena.properties['default run velocity']  # cm/s

    gid_count = 0
    start_time = time.time()

    target_gid_set = None
    if len(gid) > 0:
        target_gid_set = set(gid)
    projections = [(source, destination) for source in sources]
    graph_info = read_graph_info(connections_path,
                                 namespaces=['Connections', 'Synapses'],
                                 read_node_index=True)
    for projection in projections:
        if projection not in graph_info:
            raise RuntimeError(
                f'Projection {projection[0]} -> {projection[1]} is not present in connections file.'
            )
        if target_gid_set is None:
            target_gid_set = set(graph_info[projection][1])

    all_sources = sources + non_structured_sources
    src_input_features_attr_dict = {source: {} for source in all_sources}
    for source in sorted(all_sources):
        this_src_input_features_attr_dict = {}
        for this_input_features_namespace in this_input_features_namespaces:
            logger.info(
                f'Rank {rank}: Reading {this_input_features_namespace} feature data for cells in population {source}'
            )
            input_features_dict = scatter_read_cell_attributes(
                input_features_path,
                source,
                namespaces=[this_input_features_namespace],
                mask=set(source_features_attr_names),
                comm=env.comm,
                io_size=env.io_size)
            for gid, attr_dict in input_features_dict[
                    this_input_features_namespace]:
                this_src_input_features_attr_dict[gid] = attr_dict
        src_input_features_attr_dict[
            source] = this_src_input_features_attr_dict
        source_gid_count = env.comm.reduce(
            len(this_src_input_features_attr_dict), op=MPI.SUM, root=0)
        if rank == 0:
            logger.info(
                f'Rank {rank}: Read feature data for {source_gid_count} cells in population {source}'
            )

    dst_gids = []
    if target_gid_set is not None:
        for i, gid in enumerate(target_gid_set):
            if i % nranks == rank:
                dst_gids.append(gid)

    dst_input_features_attr_dict = {}
    for this_input_features_namespace in this_input_features_namespaces:
        feature_count = 0
        gid_count = 0
        logger.info(
            f'Rank {rank}: reading {this_input_features_namespace} feature data for {len(dst_gids)} cells in population {destination}'
        )
        input_features_iter = scatter_read_cell_attribute_selection(
            input_features_path,
            destination,
            namespace=this_input_features_namespace,
            mask=set(target_features_attr_names),
            selection=dst_gids,
            io_size=env.io_size,
            comm=env.comm)
        for gid, attr_dict in input_features_iter:
            gid_count += 1
            if (len(coordinates) > 0) or (attr_dict['Num Fields'][0] > 0):
                dst_input_features_attr_dict[gid] = attr_dict
                feature_count += 1

        logger.info(
            f'Rank {rank}: read {this_input_features_namespace} feature data for '
            f'{gid_count} / {feature_count} cells in population {destination}')
        feature_count = env.comm.reduce(feature_count, op=MPI.SUM, root=0)
        env.comm.barrier()
        if rank == 0:
            logger.info(
                f'Read {this_input_features_namespace} feature data for {feature_count} cells in population {destination}'
            )

    feature_dst_gids = list(dst_input_features_attr_dict.keys())
    all_feature_gids_per_rank = comm.allgather(feature_dst_gids)
    all_feature_gids = sorted(
        [item for sublist in all_feature_gids_per_rank for item in sublist])
    request_dst_gids = []
    for i, gid in enumerate(all_feature_gids):
        if i % nranks == rank:
            request_dst_gids.append(gid)

    dst_input_features_attr_dict = exchange_input_features(
        env.comm, request_dst_gids, dst_input_features_attr_dict)
    dst_gids = list(dst_input_features_attr_dict.keys())

    if rank == 0:
        logger.info(
            f"Rank {rank} feature dict is {dst_input_features_attr_dict}")

    dst_count = env.comm.reduce(len(dst_gids), op=MPI.SUM, root=0)

    logger.info(f"Rank {rank} has {len(dst_gids)} feature gids")
    if rank == 0:
        logger.info(f'Total {dst_count} feature gids')

    max_dst_count = env.comm.allreduce(len(dst_gids), op=MPI.MAX)
    env.comm.barrier()

    max_iter_count = max_dst_count
    output_features_dict = {}
    LTP_output_weights_dict = {}
    LTD_output_weights_dict = {}
    non_structured_output_weights_dict = {}
    for iter_count in range(max_iter_count):

        gc.collect()

        local_time = time.time()
        selection = []
        if len(dst_gids) > 0:
            dst_gid = dst_gids.pop()
            selection.append(dst_gid)
            logger.info(f'Rank {rank} received gid {dst_gid}')

        env.comm.barrier()

        arena_margin_size = 0.
        arena_margin = max(arena_margin, 0.)

        target_selectivity_features_dict = {}
        target_selectivity_config_dict = {}
        target_field_width_dict = {}

        for destination_gid in selection:
            arena_margin_size = init_selectivity_config(
                destination_gid,
                spatial_resolution,
                arena,
                arena_margin,
                arena_margin_size,
                coordinates,
                field_width,
                field_width_scale,
                peak_rate,
                target_selectivity_type,
                selectivity_type_index,
                dst_input_features_attr_dict,
                target_selectivity_features_dict,
                target_selectivity_config_dict,
                target_field_width_dict,
                logger=logger)

        arena_x, arena_y = stimulus.get_2D_arena_spatial_mesh(
            arena, spatial_resolution, margin=arena_margin_size)

        selection = list(target_selectivity_features_dict.keys())

        initial_weights_by_source_gid_dict = defaultdict(lambda: dict())
        initial_weights_by_syn_id_dict = \
          read_weights(initial_weights_path, initial_weights_namespace, synapse_name,
                       destination, selection, env.comm, env.io_size, defaultdict(lambda: dict()),
                       logger=logger if rank == 0 else None)

        non_structured_weights_by_source_gid_dict = defaultdict(lambda: dict())
        non_structured_weights_by_syn_id_dict = None
        if len(non_structured_sources) > 0:
            non_structured_weights_by_syn_id_dict = \
             read_weights(non_structured_weights_path, non_structured_weights_namespace, synapse_name,
                          destination, selection, env.comm, env.io_size, defaultdict(lambda: dict()),
                          logger=logger if rank == 0 else None)

        reference_weights_by_syn_id_dict = None
        reference_weights_by_source_gid_dict = defaultdict(lambda: dict())
        if reference_weights_path is not None:
            reference_weights_by_syn_id_dict = \
             read_weights(reference_weights_path, reference_weights_namespace, synapse_name,
                          destination, selection, env.comm, env.io_size, defaultdict(lambda: dict()),
                          logger=logger if rank == 0 else None)

        source_gid_set_dict = defaultdict(set)
        syn_count_by_source_gid_dict = defaultdict(lambda: defaultdict(int))
        syn_ids_by_source_gid_dict = defaultdict(lambda: defaultdict(list))
        structured_syn_id_count = defaultdict(int)
        non_structured_syn_id_count = defaultdict(int)

        projections = [(source, destination) for source in all_sources]
        edge_iter_dict, edge_attr_info = scatter_read_graph_selection(
            connections_path,
            selection=selection,
            namespaces=['Synapses'],
            projections=projections,
            comm=env.comm,
            io_size=env.io_size)

        syn_counts_by_source = init_syn_weight_dicts(
            destination, non_structured_sources, edge_iter_dict,
            edge_attr_info, initial_weights_by_syn_id_dict,
            initial_weights_by_source_gid_dict,
            non_structured_weights_by_syn_id_dict,
            non_structured_weights_by_source_gid_dict,
            reference_weights_by_syn_id_dict,
            reference_weights_by_source_gid_dict, source_gid_set_dict,
            syn_count_by_source_gid_dict, syn_ids_by_source_gid_dict,
            structured_syn_id_count, non_structured_syn_id_count)

        for source in syn_counts_by_source:
            for this_gid in syn_counts_by_source[source]:
                count = syn_counts_by_source[source][this_gid]
                logger.info(
                    f'Rank {rank}: destination: {destination}; gid {this_gid}; '
                    f'{count} edges from source population {source}')

        input_rate_maps_by_source_gid_dict = {}
        if len(non_structured_sources) > 0:
            non_structured_input_rate_maps_by_source_gid_dict = {}
        else:
            non_structured_input_rate_maps_by_source_gid_dict = None

        for source in all_sources:
            source_gids = list(source_gid_set_dict[source])
            if rank == 0:
                logger.info(
                    f'Rank {rank}: getting feature data for {len(source_gids)} cells in population {source}'
                )
            this_src_input_features = exchange_input_features(
                env.comm, source_gids, src_input_features_attr_dict[source])

            count = 0
            for this_gid in source_gids:
                attr_dict = this_src_input_features[this_gid]
                this_selectivity_type = attr_dict['Selectivity Type'][0]
                this_selectivity_type_name = selectivity_type_index[
                    this_selectivity_type]
                input_cell_config = stimulus.get_input_cell_config(
                    this_selectivity_type,
                    selectivity_type_index,
                    selectivity_attr_dict=attr_dict)
                this_arena_rate_map = np.asarray(
                    input_cell_config.get_rate_map(arena_x, arena_y),
                    dtype=np.float32)
                if source in non_structured_sources:
                    non_structured_input_rate_maps_by_source_gid_dict[
                        this_gid] = this_arena_rate_map
                else:
                    input_rate_maps_by_source_gid_dict[
                        this_gid] = this_arena_rate_map
                count += 1

        for destination_gid in selection:

            if is_interactive:
                context.update(locals())

            save_fig_path = None
            if save_fig is not None:
                save_fig_path = f'{save_fig}/Structured Weights {destination} {destination_gid}.png'

            reference_weight_dict = None
            if reference_weights_path is not None:
                reference_weight_dict = reference_weights_by_source_gid_dict[
                    destination_gid]

            LTP_delta_weights_dict, LTD_delta_weights_dict, arena_structured_map = \
               synapses.generate_structured_weights(destination_gid,
                                                 target_map=target_selectivity_features_dict[destination_gid]['Arena Rate Map'],
                                                 initial_weight_dict=initial_weights_by_source_gid_dict[destination_gid],
                                                 #reference_weight_dict=reference_weight_dict,
                                                 #reference_weights_are_delta=reference_weights_are_delta,
                                                 #reference_weights_namespace=reference_weights_namespace,
                                                 input_rate_map_dict=input_rate_maps_by_source_gid_dict,
                                                 non_structured_input_rate_map_dict=non_structured_input_rate_maps_by_source_gid_dict,
                                                 non_structured_weights_dict=non_structured_weights_by_source_gid_dict[destination_gid],
                                                 syn_count_dict=syn_count_by_source_gid_dict[destination_gid],
                                                 max_opt_iter=max_opt_iter,
                                                 max_weight_decay_fraction=max_weight_decay_fraction,
                                                 target_amplitude=target_amplitude,
                                                 arena_x=arena_x, arena_y=arena_y,
                                                 optimize_tol=optimize_tol,
                                                 verbose=verbose if rank == 0 else False,
                                                 plot=plot, show_fig=show_fig,
                                                 save_fig=save_fig_path,
                                                 fig_kwargs={'gid': destination_gid,
                                                             'field_width': target_field_width_dict[destination_gid]})
            input_rate_maps_by_source_gid_dict.clear()

            target_map_flat = target_selectivity_features_dict[
                destination_gid]['Arena Rate Map'].flat
            arena_map_residual_mae = np.mean(
                np.abs(arena_structured_map - target_map_flat))
            output_features_dict[destination_gid] = \
               { fld: target_selectivity_features_dict[destination_gid][fld]
                 for fld in ['Selectivity Type',
                             'Num Fields',
                             'Field Width',
                             'Peak Rate',
                             'X Offset',
                             'Y Offset',]}
            output_features_dict[destination_gid][
                'Rate Map Residual Mean Error'] = np.asarray(
                    [arena_map_residual_mae], dtype=np.float32)

            this_structured_syn_id_count = structured_syn_id_count[
                destination_gid]
            output_syn_ids = np.empty(this_structured_syn_id_count,
                                      dtype='uint32')
            LTD_output_weights = np.empty(this_structured_syn_id_count,
                                          dtype='float32')
            LTP_output_weights = np.empty(this_structured_syn_id_count,
                                          dtype='float32')
            i = 0
            for source_gid in LTP_delta_weights_dict:
                for syn_id in syn_ids_by_source_gid_dict[destination_gid][
                        source_gid]:
                    output_syn_ids[i] = syn_id
                    LTP_output_weights[i] = LTP_delta_weights_dict[source_gid]
                    LTD_output_weights[i] = LTD_delta_weights_dict[source_gid]
                    i += 1
            LTP_output_weights_dict[destination_gid] = {
                'syn_id': output_syn_ids,
                synapse_name: LTP_output_weights
            }
            LTD_output_weights_dict[destination_gid] = {
                'syn_id': output_syn_ids,
                synapse_name: LTD_output_weights
            }

            this_non_structured_syn_id_count = non_structured_syn_id_count[
                destination_gid]
            i = 0

            logger.info(
                f'Rank {rank}; destination: {destination}; gid {destination_gid}; '
                f'generated structured weights for {len(output_syn_ids)} inputs in {time.time() - local_time:.2f} s; '
                f'residual error is {arena_map_residual_mae:.2f}')
            gid_count += 1
            gc.collect()

        env.comm.barrier()
        if (write_size > 0) and (iter_count % write_size == 0):
            if not dry_run:
                append_cell_attributes(output_weights_path,
                                       destination,
                                       LTD_output_weights_dict,
                                       namespace=LTD_output_weights_namespace,
                                       comm=env.comm,
                                       io_size=env.io_size,
                                       chunk_size=chunk_size,
                                       value_chunk_size=value_chunk_size)
                append_cell_attributes(output_weights_path,
                                       destination,
                                       LTP_output_weights_dict,
                                       namespace=LTP_output_weights_namespace,
                                       comm=env.comm,
                                       io_size=env.io_size,
                                       chunk_size=chunk_size,
                                       value_chunk_size=value_chunk_size)
                count = env.comm.reduce(len(LTP_output_weights_dict),
                                        op=MPI.SUM,
                                        root=0)
                env.comm.barrier()

                if rank == 0:
                    logger.info(
                        f'Destination: {destination}; appended weights for {count} cells'
                    )
                if output_features_path is not None:
                    if output_features_namespace is None:
                        output_features_namespace = f'{target_selectivity_type_name.title()} Selectivity'
                    this_output_features_namespace = f'{output_features_namespace} {arena_id}'
                    append_cell_attributes(
                        output_features_path,
                        destination,
                        output_features_dict,
                        namespace=this_output_features_namespace)
                    count = env.comm.reduce(len(output_features_dict),
                                            op=MPI.SUM,
                                            root=0)
                    env.comm.barrier()

                    if rank == 0:
                        logger.info(
                            f'Destination: {destination}; appended selectivity features for {count} cells'
                        )

            LTP_output_weights_dict.clear()
            LTD_output_weights_dict.clear()
            output_features_dict.clear()
            gc.collect()

        env.comm.barrier()

        if (iter_count >= 10) and debug:
            break

    env.comm.barrier()
    if not dry_run:
        append_cell_attributes(output_weights_path,
                               destination,
                               LTD_output_weights_dict,
                               namespace=LTD_output_weights_namespace,
                               comm=env.comm,
                               io_size=env.io_size,
                               chunk_size=chunk_size,
                               value_chunk_size=value_chunk_size)
        append_cell_attributes(output_weights_path,
                               destination,
                               LTP_output_weights_dict,
                               namespace=LTP_output_weights_namespace,
                               comm=env.comm,
                               io_size=env.io_size,
                               chunk_size=chunk_size,
                               value_chunk_size=value_chunk_size)
        count = comm.reduce(len(LTP_output_weights_dict), op=MPI.SUM, root=0)
        env.comm.barrier()

        if rank == 0:
            logger.info(
                f'Destination: {destination}; appended weights for {count} cells'
            )
        if output_features_path is not None:
            if output_features_namespace is None:
                output_features_namespace = 'Selectivity Features'
            this_output_features_namespace = f'{output_features_namespace} {arena_id}'
            append_cell_attributes(output_features_path,
                                   destination,
                                   output_features_dict,
                                   namespace=this_output_features_namespace)
            count = env.comm.reduce(len(output_features_dict),
                                    op=MPI.SUM,
                                    root=0)
            env.comm.barrier()

            if rank == 0:
                logger.info(
                    f'Destination: {destination}; appended selectivity features for {count} cells'
                )

    env.comm.barrier()
    global_count = env.comm.gather(gid_count, root=0)
    env.comm.barrier()

    if rank == 0:
        total_count = np.sum(global_count)
        total_time = time.time() - start_time
        logger.info(
            f'Destination: {destination}; '
            f'{env.comm.size} ranks assigned structured weights to {total_count} cells in {total_time:.2f} s'
        )
def init_selectivity_config(destination_gid,
                            spatial_resolution,
                            arena,
                            arena_margin,
                            arena_margin_size,
                            coordinates,
                            field_width,
                            field_width_scale,
                            peak_rate,
                            target_selectivity_type,
                            selectivity_type_index,
                            input_features_attr_dict,
                            target_selectivity_features_dict,
                            target_selectivity_config_dict,
                            target_field_width_dict,
                            logger=None):

    assert (destination_gid in input_features_attr_dict)
    this_target_selectivity_features_dict = input_features_attr_dict[
        destination_gid]
    this_target_selectivity_features_dict['Selectivity Type'] = np.asarray(
        [target_selectivity_type], dtype=np.uint8)

    if len(coordinates) > 0:
        num_fields = len(coordinates)
        this_target_selectivity_features_dict['X Offset'] = np.asarray(
            [x[0] for x in coordinates], dtype=np.float32)
        this_target_selectivity_features_dict['Y Offset'] = np.asarray(
            [x[1] for x in coordinates], dtype=np.float32)
        this_target_selectivity_features_dict['Num Fields'] = np.asarray(
            [num_fields], dtype=np.uint8)
    elif 'Num Fields' in this_target_selectivity_features_dict:
        num_fields = this_target_selectivity_features_dict['Num Fields'][0]
    else:
        num_fields = 0

    if field_width is not None:
        this_target_selectivity_features_dict['Field Width'] = np.asarray(
            [field_width] * num_fields, dtype=np.float32)
    elif 'Field Width' in this_target_selectivity_features_dict:
        this_field_width = this_target_selectivity_features_dict['Field Width']
        this_target_selectivity_features_dict[
            'Field Width'] = this_field_width[:num_fields]
    else:
        this_field_width = np.asarray([], dtype=np.float32)

    if peak_rate is not None:
        this_target_selectivity_features_dict['Peak Rate'] = np.asarray(
            [peak_rate] * num_fields, dtype=np.float32)

    if num_fields > 0:
        input_cell_config = stimulus.get_input_cell_config(
            target_selectivity_type,
            selectivity_type_index,
            selectivity_attr_dict=this_target_selectivity_features_dict)
        arena_margin_size = max(
            arena_margin_size,
            np.max(input_cell_config.field_width) * arena_margin)

        arena_x, arena_y = stimulus.get_2D_arena_spatial_mesh(
            arena, spatial_resolution, margin=arena_margin_size)

        target_map = np.asarray(input_cell_config.get_rate_map(
            arena_x, arena_y, scale=field_width_scale),
                                dtype=np.float32).flatten()

        this_target_selectivity_features_dict['Arena Rate Map'] = target_map
        target_selectivity_features_dict[
            destination_gid] = this_target_selectivity_features_dict
        target_field_width_dict[
            destination_gid] = input_cell_config.field_width
        target_selectivity_config_dict[destination_gid] = input_cell_config

    return arena_margin_size
def main(config, coordinates, field_width, gid, input_features_path,
         input_features_namespaces, initial_weights_path,
         output_features_namespace, output_features_path, output_weights_path,
         reference_weights_path, h5types_path, synapse_name,
         initial_weights_namespace, output_weights_namespace,
         reference_weights_namespace, connections_path, destination, sources,
         non_structured_sources, non_structured_weights_namespace,
         non_structured_weights_path, arena_id, field_width_scale,
         max_delta_weight, max_opt_iter, max_weight_decay_fraction,
         optimize_method, optimize_tol, optimize_grad, peak_rate,
         reference_weights_are_delta, arena_margin, target_amplitude, io_size,
         chunk_size, value_chunk_size, cache_size, write_size, verbose,
         dry_run, plot, show_fig, save_fig):
    """

    :param config: str (path to .yaml file)
    :param input_features_path: str (path to .h5 file)
    :param initial_weights_path: str (path to .h5 file)
    :param initial_weights_namespace: str
    :param output_weights_namespace: str
    :param connections_path: str (path to .h5 file)
    :param destination: str
    :param sources: list of str
    :param io_size:
    :param chunk_size:
    :param value_chunk_size:
    :param cache_size:
    :param write_size:
    :param verbose:
    :param dry_run:
    :return:
    """

    utils.config_logging(verbose)
    logger = utils.get_script_logger(__file__)

    comm = MPI.COMM_WORLD
    rank = comm.rank
    nranks = comm.size

    if io_size == -1:
        io_size = comm.size
    if rank == 0:
        logger.info('%s: %i ranks have been allocated' % (__file__, comm.size))

    env = Env(comm=comm, config_file=config, io_size=io_size)

    if plot and (not save_fig) and (not show_fig):
        show_fig = True

    if (not dry_run) and (rank == 0):
        if not os.path.isfile(output_weights_path):
            if initial_weights_path is not None:
                input_file = h5py.File(initial_weights_path, 'r')
            elif h5types_path is not None:
                input_file = h5py.File(h5types_path, 'r')
            else:
                raise RuntimeError(
                    'h5types input path must be specified when weights path is not specified.'
                )
            output_file = h5py.File(output_weights_path, 'w')
            input_file.copy('/H5Types', output_file)
            input_file.close()
            output_file.close()
    env.comm.barrier()

    LTD_output_weights_namespace = 'LTD %s %s' % (output_weights_namespace,
                                                  arena_id)
    LTP_output_weights_namespace = 'LTP %s %s' % (output_weights_namespace,
                                                  arena_id)
    this_input_features_namespaces = [
        '%s %s' % (input_features_namespace, arena_id)
        for input_features_namespace in input_features_namespaces
    ]

    selectivity_type_index = {
        i: n
        for n, i in viewitems(env.selectivity_types)
    }
    target_selectivity_type_name = 'place'
    target_selectivity_type = env.selectivity_types[
        target_selectivity_type_name]
    features_attrs = defaultdict(dict)
    source_features_attr_names = [
        'Selectivity Type', 'Num Fields', 'Field Width', 'Peak Rate',
        'Module ID', 'Grid Spacing', 'Grid Orientation',
        'Field Width Concentration Factor', 'X Offset', 'Y Offset'
    ]
    target_features_attr_names = [
        'Selectivity Type', 'Num Fields', 'Field Width', 'Peak Rate',
        'X Offset', 'Y Offset'
    ]

    local_random = np.random.RandomState()

    seed_offset = int(
        env.model_config['Random Seeds']['GC Structured Weights'])
    spatial_resolution = env.stimulus_config['Spatial Resolution']  # cm

    arena = env.stimulus_config['Arena'][arena_id]
    default_run_vel = arena.properties['default run velocity']  # cm/s

    gid_count = 0
    start_time = time.time()

    target_gid_set = None
    if len(gid) > 0:
        target_gid_set = set(gid)

    all_sources = sources + non_structured_sources

    connection_gen_list = [ NeuroH5ProjectionGen(connections_path, source, destination, namespaces=['Synapses'], comm=comm) \
                               for source in all_sources ]

    output_features_dict = {}
    LTP_output_weights_dict = {}
    LTD_output_weights_dict = {}
    for iter_count, attr_gen_package in enumerate(
            zip_longest(*connection_gen_list)):

        local_time = time.time()
        this_gid = attr_gen_package[0][0]
        if not all([
                attr_gen_items[0] == this_gid
                for attr_gen_items in attr_gen_package
        ]):
            raise Exception(
                'Rank: %i; destination: %s; this_gid not matched across multiple attribute '
                'generators: %s' %
                (rank, destination,
                 [attr_gen_items[0] for attr_gen_items in attr_gen_package]))

        if (target_gid_set is not None) and (this_gid not in target_gid_set):
            continue

        if this_gid is None:
            selection = []
            logger.info('Rank: %i received None' % rank)
        else:
            selection = [this_gid]
            local_random.seed(int(this_gid + seed_offset))

        has_structured_weights = False

        dst_input_features_attr_dict = {}
        for input_features_namespace in this_input_features_namespaces:
            input_features_iter = read_cell_attribute_selection(
                input_features_path,
                destination,
                namespace=input_features_namespace,
                mask=set(target_features_attr_names),
                comm=env.comm,
                selection=selection)
            count = 0
            for gid, attr_dict in input_features_iter:
                dst_input_features_attr_dict[gid] = attr_dict
                count += 1
            if rank == 0:
                logger.info(
                    'Read %s feature data for %i cells in population %s' %
                    (input_features_namespace, count, destination))

        arena_margin_size = 0.
        arena_margin = max(arena_margin, 0.)
        target_selectivity_features_dict = {}
        target_selectivity_config_dict = {}
        target_field_width_dict = {}
        for gid in selection:
            target_selectivity_features_dict[
                gid] = dst_input_features_attr_dict.get(gid, {})
            target_selectivity_features_dict[gid][
                'Selectivity Type'] = np.asarray([target_selectivity_type],
                                                 dtype=np.uint8)

            num_fields = target_selectivity_features_dict[gid]['Num Fields'][0]

            if coordinates[0] is not None:
                num_fields = 1
                target_selectivity_features_dict[gid]['X Offset'] = np.asarray(
                    [coordinates[0]], dtype=np.float32)
                target_selectivity_features_dict[gid]['Y Offset'] = np.asarray(
                    [coordinates[1]], dtype=np.float32)
                target_selectivity_features_dict[gid][
                    'Num Fields'] = np.asarray([num_fields], dtype=np.uint8)

            if field_width is not None:
                target_selectivity_features_dict[gid][
                    'Field Width'] = np.asarray([field_width] * num_fields,
                                                dtype=np.float32)
            else:
                this_field_width = target_selectivity_features_dict[gid][
                    'Field Width']
                target_selectivity_features_dict[gid][
                    'Field Width'] = this_field_width[:num_fields]

            if peak_rate is not None:
                target_selectivity_features_dict[gid][
                    'Peak Rate'] = np.asarray([peak_rate] * num_fields,
                                              dtype=np.float32)

            input_cell_config = stimulus.get_input_cell_config(
                target_selectivity_type,
                selectivity_type_index,
                selectivity_attr_dict=target_selectivity_features_dict[gid])
            if input_cell_config.num_fields > 0:
                arena_margin_size = max(
                    arena_margin_size,
                    np.max(input_cell_config.field_width) * arena_margin)
                target_field_width_dict[gid] = input_cell_config.field_width
                target_selectivity_config_dict[gid] = input_cell_config
                has_structured_weights = True

        arena_x, arena_y = stimulus.get_2D_arena_spatial_mesh(
            arena, spatial_resolution, margin=arena_margin_size)
        for gid, input_cell_config in viewitems(
                target_selectivity_config_dict):
            target_map = np.asarray(input_cell_config.get_rate_map(
                arena_x, arena_y, scale=field_width_scale),
                                    dtype=np.float32)
            target_selectivity_features_dict[gid][
                'Arena Rate Map'] = target_map

        if not has_structured_weights:
            selection = []

        initial_weights_by_syn_id_dict = defaultdict(lambda: dict())
        initial_weights_by_source_gid_dict = defaultdict(lambda: dict())

        if initial_weights_path is not None:
            initial_weights_iter = \
              read_cell_attribute_selection(initial_weights_path, destination,
                                            namespace=initial_weights_namespace,
                                            selection=selection)

            initial_weights_gid_count = 0
            initial_weights_syn_count = 0
            for this_gid, syn_weight_attr_dict in initial_weights_iter:
                syn_ids = syn_weight_attr_dict['syn_id']
                weights = syn_weight_attr_dict[synapse_name]

                for (syn_id, weight) in zip(syn_ids, weights):
                    initial_weights_by_syn_id_dict[this_gid][int(
                        syn_id)] = float(weight)
                initial_weights_gid_count += 1
                initial_weights_syn_count += len(syn_ids)

            logger.info(
                'destination: %s; read initial synaptic weights for %i gids and %i syns'
                % (destination, initial_weights_gid_count,
                   initial_weights_syn_count))

        if len(non_structured_sources) > 0:
            non_structured_weights_by_syn_id_dict = defaultdict(lambda: dict())
            non_structured_weights_by_source_gid_dict = defaultdict(
                lambda: dict())
        else:
            non_structured_weights_by_syn_id_dict = None

        if non_structured_weights_path is not None:
            non_structured_weights_iter = \
                read_cell_attribute_selection(initial_weights_path, destination,
                                              namespace=non_structured_weights_namespace,
                                              selection=selection)

            non_structured_weights_gid_count = 0
            non_structured_weights_syn_count = 0
            for this_gid, syn_weight_attr_dict in non_structured_weights_iter:
                syn_ids = syn_weight_attr_dict['syn_id']
                weights = syn_weight_attr_dict[synapse_name]

                for (syn_id, weight) in zip(syn_ids, weights):
                    non_structured_weights_by_syn_id_dict[this_gid][int(
                        syn_id)] = float(weight)
                non_structured_weights_gid_count += 1
                non_structured_weights_syn_count += len(syn_ids)

            logger.info(
                'destination: %s; read non-structured synaptic weights for %i gids and %i syns'
                % (
                    destination,
                    non_structured_weights_gid_count,
                    non_structured_weights_syn_count,
                ))

        reference_weights_by_syn_id_dict = None
        reference_weights_by_source_gid_dict = defaultdict(lambda: dict())
        if reference_weights_path is not None:
            reference_weights_by_syn_id_dict = defaultdict(lambda: dict())
            reference_weights_iter = \
              read_cell_attribute_selection(reference_weights_path, destination, namespace=reference_weights_namespace,
                                            selection=selection)
            reference_weights_gid_count = 0

            for this_gid, syn_weight_attr_dict in reference_weights_iter:
                syn_ids = syn_weight_attr_dict['syn_id']
                weights = syn_weight_attr_dict[synapse_name]

                for (syn_id, weight) in zip(syn_ids, weights):
                    reference_weights_by_syn_id_dict[this_gid][int(
                        syn_id)] = float(weight)

            logger.info(
                'destination: %s; read reference synaptic weights for %i gids'
                % (destination, reference_weights_gid_count))

        syn_count_by_source_gid_dict = defaultdict(int)
        source_gid_set_dict = defaultdict(set)
        syn_ids_by_source_gid_dict = defaultdict(list)
        structured_syn_id_count = 0

        if has_structured_weights:
            for source, (destination_gid, (source_gid_array,
                                           conn_attr_dict)) in zip_longest(
                                               all_sources, attr_gen_package):
                syn_ids = conn_attr_dict['Synapses']['syn_id']
                count = 0
                this_initial_weights_by_syn_id_dict = None
                this_initial_weights_by_source_gid_dict = None
                this_reference_weights_by_syn_id_dict = None
                this_reference_weights_by_source_gid_dict = None
                this_non_structured_weights_by_syn_id_dict = None
                this_non_structured_weights_by_source_gid_dict = None
                if destination_gid is not None:
                    this_initial_weights_by_syn_id_dict = initial_weights_by_syn_id_dict[
                        destination_gid]
                    this_initial_weights_by_source_gid_dict = initial_weights_by_source_gid_dict[
                        destination_gid]
                    if reference_weights_by_syn_id_dict is not None:
                        this_reference_weights_by_syn_id_dict = reference_weights_by_syn_id_dict[
                            destination_gid]
                        this_reference_weights_by_source_gid_dict = reference_weights_by_source_gid_dict[
                            destination_gid]
                    this_non_structured_weights_by_syn_id_dict = non_structured_weights_by_syn_id_dict[
                        destination_gid]
                    this_non_structured_weights_by_source_gid_dict = non_structured_weights_by_source_gid_dict[
                        destination_gid]

                for i in range(len(source_gid_array)):
                    this_source_gid = source_gid_array[i]
                    this_syn_id = syn_ids[i]
                    if this_syn_id in this_initial_weights_by_syn_id_dict:
                        this_syn_wgt = this_initial_weights_by_syn_id_dict[
                            this_syn_id]
                        if this_source_gid not in this_initial_weights_by_source_gid_dict:
                            this_initial_weights_by_source_gid_dict[
                                this_source_gid] = this_syn_wgt
                        if this_reference_weights_by_syn_id_dict is not None:
                            this_reference_weights_by_source_gid_dict[this_source_gid] = \
                              this_reference_weights_by_syn_id_dict[this_syn_id]
                    elif this_syn_id in this_non_structured_weights_by_syn_id_dict:
                        this_syn_wgt = this_non_structured_weights_by_syn_id_dict[
                            this_syn_id]
                        if this_source_gid not in this_non_structured_weights_by_source_gid_dict:
                            this_non_structured_weights_by_source_gid_dict[
                                this_source_gid] = this_syn_wgt
                    source_gid_set_dict[source].add(this_source_gid)
                    syn_ids_by_source_gid_dict[this_source_gid].append(
                        this_syn_id)
                    syn_count_by_source_gid_dict[this_source_gid] += 1

                    count += 1
                if source not in non_structured_sources:
                    structured_syn_id_count += len(syn_ids)
                logger.info(
                    'Rank %i; destination: %s; gid %i; %d edges from source population %s'
                    % (rank, destination, this_gid, count, source))

        input_rate_maps_by_source_gid_dict = {}
        if len(non_structured_sources) > 0:
            non_structured_input_rate_maps_by_source_gid_dict = {}
        else:
            non_structured_input_rate_maps_by_source_gid_dict = None
        for source in all_sources:
            if has_structured_weights:
                source_gids = list(source_gid_set_dict[source])
            else:
                source_gids = []
            if rank == 0:
                logger.info(
                    'Reading %s feature data for %i cells in population %s...'
                    % (input_features_namespace, len(source_gids), source))
            for input_features_namespace in this_input_features_namespaces:
                input_features_iter = read_cell_attribute_selection(
                    input_features_path,
                    source,
                    namespace=input_features_namespace,
                    mask=set(source_features_attr_names),
                    comm=env.comm,
                    selection=source_gids)
                count = 0
                for gid, attr_dict in input_features_iter:
                    this_selectivity_type = attr_dict['Selectivity Type'][0]
                    this_selectivity_type_name = selectivity_type_index[
                        this_selectivity_type]
                    input_cell_config = stimulus.get_input_cell_config(
                        this_selectivity_type,
                        selectivity_type_index,
                        selectivity_attr_dict=attr_dict)
                    this_arena_rate_map = np.asarray(
                        input_cell_config.get_rate_map(arena_x, arena_y),
                        dtype=np.float32)
                    if source in non_structured_sources:
                        non_structured_input_rate_maps_by_source_gid_dict[
                            gid] = this_arena_rate_map
                    else:
                        input_rate_maps_by_source_gid_dict[
                            gid] = this_arena_rate_map
                    count += 1
                if rank == 0:
                    logger.info(
                        'Read %s feature data for %i cells in population %s' %
                        (input_features_namespace, count, source))

        if has_structured_weights:

            if is_interactive:
                context.update(locals())

            save_fig_path = None
            if save_fig is not None:
                save_fig_path = '%s/Structured Weights %s %d.png' % (
                    save_fig, destination, this_gid)

            normalized_LTP_delta_weights_dict, LTD_delta_weights_dict, arena_LS_map = \
              synapses.generate_structured_weights(target_map=target_selectivity_features_dict[this_gid]['Arena Rate Map'],
                                                initial_weight_dict=this_initial_weights_by_source_gid_dict,
                                                reference_weight_dict=this_reference_weights_by_source_gid_dict,
                                                reference_weights_are_delta=reference_weights_are_delta,
                                                reference_weights_namespace=reference_weights_namespace,
                                                input_rate_map_dict=input_rate_maps_by_source_gid_dict,
                                                non_structured_input_rate_map_dict=non_structured_input_rate_maps_by_source_gid_dict,
                                                non_structured_weights_dict=this_non_structured_weights_by_source_gid_dict,
                                                syn_count_dict=syn_count_by_source_gid_dict,
                                                max_delta_weight=max_delta_weight,
                                                max_opt_iter=max_opt_iter,
                                                max_weight_decay_fraction=max_weight_decay_fraction,
                                                target_amplitude=target_amplitude,
                                                arena_x=arena_x, arena_y=arena_y,
                                                optimize_method=optimize_method,
                                                optimize_tol=optimize_tol,
                                                optimize_grad=optimize_grad,
                                                verbose=verbose, plot=plot, show_fig=show_fig,
                                                save_fig=save_fig_path,
                                                fig_kwargs={'gid': this_gid,
                                                            'field_width': target_field_width_dict[this_gid]})
            gc.collect()

            this_selectivity_dict = target_selectivity_features_dict[this_gid]
            output_features_dict[this_gid] = {
                fld: this_selectivity_dict[fld]
                for fld in [
                    'Selectivity Type', 'Num Fields', 'Field Width',
                    'Peak Rate', 'X Offset', 'Y Offset'
                ]
            }
            output_features_dict[this_gid]['Arena State Map'] = np.asarray(
                arena_LS_map.ravel(), dtype=np.float32)
            output_syn_ids = np.empty(structured_syn_id_count, dtype='uint32')
            LTD_output_weights = np.empty(structured_syn_id_count,
                                          dtype='float32')
            LTP_output_weights = np.empty(structured_syn_id_count,
                                          dtype='float32')
            i = 0
            for source_gid in normalized_LTP_delta_weights_dict:
                for syn_id in syn_ids_by_source_gid_dict[source_gid]:
                    output_syn_ids[i] = syn_id
                    LTP_output_weights[i] = normalized_LTP_delta_weights_dict[
                        source_gid]
                    LTD_output_weights[i] = LTD_delta_weights_dict[source_gid]
                    i += 1
            LTP_output_weights_dict[this_gid] = {
                'syn_id': output_syn_ids,
                synapse_name: LTP_output_weights
            }
            LTD_output_weights_dict[this_gid] = {
                'syn_id': output_syn_ids,
                synapse_name: LTD_output_weights
            }

            logger.info(
                'Rank %i; destination: %s; gid %i; generated structured weights for %i inputs in %.2f '
                's' % (rank, destination, this_gid, len(output_syn_ids),
                       time.time() - local_time))
            gid_count += 1

        if iter_count % write_size == 0:
            if not dry_run:
                append_cell_attributes(output_weights_path,
                                       destination,
                                       LTD_output_weights_dict,
                                       namespace=LTD_output_weights_namespace,
                                       comm=env.comm,
                                       io_size=env.io_size,
                                       chunk_size=chunk_size,
                                       value_chunk_size=value_chunk_size)
                append_cell_attributes(output_weights_path,
                                       destination,
                                       LTP_output_weights_dict,
                                       namespace=LTP_output_weights_namespace,
                                       comm=env.comm,
                                       io_size=env.io_size,
                                       chunk_size=chunk_size,
                                       value_chunk_size=value_chunk_size)
                count = comm.reduce(len(LTP_output_weights_dict),
                                    op=MPI.SUM,
                                    root=0)
                if rank == 0:
                    logger.info(
                        'Destination: %s; appended weights for %i cells' %
                        (destination, count))
                if output_features_path is not None:
                    if output_features_namespace is None:
                        output_features_namespace = '%s Selectivity' % target_selectivity_type_name.title(
                        )
                    this_output_features_namespace = '%s %s' % (
                        output_features_namespace, arena_id)
                    logger.info(str(output_features_dict))
                    append_cell_attributes(
                        output_features_path,
                        destination,
                        output_features_dict,
                        namespace=this_output_features_namespace)
                    count = comm.reduce(len(output_features_dict),
                                        op=MPI.SUM,
                                        root=0)
                    if rank == 0:
                        logger.info(
                            'Destination: %s; appended selectivity features for %i cells'
                            % (destination, count))

            LTP_output_weights_dict.clear()
            LTD_output_weights_dict.clear()
            output_features_dict.clear()
            gc.collect()

        env.comm.barrier()

    if not dry_run:
        append_cell_attributes(output_weights_path,
                               destination,
                               LTD_output_weights_dict,
                               namespace=LTD_output_weights_namespace,
                               comm=env.comm,
                               io_size=env.io_size,
                               chunk_size=chunk_size,
                               value_chunk_size=value_chunk_size)
        append_cell_attributes(output_weights_path,
                               destination,
                               LTP_output_weights_dict,
                               namespace=LTP_output_weights_namespace,
                               comm=env.comm,
                               io_size=env.io_size,
                               chunk_size=chunk_size,
                               value_chunk_size=value_chunk_size)
        count = comm.reduce(len(LTP_output_weights_dict), op=MPI.SUM, root=0)
        if rank == 0:
            logger.info('Destination: %s; appended weights for %i cells' %
                        (destination, count))
        if output_features_path is not None:
            if output_features_namespace is None:
                output_features_namespace = 'Selectivity Features'
            this_output_features_namespace = '%s %s' % (
                output_features_namespace, arena_id)
            append_cell_attributes(output_features_path,
                                   destination,
                                   output_features_dict,
                                   namespace=this_output_features_namespace)
            count = comm.reduce(len(output_features_dict), op=MPI.SUM, root=0)
            if rank == 0:
                logger.info(
                    'Destination: %s; appended selectivity features for %i cells'
                    % (destination, count))

    env.comm.barrier()
    global_count = comm.gather(gid_count, root=0)
    if rank == 0:
        logger.info(
            'destination: %s; %i ranks assigned structured weights to %i cells in %.2f s'
            % (destination, comm.size, np.sum(global_count),
               time.time() - start_time))
def main(config, config_prefix, coords_path, distances_namespace, bin_distance,
         selectivity_path, selectivity_namespace, subset_seed, arena_id,
         populations, io_size, cache_size, verbose, debug, show_fig, save_fig,
         save_fig_dir, font_size, fig_size, colormap, fig_format):
    """

    :param config: str (.yaml file name)
    :param config_prefix: str (path to dir)
    :param coords_path: str (path to file)
    :param distances_namespace: str
    :param bin_distance: float
    :param selectivity_path: str
    :param subset_seed: int; for reproducible choice of gids to plot individual rate maps
    :param arena_id: str
    :param populations: tuple of str
    :param io_size: int
    :param cache_size: int
    :param verbose: bool
    :param debug: bool
    :param show_fig: bool
    :param save_fig: str (base file name)
    :param save_fig_dir:  str (path to dir)
    :param font_size: float
    :param fig_format: str
    """
    comm = MPI.COMM_WORLD
    rank = comm.rank

    config_logging(verbose)

    env = Env(comm=comm,
              config_file=config,
              config_prefix=config_prefix,
              template_paths=None)
    if io_size == -1:
        io_size = comm.size
    if rank == 0:
        logger.info('%i ranks have been allocated' % comm.size)

    fig_options = copy.copy(default_fig_options)
    fig_options.saveFigDir = save_fig_dir
    fig_options.fontSize = font_size
    fig_options.figFormat = fig_format
    fig_options.showFig = show_fig
    fig_options.figSize = fig_size

    if save_fig is not None:
        save_fig = '%s %s' % (save_fig, arena_id)
    fig_options.saveFig = save_fig

    population_ranges = read_population_ranges(selectivity_path, comm)[0]
    coords_population_ranges = read_population_ranges(coords_path, comm)[0]

    if len(populations) == 0:
        populations = ('MC', 'ConMC', 'LPP', 'GC', 'MPP', 'CA3c')

    valid_selectivity_namespaces = dict()
    if rank == 0:
        for population in populations:
            if population not in population_ranges:
                raise RuntimeError(
                    'plot_input_selectivity_features: specified population: %s not found in '
                    'provided selectivity_path: %s' %
                    (population, selectivity_path))
            if population not in env.stimulus_config[
                    'Selectivity Type Probabilities']:
                raise RuntimeError(
                    'plot_input_selectivity_features: selectivity type not specified for '
                    'population: %s' % population)
            valid_selectivity_namespaces[population] = []
            with h5py.File(selectivity_path, 'r') as selectivity_f:
                for this_namespace in selectivity_f['Populations'][population]:
                    if f'{selectivity_namespace} {arena_id}' in this_namespace:
                        valid_selectivity_namespaces[population].append(
                            this_namespace)
                if len(valid_selectivity_namespaces[population]) == 0:
                    raise RuntimeError(
                        'plot_input_selectivity_features: no selectivity data in arena: %s found '
                        'for specified population: %s in provided selectivity_path: %s'
                        % (arena_id, population, selectivity_path))

    valid_selectivity_namespaces = comm.bcast(valid_selectivity_namespaces,
                                              root=0)
    selectivity_type_names = dict(
        (val, key) for (key, val) in viewitems(env.selectivity_types))

    reference_u_arc_distance_bounds = None
    reference_v_arc_distance_bounds = None
    if rank == 0:
        for population in populations:
            if population not in coords_population_ranges:
                raise RuntimeError(
                    'plot_input_selectivity_features: specified population: %s not found in '
                    'provided coords_path: %s' % (population, coords_path))
            with h5py.File(coords_path, 'r') as coords_f:
                pop_size = population_ranges[population][1]
                unique_gid_count = len(
                    set(coords_f['Populations'][population]
                        [distances_namespace]['U Distance']['Cell Index'][:]))
                if pop_size != unique_gid_count:
                    raise RuntimeError(
                        'plot_input_selectivity_features: only %i/%i unique cell indexes found '
                        'for specified population: %s in provided coords_path: %s'
                        %
                        (unique_gid_count, pop_size, population, coords_path))
                if reference_u_arc_distance_bounds is None:
                    try:
                        reference_u_arc_distance_bounds = \
                            coords_f['Populations'][population][distances_namespace].attrs['Reference U Min'], \
                            coords_f['Populations'][population][distances_namespace].attrs['Reference U Max']
                    except Exception:
                        raise RuntimeError(
                            'plot_input_selectivity_features: problem locating attributes '
                            'containing reference bounds in namespace: %s for population: %s from '
                            'coords_path: %s' %
                            (distances_namespace, population, coords_path))
                if reference_v_arc_distance_bounds is None:
                    try:
                        reference_v_arc_distance_bounds = \
                            coords_f['Populations'][population][distances_namespace].attrs['Reference V Min'], \
                            coords_f['Populations'][population][distances_namespace].attrs['Reference V Max']
                    except Exception:
                        raise RuntimeError(
                            'plot_input_selectivity_features: problem locating attributes '
                            'containing reference bounds in namespace: %s for population: %s from '
                            'coords_path: %s' %
                            (distances_namespace, population, coords_path))
    reference_u_arc_distance_bounds = comm.bcast(
        reference_u_arc_distance_bounds, root=0)
    reference_v_arc_distance_bounds = comm.bcast(
        reference_v_arc_distance_bounds, root=0)

    u_edges = np.arange(reference_u_arc_distance_bounds[0],
                        reference_u_arc_distance_bounds[1] + bin_distance / 2.,
                        bin_distance)
    v_edges = np.arange(reference_v_arc_distance_bounds[0],
                        reference_v_arc_distance_bounds[1] + bin_distance / 2.,
                        bin_distance)

    if arena_id not in env.stimulus_config['Arena']:
        raise RuntimeError(
            'Arena with ID: %s not specified by configuration at file path: %s'
            % (arena_id, config_prefix + '/' + config))

    arena = env.stimulus_config['Arena'][arena_id]
    arena_x_mesh, arena_y_mesh = None, None
    if rank == 0:
        arena_x_mesh, arena_y_mesh = \
            get_2D_arena_spatial_mesh(arena=arena, spatial_resolution=env.stimulus_config['Spatial Resolution'])
    arena_x_mesh = comm.bcast(arena_x_mesh, root=0)
    arena_y_mesh = comm.bcast(arena_y_mesh, root=0)

    for population in populations:

        start_time = time.time()
        u_distances_by_gid = dict()
        v_distances_by_gid = dict()
        distances_attr_gen = \
            bcast_cell_attributes(coords_path, population, root=0, namespace=distances_namespace, comm=comm)
        for gid, distances_attr_dict in distances_attr_gen:
            u_distances_by_gid[gid] = distances_attr_dict['U Distance'][0]
            v_distances_by_gid[gid] = distances_attr_dict['V Distance'][0]

        if rank == 0:
            logger.info(
                'Reading %i cell positions for population %s took %.2f s' %
                (len(u_distances_by_gid), population,
                 time.time() - start_time))

        for this_selectivity_namespace in valid_selectivity_namespaces[
                population]:
            start_time = time.time()
            if rank == 0:
                logger.info('Reading from %s namespace for population %s...' %
                            (this_selectivity_namespace, population))
            gid_count = 0
            gathered_cell_attributes = defaultdict(list)
            gathered_component_attributes = defaultdict(list)
            u_distances_by_cell = list()
            v_distances_by_cell = list()
            u_distances_by_component = list()
            v_distances_by_component = list()
            rate_map_sum_by_module = defaultdict(
                lambda: np.zeros_like(arena_x_mesh))
            start_time = time.time()
            selectivity_attr_gen = NeuroH5CellAttrGen(
                selectivity_path,
                population,
                namespace=this_selectivity_namespace,
                comm=comm,
                io_size=io_size,
                cache_size=cache_size)
            for iter_count, (
                    gid,
                    selectivity_attr_dict) in enumerate(selectivity_attr_gen):
                if gid is not None:
                    gid_count += 1
                    this_selectivity_type = selectivity_attr_dict[
                        'Selectivity Type'][0]
                    this_selectivity_type_name = selectivity_type_names[
                        this_selectivity_type]
                    input_cell_config = \
                        get_input_cell_config(selectivity_type=this_selectivity_type,
                                               selectivity_type_names=selectivity_type_names,
                                               selectivity_attr_dict=selectivity_attr_dict)
                    rate_map = input_cell_config.get_rate_map(x=arena_x_mesh,
                                                              y=arena_y_mesh)
                    u_distances_by_cell.append(u_distances_by_gid[gid])
                    v_distances_by_cell.append(v_distances_by_gid[gid])
                    this_cell_attrs, component_count, this_component_attrs = input_cell_config.gather_attributes(
                    )
                    for attr_name, attr_val in viewitems(this_cell_attrs):
                        gathered_cell_attributes[attr_name].append(attr_val)
                    gathered_cell_attributes['Mean Rate'].append(
                        np.mean(rate_map))
                    if component_count > 0:
                        u_distances_by_component.extend(
                            [u_distances_by_gid[gid]] * component_count)
                        v_distances_by_component.extend(
                            [v_distances_by_gid[gid]] * component_count)
                        for attr_name, attr_val in viewitems(
                                this_component_attrs):
                            gathered_component_attributes[attr_name].extend(
                                attr_val)
                    this_module_id = this_cell_attrs['Module ID']
                    if debug and rank == 0:
                        fig_title = '%s %s cell %i' % (
                            population, this_selectivity_type_name, gid)
                        if save_fig is not None:
                            fig_options.saveFig = '%s %s' % (save_fig,
                                                             fig_title)
                        plot_2D_rate_map(
                            x=arena_x_mesh,
                            y=arena_y_mesh,
                            rate_map=rate_map,
                            peak_rate=env.stimulus_config['Peak Rate']
                            [population][this_selectivity_type],
                            title='%s\nModule: %i' %
                            (fig_title, this_module_id),
                            **fig_options())
                    rate_map_sum_by_module[this_module_id] = np.add(
                        rate_map, rate_map_sum_by_module[this_module_id])
                if debug and iter_count >= 10:
                    break

            cell_count_hist, _, _ = np.histogram2d(u_distances_by_cell,
                                                   v_distances_by_cell,
                                                   bins=[u_edges, v_edges])
            component_count_hist, _, _ = np.histogram2d(
                u_distances_by_component,
                v_distances_by_component,
                bins=[u_edges, v_edges])

            if debug:
                context.update(locals())

            gathered_cell_attr_hist = dict()
            gathered_component_attr_hist = dict()
            for key in gathered_cell_attributes:
                gathered_cell_attr_hist[key], _, _ = \
                    np.histogram2d(u_distances_by_cell, v_distances_by_cell, bins=[u_edges, v_edges],
                                   weights=gathered_cell_attributes[key])
            for key in gathered_component_attributes:
                gathered_component_attr_hist[key], _, _ = \
                    np.histogram2d(u_distances_by_component, v_distances_by_component, bins=[u_edges, v_edges],
                                   weights=gathered_component_attributes[key])
            gid_count = comm.gather(gid_count, root=0)
            cell_count_hist = comm.gather(cell_count_hist, root=0)
            component_count_hist = comm.gather(component_count_hist, root=0)
            gathered_cell_attr_hist = comm.gather(gathered_cell_attr_hist,
                                                  root=0)
            gathered_component_attr_hist = comm.gather(
                gathered_component_attr_hist, root=0)
            rate_map_sum_by_module = dict(rate_map_sum_by_module)
            rate_map_sum_by_module = comm.gather(rate_map_sum_by_module,
                                                 root=0)

            if rank == 0:
                gid_count = sum(gid_count)
                cell_count_hist = np.sum(cell_count_hist, axis=0)
                component_count_hist = np.sum(component_count_hist, axis=0)
                merged_cell_attr_hist = defaultdict(
                    lambda: np.zeros_like(cell_count_hist))
                merged_component_attr_hist = defaultdict(
                    lambda: np.zeros_like(component_count_hist))
                for each_cell_attr_hist in gathered_cell_attr_hist:
                    for key in each_cell_attr_hist:
                        merged_cell_attr_hist[key] = np.add(
                            merged_cell_attr_hist[key],
                            each_cell_attr_hist[key])
                for each_component_attr_hist in gathered_component_attr_hist:
                    for key in each_component_attr_hist:
                        merged_component_attr_hist[key] = np.add(
                            merged_component_attr_hist[key],
                            each_component_attr_hist[key])
                merged_rate_map_sum_by_module = defaultdict(
                    lambda: np.zeros_like(arena_x_mesh))
                for each_rate_map_sum_by_module in rate_map_sum_by_module:
                    for this_module_id in each_rate_map_sum_by_module:
                        merged_rate_map_sum_by_module[this_module_id] = \
                            np.add(merged_rate_map_sum_by_module[this_module_id],
                                   each_rate_map_sum_by_module[this_module_id])

                logger.info('Processing %i %s %s cells took %.2f s' %
                            (gid_count, population, this_selectivity_type_name,
                             time.time() - start_time))

                if debug:
                    context.update(locals())

                for key in merged_cell_attr_hist:
                    fig_title = '%s %s cells %s distribution' % (
                        population, this_selectivity_type_name, key)
                    if save_fig is not None:
                        fig_options.saveFig = '%s %s' % (save_fig, fig_title)
                    if colormap is not None:
                        fig_options.colormap = colormap
                    title = '%s %s cells\n%s distribution' % (
                        population, this_selectivity_type_name, key)
                    fig = plot_2D_histogram(
                        merged_cell_attr_hist[key],
                        x_edges=u_edges,
                        y_edges=v_edges,
                        norm=cell_count_hist,
                        ylabel='Transverse position (um)',
                        xlabel='Septo-temporal position (um)',
                        title=title,
                        cbar_label='Mean value per bin',
                        cbar=True,
                        **fig_options())
                    close_figure(fig)

                for key in merged_component_attr_hist:
                    fig_title = '%s %s cells %s distribution' % (
                        population, this_selectivity_type_name, key)
                    if save_fig is not None:
                        fig_options.saveFig = '%s %s' % (save_fig, fig_title)
                    title = '%s %s cells\n%s distribution' % (
                        population, this_selectivity_type_name, key)
                    fig = plot_2D_histogram(
                        merged_component_attr_hist[key],
                        x_edges=u_edges,
                        y_edges=v_edges,
                        norm=component_count_hist,
                        ylabel='Transverse position (um)',
                        xlabel='Septo-temporal position (um)',
                        title=title,
                        cbar_label='Mean value per bin',
                        cbar=True,
                        **fig_options())
                    close_figure(fig)

                for this_module_id in merged_rate_map_sum_by_module:
                    fig_title = '%s %s Module %i summed rate maps' % \
                                (population, this_selectivity_type_name, this_module_id)
                    if save_fig is not None:
                        fig_options.saveFig = '%s %s' % (save_fig, fig_title)
                    fig = plot_2D_rate_map(
                        x=arena_x_mesh,
                        y=arena_y_mesh,
                        rate_map=merged_rate_map_sum_by_module[this_module_id],
                        title='%s %s summed rate maps\nModule %i' %
                        (population, this_selectivity_type_name,
                         this_module_id),
                        **fig_options())
                    close_figure(fig)

    if is_interactive and rank == 0:
        context.update(locals())
예제 #7
0
def main(config, config_prefix, coords_path, distances_namespace, bin_distance,
         selectivity_path, selectivity_namespace, spatial_resolution, arena_id,
         populations, io_size, cache_size, verbose, debug, show_fig, save_fig,
         save_fig_dir, font_size, fig_size, colormap, fig_format):
    """

    :param config: str (.yaml file name)
    :param config_prefix: str (path to dir)
    :param coords_path: str (path to file)
    :param distances_namespace: str
    :param bin_distance: float
    :param selectivity_path: str
    :param arena_id: str
    :param populations: tuple of str
    :param io_size: int
    :param cache_size: int
    :param verbose: bool
    :param debug: bool
    :param show_fig: bool
    :param save_fig: str (base file name)
    :param save_fig_dir:  str (path to dir)
    :param font_size: float
    :param fig_format: str
    """
    comm = MPI.COMM_WORLD
    rank = comm.rank

    config_logging(verbose)

    env = Env(comm=comm,
              config_file=config,
              config_prefix=config_prefix,
              template_paths=None)
    if io_size == -1:
        io_size = comm.size
    if rank == 0:
        logger.info(f'{comm.size} ranks have been allocated')

    fig_options = copy.copy(default_fig_options)
    fig_options.saveFigDir = save_fig_dir
    fig_options.fontSize = font_size
    fig_options.figFormat = fig_format
    fig_options.showFig = show_fig
    fig_options.figSize = fig_size

    if save_fig is not None:
        save_fig = f'{save_fig} {arena_id}'
    fig_options.saveFig = save_fig

    population_ranges = read_population_ranges(selectivity_path, comm)[0]
    coords_population_ranges = read_population_ranges(coords_path, comm)[0]

    if len(populations) == 0:
        populations = ('MC', 'ConMC', 'LPP', 'GC', 'MPP', 'CA3c')

    valid_selectivity_namespaces = dict()
    if rank == 0:
        for population in populations:
            if population not in population_ranges:
                raise RuntimeError(
                    f'plot_input_selectivity_features: specified population: {population} not found in '
                    f'provided selectivity_path: {selectivity_path}')
            if population not in env.stimulus_config[
                    'Selectivity Type Probabilities']:
                raise RuntimeError(
                    'plot_input_selectivity_features: selectivity type not specified for '
                    f'population: {population}')
            valid_selectivity_namespaces[population] = []
            with h5py.File(selectivity_path, 'r') as selectivity_f:
                for this_namespace in selectivity_f['Populations'][population]:
                    if f'{selectivity_namespace} {arena_id}' in this_namespace:
                        valid_selectivity_namespaces[population].append(
                            this_namespace)
                if len(valid_selectivity_namespaces[population]) == 0:
                    raise RuntimeError(
                        f'plot_input_selectivity_features: no selectivity data in arena: {arena_id} found '
                        f'for specified population: {population} in provided selectivity_path: {selectivity_path}'
                    )

    valid_selectivity_namespaces = comm.bcast(valid_selectivity_namespaces,
                                              root=0)
    selectivity_type_names = dict(
        (val, key) for (key, val) in viewitems(env.selectivity_types))

    reference_u_arc_distance_bounds = None
    reference_v_arc_distance_bounds = None
    if rank == 0:
        for population in populations:
            if population not in coords_population_ranges:
                raise RuntimeError(
                    f'plot_input_selectivity_features: specified population: {population} not found in '
                    f'provided coords_path: {coords_path}')
            with h5py.File(coords_path, 'r') as coords_f:
                pop_size = population_ranges[population][1]
                unique_gid_count = len(
                    set(coords_f['Populations'][population]
                        [distances_namespace]['U Distance']['Cell Index'][:]))
                if pop_size != unique_gid_count:
                    raise RuntimeError(
                        f'plot_input_selectivity_features: only {unique_gid_count}/{pop_size} unique cell indexes found '
                        f'for specified population: {population} in provided coords_path: {coords_path}'
                    )
                if reference_u_arc_distance_bounds is None:
                    try:
                        reference_u_arc_distance_bounds = \
                            coords_f['Populations'][population][distances_namespace].attrs['Reference U Min'], \
                            coords_f['Populations'][population][distances_namespace].attrs['Reference U Max']
                    except Exception:
                        raise RuntimeError(
                            'plot_input_selectivity_features: problem locating attributes '
                            f'containing reference bounds in namespace: {distances_namespace} '
                            f'for population: {population} from coords_path: {coords_path}'
                        )
                if reference_v_arc_distance_bounds is None:
                    try:
                        reference_v_arc_distance_bounds = \
                            coords_f['Populations'][population][distances_namespace].attrs['Reference V Min'], \
                            coords_f['Populations'][population][distances_namespace].attrs['Reference V Max']
                    except Exception:
                        raise RuntimeError(
                            'plot_input_selectivity_features: problem locating attributes '
                            f'containing reference bounds in namespace: {distances_namespace} '
                            f'for population: {population} from coords_path: {coords_path}'
                        )
    reference_u_arc_distance_bounds = comm.bcast(
        reference_u_arc_distance_bounds, root=0)
    reference_v_arc_distance_bounds = comm.bcast(
        reference_v_arc_distance_bounds, root=0)

    u_edges = np.arange(reference_u_arc_distance_bounds[0],
                        reference_u_arc_distance_bounds[1] + bin_distance / 2.,
                        bin_distance)
    v_edges = np.arange(reference_v_arc_distance_bounds[0],
                        reference_v_arc_distance_bounds[1] + bin_distance / 2.,
                        bin_distance)

    if arena_id not in env.stimulus_config['Arena']:
        raise RuntimeError(
            f'Arena with ID: {arena_id} not specified by configuration at file path: {config_prefix}/{config}'
        )

    if spatial_resolution is None:
        spatial_resolution = env.stimulus_config['Spatial Resolution']
    arena = env.stimulus_config['Arena'][arena_id]
    arena_x_mesh, arena_y_mesh = None, None
    if rank == 0:
        arena_x_mesh, arena_y_mesh = \
            get_2D_arena_spatial_mesh(arena=arena, spatial_resolution=spatial_resolution)
    arena_x_mesh = comm.bcast(arena_x_mesh, root=0)
    arena_y_mesh = comm.bcast(arena_y_mesh, root=0)
    x0_dict = {}
    y0_dict = {}

    for population in populations:

        start_time = time.time()
        u_distances_by_gid = dict()
        v_distances_by_gid = dict()
        distances_attr_gen = \
            bcast_cell_attributes(coords_path, population, root=0, namespace=distances_namespace, comm=comm)
        for gid, distances_attr_dict in distances_attr_gen:
            u_distances_by_gid[gid] = distances_attr_dict['U Distance'][0]
            v_distances_by_gid[gid] = distances_attr_dict['V Distance'][0]

        if rank == 0:
            logger.info(
                f'Reading {len(u_distances_by_gid)} cell positions for population {population} took '
                f'{time.time() - start_time:.2f} s')

        for this_selectivity_namespace in valid_selectivity_namespaces[
                population]:
            start_time = time.time()
            if rank == 0:
                logger.info(
                    f'Reading from {this_selectivity_namespace} namespace for population {population}...'
                )
            gid_count = 0
            gathered_cell_attributes = defaultdict(list)
            gathered_component_attributes = defaultdict(list)
            u_distances_by_cell = list()
            v_distances_by_cell = list()
            u_distances_by_component = list()
            v_distances_by_component = list()
            rate_map_sum_by_module = defaultdict(
                lambda: np.zeros_like(arena_x_mesh))
            count_by_module = defaultdict(int)
            start_time = time.time()
            x0_list_by_module = defaultdict(list)
            y0_list_by_module = defaultdict(list)
            selectivity_attr_gen = NeuroH5CellAttrGen(
                selectivity_path,
                population,
                namespace=this_selectivity_namespace,
                comm=comm,
                io_size=io_size,
                cache_size=cache_size)
            for iter_count, (
                    gid,
                    selectivity_attr_dict) in enumerate(selectivity_attr_gen):
                if gid is not None:
                    gid_count += 1
                    this_selectivity_type = selectivity_attr_dict[
                        'Selectivity Type'][0]
                    this_selectivity_type_name = selectivity_type_names[
                        this_selectivity_type]
                    input_cell_config = \
                        get_input_cell_config(selectivity_type=this_selectivity_type,
                                               selectivity_type_names=selectivity_type_names,
                                               selectivity_attr_dict=selectivity_attr_dict)
                    rate_map = input_cell_config.get_rate_map(x=arena_x_mesh,
                                                              y=arena_y_mesh)
                    u_distances_by_cell.append(u_distances_by_gid[gid])
                    v_distances_by_cell.append(v_distances_by_gid[gid])
                    this_cell_attrs, component_count, this_component_attrs = input_cell_config.gather_attributes(
                    )
                    for attr_name, attr_val in viewitems(this_cell_attrs):
                        gathered_cell_attributes[attr_name].append(attr_val)
                    gathered_cell_attributes['Mean Rate'].append(
                        np.mean(rate_map))
                    if component_count > 0:
                        u_distances_by_component.extend(
                            [u_distances_by_gid[gid]] * component_count)
                        v_distances_by_component.extend(
                            [v_distances_by_gid[gid]] * component_count)
                        for attr_name, attr_val in viewitems(
                                this_component_attrs):
                            gathered_component_attributes[attr_name].extend(
                                attr_val)
                    this_module_id = this_cell_attrs['Module ID']
                    if debug and rank == 0:
                        fig_title = f'{population} {this_selectivity_type_name} cell {gid}'
                        if save_fig is not None:
                            fig_options.saveFig = f'{save_fig} {fig_title}'
                        plot_2D_rate_map(
                            x=arena_x_mesh,
                            y=arena_y_mesh,
                            rate_map=rate_map,
                            peak_rate=env.stimulus_config['Peak Rate']
                            [population][this_selectivity_type],
                            title=f'{fig_title}\nModule: {this_module_id}',
                            **fig_options())
                    x0_list_by_module[this_module_id].append(
                        selectivity_attr_dict['X Offset'])
                    y0_list_by_module[this_module_id].append(
                        selectivity_attr_dict['Y Offset'])
                    rate_map_sum_by_module[this_module_id] = np.add(
                        rate_map, rate_map_sum_by_module[this_module_id])
                    count_by_module[this_module_id] += 1
                if debug and iter_count >= 10:
                    break

            if rank == 0:
                logger.info(
                    f'Done reading from {this_selectivity_namespace} namespace for population {population}...'
                )

            cell_count_hist, _, _ = np.histogram2d(u_distances_by_cell,
                                                   v_distances_by_cell,
                                                   bins=[u_edges, v_edges])
            component_count_hist, _, _ = np.histogram2d(
                u_distances_by_component,
                v_distances_by_component,
                bins=[u_edges, v_edges])

            if debug:
                context.update(locals())

            gathered_cell_attr_hist = dict()
            gathered_component_attr_hist = dict()
            for key in gathered_cell_attributes:
                gathered_cell_attr_hist[key], _, _ = \
                    np.histogram2d(u_distances_by_cell, v_distances_by_cell, bins=[u_edges, v_edges],
                                   weights=gathered_cell_attributes[key])
            for key in gathered_component_attributes:
                gathered_component_attr_hist[key], _, _ = \
                    np.histogram2d(u_distances_by_component, v_distances_by_component, bins=[u_edges, v_edges],
                                   weights=gathered_component_attributes[key])
            gid_count = comm.gather(gid_count, root=0)
            cell_count_hist = comm.gather(cell_count_hist, root=0)
            component_count_hist = comm.gather(component_count_hist, root=0)
            gathered_cell_attr_hist = comm.gather(gathered_cell_attr_hist,
                                                  root=0)
            gathered_component_attr_hist = comm.gather(
                gathered_component_attr_hist, root=0)
            x0_list_by_module = dict(x0_list_by_module)
            y0_list_by_module = dict(y0_list_by_module)
            x0_list_by_module = comm.reduce(x0_list_by_module,
                                            op=mpi_op_merge_list_dict,
                                            root=0)
            y0_list_by_module = comm.reduce(y0_list_by_module,
                                            op=mpi_op_merge_list_dict,
                                            root=0)
            rate_map_sum_by_module = dict(rate_map_sum_by_module)
            rate_map_sum_by_module = comm.gather(rate_map_sum_by_module,
                                                 root=0)
            count_by_module = dict(count_by_module)
            count_by_module = comm.reduce(count_by_module,
                                          op=mpi_op_merge_count_dict,
                                          root=0)

            if rank == 0:
                gid_count = sum(gid_count)
                cell_count_hist = np.sum(cell_count_hist, axis=0)
                component_count_hist = np.sum(component_count_hist, axis=0)
                merged_cell_attr_hist = defaultdict(
                    lambda: np.zeros_like(cell_count_hist))
                merged_component_attr_hist = defaultdict(
                    lambda: np.zeros_like(component_count_hist))
                for each_cell_attr_hist in gathered_cell_attr_hist:
                    for key in each_cell_attr_hist:
                        merged_cell_attr_hist[key] = np.add(
                            merged_cell_attr_hist[key],
                            each_cell_attr_hist[key])
                for each_component_attr_hist in gathered_component_attr_hist:
                    for key in each_component_attr_hist:
                        merged_component_attr_hist[key] = np.add(
                            merged_component_attr_hist[key],
                            each_component_attr_hist[key])
                merged_rate_map_sum_by_module = defaultdict(
                    lambda: np.zeros_like(arena_x_mesh))
                for each_rate_map_sum_by_module in rate_map_sum_by_module:
                    for this_module_id in each_rate_map_sum_by_module:
                        merged_rate_map_sum_by_module[this_module_id] = \
                            np.add(merged_rate_map_sum_by_module[this_module_id],
                                   each_rate_map_sum_by_module[this_module_id])

                logger.info(
                    f'Processing {gid_count} {population} {this_selectivity_type_name} cells '
                    f'took {time.time() - start_time:.2f} s')

                if debug:
                    context.update(locals())

                fig_title = f'{population} {this_selectivity_type_name} field offsets'
                if save_fig is not None:
                    fig_options.saveFig = f'{save_fig} {fig_title}'

                for key in merged_cell_attr_hist:
                    fig_title = f'{population} {this_selectivity_type_name} cells {key} distribution'
                    if save_fig is not None:
                        fig_options.saveFig = f'{save_fig} {fig_title}'
                    if colormap is not None:
                        fig_options.colormap = colormap
                    title = f'{population} {this_selectivity_type_name} cells\n{key} distribution'
                    fig = plot_2D_histogram(
                        merged_cell_attr_hist[key],
                        x_edges=u_edges,
                        y_edges=v_edges,
                        norm=cell_count_hist,
                        ylabel='Transverse position (um)',
                        xlabel='Septo-temporal position (um)',
                        title=title,
                        cbar_label='Mean value per bin',
                        cbar=True,
                        **fig_options())
                    close_figure(fig)

                for key in merged_component_attr_hist:
                    fig_title = f'{population} {this_selectivity_type_name} cells {key} distribution'
                    if save_fig is not None:
                        fig_options.saveFig = f'{save_fig} {fig_title}'
                    title = f'{population} {this_selectivity_type_name} cells\n{key} distribution'
                    fig = plot_2D_histogram(
                        merged_component_attr_hist[key],
                        x_edges=u_edges,
                        y_edges=v_edges,
                        norm=component_count_hist,
                        ylabel='Transverse position (um)',
                        xlabel='Septo-temporal position (um)',
                        title=title,
                        cbar_label='Mean value per bin',
                        cbar=True,
                        **fig_options())
                    close_figure(fig)

                for this_module_id in merged_rate_map_sum_by_module:
                    num_cells = count_by_module[this_module_id]
                    x0 = np.concatenate(x0_list_by_module[this_module_id])
                    y0 = np.concatenate(y0_list_by_module[this_module_id])
                    fig_title = f'{population} {this_selectivity_type_name} Module {this_module_id} rate map'
                    if save_fig is not None:
                        fig_options.saveFig = f'{save_fig} {fig_title}'
                    fig = plot_2D_rate_map(
                        x=arena_x_mesh,
                        y=arena_y_mesh,
                        x0=x0,
                        y0=y0,
                        rate_map=merged_rate_map_sum_by_module[this_module_id],
                        title=
                        (f'{population} {this_selectivity_type_name} rate map\n'
                         f'Module {this_module_id} ({num_cells} cells)'),
                        **fig_options())
                    close_figure(fig)

    if is_interactive and rank == 0:
        context.update(locals())
def main(config, config_prefix, coords_path, distances_namespace, output_path,
         arena_id, populations, use_noise_gen, io_size, chunk_size,
         value_chunk_size, cache_size, write_size, verbose, gather,
         interactive, debug, debug_count, plot, show_fig, save_fig,
         save_fig_dir, font_size, fig_format, dry_run):
    """

    :param config: str (.yaml file name)
    :param config_prefix: str (path to dir)
    :param coords_path: str (path to file)
    :param distances_namespace: str
    :param output_path: str
    :param arena_id: str
    :param populations: tuple of str
    :param io_size: int
    :param chunk_size: int
    :param value_chunk_size: int
    :param cache_size: int
    :param write_size: int
    :param verbose: bool
    :param gather: bool; whether to gather population attributes to rank 0 for interactive analysis or plotting
    :param interactive: bool
    :param debug: bool
    :param plot: bool
    :param show_fig: bool
    :param save_fig: str (base file name)
    :param save_fig_dir:  str (path to dir)
    :param font_size: float
    :param fig_format: str
    :param dry_run: bool
    """
    comm = MPI.COMM_WORLD
    rank = comm.rank

    config_logging(verbose)

    env = Env(comm=comm,
              config_file=config,
              config_prefix=config_prefix,
              template_paths=None)
    if io_size == -1:
        io_size = comm.size
    if rank == 0:
        logger.info(f'{comm.size} ranks have been allocated')

    if save_fig is not None:
        plot = True

    if plot:
        import matplotlib.pyplot as plt
        from dentate.plot import plot_2D_rate_map, default_fig_options, save_figure, clean_axes, close_figure

        fig_options = copy.copy(default_fig_options)
        fig_options.saveFigDir = save_fig_dir
        fig_options.fontSize = font_size
        fig_options.figFormat = fig_format
        fig_options.showFig = show_fig

    if save_fig is not None:
        save_fig = '%s %s' % (save_fig, arena_id)
        fig_options.saveFig = save_fig

    if not dry_run and rank == 0:
        if output_path is None:
            raise RuntimeError(
                'generate_input_selectivity_features: missing output_path')
        if not os.path.isfile(output_path):
            input_file = h5py.File(coords_path, 'r')
            output_file = h5py.File(output_path, 'w')
            input_file.copy('/H5Types', output_file)
            input_file.close()
            output_file.close()
    comm.barrier()
    population_ranges = read_population_ranges(coords_path, comm)[0]

    if len(populations) == 0:
        populations = sorted(population_ranges.keys())

    reference_u_arc_distance_bounds_dict = {}
    if rank == 0:
        for population in sorted(populations):
            if population not in population_ranges:
                raise RuntimeError(
                    'generate_input_selectivity_features: specified population: %s not found in '
                    'provided coords_path: %s' % (population, coords_path))
            if population not in env.stimulus_config[
                    'Selectivity Type Probabilities']:
                raise RuntimeError(
                    'generate_input_selectivity_features: selectivity type not specified for '
                    'population: %s' % population)
            with h5py.File(coords_path, 'r') as coords_f:
                pop_size = population_ranges[population][1]
                unique_gid_count = len(
                    set(coords_f['Populations'][population]
                        [distances_namespace]['U Distance']['Cell Index'][:]))
                if pop_size != unique_gid_count:
                    raise RuntimeError(
                        'generate_input_selectivity_features: only %i/%i unique cell indexes found '
                        'for specified population: %s in provided coords_path: %s'
                        %
                        (unique_gid_count, pop_size, population, coords_path))
                try:
                    reference_u_arc_distance_bounds_dict[population] = \
                      coords_f['Populations'][population][distances_namespace].attrs['Reference U Min'], \
                      coords_f['Populations'][population][distances_namespace].attrs['Reference U Max']
                except Exception:
                    raise RuntimeError(
                        'generate_input_selectivity_features: problem locating attributes '
                        'containing reference bounds in namespace: %s for population: %s from '
                        'coords_path: %s' %
                        (distances_namespace, population, coords_path))
    comm.barrier()
    reference_u_arc_distance_bounds_dict = comm.bcast(
        reference_u_arc_distance_bounds_dict, root=0)

    selectivity_type_names = dict([
        (val, key) for (key, val) in viewitems(env.selectivity_types)
    ])
    selectivity_type_namespaces = dict()
    for this_selectivity_type in selectivity_type_names:
        this_selectivity_type_name = selectivity_type_names[
            this_selectivity_type]
        chars = list(this_selectivity_type_name)
        chars[0] = chars[0].upper()
        selectivity_type_namespaces[this_selectivity_type_name] = ''.join(
            chars) + ' Selectivity %s' % arena_id

    if arena_id not in env.stimulus_config['Arena']:
        raise RuntimeError(
            f'Arena with ID: {arena_id} not specified by configuration at file path: {config_prefix}/{config}'
        )
    arena = env.stimulus_config['Arena'][arena_id]
    arena_x_mesh, arena_y_mesh = None, None
    if rank == 0:
        arena_x_mesh, arena_y_mesh = \
             get_2D_arena_spatial_mesh(arena=arena, spatial_resolution=env.stimulus_config['Spatial Resolution'])
    arena_x_mesh = comm.bcast(arena_x_mesh, root=0)
    arena_y_mesh = comm.bcast(arena_y_mesh, root=0)

    local_random = np.random.RandomState()
    selectivity_seed_offset = int(
        env.model_config['Random Seeds']['Input Selectivity'])
    local_random.seed(selectivity_seed_offset - 1)

    selectivity_config = InputSelectivityConfig(env.stimulus_config,
                                                local_random)
    if plot and rank == 0:
        selectivity_config.plot_module_probabilities(**fig_options())

    if (debug or interactive) and rank == 0:
        context.update(dict(locals()))

    pop_norm_distances = {}
    rate_map_sum = {}
    x0_dict = {}
    y0_dict = {}
    write_every = max(1, int(math.floor(write_size / comm.size)))
    for population in sorted(populations):
        if rank == 0:
            logger.info(
                f'Generating input selectivity features for population {population}...'
            )

        reference_u_arc_distance_bounds = reference_u_arc_distance_bounds_dict[
            population]

        modular = True
        if population in env.stimulus_config[
                'Non-modular Place Selectivity Populations']:
            modular = False

        noise_gen_dict = None
        if use_noise_gen:
            noise_gen_dict = {}
            if modular:
                for module_id in range(env.stimulus_config['Number Modules']):
                    extent_x, extent_y = get_2D_arena_extents(arena)
                    margin = round(
                        selectivity_config.place_module_field_widths[module_id]
                        / 2.)
                    arena_x_bounds, arena_y_bounds = get_2D_arena_bounds(
                        arena, margin=margin)
                    noise_gen = MPINoiseGenerator(
                        comm=comm,
                        bounds=(arena_x_bounds, arena_y_bounds),
                        tile_rank=comm.rank,
                        bin_size=0.5,
                        mask_fraction=0.99,
                        seed=int(selectivity_seed_offset + module_id * 1e6))
                    noise_gen_dict[module_id] = noise_gen
            else:
                margin = round(
                    np.mean(selectivity_config.place_module_field_widths) / 2.)
                arena_x_bounds, arena_y_bounds = get_2D_arena_bounds(
                    arena, margin=margin)
                noise_gen_dict[-1] = MPINoiseGenerator(
                    comm=comm,
                    bounds=(arena_x_bounds, arena_y_bounds),
                    tile_rank=comm.rank,
                    bin_size=0.5,
                    mask_fraction=0.99,
                    seed=selectivity_seed_offset)

        this_pop_norm_distances = {}
        this_rate_map_sum = defaultdict(lambda: np.zeros_like(arena_x_mesh))
        this_x0_list = []
        this_y0_list = []
        start_time = time.time()
        gid_count = defaultdict(lambda: 0)
        distances_attr_gen = NeuroH5CellAttrGen(coords_path,
                                                population,
                                                namespace=distances_namespace,
                                                comm=comm,
                                                io_size=io_size,
                                                cache_size=cache_size)

        selectivity_attr_dict = dict(
            (key, dict()) for key in env.selectivity_types)
        for iter_count, (gid,
                         distances_attr_dict) in enumerate(distances_attr_gen):
            req = comm.Ibarrier()
            if gid is None:
                if noise_gen_dict is not None:
                    all_module_ids = [-1]
                    if modular:
                        all_module_ids = comm.allreduce(set([]),
                                                        op=mpi_op_set_union)
                    for module_id in all_module_ids:
                        this_noise_gen = noise_gen_dict[module_id]
                        global_num_fields = this_noise_gen.sync(0)
                        for i in range(global_num_fields):
                            this_noise_gen.add(
                                np.empty(shape=(0, 0), dtype=np.float32), None)
            else:
                if rank == 0:
                    logger.info(
                        f'Rank {rank} generating selectivity features for gid {gid}...'
                    )
                u_arc_distance = distances_attr_dict['U Distance'][0]
                v_arc_distance = distances_attr_dict['V Distance'][0]
                norm_u_arc_distance = (
                    (u_arc_distance - reference_u_arc_distance_bounds[0]) /
                    (reference_u_arc_distance_bounds[1] -
                     reference_u_arc_distance_bounds[0]))

                this_pop_norm_distances[gid] = norm_u_arc_distance

                this_selectivity_type_name, this_selectivity_attr_dict = \
                 generate_input_selectivity_features(env, population, arena,
                                                     arena_x_mesh, arena_y_mesh,
                                                     gid, (norm_u_arc_distance, v_arc_distance),
                                                     selectivity_config, selectivity_type_names,
                                                     selectivity_type_namespaces,
                                                     noise_gen_dict=noise_gen_dict,
                                                     rate_map_sum=this_rate_map_sum,
                                                     debug= (debug_callback, context) if debug else False)
                if 'X Offset' in this_selectivity_attr_dict:
                    this_x0_list.append(this_selectivity_attr_dict['X Offset'])
                    this_y0_list.append(this_selectivity_attr_dict['Y Offset'])
                selectivity_attr_dict[this_selectivity_type_name][
                    gid] = this_selectivity_attr_dict
                gid_count[this_selectivity_type_name] += 1
            if noise_gen_dict is not None:
                for m in noise_gen_dict:
                    noise_gen_dict[m].tile_rank = (
                        noise_gen_dict[m].tile_rank + 1) % comm.size
            req.wait()

            if (iter_count > 0 and iter_count % write_every
                    == 0) or (debug and iter_count == debug_count):
                total_gid_count = 0
                gid_count_dict = dict(gid_count.items())
                req = comm.Ibarrier()
                selectivity_gid_count = comm.reduce(gid_count_dict,
                                                    root=0,
                                                    op=mpi_op_merge_count_dict)
                req.wait()
                if rank == 0:
                    for selectivity_type_name in selectivity_gid_count:
                        total_gid_count += selectivity_gid_count[
                            selectivity_type_name]
                    for selectivity_type_name in selectivity_gid_count:
                        logger.info(
                            'generated selectivity features for %i/%i %s %s cells in %.2f s'
                            % (selectivity_gid_count[selectivity_type_name],
                               total_gid_count, population,
                               selectivity_type_name,
                               (time.time() - start_time)))

                if not dry_run:
                    for selectivity_type_name in sorted(
                            selectivity_attr_dict.keys()):
                        req = comm.Ibarrier()
                        if rank == 0:
                            logger.info(
                                f'writing selectivity features for {population} [{selectivity_type_name}]...'
                            )
                        selectivity_type_namespace = selectivity_type_namespaces[
                            selectivity_type_name]
                        append_cell_attributes(
                            output_path,
                            population,
                            selectivity_attr_dict[selectivity_type_name],
                            namespace=selectivity_type_namespace,
                            comm=comm,
                            io_size=io_size,
                            chunk_size=chunk_size,
                            value_chunk_size=value_chunk_size)
                        req.wait()
                    del selectivity_attr_dict
                    selectivity_attr_dict = dict(
                        (key, dict()) for key in env.selectivity_types)
                    gc.collect()

            if debug and iter_count >= debug_count:
                break

        pop_norm_distances[population] = this_pop_norm_distances
        rate_map_sum[population] = dict(this_rate_map_sum)
        if len(this_x0_list) > 0:
            x0_dict[population] = np.concatenate(this_x0_list, axis=None)
            y0_dict[population] = np.concatenate(this_y0_list, axis=None)

        total_gid_count = 0
        gid_count_dict = dict(gid_count.items())
        req = comm.Ibarrier()
        selectivity_gid_count = comm.reduce(gid_count_dict,
                                            root=0,
                                            op=mpi_op_merge_count_dict)
        req.wait()

        if rank == 0:
            for selectivity_type_name in selectivity_gid_count:
                total_gid_count += selectivity_gid_count[selectivity_type_name]
            for selectivity_type_name in selectivity_gid_count:
                logger.info(
                    'generated selectivity features for %i/%i %s %s cells in %.2f s'
                    % (selectivity_gid_count[selectivity_type_name],
                       total_gid_count, population, selectivity_type_name,
                       (time.time() - start_time)))

        if not dry_run:
            for selectivity_type_name in sorted(selectivity_attr_dict.keys()):
                req = comm.Ibarrier()
                if rank == 0:
                    logger.info(
                        f'writing selectivity features for {population} [{selectivity_type_name}]...'
                    )
                selectivity_type_namespace = selectivity_type_namespaces[
                    selectivity_type_name]
                append_cell_attributes(
                    output_path,
                    population,
                    selectivity_attr_dict[selectivity_type_name],
                    namespace=selectivity_type_namespace,
                    comm=comm,
                    io_size=io_size,
                    chunk_size=chunk_size,
                    value_chunk_size=value_chunk_size)
                req.wait()
            del selectivity_attr_dict
            gc.collect()
        req = comm.Ibarrier()
        req.wait()

    if gather:
        merged_pop_norm_distances = {}
        for population in sorted(populations):
            merged_pop_norm_distances[population] = \
              comm.reduce(pop_norm_distances[population], root=0,
                          op=mpi_op_merge_dict)
        merged_rate_map_sum = comm.reduce(rate_map_sum,
                                          root=0,
                                          op=mpi_op_merge_rate_map_dict)
        merged_x0 = comm.reduce(x0_dict,
                                root=0,
                                op=mpi_op_concatenate_ndarray_dict)
        merged_y0 = comm.reduce(y0_dict,
                                root=0,
                                op=mpi_op_concatenate_ndarray_dict)
        if rank == 0:
            if plot:
                for population in merged_pop_norm_distances:
                    norm_distance_values = np.asarray(
                        list(merged_pop_norm_distances[population].values()))
                    hist, edges = np.histogram(norm_distance_values, bins=100)
                    fig, axes = plt.subplots(1)
                    axes.plot(edges[1:], hist)
                    axes.set_title(f'Population: {population}')
                    axes.set_xlabel('Normalized cell position')
                    axes.set_ylabel('Cell count')
                    clean_axes(axes)
                    if save_fig is not None:
                        save_figure(
                            f'{save_fig} {population} normalized distances histogram',
                            fig=fig,
                            **fig_options())
                    if fig_options.showFig:
                        fig.show()
                    close_figure(fig)
                for population in merged_rate_map_sum:
                    for selectivity_type_name in merged_rate_map_sum[
                            population]:
                        fig_title = f'{population} {this_selectivity_type_name} summed rate maps'
                        if save_fig is not None:
                            fig_options.saveFig = f'{save_fig} {fig_title}'
                        plot_2D_rate_map(
                            x=arena_x_mesh,
                            y=arena_y_mesh,
                            rate_map=merged_rate_map_sum[population]
                            [selectivity_type_name],
                            title=
                            f'Summed rate maps\n{population} {selectivity_type_name} cells',
                            **fig_options())
                for population in merged_x0:
                    fig_title = f'{population} field offsets'
                    if save_fig is not None:
                        fig_options.saveFig = f'{save_fig} {fig_title}'
                    x0 = merged_x0[population]
                    y0 = merged_y0[population]
                    fig, axes = plt.subplots(1)
                    axes.scatter(x0, y0)
                    if save_fig is not None:
                        save_figure(f'{save_fig} {fig_title}',
                                    fig=fig,
                                    **fig_options())
                    if fig_options.showFig:
                        fig.show()
                    close_figure(fig)

    if interactive and rank == 0:
        context.update(locals())
예제 #9
0
def main(config, config_prefix, coords_path, distances_namespace, output_path,
         arena_id, populations, io_size, chunk_size, value_chunk_size,
         cache_size, write_size, verbose, gather, interactive, debug, plot,
         show_fig, save_fig, save_fig_dir, font_size, fig_format, dry_run):
    """

    :param config: str (.yaml file name)
    :param config_prefix: str (path to dir)
    :param coords_path: str (path to file)
    :param distances_namespace: str
    :param output_path: str
    :param arena_id: str
    :param populations: tuple of str
    :param io_size: int
    :param chunk_size: int
    :param value_chunk_size: int
    :param cache_size: int
    :param write_size: int
    :param verbose: bool
    :param gather: bool; whether to gather population attributes to rank 0 for interactive analysis or plotting
    :param interactive: bool
    :param debug: bool
    :param plot: bool
    :param show_fig: bool
    :param save_fig: str (base file name)
    :param save_fig_dir:  str (path to dir)
    :param font_size: float
    :param fig_format: str
    :param dry_run: bool
    """
    comm = MPI.COMM_WORLD
    rank = comm.rank

    config_logging(verbose)

    env = Env(comm=comm,
              config_file=config,
              config_prefix=config_prefix,
              template_paths=None)
    if io_size == -1:
        io_size = comm.size
    if rank == 0:
        logger.info('%i ranks have been allocated' % comm.size)

    if save_fig is not None:
        plot = True

    if plot:
        import matplotlib.pyplot as plt
        from dentate.plot import plot_2D_rate_map, default_fig_options, save_figure, clean_axes

        fig_options = copy.copy(default_fig_options)
        fig_options.saveFigDir = save_fig_dir
        fig_options.fontSize = font_size
        fig_options.figFormat = fig_format
        fig_options.showFig = show_fig

    if save_fig is not None:
        save_fig = '%s %s' % (save_fig, arena_id)
        fig_options.saveFig = save_fig

    if not dry_run and rank == 0:
        if output_path is None:
            raise RuntimeError(
                'generate_input_selectivity_features: missing output_path')
        if not os.path.isfile(output_path):
            input_file = h5py.File(coords_path, 'r')
            output_file = h5py.File(output_path, 'w')
            input_file.copy('/H5Types', output_file)
            input_file.close()
            output_file.close()
    comm.barrier()
    population_ranges = read_population_ranges(coords_path, comm)[0]

    if len(populations) == 0:
        populations = sorted(population_ranges.keys())

    reference_u_arc_distance_bounds_dict = {}
    if rank == 0:
        for population in sorted(populations):
            if population not in population_ranges:
                raise RuntimeError(
                    'generate_input_selectivity_features: specified population: %s not found in '
                    'provided coords_path: %s' % (population, coords_path))
            if population not in env.stimulus_config[
                    'Selectivity Type Probabilities']:
                raise RuntimeError(
                    'generate_input_selectivity_features: selectivity type not specified for '
                    'population: %s' % population)
            with h5py.File(coords_path, 'r') as coords_f:
                pop_size = population_ranges[population][1]
                unique_gid_count = len(
                    set(coords_f['Populations'][population]
                        [distances_namespace]['U Distance']['Cell Index'][:]))
                if pop_size != unique_gid_count:
                    raise RuntimeError(
                        'generate_input_selectivity_features: only %i/%i unique cell indexes found '
                        'for specified population: %s in provided coords_path: %s'
                        %
                        (unique_gid_count, pop_size, population, coords_path))
                try:
                    reference_u_arc_distance_bounds_dict[population] = \
                      coords_f['Populations'][population][distances_namespace].attrs['Reference U Min'], \
                      coords_f['Populations'][population][distances_namespace].attrs['Reference U Max']
                except Exception:
                    raise RuntimeError(
                        'generate_input_selectivity_features: problem locating attributes '
                        'containing reference bounds in namespace: %s for population: %s from '
                        'coords_path: %s' %
                        (distances_namespace, population, coords_path))
    comm.barrier()
    reference_u_arc_distance_bounds_dict = comm.bcast(
        reference_u_arc_distance_bounds_dict, root=0)

    selectivity_type_names = dict([
        (val, key) for (key, val) in viewitems(env.selectivity_types)
    ])
    selectivity_type_namespaces = dict()
    for this_selectivity_type in selectivity_type_names:
        this_selectivity_type_name = selectivity_type_names[
            this_selectivity_type]
        chars = list(this_selectivity_type_name)
        chars[0] = chars[0].upper()
        selectivity_type_namespaces[this_selectivity_type_name] = ''.join(
            chars) + ' Selectivity %s' % arena_id

    if arena_id not in env.stimulus_config['Arena']:
        raise RuntimeError(
            'Arena with ID: %s not specified by configuration at file path: %s'
            % (arena_id, config_prefix + '/' + config))
    arena = env.stimulus_config['Arena'][arena_id]
    arena_x_mesh, arena_y_mesh = None, None
    if rank == 0:
        arena_x_mesh, arena_y_mesh = \
             get_2D_arena_spatial_mesh(arena=arena, spatial_resolution=env.stimulus_config['Spatial Resolution'])
    arena_x_mesh = comm.bcast(arena_x_mesh, root=0)
    arena_y_mesh = comm.bcast(arena_y_mesh, root=0)

    local_random = np.random.RandomState()
    selectivity_seed_offset = int(
        env.model_config['Random Seeds']['Input Selectivity'])
    local_random.seed(selectivity_seed_offset - 1)

    selectivity_config = InputSelectivityConfig(env.stimulus_config,
                                                local_random)
    if plot and rank == 0:
        selectivity_config.plot_module_probabilities(**fig_options())

    if (debug or interactive) and rank == 0:
        context.update(dict(locals()))

    pop_norm_distances = {}
    rate_map_sum = {}
    write_every = max(1, int(math.floor(write_size / comm.size)))
    for population in sorted(populations):
        if rank == 0:
            logger.info(
                'Generating input selectivity features for population %s...' %
                population)

        reference_u_arc_distance_bounds = reference_u_arc_distance_bounds_dict[
            population]

        this_pop_norm_distances = {}
        this_rate_map_sum = defaultdict(lambda: np.zeros_like(arena_x_mesh))
        start_time = time.time()
        gid_count = defaultdict(lambda: 0)
        distances_attr_gen = NeuroH5CellAttrGen(coords_path,
                                                population,
                                                namespace=distances_namespace,
                                                comm=comm,
                                                io_size=io_size,
                                                cache_size=cache_size)

        selectivity_attr_dict = dict(
            (key, dict()) for key in env.selectivity_types)
        for iter_count, (gid,
                         distances_attr_dict) in enumerate(distances_attr_gen):
            if gid is not None:
                u_arc_distance = distances_attr_dict['U Distance'][0]
                v_arc_distance = distances_attr_dict['V Distance'][0]
                norm_u_arc_distance = (
                    (u_arc_distance - reference_u_arc_distance_bounds[0]) /
                    (reference_u_arc_distance_bounds[1] -
                     reference_u_arc_distance_bounds[0]))

                this_pop_norm_distances[gid] = norm_u_arc_distance

                this_selectivity_type_name, this_selectivity_attr_dict = \
                 generate_input_selectivity_features(env, population, arena,
                                                     arena_x_mesh, arena_y_mesh,
                                                     gid, (norm_u_arc_distance, v_arc_distance),
                                                     selectivity_config, selectivity_type_names,
                                                     selectivity_type_namespaces,
                                                     rate_map_sum=this_rate_map_sum,
                                                     debug= (debug_callback, context) if debug else False)
                selectivity_attr_dict[this_selectivity_type_name][
                    gid] = this_selectivity_attr_dict
                gid_count[this_selectivity_type_name] += 1

            if (iter_count > 0 and iter_count % write_every
                    == 0) or (debug and iter_count == 10):
                total_gid_count = 0
                gid_count_dict = dict(gid_count.items())
                selectivity_gid_count = comm.reduce(gid_count_dict,
                                                    root=0,
                                                    op=mpi_op_merge_count_dict)
                if rank == 0:
                    for selectivity_type_name in selectivity_gid_count:
                        total_gid_count += selectivity_gid_count[
                            selectivity_type_name]
                    for selectivity_type_name in selectivity_gid_count:
                        logger.info(
                            'generated selectivity features for %i/%i %s %s cells in %.2f s'
                            % (selectivity_gid_count[selectivity_type_name],
                               total_gid_count, population,
                               selectivity_type_name,
                               (time.time() - start_time)))

                if not dry_run:
                    for selectivity_type_name in sorted(
                            selectivity_attr_dict.keys()):
                        if rank == 0:
                            logger.info(
                                'writing selectivity features for %s [%s]...' %
                                (population, selectivity_type_name))
                        selectivity_type_namespace = selectivity_type_namespaces[
                            selectivity_type_name]
                        append_cell_attributes(
                            output_path,
                            population,
                            selectivity_attr_dict[selectivity_type_name],
                            namespace=selectivity_type_namespace,
                            comm=comm,
                            io_size=io_size,
                            chunk_size=chunk_size,
                            value_chunk_size=value_chunk_size)
                del selectivity_attr_dict
                selectivity_attr_dict = dict(
                    (key, dict()) for key in env.selectivity_types)

            if debug and iter_count >= 10:
                break

        pop_norm_distances[population] = this_pop_norm_distances
        rate_map_sum[population] = this_rate_map_sum

        total_gid_count = 0
        gid_count_dict = dict(gid_count.items())
        selectivity_gid_count = comm.reduce(gid_count_dict,
                                            root=0,
                                            op=mpi_op_merge_count_dict)

        if rank == 0:
            for selectivity_type_name in selectivity_gid_count:
                total_gid_count += selectivity_gid_count[selectivity_type_name]
            for selectivity_type_name in selectivity_gid_count:
                logger.info(
                    'generated selectivity features for %i/%i %s %s cells in %.2f s'
                    % (selectivity_gid_count[selectivity_type_name],
                       total_gid_count, population, selectivity_type_name,
                       (time.time() - start_time)))

        if not dry_run:
            for selectivity_type_name in sorted(selectivity_attr_dict.keys()):
                if rank == 0:
                    logger.info('writing selectivity features for %s [%s]...' %
                                (population, selectivity_type_name))
                selectivity_type_namespace = selectivity_type_namespaces[
                    selectivity_type_name]
                append_cell_attributes(
                    output_path,
                    population,
                    selectivity_attr_dict[selectivity_type_name],
                    namespace=selectivity_type_namespace,
                    comm=comm,
                    io_size=io_size,
                    chunk_size=chunk_size,
                    value_chunk_size=value_chunk_size)
        del selectivity_attr_dict
        comm.barrier()

    if gather:
        merged_pop_norm_distances = {}
        for population in sorted(populations):
            merged_pop_norm_distances[population] = \
              comm.reduce(pop_norm_distances[population], root=0,
                          op=mpi_op_merge_dict)
        rate_map_sum = dict([(key, dict(val.items()))
                             for key, val in viewitems(rate_map_sum)])
        rate_map_sum = comm.gather(rate_map_sum, root=0)
        if rank == 0:
            merged_rate_map_sum = defaultdict(
                lambda: defaultdict(lambda: np.zeros_like(arena_x_mesh)))
            for each_rate_map_sum in rate_map_sum:
                for population in each_rate_map_sum:
                    for selectivity_type_name in each_rate_map_sum[population]:
                        merged_rate_map_sum[population][selectivity_type_name] = \
                            np.add(merged_rate_map_sum[population][selectivity_type_name],
                                   each_rate_map_sum[population][selectivity_type_name])
            if plot:
                for population in merged_pop_norm_distances:
                    norm_distance_values = np.asarray(
                        list(merged_pop_norm_distances[population].values()))
                    hist, edges = np.histogram(norm_distance_values, bins=100)
                    fig, axes = plt.subplots(1)
                    axes.plot(edges[1:], hist)
                    axes.set_title('Population: %s' % population)
                    axes.set_xlabel('Normalized cell position')
                    axes.set_ylabel('Cell count')
                    clean_axes(axes)
                    if save_fig is not None:
                        save_figure('%s %s normalized distances histogram' %
                                    (save_fig, population),
                                    fig=fig,
                                    **fig_options())
                    if fig_options.showFig:
                        fig.show()
                for population in merged_rate_map_sum:
                    for selectivity_type_name in merged_rate_map_sum[
                            population]:
                        fig_title = '%s %s summed rate maps' % (
                            population, this_selectivity_type_name)
                        if save_fig is not None:
                            fig_options.saveFig = '%s %s' % (save_fig,
                                                             fig_title)
                        plot_2D_rate_map(
                            x=arena_x_mesh,
                            y=arena_y_mesh,
                            rate_map=merged_rate_map_sum[population]
                            [selectivity_type_name],
                            title='Summed rate maps\n%s %s cells' %
                            (population, selectivity_type_name),
                            **fig_options())

    if interactive and rank == 0:
        context.update(locals())
예제 #10
0
def main(config, coordinates, gid, field_width, peak_rate, input_features_path, input_features_namespaces,
         output_weights_path, output_features_path, weights_path, h5types_path, synapse_name, initial_weights_namespace,
         structured_weights_namespace, connections_path, destination, sources, arena_id, baseline_weight,
         field_width_scale, max_iter, verbose, dry_run, interactive):
    """

    :param config: str (path to .yaml file)
    :param weights_path: str (path to .h5 file)
    :param initial_weights_namespace: str
    :param structured_weights_namespace: str
    :param connections_path: str (path to .h5 file)
    :param destination: str
    :param sources: list of str
    :param verbose:
    :param dry_run:
    :return:
    """

    utils.config_logging(verbose)
    logger = utils.get_script_logger(__file__)

    env = Env(config_file=config)

    if output_weights_path is None:
        if weights_path is None:
            raise RuntimeError('Output weights path must be specified when weights path is not specified.')
        output_weights_path = weights_path
    
    if (not dry_run):
        if not os.path.isfile(output_weights_path):
            if weights_path is not None:
                input_file  = h5py.File(weights_path,'r')
            elif h5types_path is not None:
                input_file  = h5py.File(h5types_path,'r')
            else:
                raise RuntimeError('h5types input path must be specified when weights path is not specified.')
            output_file = h5py.File(output_weights_path,'w')
            input_file.copy('/H5Types',output_file)
            input_file.close()
            output_file.close()

    
    this_input_features_namespaces = ['%s %s' % (input_features_namespace, arena_id) for input_features_namespace in input_features_namespaces]

    initial_weights_dict = None
    if weights_path is not None:
        logger.info('Reading initial weights data from %s...' % weights_path)
        cell_attributes_dict = read_cell_attribute_selection(weights_path, destination, 
                                                             namespaces=[initial_weights_namespace],
                                                             selection=[gid])
                                                            
        if initial_weights_namespace in cell_attributes_dict:
            initial_weights_iter = cell_attributes_dict[initial_weights_namespace]
            initial_weights_dict = { gid: attr_dict for gid, attr_dict in initial_weights_iter }
        else:
            raise RuntimeError('Initial weights namespace %s was not found in file %s' % (initial_weights_namespace, weights_path))
    
        logger.info('Rank %i; destination: %s; read synaptic weights for %i cells' %
                    (env.comm.rank, destination, len(initial_weights_dict)))


    features_attr_names = ['Num Fields', 'Field Width', 'Peak Rate', 'X Offset', 'Y Offset', 'Arena Rate Map']
    
    local_random = np.random.RandomState()

    seed_offset = int(env.model_config['Random Seeds']['GC Structured Weights'])
    local_random.seed(int(gid + seed_offset))
    
    spatial_resolution = env.stimulus_config['Spatial Resolution'] # cm

    arena = env.stimulus_config['Arena'][arena_id]
    default_run_vel = arena.properties['default run velocity']  # cm/s

    x, y = stimulus.get_2D_arena_spatial_mesh(arena, spatial_resolution)
    
    plasticity_kernel = lambda x, y, x_loc, y_loc, sx, sy: gauss2d(x-x_loc, y-y_loc, sx=sx, sy=sy)
    plasticity_kernel = np.vectorize(plasticity_kernel, excluded=[2,3,4,5])


    dst_input_features = defaultdict(dict)
    num_fields = len(coordinates)
    this_field_width = np.array([field_width]*num_fields, dtype=np.float32)
    this_peak_rate = np.array([peak_rate]*num_fields, dtype=np.float32)
    this_x0 = np.array([x for x, y in coordinates], dtype=np.float32)
    this_y0 = np.array([y for x, y in coordinates], dtype=np.float32)
    this_rate_map = np.asarray(get_rate_map(this_x0, this_y0, this_field_width, this_peak_rate, x, y),
                               dtype=np.float32)
    selectivity_type = env.selectivity_types['place']
    dst_input_features[destination][gid] = {
        'Selectivity Type': np.array([selectivity_type], dtype=np.uint8),
        'Num Fields': np.array([num_fields], dtype=np.uint8),
        'Field Width': this_field_width,
        'Peak Rate': this_peak_rate,
        'X Offset': this_x0,
        'Y Offset': this_y0,
        'Arena Rate Map': this_rate_map.ravel() }

    selection=[gid]
    structured_weights_dict = {}
    source_syn_dict = defaultdict(lambda: defaultdict(list))
    syn_weight_dict = {}
    if weights_path is not None:
        initial_weights_iter = read_cell_attribute_selection(weights_path, destination, 
                                                                 namespace=initial_weights_namespace, 
                                                                 selection=selection)
        syn_weight_attr_dict = dict(initial_weights_iter)

        syn_ids = syn_weight_attr_dict[gid]['syn_id']
        weights = syn_weight_attr_dict[gid][synapse_name]
                    
        for (syn_id, weight) in zip(syn_ids, weights):
            syn_weight_dict[int(syn_id)] = float(weight) 

        logger.info('destination: %s; gid %i; received synaptic weights for %i synapses' %
                        (destination, gid, len(syn_weight_dict)))

    (graph, edge_attr_info) = read_graph_selection(file_name=connections_path,
                                                   selection=[gid],
                                                   namespaces=['Synapses'])
    syn_id_attr_index = None
    for source, edge_iter in viewitems(graph[destination]):
        this_edge_attr_info = edge_attr_info[destination][source]
        if 'Synapses' in this_edge_attr_info and \
           'syn_id' in this_edge_attr_info['Synapses']:
            syn_id_attr_index = this_edge_attr_info['Synapses']['syn_id']
        for (destination_gid, edges) in edge_iter:
            assert destination_gid == gid
            source_gids, edge_attrs = edges

            syn_ids = edge_attrs['Synapses'][syn_id_attr_index]
            this_source_syn_dict = source_syn_dict[source]
            count = 0
            for i in range(len(source_gids)):
                this_source_gid = source_gids[i]
                this_syn_id = syn_ids[i]
                this_syn_wgt = syn_weight_dict.get(this_syn_id, 0.0)
                this_source_syn_dict[this_source_gid].append((this_syn_id, this_syn_wgt))
                count += 1
            logger.info('destination: %s; gid %i; %d synaptic weights from source population %s' %
                        (destination, gid, count, source))
                    
    src_input_features = defaultdict(dict)
    for source in sources:
        source_gids = list(source_syn_dict[source].keys())
        for input_features_namespace in this_input_features_namespaces:
            input_features_iter = read_cell_attribute_selection(input_features_path, source, 
                                                                namespace=input_features_namespace,
                                                                mask=set(features_attr_names), 
                                                                selection=source_gids)
            this_src_input_features = src_input_features[source]
            count = 0
            for gid, attr_dict in input_features_iter:
                this_src_input_features[gid] = attr_dict
                count += 1
            logger.info('Read %s feature data for %i cells in population %s' % (input_features_namespace, count, source))

    this_syn_weights = \
      synapses.generate_structured_weights(destination_gid,
                                           destination,
                                           synapse_name,
                                           sources,
                                           dst_input_features,
                                           src_input_features,
                                           source_syn_dict,
                                           spatial_mesh=(x,y),
                                           plasticity_kernel=plasticity_kernel,
                                           field_width_scale=field_width_scale,
                                           baseline_weight=baseline_weight,
                                           local_random=local_random,
                                           interactive=interactive)

    assert this_syn_weights is not None
    structured_weights_dict[destination_gid] = this_syn_weights
    logger.info('destination: %s; gid %i; generated structured weights for %i inputs'
                   % (destination, destination_gid, len(this_syn_weights['syn_id'])))
    gc.collect()
    if not dry_run:
            
        logger.info('Destination: %s; appending structured weights...' % (destination))
        this_structured_weights_namespace = '%s %s' % (structured_weights_namespace, arena_id)
        append_cell_attributes(output_weights_path, destination, structured_weights_dict,
                               namespace=this_structured_weights_namespace)
        logger.info('Destination: %s; appended structured weights' % (destination))
        structured_weights_dict.clear()
        if output_features_path is not None:
            output_features_namespace = 'Place Selectivity %s' % arena_id
            cell_attr_dict = dst_input_features[destination]
            logger.info('Destination: %s; appending features...' % (destination))
            append_cell_attributes(output_features_path, destination,
                                   cell_attr_dict, namespace=output_features_namespace)
            
            
        gc.collect()
            
    del(syn_weight_dict)
    del(src_input_features)
    del(dst_input_features)
def main(config, config_prefix, arena_id, populations, module_ids,
         target_fraction_active, normalize_scale, verbose, interactive, debug,
         plot, show_fig, save_fig, save_fig_dir, font_size, fig_format):
    """

    :param config: str (.yaml file name)
    :param config_prefix: str (path to dir)
    :param arena_id: str
    :param populations: tuple of str
    :param module_ids: tuple of int
    :param target_fraction_active: float
    :param normalize_scale: bool; whether to interpret the scale of the num_place_field_probabilities distribution
                                    as normalized to the scale of the mean place field width
    :param verbose: bool
    :param interactive: bool
    :param debug: bool
    :param plot: bool
    :param show_fig: bool
    :param save_fig: str (base file name)
    :param save_fig_dir:  str (path to dir)
    :param font_size: float
    :param fig_format: str
    """
    comm = MPI.COMM_WORLD
    rank = comm.rank

    config_logging(verbose)

    env = Env(comm=comm,
              config_file=config,
              config_prefix=config_prefix,
              template_paths=None)

    if plot:
        import matplotlib.pyplot as plt
        from dentate.plot import clean_axes

    fig_options = copy.copy(default_fig_options)
    fig_options.saveFigDir = save_fig_dir
    fig_options.fontSize = font_size
    fig_options.figFormat = fig_format
    fig_options.showFig = show_fig

    if save_fig is not None:
        save_fig = '%s %s' % (save_fig, arena_id)
    fig_options.saveFig = save_fig

    if len(populations) == 0:
        populations = ('MC', 'ConMC', 'LPP', 'GC', 'MPP', 'CA3c')

    if arena_id not in env.stimulus_config['Arena']:
        raise RuntimeError(
            'Arena with ID: %s not specified by configuration at file path: %s'
            % (arena_id, config_prefix + '/' + config))

    selectivity_type_names = dict(
        (val, key) for (key, val) in viewitems(env.selectivity_types))

    arena = env.stimulus_config['Arena'][arena_id]
    arena_x_mesh, arena_y_mesh = \
        get_2D_arena_spatial_mesh(arena=arena, spatial_resolution=env.stimulus_config['Spatial Resolution'])

    local_random = np.random.RandomState()
    selectivity_seed_offset = int(
        env.model_config['Random Seeds']['Input Selectivity'])
    local_random.seed(selectivity_seed_offset - 1)

    selectivity_config = InputSelectivityConfig(env.stimulus_config,
                                                local_random)

    this_selectivity_type_name = 'place'
    this_selectivity_type = env.selectivity_types['place']

    if interactive:
        context.update(locals())

    if len(module_ids) == 0:
        module_ids = selectivity_config.module_ids
    elif not all([
            module_id in selectivity_config.module_ids
            for module_id in module_ids
    ]):
        raise RuntimeError(
            'calibrate_DG_num_place_field_probabilities: invalid module_ids provided: %s'
            % str(module_ids))

    for population in populations:

        if population not in env.stimulus_config[
                'Num Place Field Probabilities']:
            raise RuntimeError(
                'calibrate_DG_num_place_field_probabilities: probabilities for number of place fields '
                'not specified for population: %s' % population)
        num_place_field_probabilities = env.stimulus_config[
            'Num Place Field Probabilities'][population]

        if population not in env.stimulus_config['Peak Rate'] or \
                this_selectivity_type not in env.stimulus_config['Peak Rate'][population]:
            raise RuntimeError(
                'calibrate_DG_num_place_field_probabilities: peak rate not specified for population: '
                '%s, selectivity type: %s' %
                (population, this_selectivity_type_name))
        peak_rate = env.stimulus_config['Peak Rate'][population][
            this_selectivity_type]

        start_time = time.time()
        for module_id in module_ids:
            field_width = selectivity_config.place_module_field_widths[
                module_id]
            logger.info(
                'Calibrating distribution of num_place_field_probabilities for population: %s, module: %i, '
                'field width: %.2f' % (population, module_id, field_width))
            modified_num_place_field_probabilities = \
                calibrate_num_place_field_probabilities(num_place_field_probabilities, field_width,
                                                        peak_rate=peak_rate, selectivity_type=this_selectivity_type,
                                                        arena=arena, normalize_scale=normalize_scale,
                                                        selectivity_config=selectivity_config,
                                                        target_fraction_active=target_fraction_active,
                                                        random_seed=selectivity_seed_offset + module_id,
                                                        plot=plot and show_fig)
            logger.info(
                'Modified num_place_field_probabilities for population: %s, module: %i, field width: %.2f'
                % (population, module_id, field_width))
            print_param_dict_like_yaml(modified_num_place_field_probabilities)
            sys.stdout.flush()
            if debug:
                context.update(locals())
                return

    if interactive:
        context.update(locals())