def main(population, forest_path, forest_measurement_namespace, attr_name,
         selection_path):

    comm = MPI.COMM_WORLD
    rank = comm.rank
    size = comm.size

    population_ranges = read_population_ranges(forest_path)[0]

    selection = []
    f = open(selection_path, 'r')
    for line in f.readlines():
        selection.append(int(line))
    f.close()

    columns = [attr_name]
    df_dict = {}
    it = read_cell_attribute_selection(forest_path,
                                       population,
                                       namespace=forest_measurement_namespace,
                                       selection=selection)

    for cell_gid, meas_dict in it:
        cell_attr = meas_dict[attr_name]
        df_dict[cell_gid] = [np.sum(cell_attr)]

    df = pd.DataFrame.from_dict(df_dict, orient='index', columns=columns)
    df = df.reindex(selection)
    df.to_csv('tree.%s.%s.csv' % (attr_name, population))
Exemplo n.º 2
0
def main(template_path, forest_path, synapses_path, connections_path, config_path):
    
    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()

    env = Env(comm=comm, config_file=config_path, template_paths=template_path)

    h('objref nil, pc, tlog, Vlog, spikelog')
    h.load_file("nrngui.hoc")
    h.xopen ("./tests/rn.hoc")
    h.xopen(template_path+'/HICAPCell.hoc')
    
    pop_name = "HCC"
    gid = 1043250
    (trees_dict,_) = read_tree_selection (forest_path, pop_name, [gid], comm=env.comm)

    (_, tree) = next(trees_dict)
    v_init = -67
    
    template_class = getattr(h, "HICAPCell")

    passive_test(template_class, tree, v_init)
    ap_test(template_class, tree, v_init)
    ap_rate_test(template_class, tree, v_init)
    fi_test(template_class, tree, v_init)

    if synapses_path and connections_path:
        synapses_iter = read_cell_attribute_selection (synapses_path, pop_name, [gid],
                                                       "Synapse Attributes", comm=env.comm)
        (_, synapses_dict) = next(synapses_iter)
        connections = read_graph_selection(file_name=connections_path, selection=[gid],
                                            namespaces=['Synapses', 'Connections'], comm=env.comm)

        synapse_test(template_class, gid, tree, synapses_dict, connections, v_init, env)
Exemplo n.º 3
0
def main(population, features_path, features_namespace, selection_path):

    comm = MPI.COMM_WORLD
    rank = comm.rank
    size = comm.size

    population_ranges = read_population_ranges(features_path)[0]
    
    soma_coords = {}

    selection = []
    f = open(selection_path, 'r')
    for line in f.readlines():
        selection.append(int(line))
    f.close()

    columns = ['Field Width', 'X Offset', 'Y Offset']
    df_dict = {}
    it = read_cell_attribute_selection(features_path, population, 
                                       namespace=features_namespace, 
                                       selection=selection)

    for cell_gid, features_dict in it:
        cell_field_width = features_dict['Field Width'][0]
        cell_xoffset = features_dict['X Offset'][0]
        cell_yoffset = features_dict['Y Offset'][0]
        
        df_dict[cell_gid] = [cell_field_width, cell_xoffset, cell_yoffset]

        
    df = pd.DataFrame.from_dict(df_dict, orient='index', columns=columns)
    df = df.reindex(selection)
    df.to_csv('features.%s.csv' % population)
Exemplo n.º 4
0
def read_target_rate_vector(context, eps=1e-2):
    """
    """

    target_rate_map_arena = context.init_params['target_rate_map_arena']
    target_rate_map_trajectory = context.init_params['target_rate_map_trajectory']
    target_rate_map_path = context.init_params['target_rate_map_path']
    target_rate_map_namespace = context.init_params.get('target_rate_map_namespace', 'Input Spikes')
    trj_x, trj_y, trj_d, trj_t = stimulus.read_trajectory(target_rate_map_path, target_rate_map_arena, target_rate_map_trajectory)

    time_range = (0., min(np.max(trj_t), context.init_params['tstop']))
    time_step = context.env.stimulus_config['Temporal Resolution']
    context.time_bins = np.arange(time_range[0], time_range[1], time_step)
    context.state_time_bins = np.arange(time_range[0], time_range[1], time_step)[:-1]

    input_namespace = '%s %s %s' % (target_rate_map_namespace, target_rate_map_arena, target_rate_map_trajectory)
    it = read_cell_attribute_selection(target_rate_map_path, context.population, namespace=input_namespace,
                                        selection=[context.gid], mask=set(['Trajectory Rate Map']),
                                        comm=context.comm)
    _, attr_dict = next(it)
    trj_rate_map = attr_dict['Trajectory Rate Map']
    target_rate_vector = np.interp(context.state_time_bins, trj_t, trj_rate_map)

    target_rate_vector[np.abs(target_rate_vector) < eps] = 0.
    
    
    return target_rate_vector
Exemplo n.º 5
0
def main(template_path, forest_path, synapses_path, config_path):

    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()

    env = Env(comm=comm, config_file=config_path, template_paths=template_path)

    h('objref nil, pc, tlog, Vlog, spikelog')
    h.load_file("nrngui.hoc")
    h.xopen("./tests/rn.hoc")
    h.xopen(template_path + '/BasketCell.hoc')

    pop_name = "BC"
    gid = 1039000
    (trees_dict, _) = read_tree_selection(forest_path,
                                          pop_name, [gid],
                                          comm=env.comm)
    synapses_dict = read_cell_attribute_selection(synapses_path,
                                                  pop_name, [gid],
                                                  "Synapse Attributes",
                                                  comm=env.comm)

    (_, tree) = next(trees_dict)
    (_, synapses) = next(synapses_dict)

    v_init = -60

    template_class = getattr(h, "BasketCell")

    ap_test(template_class, tree, v_init)
    passive_test(template_class, tree, v_init)
    ap_rate_test(template_class, tree, v_init)
    fi_test(template_class, tree, v_init)
    gap_junction_test(env, template_class, tree, v_init)
    synapse_test(template_class, gid, tree, synapses, v_init, env)
Exemplo n.º 6
0
def main(population, coords_path, coords_namespace, distances_namespace,
         selection_path):

    comm = MPI.COMM_WORLD
    rank = comm.rank
    size = comm.size

    population_ranges = read_population_ranges(coords_path)[0]

    soma_coords = {}

    selection = []
    f = open(selection_path, 'r')
    for line in f.readlines():
        selection.append(int(line))
    f.close()

    columns = ['U', 'V', 'L']
    df_dict = {}
    it = read_cell_attribute_selection(coords_path,
                                       population,
                                       namespace=coords_namespace,
                                       selection=selection)

    for cell_gid, coords_dict in it:
        cell_u = coords_dict['U Coordinate'][0]
        cell_v = coords_dict['V Coordinate'][0]
        cell_l = coords_dict['L Coordinate'][0]

        df_dict[cell_gid] = [cell_u, cell_v, cell_l]

    if distances_namespace is not None:
        columns.extend(['U Distance', 'V Distance'])
        it = read_cell_attribute_selection(coords_path,
                                           population,
                                           namespace=distances_namespace,
                                           selection=selection)
        for cell_gid, distances_dict in it:
            cell_ud = distances_dict['U Distance'][0]
            cell_vd = distances_dict['V Distance'][0]

            df_dict[cell_gid].extend([cell_ud, cell_vd])

    df = pd.DataFrame.from_dict(df_dict, orient='index', columns=columns)
    df = df.reindex(selection)
    df.to_csv('coords.%s.csv' % population)
Exemplo n.º 7
0
def main(config_path, template_paths, forest_path, synapses_path):

    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    env = Env(comm=comm,
              config_file=config_path,
              template_paths=template_paths)

    neuron_utils.configure_hoc_env(env)

    h.pc = h.ParallelContext()

    v_init = -65.0
    popName = "MC"
    gid = 1000000

    env.load_cell_template(popName)

    (trees, _) = read_tree_selection(forest_path, popName, [gid], comm=comm)
    if synapses_path is not None:
        synapses_iter = read_cell_attribute_selection (synapses_path, popName, [gid], \
                                                       "Synapse Attributes", comm=comm)
    else:
        synapses_iter = None

    gid, tree = next(trees)
    if synapses_iter is not None:
        (_, synapses) = next(synapses_iter)
    else:
        synapses = None

    if 'mech_file' in env.celltypes[popName]:
        mech_file_path = env.config_prefix + '/' + env.celltypes[popName][
            'mech_file']
    else:
        mech_file_path = None

    template_class = getattr(h, "MossyCell")

    if (synapses is not None):
        synapse_test(template_class, mech_file_path, gid, tree, synapses,
                     v_init, env)
Exemplo n.º 8
0
def main(arena_id, config, config_prefix, dataset_prefix, distances_namespace, spike_input_path, spike_input_namespace, spike_input_attr, input_features_namespaces, input_features_path, selection_path, output_path, io_size, trajectory_id, verbose):

    utils.config_logging(verbose)
    logger = utils.get_script_logger(os.path.basename(__file__))

    comm = MPI.COMM_WORLD
    rank = comm.rank
    if io_size == -1:
        io_size = comm.size

    env = Env(comm=comm, config_file=config, 
              config_prefix=config_prefix, dataset_prefix=dataset_prefix, 
              results_path=output_path, spike_input_path=spike_input_path, 
              spike_input_namespace=spike_input_namespace, spike_input_attr=spike_input_attr,
              arena_id=arena_id, trajectory_id=trajectory_id, io_size=io_size)

    selection = []
    f = open(selection_path, 'r')
    for line in f.readlines():
        selection.append(int(line))
    f.close()
    selection = set(selection)

    pop_ranges, pop_size = read_population_ranges(env.connectivity_file_path, comm=comm)

    distance_U_dict = {}
    distance_V_dict = {}
    range_U_dict = {}
    range_V_dict = {}

    selection_dict = defaultdict(set)

    comm0 = env.comm.Split(2 if rank == 0 else 0, 0)

    if rank == 0:
        for population in pop_ranges:
            distances = read_cell_attributes(env.data_file_path, population, namespace=distances_namespace, comm=comm0)
            soma_distances = { k: (v['U Distance'][0], v['V Distance'][0]) for (k,v) in distances }
            del distances
        
            numitems = len(list(soma_distances.keys()))

            if numitems == 0:
                continue

            distance_U_array = np.asarray([soma_distances[gid][0] for gid in soma_distances])
            distance_V_array = np.asarray([soma_distances[gid][1] for gid in soma_distances])

            U_min = np.min(distance_U_array)
            U_max = np.max(distance_U_array)
            V_min = np.min(distance_V_array)
            V_max = np.max(distance_V_array)

            range_U_dict[population] = (U_min, U_max)
            range_V_dict[population] = (V_min, V_max)
            
            distance_U = { gid: soma_distances[gid][0] for gid in soma_distances }
            distance_V = { gid: soma_distances[gid][1] for gid in soma_distances }
            
            distance_U_dict[population] = distance_U
            distance_V_dict[population] = distance_V
            
            min_dist = U_min
            max_dist = U_max 

            selection_dict[population] = set([ k for k in distance_U if k in selection ])
    

    env.comm.barrier()

    write_selection_file_path =  "%s/%s_selection.h5" % (env.results_path, env.modelName)

    if rank == 0:
        io_utils.mkout(env, write_selection_file_path)
    env.comm.barrier()
    selection_dict = env.comm.bcast(dict(selection_dict), root=0)
    env.cell_selection = selection_dict
    io_utils.write_cell_selection(env, write_selection_file_path)
    input_selection = io_utils.write_connection_selection(env, write_selection_file_path)
    if spike_input_path:
        io_utils.write_input_cell_selection(env, input_selection, write_selection_file_path)
    if input_features_path:
        for this_input_features_namespace in sorted(input_features_namespaces):
            for population in sorted(input_selection):
                logger.info(f"Extracting input features {this_input_features_namespace} for population {population}...")
                it = read_cell_attribute_selection(input_features_path, population, 
                                                   namespace=f"{this_input_features_namespace} {arena_id}", 
                                                   selection=input_selection[population], comm=env.comm)
                output_features_dict = { cell_gid : cell_features_dict for cell_gid, cell_features_dict in it }
                append_cell_attributes(write_selection_file_path, population, output_features_dict,
                                       namespace=f"{this_input_features_namespace} {arena_id}", 
                                       io_size=io_size, comm=env.comm)
    env.comm.barrier()
