def fetch_compartments_meshes(compartments,
                             segment_id,
                          split_index=0,
                          original_mesh=None,
                        verbose=False,
                          plot_mesh = False,
                             mesh_alpha = 1):
    """
    Purpose: to get the requested
    compartment meshes saved off

    Ex: 
    import apical_utils as apu
    pv.fetch_compartments_meshes(apu.default_compartment_order,
                                segment_id,
                                split_index,
                                 original_mesh = original_mesh,
                                plot_mesh=True)
    """
    compartments = nu.convert_to_array_like(compartments)
    
    if original_mesh is None:
        original_mesh = du.fetch_segment_id_mesh(segment_id)
        
    comp_meshes = [pv.fetch_compartment_mesh(c,segment_id,split_index,
                                            original_mesh=original_mesh) for c in compartments]
    
    if plot_mesh:
        comp_meshes_colors = apu.colors_from_compartments(compartments)
        nviz.plot_objects(original_mesh,
                         meshes = comp_meshes,
                         meshes_colors=comp_meshes_colors,
                         mesh_alpha=mesh_alpha)
        
    return comp_meshes
def fetch_proofread_mesh(segment_id,split_index = 0,
                         original_mesh = None,
                        return_error_mesh = False,
                         plot_mesh = False
                        ):
    if original_mesh is None:
        original_mesh = du.fetch_segment_id_mesh(segment_id)
        
    proof_mesh = pv.fetch_compartment_mesh("neuron",
                                          segment_id,
                                          split_index,
                                          original_mesh=original_mesh,
                                          )
    if return_error_mesh:
        error_mesh = tu.subtract_mesh(original_mesh,proof_mesh)
        
        if plot_mesh:
            nviz.plot_objects(proof_mesh,
                             meshes=[error_mesh],
                             meshes_colors=["red"])
        
        return proof_mesh,error_mesh
    
    if plot_mesh:
        nviz.plot_objects(proof_mesh,
                         #meshes=[error_mesh],
                         #meshes_colors=["red"]
                         )
    return proof_mesh
def nucleus_mesh_from_segment_id(
    segment_id,
    split_index = 0,
    nuclei_size=2500,
    verbose = False,
    plot_nuclei_with_mesh=False):
    
    nucleus_center = pv.nucleus_center_from_segment_id(segment_id,split_index)
    if verbose:
        print(f"nucleus_center= {nucleus_center}")
        
    if nucleus_center is None:
        return None
    
    
    nucleus_mesh = tu.sphere_mesh(nucleus_center,radius=nuclei_size)
    
    if plot_nuclei_with_mesh:
        proof_mesh = du.fetch_proofread_neuron(segment_id,split_index)
        nviz.plot_objects(main_mesh=proof_mesh,
                 meshes=[nucleus_mesh],
                 meshes_colors="red",
                         mesh_alpha = 1,
                         scatters=[nucleus_center.reshape(-1,3)],
                                  scatter_size=1)
        
    return nucleus_mesh
示例#4
0
def fetch_segment_id_mesh(segment_id, plot_mesh=False):
    """
    Gets the undecimated mesh
    """
    seg_mesh = ((decimated_mesh_table
                 & dict(segment_id=segment_id)).fetch1("mesh"))
    if plot_mesh:
        import neuron_visualizations as nviz
        nviz.plot_objects(seg_mesh)
    return seg_mesh
示例#5
0
def mesh_from_seg_id(segment_id, cloudvolume_obj=cloudvolume_raw, plot=False):
    """
    Purpose: To return a mesh object from downloading through the
    cloudvolume interface
    """
    mesh = cloudvolume_obj.mesh.get(segment_id)
    mesh_verts = list(mesh.values())[0].vertices
    mesh_faces = list(mesh.values())[0].faces

    mesh_tri = trimesh.Trimesh(vertices=mesh_verts, faces=mesh_faces)

    if plot:
        nviz.plot_objects(meshes=[mesh_tri])
    return mesh_tri
示例#6
0
def plot_df_xyz(df,branch_size = 1,soma_size = 4,
               soma_color = "blue",branch_color = "red",
               col_suffix = "",
                flip_y = True,
                **kwargs,):
    soma_center = soma_center_from_df(df)
    all_points = df[[f"x{col_suffix}",f"y{col_suffix}",f"z{col_suffix}"]].to_numpy().reshape(-1,3)
    s_center = soma_center_from_df(df,col_suffix=col_suffix)
    nviz.plot_objects(scatters=[all_points,s_center.reshape(-1,3)],
                      scatters_colors=[branch_color,soma_color],
                     scatter_size=[branch_size,soma_size],
                      flip_y = flip_y,
                      axis_box_off = False,
                     **kwargs)
