Ejemplo n.º 1
0
def main(config, config_prefix, population, gid, ref_axis, input_file, template_name, output_file, dry_run, verbose):
    
    utils.config_logging(verbose)
    logger = utils.get_script_logger(os.path.basename(__file__))

    h.load_file("nrngui.hoc")
    h.load_file("import3d.hoc")

    env = Env(config_file=config, config_prefix=config_prefix)
    swc_type_defs = env.SWC_Types

    if not os.path.isfile(output_file):
        io_utils.make_h5types(env, output_file)

    (forest_pop_ranges, _)  = read_population_ranges(output_file)
    (forest_population_start, forest_population_count) = forest_pop_ranges[population]
    forest_population_end = forest_population_start + forest_population_count
    h.load_file(input_file)
    cell = getattr(h, template_name)(0, 0)
    if verbose:
        h.topology()
    tree_dict = export_swc_dict(cell, ref_axis=ref_axis)

    if (gid < forest_population_start) or (gid > forest_population_end):
        gid = forest_population_start
    trees_dict = { gid : tree_dict }

    logger.info(pprint.pformat(trees_dict))

    if not dry_run:
        append_cell_trees(output_file, population, trees_dict)
Ejemplo n.º 2
0
def main(spike_events_path, spike_events_namespace, populations,
         include_artificial, bin_size, smooth, t_variable, t_max, t_min,
         quantity, font_size, graph_type, overlay, save_format, progress,
         verbose):

    utils.config_logging(verbose)

    if t_max is None:
        time_range = None
    else:
        if t_min is None:
            time_range = [0.0, t_max]
        else:
            time_range = [t_min, t_max]

    if not populations:
        populations = ['eachPop']

    plot.plot_spike_histogram(spike_events_path,
                              spike_events_namespace,
                              include=populations,
                              time_variable=t_variable,
                              time_range=time_range,
                              pop_rates=True,
                              bin_size=bin_size,
                              smooth=smooth,
                              quantity=quantity,
                              fontSize=font_size,
                              overlay=overlay,
                              graph_type=graph_type,
                              progress=progress,
                              include_artificial=include_artificial,
                              saveFig=True,
                              figFormat=save_format)
Ejemplo n.º 3
0
def main(spike_events_path, spike_events_namespace, populations, max_spikes,
         spike_hist_bin, t_variable, t_max, t_min, font_size, fig_size, labels,
         save_format, include_artificial, verbose):

    utils.config_logging(verbose)

    if t_max is None:
        time_range = None
    else:
        if t_min is None:
            time_range = [0.0, t_max]
        else:
            time_range = [t_min, t_max]

    if not populations:
        populations = ['eachPop']

    plot.plot_spike_raster(spike_events_path,
                           spike_events_namespace,
                           include=populations,
                           time_range=time_range,
                           time_variable=t_variable,
                           pop_rates=True,
                           spike_hist='subplot',
                           max_spikes=max_spikes,
                           spike_hist_bin=spike_hist_bin,
                           include_artificial=include_artificial,
                           fontSize=font_size,
                           figSize=fig_size,
                           labels=labels,
                           saveFig=True,
                           figFormat=save_format)
Ejemplo n.º 4
0
def main(config, config_prefix, verbose):

    config_logging(verbose)
    logger = get_script_logger(script_name)

    comm = MPI.COMM_WORLD
    rank = comm.rank
    size = comm.size

    env = Env(comm=comm, config_file=config, config_prefix=config_prefix)
Ejemplo n.º 5
0
def main(celltype_path, input_path, output_path, output_namespace, output_npy,
         verbose, progress):

    utils.config_logging(verbose)
    io_utils.import_spikeraster(celltype_path,
                                input_path,
                                output_path,
                                namespace=output_namespace,
                                output_npy=output_npy,
                                progress=progress)
Ejemplo n.º 6
0
def main(config_path, input_path, t_max, t_min, window_size, overlap,
         frequency_range, dt, font_size, verbose):

    utils.config_logging(verbose)

    if t_max is None:
        time_range = None
    else:
        if t_min is None:
            time_range = [0.0, t_max]
        else:
            time_range = [t_min, t_max]

    plot.plot_lfp_spectrogram (input_path, config_path=config_path, time_range=time_range, \
                               window_size=window_size, overlap=overlap, frequency_range=frequency_range, \
                               fontSize=font_size, dt=dt, saveFig=True)