def main(config, coordinates, gid, field_width, peak_rate, input_features_path,
         input_features_namespaces, output_features_namespace,
         output_weights_path, output_features_path, initial_weights_path,
         reference_weights_path, h5types_path, synapse_name,
         initial_weights_namespace, reference_weights_namespace,
         output_weights_namespace, reference_weights_are_delta,
         connections_path, optimize_method, destination, sources, arena_id,
         max_delta_weight, field_width_scale, max_iter, verbose, dry_run,
         plot):
    """
    :param config: str (path to .yaml file)
    :param coordinates: tuple of float
    :param gid: int
    :param field_width: float
    :param peak_rate: float
    :param input_features_path: str (path to .h5 file)
    :param input_features_namespaces: str
    :param output_features_namespace: str
    :param output_weights_path: str (path to .h5 file)
    :param output_features_path: str (path to .h5 file)
    :param initial_weights_path: str (path to .h5 file)
    :param reference_weights_path: str (path to .h5 file)
    :param h5types_path: str (path to .h5 file)
    :param synapse_name: str
    :param initial_weights_namespace: str
    :param output_weights_namespace: str
    :param reference_weights_are_delta: bool
    :param connections_path: str (path to .h5 file)
    :param destination: str (population name)
    :param sources: list of str (population name)
    :param arena_id: str
    :param max_delta_weight: float
    :param field_width_scale: float
    :param max_iter: int
    :param verbose: bool
    :param dry_run: bool
    :param interactive: bool
    :param plot: bool
    """
    utils.config_logging(verbose)
    logger = utils.get_script_logger(__file__)

    env = Env(config_file=config)

    if not dry_run:
        if output_weights_path is None:
            raise RuntimeError(
                'Missing required argument: output_weights_path.')
        if not os.path.isfile(output_weights_path):
            if initial_weights_path is not None and os.path.isfile(
                    initial_weights_path):
                input_file_path = initial_weights_path
            elif h5types_path is not None and os.path.isfile(h5types_path):
                input_file_path = h5types_path
            else:
                raise RuntimeError(
                    'Missing required source for h5types: either an initial_weights_path or an '
                    'h5types_path must be provided.')
            with h5py.File(output_weights_path, 'a') as output_file:
                with h5py.File(input_file_path, 'r') as input_file:
                    input_file.copy('/H5Types', output_file)

    this_input_features_namespaces = [
        '%s %s' % (input_features_namespace, arena_id)
        for input_features_namespace in input_features_namespaces
    ]
    features_attr_names = ['Arena Rate Map']
    spatial_resolution = env.stimulus_config['Spatial Resolution']  # cm
    arena = env.stimulus_config['Arena'][arena_id]
    default_run_vel = arena.properties['default run velocity']  # cm/s
    arena_x, arena_y = stimulus.get_2D_arena_spatial_mesh(
        arena, spatial_resolution)
    dim_x = len(arena_x)
    dim_y = len(arena_y)

    if gid is None:
        target_gids = []
    else:
        target_gids = [gid]

    dst_input_features = defaultdict(dict)
    num_fields = len(coordinates)
    this_field_width = np.array([field_width] * num_fields, dtype=np.float32)
    this_scaled_field_width = np.array([field_width * field_width_scale] *
                                       num_fields,
                                       dtype=np.float32)
    this_peak_rate = np.array([peak_rate] * num_fields, dtype=np.float32)
    this_x0 = np.array([x for x, y in coordinates], dtype=np.float32)
    this_y0 = np.array([y for x, y in coordinates], dtype=np.float32)
    this_rate_map = np.asarray(get_rate_map(this_x0, this_y0, this_field_width,
                                            this_peak_rate, arena_x, arena_y),
                               dtype=np.float32)
    target_map = np.asarray(get_rate_map(this_x0, this_y0,
                                         this_scaled_field_width,
                                         this_peak_rate, arena_x, arena_y),
                            dtype=np.float32)
    selectivity_type = env.selectivity_types['place']
    dst_input_features[destination][target_gid] = {
        'Selectivity Type': np.array([selectivity_type], dtype=np.uint8),
        'Num Fields': np.array([num_fields], dtype=np.uint8),
        'Field Width': this_field_width,
        'Peak Rate': this_peak_rate,
        'X Offset': this_x0,
        'Y Offset': this_y0,
        'Arena Rate Map': this_rate_map.ravel()
    }

    initial_weights_by_syn_id_dict = dict()
    selection = [target_gid]
    if initial_weights_path is not None:
        initial_weights_iter = \
            read_cell_attribute_selection(initial_weights_path, destination, namespace=initial_weights_namespace,
                                          selection=selection)
        syn_weight_attr_dict = dict(initial_weights_iter)

        syn_ids = syn_weight_attr_dict[target_gid]['syn_id']
        weights = syn_weight_attr_dict[target_gid][synapse_name]

        for (syn_id, weight) in zip(syn_ids, weights):
            initial_weights_by_syn_id_dict[int(syn_id)] = float(weight)

        logger.info(
            'destination: %s; gid %i; read initial synaptic weights for %i synapses'
            % (destination, target_gid, len(initial_weights_by_syn_id_dict)))

    reference_weights_by_syn_id_dict = None
    if reference_weights_path is not None:
        reference_weights_by_syn_id_dict = dict()
        reference_weights_iter = \
            read_cell_attribute_selection(reference_weights_path, destination, namespace=reference_weights_namespace,
                                          selection=selection)
        syn_weight_attr_dict = dict(reference_weights_iter)

        syn_ids = syn_weight_attr_dict[target_gid]['syn_id']
        weights = syn_weight_attr_dict[target_gid][synapse_name]

        for (syn_id, weight) in zip(syn_ids, weights):
            reference_weights_by_syn_id_dict[int(syn_id)] = float(weight)

        logger.info(
            'destination: %s; gid %i; read reference synaptic weights for %i synapses'
            % (destination, target_gid, len(reference_weights_by_syn_id_dict)))

    source_gid_set_dict = defaultdict(set)
    syn_ids_by_source_gid_dict = defaultdict(list)
    initial_weights_by_source_gid_dict = dict()
    if reference_weights_by_syn_id_dict is None:
        reference_weights_by_source_gid_dict = None
    else:
        reference_weights_by_source_gid_dict = dict()
    (graph, edge_attr_info) = read_graph_selection(file_name=connections_path,
                                                   selection=[target_gid],
                                                   namespaces=['Synapses'])
    syn_id_attr_index = None
    for source, edge_iter in viewitems(graph[destination]):
        if source not in sources:
            continue
        this_edge_attr_info = edge_attr_info[destination][source]
        if 'Synapses' in this_edge_attr_info and \
           'syn_id' in this_edge_attr_info['Synapses']:
            syn_id_attr_index = this_edge_attr_info['Synapses']['syn_id']
        for (destination_gid, edges) in edge_iter:
            assert destination_gid == target_gid
            source_gids, edge_attrs = edges
            syn_ids = edge_attrs['Synapses'][syn_id_attr_index]
            count = 0
            for i in range(len(source_gids)):
                this_source_gid = int(source_gids[i])
                source_gid_set_dict[source].add(this_source_gid)
                this_syn_id = int(syn_ids[i])
                if this_syn_id not in initial_weights_by_syn_id_dict:
                    this_weight = \
                        env.connection_config[destination][source].mechanisms['default'][synapse_name]['weight']
                    initial_weights_by_syn_id_dict[this_syn_id] = this_weight
                syn_ids_by_source_gid_dict[this_source_gid].append(this_syn_id)
                if this_source_gid not in initial_weights_by_source_gid_dict:
                    initial_weights_by_source_gid_dict[this_source_gid] = \
                        initial_weights_by_syn_id_dict[this_syn_id]
                    if reference_weights_by_source_gid_dict is not None:
                        reference_weights_by_source_gid_dict[this_source_gid] = \
                            reference_weights_by_syn_id_dict[this_syn_id]
                count += 1
            logger.info(
                'destination: %s; gid %i; set initial synaptic weights for %d inputs from source population '
                '%s' % (destination, destination_gid, count, source))

    syn_count_by_source_gid_dict = dict()
    for source_gid in syn_ids_by_source_gid_dict:
        syn_count_by_source_gid_dict[source_gid] = len(
            syn_ids_by_source_gid_dict[source_gid])

    input_rate_maps_by_source_gid_dict = dict()
    for source in sources:
        source_gids = list(source_gid_set_dict[source])
        for input_features_namespace in this_input_features_namespaces:
            input_features_iter = read_cell_attribute_selection(
                input_features_path,
                source,
                namespace=input_features_namespace,
                mask=set(features_attr_names),
                selection=source_gids)
            count = 0
            for gid, attr_dict in input_features_iter:
                input_rate_maps_by_source_gid_dict[gid] = attr_dict[
                    'Arena Rate Map'].reshape((dim_x, dim_y))
                count += 1
            logger.info('Read %s feature data for %i cells in population %s' %
                        (input_features_namespace, count, source))

    if is_interactive:
        context.update(locals())

    normalized_delta_weights_dict, arena_LS_map = \
        synapses.generate_structured_weights(target_map=target_map,
                                             initial_weight_dict=initial_weights_by_source_gid_dict,
                                             input_rate_map_dict=input_rate_maps_by_source_gid_dict,
                                             syn_count_dict=syn_count_by_source_gid_dict,
                                             max_delta_weight=max_delta_weight, arena_x=arena_x, arena_y=arena_y,
                                             reference_weight_dict=reference_weights_by_source_gid_dict,
                                             reference_weights_are_delta=reference_weights_are_delta,
                                             reference_weights_namespace=reference_weights_namespace,
                                             optimize_method=optimize_method, verbose=verbose, plot=plot)

    output_syn_ids = np.empty(len(initial_weights_by_syn_id_dict),
                              dtype='uint32')
    output_weights = np.empty(len(initial_weights_by_syn_id_dict),
                              dtype='float32')
    i = 0
    for source_gid, this_weight in viewitems(normalized_delta_weights_dict):
        for syn_id in syn_ids_by_source_gid_dict[source_gid]:
            output_syn_ids[i] = syn_id
            output_weights[i] = this_weight
            i += 1
    output_weights_dict = {
        target_gid: {
            'syn_id': output_syn_ids,
            synapse_name: output_weights
        }
    }

    logger.info('destination: %s; gid %i; generated %s for %i synapses' %
                (destination, target_gid, output_weights_namespace,
                 len(output_weights)))

    if not dry_run:
        this_output_weights_namespace = '%s %s' % (output_weights_namespace,
                                                   arena_id)
        logger.info('Destination: %s; appending %s ...' %
                    (destination, this_output_weights_namespace))
        append_cell_attributes(output_weights_path,
                               destination,
                               output_weights_dict,
                               namespace=this_output_weights_namespace)
        logger.info('Destination: %s; appended %s' %
                    (destination, this_output_weights_namespace))
        output_weights_dict.clear()
        if output_features_path is not None:
            this_output_features_namespace = '%s %s' % (
                output_features_namespace, arena_id)
            cell_attr_dict = dst_input_features[destination]
            cell_attr_dict[target_gid]['Arena State Map'] = np.asarray(
                arena_LS_map.ravel(), dtype=np.float32)
            logger.info('Destination: %s; appending %s ...' %
                        (destination, this_output_features_namespace))
            append_cell_attributes(output_features_path,
                                   destination,
                                   cell_attr_dict,
                                   namespace=this_output_features_namespace)

    if is_interactive:
        context.update(locals())
