def features_from_proofread_table(segment_id,
                           feature_names,
                          split_index = None,
                           return_dict = False,
                          ):
    """
    Purpose: To get any features of a segment id
    and split index 

    """
    
    feature_names = nu.convert_to_array_like(feature_names)
    
    

    segment_id,split_index = pv.segment_id_and_split_index(segment_id,
                                 split_index)
    
    ret_value = du.segment_id_to_autoproofread_neuron_features(
        segment_id=segment_id,
        split_index = split_index,
        statistic_names=feature_names,
        validation = False,
        return_dict = return_dict
    )

    if len(feature_names) == 1:
        return ret_value[0]
    else:
        return ret_value
def fetch_compartments_meshes(compartments,
                             segment_id,
                          split_index=0,
                          original_mesh=None,
                        verbose=False,
                          plot_mesh = False,
                             mesh_alpha = 1):
    """
    Purpose: to get the requested
    compartment meshes saved off

    Ex: 
    import apical_utils as apu
    pv.fetch_compartments_meshes(apu.default_compartment_order,
                                segment_id,
                                split_index,
                                 original_mesh = original_mesh,
                                plot_mesh=True)
    """
    compartments = nu.convert_to_array_like(compartments)
    
    if original_mesh is None:
        original_mesh = du.fetch_segment_id_mesh(segment_id)
        
    comp_meshes = [pv.fetch_compartment_mesh(c,segment_id,split_index,
                                            original_mesh=original_mesh) for c in compartments]
    
    if plot_mesh:
        comp_meshes_colors = apu.colors_from_compartments(compartments)
        nviz.plot_objects(original_mesh,
                         meshes = comp_meshes,
                         meshes_colors=comp_meshes_colors,
                         mesh_alpha=mesh_alpha)
        
    return comp_meshes
def set_global_parameters_and_attributes_by_data_type(module,
                                                      data_type=None,
                                                      algorithms=None,
                                                      set_default_first=True,
                                                      verbose=False):

    #     if set_default_first and data_type != "default":
    #         if verbose:
    #             print(f"Setting default first")
    #         modu.set_global_parameters_and_attributes_by_data_type(module,
    #                                                                data_type="default",
    #                                                               set_default_first = False,
    #                                                               verbose = False)
    if data_type is None:
        data_type = "default"

    module_list = nu.convert_to_array_like(module)

    for module in module_list:

        module, algorithms_local, algorithms_only = modu.extract_module_algorithm(
            module,
            return_parameters=True,
        )

        if algorithms_local is None:
            algorithms_local = algorithms

        if verbose:
            print(
                f"Setting data type dicts for module {module.__name__} "
                f"with data_type = {data_type}, algorithms = {algorithms_local}"
            )
        (global_parameters_dict, attributes_dict
         ) = modu.collect_global_parameters_and_attributes_by_data_type(
             module=module,
             data_type=data_type,
             algorithms=algorithms_local,
             include_default=set_default_first,
             algorithms_only=algorithms_only,
             verbose=verbose)
        module.data_type = data_type
        module.algorithms = algorithms_local
        modu.set_global_parameters_by_dict(module, global_parameters_dict)
        modu.set_attributes_by_dict(module, attributes_dict)

    if verbose:
        print(f"\n--After Setting global parameters and attributes --")
        for m in module_list:
            m, algorithms_local = modu.extract_module_algorithm(
                m,
                return_parameters=False,
            )

            print(
                f"   module {m.__name__}: data_type = {m.data_type}, algorithms = {m.algorithms}"
            )
