Example #1
0
def findWindow(node, branch, window_size=4, scale=(0.22, 0.22, 0.3)):
    if isinstance(node, np.ndarray):
        node = node.tolist()
        node[1] = 2
    parent_distance = 0
    child_distance = 0
    parent_nodes = []
    child_nodes = []
    # select nodes in distance window_size/2 around node
    current_parent_node = node
    while parent_distance < window_size / 2:
        parent_nodes.append(current_parent_node)
        next_parent_node = utility.nextNode(current_parent_node,
                                            branch).tolist()
        if not isinstance(next_parent_node, list):
            #print('not enough parent nodes!')
            break
        parent_distance += utility.dist3D(current_parent_node,
                                          next_parent_node,
                                          scale=scale)
        current_parent_node = next_parent_node
    current_child_node = node
    while child_distance < window_size / 2:
        child_nodes.append(current_child_node)
        next_child_node = utility.prevNode(current_child_node, branch).tolist()
        if not next_child_node == []:
            next_child_node = next_child_node[0]
        if not isinstance(next_child_node, list):
            #print('not enough child nodes!')
            break
        child_distance += utility.dist3D(current_child_node,
                                         next_child_node,
                                         scale=scale)
        current_child_node = next_child_node
    child_nodes = child_nodes[1:]
    try:
        window_nodes = np.concatenate(
            (np.array(parent_nodes), np.array(child_nodes)), axis=0)
    except ValueError:
        if parent_nodes == []:
            window_nodes = child_nodes
        elif child_nodes == []:
            window_nodes = parent_nodes
        else:
            raise NameError('NoWindowFound')
    return np.array(window_nodes)
Example #2
0
def interpolateNodes(start, end, idx, radius=None):
    """
    Interpolate nodes between `start` and `end`.

    Parameters
    ----------
    start : np.ndarray
        Start node.
    end : np.ndarray
        End node.
    idx : int
        Lower bound to start index count for interpolated nodes.
    radius : int or None
        If int, constant radius for interpolated nodes, else radius is interpolated linearly.

    Returns
    -------
    nodes : np.ndarray
        Interpolated nodes.
    """

    distance_in_pixels = utility.dist3D(start, end)
    nodes = np.zeros((int(distance_in_pixels), 7))
    nodes[:,0] = np.arange(idx+1, idx+1+len(nodes))
    nodes[:,6] = nodes[:,0] - 1
    if end[2]-start[2] != 0:
        nodes[:,2] = np.linspace(start[2], end[2], len(nodes))
    else:
        nodes[:,2] = start[2]
    if end[3]-start[3] != 0:
        nodes[:,3] = np.linspace(start[3], end[3], len(nodes))
    else:
        nodes[:,3] = start[3]
    if end[4]-start[4] != 0:
        nodes[:,4] = np.linspace(start[4], end[4], len(nodes))
    else:
        nodes[:,4] = start[4]
    nodes[:,1] = start[1]
    if radius:
        nodes[:,5] = radius
    else:
        if end[5]-start[5] != 0:
            nodes[:,5] = np.linspace(start[5], end[5], len(nodes))
        else:
            nodes[:,5] = start[5]
    return nodes
Example #3
0
def kinkPositions(fully_annotated_mainbranch, scale=(0.22, 0.22, 0.3)):
    """
    Returns the positoins of kinks and outgrowth events along the mainbranch.

    Parameters
    ----------
    fully_annotated_mainbranch : np.ndarray
        Mainbranch with annotated kinks (encoded in the radius) and outgrowth events (encoded with type set to 2).
    scale : tuple of floats
        x, y and z scales of the images underlying the analysis.

    Returns
    -------
    kink_positions : np.ndarray
        Array with distances of kinks from distal end.
    outgrowth_positions : np.ndarray
        Array with distances of outgrowth events form distal end.
    """

    n_kinks = np.sum(fully_annotated_mainbranch[:, 5] > 0.5)
    n_outgrowths = np.sum(fully_annotated_mainbranch[:, 1] == 2)
    kink_positions = np.zeros(n_kinks)
    outgrowth_positions = np.zeros(n_outgrowths)
    #fully_annoated_mainbranch has nodes of raius 0.5, nodes of radius >0.5 are kinks.
    endpoints = utility.findEndpoints(fully_annotated_mainbranch,
                                      return_node=True)
    current_node = endpoints[0]
    next_node = utility.nextNode(current_node, fully_annotated_mainbranch)
    distance = 0
    i = 0
    j = 0

    while isinstance(next_node, np.ndarray):
        distance += utility.dist3D(current_node, next_node, scale=scale)
        if next_node[5] > 0.5:  #if next_node is a kink
            kink_positions[i] = distance
            i += 1
        if next_node[1] == 2:  #if next_node is a branching point
            outgrowth_positions[j] = distance
            j += 1
        current_node = next_node
        next_node = utility.nextNode(current_node, fully_annotated_mainbranch)
    return kink_positions, outgrowth_positions