Exemplo n.º 10
0
def read_state(input_file,
               population_names,
               namespace_id,
               time_variable='t',
               state_variable='v',
               time_range=None,
               max_units=None,
               gid=None,
               comm=None,
               n_trials=-1):
    if comm is None:
        comm = MPI.COMM_WORLD
    pop_state_dict = {}

    logger.info(
        'Reading state data from populations %s, namespace %s gid = %s...' %
        (str(population_names), namespace_id, str(gid)))

    attr_info_dict = read_cell_attribute_info(input_file,
                                              populations=population_names,
                                              read_cell_index=True)

    for pop_name in population_names:
        cell_index = None
        pop_state_dict[pop_name] = {}
        for attr_name, attr_cell_index in attr_info_dict[pop_name][
                namespace_id]:
            if state_variable == attr_name:
                cell_index = attr_cell_index

        if cell_index is None:
            raise RuntimeError(
                'read_state: Unable to find recordings for state variable %s in population %s namespace %s'
                % (state_variable, pop_name, str(namespace_id)))
        cell_set = set(cell_index)

        # Limit to max_units
        if gid is None:
            if (max_units is not None) and (len(cell_set) > max_units):
                logger.info(
                    '  Reading only randomly sampled %i out of %i units for population %s'
                    % (max_units, len(cell_set), pop_name))
                sample_inds = np.random.randint(0,
                                                len(cell_set) - 1,
                                                size=int(max_units))
                cell_set_lst = list(cell_set)
                gid_set = set([cell_set_lst[i] for i in sample_inds])
            else:
                gid_set = cell_set
        else:
            gid_set = set(gid)

        state_dict = {}
        if gid is None:
            valiter = read_cell_attributes(input_file,
                                           pop_name,
                                           namespace=namespace_id,
                                           comm=comm)
        else:
            valiter = read_cell_attribute_selection(input_file,
                                                    pop_name,
                                                    namespace=namespace_id,
                                                    selection=list(gid_set),
                                                    comm=comm)

        if time_range is None:
            for cellind, vals in valiter:
                if cellind is not None:
                    trial_dur = vals.get('trial duration', None)
                    distance = vals.get('distance', [None])[0]
                    section = vals.get('section', [None])[0]
                    loc = vals.get('loc', [None])[0]
                    tvals = np.asarray(vals[time_variable], dtype=np.float32)
                    svals = np.asarray(vals[state_variable], dtype=np.float32)
                    trial_bounds = list(
                        np.where(np.isclose(tvals, tvals[0], atol=1e-4))[0])
                    n_trial_bounds = len(trial_bounds)
                    trial_bounds.append(len(tvals))
                    if n_trials == -1:
                        this_n_trials = n_trial_bounds
                    else:
                        this_n_trials = min(n_trial_bounds, n_trials)

                    if this_n_trials > 1:
                        state_dict[cellind] = (np.split(
                            tvals, trial_bounds[1:n_trials]),
                                               np.split(
                                                   svals,
                                                   trial_bounds[1:n_trials]),
                                               distance, section, loc)
                    else:
                        state_dict[cellind] = ([tvals[:trial_bounds[1]]
                                                ], [svals[:trial_bounds[1]]],
                                               distance, section, loc)

        else:
            for cellind, vals in valiter:
                if cellind is not None:
                    distance = vals.get('distance', [None])[0]
                    section = vals.get('section', [None])[0]
                    loc = vals.get('loc', [None])[0]
                    tinds = np.argwhere((vals[time_variable] <= time_range[1])
                                        &
                                        (vals[time_variable] >= time_range[0]))
                    tvals = np.asarray(vals[time_variable][tinds],
                                       dtype=np.float32).reshape((-1, ))
                    svals = np.asarray(vals[state_variable][tinds],
                                       dtype=np.float32).reshape((-1, ))
                    trial_bounds = list(
                        np.where(np.isclose(tvals, tvals[0], atol=1e-4))[0])
                    n_trial_bounds = len(trial_bounds)
                    trial_bounds.append(len(tvals))
                    if n_trials == -1:
                        this_n_trials = n_trial_bounds
                    else:
                        this_n_trials = min(n_trial_bounds, n_trials)

                    if this_n_trials > 1:
                        state_dict[cellind] = (np.split(
                            tvals, trial_bounds[1:n_trials]),
                                               np.split(
                                                   svals,
                                                   trial_bounds[1:n_trials]),
                                               distance, section, loc)
                    else:
                        state_dict[cellind] = ([tvals[:trial_bounds[1]]
                                                ], [svals[:trial_bounds[1]]],
                                               distance, section, loc)

        pop_state_dict[pop_name] = state_dict

    return {
        'states': pop_state_dict,
        'time_variable': time_variable,
        'state_variable': state_variable
    }
Exemplo n.º 11
0
def main(config_file, config_prefix, input_path, population, template_paths,
         dataset_prefix, results_path, results_file_id, results_namespace_id,
         v_init, io_size, chunk_size, value_chunk_size, write_size, verbose):

    utils.config_logging(verbose)
    logger = utils.get_script_logger(os.path.basename(__file__))

    comm = MPI.COMM_WORLD
    rank = comm.rank

    if rank == 0:
        logger.info('%i ranks have been allocated' % comm.size)

    if io_size == -1:
        io_size = comm.size

    if results_file_id is None:
        if rank == 0:
            result_file_id = uuid.uuid4()
        results_file_id = comm.bcast(results_file_id, root=0)
    if results_namespace_id is None:
        results_namespace_id = 'Cell Clamp Results'
    comm = MPI.COMM_WORLD
    np.seterr(all='raise')
    verbose = True
    params = dict(locals())
    env = Env(**params)
    configure_hoc_env(env)
    if rank == 0:
        io_utils.mkout(env, env.results_file_path)
    env.comm.barrier()
    env.cell_selection = {}
    template_class = load_cell_template(env, population)

    if input_path is not None:
        env.data_file_path = input_path
        env.load_celltypes()

    synapse_config = env.celltypes[population]['synapses']

    weights_namespaces = []
    if 'weights' in synapse_config:
        has_weights = synapse_config['weights']
        if has_weights:
            if 'weights namespace' in synapse_config:
                weights_namespaces.append(synapse_config['weights namespace'])
            elif 'weights namespaces' in synapse_config:
                weights_namespaces.extend(synapse_config['weights namespaces'])
            else:
                weights_namespaces.append('Weights')
    else:
        has_weights = False

    start_time = time.time()
    count = 0
    gid_count = 0
    attr_dict = {}
    if input_path is None:
        cell_path = env.data_file_path
        connectivity_path = env.connectivity_file_path
    else:
        cell_path = input_path
        connectivity_path = input_path

    for gid, morph_dict in NeuroH5TreeGen(cell_path,
                                          population,
                                          io_size=io_size,
                                          comm=env.comm,
                                          topology=True):
        local_time = time.time()
        if gid is not None:
            color = 0
            comm0 = comm.Split(color, 0)

            logger.info('Rank %i gid: %i' % (rank, gid))
            cell_dict = {'morph': morph_dict}
            synapses_iter = read_cell_attribute_selection(cell_path,
                                                          population, [gid],
                                                          'Synapse Attributes',
                                                          comm=comm0)
            _, synapse_dict = next(synapses_iter)
            cell_dict['synapse'] = synapse_dict

            if has_weights:
                cell_weights_iters = [
                    read_cell_attribute_selection(cell_path,
                                                  population, [gid],
                                                  weights_namespace,
                                                  comm=comm0)
                    for weights_namespace in weights_namespaces
                ]
                weight_dict = dict(
                    zip_longest(weights_namespaces, cell_weights_iters))
                cell_dict['weight'] = weight_dict

            (graph,
             a) = read_graph_selection(file_name=connectivity_path,
                                       selection=[gid],
                                       namespaces=['Synapses', 'Connections'],
                                       comm=comm0)
            cell_dict['connectivity'] = (graph, a)

            gid_count += 1

            attr_dict[gid] = {}
            attr_dict[gid].update(
                cell_clamp.measure_passive(gid,
                                           population,
                                           v_init,
                                           env,
                                           cell_dict=cell_dict))
            attr_dict[gid].update(
                cell_clamp.measure_ap(gid,
                                      population,
                                      v_init,
                                      env,
                                      cell_dict=cell_dict))
            attr_dict[gid].update(
                cell_clamp.measure_ap_rate(gid,
                                           population,
                                           v_init,
                                           env,
                                           cell_dict=cell_dict))
            attr_dict[gid].update(
                cell_clamp.measure_fi(gid,
                                      population,
                                      v_init,
                                      env,
                                      cell_dict=cell_dict))

        else:
            color = 1
            comm0 = comm.Split(color, 0)
            logger.info('Rank %i gid is None' % (rank))
        comm0.Free()

        count += 1
        if (results_path is not None) and (count % write_size == 0):
            append_cell_attributes(env.results_file_path,
                                   population,
                                   attr_dict,
                                   namespace=env.results_namespace_id,
                                   comm=env.comm,
                                   io_size=env.io_size,
                                   chunk_size=chunk_size,
                                   value_chunk_size=value_chunk_size)
            attr_dict = {}

    env.comm.barrier()
    if results_path is not None:
        append_cell_attributes(env.results_file_path,
                               population,
                               attr_dict,
                               namespace=env.results_namespace_id,
                               comm=env.comm,
                               io_size=env.io_size,
                               chunk_size=chunk_size,
                               value_chunk_size=value_chunk_size)
    global_count = env.comm.gather(gid_count, root=0)

    MPI.Finalize()