示例#7
0
    def make(self, key):
        print(key)
        segment_id = key['segment_id']
        split_index = key['split_index']
        version = key.get("decimation_version", 0)

        mesh = pv.fetch_proofread_mesh(segment_id, split_index=split_index)
        #mesh = (minnie.Mesh & key).fetch1('mesh')

        if verbose:
            print(
                f"Mesh size BEFORE DECIMATION: n_vertices = {len(mesh.vertices)}, n_faces = {len(mesh.faces)}"
            )

        if verbose:
            st = time.time()

        dec_mesh = tu.decimate(mesh, decimation_ratio_global)

        if verbose:
            print(f"Total time for decimation: {time.time() - st}")
            print(
                f"Mesh size AFTER DECIMATION: n_vertices = {len(dec_mesh.vertices)}, n_faces = {len(dec_mesh.faces)}"
            )

        if plot_decimation:
            print(f"Decimated Mesh")
            nviz.plot_objects(dec_mesh)

        vertices = dec_mesh.vertices
        faces = dec_mesh.faces
        filepath = self.make_file(
            segment_id=segment_id,
            split_index=split_index,
            decimation_ratio=decimation_ratio_global,
            vertices=vertices,
            faces=faces,
            version=version,
        )

        if verbose:
            print(f"filepath = {filepath}")

        self.insert1(dict(
            key,
            n_vertices=len(vertices),
            n_faces=len(faces),
            mesh=filepath,
            decimation_ratio_after_proof=decimation_ratio_global),
                     allow_direct_insert=True)
def fetch_compartment_mesh(compartment,
                           segment_id,
                          split_index=0,
                          original_mesh=None,
                        verbose=False,
                          plot_mesh = False):
    """
    Purpose: To get the mesh belonging to a certain compartment
    
    Ex: 
    original_mesh = du.fetch_segment_id_mesh(segment_id)

    comp_mesh = pv.fetch_compartment_mesh("apical_shaft",
                              segment_id,
                              split_index,
                            original_mesh=original_mesh,
                                          verbose = True,
                                          plot_mesh = True,
                             )
    """
    if compartment == "apical_total":
        compartment_faces = np.concatenate([pv.fetch_compartment_faces(compartment = c,
                                                             segment_id=segment_id,
                                                             split_index=split_index,
                                                         ) for c in apu.apical_total]).astype("int")
    else:
        compartment_faces = pv.fetch_compartment_faces(compartment = compartment,
                                                             segment_id=segment_id,
                                                             split_index=split_index,
                                                         )
    if verbose:
        print(f"# of faces = {len(compartment_faces)}")
        
    if original_mesh is None:
        original_mesh = du.fetch_segment_id_mesh(segment_id)
        
    compartment_mesh = original_mesh.submesh([compartment_faces],append=True)
    
    if not tu.is_mesh(compartment_mesh):
        compartment_mesh = tu.empty_mesh()
    
    if plot_mesh:
        print(f"Plotting {compartment}")
        nviz.plot_objects(original_mesh,
                         meshes=[compartment_mesh],
                         meshes_colors="red")
    
    return compartment_mesh
示例#9
0
def skeleton_from_seg_id(
    segment_id,
    cloudvolume_obj=cloudvolume_raw,
    plot=False,
):
    """
    Purpose: To return a skeleton object from downloading through the
    cloudvolume interface
    """
    skel = cloudvolume_obj.skeleton.get(segment_id)

    skeleton = sk.convert_nodes_edges_to_skeleton(skel.vertices, skel.edges)
    #skeleton_proof = sk.convert_nodes_edges_to_skeleton(skeleton.vertices,skel_proof.edges)

    if plot:
        nviz.plot_objects(skeletons=[skeleton])

    return skeleton
def fetch_soma_mesh(segment_id,
                   split_index = None,
                   plot_soma = False,
                   return_sdf=False,
                    verbose = False,
                   ):
    """
    Purpose: To retrieve the soma of the 
    proofread mesh

    1) Get the soma center x,y,z of segment

    """

    segment_id,split_index = pv.segment_id_and_split_index(segment_id,
                                     split_index)

    soma_x,soma_y,soma_z = pv.features_from_proofread_table(
    segment_id = segment_id,
    split_index = split_index,
    feature_names = ["soma_x","soma_y","soma_z"],
    return_dict = False,
    )
    
    if verbose:
        print(f"For segment_id = {segment_id}, split_index = {split_index} ")
        print(f"soma_x = {soma_x}, soma_y= {soma_y}, soma_z= {soma_z}")

    soma_mesh,soma_run_time,soma_sdf = du.get_soma_mesh_list_singular(
        segment_id = segment_id,
        soma_center = [soma_x,soma_y,soma_z],
        verbose = verbose
    )
    

    if plot_soma:
        print(f"Plotting soma: {soma_mesh}")
        nviz.plot_objects(soma_mesh)

    if return_sdf:
        return soma_mesh,soma_sdf
    else:
        return soma_mesh