Example #4
0
def traceBranch(endpoint, tree, main_nodes=[], soma_nodes=[], scale=(1, 1, 1)):
    '''Trace from an endpoint to a root node or any other node specified in main_branch soma_nodes and pmv_nodes.'''
    if isinstance(endpoint, (float, int)):
        endpoint_index = endpoint
    else:
        endpoint_index = endpoint[0]
    if isinstance(tree, list):
        tree = np.array(tree)
    if isinstance(main_nodes, np.ndarray):
        main_nodes = main_nodes.tolist()
    if isinstance(soma_nodes, np.ndarray):
        soma_nodes = soma_nodes.tolist()

    #any tracing eventually stops in a root_node
    root_nodes_array = utility.findRoots(tree, return_node=True)
    root_nodes = root_nodes_array.tolist()

    branch = []
    length = 0
    current_node = utility.thisNode(endpoint_index, tree, as_list=False)
    branch.append(current_node)
    count = 0
    while current_node.tolist() not in root_nodes and current_node.tolist(
    ) not in main_nodes and current_node.tolist() not in soma_nodes:
        count += 1
        if count > 10000:
            break

        next_node = utility.nextNode(current_node, tree)
        dist = utility.dist3D(current_node, next_node, scale=scale)
        try:
            length += dist
            current_node = next_node
            branch.append(current_node)
        except TypeError:
            break

    branch_array = np.array(branch)
    return branch_array, length
Example #5
0
def cleanup(infilename='data/trees/plm.swc',
            outfilename='data/trees/plm_clean.swc',
            neurontype='PLM',
            scale=(0.223, 0.223, 0.3),
            visualize=True):

    tree = utility.readSWC(infilename)
    endpoints = utility.findEndpoints(tree)

    #For ALM neurons detect the soma_nodes i.e. all nodes connected to the root that are above a threshold
    if neurontype == 'ALM':
        soma_nodes = utility.findSomaNodes(tree, scale=scale)
    else:
        soma_nodes = []

    #Trace from every endpoint to a root and save the corresponding branches, select the longest as mainbranch
    branches = []
    lengths = np.zeros(len(endpoints))
    for i in range(len(endpoints)):
        branch, length = traceBranch(endpoints[i],
                                     tree,
                                     soma_nodes=soma_nodes,
                                     scale=scale)
        branches.append(branch)
        lengths[i] = length
    mainbranch = branches[lengths.argmax()]
    mainbranch_length = lengths.max() - mainbranch[-1][5] * scale[
        0]  #the last node is part of the soma and its radius gets subtracted from the final length

    #Trace from every endpoint to a node on the mainbranch to find sidebranches
    side_branches = []
    side_lengths = np.zeros(len(endpoints))
    for i in range(len(endpoints)):
        branch, length = traceBranch(endpoints[i],
                                     tree,
                                     main_nodes=mainbranch,
                                     soma_nodes=soma_nodes,
                                     scale=scale)
        side_branches.append(np.flip(branch, axis=0))
        side_lengths[i] = length - branch[-1][5] * scale[0]

    #check if sidebranches are close and parallel to mainbranch
    if visualize:
        fig, axes = plt.subplots(2, 1, sharex='col')
    all_distances = []
    all_slopes = []
    clean_side_branches = []
    windows = []
    for side_branch in side_branches:
        root = utility.findRoots(side_branch, return_node=True)[0]
        if root.tolist() in soma_nodes:
            window = [
                root
            ]  #set the searching window to the root node in case of alm soma_outgrowth side_branch.
        else:
            window = utility.findWindow(root,
                                        mainbranch,
                                        window_size=40,
                                        scale=scale)
        windows.append(window)
        min_distance_from_mainbranch = []
        for node in side_branch:
            distances = []
            for main_node in window:
                distances.append(utility.dist3D(node, main_node, scale=scale))
            min_distance_from_mainbranch.append(min(distances))
        #all_distances.append(min_distance_from_mainbranch)
        min_distance_from_mainbranch = min_distance_from_mainbranch[5:]
        if visualize:
            axes[0].plot(min_distance_from_mainbranch)
        all_distances.append(min_distance_from_mainbranch)

        n = 4
        out = np.zeros(n).tolist()
        x = np.arange(n)
        for i in range(len(min_distance_from_mainbranch) - n):
            data = min_distance_from_mainbranch[i:i + n]
            try:
                slope, intercept, r_value, p_value, std_err = linregress(
                    x, data)
            except ValueError:
                break
            if slope > 0.05:
                pass
            out.append(slope)

        if visualize:
            axes[1].plot(out, '.')
        #out.insert(0, np.zeros(n).tolist())
        all_slopes.append(out)

    if visualize:
        plt.show()

    start_node_index = np.zeros(len(all_slopes))
    for i in range(len(all_slopes)):
        if len(all_slopes[i]) == len(all_distances[i]):
            for j in range(len(all_slopes[i])):
                if all_slopes[i][j] > 0.02 or all_distances[i][j] > 0.5:
                    start_node_index[i] = j
                    break

    for i in range(len(start_node_index)):
        if start_node_index[i] == 0:
            clean_side_branches.append(side_branches[i])
        else:
            new_side_branch = side_branches[i]
            new_side_branch = new_side_branch[int(start_node_index[i]):]
            window = windows[i]
            distances = np.zeros(len(window))
            for i in range(len(window)):
                distances[i] = utility.dist3D(new_side_branch[0], window[i])
            connection_node = window[distances.argmin()]
            new_side_branch[0][6] = connection_node[0]
            clean_side_branches.append(new_side_branch)

    #connect everything again and save clean .swc file
    full_clean_tree = []
    for node in mainbranch:
        full_clean_tree.append(node)
    for node in soma_nodes:
        full_clean_tree.append(node)
    for clean_side_branch in clean_side_branches:
        for node in clean_side_branch:
            full_clean_tree.append(node)

    full_clean_tree = np.array(full_clean_tree)
    full_clean = utility.removeDoubleNodes(full_clean_tree)
    utility.saveSWC(outfilename, full_clean)