Exemplo n.º 12
0
def write_input_cell_selection(env,
                               input_sources,
                               write_selection_file_path,
                               populations=None,
                               write_kwds={}):
    """
    Writes out predefined spike trains when only a subset of the network is instantiated.

    :param env: an instance of the `dentate.Env` class
    :param input_sources: a dictionary of the form { pop_name, gid_sources }
    """

    if 'comm' not in write_kwds:
        write_kwds['comm'] = env.comm
    if 'io_size' not in write_kwds:
        write_kwds['io_size'] = env.io_size

    rank = int(env.comm.Get_rank())
    nhosts = int(env.comm.Get_size())

    dataset_path = env.dataset_path
    input_file_path = env.data_file_path

    if populations is None:
        pop_names = sorted(env.celltypes.keys())
    else:
        pop_names = populations

    for pop_name, gid_range in sorted(viewitems(input_sources)):

        if pop_name not in pop_names:
            continue

        spikes_output_dict = {}

        if (env.cell_selection is not None) and (pop_name
                                                 in env.cell_selection):
            local_gid_range = gid_range.difference(
                set(env.cell_selection[pop_name]))
        else:
            local_gid_range = gid_range

        gid_ranges = env.comm.allgather(local_gid_range)
        this_gid_range = set([])
        for gid_range in gid_ranges:
            for gid in gid_range:
                if gid % nhosts == rank:
                    this_gid_range.add(gid)

        has_spike_train = False
        spike_input_source_loc = []
        if (env.spike_input_attribute_info is not None) and (env.spike_input_ns
                                                             is not None):
            if (pop_name in env.spike_input_attribute_info) and \
                    (env.spike_input_ns in env.spike_input_attribute_info[pop_name]):
                has_spike_train = True
                spike_input_source_loc.append(
                    (env.spike_input_path, env.spike_input_ns))
        if (env.cell_attribute_info is not None) and (env.spike_input_ns
                                                      is not None):
            if (pop_name in env.cell_attribute_info) and \
                    (env.spike_input_ns in env.cell_attribute_info[pop_name]):
                has_spike_train = True
                spike_input_source_loc.append(
                    (input_file_path, env.spike_input_ns))

        if rank == 0:
            logger.info(
                '*** Reading spike trains for population %s: %d cells: has_spike_train = %s'
                % (pop_name, len(this_gid_range), str(has_spike_train)))

        if has_spike_train:

            vecstim_attr_set = set(['t'])
            if env.spike_input_attr is not None:
                vecstim_attr_set.add(env.spike_input_attr)
            if 'spike train' in env.celltypes[pop_name]:
                vecstim_attr_set.add(
                    env.celltypes[pop_name]['spike train']['attribute'])

            cell_spikes_iters = [ read_cell_attribute_selection(input_path, pop_name, \
                                                                list(this_gid_range), \
                                                                namespace=input_ns, \
                                                                mask=vecstim_attr_set, \
                                                                comm=env.comm) for (input_path, input_ns) in spike_input_source_loc ]

            for cell_spikes_iter in cell_spikes_iters:
                spikes_output_dict.update(dict(list(cell_spikes_iter)))

        if rank == 0:
            logger.info('*** Writing spike trains for population %s: %s' %
                        (pop_name, str(spikes_output_dict)))


        write_cell_attributes(write_selection_file_path, pop_name, spikes_output_dict,  \
                              namespace=env.spike_input_ns, **write_kwds)
Exemplo n.º 13
0
def write_connection_selection(env,
                               write_selection_file_path,
                               populations=None,
                               write_kwds={}):
    """
    Loads NeuroH5 connectivity file, and writes the corresponding
    synapse and network connection mechanisms for the selected postsynaptic cells.

    :param env: an instance of the `dentate.Env` class
    """

    if 'comm' not in write_kwds:
        write_kwds['comm'] = env.comm
    if 'io_size' not in write_kwds:
        write_kwds['io_size'] = env.io_size

    connectivity_file_path = env.connectivity_file_path
    forest_file_path = env.forest_file_path
    rank = int(env.comm.Get_rank())
    nhosts = int(env.comm.Get_size())
    syn_attrs = env.synapse_attributes

    if populations is None:
        pop_names = sorted(env.cell_selection.keys())
    else:
        pop_names = populations

    input_sources = {pop_name: set([]) for pop_name in env.celltypes}

    for (postsyn_name, presyn_names) in sorted(viewitems(env.projection_dict)):

        if rank == 0:
            logger.info('*** Writing connection selection of population %s' %
                        (postsyn_name))

        if postsyn_name not in pop_names:
            continue

        gid_range = [
            gid for gid in env.cell_selection[postsyn_name]
            if gid % nhosts == rank
        ]

        synapse_config = env.celltypes[postsyn_name]['synapses']

        if 'weights' in synapse_config:
            has_weights = synapse_config['weights']
        else:
            has_weights = False

        weights_namespaces = []
        if 'weights' in synapse_config:
            has_weights = synapse_config['weights']
            if has_weights:
                if 'weights namespace' in synapse_config:
                    weights_namespaces.append(
                        synapse_config['weights namespace'])
                elif 'weights namespaces' in synapse_config:
                    weights_namespaces.extend(
                        synapse_config['weights namespaces'])
                else:
                    weights_namespaces.append('Weights')
        else:
            has_weights = False

        if rank == 0:
            logger.info('*** Reading synaptic attributes of population %s' %
                        (postsyn_name))

        syn_attributes_iter = read_cell_attribute_selection(
            forest_file_path,
            postsyn_name,
            selection=gid_range,
            namespace='Synapse Attributes',
            comm=env.comm)

        syn_attributes_output_dict = dict(list(syn_attributes_iter))
        write_cell_attributes(write_selection_file_path,
                              postsyn_name,
                              syn_attributes_output_dict,
                              namespace='Synapse Attributes',
                              **write_kwds)
        del syn_attributes_output_dict
        del syn_attributes_iter

        if has_weights:
            for weights_namespace in sorted(weights_namespaces):
                weight_attributes_iter = read_cell_attribute_selection(
                    forest_file_path,
                    postsyn_name,
                    selection=gid_range,
                    namespace=weights_namespace,
                    comm=env.comm)
                weight_attributes_output_dict = dict(
                    list(weight_attributes_iter))
                write_cell_attributes(write_selection_file_path,
                                      postsyn_name,
                                      weight_attributes_output_dict,
                                      namespace=weights_namespace,
                                      **write_kwds)
                del weight_attributes_output_dict
                del weight_attributes_iter

        logger.info(
            '*** Rank %i: reading connectivity selection from file %s for postsynaptic population: %s: selection: %s'
            % (rank, connectivity_file_path, postsyn_name, str(gid_range)))

        (graph, attr_info) = read_graph_selection(connectivity_file_path, selection=gid_range, \
                                                  projections=[ (presyn_name, postsyn_name) for presyn_name in sorted(presyn_names) ], \
                                                  comm=env.comm, namespaces=['Synapses', 'Connections'])

        for presyn_name in sorted(presyn_names):
            gid_dict = {}
            edge_count = 0
            node_count = 0
            if postsyn_name in graph:

                if postsyn_name in attr_info and presyn_name in attr_info[
                        postsyn_name]:
                    edge_attr_info = attr_info[postsyn_name][presyn_name]
                else:
                    raise RuntimeError('write_connection_selection: missing edge attributes for projection %s -> %s' % \
                                       (presyn_name, postsyn_name))

                if 'Synapses' in edge_attr_info and \
                        'syn_id' in edge_attr_info['Synapses'] and \
                        'Connections' in edge_attr_info and \
                        'distance' in edge_attr_info['Connections']:
                    syn_id_attr_index = edge_attr_info['Synapses']['syn_id']
                    distance_attr_index = edge_attr_info['Connections'][
                        'distance']
                else:
                    raise RuntimeError('write_connection_selection: missing edge attributes for projection %s -> %s' % \
                                           (presyn_name, postsyn_name))

                edge_iter = compose_iter(lambda edgeset: input_sources[presyn_name].update(edgeset[1][0]), \
                                         graph[postsyn_name][presyn_name])
                for (postsyn_gid, edges) in edge_iter:

                    presyn_gids, edge_attrs = edges
                    edge_syn_ids = edge_attrs['Synapses'][syn_id_attr_index]
                    edge_dists = edge_attrs['Connections'][distance_attr_index]

                    gid_dict[postsyn_gid] = (presyn_gids, {
                        'Synapses': {
                            'syn_id': edge_syn_ids
                        },
                        'Connections': {
                            'distance': edge_dists
                        }
                    })
                    edge_count += len(presyn_gids)
                    node_count += 1

            logger.info(
                '*** Rank %d: Writing projection %s -> %s selection: %d nodes, %d edges'
                % (rank, presyn_name, postsyn_name, node_count, edge_count))
            write_graph(write_selection_file_path, \
                        src_pop_name=presyn_name, dst_pop_name=postsyn_name, \
                        edges=gid_dict, comm=env.comm, io_size=env.io_size)
        env.comm.barrier()

    return input_sources