def fetch_error_mesh(segment_id,split_index = 0,
                    original_mesh = None,
                    plot_mesh = False):
    """
    Ex: 
    pv.fetch_error_mesh(segment_id,
                   split_index,
                   original_mesh=original_mesh,
                   plot_mesh = True)
    """
    proof_mesh,error_mesh = pv.fetch_proofread_mesh(segment_id,
                                                    split_index = split_index,
                         original_mesh = original_mesh,
                        return_error_mesh = True)
    if plot_mesh:
        nviz.plot_objects(proof_mesh,
                         meshes=[error_mesh],
                         meshes_colors=["red"])
    
    return error_mesh
def fetch_compartment_skeleton(compartment,
                          segment_id,
                          split_index = 0,
                           verbose = False,
                               plot_skeleton = False,
                               original_mesh = None,
                          ):
    """
    Purpose: To retrieve the datajoint
    stored skeleton for that compartment
    
    Ex: 
    comp_skeleton = pv.fetch_compartment_skeleton("apical_shaft",
                             segment_id,
                             split_index,
                            plot_skeleton = True)
    """
    segment_id,split_index = pv.segment_id_and_split_index(segment_id,split_index)
    
    if compartment == "apical_total":
        comp_skeleton = sk.stack_skeletons([(du.proofreading_stats_table() & dict(segment_id=segment_id,
                                        split_index=split_index)).fetch1(f"{c}_skeleton") for c in apu.apical_total])
    else:
        comp_skeleton = (du.proofreading_stats_table() & dict(segment_id=segment_id,
                                        split_index=split_index)).fetch1(f"{compartment}_skeleton")
    
    if len(comp_skeleton) == 0:
        comp_skeleton = np.array([]).reshape(-1,2,3)
    if verbose:
        print(f"{compartment} skeleton = {sk.calculate_skeleton_distance(comp_skeleton)}")
            
    if plot_skeleton:
        if original_mesh is None:
            original_mesh = du.fetch_segment_id_mesh(segment_id)
        nviz.plot_objects(original_mesh,
                         skeletons = [comp_skeleton])
        
    return comp_skeleton
示例#13
0
def plot_axon_dendrite_skeletons(segment_id,
                                 split_index=0,
                                 table=None,
                                 axon_skeleton_color="black",
                                 plot_mesh=True,
                                 axon_color="black",
                                 dendrite_color="aqua",
                                 mesh_color="green",
                                 verbose=True):
    """
    Purpose: Fetch the old mesh and all of the axon 
    skeletons and graph

    Pseudocode: 
    1) Fetch the old mesh
    2) fetch all of the axon skeletons 
    3) graph together
    """
    if plot_mesh:
        neuron_mesh = fetch_segment_id_mesh(segment_id)
    else:
        neuron_mesh = None

    axon_skeleton, dendrite_skeleton = hdju.axon_dendrite_skeleton(
        segment_id=segment_id, split_index=split_index, table=table)
    if verbose:
        print(
            f"axon_skeleton length = {sk.calculate_skeleton_distance(axon_skeleton)}"
        )
        print(
            f"dendrite_skeleton length = {sk.calculate_skeleton_distance(dendrite_skeleton)}"
        )

    nviz.plot_objects(main_mesh=neuron_mesh,
                      main_mesh_color=mesh_color,
                      skeletons=[axon_skeleton, dendrite_skeleton],
                      skeletons_colors=[axon_color, dendrite_color])
def fetch_compartments_skeletons(
    compartments,
    segment_id,
    split_index=0,
    verbose=False,

    #plotting arguments
    plot_skeleton = False,
    original_mesh=None,):
    """
    
    Ex: 
    import apical_utils as apu
    pv.fetch_compartments_skeletons(apu.default_compartment_order,
                                segment_id,
                                split_index,
                                 original_mesh = original_mesh,
                                plot_skeleton=True)
    
    """
    
    
    compartments = nu.convert_to_array_like(compartments)
    
    comp_skeletons = [pv.fetch_compartment_skeleton(c,segment_id,split_index,
                                                   verbose=verbose) for c in compartments]
    
    if plot_skeleton:
        if original_mesh is None:
            original_mesh = du.fetch_segment_id_mesh(segment_id)
        
        nviz.plot_objects(original_mesh,
                         skeletons=comp_skeletons,
                         skeletons_colors = apu.colors_from_compartments(compartments))
        
    return comp_skeletons
示例#15
0
    def make(self,key):
        """
        Purpose: To decimate a mesh by a perscribed
        decimation ratio and algorithm
        
        Pseudocode: 
        1) Get the current mesh,somas,glia and nuclie faces
        2) Get the parameters and run the neuron preprocessing
        
        """
        global_time = time.time()
        
        segment_id = key["segment_id"]
        decomposition_hash = key["decomposition_method"]
        ver =key["ver"]
        
        if verbose:
            print(f"\n\n--Working on {segment_id}: (decomposition_hash = {decomposition_hash})")
        
