예제 #1
0
def main(config_path, spike_events_path, spike_events_namespace,
         spike_train_attr_name, populations, include_artificial, t_max, t_min,
         trajectory_path, arena_id, trajectory_id, bin_size, min_pf_width,
         min_pf_rate, font_size, output_file_path, plot_dir_path, save_fig,
         fig_format, progress, verbose):
    """

    :param spike_events_path:
    :param spike_events_namespace:
    :param spike_train_attr_name:
    :param populations:
    :param t_max:
    :param t_min:
    :param trajectory_path:
    :param arena_id:
    :param trajectory_id:
    :param bin_size:
    :param min_pf_width:
    :param font_size:
    :param output_file_path:
    :param plot_dir_path:
    :param save_fig:
    :param fig_format:
    :param progress:
    :param verbose:
    """
    utils.config_logging(verbose)

    plot.plot_place_fields(spike_events_path,
                           spike_events_namespace,
                           trajectory_path,
                           arena_id,
                           trajectory_id,
                           config_path=config_path,
                           populations=populations,
                           include_artificial=include_artificial,
                           bin_size=bin_size,
                           min_pf_width=min_pf_width,
                           min_pf_rate=min_pf_rate,
                           spike_train_attr_name=spike_train_attr_name,
                           time_range=[t_min, t_max],
                           fontSize=font_size,
                           output_file_path=output_file_path,
                           plot_dir_path=plot_dir_path,
                           progress=progress,
                           saveFig=save_fig,
                           figFormat=fig_format,
                           verbose=verbose)
예제 #2
0
def config_worker():
    """

    """
    utils.config_logging(context.verbose)
    context.logger = utils.get_script_logger(os.path.basename(__file__))
    if 'results_id' not in context():
        context.results_id = 'DG_test_network_subworlds_%s_%s' % \
                             (context.interface.worker_id, datetime.datetime.today().strftime('%Y%m%d_%H%M'))
    if 'env' not in context():
        try:
            init_network()
        except Exception as err:
            context.logger.exception(err)
            raise err
        context.bin_size = 5.0
예제 #3
0
def main(config_path, input_path, t_max, t_min, window_size, overlap,
         frequency_range, font_size, verbose):

    utils.config_logging(verbose)

    if t_max is None:
        time_range = None
    else:
        if t_min is None:
            time_range = [0.0, t_max]
        else:
            time_range = [t_min, t_max]

    plot.plot_lfp_spectrogram (config_path, input_path, time_range=time_range, \
                   window_size=window_size, overlap=overlap, frequency_range=frequency_range, \
                   fontSize=font_size, saveFig=True)
예제 #4
0
def main(connectivity_path, coords_path, distances_namespace, destination,
         source, bin_size, font_size, fig_size, verbose):

    utils.config_logging(verbose)
    comm = MPI.COMM_WORLD

    plot.plot_vertex_distribution(connectivity_path,
                                  coords_path,
                                  distances_namespace,
                                  destination,
                                  source,
                                  bin_size,
                                  fontSize=font_size,
                                  saveFig=True,
                                  figSize=fig_size,
                                  comm=comm)
예제 #5
0
def config_controller():
    """

    """
    utils.config_logging(context.verbose)
    context.logger = utils.get_script_logger(os.path.basename(__file__))

    context.init_params = context.kwargs
    context.init_params['target_rate_map_arena'] = context.init_params['arena_id']
    context.init_params['target_rate_map_trajectory'] = context.init_params['trajectory_id']
    context.gid = int(context.init_params['gid'])
    context.target_val = {}
    
    if 'results_file_id' not in context():
        context.results_file_id = 'DG_optimize_pf_%s_%s' % \
                             (context.interface.worker_id, datetime.datetime.today().strftime('%Y%m%d_%H%M'))
예제 #6
0
def config_controller():
    """

    """
    utils.config_logging(context.verbose)
    context.logger = utils.get_script_logger(os.path.basename(__file__))
    if 'results_file_id' not in context():
        context.results_file_id = 'DG_optimize_network_subworlds_%s_%s' % \
                             (context.interface.worker_id, datetime.datetime.today().strftime('%Y%m%d_%H%M'))
    if 'env' not in context():
        try:
            context.comm = MPI.COMM_WORLD
            #init_env()
        except Exception as err:
            context.logger.exception(err)
            raise err
예제 #7
0
def main(forest_path, state_path, state_namespace, state_namespace_pattern,
         population, gid, t_variable, state_variable, t_max, t_min, font_size,
         colormap, query, verbose):

    utils.config_logging(verbose)

    if t_max is None:
        time_range = None
    else:
        if t_min is None:
            time_range = [0.0, t_max]
        else:
            time_range = [t_min, t_max]

    namespace_id_lst, attr_info_dict = statedata.query_state(
        state_path, [population], namespace_ids=[state_namespace])
    if query:
        for this_namespace_id in namespace_id_lst:
            print("Namespace: %s" % str(this_namespace_id))
            for attr_name, attr_cell_index in attr_info_dict[pop_name][
                    this_namespace_id]:
                print("\tAttribute: %s" % str(attr_name))
                for i in attr_cell_index:
                    print("\t%d" % i)
        sys.exit()

    state_namespaces = []
    if state_namespace is not None:
        state_namespaces.append(state_namespace)

    if state_namespace_pattern is not None:
        for namespace_id in namespace_id_lst:
            m = re.match(state_namespace_pattern, namespace_id)
            if m:
                state_namespaces.append(namespace_id)

    plot.plot_intracellular_state_in_tree(gid,
                                          population,
                                          forest_path,
                                          state_path,
                                          state_namespaces,
                                          time_range=time_range,
                                          time_variable=t_variable,
                                          state_variable=state_variable,
                                          fontSize=font_size,
                                          colormap=colormap,
                                          saveFig=True)
예제 #8
0
def main(config_file, population, gid, template_paths, dataset_prefix,
         config_prefix, load_synapses, syn_types, syn_sources,
         syn_source_threshold, font_size, bgcolor, colormap, verbose):

    utils.config_logging(verbose)
    logger = utils.get_script_logger(script_name)

    params = dict(locals())
    env = Env(**params)
    configure_hoc_env(env)

    ## Determine if a mechanism configuration file exists for this cell type
    if 'mech_file_path' in env.celltypes[population]:
        mech_file_path = env.celltypes[population]['mech_file_path']
    else:
        mech_file_path = None

    logger.info('loading cell %i' % gid)

    load_weights = False
    biophys_cell = get_biophys_cell(env,
                                    population,
                                    gid,
                                    load_synapses=load_synapses,
                                    load_weights=load_weights,
                                    load_edges=load_synapses,
                                    mech_file_path=mech_file_path)

    if len(syn_types) == 0:
        syn_types = None
    else:
        syn_types = list(syn_types)
    if len(syn_sources) == 0:
        syn_sources = None
    else:
        syn_sources = list(syn_sources)

    plot.plot_biophys_cell_tree(env,
                                biophys_cell,
                                saveFig=True,
                                syn_source_threshold=syn_source_threshold,
                                synapse_filters={
                                    'syn_types': syn_types,
                                    'sources': syn_sources
                                },
                                bgcolor=bgcolor,
                                colormap=colormap)