Exemplo n.º 14
0
def write_cell_selection(env,
                         write_selection_file_path,
                         populations=None,
                         write_kwds={}):
    """
    Writes out the data necessary to instantiate the selected cells.

    :param env: an instance of the `dentate.Env` class
    """

    if 'comm' not in write_kwds:
        write_kwds['comm'] = env.comm
    if 'io_size' not in write_kwds:
        write_kwds['io_size'] = env.io_size

    rank = int(env.comm.Get_rank())
    nhosts = int(env.comm.Get_size())

    dataset_path = env.dataset_path
    data_file_path = env.data_file_path

    if populations is None:
        pop_names = sorted(env.cell_selection.keys())
    else:
        pop_names = populations

    for pop_name in pop_names:

        gid_range = [
            gid for gid in env.cell_selection[pop_name] if gid % nhosts == rank
        ]

        trees_output_dict = {}
        coords_output_dict = {}
        num_cells = 0
        if (pop_name in env.cell_attribute_info) and (
                'Trees' in env.cell_attribute_info[pop_name]):
            if rank == 0:
                logger.info("*** Reading trees for population %s" % pop_name)

            cell_tree_iter, _ = read_tree_selection(data_file_path, pop_name, selection=gid_range, \
                                                 topology=False, comm=env.comm)
            if rank == 0:
                logger.info("*** Done reading trees for population %s" %
                            pop_name)

            for i, (gid, tree) in enumerate(cell_tree_iter):
                trees_output_dict[gid] = tree
                num_cells += 1

            assert (len(trees_output_dict) == len(gid_range))

        elif (pop_name in env.cell_attribute_info) and (
                'Coordinates' in env.cell_attribute_info[pop_name]):
            if rank == 0:
                logger.info("*** Reading coordinates for population %s" %
                            pop_name)

            cell_attributes_iter = read_cell_attribute_selection(data_file_path, pop_name, selection=gid_range, \
                                                                 namespace='Coordinates', comm=env.comm)

            if rank == 0:
                logger.info("*** Done reading coordinates for population %s" %
                            pop_name)

            for i, (gid, coords) in enumerate(cell_attributes_iter):
                coords_output_dict[gid] = coords
                num_cells += 1

        if rank == 0:
            logger.info(
                "*** Writing cell selection for population %s to file %s" %
                (pop_name, write_selection_file_path))
        append_cell_trees(write_selection_file_path,
                          pop_name,
                          trees_output_dict,
                          create_index=True,
                          **write_kwds)
        write_cell_attributes(write_selection_file_path,
                              pop_name,
                              coords_output_dict,
                              namespace='Coordinates',
                              **write_kwds)
Exemplo n.º 15
0
def main(config, coordinates, gid, field_width, peak_rate, input_features_path, input_features_namespaces,
         output_weights_path, output_features_path, weights_path, h5types_path, synapse_name, initial_weights_namespace,
         structured_weights_namespace, connections_path, destination, sources, arena_id, baseline_weight,
         field_width_scale, max_iter, verbose, dry_run, interactive):
    """

    :param config: str (path to .yaml file)
    :param weights_path: str (path to .h5 file)
    :param initial_weights_namespace: str
    :param structured_weights_namespace: str
    :param connections_path: str (path to .h5 file)
    :param destination: str
    :param sources: list of str
    :param verbose:
    :param dry_run:
    :return:
    """

    utils.config_logging(verbose)
    logger = utils.get_script_logger(__file__)

    env = Env(config_file=config)

    if output_weights_path is None:
        if weights_path is None:
            raise RuntimeError('Output weights path must be specified when weights path is not specified.')
        output_weights_path = weights_path
    
    if (not dry_run):
        if not os.path.isfile(output_weights_path):
            if weights_path is not None:
                input_file  = h5py.File(weights_path,'r')
            elif h5types_path is not None:
                input_file  = h5py.File(h5types_path,'r')
            else:
                raise RuntimeError('h5types input path must be specified when weights path is not specified.')
            output_file = h5py.File(output_weights_path,'w')
            input_file.copy('/H5Types',output_file)
            input_file.close()
            output_file.close()

    
    this_input_features_namespaces = ['%s %s' % (input_features_namespace, arena_id) for input_features_namespace in input_features_namespaces]

    initial_weights_dict = None
    if weights_path is not None:
        logger.info('Reading initial weights data from %s...' % weights_path)
        cell_attributes_dict = read_cell_attribute_selection(weights_path, destination, 
                                                             namespaces=[initial_weights_namespace],
                                                             selection=[gid])
                                                            
        if initial_weights_namespace in cell_attributes_dict:
            initial_weights_iter = cell_attributes_dict[initial_weights_namespace]
            initial_weights_dict = { gid: attr_dict for gid, attr_dict in initial_weights_iter }
        else:
            raise RuntimeError('Initial weights namespace %s was not found in file %s' % (initial_weights_namespace, weights_path))
    
        logger.info('Rank %i; destination: %s; read synaptic weights for %i cells' %
                    (env.comm.rank, destination, len(initial_weights_dict)))


    features_attr_names = ['Num Fields', 'Field Width', 'Peak Rate', 'X Offset', 'Y Offset', 'Arena Rate Map']
    
    local_random = np.random.RandomState()

    seed_offset = int(env.model_config['Random Seeds']['GC Structured Weights'])
    local_random.seed(int(gid + seed_offset))
    
    spatial_resolution = env.stimulus_config['Spatial Resolution'] # cm

    arena = env.stimulus_config['Arena'][arena_id]
    default_run_vel = arena.properties['default run velocity']  # cm/s

    x, y = stimulus.get_2D_arena_spatial_mesh(arena, spatial_resolution)
    
    plasticity_kernel = lambda x, y, x_loc, y_loc, sx, sy: gauss2d(x-x_loc, y-y_loc, sx=sx, sy=sy)
    plasticity_kernel = np.vectorize(plasticity_kernel, excluded=[2,3,4,5])


    dst_input_features = defaultdict(dict)
    num_fields = len(coordinates)
    this_field_width = np.array([field_width]*num_fields, dtype=np.float32)
    this_peak_rate = np.array([peak_rate]*num_fields, dtype=np.float32)
    this_x0 = np.array([x for x, y in coordinates], dtype=np.float32)
    this_y0 = np.array([y for x, y in coordinates], dtype=np.float32)
    this_rate_map = np.asarray(get_rate_map(this_x0, this_y0, this_field_width, this_peak_rate, x, y),
                               dtype=np.float32)
    selectivity_type = env.selectivity_types['place']
    dst_input_features[destination][gid] = {
        'Selectivity Type': np.array([selectivity_type], dtype=np.uint8),
        'Num Fields': np.array([num_fields], dtype=np.uint8),
        'Field Width': this_field_width,
        'Peak Rate': this_peak_rate,
        'X Offset': this_x0,
        'Y Offset': this_y0,
        'Arena Rate Map': this_rate_map.ravel() }

    selection=[gid]
    structured_weights_dict = {}
    source_syn_dict = defaultdict(lambda: defaultdict(list))
    syn_weight_dict = {}
    if weights_path is not None:
        initial_weights_iter = read_cell_attribute_selection(weights_path, destination, 
                                                                 namespace=initial_weights_namespace, 
                                                                 selection=selection)
        syn_weight_attr_dict = dict(initial_weights_iter)

        syn_ids = syn_weight_attr_dict[gid]['syn_id']
        weights = syn_weight_attr_dict[gid][synapse_name]
                    
        for (syn_id, weight) in zip(syn_ids, weights):
            syn_weight_dict[int(syn_id)] = float(weight) 

        logger.info('destination: %s; gid %i; received synaptic weights for %i synapses' %
                        (destination, gid, len(syn_weight_dict)))

    (graph, edge_attr_info) = read_graph_selection(file_name=connections_path,
                                                   selection=[gid],
                                                   namespaces=['Synapses'])
    syn_id_attr_index = None
    for source, edge_iter in viewitems(graph[destination]):
        this_edge_attr_info = edge_attr_info[destination][source]
        if 'Synapses' in this_edge_attr_info and \
           'syn_id' in this_edge_attr_info['Synapses']:
            syn_id_attr_index = this_edge_attr_info['Synapses']['syn_id']
        for (destination_gid, edges) in edge_iter:
            assert destination_gid == gid
            source_gids, edge_attrs = edges

            syn_ids = edge_attrs['Synapses'][syn_id_attr_index]
            this_source_syn_dict = source_syn_dict[source]
            count = 0
            for i in range(len(source_gids)):
                this_source_gid = source_gids[i]
                this_syn_id = syn_ids[i]
                this_syn_wgt = syn_weight_dict.get(this_syn_id, 0.0)
                this_source_syn_dict[this_source_gid].append((this_syn_id, this_syn_wgt))
                count += 1
            logger.info('destination: %s; gid %i; %d synaptic weights from source population %s' %
                        (destination, gid, count, source))
                    
    src_input_features = defaultdict(dict)
    for source in sources:
        source_gids = list(source_syn_dict[source].keys())
        for input_features_namespace in this_input_features_namespaces:
            input_features_iter = read_cell_attribute_selection(input_features_path, source, 
                                                                namespace=input_features_namespace,
                                                                mask=set(features_attr_names), 
                                                                selection=source_gids)
            this_src_input_features = src_input_features[source]
            count = 0
            for gid, attr_dict in input_features_iter:
                this_src_input_features[gid] = attr_dict
                count += 1
            logger.info('Read %s feature data for %i cells in population %s' % (input_features_namespace, count, source))

    this_syn_weights = \
      synapses.generate_structured_weights(destination_gid,
                                           destination,
                                           synapse_name,
                                           sources,
                                           dst_input_features,
                                           src_input_features,
                                           source_syn_dict,
                                           spatial_mesh=(x,y),
                                           plasticity_kernel=plasticity_kernel,
                                           field_width_scale=field_width_scale,
                                           baseline_weight=baseline_weight,
                                           local_random=local_random,
                                           interactive=interactive)

    assert this_syn_weights is not None
    structured_weights_dict[destination_gid] = this_syn_weights
    logger.info('destination: %s; gid %i; generated structured weights for %i inputs'
                   % (destination, destination_gid, len(this_syn_weights['syn_id'])))
    gc.collect()
    if not dry_run:
            
        logger.info('Destination: %s; appending structured weights...' % (destination))
        this_structured_weights_namespace = '%s %s' % (structured_weights_namespace, arena_id)
        append_cell_attributes(output_weights_path, destination, structured_weights_dict,
                               namespace=this_structured_weights_namespace)
        logger.info('Destination: %s; appended structured weights' % (destination))
        structured_weights_dict.clear()
        if output_features_path is not None:
            output_features_namespace = 'Place Selectivity %s' % arena_id
            cell_attr_dict = dst_input_features[destination]
            logger.info('Destination: %s; appending features...' % (destination))
            append_cell_attributes(output_features_path, destination,
                                   cell_attr_dict, namespace=output_features_namespace)
            
            
        gc.collect()
            
    del(syn_weight_dict)
    del(src_input_features)
    del(dst_input_features)
