def test_contains(self): root_node = Tree(n_children=4) self.assertTrue(root_node.contains(root_node)) node = root_node for dummy in range(4): node.split() node = node.get_child(0) self.assertTrue(root_node.contains(node)) self.assertFalse(node.contains(root_node))
def test_record(self): forest = Forest([Tree(2), Tree(2), Tree(2)]) for dummy in range(5): for child in forest.get_children(): if np.random.rand() > 0.5: child.split() forest.record(1) for tree in forest.traverse(): self.assertTrue(tree.is_marked(1))
def test_constructor(self): # Must at least specify number of children self.assertRaises(Exception, Tree) # Quadnode node = Tree(n_children=4) self.assertEqual(len(node._children), 4, 'There should be 4 children.') # Tree in a forest forest = Forest() node = Tree(forest=forest, n_children=2)
def test_get_children(self): forest = Forest() self.assertEqual(forest.get_children(), []) forest = Forest([Tree(1), Tree(3), Tree(10)]) n_children = [1, 3, 10] i = 0 for child in forest.get_children(): self.assertEqual(child.get_node_position(), i) self.assertEqual(child.n_children(), n_children[i]) i += 1
def test_set_node_type(self): # Define a new node node = Tree(n_children=3) # Should complain about changing root to branch self.assertRaises(Exception, node.set_node_type, *('BRANCH')) # Split node.split() child = node.get_child(0) self.assertRaises(Exception, child.set_node_type, *('ROOT')) self.assertRaises(Exception, child.set_node_type, *('BRANCH'))
def test_delete_children(self): # # Delete child 2 # node = Tree(n_children=4) node.split() node.delete_children(2) self.assertIsNone(node.get_child(2)) # # Delete all children # node.delete_children() self.assertFalse(node.has_children())
def test_get_child(self): # New node node = Tree(n_children=5) node.split() # Access child directly and mark it 1 child_4 = node._children[4] child_4.mark(1) # Access child via function child_4_v1 = node.get_child(4) # Check whether it's the same child. self.assertTrue(child_4_v1.is_marked(1), 'Child 4 should be marked 1.') self.assertEqual(child_4_v1, child_4, 'Children should be the same.')
def test_get_parent(self): count = 0 pos1 = [0, 0] pos2 = [1, 3] for node in [Tree(n_children=2), Tree(n_children=4)]: node.mark(1) node.split() child = node.get_child(pos1[count]) child.split() self.assertEqual(node,node.get_child(pos2[count]).get_parent(1),\ 'First marked ancestor should be node.') child.mark(1) self.assertEqual(child, child.get_child(pos2[count]).get_parent(1),\ 'First marked ancestor should be child.') count += 1
def test_has_trees(self): # Empty forest has no trees forest = Forest() self.assertFalse(forest.has_children()) # Forest with two trees t0 = Tree(2) t1 = Tree(3) forest = Forest([t0, t1]) self.assertTrue(forest.has_children()) self.assertFalse(forest.has_children(flag=0)) # Mark one tree t0.mark(0) self.assertTrue(forest.has_children(flag=0)) self.assertFalse(forest.has_children(flag=1))
def test_has_parent(self): for node in [Tree(n_children=2), Tree(n_children=4)]: node.split() for child in node.get_children(): self.assertTrue(child.has_parent(),\ 'Nodes children should have a parent.') node.mark(1) for child in node.get_children(): self.assertTrue(child.has_parent(1), \ 'Children should have parent labeled 1.') self.assertFalse(child.has_parent(2),\ 'Children do not have parent labeled 2.') self.assertFalse(node.has_parent(1), \ 'Root node should not have parents of type 1.') self.assertFalse(node.has_parent(), \ 'Root node should not have parents.')
def test_get_root(self): root_node = Tree(n_children=3) node = root_node for _ in range(10): node.split() node = node.get_child(0) self.assertEqual(node.get_root(),root_node,\ 'All children should have the same root node')
def test_add_remove_tree(self): forest = Forest() self.assertEqual(forest.n_children(), 0) node = Tree(3) forest.add_tree(node) self.assertEqual(forest.n_children(), 1) self.assertRaises(Exception, forest.add_tree, *(1, )) forest.remove_tree(0) self.assertEqual(forest.n_children(), 0)
def test_depth(self): # # Generate forest # forest = Forest([Tree(2), Tree(2)]) # # Split tree 0 three times # tree = forest.get_child(0) for dummy in range(3): tree.split() tree = tree.get_child(0) # Check that depth is 3 self.assertEqual(forest.depth(), 3) # Remove split tree and verify that depth is 0 forest.remove_tree(0) self.assertEqual(forest.depth(), 0)
def test_get_depth(self): # New node should have depth 0 node = Tree(n_children=2) self.assertEqual(node.get_depth(), 0, 'ROOT node should have depth 0') # Split node 10 times for _ in range(10): node.split() node = node.get_child(1) # Last generation should have depth 10 self.assertEqual(node.get_depth(), 10, 'Node should have depth 10.') self.assertEqual(node.get_root().get_depth(), 0, \ 'ROOT node should have depth 0')
def test_find_node(self): node = Tree(n_children=4) address = [0, 2, 3, 0] # No-one lives at this address: return None self.assertIsNone(node.find_node(address)) # Generate node with given address and try to recover it. for a in address: node.split() node = node.get_child(a) self.assertEqual(node, node.get_root().find_node(address))
def test_is_regular(self): # Make non-regular tree node = Tree(regular=False, n_children=2) self.assertFalse(node.is_regular(), 'Node is not regular.') # Make regular tree node = Tree(n_children=3) self.assertTrue(node.is_regular(), 'Node should be regular.')
def test_tree_depth(self): # New node should have depth 0 node = Tree(n_children=2) self.assertEqual(node.tree_depth(),0,\ 'Tree should have depth 0') # Split node 10 times for i in range(10): node.split() node = node.get_child(1) # All nodes should have the same tree_depth self.assertEqual(node.tree_depth(),10,\ 'Tree should have depth 10.') self.assertEqual(node.get_root().tree_depth(),10,\ 'Tree should have depth 10.')
def test_remove(self): # # Remove child 2 # node = Tree(n_children=4) node.split() child = node.get_child(2) child.remove() self.assertIsNone(node.get_child(2))
def test_find_node(self): t0 = Tree(2) t1 = Tree(3) forest = Forest([t0, t1]) t0.split() t00 = t0.get_child(0) t00.split() t001 = t00.get_child(1) self.assertEqual(forest.find_node([0, 0, 1]), t001) self.assertIsNone(forest.find_node([4, 5, 6]))
def test_in_forest(self): # # Initialize forest with node # node = Tree(n_children=2) Forest([node]) self.assertTrue(node.in_forest(), 'Node should be in forest') # # Initialize empty forest # node = Tree(n_children=2) forest = Forest() # Node should not be in there self.assertFalse(node.in_forest(), 'Node should not be in a forest.') # Add node: NOW it's in the forest. forest.add_tree(node) self.assertTrue(node.in_forest(), 'Node should be in a forest.') # Remove node: it should no longer be there. forest.remove_tree(node.get_node_position()) self.assertFalse(node.in_forest(), 'Node should no longer be in the forest.')
def test_get_node_type(self): # Define new ROOT node node = Tree(n_children=2) self.assertEqual(node.get_node_type(),'ROOT',\ 'Output "node_type" should be "ROOT".') # Split node and assert that its child is a LEAF node.split() child = node.get_child(0) self.assertEqual(node.get_node_type(),'ROOT',\ 'Output "node_type" should be "ROOT".') self.assertEqual(child.get_node_type(),'LEAF',\ 'Output "node_type" should be "LEAF".') # Split the child and assert that it is now a BRANCH child.split() self.assertEqual(child.get_node_type(),'BRANCH',\ 'Output "node_type" should be "BRANCH".')
def test_n_children(self): for n in range(10): # # Define regular node with n children # node = Tree(n_children=n) # check if number of children correct self.assertEqual(node.n_children(),n,\ 'Number of children incorrect') # split node and check if children inherit the same number of children node.split() for child in node.get_children(): self.assertEqual(child.n_children(),n,\ 'Children have incorrect number of children.')
def test_constructor(self): t0 = Tree(2) # Check whether all entries are Trees self.assertRaises(Exception, Forest, **{'trees': [t0, 0]}) # Check whether Trees are roots t0.split() t00 = t0.get_child(0) self.assertRaises(Exception, Forest, **{'trees': [t0, t00]}) t1 = Tree(4) forest = Forest([t0, t1]) self.assertEqual(len(forest._trees), 2)
def test_refine(self): # ===================================================================== # Simple Refineement # ===================================================================== # Define a new forest with two binary trees forest = Forest([Tree(2), Tree(2)]) # Refine the forest indiscriminantly forest.refine() # Check wether the trees have been split count = 0 for dummy in forest.traverse(): count += 1 self.assertEqual(count, 6) # ===================================================================== # Refinement Label # ===================================================================== # Mark second tree and refine only by its label forest.get_child(1).mark(1) forest.refine(refinement_flag=1) # Nothing should have happened (because child is not a leaf) count = 0 for dummy in forest.traverse(): count += 1 self.assertEqual(count, 6) forest.get_child(1).get_child(1).mark(1) forest.refine(refinement_flag=1) count = 0 for dummy in forest.traverse(): count += 1 self.assertEqual(count, 8) # ===================================================================== # Refinement of subforest # ===================================================================== forest = Forest([Tree(2), Tree(2)]) forest.refine() # Define subforest forest.get_child(0).get_child(0).mark(1) forest.root_subtrees(1) # Check node count count = 0 for dummy in forest.traverse(1): count += 1 self.assertEqual(count, 4) forest.refine(subforest_flag=1) # Check node count count = 0 for dummy in forest.traverse(1): count += 1 self.assertEqual(count, 10) forest.coarsen(subforest_flag=1) count = 0 for dummy in forest.traverse(1): count += 1 self.assertEqual(count, 4) # Now try with a refinement flag forest.get_child(1).mark(2) forest.refine(subforest_flag=1, refinement_flag=2) # Check node count count = 0 for dummy in forest.traverse(1): count += 1 self.assertEqual(count, 6) # ===================================================================== # Refine with new label # ===================================================================== forest = Forest([Tree(2), Tree(2)]) forest.refine() # Define subforest forest.get_child(0).get_child(0).mark(1) forest.root_subtrees(1) # Check node count count = 0 for dummy in forest.traverse(1): count += 1 self.assertEqual(count, 4) forest.refine(subforest_flag=1, new_label=4) # Node count of new label count = 0 for dummy in forest.traverse(4): count += 1 self.assertEqual(count, 10) # Node count for original submesh count = 0 for dummy in forest.traverse(1): count += 1 self.assertEqual(count, 4) # Refinement marker forest.get_child(1).mark(3) # Refine forest.refine(subforest_flag=1, refinement_flag=3, new_label=5) # Check node count count = 0 for dummy in forest.traverse(5): count += 1 self.assertEqual(count, 6)
def test_traverse(self): # # Binary Tree # # Standard node = Tree(2) forest = Forest([node]) node.split() node.get_child(0).split() node.get_child(0).get_child(1).remove() addresses = { 'breadth-first': [[0], [0, 0], [0, 1], [0, 0, 0]], 'depth-first': [[0], [0, 0], [0, 0, 0], [0, 1]] } for mode in ['depth-first', 'breadth-first']: count = 0 for leaf in forest.traverse(mode=mode): self.assertEqual(leaf.get_node_address(), addresses[mode][count]), count += 1 # # QuadTree # node = Tree(4) forest = Forest([node]) node.split() node.get_child(1).split() node.get_child(1).get_child(2).remove() addresses = [[0], [0, 0], [0, 1], [0, 2], [0, 3], [0, 1, 0], [0, 1, 1], [0, 1, 3]] count = 0 for n in node.traverse(mode='breadth-first'): self.assertEqual(n.get_node_address(), addresses[count],\ 'Incorrect address.') count += 1 # # Forest with one quadtree and one bitree # bi = Tree(2) quad = Tree(4) forest = Forest([bi, quad]) bi.split() bi.get_child(0).split() quad.split() addresses = { 'breadth-first': [[0], [1], [0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3], [0, 0, 0], [0, 0, 1]], 'depth-first': [[0], [0, 0], [0, 0, 0], [0, 0, 1], [0, 1], [1], [1, 0], [1, 1], [1, 2], [1, 3]] } for mode in ['depth-first', 'breadth-first']: count = 0 for leaf in forest.traverse(mode=mode): self.assertEqual(leaf.get_node_address(), addresses[mode][count]) count += 1
def test_coarsen(self): # ===================================================================== # Simple Coarsening # ===================================================================== # Define a new forest with two quadtrees forest = Forest([Tree(4), Tree(4)]) # Coarsen, nothing should happen forest.coarsen() # Check that forest is as it was count = 0 for dummy in forest.traverse(): count += 1 self.assertEqual(count, 2) # Refine and coarsen again forest.refine() forest.coarsen() # Check that forest is as it was count = 0 for dummy in forest.traverse(): count += 1 self.assertEqual(count, 2) # ===================================================================== # Coarsening Flag # ===================================================================== # Refine and mark one grandchild forest.refine() # Check that forest now has 10 Tree count = 0 for dummy in forest.traverse(): count += 1 self.assertEqual(count, 10) forest.get_child(0).mark(1) forest.coarsen(coarsening_flag=1) # Tree 0 should not have children, while Tree 1 should have 4 count = 0 for dummy in forest.traverse(): count += 1 self.assertEqual(count, 6) self.assertFalse(forest.get_child(0).has_children()) self.assertTrue(forest.get_child(1).has_children()) # Nothing is marked 1 self.assertFalse(any( child.is_marked(1) for child in forest.traverse())) # ===================================================================== # Coarsening with subforests # ===================================================================== forest = Forest([Tree(2), Tree(2)]) forest.refine() # Make a subforest 0, 00, 01, 1 forest.get_child(0).get_child(0).mark(2) forest.root_subtrees(2) count = 0 for node in forest.traverse(flag=2): count += 1 self.assertEqual(count, 4) # Coarsen the subforest forest.coarsen(subforest_flag=2) # Subforest now contains [0], [1] count = 0 for dummy in forest.traverse(flag=2): count += 1 self.assertEqual(count, 2) # Forest still contains 4 nodes (0,1,00,01,10,11) count = 0 for dummy in forest.traverse(): count += 1 self.assertEqual(count, 6) # New subforest: 0, 00, 01, 1 forest.get_child(0).get_child(0).mark(2) forest.root_subtrees(2) # Count subforest nodes count = 0 for dummy in forest.traverse(2): count += 1 self.assertEqual(count, 4) # Coarsening flag at a node not in the subforest forest.get_child(1).get_child(1).mark(1) forest.coarsen(subforest_flag=2, coarsening_flag=1) # Count subforest nodes count = 0 for dummy in forest.traverse(2): count += 1 self.assertEqual(count, 4) # Apply coarsening flag to a node in the subforest forest.get_child(0).mark(1) # Coarsen forest.coarsen(subforest_flag=2, coarsening_flag=1) # Now there should be 2 subnodes count = 0 for dummy in forest.traverse(2): count += 1 self.assertEqual(count, 2) # Make sure the coarsening flag is deleted. self.assertFalse(forest.get_child(0).is_marked(1)) # ===================================================================== # Coarsening with new_label # ===================================================================== # TODO: TEST HERE. # Subforest is: 0, 00, 01, 1 forest.get_child(0).get_child(0).mark(2) forest.root_subtrees(2) # Mark [0,0] with coarsening flag forest.get_child(0).mark(1) # Coarsen subforest and label with new_label forest.coarsen(subforest_flag=2, coarsening_flag=1, new_label=3, debug=True) # Check that subforest still has the same nodes count = 0 for dummy in forest.traverse(2): count += 1 self.assertEqual(count, 4) # Check that the new submesh has fewer count = 0 for dummy in forest.traverse(3): count += 1 self.assertEqual(count, 2) # # Now with no submesh # # Check that forest still has 6 nodes count = 0 for dummy in forest.traverse(): count += 1 self.assertEqual(count, 6) # Mark with coarsening flag forest.get_child(0).mark(1) # Coarsen forest.coarsen(coarsening_flag=1, new_label=4) # Check that forest still has 6 nodes count = 0 for dummy in forest.traverse(): count += 1 self.assertEqual(count, 6) # Check that subforest has 5 nodes count = 0 for dummy in forest.traverse(4): count += 1 self.assertEqual(count, 4)
from mesh import QuadMesh, Vertex, HalfEdge, QuadCell, Mesh1D, Interval, Tree from fem import DofHandler, QuadFE, GaussRule, Function from mesh import convert_to_array from plot import Plot from mpl_toolkits.mplot3d import Axes3D import matplotlib.pyplot as plt import numpy as np mtags = Tree(regular=False) mesh = Mesh1D(resolution=(1, )) flag = tuple(mtags.get_node_address()) mesh.cells.record(flag) mtags.add_child() new_flag = tuple(mtags.get_child(0).get_node_address()) mesh.cells.refine(subforest_flag=flag, \ new_label=new_flag) for leaf in mesh.cells.get_leaves(subforest_flag=flag): leaf.info() print('==' * 20) for leaf in mesh.cells.get_leaves(subforest_flag=new_flag): leaf.info() element = QuadFE(1, 'Q2') dofhandler = DofHandler(mesh, element) dofhandler.distribute_dofs() dofhandler.set_dof_vertices()
def test_child(self): forest = Forest() self.assertRaises(Exception, forest.get_child, *(0, )) forest = Forest([Tree(1), Tree(4)]) self.assertEqual(forest.get_child(1).n_children(), 4)
def test_constructor(self): # ===================================================================== # Test 1D # ===================================================================== # # Kernel consists of a single explicit Function: # f1 = lambda x: x+2 f = Explicit(f1, dim=1) k = Kernel(f) x = np.linspace(0,1,100) n_points = len(x) # Check that it evaluates correctly. self.assertTrue(np.allclose(f1(x), k.eval(x).ravel())) # Check shape of kernel self.assertEqual(k.eval(x).shape, (n_points,1)) # # Kernel consists of a combination of two explicit functions # f1 = Explicit(lambda x: x+2, dim=1) f2 = Explicit(lambda x: x**2 + 1, dim=1) F = lambda f1, f2: f1**2 + f2 f_t = lambda x: (x+2)**2 + x**2 + 1 k = Kernel([f1,f2], F=F) # Check evaluation self.assertTrue(np.allclose(f_t(x), k.eval(x).ravel())) # Check shape self.assertEqual(k.eval(x).shape, (n_points,1)) # # Same thing as above, but with nodal functions # mesh = Mesh1D(resolution=(1,)) Q1 = QuadFE(1,'Q1') Q2 = QuadFE(1,'Q2') dQ1 = DofHandler(mesh,Q1) dQ2 = DofHandler(mesh,Q2) # Distribute dofs [dQ.distribute_dofs() for dQ in [dQ1,dQ2]] # Basis functions phi1 = Basis(dQ1,'u') phi2 = Basis(dQ2,'u') f1 = Nodal(lambda x: x+2, basis=phi1) f2 = Nodal(lambda x: x**2 + 1, basis=phi2) k = Kernel([f1,f2], F=F) # Check evaluation self.assertTrue(np.allclose(f_t(x), k.eval(x).ravel())) # # Replace f2 above with its derivative # k = Kernel([f1,f2], derivatives=['f', 'fx'], F=F) f_t = lambda x: (x+2)**2 + 2*x # Check derivative evaluation F = F(f1, df2_dx) self.assertTrue(np.allclose(f_t(x), k.eval(x).ravel())) # # Sampling # one = Constant(1) f1 = Explicit(lambda x: x**2 + 1, dim=1) # Sampled function a = np.linspace(0,1,11) n_samples = len(a) # Define Dofhandler dh = DofHandler(mesh, Q2) dh.distribute_dofs() dh.set_dof_vertices() xv = dh.get_dof_vertices() n_dofs = dh.n_dofs() phi = Basis(dh, 'u') # Evaluate parameterized function at mesh dof vertices f2_m = np.empty((n_dofs, n_samples)) for i in range(n_samples): f2_m[:,i] = xv.ravel() + a[i]*xv.ravel()**2 f2 = Nodal(data=f2_m, basis=phi) # Define kernel F = lambda f1, f2, one: f1 + f2 + one k = Kernel([f1,f2,one], F=F) # Evaluate on a fine mesh x = np.linspace(0,1,100) n_points = len(x) self.assertEqual(k.eval(x).shape, (n_points, n_samples)) for i in range(n_samples): # Check evaluation self.assertTrue(np.allclose(k.eval(x)[:,i], f1.eval(x)[:,i] + x + a[i]*x**2+ 1)) # # Sample multiple constant functions # f1 = Constant(data=a) f2 = Explicit(lambda x: 1 + x**2, dim=1) f3 = Nodal(data=f2_m[:,-1], basis=phi) F = lambda f1, f2, f3: f1 + f2 + f3 k = Kernel([f1,f2,f3], F=F) x = np.linspace(0,1,100) for i in range(n_samples): self.assertTrue(np.allclose(k.eval(x)[:,i], \ a[i] + f2.eval(x)[:,i] + f3.eval(x)[:,i])) # # Submeshes # mesh = Mesh1D(resolution=(1,)) mesh_labels = Tree(regular=False) mesh = Mesh1D(resolution=(1,)) Q1 = QuadFE(1,'Q1') Q2 = QuadFE(1,'Q2') dQ1 = DofHandler(mesh,Q1) dQ2 = DofHandler(mesh,Q2) # Distribute dofs [dQ.distribute_dofs() for dQ in [dQ1,dQ2]] # Basis p1 = Basis(dQ1) p2 = Basis(dQ2) f1 = Nodal(lambda x: x, basis=p1) f2 = Nodal(lambda x: -2+2*x**2, basis=p2) one = Constant(np.array([1,2])) F = lambda f1, f2, one: 2*f1**2 + f2 + one I = mesh.cells.get_child(0) kernel = Kernel([f1,f2, one], F=F) rule1D = GaussRule(5,shape='interval') x = I.reference_map(rule1D.nodes())
def test_get_leaves(self): # # 1D # node = Tree(2) forest = Forest([node]) leaves = forest.get_leaves() # Only a ROOT node, it should be the only LEAF self.assertEqual(leaves, [node], 'Cell should be its own leaf.') # # Split cell and L child - find leaves # node.split() l_child = node.get_child(0) l_child.split() leaves = forest.get_leaves() self.assertEqual(len(leaves), 3, 'Cell should have 3 leaves.') # # Depth first order # addresses_depth_first = [[0, 0, 0], [0, 0, 1], [0, 1]] leaves = forest.get_leaves(mode='depth-first') for i in range(len(leaves)): leaf = leaves[i] self.assertEqual(leaf.get_node_address(), addresses_depth_first[i], 'Incorrect order, depth first search.') # # Breadth first order # addresses_breadth_first = [[0, 1], [0, 0, 0], [0, 0, 1]] leaves = node.get_leaves(mode='breadth-first') for i in range(len(leaves)): leaf = leaves[i] self.assertEqual(leaf.get_node_address(), addresses_breadth_first[i], 'Incorrect order, breadth first search.') node.get_child(0).get_child(0).mark('1') node.get_child(1).mark('1') node.make_rooted_subtree('1') leaves = node.get_leaves(subtree_flag='1') self.assertEqual(len(leaves),2, \ 'There should only be 2 flagged leaves') # # 2D # node = Tree(4) forest = Forest([node]) # # Split cell and SW child - find leaves # node.split() sw_child = node.get_child(0) sw_child.split() leaves = node.get_leaves() self.assertEqual(len(leaves), 7, 'Node should have 7 leaves.') # # Nested traversal # leaves = node.get_leaves() self.assertEqual(leaves[0].get_node_address(),[0,1], \ 'The first leaf in the nested enumeration should have address [1]') leaves = node.get_leaves(mode='depth-first') self.assertEqual(leaves[0].get_node_address(), [0,0,0], \ 'First leaf in un-nested enumeration should be [0,0].') # # Merge SW child - find leaves # sw_child.delete_children() leaves = node.get_leaves() self.assertEqual(len(leaves), 4, 'Node should have 4 leaves.') # # Marked Leaves # node = Tree(4) node.mark(1) forest = Forest([node]) self.assertTrue(node in forest.get_leaves(flag=1), \ 'Node should be a marked leaf node.') self.assertTrue(node in forest.get_leaves(), \ 'Node should be a marked leaf node.') node.split() sw_child = node.get_child(0) sw_child.split() sw_child.mark(1) self.assertEqual(node.get_leaves(subtree_flag=1), \ [sw_child], 'SW child should be only marked leaf') sw_child.remove() self.assertEqual(forest.get_leaves(subforest_flag=1), \ [node], 'node should be only marked leaf') # # Nested traversal # node = Tree(4) node.split() forest = Forest([node]) for child in node.get_children(): child.split() node.get_child(1).mark(1, recursive=True) node.get_child(3).mark(1) forest.root_subtrees(1) leaves = forest.get_leaves(subforest_flag=1) self.assertEqual(len(leaves), 7, 'This tree has 7 flagged LEAF nodes.') self.assertEqual(leaves[0], node.get_child(0), 'The first leaf should be the NE child.') self.assertEqual(leaves[3], node.get_child(1).get_child(0), '4th flagged leaf should be SE-NW grandchild.')