Ejemplo n.º 7
0
def main(config_path, input_path, t_max, t_min, psd, window_size, overlap,
         frequency_range, bandpass_filter, dt, font_size, verbose):

    utils.config_logging(verbose)

    if t_max is None:
        time_range = None
    else:
        if t_min is None:
            time_range = [0.0, t_max]
        else:
            time_range = [t_min, t_max]

    if bandpass_filter[0] is None:
        bandpass_filter = None

    plot.plot_lfp (input_path, config_path=config_path, time_range=time_range, \
                   compute_psd=psd, window_size=window_size, \
                   overlap=overlap, frequency_range=frequency_range,
                   bandpass_filter=bandpass_filter, dt=dt,
                   fontSize=font_size, saveFig=True)
Ejemplo n.º 8
0
def main(config, coords_path, coords_namespace, populations, subpopulation, scale, subvol, mayavi, verbose):

    utils.config_logging(verbose)
    plot.plot_coords_in_volume (populations, coords_path, coords_namespace, config, \
                                subpopulation=subpopulation, subvol=subvol, scale=scale, verbose=verbose, mayavi=mayavi)
Ejemplo n.º 9
0
def main(config, config_prefix, types_path, geometry_path, output_path,
         output_namespace, populations, resolution, alpha_radius, nodeiter,
         dispersion_delta, snap_delta, io_size, chunk_size, value_chunk_size,
         verbose):

    config_logging(verbose)
    logger = get_script_logger(script_name)

    comm = MPI.COMM_WORLD
    rank = comm.rank
    size = comm.size

    np.seterr(all='raise')

    if io_size == -1:
        io_size = comm.size
    if rank == 0:
        logger.info('%i ranks have been allocated' % comm.size)

    if rank == 0:
        if not os.path.isfile(output_path):
            input_file = h5py.File(types_path, 'r')
            output_file = h5py.File(output_path, 'w')
            input_file.copy('/H5Types', output_file)
            input_file.close()
            output_file.close()
    comm.barrier()

    env = Env(comm=comm, config_file=config, config_prefix=config_prefix)

    random_seed = int(env.model_config['Random Seeds']['Soma Locations'])
    random.seed(random_seed)

    layer_extents = env.geometry['Parametric Surface']['Layer Extents']
    rotate = env.geometry['Parametric Surface']['Rotation']

    (extent_u, extent_v, extent_l) = get_total_extents(layer_extents)
    vol = make_CA1_volume(extent_u,
                          extent_v,
                          extent_l,
                          rotate=rotate,
                          resolution=resolution)
    layer_alpha_shape_path = 'Layer Alpha Shape/%d/%d/%d' % resolution
    if rank == 0:
        logger.info("Constructing alpha shape for volume: extents: %s..." %
                    str((extent_u, extent_v, extent_l)))
        vol_alpha_shape_path = '%s/all' % (layer_alpha_shape_path)
        if geometry_path:
            vol_alpha_shape = load_alpha_shape(geometry_path,
                                               vol_alpha_shape_path)
        else:
            vol_alpha_shape = make_alpha_shape(vol, alpha_radius=alpha_radius)
            if geometry_path:
                save_alpha_shape(geometry_path, vol_alpha_shape_path,
                                 vol_alpha_shape)
        vert = vol_alpha_shape.points
        smp = np.asarray(vol_alpha_shape.bounds, dtype=np.int64)
        vol_domain = (vert, smp)

    layer_alpha_shapes = {}
    layer_extent_vals = {}
    layer_extent_transformed_vals = {}
    if rank == 0:
        for layer, extents in viewitems(layer_extents):
            (extent_u, extent_v,
             extent_l) = get_layer_extents(layer_extents, layer)
            layer_extent_vals[layer] = (extent_u, extent_v, extent_l)
            layer_extent_transformed_vals[layer] = CA1_volume_transform(
                extent_u, extent_v, extent_l)
            has_layer_alpha_shape = False
            if geometry_path:
                this_layer_alpha_shape_path = '%s/%s' % (
                    layer_alpha_shape_path, layer)
                this_layer_alpha_shape = load_alpha_shape(
                    geometry_path, this_layer_alpha_shape_path)
                layer_alpha_shapes[layer] = this_layer_alpha_shape
                if this_layer_alpha_shape is not None:
                    has_layer_alpha_shape = True
            if not has_layer_alpha_shape:
                logger.info(
                    "Constructing alpha shape for layers %s: extents: %s..." %
                    (layer, str(extents)))
                layer_vol = make_CA1_volume(extent_u,
                                            extent_v,
                                            extent_l,
                                            rotate=rotate,
                                            resolution=resolution)
                this_layer_alpha_shape = make_alpha_shape(
                    layer_vol, alpha_radius=alpha_radius)
                layer_alpha_shapes[layer] = this_layer_alpha_shape
                if geometry_path:
                    save_alpha_shape(geometry_path,
                                     this_layer_alpha_shape_path,
                                     this_layer_alpha_shape)

    comm.barrier()
    population_ranges = read_population_ranges(output_path, comm)[0]
    if len(populations) == 0:
        populations = sorted(population_ranges.keys())

    total_count = 0
    for population in populations:
        (population_start, population_count) = population_ranges[population]
        total_count += population_count

    all_xyz_coords1 = None
    generated_coords_count_dict = defaultdict(int)
    if rank == 0:
        all_xyz_coords_lst = []
        for population in populations:
            gc.collect()

            (population_start,
             population_count) = population_ranges[population]

            pop_layers = env.geometry['Cell Distribution'][population]
            pop_constraint = None
            if 'Cell Constraints' in env.geometry:
                if population in env.geometry['Cell Constraints']:
                    pop_constraint = env.geometry['Cell Constraints'][
                        population]
            if rank == 0:
                logger.info("Population %s: layer distribution is %s" %
                            (population, str(pop_layers)))

            pop_layer_count = 0
            for layer, count in viewitems(pop_layers):
                pop_layer_count += count
            assert (population_count == pop_layer_count)

            xyz_coords_lst = []
            for layer, count in viewitems(pop_layers):
                if count <= 0:
                    continue

                alpha = layer_alpha_shapes[layer]

                vert = alpha.points
                smp = np.asarray(alpha.bounds, dtype=np.int64)

                extents_xyz = layer_extent_transformed_vals[layer]
                for (vvi, vv) in enumerate(vert):
                    for (vi, v) in enumerate(vv):
                        if v < extents_xyz[vi][0]:
                            vert[vvi][vi] = extents_xyz[vi][0]
                        elif v > extents_xyz[vi][1]:
                            vert[vvi][vi] = extents_xyz[vi][1]

                N = int(count * 2)  # layer-specific number of nodes
                node_count = 0

                logger.info(
                    "Generating %i nodes in layer %s for population %s..." %
                    (N, layer, population))
                if verbose:
                    rbf_logger = logging.Logger.manager.loggerDict[
                        'rbf.pde.nodes']
                    rbf_logger.setLevel(logging.DEBUG)

                min_energy_constraint = None
                if pop_constraint is not None and layer in pop_constraint:
                    min_energy_constraint = pop_constraint[layer]

                nodes = gen_min_energy_nodes(count, (vert, smp),
                                             min_energy_constraint, nodeiter,
                                             dispersion_delta, snap_delta)
                #nodes = gen_min_energy_nodes(count, (vert, smp),
                #                             pop_constraint[layer] if pop_constraint is not None else None,
                #                             nodeiter, dispersion_delta, snap_delta)

                xyz_coords_lst.append(nodes.reshape(-1, 3))

            for this_xyz_coords in xyz_coords_lst:
                all_xyz_coords_lst.append(this_xyz_coords)
                generated_coords_count_dict[population] += len(this_xyz_coords)

        # Additional dispersion step to ensure no overlapping cell positions
        all_xyz_coords = np.row_stack(all_xyz_coords_lst)
        mask = np.ones((all_xyz_coords.shape[0], ), dtype=np.bool)
        # distance to nearest neighbor
        while True:
            kdt = cKDTree(all_xyz_coords[mask, :])
            nndist, nnindices = kdt.query(all_xyz_coords[mask, :], k=2)
            nndist, nnindices = nndist[:, 1:], nnindices[:, 1:]

            zindices = nnindices[np.argwhere(
                np.isclose(nndist, 0.0, atol=1e-3, rtol=1e-3))]
            if len(zindices) > 0:
                mask[np.argwhere(mask)[zindices]] = False
            else:
                break

        coords_offset = 0
        for population in populations:
            pop_coords_count = generated_coords_count_dict[population]
            pop_mask = mask[coords_offset:coords_offset + pop_coords_count]
            generated_coords_count_dict[population] = np.count_nonzero(
                pop_mask)
            coords_offset += pop_coords_count

        logger.info("Dispersion of %i nodes..." % np.count_nonzero(mask))
        all_xyz_coords1 = disperse(all_xyz_coords[mask, :],
                                   vol_domain,
                                   delta=dispersion_delta)

    if rank == 0:
        logger.info("Computing UVL coordinates of %i nodes..." %
                    len(all_xyz_coords1))

    all_xyz_coords_interp = None
    all_uvl_coords_interp = None

    if rank == 0:
        all_uvl_coords_interp = vol.inverse(all_xyz_coords1)
        all_xyz_coords_interp = vol(all_uvl_coords_interp[:, 0],
                                    all_uvl_coords_interp[:, 1],
                                    all_uvl_coords_interp[:, 2],
                                    mesh=False).reshape(3, -1).T

    if rank == 0:
        logger.info("Broadcasting generated nodes...")

    xyz_coords = comm.bcast(all_xyz_coords1, root=0)
    all_xyz_coords_interp = comm.bcast(all_xyz_coords_interp, root=0)
    all_uvl_coords_interp = comm.bcast(all_uvl_coords_interp, root=0)
    generated_coords_count_dict = comm.bcast(dict(generated_coords_count_dict),
                                             root=0)

    coords_offset = 0
    pop_coords_dict = {}
    for population in populations:
        xyz_error = np.asarray([0.0, 0.0, 0.0])

        pop_layers = env.geometry['Cell Distribution'][population]

        pop_start, pop_count = population_ranges[population]
        coords = []

        gen_coords_count = generated_coords_count_dict[population]

        for i, coord_ind in enumerate(
                range(coords_offset, coords_offset + gen_coords_count)):

            if i % size == rank:

                uvl_coords = all_uvl_coords_interp[coord_ind, :].ravel()
                xyz_coords1 = all_xyz_coords_interp[coord_ind, :].ravel()
                if uvl_in_bounds(all_uvl_coords_interp[coord_ind, :],
                                 layer_extents, pop_layers):
                    xyz_error = np.add(
                        xyz_error,
                        np.abs(
                            np.subtract(xyz_coords[coord_ind, :],
                                        xyz_coords1)))

                    logger.info('Rank %i: %s cell %i: %f %f %f' %
                                (rank, population, i, uvl_coords[0],
                                 uvl_coords[1], uvl_coords[2]))

                    coords.append(
                        (xyz_coords1[0], xyz_coords1[1], xyz_coords1[2],
                         uvl_coords[0], uvl_coords[1], uvl_coords[2]))
                else:
                    logger.debug(
                        'Rank %i: %s cell %i not in bounds: %f %f %f' %
                        (rank, population, i, uvl_coords[0], uvl_coords[1],
                         uvl_coords[2]))
                    uvl_coords = None
                    xyz_coords1 = None

        total_xyz_error = np.zeros((3, ))
        comm.Allreduce(xyz_error, total_xyz_error, op=MPI.SUM)

        coords_count = 0
        coords_count = np.sum(np.asarray(comm.allgather(len(coords))))

        mean_xyz_error = np.asarray([(total_xyz_error[0] / coords_count), \
                                     (total_xyz_error[1] / coords_count), \
                                     (total_xyz_error[2] / coords_count)])

        pop_coords_dict[population] = coords
        coords_offset += gen_coords_count

        if rank == 0:
            logger.info(
                'Total %i coordinates generated for population %s: mean XYZ error: %f %f %f'
                % (coords_count, population, mean_xyz_error[0],
                   mean_xyz_error[1], mean_xyz_error[2]))

    if rank == 0:
        color = 1
    else:
        color = 0

    ## comm0 includes only rank 0
    comm0 = comm.Split(color, 0)

    for population in populations:

        pop_start, pop_count = population_ranges[population]
        pop_layers = env.geometry['Cell Distribution'][population]
        pop_constraint = None
        if 'Cell Constraints' in env.geometry:
            if population in env.geometry['Cell Constraints']:
                pop_constraint = env.geometry['Cell Constraints'][population]

        coords_lst = comm.gather(pop_coords_dict[population], root=0)
        if rank == 0:
            all_coords = []
            for sublist in coords_lst:
                for item in sublist:
                    all_coords.append(item)
            coords_count = len(all_coords)

            if coords_count < pop_count:
                logger.warning(
                    "Generating additional %i coordinates for population %s..."
                    % (pop_count - len(all_coords), population))

                safety = 0.01
                delta = pop_count - len(all_coords)
                for i in range(delta):
                    for layer, count in viewitems(pop_layers):
                        if count > 0:
                            min_extent = layer_extents[layer][0]
                            max_extent = layer_extents[layer][1]
                            coord_u = np.random.uniform(
                                min_extent[0] + safety, max_extent[0] - safety)
                            coord_v = np.random.uniform(
                                min_extent[1] + safety, max_extent[1] - safety)
                            if pop_constraint is None:
                                coord_l = np.random.uniform(
                                    min_extent[2] + safety,
                                    max_extent[2] - safety)
                            else:
                                coord_l = np.random.uniform(
                                    pop_constraint[layer][0] + safety,
                                    pop_constraint[layer][1] - safety)
                            xyz_coords = CA1_volume(coord_u,
                                                    coord_v,
                                                    coord_l,
                                                    rotate=rotate).ravel()
                            all_coords.append(
                                (xyz_coords[0], xyz_coords[1], xyz_coords[2],
                                 coord_u, coord_v, coord_l))

            sampled_coords = random_subset(all_coords, int(pop_count))
            sampled_coords.sort(
                key=lambda coord: coord[3])  ## sort on U coordinate

            coords_dict = {
                pop_start + i: {
                    'X Coordinate': np.asarray([x_coord], dtype=np.float32),
                    'Y Coordinate': np.asarray([y_coord], dtype=np.float32),
                    'Z Coordinate': np.asarray([z_coord], dtype=np.float32),
                    'U Coordinate': np.asarray([u_coord], dtype=np.float32),
                    'V Coordinate': np.asarray([v_coord], dtype=np.float32),
                    'L Coordinate': np.asarray([l_coord], dtype=np.float32)
                }
                for (i, (x_coord, y_coord, z_coord, u_coord, v_coord,
                         l_coord)) in enumerate(sampled_coords)
            }

            append_cell_attributes(output_path,
                                   population,
                                   coords_dict,
                                   namespace=output_namespace,
                                   io_size=io_size,
                                   chunk_size=chunk_size,
                                   value_chunk_size=value_chunk_size,
                                   comm=comm0)

        comm.barrier()

    comm0.Free()