def main(config, coordinates, field_width, gid, input_features_path,
         input_features_namespaces, initial_weights_path,
         output_features_namespace, output_features_path, output_weights_path,
         reference_weights_path, h5types_path, synapse_name,
         initial_weights_namespace, output_weights_namespace,
         reference_weights_namespace, connections_path, destination, sources,
         non_structured_sources, non_structured_weights_namespace,
         non_structured_weights_path, arena_id, field_width_scale,
         max_delta_weight, max_opt_iter, max_weight_decay_fraction,
         optimize_method, optimize_tol, optimize_grad, peak_rate,
         reference_weights_are_delta, arena_margin, target_amplitude, io_size,
         chunk_size, value_chunk_size, cache_size, write_size, verbose,
         dry_run, plot, show_fig, save_fig):
    """

    :param config: str (path to .yaml file)
    :param input_features_path: str (path to .h5 file)
    :param initial_weights_path: str (path to .h5 file)
    :param initial_weights_namespace: str
    :param output_weights_namespace: str
    :param connections_path: str (path to .h5 file)
    :param destination: str
    :param sources: list of str
    :param io_size:
    :param chunk_size:
    :param value_chunk_size:
    :param cache_size:
    :param write_size:
    :param verbose:
    :param dry_run:
    :return:
    """

    utils.config_logging(verbose)
    logger = utils.get_script_logger(__file__)

    comm = MPI.COMM_WORLD
    rank = comm.rank
    nranks = comm.size

    if io_size == -1:
        io_size = comm.size
    if rank == 0:
        logger.info('%s: %i ranks have been allocated' % (__file__, comm.size))

    env = Env(comm=comm, config_file=config, io_size=io_size)

    if plot and (not save_fig) and (not show_fig):
        show_fig = True

    if (not dry_run) and (rank == 0):
        if not os.path.isfile(output_weights_path):
            if initial_weights_path is not None:
                input_file = h5py.File(initial_weights_path, 'r')
            elif h5types_path is not None:
                input_file = h5py.File(h5types_path, 'r')
            else:
                raise RuntimeError(
                    'h5types input path must be specified when weights path is not specified.'
                )
            output_file = h5py.File(output_weights_path, 'w')
            input_file.copy('/H5Types', output_file)
            input_file.close()
            output_file.close()
    env.comm.barrier()

    LTD_output_weights_namespace = 'LTD %s %s' % (output_weights_namespace,
                                                  arena_id)
    LTP_output_weights_namespace = 'LTP %s %s' % (output_weights_namespace,
                                                  arena_id)
    this_input_features_namespaces = [
        '%s %s' % (input_features_namespace, arena_id)
        for input_features_namespace in input_features_namespaces
    ]

    selectivity_type_index = {
        i: n
        for n, i in viewitems(env.selectivity_types)
    }
    target_selectivity_type_name = 'place'
    target_selectivity_type = env.selectivity_types[
        target_selectivity_type_name]
    features_attrs = defaultdict(dict)
    source_features_attr_names = [
        'Selectivity Type', 'Num Fields', 'Field Width', 'Peak Rate',
        'Module ID', 'Grid Spacing', 'Grid Orientation',
        'Field Width Concentration Factor', 'X Offset', 'Y Offset'
    ]
    target_features_attr_names = [
        'Selectivity Type', 'Num Fields', 'Field Width', 'Peak Rate',
        'X Offset', 'Y Offset'
    ]

    local_random = np.random.RandomState()

    seed_offset = int(
        env.model_config['Random Seeds']['GC Structured Weights'])
    spatial_resolution = env.stimulus_config['Spatial Resolution']  # cm

    arena = env.stimulus_config['Arena'][arena_id]
    default_run_vel = arena.properties['default run velocity']  # cm/s

    gid_count = 0
    start_time = time.time()

    target_gid_set = None
    if len(gid) > 0:
        target_gid_set = set(gid)

    all_sources = sources + non_structured_sources

    connection_gen_list = [ NeuroH5ProjectionGen(connections_path, source, destination, namespaces=['Synapses'], comm=comm) \
                               for source in all_sources ]

    output_features_dict = {}
    LTP_output_weights_dict = {}
    LTD_output_weights_dict = {}
    for iter_count, attr_gen_package in enumerate(
            zip_longest(*connection_gen_list)):

        local_time = time.time()
        this_gid = attr_gen_package[0][0]
        if not all([
                attr_gen_items[0] == this_gid
                for attr_gen_items in attr_gen_package
        ]):
            raise Exception(
                'Rank: %i; destination: %s; this_gid not matched across multiple attribute '
                'generators: %s' %
                (rank, destination,
                 [attr_gen_items[0] for attr_gen_items in attr_gen_package]))

        if (target_gid_set is not None) and (this_gid not in target_gid_set):
            continue

        if this_gid is None:
            selection = []
            logger.info('Rank: %i received None' % rank)
        else:
            selection = [this_gid]
            local_random.seed(int(this_gid + seed_offset))

        has_structured_weights = False

        dst_input_features_attr_dict = {}
        for input_features_namespace in this_input_features_namespaces:
            input_features_iter = read_cell_attribute_selection(
                input_features_path,
                destination,
                namespace=input_features_namespace,
                mask=set(target_features_attr_names),
                comm=env.comm,
                selection=selection)
            count = 0
            for gid, attr_dict in input_features_iter:
                dst_input_features_attr_dict[gid] = attr_dict
                count += 1
            if rank == 0:
                logger.info(
                    'Read %s feature data for %i cells in population %s' %
                    (input_features_namespace, count, destination))

        arena_margin_size = 0.
        arena_margin = max(arena_margin, 0.)
        target_selectivity_features_dict = {}
        target_selectivity_config_dict = {}
        target_field_width_dict = {}
        for gid in selection:
            target_selectivity_features_dict[
                gid] = dst_input_features_attr_dict.get(gid, {})
            target_selectivity_features_dict[gid][
                'Selectivity Type'] = np.asarray([target_selectivity_type],
                                                 dtype=np.uint8)

            num_fields = target_selectivity_features_dict[gid]['Num Fields'][0]

            if coordinates[0] is not None:
                num_fields = 1
                target_selectivity_features_dict[gid]['X Offset'] = np.asarray(
                    [coordinates[0]], dtype=np.float32)
                target_selectivity_features_dict[gid]['Y Offset'] = np.asarray(
                    [coordinates[1]], dtype=np.float32)
                target_selectivity_features_dict[gid][
                    'Num Fields'] = np.asarray([num_fields], dtype=np.uint8)

            if field_width is not None:
                target_selectivity_features_dict[gid][
                    'Field Width'] = np.asarray([field_width] * num_fields,
                                                dtype=np.float32)
            else:
                this_field_width = target_selectivity_features_dict[gid][
                    'Field Width']
                target_selectivity_features_dict[gid][
                    'Field Width'] = this_field_width[:num_fields]

            if peak_rate is not None:
                target_selectivity_features_dict[gid][
                    'Peak Rate'] = np.asarray([peak_rate] * num_fields,
                                              dtype=np.float32)

            input_cell_config = stimulus.get_input_cell_config(
                target_selectivity_type,
                selectivity_type_index,
                selectivity_attr_dict=target_selectivity_features_dict[gid])
            if input_cell_config.num_fields > 0:
                arena_margin_size = max(
                    arena_margin_size,
                    np.max(input_cell_config.field_width) * arena_margin)
                target_field_width_dict[gid] = input_cell_config.field_width
                target_selectivity_config_dict[gid] = input_cell_config
                has_structured_weights = True

        arena_x, arena_y = stimulus.get_2D_arena_spatial_mesh(
            arena, spatial_resolution, margin=arena_margin_size)
        for gid, input_cell_config in viewitems(
                target_selectivity_config_dict):
            target_map = np.asarray(input_cell_config.get_rate_map(
                arena_x, arena_y, scale=field_width_scale),
                                    dtype=np.float32)
            target_selectivity_features_dict[gid][
                'Arena Rate Map'] = target_map

        if not has_structured_weights:
            selection = []

        initial_weights_by_syn_id_dict = defaultdict(lambda: dict())
        initial_weights_by_source_gid_dict = defaultdict(lambda: dict())

        if initial_weights_path is not None:
            initial_weights_iter = \
              read_cell_attribute_selection(initial_weights_path, destination,
                                            namespace=initial_weights_namespace,
                                            selection=selection)

            initial_weights_gid_count = 0
            initial_weights_syn_count = 0
            for this_gid, syn_weight_attr_dict in initial_weights_iter:
                syn_ids = syn_weight_attr_dict['syn_id']
                weights = syn_weight_attr_dict[synapse_name]

                for (syn_id, weight) in zip(syn_ids, weights):
                    initial_weights_by_syn_id_dict[this_gid][int(
                        syn_id)] = float(weight)
                initial_weights_gid_count += 1
                initial_weights_syn_count += len(syn_ids)

            logger.info(
                'destination: %s; read initial synaptic weights for %i gids and %i syns'
                % (destination, initial_weights_gid_count,
                   initial_weights_syn_count))

        if len(non_structured_sources) > 0:
            non_structured_weights_by_syn_id_dict = defaultdict(lambda: dict())
            non_structured_weights_by_source_gid_dict = defaultdict(
                lambda: dict())
        else:
            non_structured_weights_by_syn_id_dict = None

        if non_structured_weights_path is not None:
            non_structured_weights_iter = \
                read_cell_attribute_selection(initial_weights_path, destination,
                                              namespace=non_structured_weights_namespace,
                                              selection=selection)

            non_structured_weights_gid_count = 0
            non_structured_weights_syn_count = 0
            for this_gid, syn_weight_attr_dict in non_structured_weights_iter:
                syn_ids = syn_weight_attr_dict['syn_id']
                weights = syn_weight_attr_dict[synapse_name]

                for (syn_id, weight) in zip(syn_ids, weights):
                    non_structured_weights_by_syn_id_dict[this_gid][int(
                        syn_id)] = float(weight)
                non_structured_weights_gid_count += 1
                non_structured_weights_syn_count += len(syn_ids)

            logger.info(
                'destination: %s; read non-structured synaptic weights for %i gids and %i syns'
                % (
                    destination,
                    non_structured_weights_gid_count,
                    non_structured_weights_syn_count,
                ))

        reference_weights_by_syn_id_dict = None
        reference_weights_by_source_gid_dict = defaultdict(lambda: dict())
        if reference_weights_path is not None:
            reference_weights_by_syn_id_dict = defaultdict(lambda: dict())
            reference_weights_iter = \
              read_cell_attribute_selection(reference_weights_path, destination, namespace=reference_weights_namespace,
                                            selection=selection)
            reference_weights_gid_count = 0

            for this_gid, syn_weight_attr_dict in reference_weights_iter:
                syn_ids = syn_weight_attr_dict['syn_id']
                weights = syn_weight_attr_dict[synapse_name]

                for (syn_id, weight) in zip(syn_ids, weights):
                    reference_weights_by_syn_id_dict[this_gid][int(
                        syn_id)] = float(weight)

            logger.info(
                'destination: %s; read reference synaptic weights for %i gids'
                % (destination, reference_weights_gid_count))

        syn_count_by_source_gid_dict = defaultdict(int)
        source_gid_set_dict = defaultdict(set)
        syn_ids_by_source_gid_dict = defaultdict(list)
        structured_syn_id_count = 0

        if has_structured_weights:
            for source, (destination_gid, (source_gid_array,
                                           conn_attr_dict)) in zip_longest(
                                               all_sources, attr_gen_package):
                syn_ids = conn_attr_dict['Synapses']['syn_id']
                count = 0
                this_initial_weights_by_syn_id_dict = None
                this_initial_weights_by_source_gid_dict = None
                this_reference_weights_by_syn_id_dict = None
                this_reference_weights_by_source_gid_dict = None
                this_non_structured_weights_by_syn_id_dict = None
                this_non_structured_weights_by_source_gid_dict = None
                if destination_gid is not None:
                    this_initial_weights_by_syn_id_dict = initial_weights_by_syn_id_dict[
                        destination_gid]
                    this_initial_weights_by_source_gid_dict = initial_weights_by_source_gid_dict[
                        destination_gid]
                    if reference_weights_by_syn_id_dict is not None:
                        this_reference_weights_by_syn_id_dict = reference_weights_by_syn_id_dict[
                            destination_gid]
                        this_reference_weights_by_source_gid_dict = reference_weights_by_source_gid_dict[
                            destination_gid]
                    this_non_structured_weights_by_syn_id_dict = non_structured_weights_by_syn_id_dict[
                        destination_gid]
                    this_non_structured_weights_by_source_gid_dict = non_structured_weights_by_source_gid_dict[
                        destination_gid]

                for i in range(len(source_gid_array)):
                    this_source_gid = source_gid_array[i]
                    this_syn_id = syn_ids[i]
                    if this_syn_id in this_initial_weights_by_syn_id_dict:
                        this_syn_wgt = this_initial_weights_by_syn_id_dict[
                            this_syn_id]
                        if this_source_gid not in this_initial_weights_by_source_gid_dict:
                            this_initial_weights_by_source_gid_dict[
                                this_source_gid] = this_syn_wgt
                        if this_reference_weights_by_syn_id_dict is not None:
                            this_reference_weights_by_source_gid_dict[this_source_gid] = \
                              this_reference_weights_by_syn_id_dict[this_syn_id]
                    elif this_syn_id in this_non_structured_weights_by_syn_id_dict:
                        this_syn_wgt = this_non_structured_weights_by_syn_id_dict[
                            this_syn_id]
                        if this_source_gid not in this_non_structured_weights_by_source_gid_dict:
                            this_non_structured_weights_by_source_gid_dict[
                                this_source_gid] = this_syn_wgt
                    source_gid_set_dict[source].add(this_source_gid)
                    syn_ids_by_source_gid_dict[this_source_gid].append(
                        this_syn_id)
                    syn_count_by_source_gid_dict[this_source_gid] += 1

                    count += 1
                if source not in non_structured_sources:
                    structured_syn_id_count += len(syn_ids)
                logger.info(
                    'Rank %i; destination: %s; gid %i; %d edges from source population %s'
                    % (rank, destination, this_gid, count, source))

        input_rate_maps_by_source_gid_dict = {}
        if len(non_structured_sources) > 0:
            non_structured_input_rate_maps_by_source_gid_dict = {}
        else:
            non_structured_input_rate_maps_by_source_gid_dict = None
        for source in all_sources:
            if has_structured_weights:
                source_gids = list(source_gid_set_dict[source])
            else:
                source_gids = []
            if rank == 0:
                logger.info(
                    'Reading %s feature data for %i cells in population %s...'
                    % (input_features_namespace, len(source_gids), source))
            for input_features_namespace in this_input_features_namespaces:
                input_features_iter = read_cell_attribute_selection(
                    input_features_path,
                    source,
                    namespace=input_features_namespace,
                    mask=set(source_features_attr_names),
                    comm=env.comm,
                    selection=source_gids)
                count = 0
                for gid, attr_dict in input_features_iter:
                    this_selectivity_type = attr_dict['Selectivity Type'][0]
                    this_selectivity_type_name = selectivity_type_index[
                        this_selectivity_type]
                    input_cell_config = stimulus.get_input_cell_config(
                        this_selectivity_type,
                        selectivity_type_index,
                        selectivity_attr_dict=attr_dict)
                    this_arena_rate_map = np.asarray(
                        input_cell_config.get_rate_map(arena_x, arena_y),
                        dtype=np.float32)
                    if source in non_structured_sources:
                        non_structured_input_rate_maps_by_source_gid_dict[
                            gid] = this_arena_rate_map
                    else:
                        input_rate_maps_by_source_gid_dict[
                            gid] = this_arena_rate_map
                    count += 1
                if rank == 0:
                    logger.info(
                        'Read %s feature data for %i cells in population %s' %
                        (input_features_namespace, count, source))

        if has_structured_weights:

            if is_interactive:
                context.update(locals())

            save_fig_path = None
            if save_fig is not None:
                save_fig_path = '%s/Structured Weights %s %d.png' % (
                    save_fig, destination, this_gid)

            normalized_LTP_delta_weights_dict, LTD_delta_weights_dict, arena_LS_map = \
              synapses.generate_structured_weights(target_map=target_selectivity_features_dict[this_gid]['Arena Rate Map'],
                                                initial_weight_dict=this_initial_weights_by_source_gid_dict,
                                                reference_weight_dict=this_reference_weights_by_source_gid_dict,
                                                reference_weights_are_delta=reference_weights_are_delta,
                                                reference_weights_namespace=reference_weights_namespace,
                                                input_rate_map_dict=input_rate_maps_by_source_gid_dict,
                                                non_structured_input_rate_map_dict=non_structured_input_rate_maps_by_source_gid_dict,
                                                non_structured_weights_dict=this_non_structured_weights_by_source_gid_dict,
                                                syn_count_dict=syn_count_by_source_gid_dict,
                                                max_delta_weight=max_delta_weight,
                                                max_opt_iter=max_opt_iter,
                                                max_weight_decay_fraction=max_weight_decay_fraction,
                                                target_amplitude=target_amplitude,
                                                arena_x=arena_x, arena_y=arena_y,
                                                optimize_method=optimize_method,
                                                optimize_tol=optimize_tol,
                                                optimize_grad=optimize_grad,
                                                verbose=verbose, plot=plot, show_fig=show_fig,
                                                save_fig=save_fig_path,
                                                fig_kwargs={'gid': this_gid,
                                                            'field_width': target_field_width_dict[this_gid]})
            gc.collect()

            this_selectivity_dict = target_selectivity_features_dict[this_gid]
            output_features_dict[this_gid] = {
                fld: this_selectivity_dict[fld]
                for fld in [
                    'Selectivity Type', 'Num Fields', 'Field Width',
                    'Peak Rate', 'X Offset', 'Y Offset'
                ]
            }
            output_features_dict[this_gid]['Arena State Map'] = np.asarray(
                arena_LS_map.ravel(), dtype=np.float32)
            output_syn_ids = np.empty(structured_syn_id_count, dtype='uint32')
            LTD_output_weights = np.empty(structured_syn_id_count,
                                          dtype='float32')
            LTP_output_weights = np.empty(structured_syn_id_count,
                                          dtype='float32')
            i = 0
            for source_gid in normalized_LTP_delta_weights_dict:
                for syn_id in syn_ids_by_source_gid_dict[source_gid]:
                    output_syn_ids[i] = syn_id
                    LTP_output_weights[i] = normalized_LTP_delta_weights_dict[
                        source_gid]
                    LTD_output_weights[i] = LTD_delta_weights_dict[source_gid]
                    i += 1
            LTP_output_weights_dict[this_gid] = {
                'syn_id': output_syn_ids,
                synapse_name: LTP_output_weights
            }
            LTD_output_weights_dict[this_gid] = {
                'syn_id': output_syn_ids,
                synapse_name: LTD_output_weights
            }

            logger.info(
                'Rank %i; destination: %s; gid %i; generated structured weights for %i inputs in %.2f '
                's' % (rank, destination, this_gid, len(output_syn_ids),
                       time.time() - local_time))
            gid_count += 1

        if iter_count % write_size == 0:
            if not dry_run:
                append_cell_attributes(output_weights_path,
                                       destination,
                                       LTD_output_weights_dict,
                                       namespace=LTD_output_weights_namespace,
                                       comm=env.comm,
                                       io_size=env.io_size,
                                       chunk_size=chunk_size,
                                       value_chunk_size=value_chunk_size)
                append_cell_attributes(output_weights_path,
                                       destination,
                                       LTP_output_weights_dict,
                                       namespace=LTP_output_weights_namespace,
                                       comm=env.comm,
                                       io_size=env.io_size,
                                       chunk_size=chunk_size,
                                       value_chunk_size=value_chunk_size)
                count = comm.reduce(len(LTP_output_weights_dict),
                                    op=MPI.SUM,
                                    root=0)
                if rank == 0:
                    logger.info(
                        'Destination: %s; appended weights for %i cells' %
                        (destination, count))
                if output_features_path is not None:
                    if output_features_namespace is None:
                        output_features_namespace = '%s Selectivity' % target_selectivity_type_name.title(
                        )
                    this_output_features_namespace = '%s %s' % (
                        output_features_namespace, arena_id)
                    logger.info(str(output_features_dict))
                    append_cell_attributes(
                        output_features_path,
                        destination,
                        output_features_dict,
                        namespace=this_output_features_namespace)
                    count = comm.reduce(len(output_features_dict),
                                        op=MPI.SUM,
                                        root=0)
                    if rank == 0:
                        logger.info(
                            'Destination: %s; appended selectivity features for %i cells'
                            % (destination, count))

            LTP_output_weights_dict.clear()
            LTD_output_weights_dict.clear()
            output_features_dict.clear()
            gc.collect()

        env.comm.barrier()

    if not dry_run:
        append_cell_attributes(output_weights_path,
                               destination,
                               LTD_output_weights_dict,
                               namespace=LTD_output_weights_namespace,
                               comm=env.comm,
                               io_size=env.io_size,
                               chunk_size=chunk_size,
                               value_chunk_size=value_chunk_size)
        append_cell_attributes(output_weights_path,
                               destination,
                               LTP_output_weights_dict,
                               namespace=LTP_output_weights_namespace,
                               comm=env.comm,
                               io_size=env.io_size,
                               chunk_size=chunk_size,
                               value_chunk_size=value_chunk_size)
        count = comm.reduce(len(LTP_output_weights_dict), op=MPI.SUM, root=0)
        if rank == 0:
            logger.info('Destination: %s; appended weights for %i cells' %
                        (destination, count))
        if output_features_path is not None:
            if output_features_namespace is None:
                output_features_namespace = 'Selectivity Features'
            this_output_features_namespace = '%s %s' % (
                output_features_namespace, arena_id)
            append_cell_attributes(output_features_path,
                                   destination,
                                   output_features_dict,
                                   namespace=this_output_features_namespace)
            count = comm.reduce(len(output_features_dict), op=MPI.SUM, root=0)
            if rank == 0:
                logger.info(
                    'Destination: %s; appended selectivity features for %i cells'
                    % (destination, count))

    env.comm.barrier()
    global_count = comm.gather(gid_count, root=0)
    if rank == 0:
        logger.info(
            'destination: %s; %i ranks assigned structured weights to %i cells in %.2f s'
            % (destination, comm.size, np.sum(global_count),
               time.time() - start_time))