def ease_query(bounding_box_list=None,
               query_type=None,
               query_type_joiner="OR",
               verbose=False):

    if query_type is None:
        query_type = ["dendrite", "soma"]

    query_type = nu.convert_to_array_like(query_type)

    if bounding_box_list is None:
        bounding_box_list = [ease_bounding_box, ease_bounding_box_2]

    query_type_list = []
    for q_type in query_type:
        bbox_query_list = []
        for bbox_coords in [ease_bounding_box, ease_bounding_box_2]:
            ease_bounding_box_nm = bbox_coords * alu.voxel_to_nm_scaling
            if q_type == "dendrite":
                queries = [
                    f"(NOT ((dendrite_bbox_{axis_name}_max < {ease_bounding_box_nm[0][axis_idx]}) OR (dendrite_bbox_{axis_name}_min > {ease_bounding_box_nm[1][axis_idx]})))"
                    for axis_idx, axis_name in enumerate(["x", "y", "z"])
                ]
            elif q_type == "soma":
                queries = [
                    f"((soma_{axis_name}_nm > {ease_bounding_box_nm[0][axis_idx]}) AND (soma_{axis_name}_nm < {ease_bounding_box_nm[1][axis_idx]}))"
                    for axis_idx, axis_name in enumerate(["x", "y", "z"])
                ]
            else:
                raise Exception(f"Unknown query_type = {q_type}")

            bbox_query = f' ({" AND ".join(queries)}) '
            bbox_query_list.append(bbox_query)

        total_bbox_query = "OR".join(bbox_query_list)

        if verbose:
            print(f"total_bbox_query for {q_type} = {total_bbox_query}")

        query_type_list.append(total_bbox_query)

    if len(query_type_list) > 0:
        total_query_list_query = query_type_joiner.join(query_type_list)
    else:
        total_query_list_query = query_type_list

    return total_query_list_query
def fetch_compartments_skeletons(
    compartments,
    segment_id,
    split_index=0,
    verbose=False,

    #plotting arguments
    plot_skeleton = False,
    original_mesh=None,):
    """
    
    Ex: 
    import apical_utils as apu
    pv.fetch_compartments_skeletons(apu.default_compartment_order,
                                segment_id,
                                split_index,
                                 original_mesh = original_mesh,
                                plot_skeleton=True)
    
    """
    
    
    compartments = nu.convert_to_array_like(compartments)
    
    comp_skeletons = [pv.fetch_compartment_skeleton(c,segment_id,split_index,
                                                   verbose=verbose) for c in compartments]
    
    if plot_skeleton:
        if original_mesh is None:
            original_mesh = du.fetch_segment_id_mesh(segment_id)
        
        nviz.plot_objects(original_mesh,
                         skeletons=comp_skeletons,
                         skeletons_colors = apu.colors_from_compartments(compartments))
        
    return comp_skeletons
Example #6
0
def neuroglancer_scatter_from_limb_branch_dicts(
    neuron_obj,
    limb_branch_dicts,
    name_list = None,
    color_list = None,
    verbose = False,
    output_type = "server",
    transparency = 0.5,
    ):

    """
    Purpose: convert multiple 
    limb branch dicts into a neuroglancer
    link with the points edited
    
    Example: 
    
    import allen_proofreading_utils as apru
    n_link = apru.neuroglancer_scatter_from_limb_branch_dicts(
        neuron_obj,
        limb_branch_dicts = limb_branch_dict_error,
        color_list = "red",
        name_list = f"{limb_branch_filter_name}"
    )
    """


    limb_branch_dicts = nu.convert_to_array_like(limb_branch_dicts)

    if name_list is None:
        name_list = [f"limb_branch_{k}" for k in range(len(limb_branch_dicts))]

    name_list = nu.convert_to_array_like(name_list)

    if color_list is None:
        color_list = mu.generate_non_randon_named_color_list(len(name_list))

    color_list = nu.convert_to_array_like(color_list)

    if verbose:
        print(f"name_list = {name_list}")
        print(f"color_list = {color_list}")

    annotations_info = dict()
    for lb,lb_name,lb_color in zip(limb_branch_dicts,
                                  name_list,
                                  color_list):
        coordinates = nru.skeleton_nodes_from_limb_branch(
        neuron_obj,
        limb_branch_dict = lb,
            plot_nodes = False,
        )

        coordinates = coordinates/alu.voxel_to_nm_scaling

        if verbose:
            print(f"--Working on {lb_name} ({lb_color}) with {len(coordinates)} nodes")

        annotations_info[f"{lb_name}"] = dict(
                        color = lb_color,
                        coordinates = list(coordinates)
                    )

    n_link = alu.coordinate_group_dict_to_neuroglancer(
        neuron_obj.segment_id,
        annotations_info,
        output_type=output_type,
        transparency=transparency)

    return n_link