예제 #9
0
def main(config_path, params_id, output_file_name, verbose):

    config_logging(verbose)
    logger = utils.get_script_logger(os.path.basename(__file__))

    eval_config = read_from_yaml(config_path,
                                 include_loader=utils.IncludeLoader)
    network_param_spec_src = eval_config['param_spec']
    network_param_values = eval_config['param_values']
    target_populations = eval_config['target_populations']

    network_param_spec = make_param_spec(target_populations,
                                         network_param_spec_src)

    def from_param_list(x):
        result = []
        for i, (param_name, param_tuple) in enumerate(
                zip(network_param_spec.param_names,
                    network_param_spec.param_tuples)):
            param_range = param_tuple.param_range
            #            assert((x[i] >= param_range[0]) and (x[i] <= param_range[1]))
            result.append((param_tuple, x[i]))
        return result

    params_id_list = []
    if params_id is None:
        params_id_list = list(network_param_values.keys())
    else:
        params_id_list = [params_id]

    param_output_dict = dict()
    for this_params_id in params_id_list:
        x = network_param_values[this_params_id]
        param_tuple_values = from_param_list(x)
        this_param_list = []
        for param_tuple, param_value in param_tuple_values:
            this_param_list.append((param_tuple.population, param_tuple.source,
                                    param_tuple.sec_type, param_tuple.syn_name,
                                    param_tuple.param_path, param_value))
        param_output_dict[this_params_id] = this_param_list

    pprint.pprint(param_output_dict)
    if output_file_name is not None:
        write_to_yaml(output_file_name, param_output_dict)
예제 #10
0
def main(config, coords_path, coords_namespace, distances_namespace,
         populations, verbose):

    utils.config_logging(verbose)
    logger = utils.get_script_logger(__file__)

    comm = MPI.COMM_WORLD
    rank = comm.rank

    env = Env(comm=comm, config_file=config)
    output_path = coords_path

    soma_coords = {}
    soma_distances = {}

    if rank == 0:
        logger.info('Reading population coordinates and distances...')

    for population in populations:

        coords = bcast_cell_attributes(coords_path,
                                       population,
                                       0,
                                       namespace=coords_namespace,
                                       comm=comm)
        soma_coords[population] = {
            k:
            (v['U Coordinate'][0], v['V Coordinate'][0], v['L Coordinate'][0])
            for (k, v) in coords
        }
        del coords
        gc.collect()

        distances = bcast_cell_attributes(coords_path,
                                          population,
                                          0,
                                          namespace=distances_namespace,
                                          comm=comm)
        soma_distances = {
            k: (v['U Distance'][0], v['V Distance'][0])
            for (k, v) in distances
        }
        del distances
        gc.collect()
예제 #11
0
def main(spike_events_path, spike_events_namespace, populations,
         include_artificial, bin_size, t_variable, t_max, t_min, quantity,
         graph_type, font_size, overlay, unit, verbose):

    utils.config_logging(verbose)

    if t_max is None:
        time_range = None
    else:
        if t_min is None:
            time_range = [0.0, t_max]
        else:
            time_range = [t_min, t_max]

    if not populations:
        populations = ['eachPop']

    if unit == 'cell':
        plot.plot_spike_distribution_per_cell(
            spike_events_path,
            spike_events_namespace,
            include=populations,
            include_artificial=include_artificial,
            time_variable=t_variable,
            time_range=time_range,
            quantity=quantity,
            fontSize=font_size,
            graph_type=graph_type,
            overlay=overlay,
            saveFig=True)
    elif unit == 'time':
        plot.plot_spike_distribution_per_time(
            spike_events_path,
            spike_events_namespace,
            include=populations,
            include_artificial=include_artificial,
            time_variable=t_variable,
            time_range=time_range,
            time_bin_size=bin_size,
            quantity=quantity,
            fontSize=font_size,
            overlay=overlay,
            saveFig=True)
예제 #12
0
def main(config, config_prefix, spike_events_path, spike_events_namespace,
         spike_train_attr_name, populations, include_artificial, t_max, t_min,
         trajectory_path, arena_id, trajectory_id, bin_size, output_file_path,
         save_fig, save_fig_dir, font_size, fig_size, fig_format, verbose):
    """

    :param config: str (file name)
    :param config_prefix: str (path to dir)
    :param spike_events_path: str (path to file)
    :param spike_events_namespace: str
    :param spike_train_attr_name: str
    :param populations: list of str
    :param t_max: float
    :param t_min: float
    :param trajectory_path: str (path to file)
    :param arena_id: str
    :param trajectory_id: str
    :param output_file_path: str (path to file)
    :param save_fig: str (base file name)
    :param save_fig_dir: str (path to dir)
    :param font_size: float
    :param fig_format: str
    :param verbose: bool

    """
    utils.config_logging(verbose)

    plot.plot_spatial_information(spike_events_path,
                                  spike_events_namespace,
                                  trajectory_path,
                                  arena_id,
                                  trajectory_id,
                                  populations=populations,
                                  include_artificial=include_artificial,
                                  position_bin_size=bin_size,
                                  spike_train_attr_name=spike_train_attr_name,
                                  time_range=[t_min, t_max],
                                  fontSize=font_size,
                                  verbose=verbose,
                                  output_file_path=output_file_path,
                                  saveFig=save_fig,
                                  figFormat=fig_format,
                                  figSize=fig_size)
예제 #13
0
def main(spike_events_path, spike_events_namespace, coords_path,
         distances_namespace, populations, max_spikes, t_variable, t_max,
         t_min, t_step, font_size, save_fig, verbose):

    utils.config_logging(verbose)

    if t_max is None:
        time_range = None
    else:
        if t_min is None:
            time_range = [0.0, t_max]
        else:
            time_range = [t_min, t_max]

    if not populations:
        populations = ['eachPop']

    plot.plot_spatial_spike_raster (spike_events_path, spike_events_namespace, coords_path, distances_namespace, include=populations, \
                                    time_range=time_range, time_variable=t_variable, time_step=t_step, max_spikes=max_spikes, \
                                    fontSize=font_size, saveFig=save_fig)
예제 #14
0
def main(coords_path, io_size, chunk_size, value_chunk_size):

    utils.config_logging(verbose)
    logger = utils.get_script_logger(__file__)
    
    comm = MPI.COMM_WORLD
    rank = comm.rank

    env = Env(comm=comm, config_file=config)
    output_path = coords_path

    if io_size == -1:
        io_size = comm.size
    if rank == 0:
        logger.info('%i ranks have been allocated' % comm.size)

    source_population_ranges = read_population_ranges(coords_path)
    source_populations = list(source_population_ranges.keys())

    for population in source_populations:
        if rank == 0:
            logger.info('population: ',population)
        soma_coords = bcast_cell_attributes(0, coords_path, population,
                                            namespace='Interpolated Coordinates', comm=comm)
        #print soma_coords.keys()
        u_coords = []
        gids = []
        for gid, attrs in viewitems(soma_coords):
            u_coords.append(attrs['U Coordinate'])
            gids.append(gid)
        u_coordv = np.asarray(u_coords, dtype=np.float32)
        gidv     = np.asarray(gids, dtype=np.uint32)
        sort_idx = np.argsort(u_coordv, axis=0)
        offset   = source_population_ranges[population][0]
        sorted_coords_dict = {}
        for i in range(0,sort_idx.size):
            sorted_coords_dict[offset+i] = soma_coords[gidv[sort_idx[i][0]]]
        
        append_cell_attributes(coords_path, population, sorted_coords_dict,
                                namespace='Sorted Coordinates', io_size=io_size, chunk_size=chunk_size,
                                value_chunk_size=value_chunk_size, comm=comm)