Exemplo n.º 17
0
def init_circuit_context(env,
                         pop_name,
                         gid,
                         load_edges=False,
                         connection_graph=None,
                         load_weights=False,
                         weight_dict=None,
                         load_synapses=False,
                         synapses_dict=None,
                         set_edge_delays=True,
                         **kwargs):

    syn_attrs = env.synapse_attributes
    synapse_config = env.celltypes[pop_name]['synapses']

    has_weights = False
    weight_config = []
    if 'weights' in synapse_config:
        has_weights = True
        weight_config = synapse_config['weights']

    init_synapses = False
    init_weights = False
    init_edges = False
    if load_edges or (connection_graph is not None):
        init_synapses = True
        init_edges = True
    if has_weights and (load_weights or (weight_dict is not None)):
        init_synapses = True
        init_weights = True
    if load_synapses or (synapses_dict is not None):
        init_synapses = True

    if init_synapses:
        if synapses_dict is not None:
            syn_attrs.init_syn_id_attrs(gid, **synapses_dict)
        elif load_synapses or load_edges:
            if (pop_name in env.cell_attribute_info) and (
                    'Synapse Attributes' in env.cell_attribute_info[pop_name]):
                synapses_iter = read_cell_attribute_selection(
                    env.data_file_path,
                    pop_name, [gid],
                    'Synapse Attributes',
                    mask=set([
                        'syn_ids', 'syn_locs', 'syn_secs', 'syn_layers',
                        'syn_types', 'swc_types'
                    ]),
                    comm=env.comm)
                syn_attrs.init_syn_id_attrs_from_iter(synapses_iter)
            else:
                raise RuntimeError(
                    'init_circuit_context: synapse attributes not found for %s: gid: %i'
                    % (pop_name, gid))
        else:
            raise RuntimeError(
                "init_circuit_context: invalid synapses parameters")

    if init_weights and has_weights:

        for weight_config_dict in weight_config:

            expr_closure = weight_config_dict.get('closure', None)
            weights_namespaces = weight_config_dict['namespace']

            cell_weights_dicts = {}
            if weight_dict is not None:
                for weights_namespace in weights_namespaces:
                    if weights_namespace in weight_dict:
                        cell_weights_dicts[weights_namespace] = weight_dict[
                            weights_namespace]

            elif load_weights:
                if (env.data_file_path is None):
                    raise RuntimeError(
                        'init_circuit_context: load_weights=True but data file path is not specified '
                    )

                for weights_namespace in weights_namespaces:
                    cell_weights_iter = read_cell_attribute_selection(
                        env.data_file_path,
                        pop_name,
                        selection=[gid],
                        namespace=weights_namespace,
                        comm=env.comm)
                    for cell_weights_gid, cell_weights_dict in cell_weights_iter:
                        assert (cell_weights_gid == gid)
                        cell_weights_dicts[
                            weights_namespace] = cell_weights_dict

            else:
                raise RuntimeError(
                    "init_circuit_context: invalid weights parameters")
            if len(weights_namespaces) != len(cell_weights_dicts):
                logger.warning(
                    "init_circuit_context: Unable to load all weights namespaces: %s"
                    % str(weights_namespaces))

            multiple_weights = 'error'
            append_weights = False
            for weights_namespace in weights_namespaces:
                if weights_namespace in cell_weights_dicts:
                    cell_weights_dict = cell_weights_dicts[weights_namespace]
                    weights_syn_ids = cell_weights_dict['syn_id']
                    for syn_name in (syn_name for syn_name in cell_weights_dict
                                     if syn_name != 'syn_id'):
                        weights_values = cell_weights_dict[syn_name]
                        syn_attrs.add_mech_attrs_from_iter(
                            gid,
                            syn_name,
                            zip_longest(
                                weights_syn_ids,
                                [{
                                    'weight': Promise(expr_closure, [x])
                                } for x in weights_values]
                                if expr_closure else [{
                                    'weight': x
                                } for x in weights_values]),
                            multiple=multiple_weights,
                            append=append_weights)
                        logger.info(
                            'init_circuit_context: gid: %i; found %i %s synaptic weights in namespace %s'
                            % (gid, len(cell_weights_dict[syn_name]), syn_name,
                               weights_namespace))
                        logger.info(
                            'weight_values min/max/mean: %.02f / %.02f / %.02f'
                            % (np.min(weights_values), np.max(weights_values),
                               np.mean(weights_values)))
                expr_closure = None
                append_weights = True
                multiple_weights = 'overwrite'

    if init_edges:
        if connection_graph is not None:
            (graph, a) = connection_graph
        elif load_edges:
            if env.connectivity_file_path is None:
                raise RuntimeError(
                    'init_circuit_context: load_edges=True but connectivity file path is not specified '
                )
            elif os.path.isfile(env.connectivity_file_path):
                (graph, a) = read_graph_selection(
                    file_name=env.connectivity_file_path,
                    selection=[gid],
                    namespaces=['Synapses', 'Connections'],
                    comm=env.comm)
        else:
            raise RuntimeError(
                'init_circuit_context: connection file %s not found' %
                env.connectivity_file_path)
    else:
        (graph, a) = None, None

    if graph is not None:
        if pop_name in graph:
            for presyn_name in graph[pop_name].keys():
                edge_iter = graph[pop_name][presyn_name]
                syn_attrs.init_edge_attrs_from_iter(pop_name, presyn_name, a,
                                                    edge_iter, set_edge_delays)
        else:
            logger.error(
                'init_circuit_context: connection attributes not found for %s: gid: %i'
                % (pop_name, gid))
            raise Exception