def visualize_graph_connections_by_method(
    G,
    segment_ids,
    method='meshafterparty',  #'neuroglancer'
    synapse_ids=None,  #the synapse ids we should be restricting to
    segment_ids_colors=None,
    synapse_color="red",
    plot_soma_centers=True,
    verbose=False,
    plot_synapse_skeletal_paths=True,
    plot_proofread_skeleton=False,

    #arguments for the synapse path information
    synapse_path_presyn_color="aqua",
    synapse_path_postsyn_color="orange",
    synapse_path_donwsample_factor=None,

    #arguments for neuroglancer
    transparency=0.3,
    output_type="html",

    #arguments for meshAfterParty
    plot_compartments=False,
    plot_error_mesh=False,
    synapse_scatter_size=0.2,  #0.2
    synapse_path_scatter_size=0.1,
):
    """
    Purpose: A generic function that will 
    prepare the visualization information 
    for either plotting in neuroglancer or meshAfterParty


    Pseudocode: 
    0) Determine if whether should return things in nm
    1) Get the segment id colors
    2) Get the synapses for all the segment pairs
    3) Get the soma centers if requested
    4) Get the regular int names for segment_ids (if plotting in neuroglancer)
    """

    #Pre work: Setting up the synapse scatter sizes
    scatters = []
    scatter_size = []
    scatters_colors = []

    method = method.lower()

    #0) Determine if whether should return things in nm
    if method == "neuroglancer":
        return_nm = False
    elif method == "meshafterparty":
        return_nm = True
    else:
        raise Exception("")

    #1) Get the segment id colors
    if segment_ids_colors is None:
        segment_ids_colors = mu.generate_non_randon_named_color_list(
            len(segment_ids))
    else:
        segment_ids_colors = nu.convert_to_array_like(segment_ids_colors)

    #if verbose:
    print(f"segment_ids_colors = {segment_ids_colors}")

    #2) Get the synapses for all the segment pairs
    syn_coords, syn_ids = mgu.synapses_from_segment_id_edges(
        G,
        segment_ids=segment_ids,
        synapse_ids=synapse_ids,
        return_nm=return_nm,
        return_synapse_ids=True,
        verbose=verbose)

    if len(syn_coords) == 0:
        raise Exception("No synapses to plot")

    annotations_info = dict(
        presyn=dict(color=synapse_color, coordinates=list(syn_coords)))

    if verbose:
        print(f"syn_coords = {syn_coords}")
        print(f"syn_ids = {syn_ids}")

    #3) Get the soma centers if requested
    if plot_soma_centers:
        soma_centers = mgu.soma_centers_from_segment_ids(G,
                                                         segment_ids,
                                                         return_nm=return_nm)

        if verbose:
            print(f"soma_centers = {soma_centers}")

        for s_idx, (sc, col) in enumerate(zip(soma_centers,
                                              segment_ids_colors)):
            annotations_info[f"neuron_{s_idx}"] = dict(color=col,
                                                       coordinates=list(
                                                           sc.reshape(-1, 3)))

    if plot_synapse_skeletal_paths:
        for synapse_id in syn_ids:
            pre_name, post_name = mgu.pre_post_node_names_from_synapse_id(
                G, synapse_id, node_names=segment_ids)
            pre_idx = [
                k for k, seg in enumerate(segment_ids) if pre_name == seg
            ][0]
            post_idx = [
                k for k, seg in enumerate(segment_ids) if post_name == seg
            ][0]
            presyn_path, postsyn_path = mgu.presyn_postsyn_skeletal_path_from_synapse_id(
                G,
                synapse_id=synapse_id,
                segment_ids=segment_ids,
                return_nm=return_nm,
                verbose=verbose,
                plot_skeletal_paths=False,
            )
            #annotations_info[f"pre_{synapse_id}"] = dict(
            annotations_info[f"pre_n{pre_idx}_to_n{post_idx}"] = dict(
                color=synapse_path_presyn_color, coordinates=list(presyn_path))

            #annotations_info[f"post_{synapse_id}"] = dict(
            annotations_info[f"post_n{pre_idx}_to_n{post_idx}"] = dict(
                color=synapse_path_postsyn_color,
                coordinates=list(postsyn_path))

    if not return_nm:
        #4) Get the regular int names for segment_ids
        seg_ids_int = [
            mgu.segment_id_from_seg_split_id(k) for k in segment_ids
        ]
        if verbose:
            print(f"seg_ids_int = {seg_ids_int}")

    if not return_nm:
        return alu.coordinate_group_dict_to_neuroglancer(
            seg_ids_int[0],
            annotations_info,
            output_type=output_type,
            fixed_ids=seg_ids_int,
            fixed_id_colors=segment_ids_colors,
            transparency=transparency)
    else:
        """
        Pseudocode on how to plot in meshAfterParty:

        For each group in annotations_info without "neuron" in name:
        1) Add the coordinates,color and size to the scatters list

        2) Arguments to set: 
        proofread_mesh_color= segment_ids_colors
        plot_nucleus = True
        plot_synapses = False
        plot_error_mesh = plot_error_mesh
        compartments=compartments

        """

        for k, syn_dict in annotations_info.items():
            if "neuron" in k:
                continue

            if "path" in k:
                curr_syn_size = synapse_path_scatter_size
            else:
                curr_syn_size = synapse_scatter_size

            scatters.append(np.array(syn_dict["coordinates"]).reshape(-1, 3))
            scatter_size.append(curr_syn_size)
            scatters_colors.append(syn_dict["color"])

        if plot_compartments:
            compartments = None
        else:
            compartments = []

        pv.plot_multiple_proofread_neuron(
            segment_ids=segment_ids,
            plot_proofread_skeleton=plot_proofread_skeleton,
            proofread_mesh_color=segment_ids_colors,
            proofread_skeleton_color=segment_ids_colors,
            plot_nucleus=True,
            plot_synapses=False,
            compartments=compartments,
            plot_error_mesh=plot_error_mesh,
            scatters=scatters,
            scatter_size=scatter_size,
            scatters_colors=scatters_colors,
        )