#         if len(self & dict(
#             ver=ver,
#             segment_id = segment_id,
#             decimation=decomposition_hash)) > 0:
#             if verbose:
#                 print(f"Already processed {segment_id} (decomposition_hash = {decomposition_hash})")
#             return 
            
        
        #1) 
        st = time.time()
        
        mesh = hdju.fetch_segment_id_mesh(segment_id)
        somas = hdju.get_soma_mesh_list_filtered(segment_id)
        print(f"somas = {somas}")

        glia_faces,nuclei_faces = hdju.get_segment_glia_nuclei_faces(segment_id)
        
        
        if plotting:
            soma_mesh_center = hdju.soma_info_center(segment_id,return_nm=True)
            # rotating the mesh
            nviz.plot_objects(hu.align_mesh_from_soma_coordinate(mesh,
                                                                 soma_center=soma_mesh_center
                                                                ))
            
        if verbose:
            print(f"Collecting Mesh Info: {time.time() - st}")
        
        
        #2) 
        
        preprocess_args = DecompositionMethod.restrict_one_part_with_hash(
            decomposition_method_hash).fetch1()
        
        fill_hole_size = preprocess_args["fill_hole_size"]
        
        
        current_preprocess_neuron_kwargs = {
            k:v for k,v in preprocess_args.items() if k in preprocess_neuron_kwargs.keys()} 
        
        current_spines_kwargs = {
            k:v for k,v in preprocess_args.items() if k in spines_kwargs.keys()} 
        
        print(f"current_preprocess_neuron_kwargs = \n{current_preprocess_neuron_kwargs}")
        print(f"current_spines_kwargs = \n{current_spines_kwargs}")
        
        params_to_change = [k for k in current_preprocess_neuron_kwargs if k.split("_")[-1] in ["cgal","map"]]
        if verbose:
            print(f"params_to_change = {params_to_change}")
            
        current_preprocess_neuron_kwargs["width_threshold_MAP"] = current_preprocess_neuron_kwargs["width_threshold_map"]
        current_preprocess_neuron_kwargs["size_threshold_MAP"] = current_preprocess_neuron_kwargs["size_threshold_map"]
        current_preprocess_neuron_kwargs["max_stitch_distance_CGAL"] = current_preprocess_neuron_kwargs["max_stitch_distance_cgal"]
        
        del current_preprocess_neuron_kwargs["width_threshold_map"]
        del current_preprocess_neuron_kwargs["size_threshold_map"]
        del current_preprocess_neuron_kwargs["max_stitch_distance_cgal"]
        
        
        description = "0_25"
        st = time.time()
        
        neuron_obj = neuron.Neuron(
                mesh = mesh,
                somas = somas,
                segment_id=segment_id,
                description=description,
                suppress_preprocessing_print=False,
                suppress_output=False,
                calculate_spines=True,
                widths_to_calculate=["no_spine_median_mesh_center"],
                glia_faces=glia_faces,
                nuclei_faces = nuclei_faces,
                decomposition_type = "meshafterparty",
                preprocess_neuron_kwargs=current_preprocess_neuron_kwargs,
                spines_kwargs=spines_kwargs,
                fill_hole_size=fill_hole_size
                        )

        
        
        neuron_obj_comb = nru.combined_somas_neuron_obj(neuron_obj,
                                                inplace = False,
                                                verbose = verbose,
                                                plot_soma_limb_network = plotting)
        if verbose:
            print(f"\n\n\n---- Total preprocessing time = {time.time() - st}")
        
        if plotting:
            nviz.visualize_neuron(neuron_obj_comb,
                     limb_branch_dict="all",
                     mesh_whole_neuron=True)
            
        #3) 
        st = time.time()
        stats_dict = neuron_obj_comb.neuron_stats(stats_to_ignore = [
                    "n_boutons",
                     "axon_length",
                     "axon_area",
                     "max_soma_volume",
                     "max_soma_n_faces",],
            include_skeletal_stats = True,
            include_centroids= True,
            voxel_adjustment_vector=voxel_adjustment_vector,

        )
        
        if verbose:
            print(f"-- Generating Stats: {time.time() - st}")
        
        #4) Save the neuron object in a certain location
        file_name = f"{neuron_obj_comb.segment_id}_{decomposition_hash}"
        file_name_decomp = f"{file_name}_{dataset}_decomposition"
        output_folder=str(target_dir_decomp)
        
        
        st = time.time()
        ret_file_path = neuron_obj_comb.save_compressed_neuron(
            output_folder=output_folder,
            file_name= file_name_decomp,
            return_file_path=True,
            export_mesh=False,
            suppress_output=True,
            )

        ret_file_path_str = str(ret_file_path.absolute()) + ".pbz2"
        
   
        if verbose:
            print(f"-- Neuron Object Save time: {time.time() - st}")
        
        #5) Outputting skeleton object, computing stats and saving
        st = time.time()
        
        sk_stats = nst.skeleton_stats_from_neuron_obj(
        neuron_obj_comb,
         include_centroids=True,
         voxel_adjustment_vector=voxel_adjustment_vector,
        verbose = True)
        
        skeleton = neuron_obj_comb.skeleton
        file_name_decomp_sk = f"{file_name}_{dataset}_decomposition_sk"
        ret_sk_filepath = su.compressed_pickle(
            skeleton,
            filename = file_name_decomp_sk,
            folder=str(target_dir_sk),
            return_filepath=True)
        
        if verbose:
            print(f"ret_sk_filepath = {ret_sk_filepath}")
            
        if verbose:
            print(f"-- Skeleton Generation and Save time: {time.time() - st}")
        
        
        
        # 6) make the insertions
        run_time = run_time=np.round(time.time() - global_time,4)
        # -- decomp table --
        decomp_dict = dict(key.copy(),
                       process_version = process_version,
                       index = 0,
                       multiplicity=1,
                         decomposition=ret_file_path_str,
                          run_time = run_time)
        decomp_dict.update(stats_dict)
        
        self.insert1(decomp_dict,
                     allow_direct_insert = True,
                     ignore_extra_fields = True,
                     skip_duplicates=True)
        self.Object.insert1(
                    decomp_dict,
                    allow_direct_insert = True,
                     ignore_extra_fields = True,
                     skip_duplicates=True)
        
        #-- sk table
        sk_dict = dict(key.copy() ,
                       process_version = process_version,
                       index = 0,
                       multiplicity=1,
                         skeleton=ret_sk_filepath,
                      run_time = run_time)
        sk_dict.update(sk_stats)
        
        SkeletonDecomposition.insert1(sk_dict,
                     allow_direct_insert = True,
                     ignore_extra_fields = True,
                     skip_duplicates=True)
        SkeletonDecomposition.Object.insert1(
                    sk_dict,
                    allow_direct_insert = True,
                     ignore_extra_fields = True,
                     skip_duplicates=True)