Ejemplo n.º 10
0
def main(config_file, population, gid, template_paths, dataset_prefix,
         config_prefix, data_file, load_synapses, syn_types, syn_sources,
         syn_source_threshold, font_size, bgcolor, colormap, plot_method,
         verbose):

    utils.config_logging(verbose)
    logger = utils.get_script_logger(script_name)

    if dataset_prefix is None and data_file is None:
        raise RuntimeError(
            'Either --dataset-prefix or --data-file must be provided.')

    params = dict(locals())
    env = Env(**params)
    configure_hoc_env(env)

    if env.data_file_path is None:
        env.data_file_path = data_file
        env.load_celltypes()

    ## Determine if a mechanism configuration file exists for this cell type
    if 'mech_file_path' in env.celltypes[population]:
        mech_file_path = env.celltypes[population]['mech_file_path']
    else:
        mech_file_path = None

    logger.info('loading cell %i' % gid)

    load_weights = False
    load_edges = False
    biophys_cell = make_biophys_cell(env,
                                     population,
                                     gid,
                                     load_synapses=load_synapses,
                                     load_weights=load_weights,
                                     load_edges=load_edges,
                                     mech_file_path=mech_file_path)
    cells.init_biophysics(biophys_cell,
                          reset_cable=True,
                          correct_cm=False,
                          correct_g_pas=False,
                          env=env)
    if load_synapses:
        init_syn_mech_attrs(biophys_cell, env, update_targets=True)

    cells.report_topology(env, biophys_cell)

    if len(syn_types) == 0:
        syn_types = None
    else:
        syn_types = list(syn_types)
    if len(syn_sources) == 0:
        syn_sources = None
    else:
        syn_sources = list(syn_sources)

    plot.plot_biophys_cell_tree(env,
                                biophys_cell,
                                saveFig=True,
                                syn_source_threshold=syn_source_threshold,
                                synapse_filters={
                                    'syn_types': syn_types,
                                    'sources': syn_sources
                                },
                                bgcolor=bgcolor,
                                colormap=colormap,
                                plot_method=plot_method)