def synapses_from_segment_id_edges(G,
                                   segment_id_edges=None,
                                   segment_ids=None,
                                   synapse_ids=None,
                                   return_synapse_coordinates=True,
                                   return_synapse_ids=False,
                                   return_nm=return_nm_default,
                                   return_in_dict=False,
                                   verbose=False):
    """
    Purpose: For all segment_ids get the 
    synapses  or synapse coordinates for 
    the edges between them from the graph

    Ex: 
    seg_split_ids = ["864691136388279671_0",
                "864691135403726574_0",
                "864691136194013910_0"]

    mgu.synapses_from_segment_id_edges(G,segment_ids = seg_split_ids,
                                  return_nm=True)
    """
    if synapse_ids is not None:
        synapse_ids = nu.convert_to_array_like(synapse_ids)

    if segment_id_edges is None and segment_ids is not None:
        segment_id_edges = nu.all_directed_choose_2_combinations(segment_ids)
    elif segment_id_edges is None:
        raise Exception("")
    else:
        pass

    if verbose:
        print(f"segment_id_edges = {segment_id_edges}")

    synapses_dict = dict()
    synapses_coord_dict = dict()
    for seg_1, seg_2 in segment_id_edges:
        syn_ids, syn_coords = mgu.synapse_ids_and_coord_from_segment_ids_edge(
            G, seg_1, seg_2, return_nm=return_nm, verbose=verbose)

        if synapse_ids is not None and len(syn_ids) > 0:
            syn_ids, syn_ids_idx, _ = np.intersect1d(syn_ids,
                                                     synapse_ids,
                                                     return_indices=True)
            syn_coords = syn_coords[syn_ids_idx]

            if verbose:
                print(f"After synapse restriction: syn_ids= {syn_ids}")
                print(f"syn_coords= {syn_coords}")

        if len(syn_ids) > 0:
            if seg_1 not in synapses_dict.keys():
                synapses_dict[seg_1] = dict()
                synapses_coord_dict[seg_1] = dict()

            synapses_dict[seg_1][seg_2] = syn_ids
            synapses_coord_dict[seg_1][seg_2] = syn_coords

    if not return_in_dict and len(synapses_dict) > 0:
        synapses_coord_non_dict = []
        synapses_non_dict = []
        for seg_1 in synapses_dict.keys():
            for seg_2 in synapses_dict[seg_1].keys():
                synapses_coord_non_dict.append(
                    synapses_coord_dict[seg_1][seg_2])
                synapses_non_dict.append(synapses_dict[seg_1][seg_2])

        synapses_coord_dict = np.concatenate(synapses_coord_non_dict)
        synapses_dict = np.concatenate(synapses_non_dict)

    if return_synapse_coordinates:
        if return_synapse_ids:
            return synapses_coord_dict, synapses_dict
        else:
            return synapses_coord_dict
    else:
        return synapses_dict