def presyn_postsyn_skeletal_path_from_synapse_id(
    G,
    synapse_id,
    synapse_coordinate=None,
    segment_ids=None,
    return_nm=return_nm_default,

    #arguments for plotting the paths
    plot_skeletal_paths=False,
    synapse_color="red",
    synapse_scatter_size=0.2,  #,
    path_presyn_color="yellow",
    path_postsyn_color="blue",
    path_scatter_size=0.05,  #0.2,
    plot_meshes=True,
    mesh_presyn_color="orange",
    mesh_postsyn_color="aqua",
    verbose=False,
    remove_soma_synanpse_nodes=True,
):
    """
    Purpose: To develop a skeletal path coordinates between 
    two segments ids that go through a soma 

    Application: Can then be sent to a plotting function

    Pseudocode: 
    1) Get the segment ids paired with that synapse id 
    (get synapse coordinate if not precomputed)
    2) Get the proofread skeletons associated with the segment_ids
    3) Get the soma center coordinates to determine paths
    4) Get the closest skeleton node to the coordinate
    5) Find the skeletal path in coordinates
    6) Plot the skeletal paths
    
    
    Ex: 
    G["864691136830770542_0"]["864691136881594990_0"]
    synapse_id = 299949435
    segment_ids = ["864691136830770542_0","864691136881594990_0"]

    mgu.presyn_postsyn_skeletal_path_from_synapse_id(
        G,
        synapse_id = synapse_id,
        synapse_coordinate = None,
        segment_ids = segment_ids,
        return_nm = True,
        verbose = True,
        plot_skeletal_paths=True,
        path_scatter_size = 0.04,

    )
    """

    #2) Get the proofread skeletons associated with the segment_ids
    pre_seg, post_seg = mgu.pre_post_node_names_from_synapse_id(
        G,
        node_names=segment_ids,
        synapse_id=synapse_id,
        return_one=True,
    )

    if synapse_coordinate is None:
        synapse_coordinate = mgu.synapse_coordinate_from_seg_split_syn_id(
            G, pre_seg, post_seg, synapse_id, return_nm=True)

    if verbose:
        print(
            f"syn_id = {synapse_id}: pre_seg = {pre_seg}, post_seg = {post_seg}"
        )
        print(f"synapse_coordinate = {synapse_coordinate}")

    segment_ids = [pre_seg, post_seg]
    seg_type = ["presyn", "postsyn"]

    #2) Get the proofread skeletons associated with the segment_ids
    seg_sks = [
        pv.fetch_proofread_skeleton(*pv.segment_id_and_split_index(k))
        for k in segment_ids
    ]

    #3) Get the soma center coordinates to determine paths
    soma_coordinates = [
        mgu.soma_center_from_segment_id(G, s, return_nm=True)
        for s in segment_ids
    ]

    soma_coordinates_closest = [
        sk.closest_skeleton_coordinate(curr_sk, soma_c)
        for curr_sk, soma_c in zip(seg_sks, soma_coordinates)
    ]

    if verbose:
        print(f"\n Soma Information:")
        print(f"soma_coordinates = {soma_coordinates}")
        print(f"soma_coordinates_closest = {soma_coordinates_closest}")

    #4) Get the closest skeleton node to the coordinate
    closest_sk_coords = [
        sk.closest_skeleton_coordinate(curr_sk, synapse_coordinate)
        for curr_sk in seg_sks
    ]

    #5) Find the skeletal path in coordinates
    skeletal_coord_paths = [
        np.array(
            sk.convert_skeleton_to_nodes(
                sk.skeleton_path_between_skeleton_coordinates(
                    starting_coordinate=soma_coordinates_closest[idx],
                    destination_coordinate=closest_sk_coords[idx],
                    skeleton=seg_sks[idx],
                    plot_skeleton_path=False,
                    return_singular_node_path_if_no_path=True)).reshape(-1, 3))
        for idx in range(len(seg_sks))
    ]

    if verbose:
        print(f"Path lengths = {[len(k) for k in skeletal_coord_paths]}")

    if remove_soma_synanpse_nodes:
        skeletal_coord_paths_revised = []
        for soma_c, syn_c, curr_path in zip(soma_coordinates_closest,
                                            closest_sk_coords,
                                            skeletal_coord_paths):
            if len(curr_path) > 1:
                skeletal_coord_paths_revised.append(
                    nu.setdiff2d(curr_path, np.array([soma_c, syn_c])))
            else:
                skeletal_coord_paths_revised.append(
                    np.mean(np.vstack([soma_c, syn_c]), axis=0).reshape(-1, 3))

        skeletal_coord_paths = skeletal_coord_paths_revised

    if plot_skeletal_paths:
        print(f"Plotting synapse paths")

        scatters = [synapse_coordinate.reshape(-1, 3)]
        scatter_size = [synapse_scatter_size]
        scatters_colors = [synapse_color]

        path_colors = [path_presyn_color, path_postsyn_color]

        for p_sc, p_col in zip(skeletal_coord_paths, path_colors):
            scatters.append(p_sc)
            scatter_size.append(path_scatter_size)
            scatters_colors.append(p_col)

        meshes = []
        meshes_colors = [
            mesh_presyn_color,
            mesh_postsyn_color,
        ]

        if plot_meshes:
            meshes = [pv.fetch_proofread_mesh(k) for k in segment_ids]

        nviz.plot_objects(meshes=meshes,
                          meshes_colors=meshes_colors,
                          skeletons=seg_sks,
                          skeletons_colors=meshes_colors,
                          scatters=scatters,
                          scatter_size=scatter_size,
                          scatters_colors=scatters_colors)
    if not return_nm:
        skeletal_coord_paths = [
            k / alu.voxel_to_nm_scaling for k in skeletal_coord_paths
        ]

    return skeletal_coord_paths