Example #6
0
def calculateAnglesWithLinearRegression(node,
                                        branch,
                                        window_size=3.2,
                                        scale=(0.22, 0.22, 0.3),
                                        visualize=True,
                                        fixed_node=False):
    if isinstance(node, np.ndarray):
        node = node.tolist()
    parent_distance = 0
    child_distance = 0
    parent_nodes = []
    child_nodes = []

    #select parent nodes in given window
    current_parent_node = node
    while parent_distance < window_size / 2:
        parent_nodes.append(current_parent_node)
        try:
            next_parent_node = utility.nextNode(current_parent_node,
                                                branch).tolist()
        except:
            return 180
        if not isinstance(next_parent_node, list):
            #print('not enough parent nodes!')
            return 180
        parent_distance += utility.dist3D(current_parent_node,
                                          next_parent_node,
                                          scale=scale)
        current_parent_node = next_parent_node

    #select child nodes in given window
    current_child_node = node
    while child_distance < window_size / 2:
        child_nodes.append(current_child_node)
        next_child_node = utility.prevNode(current_child_node, branch).tolist()
        if not next_child_node == []:
            next_child_node = next_child_node[0]

        if not isinstance(next_child_node, list):
            #print('not enough child nodes!')
            return 180
        try:
            child_distance += utility.dist3D(current_child_node,
                                             next_child_node,
                                             scale=scale)
            current_child_node = next_child_node
        except:
            return 180

    #take the coordinates from the nodes
    parent_nodes = np.array(parent_nodes)
    child_nodes = np.array(child_nodes)
    parent_points = parent_nodes[:, 2:5]
    child_points = child_nodes[:, 2:5]

    #calculate the mean of the points
    parent_mean = parent_points.mean(axis=0)
    child_mean = child_points.mean(axis=0)

    #calculate svd's
    parent_uu, parent_dd, parent_vv = np.linalg.svd(parent_points -
                                                    parent_mean)
    child_uu, child_dd, child_vv = np.linalg.svd(child_points - child_mean)

    parent_uu_fixednode, parent_dd_fixednode, parent_vv_fixednode = np.linalg.svd(
        parent_points - parent_points[0])
    child_uu_fixednode, child_dd_fixednode, child_vv_fixednode = np.linalg.svd(
        child_points - child_points[0])

    #calculate vectors and angle
    parent_vector = parent_vv[0]
    if utility.dist3D(parent_points[0] + parent_vector,
                      parent_points[-1]) > utility.dist3D(
                          parent_points[0] - parent_vector, parent_points[-1]):
        parent_vector *= -1
    child_vector = child_vv[0]
    if utility.dist3D(child_points[0] + child_vector,
                      child_points[-1]) > utility.dist3D(
                          child_points[0] - child_vector, child_points[-1]):
        child_vector *= -1
    angle = utility.vectorAngle3D(parent_vector, child_vector)

    parent_vector_fixednode = parent_vv_fixednode[0]
    if utility.dist3D(parent_points[0] + parent_vector_fixednode,
                      parent_points[-1]) > utility.dist3D(
                          parent_points[0] - parent_vector_fixednode,
                          parent_points[-1]):
        parent_vector_fixednode *= -1
    child_vector_fixednode = child_vv_fixednode[0]
    if utility.dist3D(child_points[0] + child_vector_fixednode,
                      child_points[-1]) > utility.dist3D(
                          child_points[0] - child_vector_fixednode,
                          child_points[-1]):
        child_vector_fixednode *= -1
    angle_fixednode = utility.vectorAngle3D(parent_vector_fixednode,
                                            child_vector_fixednode)

    #visualization
    if visualize:
        linspace = np.reshape(np.linspace(-10, 10, 2), (2, 1))
        parent_line = parent_vector * linspace
        child_line = child_vector * linspace
        parent_line += parent_mean
        child_line += child_mean

        linspace_fixednode = np.reshape(np.linspace(-20, 0, 2), (2, 1))
        parent_line_fixednode = parent_vector_fixednode * linspace_fixednode
        child_line_fixednode = child_vector_fixednode * linspace_fixednode
        parent_line_fixednode += parent_points[0]
        child_line_fixednode += child_points[0]

        import matplotlib.pyplot as plt
        import mpl_toolkits.mplot3d as m3d
        lins = np.reshape(np.linspace(0, 1, 2), (2, 1))
        a = parent_points - parent_mean
        a_line = parent_vv[0] * lins
        b = child_points - child_mean
        b_line = child_vv[0] * lins
        c = parent_points - parent_points[0]
        c_line = parent_vv_fixednode[0] * lins
        d = child_points - child_points[0]
        d_line = child_vv_fixednode[0] * lins

        ax = m3d.Axes3D(plt.figure())
        ax.scatter3D(*parent_points.T, color='red')
        ax.quiver(parent_points[0][0],
                  parent_points[0][1],
                  parent_points[0][2],
                  parent_vector[0],
                  parent_vector[1],
                  parent_vector[2],
                  color='red')
        ax.quiver(parent_points[0][0],
                  parent_points[0][1],
                  parent_points[0][2],
                  parent_vv_fixednode[0][0],
                  parent_vv_fixednode[0][1],
                  parent_vv_fixednode[0][2],
                  color='orangered')
        ax.scatter3D(*child_points.T, color='blue')
        ax.quiver(child_points[0][0],
                  child_points[0][1],
                  child_points[0][2],
                  child_vector[0],
                  child_vector[1],
                  child_vector[2],
                  color='blue')
        ax.quiver(child_points[0][0],
                  child_points[0][1],
                  child_points[0][2],
                  child_vv_fixednode[0][0],
                  child_vv_fixednode[0][1],
                  child_vv_fixednode[0][2],
                  color='cyan')
        string = 'angles: ' + str(angle) + '/' + str(angle_fixednode)
        ax.text(child_points[0][0] + 1,
                child_points[0][1],
                child_points[0][2],
                s=string)
        plt.show()
    if fixed_node:
        return angle_fixednode
    else:
        return angle
