예제 #1
0
def mkstim(env):
    rank = int(env.pc.id())
    nhosts = int(env.pc.nhost())

    datasetPath = os.path.join(env.datasetPrefix, env.datasetName)

    inputFilePath = os.path.join(datasetPath, env.modelConfig['Cell Data'])

    popNames = env.celltypes.keys()
    popNames.sort()
    for popName in popNames:
        if env.celltypes[popName].has_key('vectorStimulus'):
            vecstim_namespace = env.celltypes[popName]['vectorStimulus']

            if env.nodeRanks is None:
                cell_attributes_dict = scatter_read_cell_attributes(env.comm, inputFilePath, popName,
                                                                    namespaces=[vecstim_namespace],
                                                                    io_size=env.IOsize)
            else:
                cell_attributes_dict = scatter_read_cell_attributes(env.comm, inputFilePath, popName,
                                                                    namespaces=[vecstim_namespace],
                                                                    node_rank_map=env.nodeRanks,
                                                                    io_size=env.IOsize)
            cell_vecstim = cell_attributes_dict[vecstim_namespace]
            for (gid, vecstim_dict) in cell_vecstim:
                if env.verbose:
                    if env.pc.id() == 0:
                        if len(vecstim_dict['spiketrain']) > 0:
                            print "*** Spike train for gid %i is of length %i (first spike at %g ms)" % (
                            gid, len(vecstim_dict['spiketrain']), vecstim_dict['spiketrain'][0])
                        else:
                            print "*** Spike train for gid %i is of length %i" % (gid, len(vecstim_dict['spiketrain']))

                cell = env.pc.gid2cell(gid)
                cell.play(h.Vector(vecstim_dict['spiketrain']))
def main(coords_path, coords_namespace, io_size):

    comm = MPI.COMM_WORLD
    rank = comm.rank
    size = comm.size

    print ('Allocated %i ranks' % size)

    population_ranges = read_population_ranges(coords_path)[0]
    
    soma_coords = {}
    for population in ['GC']:

        attr_dict = scatter_read_cell_attributes(coords_path, population, namespaces=[coords_namespace], io_size=io_size)
        attr_iter = attr_dict[coords_namespace]
        
        for cell_gid, coords_dict in attr_iter:

            cell_u = coords_dict['U Coordinate']
            cell_v = coords_dict['V Coordinate']
                
            print ('Rank %i: gid = %i u = %f v = %f' % (rank, cell_gid, cell_u, cell_v))
예제 #3
0
def main(syn_path, syn_namespace, io_size):

    comm = MPI.COMM_WORLD
    rank = comm.rank
    size = comm.size

    print('Allocated %i ranks' % size)

    population_ranges = read_population_ranges(syn_path)[0]
    print(population_ranges)

    for population in population_ranges.keys():

        attr_dict = scatter_read_cell_attributes(syn_path,
                                                 population,
                                                 namespaces=[syn_namespace],
                                                 io_size=io_size)
        attr_iter = attr_dict[syn_namespace]

        for cell_gid, attr_dict in attr_iter:

            print('Rank %i: gid = %i syn attrs:' % (rank, cell_gid))
            pprint.pprint(attr_dict)