Ejemplo n.º 11
0
def main(config, coords_path, coords_namespace, geometry_path, populations, interp_chunk_size, resolution, alpha_radius, nsample, io_size, chunk_size, value_chunk_size, cache_size, verbose):

    utils.config_logging(verbose)
    logger = utils.get_script_logger(__file__)
    
    comm = MPI.COMM_WORLD
    rank = comm.rank

    env = Env(comm=comm, config_file=config)
    output_path = coords_path

    soma_coords = {}

    if rank == 0:
        logger.info('Reading population coordinates...')
        
    for population in sorted(populations):
        coords = bcast_cell_attributes(coords_path, population, 0, \
                                       namespace=coords_namespace, comm=comm)

        soma_coords[population] = { k: (v['U Coordinate'][0], v['V Coordinate'][0], v['L Coordinate'][0]) 
                                    for (k,v) in coords }
        del coords
        gc.collect()

    
    has_ip_dist=False
    origin_ranges=None
    ip_dist_u=None
    ip_dist_v=None
    ip_dist_path = 'Distance Interpolant/%d/%d/%d' % resolution
    if rank == 0:
        if geometry_path is not None:
            f = h5py.File(geometry_path,"a")
            pkl_path = f'{ip_dist_path}/ip_dist.pkl'
            if pkl_path in f:
                has_ip_dist = True
                ip_dist_dset = f[pkl_path]
                origin_ranges, ip_dist_u, ip_dist_v = pickle.loads(base64.b64decode(ip_dist_dset[()]))
            f.close()
    has_ip_dist = env.comm.bcast(has_ip_dist, root=0)
    
    if not has_ip_dist:
        if rank == 0:
            logger.info('Creating distance interpolant...')
        (origin_ranges, ip_dist_u, ip_dist_v) = make_distance_interpolant(env.comm, geometry_config=env.geometry,
                                                                          make_volume=make_CA1_volume,
                                                                          resolution=resolution, nsample=nsample)
        if rank == 0:
            if geometry_path is not None:
                f = h5py.File(geometry_path, 'a')
                pkl_path = f'{ip_dist_path}/ip_dist.pkl'
                pkl = pickle.dumps((origin_ranges, ip_dist_u, ip_dist_v))
                pklstr = base64.b64encode(pkl)
                f[pkl_path] = pklstr
                f.close()
                
    ip_dist = (origin_ranges, ip_dist_u, ip_dist_v)
    if rank == 0:
        logger.info('Measuring soma distances...')

    soma_distances = measure_distances(env.comm, env.geometry, soma_coords, ip_dist, resolution=resolution)
                                       
    for population in list(sorted(soma_distances.keys())):

        if rank == 0:
            logger.info(f'Writing distances for population {population}...')

        dist_dict = soma_distances[population]
        attr_dict = {}
        for k, v in viewitems(dist_dict):
            attr_dict[k] = { 'U Distance': np.asarray([v[0]],dtype=np.float32), \
                             'V Distance': np.asarray([v[1]],dtype=np.float32) }
        append_cell_attributes(output_path, population, attr_dict,
                               namespace='Arc Distances', comm=comm,
                               io_size=io_size, chunk_size=chunk_size,
                               value_chunk_size=value_chunk_size, cache_size=cache_size)
        if rank == 0:
            f = h5py.File(output_path, 'a')
            f['Populations'][population]['Arc Distances'].attrs['Reference U Min'] = origin_ranges[0][0]
            f['Populations'][population]['Arc Distances'].attrs['Reference U Max'] = origin_ranges[0][1]
            f['Populations'][population]['Arc Distances'].attrs['Reference V Min'] = origin_ranges[1][0]
            f['Populations'][population]['Arc Distances'].attrs['Reference V Max'] = origin_ranges[1][1]
            f.close()

    comm.Barrier()