예제 #15
0
def main(config_path, spike_events_path, spike_events_namespace, populations,
         arena_id, trajectory_id, target_input_features_path,
         target_input_features_namespace, include_artificial, max_units,
         t_variable, t_max, t_min, threshold, bin_size, meansub, graph_type,
         progress, fig_size, font_size, save_format, verbose):

    utils.config_logging(verbose)

    if t_max is None:
        time_range = None
    else:
        if t_min is None:
            time_range = [0.0, t_max]
        else:
            time_range = [t_min, t_max]

    if not populations:
        populations = ['eachPop']

    plot.plot_spike_rates_with_features(
        spike_events_path,
        spike_events_namespace,
        arena_id=arena_id,
        trajectory_id=trajectory_id,
        target_input_features_path=target_input_features_path,
        target_input_features_namespace=target_input_features_namespace,
        config_path=config_path,
        include=populations,
        max_units=max_units,
        time_range=time_range,
        time_variable=t_variable,
        threshold=threshold,
        meansub=meansub,
        bin_size=bin_size,
        graph_type=graph_type,
        fontSize=font_size,
        figSize=fig_size,
        saveFig=True,
        figFormat=save_format,
        progress=progress,
        include_artificial=include_artificial)
예제 #16
0
def main(config_path, input_path, t_max, t_min, psd, window_size, overlap,
         frequency_range, bandpass_filter, font_size, verbose):

    utils.config_logging(verbose)

    if t_max is None:
        time_range = None
    else:
        if t_min is None:
            time_range = [0.0, t_max]
        else:
            time_range = [t_min, t_max]

    if bandpass_filter[0] is None:
        bandpass_filter = None

    plot.plot_lfp (config_path, input_path, time_range=time_range, \
                   compute_psd=psd, window_size=window_size, \
                   overlap=overlap, frequency_range=frequency_range,
                   bandpass_filter=bandpass_filter,
                   fontSize=font_size, saveFig=True)
예제 #17
0
def main(config, config_prefix, connectivity_path, coords_path,
         vertex_metrics_namespace, distances_namespace, destination, sources,
         normed, metric, graph_type, bin_size, font_size, verbose):

    utils.config_logging(verbose)
    logger = utils.get_script_logger(script_name)

    env = Env(config_file=config, config_prefix=config_prefix)

    plot.plot_vertex_metrics(env,
                             connectivity_path,
                             coords_path,
                             vertex_metrics_namespace,
                             distances_namespace,
                             destination,
                             sources,
                             metric=metric,
                             normed=normed,
                             bin_size=bin_size,
                             fontSize=font_size,
                             graph_type=graph_type,
                             saveFig=True)
예제 #18
0
def main(connectivity_path, coords_path, distances_namespace, destination,
         source, bin_size, cache_size, verbose):

    utils.config_logging(verbose)
    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()

    vertex_distribution_dict = graph.vertex_distribution(connectivity_path,
                                                         coords_path,
                                                         distances_namespace,
                                                         destination,
                                                         source,
                                                         bin_size,
                                                         cache_size,
                                                         comm=comm)

    if rank == 0:
        print(vertex_distribution_dict)
        f = h5py.File(connectivity_path, 'r+')

        for dst, src_dict in utils.viewitems(
                vertex_distribution_dict['Total distance']):
            grp = f.create_group('Vertex Distribution/Total distance/%s' % dst)
            for src, bins in utils.viewitems(src_dict):
                grp[src] = np.asarray(bins, dtype=np.float32)
        for dst, src_dict in utils.viewitems(
                vertex_distribution_dict['U distance']):
            grp = f.create_group('Vertex Distribution/U distance/%s' % dst)
            for src, bins in utils.viewitems(src_dict):
                grp[src] = np.asarray(bins, dtype=np.float32)
        for dst, src_dict in utils.viewitems(
                vertex_distribution_dict['V distance']):
            grp = f.create_group('Vertex Distribution/V distance/%s' % dst)
            for src, bins in utils.viewitems(src_dict):
                grp[src] = np.asarray(bins, dtype=np.float32)

        f.close()

    comm.Barrier()
예제 #19
0
def main(config_path, spike_events_path, spike_events_namespace, coords_path,
         coords_namespace, forest_path, populations, compute_rates, plot_trees,
         t_variable, t_max, t_min, t_step, font_size, rotate_anim,
         marker_scale, save_fig, verbose):

    utils.config_logging(verbose)

    if t_max is None:
        time_range = None
    else:
        if t_min is None:
            time_range = [0.0, t_max]
        else:
            time_range = [t_min, t_max]

    if not populations:
        populations = ['eachPop']

    plot.plot_spikes_in_volume (config_path, populations, coords_path, coords_namespace, spike_events_path, spike_events_namespace, \
                                forest_path=forest_path, plot_trees=plot_trees, time_range=time_range, time_variable=t_variable, \
                                time_step=t_step, compute_rates=compute_rates, \
                                fontSize=font_size, marker_scale=marker_scale, rotate_anim=rotate_anim, saveFig=save_fig)
예제 #20
0
def main(config, features_path, features_namespace, arena_id, include,
         font_size, verbose, save_fig):
    """
    
    :param features_path: 
    :param features_namespace: 
    :param include: 
    :param font_size: 
    :param verbose: 
    :param save_fig:  
    """
    utils.config_logging(verbose)

    env = Env(config_file=config)

    for population in include:
        plot.plot_stimulus_ratemap(env,
                                   features_path,
                                   features_namespace,
                                   population,
                                   arena_id=arena_id,
                                   fontSize=font_size,
                                   saveFig=save_fig)
예제 #21
0
def main(input_path, spike_namespace, state_namespace, populations, gid, n_trials, spike_hist_bin, 
         lowpass_plot_type, state_variable, t_variable, t_max, t_min, font_size, line_width, verbose):
    """

    :param input_path:
    :param spike_namespace:
    :param state_namespace:
    :param populations:
    :param gid:
    :param spike_hist_bin:
    :param state_variable:
    :param t_variable:
    :param t_max:
    :param t_min:
    :param font_size:
    :param verbose: bool
    """
    utils.config_logging(verbose)
    
    if t_max is None:
        time_range = None
    else:
        if t_min is None:
            time_range = [0.0, t_max]
        else:
            time_range = [t_min, t_max]

    if not populations:
        populations = ['eachPop']
        
    plot.plot_network_clamp(input_path, spike_namespace, state_namespace, gid=gid, include=populations,
                            time_range=time_range, time_variable=t_variable, intracellular_variable=state_variable,
                            spike_hist_bin=spike_hist_bin, lowpass_plot_type=lowpass_plot_type,
                            n_trials=n_trials, fontSize=font_size, saveFig=True, lw=line_width)

    if is_interactive:
        context.update(locals())
def main(config, config_prefix, features_path, coords_path, features_namespace,
         arena_id, trajectory_id, distances_namespace, include, bin_size,
         from_spikes, normed, font_size, verbose, save_fig):

    utils.config_logging(verbose)

    logger = utils.get_script_logger(os.path.basename(script_name))

    env = Env(config_file=config, config_prefix=config_prefix)

    plot.plot_stimulus_spatial_rate_map(env,
                                        features_path,
                                        coords_path,
                                        arena_id,
                                        trajectory_id,
                                        features_namespace,
                                        distances_namespace,
                                        include,
                                        bin_size=bin_size,
                                        from_spikes=from_spikes,
                                        normed=normed,
                                        fontSize=font_size,
                                        saveFig=save_fig,
                                        verbose=verbose)