예제 #4
0
def read_spike_events(input_file,
                      population_names,
                      namespace_id,
                      spike_train_attr_name='t',
                      time_range=None,
                      max_spikes=None,
                      n_trials=-1,
                      merge_trials=False,
                      comm=None,
                      io_size=0,
                      include_artificial=True):
    """
    Reads spike trains from a NeuroH5 file, and returns a dictionary with spike times and cell indices.
    :param input_file: str (path to file)
    :param population_names: list of str
    :param namespace_id: str
    :param spike_train_attr_name: str
    :param time_range: list of float
    :param max_spikes: float
    :param n_trials: int
    :param merge_trials: bool
    :return: dict
    """
    assert ((n_trials >= 1) | (n_trials == -1))

    trial_index_attr = 'Trial Index'
    trial_dur_attr = 'Trial Duration'
    artificial_attr = 'artificial'

    spkpoplst = []
    spkindlst = []
    spktlst = []
    spktrials = []
    num_cell_spks = {}
    pop_active_cells = {}

    tmin = float('inf')
    tmax = 0.

    for pop_name in population_names:

        if time_range is None or time_range[1] is None:
            logger.info('Reading spike data for population %s...' % pop_name)
        else:
            logger.info(
                'Reading spike data for population %s in time range %s...' %
                (pop_name, str(time_range)))

        spike_train_attr_set = set([
            spike_train_attr_name, trial_index_attr, trial_dur_attr,
            artificial_attr
        ])
        spkiter_dict = scatter_read_cell_attributes(input_file,
                                                    pop_name,
                                                    namespaces=[namespace_id],
                                                    mask=spike_train_attr_set,
                                                    comm=comm,
                                                    io_size=io_size)
        spkiter = spkiter_dict[namespace_id]

        this_num_cell_spks = 0
        active_set = set([])

        pop_spkindlst = []
        pop_spktlst = []
        pop_spktriallst = []

        logger.info('Read spike cell attributes for population %s...' %
                    pop_name)

        # Time Range
        if time_range is not None:
            if time_range[0] is None:
                time_range[0] = 0.0

        for spkind, spkattrs in spkiter:
            is_artificial_flag = spkattrs.get(artificial_attr, None)
            is_artificial = (is_artificial_flag[0] > 0
                             ) if is_artificial_flag is not None else None
            if is_artificial is not None:
                if is_artificial and (not include_artificial):
                    continue
            slen = len(spkattrs[spike_train_attr_name])
            trial_dur = spkattrs.get(trial_dur_attr, np.asarray([0.]))
            trial_ind = spkattrs.get(trial_index_attr,
                                     np.zeros((slen, ), dtype=np.uint8))[:slen]
            if n_trials == -1:
                n_trials = len(set(trial_ind))
            filtered_spk_idxs_by_trial = np.argwhere(
                trial_ind <= n_trials).ravel()
            filtered_spkts = spkattrs[spike_train_attr_name][
                filtered_spk_idxs_by_trial]
            filtered_trial_ind = trial_ind[filtered_spk_idxs_by_trial]
            if time_range is not None:
                filtered_spk_idxs_by_time = np.argwhere(
                    np.logical_and(filtered_spkts >= time_range[0],
                                   filtered_spkts <= time_range[1])).ravel()
                filtered_spkts = filtered_spkts[filtered_spk_idxs_by_time]
                filtered_trial_ind = filtered_trial_ind[
                    filtered_spk_idxs_by_time]
            pop_spkindlst.append(
                np.repeat([spkind], len(filtered_spkts)).astype(np.uint32))
            pop_spktriallst.append(filtered_trial_ind)
            this_num_cell_spks += len(filtered_spkts)
            if len(filtered_spkts) > 0:
                active_set.add(spkind)
            for i, spkt in enumerate(filtered_spkts):
                trial_i = filtered_trial_ind[i]
                if merge_trials:
                    spkt += np.sum(trial_dur[:trial_i])
                pop_spktlst.append(spkt)
                tmin = min(tmin, spkt)
                tmax = max(tmax, spkt)

        pop_active_cells[pop_name] = active_set
        num_cell_spks[pop_name] = this_num_cell_spks

        if not active_set:
            continue

        pop_spkts = np.asarray(pop_spktlst, dtype=np.float32)
        del (pop_spktlst)
        pop_spkinds = np.concatenate(pop_spkindlst, dtype=np.uint32)
        del (pop_spkindlst)
        pop_spktrials = np.concatenate(pop_spktriallst, dtype=np.uint32)
        del (pop_spktriallst)

        # Limit to max_spikes
        if (max_spikes is not None) and (len(pop_spkts) > max_spikes):
            logger.warn(
                ' Reading only randomly sampled %i out of %i spikes for population %s'
                % (max_spikes, len(pop_spkts), pop_name))
            sample_inds = np.random.randint(0,
                                            len(pop_spkinds) - 1,
                                            size=int(max_spikes))
            pop_spkts = pop_spkts[sample_inds]
            pop_spkinds = pop_spkinds[sample_inds]
            pop_spktrials = pop_spkinds[sample_inds]
            tmax = max(tmax, max(pop_spkts))

        spkpoplst.append(pop_name)
        pop_trial_spkindlst = []
        pop_trial_spktlst = []
        for trial_i in range(n_trials):
            trial_idxs = np.where(pop_spktrials == trial_i)[0]
            sorted_trial_idxs = np.argsort(pop_spkts[trial_idxs])
            pop_trial_spktlst.append(
                np.take(pop_spkts[trial_idxs], sorted_trial_idxs))
            pop_trial_spkindlst.append(
                np.take(pop_spkinds[trial_idxs], sorted_trial_idxs))

        del pop_spkts
        del pop_spkinds
        del pop_spktrials

        if merge_trials:
            pop_spkinds = np.concatenate(pop_trial_spkindlst)
            pop_spktlst = np.concatenate(pop_trial_spktlst)
            spkindlst.append(pop_spkinds)
            spktlst.append(pop_spktlst)
        else:
            spkindlst.append(pop_trial_spkindlst)
            spktlst.append(pop_trial_spktlst)

        logger.info(' Read %i spikes and %i trials for population %s' %
                    (this_num_cell_spks, n_trials, pop_name))

    return {
        'spkpoplst': spkpoplst,
        'spktlst': spktlst,
        'spkindlst': spkindlst,
        'tmin': tmin,
        'tmax': tmax,
        'pop_active_cells': pop_active_cells,
        'num_cell_spks': num_cell_spks,
        'n_trials': n_trials
    }