def findWindow(node, branch, window_size=4, scale=(0.22, 0.22, 0.3)):
    """
    Given a `node` on a `branch`, return a window of size `window_size` [um].

    Parameters
    ----------
    node : np.ndarray
        Node around which to return a window.
    branch : np.ndarray
        Parent branch of `node`.
    window_size : float [um]
        Maximum distance of the returned window along the branch.
    scale : tuple of floats
        x, y and z scales of the images underlying the analysis.

    Returns
    -------
    window_nodes : np.ndarray
        Array of nodes belonging to the window.
    """
    if isinstance(node, np.ndarray):
        node = node.tolist()
        node[1] = 2
    parent_distance = 0
    child_distance = 0
    parent_nodes = []
    child_nodes = []
    # select nodes in distance window_size/2 around node
    current_parent_node = node
    while parent_distance < window_size / 2:
        parent_nodes.append(current_parent_node)
        next_parent_node = utility.nextNode(current_parent_node,
                                            branch).tolist()
        if not isinstance(next_parent_node, list):
            #print('not enough parent nodes!')
            break
        parent_distance += utility.dist3D(current_parent_node,
                                          next_parent_node,
                                          scale=scale)
        current_parent_node = next_parent_node
    current_child_node = node
    while child_distance < window_size / 2:
        child_nodes.append(current_child_node)
        next_child_node = utility.prevNode(current_child_node, branch).tolist()
        if not next_child_node == []:
            next_child_node = next_child_node[0]
        if not isinstance(next_child_node, list):
            #print('not enough child nodes!')
            break
        child_distance += utility.dist3D(current_child_node,
                                         next_child_node,
                                         scale=scale)
        current_child_node = next_child_node
    child_nodes = child_nodes[1:]
    try:
        window_nodes = np.concatenate(
            (np.array(parent_nodes), np.array(child_nodes)), axis=0)
    except ValueError:
        if parent_nodes == []:
            window_nodes = child_nodes
        elif child_nodes == []:
            window_nodes = parent_nodes
        else:
            raise NameError('NoWindowFound')
    return np.array(window_nodes)