예제 #23
0
def main(config, template_path, output_path, forest_path, populations, io_size,
         chunk_size, value_chunk_size, cache_size, verbose):
    """

    :param config:
    :param template_path:
    :param forest_path:
    :param populations:
    :param io_size:
    :param chunk_size:
    :param value_chunk_size:
    :param cache_size:
    """

    utils.config_logging(verbose)
    logger = utils.get_script_logger(script_name)

    comm = MPI.COMM_WORLD
    rank = comm.rank

    env = Env(comm=MPI.COMM_WORLD,
              config_file=config,
              template_paths=template_path)
    h('objref nil, pc, templatePaths')
    h.load_file("nrngui.hoc")
    h.load_file("./templates/Value.hoc")
    h.xopen("./lib.hoc")
    h.pc = h.ParallelContext()

    if io_size == -1:
        io_size = comm.size
    if rank == 0:
        logger.info('%i ranks have been allocated' % comm.size)

    h.templatePaths = h.List()
    for path in env.templatePaths:
        h.templatePaths.append(h.Value(1, path))

    if output_path is None:
        output_path = forest_path

    if rank == 0:
        if not os.path.isfile(output_path):
            input_file = h5py.File(forest_path, 'r')
            output_file = h5py.File(output_path, 'w')
            input_file.copy('/H5Types', output_file)
            input_file.close()
            output_file.close()
    comm.barrier()

    (pop_ranges, _) = read_population_ranges(forest_path, comm=comm)
    start_time = time.time()
    for population in populations:
        logger.info('Rank %i population: %s' % (rank, population))
        count = 0
        (population_start, _) = pop_ranges[population]
        template_name = env.celltypes[population]['template']
        h.find_template(h.pc, h.templatePaths, template_name)
        template_class = eval('h.%s' % template_name)
        measures_dict = {}
        for gid, morph_dict in NeuroH5TreeGen(forest_path,
                                              population,
                                              io_size=io_size,
                                              comm=comm,
                                              topology=True):
            if gid is not None:
                logger.info('Rank %i gid: %i' % (rank, gid))
                cell = cells.make_neurotree_cell(template_class,
                                                 neurotree_dict=morph_dict,
                                                 gid=gid)
                secnodes_dict = morph_dict['section_topology']['nodes']

                apicalidx = set(cell.apicalidx)
                basalidx = set(cell.basalidx)

                dendrite_area_dict = {k + 1: 0.0 for k in range(0, 4)}
                dendrite_length_dict = {k + 1: 0.0 for k in range(0, 4)}
                for (i, sec) in enumerate(cell.sections):
                    if (i in apicalidx) or (i in basalidx):
                        secnodes = secnodes_dict[i]
                        prev_layer = None
                        for seg in sec.allseg():
                            L = seg.sec.L
                            nseg = seg.sec.nseg
                            seg_l = old_div(L, nseg)
                            seg_area = h.area(seg.x)
                            layer = cells.get_node_attribute(
                                'layer', morph_dict, seg.sec, secnodes, seg.x)
                            layer = layer if layer > 0 else (
                                prev_layer if prev_layer is not None else 1)
                            prev_layer = layer
                            dendrite_length_dict[layer] += seg_l
                            dendrite_area_dict[layer] += seg_area

                measures_dict[gid] = { 'dendrite_area': np.asarray([ dendrite_area_dict[k] for k in sorted(dendrite_area_dict.keys()) ], dtype=np.float32), \
                                       'dendrite_length': np.asarray([ dendrite_length_dict[k] for k in sorted(dendrite_length_dict.keys()) ], dtype=np.float32) }

                del cell
                count += 1
            else:
                logger.info('Rank %i gid is None' % rank)
        append_cell_attributes(output_path,
                               population,
                               measures_dict,
                               namespace='Tree Measurements',
                               comm=comm,
                               io_size=io_size,
                               chunk_size=chunk_size,
                               value_chunk_size=value_chunk_size,
                               cache_size=cache_size)
    MPI.Finalize()
예제 #24
0
def main(config_file, config_prefix, erev, population, presyn_name, gid,
         load_weights, measurements, template_paths, dataset_prefix,
         results_path, results_file_id, results_namespace_id, syn_mech_name,
         syn_weight, syn_count, syn_layer, swc_type, stim_amp, v_init, dt,
         use_cvode, verbose):

    config_logging(verbose)

    if results_file_id is None:
        results_file_id = uuid.uuid4()
    if results_namespace_id is None:
        results_namespace_id = 'Cell Clamp Results'
    comm = MPI.COMM_WORLD
    np.seterr(all='raise')
    params = dict(locals())
    env = Env(**params)
    configure_hoc_env(env)
    io_utils.mkout(env, env.results_file_path)
    env.cell_selection = {}

    if measurements is not None:
        measurements = [x.strip() for x in measurements.split(",")]

    attr_dict = {}
    attr_dict[gid] = {}
    if 'passive' in measurements:
        attr_dict[gid].update(measure_passive(gid, population, v_init, env))
    if 'ap' in measurements:
        attr_dict[gid].update(measure_ap(gid, population, v_init, env))
    if 'ap_rate' in measurements:
        logger.info('ap_rate')
        attr_dict[gid].update(
            measure_ap_rate(gid, population, v_init, env, stim_amp=stim_amp))
    if 'fi' in measurements:
        attr_dict[gid].update(measure_fi(gid, population, v_init, env))
    if 'gap' in measurements:
        measure_gap_junction_coupling(gid, population, v_init, env)
    if 'psp' in measurements:
        assert (presyn_name is not None)
        assert (syn_mech_name is not None)
        assert (erev is not None)
        assert (syn_weight is not None)
        attr_dict[gid].update(
            measure_psp(gid,
                        population,
                        presyn_name,
                        syn_mech_name,
                        swc_type,
                        env,
                        v_init,
                        erev,
                        syn_layer=syn_layer,
                        syn_count=syn_count,
                        weight=syn_weight,
                        load_weights=load_weights))

    if results_path is not None:
        append_cell_attributes(env.results_file_path,
                               population,
                               attr_dict,
                               namespace=env.results_namespace_id,
                               comm=env.comm,
                               io_size=env.io_size)