示例#17
0
def plot_mesh_with_somas(
    segment_id,
    split_index=None,
    plot_glia=False,
    plot_nuclei=False,
    soma_color=["red", "purple", "pink"],
    #soma_color = 'red',
    nuclei_color="black",
    glia_color="aqua",
    main_mesh_color=[0., 1., 0., 0.2],
    plot_soma_center=True,
    soma_center_size=2,
    with_skeleton=False,
    skeleton_color="black",
    align_from_soma_center=True,
    verbose=True,
):
    """
    To plot the extracted somas and the 
    glia and nuclei if found
    
    Ex: 
    curr_idx = 12_700
    segment_id = single_seg_ids[curr_idx] 
    segment_id = 316084337
    print(f"segment_id = {segment_id}")
    print((h01mat.SomaInfo() & dict(segment_id=segment_id)).fetch("celltype","layer"))
    hdju.plot_mesh_with_somas(
        segment_id,
        plot_glia=True,
        plot_nuclei=True)
    
    """
    import neuron_visualizations as nviz

    if verbose:
        print(f"segment_id = {segment_id}")
        print((h01mat.SomaInfo() & dict(segment_id=segment_id)).fetch(
            "celltype", "layer"))

    seg_mesh = hdju.fetch_segment_id_mesh(segment_id)
    try:
        glia_mesh = hdju.glia_mesh(segment_id, mesh=seg_mesh)
    except:
        if verbose:
            print(f"No glia_mesh entry detected")
        glia_mesh = None

    try:
        nuclei_mesh = hdju.nuclei_mesh(segment_id, mesh=seg_mesh)
    except:
        if verbose:
            print(f"No nuclie entry detected")
        nuclei_mesh = None

    meshes = list(hdju.get_extracted_somas(segment_id))
    meshes_colors = mu.generate_non_randon_named_color_list(
        len(meshes), user_colors=soma_color)

    if plot_glia and glia_mesh is not None:
        if len(glia_mesh.faces) > 0:
            meshes.append(glia_mesh)
            meshes_colors.append(glia_color)
        if verbose:
            print(f"Empty Glia Mesh")
    if plot_nuclei and nuclei_mesh is not None:
        if len(nuclei_mesh.faces) > 0:
            meshes.append(nuclei_mesh)
            meshes_colors.append(nuclei_color)
        if verbose:
            print(f"Empty Nuclie Mesh")

    if plot_soma_center:
        scatters = hdju.soma_info_center(segment_id)
        scatters = [np.array(scatters).reshape(-1, 3)]
    else:
        scatters = None,

    if with_skeleton:
        try:
            skeletons = [
                np.array(skeleton_from_decomposition(segment_id, split_index))
            ]
        except:
            skeletons = [np.array(skeleton_from_decomposition(segment_id))]
    else:
        skeletons = []

    if align_from_soma_center:
        import human_utils as hu
        soma_mesh_center = hdju.soma_info_center(segment_id, return_nm=True)
        meshes = [
            hu.align_mesh(k, soma_center=soma_mesh_center, verbose=False)
            for k in meshes
        ]
        seg_mesh = hu.align_mesh(seg_mesh,
                                 soma_center=soma_mesh_center,
                                 verbose=False)
        skeletons = [
            hu.align_skeleton(
                k,
                soma_center=soma_mesh_center,
                verbose=False,
            ) for k in skeletons
        ]
        scatters = [
            hu.align_array(k, soma_center=soma_mesh_center, verbose=False)
            for k in scatters
        ]

    nviz.plot_objects(
        seg_mesh,
        main_mesh_color=main_mesh_color,
        meshes=meshes,
        meshes_colors=meshes_colors,
        scatters=scatters,
        scatter_size=soma_center_size,
        skeletons=skeletons,
        skeletons_colors=skeleton_color,
    )