Example #8
0
def cleanup(tree,
            neurontype='PLM',
            scale=(0.223, 0.223, 0.3),
            visualize=False):
    """
    Cleanup tracing errors where outgrowth events move parallel along the mainbranch.

    Parameters
    ----------
    tree : np.ndarray
        Tree on wich to perfom cleanup.
    neurontype : str {'ALM', 'PLM'}
        Process ALM or PLM neurons.
    scale :  tuple of floats
        x, y and z scales of the images underlying the analysis.
    visualize : bool
        Wheter to visualize the results.

    Returns
    -------
    full_clean_tree:
        Clean tree.
    """
    
    endpoints = utility.findEndpoints(tree)
    
    # For ALM neurons detect the soma_nodes i.e. all nodes connected to the root that have radius above a threshold
    if neurontype=='ALM':
        soma_nodes = utility.findSomaNodes(tree, scale=scale)
    else:
        soma_nodes = []
    
    # Trace from every endpoint to a root and save the corresponding branches, select the longest as mainbranch
    branches = []
    lengths = np.zeros(len(endpoints))
    for i in range(len(endpoints)):
        branch, length = traceBranch(endpoints[i], tree, soma_nodes=soma_nodes, scale=scale)
        branches.append(branch)
        lengths[i] = length
    mainbranch = branches[lengths.argmax()]
    mainbranch_length = lengths.max()-mainbranch[-1][5]*scale[0] #the last node is part of the soma and its radius gets subtracted from the final length
    
    
    # Trace from every endpoint to a node on the mainbranch to find sidebranches
    side_branches = []
    side_lengths = np.zeros(len(endpoints))
    for i in range(len(endpoints)):
        branch, length = traceBranch(endpoints[i], tree, main_nodes=mainbranch, soma_nodes=soma_nodes, scale=scale)
        side_branches.append(np.flip(branch, axis=0))
        side_lengths[i] = length-branch[-1][5]*scale[0]
    
    
    # check if sidebranches are close and parallel to mainbranch
    if visualize:
        fig, axes = plt.subplots(3, 1, sharex='col')
    all_distances = []
    all_slopes = []
    clean_side_branches = []
    windows = []
    for side_branch in side_branches:
        root = utility.findRoots(side_branch, return_node=True)[0]
        if root.tolist() in soma_nodes:
            window = [root] #set the searching window to the root node in case of alm soma_outgrowth side_branch.
        else:
            window = utility.findWindow(root, mainbranch, window_size=40, scale=scale)
        windows.append(window)
        min_distance_from_mainbranch = []
        min_distance_from_mainbranch2 = []
        for node in side_branch:
            distances = []
            distances2 = []
            for main_node in window:
                distances.append(utility.dist3D(node, main_node, scale=scale))
                distances2.append(utility.dist3DWithRadius(node, main_node, scale=scale))
            min_distance_from_mainbranch.append(min(distances))
            min_distance_from_mainbranch2.append(min(distances2))
            if min(distances)>5:
                break
        #min_distance_from_mainbranch = min_distance_from_mainbranch[5:]
        #min_distance_from_mainbranch2 = min_distance_from_mainbranch2[5:]
        if visualize:
            axes[0].plot(min_distance_from_mainbranch)
            axes[1].plot(min_distance_from_mainbranch2)
            axes[2].plot(np.cumsum(min_distance_from_mainbranch2)/(np.arange(len(min_distance_from_mainbranch2))+1))
        all_distances.append(min_distance_from_mainbranch2)

    if visualize:    
        plt.show()

    start_node_index = np.zeros(len(all_distances))
    d_th = 0.25
    for i, distance in enumerate(all_distances):
        for j, dist in enumerate(distance):
            if dist < d_th:
                start_node_index[i] = j

    #start_node_index[start_node_index<5]=0
    idx = np.max(tree[:, 0])+1
    for i in range(len(start_node_index)):
        if start_node_index[i] == 0:
            clean_side_branches.append(side_branches[i])
        else:
            new_side_branch = side_branches[i]
            radius = np.mean(new_side_branch[:int(start_node_index[i])+1, 5])
            new_side_branch = new_side_branch[int(start_node_index[i]):]
            
            
            window = windows[i]
            distances = np.zeros(len(window))
            for i in range(len(window)):
                distances[i] = utility.dist3D(new_side_branch[0], window[i])
                
                
            connection_node = window[distances.argmin()]
            
            nodes = interpolateNodes(connection_node, new_side_branch[0], idx, radius)
            idx = np.max(nodes[:, 0])+1
            new_side_branch[0][6] = nodes[-1][0]
            nodes[0][6] = connection_node[0]
            real_side_branch = np.concatenate((nodes, new_side_branch))
            #tree = np.concatenate((tree, nodes))
            clean_side_branches.append(real_side_branch)
    
    
    #connect everything again and save clean .swc file
    full_clean_tree = []
    for node in mainbranch:
        full_clean_tree.append(node)
    for node in soma_nodes:
        full_clean_tree.append(node)
    for clean_side_branch in clean_side_branches:
        for node in clean_side_branch:
            full_clean_tree.append(node)
    
    full_clean_tree = np.array(full_clean_tree)
    return utility.removeDoubleNodes(full_clean_tree)