def output_global_parameters_and_attributes_from_current_data_type(
        module,
        algorithms=None,
        verbose=True,
        lowercase=True,
        output_types=("global_parameters"),
        include_default=True,
        algorithms_only=False,
        abbreviate_keywords=False,
        **kwargs):
    if output_types is None:
        output_types = output_types_global

    module_list = nu.convert_to_array_like(module)
    total_dict_list = []
    for module in module_list:

        module, algorithms_local, algorithms_only_local = modu.extract_module_algorithm(
            module,
            return_parameters=True,
        )

        if algorithms_local is None:
            algorithms_local = algorithms

        if algorithms_only_local is None:
            algorithms_only_local = algorithms_only

        data_type = module.data_type

        if algorithms_local is None:
            algorithms_local = module.algorithms

        if verbose:
            print(
                f"module: {module.__name__} data_type set to {data_type}, algorithms = {algorithms_local}"
            )

        (global_parameters_dict, attributes_dict
         ) = modu.collect_global_parameters_and_attributes_by_data_type(
             module=module,
             data_type=data_type,
             algorithms=algorithms_local,
             include_default=include_default,
             output_types=output_types,
             algorithms_only=algorithms_only_local,
             verbose=verbose)

        total_dict = gu.merge_dicts([global_parameters_dict, attributes_dict])

        if lowercase:
            if isinstance(total_dict, dsu.DictType):
                total_dict = total_dict.lowercase()
            else:
                total_dict = {k.lower(): v for k, v in total_dict.items()}

        total_dict_list.append(total_dict)

    final_dict = gu.merge_dicts(total_dict_list)

    return final_dict
def collect_global_parameters_and_attributes_by_data_type(
        module,
        data_type,
        include_default=True,
        algorithms=None,
        output_types=None,
        algorithms_only=None,
        verbose=False):
    """
    PUrpose: To compile the dictionary to either
    set or output
    
    """
    if algorithms is not None:
        algorithms = nu.convert_to_array_like(algorithms)
    else:
        algorithms = []

    if output_types is None:
        output_types = output_types_global
    else:
        output_types = nu.convert_to_array_like(output_types)

    p_list = dict()
    parameters_list = []

    if include_default and data_type != "default":
        total_data_types = ["default", data_type]
    else:
        total_data_types = [data_type]

    for dict_type in output_types:
        p_list[dict_type] = []
        for data_type in total_data_types:

            dict_name = f"{dict_type}_dict_{data_type}"

            if not algorithms_only:
                try:
                    curr_dict = getattr(module, dict_name).copy()
                except:
                    if verbose:
                        print(f"Unknown dict_name = {dict_name}")
                else:
                    if verbose:
                        print(f"Collecting {dict_name}")
                        print(f"curr_dict = {curr_dict}")
                    p_list[dict_type].append(curr_dict)

            for alg in algorithms:
                #                 if data_type == "default":
                #                     break
                dict_name = f"{dict_type}_dict_{data_type}_{alg}"
                try:
                    curr_dict = getattr(module, dict_name).copy()
                except:
                    if verbose:
                        print(f"Unknown dict_name = {dict_name}")
                else:
                    if verbose:
                        print(f"Collecting {dict_name}")
                        print(f"curr_dict = {curr_dict}")
                    p_list[dict_type].append(curr_dict)

    #compiling all the dicts
    if "global_parameters" in p_list:
        global_parameters_dict = gu.merge_dicts(p_list["global_parameters"])
    else:
        global_parameters_dict = {}

    if "attributes" in p_list:
        attributes_dict = gu.merge_dicts(p_list["attributes"])
    else:
        attributes_dict = {}

    return global_parameters_dict, attributes_dict