def main(config, coordinates, field_width, gid, input_features_path,
         input_features_namespaces, initial_weights_path,
         output_features_namespace, output_features_path, output_weights_path,
         reference_weights_path, h5types_path, synapse_name,
         initial_weights_namespace, output_weights_namespace,
         reference_weights_namespace, connections_path, destination, sources,
         non_structured_sources, non_structured_weights_namespace,
         non_structured_weights_path, arena_id, field_width_scale,
         max_opt_iter, max_weight_decay_fraction, optimize_tol, peak_rate,
         reference_weights_are_delta, arena_margin, target_amplitude, io_size,
         chunk_size, value_chunk_size, cache_size, write_size, verbose,
         dry_run, plot, show_fig, save_fig, debug):
    """

    :param config: str (path to .yaml file)
    :param input_features_path: str (path to .h5 file)
    :param initial_weights_path: str (path to .h5 file)
    :param initial_weights_namespace: str
    :param output_weights_namespace: str
    :param connections_path: str (path to .h5 file)
    :param destination: str
    :param sources: list of str
    :param io_size:
    :param chunk_size:
    :param value_chunk_size:
    :param write_size:
    :param verbose:
    :param dry_run:
    :return:
    """

    utils.config_logging(verbose)
    script_name = __file__
    logger = utils.get_script_logger(script_name)

    comm = MPI.COMM_WORLD
    rank = comm.rank
    nranks = comm.size

    if io_size == -1:
        io_size = comm.size
    if rank == 0:
        logger.info(f'{comm.size} ranks have been allocated')

    env = Env(comm=comm, config_file=config, io_size=io_size)
    env.comm.barrier()

    if plot and (not save_fig) and (not show_fig):
        show_fig = True

    if (not dry_run) and (rank == 0):
        if not os.path.isfile(output_weights_path):
            if initial_weights_path is not None:
                input_file = h5py.File(initial_weights_path, 'r')
            elif h5types_path is not None:
                input_file = h5py.File(h5types_path, 'r')
            else:
                raise RuntimeError(
                    'h5types input path must be specified when weights path is not specified.'
                )
            output_file = h5py.File(output_weights_path, 'w')
            input_file.copy('/H5Types', output_file)
            input_file.close()
            output_file.close()
    env.comm.barrier()

    LTD_output_weights_namespace = f'LTD {output_weights_namespace} {arena_id}'
    LTP_output_weights_namespace = f'LTP {output_weights_namespace} {arena_id}'
    this_input_features_namespaces = [
        f'{input_features_namespace} {arena_id}'
        for input_features_namespace in input_features_namespaces
    ]

    selectivity_type_index = {
        i: n
        for n, i in viewitems(env.selectivity_types)
    }
    target_selectivity_type_name = 'place'
    target_selectivity_type = env.selectivity_types[
        target_selectivity_type_name]
    features_attrs = defaultdict(dict)
    source_features_attr_names = [
        'Selectivity Type', 'Num Fields', 'Field Width', 'Peak Rate',
        'Module ID', 'Grid Spacing', 'Grid Orientation',
        'Field Width Concentration Factor', 'X Offset', 'Y Offset'
    ]
    target_features_attr_names = [
        'Selectivity Type', 'Num Fields', 'Field Width', 'Peak Rate',
        'X Offset', 'Y Offset'
    ]

    seed_offset = int(
        env.model_config['Random Seeds']['GC Structured Weights'])
    spatial_resolution = env.stimulus_config['Spatial Resolution']  # cm

    arena = env.stimulus_config['Arena'][arena_id]
    default_run_vel = arena.properties['default run velocity']  # cm/s

    gid_count = 0
    start_time = time.time()

    target_gid_set = None
    if len(gid) > 0:
        target_gid_set = set(gid)
    projections = [(source, destination) for source in sources]
    graph_info = read_graph_info(connections_path,
                                 namespaces=['Connections', 'Synapses'],
                                 read_node_index=True)
    for projection in projections:
        if projection not in graph_info:
            raise RuntimeError(
                f'Projection {projection[0]} -> {projection[1]} is not present in connections file.'
            )
        if target_gid_set is None:
            target_gid_set = set(graph_info[projection][1])

    all_sources = sources + non_structured_sources
    src_input_features_attr_dict = {source: {} for source in all_sources}
    for source in sorted(all_sources):
        this_src_input_features_attr_dict = {}
        for this_input_features_namespace in this_input_features_namespaces:
            logger.info(
                f'Rank {rank}: Reading {this_input_features_namespace} feature data for cells in population {source}'
            )
            input_features_dict = scatter_read_cell_attributes(
                input_features_path,
                source,
                namespaces=[this_input_features_namespace],
                mask=set(source_features_attr_names),
                comm=env.comm,
                io_size=env.io_size)
            for gid, attr_dict in input_features_dict[
                    this_input_features_namespace]:
                this_src_input_features_attr_dict[gid] = attr_dict
        src_input_features_attr_dict[
            source] = this_src_input_features_attr_dict
        source_gid_count = env.comm.reduce(
            len(this_src_input_features_attr_dict), op=MPI.SUM, root=0)
        if rank == 0:
            logger.info(
                f'Rank {rank}: Read feature data for {source_gid_count} cells in population {source}'
            )

    dst_gids = []
    if target_gid_set is not None:
        for i, gid in enumerate(target_gid_set):
            if i % nranks == rank:
                dst_gids.append(gid)

    dst_input_features_attr_dict = {}
    for this_input_features_namespace in this_input_features_namespaces:
        feature_count = 0
        gid_count = 0
        logger.info(
            f'Rank {rank}: reading {this_input_features_namespace} feature data for {len(dst_gids)} cells in population {destination}'
        )
        input_features_iter = scatter_read_cell_attribute_selection(
            input_features_path,
            destination,
            namespace=this_input_features_namespace,
            mask=set(target_features_attr_names),
            selection=dst_gids,
            io_size=env.io_size,
            comm=env.comm)
        for gid, attr_dict in input_features_iter:
            gid_count += 1
            if (len(coordinates) > 0) or (attr_dict['Num Fields'][0] > 0):
                dst_input_features_attr_dict[gid] = attr_dict
                feature_count += 1

        logger.info(
            f'Rank {rank}: read {this_input_features_namespace} feature data for '
            f'{gid_count} / {feature_count} cells in population {destination}')
        feature_count = env.comm.reduce(feature_count, op=MPI.SUM, root=0)
        env.comm.barrier()
        if rank == 0:
            logger.info(
                f'Read {this_input_features_namespace} feature data for {feature_count} cells in population {destination}'
            )

    feature_dst_gids = list(dst_input_features_attr_dict.keys())
    all_feature_gids_per_rank = comm.allgather(feature_dst_gids)
    all_feature_gids = sorted(
        [item for sublist in all_feature_gids_per_rank for item in sublist])
    request_dst_gids = []
    for i, gid in enumerate(all_feature_gids):
        if i % nranks == rank:
            request_dst_gids.append(gid)

    dst_input_features_attr_dict = exchange_input_features(
        env.comm, request_dst_gids, dst_input_features_attr_dict)
    dst_gids = list(dst_input_features_attr_dict.keys())

    if rank == 0:
        logger.info(
            f"Rank {rank} feature dict is {dst_input_features_attr_dict}")

    dst_count = env.comm.reduce(len(dst_gids), op=MPI.SUM, root=0)

    logger.info(f"Rank {rank} has {len(dst_gids)} feature gids")
    if rank == 0:
        logger.info(f'Total {dst_count} feature gids')

    max_dst_count = env.comm.allreduce(len(dst_gids), op=MPI.MAX)
    env.comm.barrier()

    max_iter_count = max_dst_count
    output_features_dict = {}
    LTP_output_weights_dict = {}
    LTD_output_weights_dict = {}
    non_structured_output_weights_dict = {}
    for iter_count in range(max_iter_count):

        gc.collect()

        local_time = time.time()
        selection = []
        if len(dst_gids) > 0:
            dst_gid = dst_gids.pop()
            selection.append(dst_gid)
            logger.info(f'Rank {rank} received gid {dst_gid}')

        env.comm.barrier()

        arena_margin_size = 0.
        arena_margin = max(arena_margin, 0.)

        target_selectivity_features_dict = {}
        target_selectivity_config_dict = {}
        target_field_width_dict = {}

        for destination_gid in selection:
            arena_margin_size = init_selectivity_config(
                destination_gid,
                spatial_resolution,
                arena,
                arena_margin,
                arena_margin_size,
                coordinates,
                field_width,
                field_width_scale,
                peak_rate,
                target_selectivity_type,
                selectivity_type_index,
                dst_input_features_attr_dict,
                target_selectivity_features_dict,
                target_selectivity_config_dict,
                target_field_width_dict,
                logger=logger)

        arena_x, arena_y = stimulus.get_2D_arena_spatial_mesh(
            arena, spatial_resolution, margin=arena_margin_size)

        selection = list(target_selectivity_features_dict.keys())

        initial_weights_by_source_gid_dict = defaultdict(lambda: dict())
        initial_weights_by_syn_id_dict = \
          read_weights(initial_weights_path, initial_weights_namespace, synapse_name,
                       destination, selection, env.comm, env.io_size, defaultdict(lambda: dict()),
                       logger=logger if rank == 0 else None)

        non_structured_weights_by_source_gid_dict = defaultdict(lambda: dict())
        non_structured_weights_by_syn_id_dict = None
        if len(non_structured_sources) > 0:
            non_structured_weights_by_syn_id_dict = \
             read_weights(non_structured_weights_path, non_structured_weights_namespace, synapse_name,
                          destination, selection, env.comm, env.io_size, defaultdict(lambda: dict()),
                          logger=logger if rank == 0 else None)

        reference_weights_by_syn_id_dict = None
        reference_weights_by_source_gid_dict = defaultdict(lambda: dict())
        if reference_weights_path is not None:
            reference_weights_by_syn_id_dict = \
             read_weights(reference_weights_path, reference_weights_namespace, synapse_name,
                          destination, selection, env.comm, env.io_size, defaultdict(lambda: dict()),
                          logger=logger if rank == 0 else None)

        source_gid_set_dict = defaultdict(set)
        syn_count_by_source_gid_dict = defaultdict(lambda: defaultdict(int))
        syn_ids_by_source_gid_dict = defaultdict(lambda: defaultdict(list))
        structured_syn_id_count = defaultdict(int)
        non_structured_syn_id_count = defaultdict(int)

        projections = [(source, destination) for source in all_sources]
        edge_iter_dict, edge_attr_info = scatter_read_graph_selection(
            connections_path,
            selection=selection,
            namespaces=['Synapses'],
            projections=projections,
            comm=env.comm,
            io_size=env.io_size)

        syn_counts_by_source = init_syn_weight_dicts(
            destination, non_structured_sources, edge_iter_dict,
            edge_attr_info, initial_weights_by_syn_id_dict,
            initial_weights_by_source_gid_dict,
            non_structured_weights_by_syn_id_dict,
            non_structured_weights_by_source_gid_dict,
            reference_weights_by_syn_id_dict,
            reference_weights_by_source_gid_dict, source_gid_set_dict,
            syn_count_by_source_gid_dict, syn_ids_by_source_gid_dict,
            structured_syn_id_count, non_structured_syn_id_count)

        for source in syn_counts_by_source:
            for this_gid in syn_counts_by_source[source]:
                count = syn_counts_by_source[source][this_gid]
                logger.info(
                    f'Rank {rank}: destination: {destination}; gid {this_gid}; '
                    f'{count} edges from source population {source}')

        input_rate_maps_by_source_gid_dict = {}
        if len(non_structured_sources) > 0:
            non_structured_input_rate_maps_by_source_gid_dict = {}
        else:
            non_structured_input_rate_maps_by_source_gid_dict = None

        for source in all_sources:
            source_gids = list(source_gid_set_dict[source])
            if rank == 0:
                logger.info(
                    f'Rank {rank}: getting feature data for {len(source_gids)} cells in population {source}'
                )
            this_src_input_features = exchange_input_features(
                env.comm, source_gids, src_input_features_attr_dict[source])

            count = 0
            for this_gid in source_gids:
                attr_dict = this_src_input_features[this_gid]
                this_selectivity_type = attr_dict['Selectivity Type'][0]
                this_selectivity_type_name = selectivity_type_index[
                    this_selectivity_type]
                input_cell_config = stimulus.get_input_cell_config(
                    this_selectivity_type,
                    selectivity_type_index,
                    selectivity_attr_dict=attr_dict)
                this_arena_rate_map = np.asarray(
                    input_cell_config.get_rate_map(arena_x, arena_y),
                    dtype=np.float32)
                if source in non_structured_sources:
                    non_structured_input_rate_maps_by_source_gid_dict[
                        this_gid] = this_arena_rate_map
                else:
                    input_rate_maps_by_source_gid_dict[
                        this_gid] = this_arena_rate_map
                count += 1

        for destination_gid in selection:

            if is_interactive:
                context.update(locals())

            save_fig_path = None
            if save_fig is not None:
                save_fig_path = f'{save_fig}/Structured Weights {destination} {destination_gid}.png'

            reference_weight_dict = None
            if reference_weights_path is not None:
                reference_weight_dict = reference_weights_by_source_gid_dict[
                    destination_gid]

            LTP_delta_weights_dict, LTD_delta_weights_dict, arena_structured_map = \
               synapses.generate_structured_weights(destination_gid,
                                                 target_map=target_selectivity_features_dict[destination_gid]['Arena Rate Map'],
                                                 initial_weight_dict=initial_weights_by_source_gid_dict[destination_gid],
                                                 #reference_weight_dict=reference_weight_dict,
                                                 #reference_weights_are_delta=reference_weights_are_delta,
                                                 #reference_weights_namespace=reference_weights_namespace,
                                                 input_rate_map_dict=input_rate_maps_by_source_gid_dict,
                                                 non_structured_input_rate_map_dict=non_structured_input_rate_maps_by_source_gid_dict,
                                                 non_structured_weights_dict=non_structured_weights_by_source_gid_dict[destination_gid],
                                                 syn_count_dict=syn_count_by_source_gid_dict[destination_gid],
                                                 max_opt_iter=max_opt_iter,
                                                 max_weight_decay_fraction=max_weight_decay_fraction,
                                                 target_amplitude=target_amplitude,
                                                 arena_x=arena_x, arena_y=arena_y,
                                                 optimize_tol=optimize_tol,
                                                 verbose=verbose if rank == 0 else False,
                                                 plot=plot, show_fig=show_fig,
                                                 save_fig=save_fig_path,
                                                 fig_kwargs={'gid': destination_gid,
                                                             'field_width': target_field_width_dict[destination_gid]})
            input_rate_maps_by_source_gid_dict.clear()

            target_map_flat = target_selectivity_features_dict[
                destination_gid]['Arena Rate Map'].flat
            arena_map_residual_mae = np.mean(
                np.abs(arena_structured_map - target_map_flat))
            output_features_dict[destination_gid] = \
               { fld: target_selectivity_features_dict[destination_gid][fld]
                 for fld in ['Selectivity Type',
                             'Num Fields',
                             'Field Width',
                             'Peak Rate',
                             'X Offset',
                             'Y Offset',]}
            output_features_dict[destination_gid][
                'Rate Map Residual Mean Error'] = np.asarray(
                    [arena_map_residual_mae], dtype=np.float32)

            this_structured_syn_id_count = structured_syn_id_count[
                destination_gid]
            output_syn_ids = np.empty(this_structured_syn_id_count,
                                      dtype='uint32')
            LTD_output_weights = np.empty(this_structured_syn_id_count,
                                          dtype='float32')
            LTP_output_weights = np.empty(this_structured_syn_id_count,
                                          dtype='float32')
            i = 0
            for source_gid in LTP_delta_weights_dict:
                for syn_id in syn_ids_by_source_gid_dict[destination_gid][
                        source_gid]:
                    output_syn_ids[i] = syn_id
                    LTP_output_weights[i] = LTP_delta_weights_dict[source_gid]
                    LTD_output_weights[i] = LTD_delta_weights_dict[source_gid]
                    i += 1
            LTP_output_weights_dict[destination_gid] = {
                'syn_id': output_syn_ids,
                synapse_name: LTP_output_weights
            }
            LTD_output_weights_dict[destination_gid] = {
                'syn_id': output_syn_ids,
                synapse_name: LTD_output_weights
            }

            this_non_structured_syn_id_count = non_structured_syn_id_count[
                destination_gid]
            i = 0

            logger.info(
                f'Rank {rank}; destination: {destination}; gid {destination_gid}; '
                f'generated structured weights for {len(output_syn_ids)} inputs in {time.time() - local_time:.2f} s; '
                f'residual error is {arena_map_residual_mae:.2f}')
            gid_count += 1
            gc.collect()

        env.comm.barrier()
        if (write_size > 0) and (iter_count % write_size == 0):
            if not dry_run:
                append_cell_attributes(output_weights_path,
                                       destination,
                                       LTD_output_weights_dict,
                                       namespace=LTD_output_weights_namespace,
                                       comm=env.comm,
                                       io_size=env.io_size,
                                       chunk_size=chunk_size,
                                       value_chunk_size=value_chunk_size)
                append_cell_attributes(output_weights_path,
                                       destination,
                                       LTP_output_weights_dict,
                                       namespace=LTP_output_weights_namespace,
                                       comm=env.comm,
                                       io_size=env.io_size,
                                       chunk_size=chunk_size,
                                       value_chunk_size=value_chunk_size)
                count = env.comm.reduce(len(LTP_output_weights_dict),
                                        op=MPI.SUM,
                                        root=0)
                env.comm.barrier()

                if rank == 0:
                    logger.info(
                        f'Destination: {destination}; appended weights for {count} cells'
                    )
                if output_features_path is not None:
                    if output_features_namespace is None:
                        output_features_namespace = f'{target_selectivity_type_name.title()} Selectivity'
                    this_output_features_namespace = f'{output_features_namespace} {arena_id}'
                    append_cell_attributes(
                        output_features_path,
                        destination,
                        output_features_dict,
                        namespace=this_output_features_namespace)
                    count = env.comm.reduce(len(output_features_dict),
                                            op=MPI.SUM,
                                            root=0)
                    env.comm.barrier()

                    if rank == 0:
                        logger.info(
                            f'Destination: {destination}; appended selectivity features for {count} cells'
                        )

            LTP_output_weights_dict.clear()
            LTD_output_weights_dict.clear()
            output_features_dict.clear()
            gc.collect()

        env.comm.barrier()

        if (iter_count >= 10) and debug:
            break

    env.comm.barrier()
    if not dry_run:
        append_cell_attributes(output_weights_path,
                               destination,
                               LTD_output_weights_dict,
                               namespace=LTD_output_weights_namespace,
                               comm=env.comm,
                               io_size=env.io_size,
                               chunk_size=chunk_size,
                               value_chunk_size=value_chunk_size)
        append_cell_attributes(output_weights_path,
                               destination,
                               LTP_output_weights_dict,
                               namespace=LTP_output_weights_namespace,
                               comm=env.comm,
                               io_size=env.io_size,
                               chunk_size=chunk_size,
                               value_chunk_size=value_chunk_size)
        count = comm.reduce(len(LTP_output_weights_dict), op=MPI.SUM, root=0)
        env.comm.barrier()

        if rank == 0:
            logger.info(
                f'Destination: {destination}; appended weights for {count} cells'
            )
        if output_features_path is not None:
            if output_features_namespace is None:
                output_features_namespace = 'Selectivity Features'
            this_output_features_namespace = f'{output_features_namespace} {arena_id}'
            append_cell_attributes(output_features_path,
                                   destination,
                                   output_features_dict,
                                   namespace=this_output_features_namespace)
            count = env.comm.reduce(len(output_features_dict),
                                    op=MPI.SUM,
                                    root=0)
            env.comm.barrier()

            if rank == 0:
                logger.info(
                    f'Destination: {destination}; appended selectivity features for {count} cells'
                )

    env.comm.barrier()
    global_count = env.comm.gather(gid_count, root=0)
    env.comm.barrier()

    if rank == 0:
        total_count = np.sum(global_count)
        total_time = time.time() - start_time
        logger.info(
            f'Destination: {destination}; '
            f'{env.comm.size} ranks assigned structured weights to {total_count} cells in {total_time:.2f} s'
        )