Ejemplo n.º 12
0
def main(config, config_prefix, include, forest_path, connectivity_path,
         connectivity_namespace, coords_path, coords_namespace,
         synapses_namespace, distances_namespace, resolution,
         interp_chunk_size, io_size, chunk_size, value_chunk_size, cache_size,
         write_size, verbose, dry_run, debug):

    utils.config_logging(verbose)
    logger = utils.get_script_logger(os.path.basename(__file__))

    comm = MPI.COMM_WORLD
    rank = comm.rank

    env = Env(comm=comm, config_file=config, config_prefix=config_prefix)
    configure_hoc_env(env)

    connection_config = env.connection_config
    extent = {}

    if (not dry_run) and (rank == 0):
        if not os.path.isfile(connectivity_path):
            input_file = h5py.File(coords_path, 'r')
            output_file = h5py.File(connectivity_path, 'w')
            input_file.copy('/H5Types', output_file)
            input_file.close()
            output_file.close()
    comm.barrier()

    population_ranges = read_population_ranges(coords_path)[0]
    populations = sorted(list(population_ranges.keys()))

    color = 0
    if rank == 0:
        color = 1
    comm0 = comm.Split(color, 0)

    soma_distances = {}
    soma_coords = {}
    for population in populations:
        if rank == 0:
            logger.info(f'Reading {population} coordinates...')
            coords_iter = read_cell_attributes(
                coords_path,
                population,
                comm=comm0,
                mask=set(['U Coordinate', 'V Coordinate', 'L Coordinate']),
                namespace=coords_namespace)
            distances_iter = read_cell_attributes(
                coords_path,
                population,
                comm=comm0,
                mask=set(['U Distance', 'V Distance']),
                namespace=distances_namespace)

            soma_coords[population] = {
                k: (float(v['U Coordinate'][0]), float(v['V Coordinate'][0]),
                    float(v['L Coordinate'][0]))
                for (k, v) in coords_iter
            }

            distances = {
                k: (float(v['U Distance'][0]), float(v['V Distance'][0]))
                for (k, v) in distances_iter
            }

            if len(distances) > 0:
                soma_distances[population] = distances

            gc.collect()

    comm.barrier()
    comm0.Free()

    soma_distances = comm.bcast(soma_distances, root=0)
    soma_coords = comm.bcast(soma_coords, root=0)

    forest_populations = sorted(read_population_names(forest_path))
    if (include is None) or (len(include) == 0):
        destination_populations = forest_populations
    else:
        destination_populations = []
        for p in include:
            if p in forest_populations:
                destination_populations.append(p)
    if rank == 0:
        logger.info(
            f'Generating connectivity for populations {destination_populations}...'
        )

    if len(soma_distances) == 0:
        (origin_ranges, ip_dist_u,
         ip_dist_v) = make_distance_interpolant(env,
                                                resolution=resolution,
                                                nsample=nsample)
        ip_dist = (origin_ranges, ip_dist_u, ip_dist_v)
        soma_distances = measure_distances(env,
                                           soma_coords,
                                           ip_dist,
                                           resolution=resolution)

    for destination_population in destination_populations:

        if rank == 0:
            logger.info(
                f'Generating connection probabilities for population {destination_population}...'
            )

        connection_prob = ConnectionProb(destination_population, soma_coords, soma_distances, \
                                         env.connection_extents)

        synapse_seed = int(
            env.model_config['Random Seeds']['Synapse Projection Partitions'])

        connectivity_seed = int(env.model_config['Random Seeds']
                                ['Distance-Dependent Connectivity'])
        cluster_seed = int(
            env.model_config['Random Seeds']['Connectivity Clustering'])

        if rank == 0:
            logger.info(
                f'Generating connections for population {destination_population}...'
            )

        populations_dict = env.model_config['Definitions']['Populations']
        generate_uv_distance_connections(comm,
                                         populations_dict,
                                         connection_config,
                                         connection_prob,
                                         forest_path,
                                         synapse_seed,
                                         connectivity_seed,
                                         cluster_seed,
                                         synapses_namespace,
                                         connectivity_namespace,
                                         connectivity_path,
                                         io_size,
                                         chunk_size,
                                         value_chunk_size,
                                         cache_size,
                                         write_size,
                                         dry_run=dry_run,
                                         debug=debug)
    MPI.Finalize()