def plot_proofread_neuron(
    segment_id,
    split_index = 0,
    cell_type = None,
    original_mesh = None,
    
    plot_proofread_skeleton = False,
    
    proofread_mesh_color = "green",
    proofread_mesh_alpha = None,
    proofread_skeleton_color = "black",

    plot_nucleus = True,
    nucleus_size = 1,
    nucleus_color = "proofread_mesh_color",#"black",


    plot_synapses = True,
    synapses_size = 0.05,
    synapse_plot_type = "spine_bouton",#"compartment"#  "valid_error" #"valid_presyn_postsyn"
    synapse_compartments = None,
    synapse_spine_bouton_labels = None,
    plot_error_synapses = False,
    valid_synapses_color = "orange",
    error_synapses_color = "aliceblue",
    synapse_queries = None,
    synapse_queries_colors = None,

    plot_error_mesh = False,
    error_mesh_color = "black",
    error_mesh_alpha = 1,


    compartments = None,
    #compartments = ["apical_total"]
    #compartments= ["axon","dendrite"]
    plot_compartment_meshes = True,
    compartment_mesh_alpha = 0.3,
    plot_compartment_skeletons = True,

    verbose = False,
    print_spine_colors = True,
    print_compartment_colors = True,
    
    #arguments for plotting more scatters
    scatters = None,
    scatter_sizes = 0.2,
    scatters_colors = "yellow",
    
    show_at_end = True,
    append_figure = False,
    ):
    
    """
    Purpose: Will plot the saved
    proofread information of a neuron

    Ex: 
    #trying on inhibitory
    segment_id,split_index = (864691134917559306,0)

    original_mesh = du.fetch_segment_id_mesh(segment_id)

    pv.plot_proofread_neuron(
        segment_id,
        split_index,
        original_mesh=original_mesh,
        plot_error_mesh=False,
        verbose = True)


    """
    if not append_figure:
        import ipyvolume as ipv
        ipv.clear()
    
    su.ignore_warnings()
    
    if type(segment_id) == str:
        segment_id,split_index = pv.segment_id_and_split_index_from_node_name(segment_id)
    
    if verbose:
        print(f"Plotting {segment_id}_{split_index} (nucleus_id={pv.nucleus_id_from_segment_id(segment_id,split_index)})")
    
    if cell_type is None:
        cell_type = pv.cell_type_from_segment_id(segment_id,split_index)
        if verbose:
            print(f"cell_type = {cell_type}")
    if synapse_compartments is None:
        synapse_compartments = apu.compartments_to_plot(cell_type)
        
    if synapse_spine_bouton_labels is None:
        synapse_spine_bouton_labels = spu.spine_bouton_labels_to_plot()
        
    if compartments is None:
        compartments = apu.compartments_to_plot(cell_type)

    meshes = []
    meshes_colors = []
    skeletons = []
    skeletons_colors = []
    meshes_alpha = []
    
    
    if scatters is not None:
        scatters_colors = nu.convert_to_array_like(scatters_colors)
        if len(scatters_colors) == 1:
            scatters_colors = scatters_colors*len(scatters)
        scatter_sizes = nu.convert_to_array_like(scatter_sizes)
        if len(scatter_sizes) == 1:
            scatter_sizes = scatter_sizes*len(scatters)
    else:
        scatters = []
        scatters_colors = []
        scatter_sizes = []

    if original_mesh is None:
        original_mesh = du.fetch_segment_id_mesh(segment_id)

    compartment_color_dict = dict(valid_mesh=proofread_mesh_color)
        
    proof_mesh,error_mesh = pv.fetch_proofread_mesh(segment_id,
                            split_index = split_index,
                            original_mesh=original_mesh,
                            return_error_mesh=True)

    if plot_proofread_skeleton:
        proof_skeleton = pv.fetch_proofread_skeleton(segment_id,
                                   split_index,
                                   plot_skeleton=False,
                                   #original_mesh=original_mesh
                                                    )
    else:
        proof_skeleton = None

    if plot_error_mesh:
        meshes.append(error_mesh)
        meshes_colors.append(error_mesh_color)
        meshes_alpha.append(error_mesh_alpha)
        compartment_color_dict["error_mesh"] = error_mesh_color
    
    


    if plot_nucleus:
        nuc_center = pv.nucleus_center_from_segment_id(segment_id,
                                         split_index)
        
        if nucleus_color == "proofread_mesh_color":
            nucleus_color = proofread_mesh_color
        
        if nuc_center is None:
            print(f"No nucleus to plot")
        else:
            scatters += [nuc_center.reshape(-1,3)]
            scatters_colors += [nucleus_color]
            scatter_sizes += [nucleus_size]


    #get the synapse groups
    if plot_synapses:
    
        synapses_objs = pv.syanpse_objs_from_segment_id(segment_id,split_index)
        
        (syn_scatters,
        syn_colors,
        syn_sizes) = syu.synapse_plot_items_by_type_or_query(
                        synapses_objs,
                        synapses_size = synapses_size,
                        synapse_plot_type = synapse_plot_type,#"compartment"#  "valid_error"
                        synapse_compartments = synapse_compartments,
                        synapse_spine_bouton_labels = synapse_spine_bouton_labels,
                        plot_error_synapses = plot_error_synapses,
                        valid_synapses_color = valid_synapses_color,
                        error_synapses_color = error_synapses_color,
                        synapse_queries = synapse_queries,
                        synapse_queries_colors = synapse_queries_colors,
        
                        verbose = verbose,
                        print_spine_colors = print_spine_colors)

        scatters += syn_scatters
        scatters_colors += syn_colors
        scatter_sizes += syn_sizes
        
   

    if compartments is not None and len(compartments) > 0:
        comp_colors = apu.colors_from_compartments(compartments)
        if plot_compartment_meshes:
            comp_meshes = pv.fetch_compartments_meshes(compartments,
                                                      segment_id,
                                                     split_index,
                                                       original_mesh = original_mesh,
                                                     )
            meshes += comp_meshes
            meshes_colors += comp_colors
            meshes_alpha += [compartment_mesh_alpha]*len(comp_meshes)

        if plot_compartment_skeletons:
            comp_sk = pv.fetch_compartments_skeletons(compartments,
                                                      segment_id,
                                                     split_index,
                                                     )

            skeletons += comp_sk
            skeletons_colors += comp_colors
            
        compartment_color_dict.update({k:v for k,v in zip(compartments,comp_colors)})

    if print_compartment_colors:
        print(f"\nCompartment Colors:")
        for k,v in compartment_color_dict.items():
            print(f"  {k}:{v}")

    
#     print(f"proof_mesh== {proof_mesh}")
#     print(f"proof_skeleton = {proof_skeleton}")
#     if len(proof_skeleton) == 0:
#         proof_skeleton = None
#     print(f"skeletons = {skeletons}")
#     print(f"meshes = {meshes}")
#     print(f"scatters = {scatters}")
    nviz.plot_objects(main_mesh = proof_mesh,
                      main_mesh_alpha=proofread_mesh_alpha,
                      main_mesh_color=proofread_mesh_color,

                      main_skeleton=proof_skeleton,
                      main_skeleton_color=proofread_skeleton_color,

                     skeletons=skeletons,
                     skeletons_colors=skeletons_colors,

                     meshes=meshes,
                     meshes_colors=meshes_colors,
                     mesh_alpha=meshes_alpha,

                     scatters=scatters,
                     scatter_size=scatter_sizes,
                     scatters_colors=scatters_colors,
                      
                     show_at_end = show_at_end,
                    append_figure = append_figure,
                     )