def traceBranch(endpoint, tree, main_nodes=None, soma_nodes=None, scale=(1,1,1)):
    """
    Trace from an endpoint to a root node or any other node specified in main_branch  or soma_nodes.

    Parameters
    ----------
    endpoint : int or np.ndarray
        Index of endpoint or endpoint-node of the branch to be traced.
    tree : np.ndarray
        Tree on which to trace a branch.
    main_nodes : list or None
        List of nodes that are classified as mainbranch-nodes.
    soma_nodes : list or None
        List of nodes that are classified as soma-nodes.
    scale : tuple of floats [um]
         x, y and z scales of the images underlying the analysis.

    Returns
    -------
    branch_array : np.ndarray
        Tree containing only the traced branch.
    length : float
        Length of the traced branch.

    """

    if isinstance(endpoint, (float, int)):
        endpoint_index = endpoint
    else:
        endpoint_index = endpoint[0]
    if isinstance(tree, list):
        tree = np.array(tree)
    if isinstance(main_nodes, np.ndarray):
        main_nodes = main_nodes.tolist()
    if isinstance(soma_nodes, np.ndarray):
        soma_nodes = soma_nodes.tolist()

    if not main_nodes:
        main_nodes=[]
    if not soma_nodes:
        soma_nodes=[]
 
    
    #any tracing eventually stops in a root_node
    root_nodes_array = utility.findRoots(tree, return_node=True)
    root_nodes = root_nodes_array.tolist()

    
    branch = []
    length = 0
    current_node = utility.thisNode(endpoint_index, tree, as_list=False)
    branch.append(current_node)
    count = 0
    while current_node.tolist() not in root_nodes and current_node.tolist() not in main_nodes and current_node.tolist() not in soma_nodes:
        count += 1
        if count > 10000:
            break

        next_node = utility.nextNode(current_node, tree)
        dist = utility.dist3D(current_node, next_node, scale=scale)
        try:
            length += dist
            current_node = next_node
            branch.append(current_node)
        except TypeError:
            break
    
    branch_array = np.array(branch)
    return branch_array, length