def main(config, coordinates, gid, field_width, peak_rate, input_features_path,
         input_features_namespaces, output_features_namespace,
         output_weights_path, output_features_path, initial_weights_path,
         reference_weights_path, h5types_path, synapse_name,
         initial_weights_namespace, reference_weights_namespace,
         output_weights_namespace, reference_weights_are_delta,
         connections_path, optimize_method, destination, sources, arena_id,
         max_delta_weight, field_width_scale, max_iter, verbose, dry_run,
         plot):
    """
    :param config: str (path to .yaml file)
    :param coordinates: tuple of float
    :param gid: int
    :param field_width: float
    :param peak_rate: float
    :param input_features_path: str (path to .h5 file)
    :param input_features_namespaces: str
    :param output_features_namespace: str
    :param output_weights_path: str (path to .h5 file)
    :param output_features_path: str (path to .h5 file)
    :param initial_weights_path: str (path to .h5 file)
    :param reference_weights_path: str (path to .h5 file)
    :param h5types_path: str (path to .h5 file)
    :param synapse_name: str
    :param initial_weights_namespace: str
    :param output_weights_namespace: str
    :param reference_weights_are_delta: bool
    :param connections_path: str (path to .h5 file)
    :param destination: str (population name)
    :param sources: list of str (population name)
    :param arena_id: str
    :param max_delta_weight: float
    :param field_width_scale: float
    :param max_iter: int
    :param verbose: bool
    :param dry_run: bool
    :param interactive: bool
    :param plot: bool
    """
    utils.config_logging(verbose)
    logger = utils.get_script_logger(__file__)

    env = Env(config_file=config)

    if not dry_run:
        if output_weights_path is None:
            raise RuntimeError(
                'Missing required argument: output_weights_path.')
        if not os.path.isfile(output_weights_path):
            if initial_weights_path is not None and os.path.isfile(
                    initial_weights_path):
                input_file_path = initial_weights_path
            elif h5types_path is not None and os.path.isfile(h5types_path):
                input_file_path = h5types_path
            else:
                raise RuntimeError(
                    'Missing required source for h5types: either an initial_weights_path or an '
                    'h5types_path must be provided.')
            with h5py.File(output_weights_path, 'a') as output_file:
                with h5py.File(input_file_path, 'r') as input_file:
                    input_file.copy('/H5Types', output_file)

    this_input_features_namespaces = [
        '%s %s' % (input_features_namespace, arena_id)
        for input_features_namespace in input_features_namespaces
    ]
    features_attr_names = ['Arena Rate Map']
    spatial_resolution = env.stimulus_config['Spatial Resolution']  # cm
    arena = env.stimulus_config['Arena'][arena_id]
    default_run_vel = arena.properties['default run velocity']  # cm/s
    arena_x, arena_y = stimulus.get_2D_arena_spatial_mesh(
        arena, spatial_resolution)
    dim_x = len(arena_x)
    dim_y = len(arena_y)

    if gid is None:
        target_gids = []
    else:
        target_gids = [gid]

    dst_input_features = defaultdict(dict)
    num_fields = len(coordinates)
    this_field_width = np.array([field_width] * num_fields, dtype=np.float32)
    this_scaled_field_width = np.array([field_width * field_width_scale] *
                                       num_fields,
                                       dtype=np.float32)
    this_peak_rate = np.array([peak_rate] * num_fields, dtype=np.float32)
    this_x0 = np.array([x for x, y in coordinates], dtype=np.float32)
    this_y0 = np.array([y for x, y in coordinates], dtype=np.float32)
    this_rate_map = np.asarray(get_rate_map(this_x0, this_y0, this_field_width,
                                            this_peak_rate, arena_x, arena_y),
                               dtype=np.float32)
    target_map = np.asarray(get_rate_map(this_x0, this_y0,
                                         this_scaled_field_width,
                                         this_peak_rate, arena_x, arena_y),
                            dtype=np.float32)
    selectivity_type = env.selectivity_types['place']
    dst_input_features[destination][target_gid] = {
        'Selectivity Type': np.array([selectivity_type], dtype=np.uint8),
        'Num Fields': np.array([num_fields], dtype=np.uint8),
        'Field Width': this_field_width,
        'Peak Rate': this_peak_rate,
        'X Offset': this_x0,
        'Y Offset': this_y0,
        'Arena Rate Map': this_rate_map.ravel()
    }

    initial_weights_by_syn_id_dict = dict()
    selection = [target_gid]
    if initial_weights_path is not None:
        initial_weights_iter = \
            read_cell_attribute_selection(initial_weights_path, destination, namespace=initial_weights_namespace,
                                          selection=selection)
        syn_weight_attr_dict = dict(initial_weights_iter)

        syn_ids = syn_weight_attr_dict[target_gid]['syn_id']
        weights = syn_weight_attr_dict[target_gid][synapse_name]

        for (syn_id, weight) in zip(syn_ids, weights):
            initial_weights_by_syn_id_dict[int(syn_id)] = float(weight)

        logger.info(
            'destination: %s; gid %i; read initial synaptic weights for %i synapses'
            % (destination, target_gid, len(initial_weights_by_syn_id_dict)))

    reference_weights_by_syn_id_dict = None
    if reference_weights_path is not None:
        reference_weights_by_syn_id_dict = dict()
        reference_weights_iter = \
            read_cell_attribute_selection(reference_weights_path, destination, namespace=reference_weights_namespace,
                                          selection=selection)
        syn_weight_attr_dict = dict(reference_weights_iter)

        syn_ids = syn_weight_attr_dict[target_gid]['syn_id']
        weights = syn_weight_attr_dict[target_gid][synapse_name]

        for (syn_id, weight) in zip(syn_ids, weights):
            reference_weights_by_syn_id_dict[int(syn_id)] = float(weight)

        logger.info(
            'destination: %s; gid %i; read reference synaptic weights for %i synapses'
            % (destination, target_gid, len(reference_weights_by_syn_id_dict)))

    source_gid_set_dict = defaultdict(set)
    syn_ids_by_source_gid_dict = defaultdict(list)
    initial_weights_by_source_gid_dict = dict()
    if reference_weights_by_syn_id_dict is None:
        reference_weights_by_source_gid_dict = None
    else:
        reference_weights_by_source_gid_dict = dict()
    (graph, edge_attr_info) = read_graph_selection(file_name=connections_path,
                                                   selection=[target_gid],
                                                   namespaces=['Synapses'])
    syn_id_attr_index = None
    for source, edge_iter in viewitems(graph[destination]):
        if source not in sources:
            continue
        this_edge_attr_info = edge_attr_info[destination][source]
        if 'Synapses' in this_edge_attr_info and \
           'syn_id' in this_edge_attr_info['Synapses']:
            syn_id_attr_index = this_edge_attr_info['Synapses']['syn_id']
        for (destination_gid, edges) in edge_iter:
            assert destination_gid == target_gid
            source_gids, edge_attrs = edges
            syn_ids = edge_attrs['Synapses'][syn_id_attr_index]
            count = 0
            for i in range(len(source_gids)):
                this_source_gid = int(source_gids[i])
                source_gid_set_dict[source].add(this_source_gid)
                this_syn_id = int(syn_ids[i])
                if this_syn_id not in initial_weights_by_syn_id_dict:
                    this_weight = \
                        env.connection_config[destination][source].mechanisms['default'][synapse_name]['weight']
                    initial_weights_by_syn_id_dict[this_syn_id] = this_weight
                syn_ids_by_source_gid_dict[this_source_gid].append(this_syn_id)
                if this_source_gid not in initial_weights_by_source_gid_dict:
                    initial_weights_by_source_gid_dict[this_source_gid] = \
                        initial_weights_by_syn_id_dict[this_syn_id]
                    if reference_weights_by_source_gid_dict is not None:
                        reference_weights_by_source_gid_dict[this_source_gid] = \
                            reference_weights_by_syn_id_dict[this_syn_id]
                count += 1
            logger.info(
                'destination: %s; gid %i; set initial synaptic weights for %d inputs from source population '
                '%s' % (destination, destination_gid, count, source))

    syn_count_by_source_gid_dict = dict()
    for source_gid in syn_ids_by_source_gid_dict:
        syn_count_by_source_gid_dict[source_gid] = len(
            syn_ids_by_source_gid_dict[source_gid])

    input_rate_maps_by_source_gid_dict = dict()
    for source in sources:
        source_gids = list(source_gid_set_dict[source])
        for input_features_namespace in this_input_features_namespaces:
            input_features_iter = read_cell_attribute_selection(
                input_features_path,
                source,
                namespace=input_features_namespace,
                mask=set(features_attr_names),
                selection=source_gids)
            count = 0
            for gid, attr_dict in input_features_iter:
                input_rate_maps_by_source_gid_dict[gid] = attr_dict[
                    'Arena Rate Map'].reshape((dim_x, dim_y))
                count += 1
            logger.info('Read %s feature data for %i cells in population %s' %
                        (input_features_namespace, count, source))

    if is_interactive:
        context.update(locals())

    normalized_delta_weights_dict, arena_LS_map = \
        synapses.generate_structured_weights(target_map=target_map,
                                             initial_weight_dict=initial_weights_by_source_gid_dict,
                                             input_rate_map_dict=input_rate_maps_by_source_gid_dict,
                                             syn_count_dict=syn_count_by_source_gid_dict,
                                             max_delta_weight=max_delta_weight, arena_x=arena_x, arena_y=arena_y,
                                             reference_weight_dict=reference_weights_by_source_gid_dict,
                                             reference_weights_are_delta=reference_weights_are_delta,
                                             reference_weights_namespace=reference_weights_namespace,
                                             optimize_method=optimize_method, verbose=verbose, plot=plot)

    output_syn_ids = np.empty(len(initial_weights_by_syn_id_dict),
                              dtype='uint32')
    output_weights = np.empty(len(initial_weights_by_syn_id_dict),
                              dtype='float32')
    i = 0
    for source_gid, this_weight in viewitems(normalized_delta_weights_dict):
        for syn_id in syn_ids_by_source_gid_dict[source_gid]:
            output_syn_ids[i] = syn_id
            output_weights[i] = this_weight
            i += 1
    output_weights_dict = {
        target_gid: {
            'syn_id': output_syn_ids,
            synapse_name: output_weights
        }
    }

    logger.info('destination: %s; gid %i; generated %s for %i synapses' %
                (destination, target_gid, output_weights_namespace,
                 len(output_weights)))

    if not dry_run:
        this_output_weights_namespace = '%s %s' % (output_weights_namespace,
                                                   arena_id)
        logger.info('Destination: %s; appending %s ...' %
                    (destination, this_output_weights_namespace))
        append_cell_attributes(output_weights_path,
                               destination,
                               output_weights_dict,
                               namespace=this_output_weights_namespace)
        logger.info('Destination: %s; appended %s' %
                    (destination, this_output_weights_namespace))
        output_weights_dict.clear()
        if output_features_path is not None:
            this_output_features_namespace = '%s %s' % (
                output_features_namespace, arena_id)
            cell_attr_dict = dst_input_features[destination]
            cell_attr_dict[target_gid]['Arena State Map'] = np.asarray(
                arena_LS_map.ravel(), dtype=np.float32)
            logger.info('Destination: %s; appending %s ...' %
                        (destination, this_output_features_namespace))
            append_cell_attributes(output_features_path,
                                   destination,
                                   cell_attr_dict,
                                   namespace=this_output_features_namespace)

    if is_interactive:
        context.update(locals())