Exemplo n.º 18
0
def load_biophys_cell_dicts(env,
                            pop_name,
                            gid_set,
                            data_file_path=None,
                            load_connections=True,
                            validate_tree=True):
    """
    Loads the data necessary to instantiate BiophysCell into the given dictionary.

    :param env: an instance of env.Env
    :param pop_name: population name
    :param gid: gid
    :param data_file_path: str or None
    :param load_connections: bool
    :param validate_tree: bool

    Environment can be instantiated as:
    env = Env(config_file, template_paths, dataset_prefix, config_prefix)
    :param template_paths: str; colon-separated list of paths to directories containing hoc cell templates
    :param dataset_prefix: str; path to directory containing required neuroh5 data files
    :param config_prefix: str; path to directory containing network and cell mechanism config files
    """

    synapse_config = env.celltypes[pop_name]['synapses']

    has_weights = False
    weights_config = None
    if 'weights' in synapse_config:
        has_weights = True
        weights_config = synapse_config['weights']

    ## Loads cell morphological data, synaptic attributes and connection data

    tree_dicts = {}
    synapses_dicts = {}
    weight_dicts = {}
    connection_graphs = {gid: {pop_name: {}} for gid in gid_set}
    graph_attr_info = None

    gid_list = list(gid_set)
    tree_attr_iter, _ = read_tree_selection(env.data_file_path,
                                            pop_name,
                                            gid_list,
                                            comm=env.comm,
                                            topology=True,
                                            validate=validate_tree)
    for gid, tree_dict in tree_attr_iter:
        tree_dicts[gid] = tree_dict

    if load_connections:
        synapses_iter = read_cell_attribute_selection(
            env.data_file_path,
            pop_name,
            gid_list,
            'Synapse Attributes',
            mask=set([
                'syn_ids', 'syn_locs', 'syn_secs', 'syn_layers', 'syn_types',
                'swc_types'
            ]),
            comm=env.comm)
        for gid, attr_dict in synapses_iter:
            synapses_dicts[gid] = attr_dict

        if has_weights:
            for config in weights_config:
                weights_namespaces = config['namespace']
                cell_weights_iters = [
                    read_cell_attribute_selection(env.data_file_path,
                                                  pop_name,
                                                  gid_list,
                                                  weights_namespace,
                                                  comm=env.comm)
                    for weights_namespace in weights_namespaces
                ]
                for weights_namespace, cell_weights_iter in zip_longest(
                        weights_namespaces, cell_weights_iters):
                    for gid, cell_weights_dict in cell_weights_iter:
                        this_weights_dict = weight_dicts.get(gid, {})
                        this_weights_dict[
                            weights_namespace] = cell_weights_dict
                        weight_dicts[gid] = this_weights_dict

        graph, graph_attr_info = read_graph_selection(
            file_name=env.connectivity_file_path,
            selection=gid_list,
            namespaces=['Synapses', 'Connections'],
            comm=env.comm)
        if pop_name in graph:
            for presyn_name in graph[pop_name].keys():
                edge_iter = graph[pop_name][presyn_name]
                for (postsyn_gid, edges) in edge_iter:
                    connection_graphs[postsyn_gid][pop_name][presyn_name] = [
                        (postsyn_gid, edges)
                    ]

    cell_dicts = {}
    for gid in gid_set:
        this_cell_dict = {}

        tree_dict = tree_dicts[gid]
        this_cell_dict['morph'] = tree_dict

        if load_connections:
            synapses_dict = synapses_dicts[gid]
            weight_dict = weight_dicts.get(gid, None)
            connection_graph = connection_graphs[gid]
            this_cell_dict['synapse'] = synapses_dict
            this_cell_dict['connectivity'] = connection_graph, graph_attr_info
            this_cell_dict['weight'] = weight_dict
        cell_dicts[gid] = this_cell_dict

    return cell_dicts