예제 #6
0
def mkcells(env):
    h('objref templatePaths, templatePathValue')

    rank = int(env.pc.id())
    nhosts = int(env.pc.nhost())

    v_sample_seed = int(env.modelConfig['Random Seeds']['Intracellular Voltage Sample'])
    ranstream_v_sample = np.random.RandomState()
    ranstream_v_sample.seed(v_sample_seed)

    datasetPath = os.path.join(env.datasetPrefix, env.datasetName)

    h.templatePaths = h.List()
    for path in env.templatePaths:
        h.templatePathValue = h.Value(1, path)
        h.templatePaths.append(h.templatePathValue)
    popNames = env.celltypes.keys()
    popNames.sort()
    for popName in popNames:
        templateName = env.celltypes[popName]['template']
        h.find_template(env.pc, h.templatePaths, templateName)

    dataFilePath = os.path.join(datasetPath, env.modelConfig['Cell Data'])

    if rank == 0:
        print 'cell attributes: ', env.cellAttributeInfo

    for popName in popNames:

        if env.verbose:
            if env.pc.id() == 0:
                print "*** Creating population %s" % popName

        templateName = env.celltypes[popName]['template']
        templateClass = eval('h.%s' % templateName)

        if env.celltypes[popName].has_key('synapses'):
            synapses = env.celltypes[popName]['synapses']
        else:
            synapses = {}

        v_sample_set = set([])
        env.v_dict[popName] = {}

        for gid in xrange(env.celltypes[popName]['start'],
                          env.celltypes[popName]['start'] + env.celltypes[popName]['num']):
            if ranstream_v_sample.uniform() <= env.vrecordFraction:
                v_sample_set.add(gid)

        if env.cellAttributeInfo.has_key(popName) and env.cellAttributeInfo[popName].has_key('Trees'):
            if env.verbose:
                if env.pc.id() == 0:
                    print "*** Reading trees for population %s" % popName

            if env.nodeRanks is None:
                (trees, forestSize) = scatter_read_trees(env.comm, dataFilePath, popName, io_size=env.IOsize)
            else:
                (trees, forestSize) = scatter_read_trees(env.comm, dataFilePath, popName, io_size=env.IOsize,
                                                         node_rank_map=env.nodeRanks)
            if env.verbose:
                if env.pc.id() == 0:
                    print "*** Done reading trees for population %s" % popName

            h.numCells = 0
            i = 0
            for (gid, tree) in trees:
                if env.verbose:
                    if env.pc.id() == 0:
                        print "*** Creating gid %i" % gid

                verboseflag = 0
                model_cell = cells.make_neurotree_cell(templateClass, neurotree_dict=tree, gid=gid, local_id=i,
                                                       dataset_path=datasetPath)
                if env.verbose:
                    if (rank == 0) and (i == 0):
                        for sec in list(model_cell.all):
                            h.psection(sec=sec)
                env.gidlist.append(gid)
                env.cells.append(model_cell)
                env.pc.set_gid2node(gid, int(env.pc.id()))
                ## Tell the ParallelContext that this cell is a spike source
                ## for all other hosts. NetCon is temporary.
                nc = model_cell.connect2target(h.nil)
                env.pc.cell(gid, nc, 1)
                ## Record spikes of this cell
                env.pc.spike_record(gid, env.t_vec, env.id_vec)
                ## Record voltages from a subset of cells
                if gid in v_sample_set:
                    v_vec = h.Vector()
                    soma = list(model_cell.soma)[0]
                    v_vec.record(soma(0.5)._ref_v)
                    env.v_dict[popName][gid] = v_vec
                i = i + 1
                h.numCells = h.numCells + 1
            if env.verbose:
                if env.pc.id() == 0:
                    print "*** Created %i cells" % i

        elif env.cellAttributeInfo.has_key(popName) and env.cellAttributeInfo[popName].has_key('Coordinates'):
            if env.verbose:
                if env.pc.id() == 0:
                    print "*** Reading coordinates for population %s" % popName

            if env.nodeRanks is None:
                cell_attributes_dict = scatter_read_cell_attributes(env.comm, dataFilePath, popName,
                                                                    namespaces=['Coordinates'],
                                                                    io_size=env.IOsize)
            else:
                cell_attributes_dict = scatter_read_cell_attributes(env.comm, dataFilePath, popName,
                                                                    namespaces=['Coordinates'],
                                                                    node_rank_map=env.nodeRanks,
                                                                    io_size=env.IOsize)
            if env.verbose:
                if env.pc.id() == 0:
                    print "*** Done reading coordinates for population %s" % popName

            coords = cell_attributes_dict['Coordinates']

            h.numCells = 0
            i = 0
            for (gid, _) in coords:
                if env.verbose:
                    if env.pc.id() == 0:
                        print "*** Creating gid %i" % gid

                verboseflag = 0
                model_cell = cells.make_cell(templateClass, gid=gid, local_id=i, dataset_path=datasetPath)
                env.gidlist.append(gid)
                env.cells.append(model_cell)
                env.pc.set_gid2node(gid, int(env.pc.id()))
                ## Tell the ParallelContext that this cell is a spike source
                ## for all other hosts. NetCon is temporary.
                nc = model_cell.connect2target(h.nil)
                env.pc.cell(gid, nc, 1)
                ## Record spikes of this cell
                env.pc.spike_record(gid, env.t_vec, env.id_vec)
                i = i + 1
                h.numCells = h.numCells + 1