예제 #26
0
def main(arena_id, bin_sample_count, config, config_prefix, dataset_prefix,
         distances_namespace, distance_bin_extent, input_features_path,
         input_features_namespaces, populations, spike_input_path,
         spike_input_namespace, spike_input_attr, output_path, io_size,
         trajectory_id, write_selection, verbose):

    utils.config_logging(verbose)
    logger = utils.get_script_logger(os.path.basename(__file__))

    comm = MPI.COMM_WORLD
    rank = comm.rank

    env = Env(comm=comm,
              config_file=config,
              config_prefix=config_prefix,
              dataset_prefix=dataset_prefix,
              results_path=output_path,
              spike_input_path=spike_input_path,
              spike_input_namespace=spike_input_namespace,
              spike_input_attr=spike_input_attr,
              arena_id=arena_id,
              trajectory_id=trajectory_id)

    if io_size == -1:
        io_size = comm.size
    if rank == 0:
        logger.info('%i ranks have been allocated' % comm.size)

    pop_ranges, pop_size = read_population_ranges(env.connectivity_file_path,
                                                  comm=comm)

    distance_U_dict = {}
    distance_V_dict = {}
    range_U_dict = {}
    range_V_dict = {}

    selection_dict = defaultdict(set)

    comm0 = env.comm.Split(2 if rank == 0 else 0, 0)

    local_random = np.random.RandomState()
    local_random.seed(1000)

    if len(populations) == 0:
        populations = sorted(pop_ranges.keys())

    if rank == 0:
        for population in populations:
            distances = read_cell_attributes(env.data_file_path,
                                             population,
                                             namespace=distances_namespace,
                                             comm=comm0)

            soma_distances = {}
            if input_features_path is not None:
                num_fields_dict = {}
                for input_features_namespace in input_features_namespaces:
                    if arena_id is not None:
                        this_features_namespace = '%s %s' % (
                            input_features_namespace, arena_id)
                    else:
                        this_features_namespace = input_features_namespace
                    input_features_iter = read_cell_attributes(
                        input_features_path,
                        population,
                        namespace=this_features_namespace,
                        mask=set(['Num Fields']),
                        comm=comm0)
                    count = 0
                    for gid, attr_dict in input_features_iter:
                        num_fields_dict[gid] = attr_dict['Num Fields']
                        count += 1
                    logger.info(
                        'Read feature data from namespace %s for %i cells in population %s'
                        % (this_features_namespace, count, population))

                for (gid, v) in distances:
                    num_fields = num_fields_dict.get(gid, 0)
                    if num_fields > 0:
                        soma_distances[gid] = (v['U Distance'][0],
                                               v['V Distance'][0])
            else:
                for (gid, v) in distances:
                    soma_distances[gid] = (v['U Distance'][0],
                                           v['V Distance'][0])

            numitems = len(list(soma_distances.keys()))
            logger.info('read %s distances (%i elements)' %
                        (population, numitems))

            if numitems == 0:
                continue

            gid_array = np.asarray([gid for gid in soma_distances])
            distance_U_array = np.asarray(
                [soma_distances[gid][0] for gid in gid_array])
            distance_V_array = np.asarray(
                [soma_distances[gid][1] for gid in gid_array])

            U_min = np.min(distance_U_array)
            U_max = np.max(distance_U_array)
            V_min = np.min(distance_V_array)
            V_max = np.max(distance_V_array)

            range_U_dict[population] = (U_min, U_max)
            range_V_dict[population] = (V_min, V_max)

            distance_U = {
                gid: soma_distances[gid][0]
                for gid in soma_distances
            }
            distance_V = {
                gid: soma_distances[gid][1]
                for gid in soma_distances
            }

            distance_U_dict[population] = distance_U
            distance_V_dict[population] = distance_V

            min_dist = U_min
            max_dist = U_max

            distance_bins = np.arange(U_min, U_max, distance_bin_extent)
            distance_bin_array = np.digitize(distance_U_array, distance_bins)

            selection_set = set([])
            for bin_index in range(len(distance_bins) + 1):
                bin_gids = gid_array[np.where(
                    distance_bin_array == bin_index)[0]]
                if len(bin_gids) > 0:
                    selected_bin_gids = local_random.choice(
                        bin_gids, replace=False, size=bin_sample_count)
                    for gid in selected_bin_gids:
                        selection_set.add(int(gid))
            selection_dict[population] = selection_set

        yaml_output_dict = {}
        for k, v in utils.viewitems(selection_dict):
            yaml_output_dict[k] = list(sorted(v))

        yaml_output_path = '%s/DG_slice.yaml' % output_path
        with open(yaml_output_path, 'w') as outfile:
            yaml.dump(yaml_output_dict, outfile)

        del (yaml_output_dict)

    env.comm.barrier()

    write_selection_file_path = None
    if write_selection:
        write_selection_file_path = "%s/%s_selection.h5" % (env.results_path,
                                                            env.modelName)

    if write_selection_file_path is not None:
        if rank == 0:
            io_utils.mkout(env, write_selection_file_path)
        env.comm.barrier()
        selection_dict = env.comm.bcast(dict(selection_dict), root=0)
        env.cell_selection = selection_dict
        io_utils.write_cell_selection(env,
                                      write_selection_file_path,
                                      populations=populations)
        input_selection = io_utils.write_connection_selection(
            env, write_selection_file_path, populations=populations)

        if env.spike_input_ns is not None:
            io_utils.write_input_cell_selection(env,
                                                input_selection,
                                                write_selection_file_path,
                                                populations=populations)
    env.comm.barrier()
    MPI.Finalize()