def plot_multiple_proofread_neuron(
    segment_ids,
    split_indexes = None,
    cell_type = None,
    original_mesh = None,
    plot_proofread_skeleton = False, 
    
    proofread_mesh_color = "green",
    proofread_mesh_alpha = None,
    proofread_skeleton_color = "black",

    plot_nucleus = True,
    nucleus_size = 1,
    nucleus_color = "proofread_mesh_color",


    plot_synapses = True,
    synapses_size = 0.05,
    synapse_plot_type = "spine_bouton",#"compartment"#  "valid_error"
    synapse_compartments = None,
    synapse_spine_bouton_labels = None,
    plot_error_synapses = False,
    valid_synapses_color = "orange",
    error_synapses_color = "aliceblue",
    synapse_queries = None,
    synapse_queries_colors = None,

    plot_error_mesh = False,
    error_mesh_color = "black",
    error_mesh_alpha = 1,


    compartments = None,
    #compartments = ["apical_total"]
    #compartments= ["axon","dendrite"]
    plot_compartment_meshes = True,
    compartment_mesh_alpha = 0.3,
    plot_compartment_skeletons = True,
    
    
    # for adding new scatter points:
    scatters = None,
    scatter_size = 0.2,
    scatters_colors = "yellow",

    verbose = False,
    print_spine_colors = True,
    print_compartment_colors = True,
    ):
    
    
    import ipyvolume as ipv
    ipv.clear()
    
    su.ignore_warnings()
    
    segment_ids = nu.convert_to_array_like(segment_ids)
    
    if verbose:
        print(f"segment_ids = {segment_ids}")
    proofread_mesh_color = nu.convert_to_array_like(proofread_mesh_color)
    
    if len(proofread_mesh_color) != len(segment_ids):
        proofread_mesh_color = proofread_mesh_color*len(segment_ids)
    if split_indexes is None:
        split_indexes = [0]*len(segment_ids)
        
    for j,(seg_id,sp_idx,proof_col)  in enumerate(zip(segment_ids,
                                        split_indexes,
                                        proofread_mesh_color,
                                       )):

        if verbose:
            print(f"\n{seg_id}_{sp_idx}: {proof_col}")
            
        curr_scatters = None
        curr_scatter_size = None
        curr_scatters_colors = None
            
        if j == len(segment_ids)-1:
            show_at_end = True
            append_figure = True
            print_spine_colors_curr = print_spine_colors
            print_compartment_colors_curr = print_compartment_colors
            
            curr_scatters = scatters
            curr_scatter_size = scatter_size
            curr_scatters_colors = scatters_colors 
            
        elif j == 0:
            show_at_end = False
            append_figure = False
            print_spine_colors_curr = False
            print_compartment_colors_curr = False
        else:
            show_at_end = False
            append_figure = True
            print_spine_colors_curr = False
            print_compartment_colors_curr = False
            
        #print(f"{curr_scatters,curr_scatter_size,curr_scatters_colors}")
             
             
        pv.plot_proofread_neuron(
            seg_id,
            sp_idx,
            cell_type = cell_type,
            original_mesh = original_mesh,
            plot_proofread_skeleton=plot_proofread_skeleton,

            proofread_mesh_color = proof_col,
            proofread_mesh_alpha = proofread_mesh_alpha,
            proofread_skeleton_color = proofread_skeleton_color,

            plot_nucleus = plot_nucleus,
            nucleus_size = nucleus_size,
            nucleus_color = nucleus_color,


            plot_synapses = plot_synapses,
            synapses_size = synapses_size,
            synapse_plot_type = synapse_plot_type,#"compartment"#  "valid_error"
            synapse_compartments = synapse_compartments,
            synapse_spine_bouton_labels = synapse_spine_bouton_labels,
            plot_error_synapses = plot_error_synapses,
            valid_synapses_color = valid_synapses_color,
            error_synapses_color = error_synapses_color,
            synapse_queries = synapse_queries,
            synapse_queries_colors = synapse_queries_colors,

            plot_error_mesh = plot_error_mesh,
            error_mesh_color = error_mesh_color,
            error_mesh_alpha = error_mesh_alpha,


            compartments = compartments,
            #compartments = ["apical_total"]
            #compartments= ["axon","dendrite"]
            plot_compartment_meshes = plot_compartment_meshes,
            compartment_mesh_alpha = compartment_mesh_alpha,
            plot_compartment_skeletons = plot_compartment_skeletons,

            verbose = verbose,
            print_spine_colors = print_spine_colors_curr,
            print_compartment_colors = print_compartment_colors_curr,
            
            
            show_at_end = show_at_end,
            append_figure = append_figure,
            
            scatters = curr_scatters,
            scatter_sizes = curr_scatter_size,
            scatters_colors = curr_scatters_colors,

            )