Ejemplo n.º 13
0
def main(config, config_prefix, template_path, output_path, forest_path,
         populations, distribution, io_size, chunk_size, value_chunk_size,
         write_size, verbose, dry_run, debug):
    """
    :param config:
    :param config_prefix:
    :param template_path:
    :param forest_path:
    :param populations:
    :param distribution:
    :param io_size:
    :param chunk_size:
    :param value_chunk_size:
    """

    utils.config_logging(verbose)
    logger = utils.get_script_logger(os.path.basename(__file__))

    comm = MPI.COMM_WORLD
    rank = comm.rank

    if rank == 0:
        logger.info(f'{comm.size} ranks have been allocated')

    env = Env(comm=comm,
              config_file=config,
              config_prefix=config_prefix,
              template_paths=template_path)

    configure_hoc_env(env)

    if io_size == -1:
        io_size = comm.size

    if output_path is None:
        output_path = forest_path

    if not dry_run:
        if rank == 0:
            if not os.path.isfile(output_path):
                input_file = h5py.File(forest_path, 'r')
                output_file = h5py.File(output_path, 'w')
                input_file.copy('/H5Types', output_file)
                input_file.close()
                output_file.close()
        comm.barrier()

    (pop_ranges, _) = read_population_ranges(forest_path, comm=comm)
    start_time = time.time()
    syn_stats = dict()
    for population in populations:
        syn_stats[population] = { 'section': defaultdict(lambda: { 'excitatory': 0, 'inhibitory': 0 }), \
                                  'layer': defaultdict(lambda: { 'excitatory': 0, 'inhibitory': 0 }), \
                                  'swc_type': defaultdict(lambda: { 'excitatory': 0, 'inhibitory': 0 }), \
                                  'total': { 'excitatory': 0, 'inhibitory': 0 } }

    for population in populations:
        logger.info(f'Rank {rank} population: {population}')
        (population_start, _) = pop_ranges[population]
        template_class = load_cell_template(env,
                                            population,
                                            bcast_template=True)

        density_dict = env.celltypes[population]['synapses']['density']
        layer_set_dict = defaultdict(set)
        swc_set_dict = defaultdict(set)
        for sec_name, sec_dict in viewitems(density_dict):
            for syn_type, syn_dict in viewitems(sec_dict):
                swc_set_dict[syn_type].add(env.SWC_Types[sec_name])
                for layer_name in syn_dict:
                    if layer_name != 'default':
                        layer = env.layers[layer_name]
                        layer_set_dict[syn_type].add(layer)

        syn_stats_dict = { 'section': defaultdict(lambda: { 'excitatory': 0, 'inhibitory': 0 }), \
                           'layer': defaultdict(lambda: { 'excitatory': 0, 'inhibitory': 0 }), \
                           'swc_type': defaultdict(lambda: { 'excitatory': 0, 'inhibitory': 0 }), \
                           'total': { 'excitatory': 0, 'inhibitory': 0 } }

        count = 0
        gid_count = 0
        synapse_dict = {}
        for gid, morph_dict in NeuroH5TreeGen(forest_path,
                                              population,
                                              io_size=io_size,
                                              comm=comm,
                                              topology=True):
            local_time = time.time()
            if gid is not None:
                logger.info(f'Rank {rank} gid: {gid}: {morph_dict}')
                cell = cells.make_neurotree_hoc_cell(template_class,
                                                     neurotree_dict=morph_dict,
                                                     gid=gid)
                cell_sec_dict = {
                    'apical': (cell.apical_list, None),
                    'basal': (cell.basal_list, None),
                    'soma': (cell.soma_list, None),
                    'ais': (cell.ais_list, None),
                    'hillock': (cell.hillock_list, None)
                }
                cell_secidx_dict = {
                    'apical': cell.apicalidx,
                    'basal': cell.basalidx,
                    'soma': cell.somaidx,
                    'ais': cell.aisidx,
                    'hillock': cell.hilidx
                }

                random_seed = env.model_config['Random Seeds'][
                    'Synapse Locations'] + gid
                if distribution == 'uniform':
                    syn_dict, seg_density_per_sec = synapses.distribute_uniform_synapses(
                        random_seed, env.Synapse_Types, env.SWC_Types,
                        env.layers, density_dict, morph_dict, cell_sec_dict,
                        cell_secidx_dict)

                elif distribution == 'poisson':
                    syn_dict, seg_density_per_sec = synapses.distribute_poisson_synapses(
                        random_seed, env.Synapse_Types, env.SWC_Types,
                        env.layers, density_dict, morph_dict, cell_sec_dict,
                        cell_secidx_dict)
                else:
                    raise Exception('Unknown distribution type: %s' %
                                    distribution)

                synapse_dict[gid] = syn_dict
                this_syn_stats = update_syn_stats(env, syn_stats_dict,
                                                  syn_dict)
                check_syns(gid, morph_dict, this_syn_stats,
                           seg_density_per_sec, layer_set_dict, swc_set_dict,
                           env, logger)

                del cell
                num_syns = len(synapse_dict[gid]['syn_ids'])
                logger.info(
                    f'Rank {rank} took {time.time() - local_time:.2f} s to compute {num_syns} synapse locations for {population} gid: {gid}\n'
                    f'{local_syn_summary(this_syn_stats)}')
                gid_count += 1
            else:
                logger.info(f'Rank {rank} gid is None')
            gc.collect()
            if (not dry_run) and (write_size > 0) and (gid_count % write_size
                                                       == 0):
                append_cell_attributes(output_path,
                                       population,
                                       synapse_dict,
                                       namespace='Synapse Attributes',
                                       comm=comm,
                                       io_size=io_size,
                                       chunk_size=chunk_size,
                                       value_chunk_size=value_chunk_size)
                synapse_dict = {}
            syn_stats[population] = syn_stats_dict
            count += 1
            if debug and count == 5:
                break

        if not dry_run:
            append_cell_attributes(output_path,
                                   population,
                                   synapse_dict,
                                   namespace='Synapse Attributes',
                                   comm=comm,
                                   io_size=io_size,
                                   chunk_size=chunk_size,
                                   value_chunk_size=value_chunk_size)

        global_count, summary = global_syn_summary(comm,
                                                   syn_stats,
                                                   gid_count,
                                                   root=0)
        if rank == 0:
            logger.info(
                f'Population: {population}, {comm.size} ranks took {time.time() - start_time:.2f} s '
                f'to compute synapse locations for {np.sum(global_count)} cells'
            )
            logger.info(summary)

        comm.barrier()

    MPI.Finalize()