예제 #27
0
def main(config, template_path, output_path, forest_path, populations,
         distance_bin_size, io_size, chunk_size, value_chunk_size, cache_size,
         verbose):
    """

    :param config:
    :param template_path:
    :param forest_path:
    :param populations:
    :param io_size:
    :param chunk_size:
    :param value_chunk_size:
    :param cache_size:
    """

    utils.config_logging(verbose)
    logger = utils.get_script_logger(script_name)

    comm = MPI.COMM_WORLD
    rank = comm.rank

    env = Env(comm=MPI.COMM_WORLD,
              config_file=config,
              template_paths=template_path)
    configure_hoc_env(env)

    if io_size == -1:
        io_size = comm.size
    if rank == 0:
        logger.info('%i ranks have been allocated' % comm.size)

    if output_path is None:
        output_path = forest_path

    if rank == 0:
        if not os.path.isfile(output_path):
            input_file = h5py.File(forest_path, 'r')
            output_file = h5py.File(output_path, 'w')
            input_file.copy('/H5Types', output_file)
            input_file.close()
            output_file.close()
    comm.barrier()

    layers = env.layers
    layer_idx_dict = {
        layers[layer_name]: layer_name
        for layer_name in ['GCL', 'IML', 'MML', 'OML', 'Hilus']
    }

    (pop_ranges, _) = read_population_ranges(forest_path, comm=comm)
    start_time = time.time()
    for population in populations:
        logger.info('Rank %i population: %s' % (rank, population))
        count = 0
        (population_start, _) = pop_ranges[population]
        template_class = load_cell_template(env,
                                            population,
                                            bcast_template=True)
        measures_dict = {}
        for gid, morph_dict in NeuroH5TreeGen(forest_path,
                                              population,
                                              io_size=io_size,
                                              comm=comm,
                                              topology=True):
            if gid is not None:
                logger.info('Rank %i gid: %i' % (rank, gid))
                cell = cells.make_neurotree_cell(template_class,
                                                 neurotree_dict=morph_dict,
                                                 gid=gid)
                secnodes_dict = morph_dict['section_topology']['nodes']

                apicalidx = set(cell.apicalidx)
                basalidx = set(cell.basalidx)

                dendrite_area_dict = {k: 0.0 for k in layer_idx_dict}
                dendrite_length_dict = {k: 0.0 for k in layer_idx_dict}
                dendrite_distances = []
                dendrite_diams = []
                for (i, sec) in enumerate(cell.sections):
                    if (i in apicalidx) or (i in basalidx):
                        secnodes = secnodes_dict[i]
                        for seg in sec.allseg():
                            L = seg.sec.L
                            nseg = seg.sec.nseg
                            seg_l = L / nseg
                            seg_area = h.area(seg.x)
                            seg_diam = seg.diam
                            seg_distance = get_distance_to_node(
                                cell,
                                list(cell.soma)[0], seg.sec, seg.x)
                            dendrite_diams.append(seg_diam)
                            dendrite_distances.append(seg_distance)
                            layer = synapses.get_node_attribute(
                                'layer', morph_dict, seg.sec, secnodes, seg.x)
                            dendrite_length_dict[layer] += seg_l
                            dendrite_area_dict[layer] += seg_area

                dendrite_distance_array = np.asarray(dendrite_distances)
                dendrite_diam_array = np.asarray(dendrite_diams)
                dendrite_distance_bin_range = int(
                    ((np.max(dendrite_distance_array)) -
                     np.min(dendrite_distance_array)) / distance_bin_size) + 1
                dendrite_distance_counts, dendrite_distance_edges = np.histogram(
                    dendrite_distance_array,
                    bins=dendrite_distance_bin_range,
                    density=False)
                dendrite_diam_sums, _ = np.histogram(
                    dendrite_distance_array,
                    weights=dendrite_diam_array,
                    bins=dendrite_distance_bin_range,
                    density=False)
                dendrite_mean_diam_hist = np.zeros_like(dendrite_diam_sums)
                np.divide(dendrite_diam_sums,
                          dendrite_distance_counts,
                          where=dendrite_distance_counts > 0,
                          out=dendrite_mean_diam_hist)

                dendrite_area_per_layer = np.asarray([
                    dendrite_area_dict[k]
                    for k in sorted(dendrite_area_dict.keys())
                ],
                                                     dtype=np.float32)
                dendrite_length_per_layer = np.asarray([
                    dendrite_length_dict[k]
                    for k in sorted(dendrite_length_dict.keys())
                ],
                                                       dtype=np.float32)

                measures_dict[gid] = {
                    'dendrite_distance_hist_edges':
                    np.asarray(dendrite_distance_edges, dtype=np.float32),
                    'dendrite_distance_counts':
                    np.asarray(dendrite_distance_counts, dtype=np.int32),
                    'dendrite_mean_diam_hist':
                    np.asarray(dendrite_mean_diam_hist, dtype=np.float32),
                    'dendrite_area_per_layer':
                    dendrite_area_per_layer,
                    'dendrite_length_per_layer':
                    dendrite_length_per_layer
                }

                del cell
                count += 1
            else:
                logger.info('Rank %i gid is None' % rank)
        append_cell_attributes(output_path,
                               population,
                               measures_dict,
                               namespace='Tree Measurements',
                               comm=comm,
                               io_size=io_size,
                               chunk_size=chunk_size,
                               value_chunk_size=value_chunk_size,
                               cache_size=cache_size)
    MPI.Finalize()
예제 #28
0
def main(config, config_prefix, max_section_length, population, forest_path,
         template_path, output_path, io_size, chunk_size, value_chunk_size,
         dry_run, verbose):
    """

    :param population: str
    :param forest_path: str (path)
    :param output_path: str (path)
    :param io_size: int
    :param chunk_size: int
    :param value_chunk_size: int
    :param verbose: bool
    """

    utils.config_logging(verbose)
    logger = utils.get_script_logger(os.path.basename(__file__))

    comm = MPI.COMM_WORLD
    rank = comm.rank

    if io_size == -1:
        io_size = comm.size
    if rank == 0:
        logger.info('%i ranks have been allocated' % comm.size)

    env = Env(comm=comm,
              config_file=config,
              config_prefix=config_prefix,
              template_paths=template_path)

    if rank == 0:
        if not os.path.isfile(output_path):
            input_file = h5py.File(forest_path, 'r')
            output_file = h5py.File(output_path, 'w')
            input_file.copy('/H5Types', output_file)
            input_file.close()
            output_file.close()
    comm.barrier()

    (forest_pop_ranges, _) = read_population_ranges(forest_path)
    (forest_population_start,
     forest_population_count) = forest_pop_ranges[population]

    (pop_ranges, _) = read_population_ranges(output_path)

    (population_start, population_count) = pop_ranges[population]

    new_trees_dict = {}
    for gid, tree_dict in NeuroH5TreeGen(forest_path,
                                         population,
                                         io_size=io_size,
                                         comm=comm,
                                         topology=False):
        if gid is not None:
            logger.info("Rank %d received gid %d" % (rank, gid))
            logger.info(pprint.pformat(tree_dict))
            new_tree_dict = cells.resize_tree_sections(tree_dict,
                                                       max_section_length)
            logger.info(pprint.pformat(new_tree_dict))
            new_trees_dict[gid] = new_tree_dict

    if not dry_run:
        append_cell_trees(output_path,
                          population,
                          new_trees_dict,
                          io_size=io_size,
                          comm=comm)

    comm.barrier()
    if (not dry_run) and (rank == 0):
        logger.info('Appended resized trees to %s' % output_path)