def plot_proofread_neuron(
    segment_id,
    split_index = 0,
    cell_type = None,
    original_mesh = None,
    
    plot_proofread_skeleton = False,
    
    proofread_mesh_color = "green",
    proofread_mesh_alpha = None,
    proofread_skeleton_color = "black",

    plot_nucleus = True,
    nucleus_size = 1,
    nucleus_color = "proofread_mesh_color",#"black",


    plot_synapses = True,
    synapses_size = 0.05,
    synapse_plot_type = "spine_bouton",#"compartment"#  "valid_error" #"valid_presyn_postsyn"
    synapse_compartments = None,
    synapse_spine_bouton_labels = None,
    plot_error_synapses = False,
    valid_synapses_color = "orange",
    error_synapses_color = "aliceblue",
    synapse_queries = None,
    synapse_queries_colors = None,

    plot_error_mesh = False,
    error_mesh_color = "black",
    error_mesh_alpha = 1,


    compartments = None,
    #compartments = ["apical_total"]
    #compartments= ["axon","dendrite"]
    plot_compartment_meshes = True,
    compartment_mesh_alpha = 0.3,
    plot_compartment_skeletons = True,

    verbose = False,
    print_spine_colors = True,
    print_compartment_colors = True,
    
    #arguments for plotting more scatters
    scatters = None,
    scatter_sizes = 0.2,
    scatters_colors = "yellow",
    
    show_at_end = True,
    append_figure = False,
    ):
    
    """
    Purpose: Will plot the saved
    proofread information of a neuron

    Ex: 
    #trying on inhibitory
    segment_id,split_index = (864691134917559306,0)

    original_mesh = du.fetch_segment_id_mesh(segment_id)

    pv.plot_proofread_neuron(
        segment_id,
        split_index,
        original_mesh=original_mesh,
        plot_error_mesh=False,
        verbose = True)


    """
    if not append_figure:
        import ipyvolume as ipv
        ipv.clear()
    
    su.ignore_warnings()
    
    if type(segment_id) == str:
        segment_id,split_index = pv.segment_id_and_split_index_from_node_name(segment_id)
    
    if verbose:
        print(f"Plotting {segment_id}_{split_index} (nucleus_id={pv.nucleus_id_from_segment_id(segment_id,split_index)})")
    
    if cell_type is None:
        cell_type = pv.cell_type_from_segment_id(segment_id,split_index)
        if verbose:
            print(f"cell_type = {cell_type}")
    if synapse_compartments is None:
        synapse_compartments = apu.compartments_to_plot(cell_type)
        
    if synapse_spine_bouton_labels is None:
        synapse_spine_bouton_labels = spu.spine_bouton_labels_to_plot()
        
    if compartments is None:
        compartments = apu.compartments_to_plot(cell_type)

    meshes = []
    meshes_colors = []
    skeletons = []
    skeletons_colors = []
    meshes_alpha = []
    
    
    if scatters is not None:
        scatters_colors = nu.convert_to_array_like(scatters_colors)
        if len(scatters_colors) == 1:
            scatters_colors = scatters_colors*len(scatters)
        scatter_sizes = nu.convert_to_array_like(scatter_sizes)
        if len(scatter_sizes) == 1:
            scatter_sizes = scatter_sizes*len(scatters)
    else:
        scatters = []
        scatters_colors = []
        scatter_sizes = []

    if original_mesh is None:
        original_mesh = du.fetch_segment_id_mesh(segment_id)

    compartment_color_dict = dict(valid_mesh=proofread_mesh_color)
        
    proof_mesh,error_mesh = pv.fetch_proofread_mesh(segment_id,
                            split_index = split_index,
                            original_mesh=original_mesh,
                            return_error_mesh=True)

    if plot_proofread_skeleton:
        proof_skeleton = pv.fetch_proofread_skeleton(segment_id,
                                   split_index,
                                   plot_skeleton=False,
                                   #original_mesh=original_mesh
                                                    )
    else:
        proof_skeleton = None

    if plot_error_mesh:
        meshes.append(error_mesh)
        meshes_colors.append(error_mesh_color)
        meshes_alpha.append(error_mesh_alpha)
        compartment_color_dict["error_mesh"] = error_mesh_color
    
    


    if plot_nucleus:
        nuc_center = pv.nucleus_center_from_segment_id(segment_id,
                                         split_index)
        
        if nucleus_color == "proofread_mesh_color":
            nucleus_color = proofread_mesh_color
        
        if nuc_center is None:
            print(f"No nucleus to plot")
        else:
            scatters += [nuc_center.reshape(-1,3)]
            scatters_colors += [nucleus_color]
            scatter_sizes += [nucleus_size]


    #get the synapse groups
    if plot_synapses:
    
        synapses_objs = pv.syanpse_objs_from_segment_id(segment_id,split_index)
        
        (syn_scatters,
        syn_colors,
        syn_sizes) = syu.synapse_plot_items_by_type_or_query(
                        synapses_objs,
                        synapses_size = synapses_size,
                        synapse_plot_type = synapse_plot_type,#"compartment"#  "valid_error"
                        synapse_compartments = synapse_compartments,
                        synapse_spine_bouton_labels = synapse_spine_bouton_labels,
                        plot_error_synapses = plot_error_synapses,
                        valid_synapses_color = valid_synapses_color,
                        error_synapses_color = error_synapses_color,
                        synapse_queries = synapse_queries,
                        synapse_queries_colors = synapse_queries_colors,
        
                        verbose = verbose,
                        print_spine_colors = print_spine_colors)

        scatters += syn_scatters
        scatters_colors += syn_colors
        scatter_sizes += syn_sizes
        
   

    if compartments is not None and len(compartments) > 0:
        comp_colors = apu.colors_from_compartments(compartments)
        if plot_compartment_meshes:
            comp_meshes = pv.fetch_compartments_meshes(compartments,
                                                      segment_id,
                                                     split_index,
                                                       original_mesh = original_mesh,
                                                     )
            meshes += comp_meshes
            meshes_colors += comp_colors
            meshes_alpha += [compartment_mesh_alpha]*len(comp_meshes)

        if plot_compartment_skeletons:
            comp_sk = pv.fetch_compartments_skeletons(compartments,
                                                      segment_id,
                                                     split_index,
                                                     )

            skeletons += comp_sk
            skeletons_colors += comp_colors
            
        compartment_color_dict.update({k:v for k,v in zip(compartments,comp_colors)})

    if print_compartment_colors:
        print(f"\nCompartment Colors:")
        for k,v in compartment_color_dict.items():
            print(f"  {k}:{v}")

    
#     print(f"proof_mesh== {proof_mesh}")
#     print(f"proof_skeleton = {proof_skeleton}")
#     if len(proof_skeleton) == 0:
#         proof_skeleton = None
#     print(f"skeletons = {skeletons}")
#     print(f"meshes = {meshes}")
#     print(f"scatters = {scatters}")
    nviz.plot_objects(main_mesh = proof_mesh,
                      main_mesh_alpha=proofread_mesh_alpha,
                      main_mesh_color=proofread_mesh_color,

                      main_skeleton=proof_skeleton,
                      main_skeleton_color=proofread_skeleton_color,

                     skeletons=skeletons,
                     skeletons_colors=skeletons_colors,

                     meshes=meshes,
                     meshes_colors=meshes_colors,
                     mesh_alpha=meshes_alpha,

                     scatters=scatters,
                     scatter_size=scatter_sizes,
                     scatters_colors=scatters_colors,
                      
                     show_at_end = show_at_end,
                    append_figure = append_figure,
                     )