예제 #29
0
def main(arena_id, config, config_prefix, dataset_prefix, distances_namespace, spike_input_path, spike_input_namespace, spike_input_attr, input_features_namespaces, input_features_path, selection_path, output_path, io_size, trajectory_id, verbose):

    utils.config_logging(verbose)
    logger = utils.get_script_logger(os.path.basename(__file__))

    comm = MPI.COMM_WORLD
    rank = comm.rank
    if io_size == -1:
        io_size = comm.size

    env = Env(comm=comm, config_file=config, 
              config_prefix=config_prefix, dataset_prefix=dataset_prefix, 
              results_path=output_path, spike_input_path=spike_input_path, 
              spike_input_namespace=spike_input_namespace, spike_input_attr=spike_input_attr,
              arena_id=arena_id, trajectory_id=trajectory_id, io_size=io_size)

    selection = []
    f = open(selection_path, 'r')
    for line in f.readlines():
        selection.append(int(line))
    f.close()
    selection = set(selection)

    pop_ranges, pop_size = read_population_ranges(env.connectivity_file_path, comm=comm)

    distance_U_dict = {}
    distance_V_dict = {}
    range_U_dict = {}
    range_V_dict = {}

    selection_dict = defaultdict(set)

    comm0 = env.comm.Split(2 if rank == 0 else 0, 0)

    if rank == 0:
        for population in pop_ranges:
            distances = read_cell_attributes(env.data_file_path, population, namespace=distances_namespace, comm=comm0)
            soma_distances = { k: (v['U Distance'][0], v['V Distance'][0]) for (k,v) in distances }
            del distances
        
            numitems = len(list(soma_distances.keys()))

            if numitems == 0:
                continue

            distance_U_array = np.asarray([soma_distances[gid][0] for gid in soma_distances])
            distance_V_array = np.asarray([soma_distances[gid][1] for gid in soma_distances])

            U_min = np.min(distance_U_array)
            U_max = np.max(distance_U_array)
            V_min = np.min(distance_V_array)
            V_max = np.max(distance_V_array)

            range_U_dict[population] = (U_min, U_max)
            range_V_dict[population] = (V_min, V_max)
            
            distance_U = { gid: soma_distances[gid][0] for gid in soma_distances }
            distance_V = { gid: soma_distances[gid][1] for gid in soma_distances }
            
            distance_U_dict[population] = distance_U
            distance_V_dict[population] = distance_V
            
            min_dist = U_min
            max_dist = U_max 

            selection_dict[population] = set([ k for k in distance_U if k in selection ])
    

    env.comm.barrier()

    write_selection_file_path =  "%s/%s_selection.h5" % (env.results_path, env.modelName)

    if rank == 0:
        io_utils.mkout(env, write_selection_file_path)
    env.comm.barrier()
    selection_dict = env.comm.bcast(dict(selection_dict), root=0)
    env.cell_selection = selection_dict
    io_utils.write_cell_selection(env, write_selection_file_path)
    input_selection = io_utils.write_connection_selection(env, write_selection_file_path)
    if spike_input_path:
        io_utils.write_input_cell_selection(env, input_selection, write_selection_file_path)
    if input_features_path:
        for this_input_features_namespace in sorted(input_features_namespaces):
            for population in sorted(input_selection):
                logger.info(f"Extracting input features {this_input_features_namespace} for population {population}...")
                it = read_cell_attribute_selection(input_features_path, population, 
                                                   namespace=f"{this_input_features_namespace} {arena_id}", 
                                                   selection=input_selection[population], comm=env.comm)
                output_features_dict = { cell_gid : cell_features_dict for cell_gid, cell_features_dict in it }
                append_cell_attributes(write_selection_file_path, population, output_features_dict,
                                       namespace=f"{this_input_features_namespace} {arena_id}", 
                                       io_size=io_size, comm=env.comm)
    env.comm.barrier()
예제 #30
0
def config_worker():
    """

    """
    utils.config_logging(context.verbose)
    context.logger = utils.get_script_logger(os.path.basename(__file__))
    if 'results_file_id' not in context():
        context.results_file_id = 'DG_optimize_network_subworlds_%s_%s' % \
                             (context.interface.worker_id, datetime.datetime.today().strftime('%Y%m%d_%H%M'))
    if 'env' not in context():
        try:
            context.comm = MPI.COMM_WORLD
            init_network()
        except Exception as err:
            context.logger.exception(err)
            raise err
        context.bin_size = 5.0

    param_bounds = {}
    param_names = []
    param_initial_dict = {}
    param_range_tuples = []
    opt_targets = {}

    for pop_name in context.target_populations:

        if (pop_name in context.env.netclamp_config.optimize_parameters):
            opt_params = context.env.netclamp_config.optimize_parameters[pop_name]
            param_ranges = opt_params['Parameter ranges']
        else:
            raise RuntimeError(
                "optimize_network_subworlds: population %s does not have optimization configuration" % pop_name)

        for target_name, target_val in viewitems(opt_params['Targets']):
            opt_targets['%s %s' % (pop_name, target_name)] = target_val

        for source, source_dict in sorted(viewitems(param_ranges), key=lambda k_v3: k_v3[0]):
            for sec_type, sec_type_dict in sorted(viewitems(source_dict), key=lambda k_v2: k_v2[0]):
                for syn_name, syn_mech_dict in sorted(viewitems(sec_type_dict), key=lambda k_v1: k_v1[0]):
                    for param_fst, param_rst in sorted(viewitems(syn_mech_dict), key=lambda k_v: k_v[0]):
                        if isinstance(param_rst, dict):
                            for const_name, const_range in sorted(viewitems(param_rst)):
                                param_path = (param_fst, const_name)
                                param_range_tuples.append((pop_name, source, sec_type, syn_name, param_path, const_range))
                                param_key = '%s.%s.%s.%s.%s.%s' % (pop_name, source, sec_type, syn_name, param_fst, const_name)
                                param_initial_value = (const_range[1] - const_range[0]) / 2.0
                                param_initial_dict[param_key] = param_initial_value
                                param_bounds[param_key] = const_range
                                param_names.append(param_key)
                        else:
                            param_name = param_fst
                            param_range = param_rst
                            param_range_tuples.append((pop_name, source, sec_type, syn_name, param_name, param_range))
                            param_key = '%s.%s.%s.%s.%s' % (pop_name, source, sec_type, syn_name, param_name)
                            param_initial_value = (param_range[1] - param_range[0]) / 2.0
                            param_initial_dict[param_key] = param_initial_value
                            param_bounds[param_key] = param_range
                            param_names.append(param_key)

    def from_param_vector(params):
        result = []
        assert (len(params) == len(param_range_tuples))
        for i, (pop_name, source, sec_type, syn_name, param_name, param_range) in enumerate(param_range_tuples):
            result.append((pop_name, source, sec_type, syn_name, param_name, params[i]))
        return result

    def to_param_vector(params):
        result = []
        for (source, sec_type, syn_name, param_name, param_value) in params:
            result.append(param_value)
        return result

    context.param_names = param_names
    context.bounds = [ param_bounds[key] for key in param_names ]
    context.x0 = param_initial_dict
    context.from_param_vector = from_param_vector
    context.to_param_vector = to_param_vector
    context.target_val = opt_targets
    context.target_range = opt_targets