Ejemplo n.º 1
0
 def test_contains(self):
     root_node = Tree(n_children=4)
     self.assertTrue(root_node.contains(root_node))
     node = root_node
     for dummy in range(4):
         node.split()
         node = node.get_child(0)
     self.assertTrue(root_node.contains(node))
     self.assertFalse(node.contains(root_node))
Ejemplo n.º 2
0
 def test_record(self):
     forest = Forest([Tree(2), Tree(2), Tree(2)])
     for dummy in range(5):
         for child in forest.get_children():
             if np.random.rand() > 0.5:
                 child.split()
     forest.record(1)
     for tree in forest.traverse():
         self.assertTrue(tree.is_marked(1))
Ejemplo n.º 3
0
    def test_constructor(self):
        # Must at least specify number of children
        self.assertRaises(Exception, Tree)

        # Quadnode
        node = Tree(n_children=4)
        self.assertEqual(len(node._children), 4, 'There should be 4 children.')

        # Tree in a forest
        forest = Forest()
        node = Tree(forest=forest, n_children=2)
Ejemplo n.º 4
0
    def test_get_children(self):
        forest = Forest()
        self.assertEqual(forest.get_children(), [])

        forest = Forest([Tree(1), Tree(3), Tree(10)])
        n_children = [1, 3, 10]
        i = 0
        for child in forest.get_children():
            self.assertEqual(child.get_node_position(), i)
            self.assertEqual(child.n_children(), n_children[i])
            i += 1
Ejemplo n.º 5
0
    def test_set_node_type(self):
        # Define a new node
        node = Tree(n_children=3)

        # Should complain about changing root to branch
        self.assertRaises(Exception, node.set_node_type, *('BRANCH'))

        # Split
        node.split()
        child = node.get_child(0)
        self.assertRaises(Exception, child.set_node_type, *('ROOT'))
        self.assertRaises(Exception, child.set_node_type, *('BRANCH'))
Ejemplo n.º 6
0
 def test_delete_children(self):
     #
     # Delete child 2
     #
     node = Tree(n_children=4)
     node.split()
     node.delete_children(2)
     self.assertIsNone(node.get_child(2))
     #
     # Delete all children
     #
     node.delete_children()
     self.assertFalse(node.has_children())
Ejemplo n.º 7
0
    def test_get_child(self):
        # New node
        node = Tree(n_children=5)
        node.split()

        # Access child directly and mark it 1
        child_4 = node._children[4]
        child_4.mark(1)

        # Access child via function
        child_4_v1 = node.get_child(4)

        # Check whether it's the same child.
        self.assertTrue(child_4_v1.is_marked(1), 'Child 4 should be marked 1.')
        self.assertEqual(child_4_v1, child_4, 'Children should be the same.')
Ejemplo n.º 8
0
 def test_get_parent(self):
     count = 0
     pos1 = [0, 0]
     pos2 = [1, 3]
     for node in [Tree(n_children=2), Tree(n_children=4)]:
         node.mark(1)
         node.split()
         child = node.get_child(pos1[count])
         child.split()
         self.assertEqual(node,node.get_child(pos2[count]).get_parent(1),\
                          'First marked ancestor should be node.')
         child.mark(1)
         self.assertEqual(child, child.get_child(pos2[count]).get_parent(1),\
                          'First marked ancestor should be child.')
         count += 1
Ejemplo n.º 9
0
    def test_has_trees(self):
        # Empty forest has no trees
        forest = Forest()
        self.assertFalse(forest.has_children())

        # Forest with two trees
        t0 = Tree(2)
        t1 = Tree(3)
        forest = Forest([t0, t1])
        self.assertTrue(forest.has_children())
        self.assertFalse(forest.has_children(flag=0))

        # Mark one tree
        t0.mark(0)
        self.assertTrue(forest.has_children(flag=0))
        self.assertFalse(forest.has_children(flag=1))
Ejemplo n.º 10
0
    def test_has_parent(self):
        for node in [Tree(n_children=2), Tree(n_children=4)]:
            node.split()
            for child in node.get_children():
                self.assertTrue(child.has_parent(),\
                                'Nodes children should have a parent.')
            node.mark(1)
            for child in node.get_children():
                self.assertTrue(child.has_parent(1), \
                                'Children should have parent labeled 1.')
                self.assertFalse(child.has_parent(2),\
                                 'Children do not have parent labeled 2.')

            self.assertFalse(node.has_parent(1), \
                             'Root node should not have parents of type 1.')
            self.assertFalse(node.has_parent(), \
                             'Root node should not have parents.')
Ejemplo n.º 11
0
 def test_get_root(self):
     root_node = Tree(n_children=3)
     node = root_node
     for _ in range(10):
         node.split()
         node = node.get_child(0)
         self.assertEqual(node.get_root(),root_node,\
                          'All children should have the same root node')
Ejemplo n.º 12
0
 def test_add_remove_tree(self):
     forest = Forest()
     self.assertEqual(forest.n_children(), 0)
     node = Tree(3)
     forest.add_tree(node)
     self.assertEqual(forest.n_children(), 1)
     self.assertRaises(Exception, forest.add_tree, *(1, ))
     forest.remove_tree(0)
     self.assertEqual(forest.n_children(), 0)
Ejemplo n.º 13
0
    def test_depth(self):
        #
        # Generate forest
        #
        forest = Forest([Tree(2), Tree(2)])
        #
        # Split tree 0 three times
        #
        tree = forest.get_child(0)
        for dummy in range(3):
            tree.split()
            tree = tree.get_child(0)

        # Check that depth is 3
        self.assertEqual(forest.depth(), 3)

        # Remove split tree and verify that depth is 0
        forest.remove_tree(0)
        self.assertEqual(forest.depth(), 0)
Ejemplo n.º 14
0
    def test_get_depth(self):
        # New node should have depth 0
        node = Tree(n_children=2)
        self.assertEqual(node.get_depth(), 0, 'ROOT node should have depth 0')

        # Split node 10 times
        for _ in range(10):
            node.split()
            node = node.get_child(1)
        # Last generation should have depth 10
        self.assertEqual(node.get_depth(), 10, 'Node should have depth 10.')
        self.assertEqual(node.get_root().get_depth(), 0, \
                         'ROOT node should have depth 0')
Ejemplo n.º 15
0
    def test_find_node(self):
        node = Tree(n_children=4)
        address = [0, 2, 3, 0]
        # No-one lives at this address: return None
        self.assertIsNone(node.find_node(address))

        # Generate node with given address and try to recover it.
        for a in address:
            node.split()
            node = node.get_child(a)
        self.assertEqual(node, node.get_root().find_node(address))
Ejemplo n.º 16
0
    def test_is_regular(self):
        # Make non-regular tree
        node = Tree(regular=False, n_children=2)
        self.assertFalse(node.is_regular(), 'Node is not regular.')

        # Make regular tree
        node = Tree(n_children=3)
        self.assertTrue(node.is_regular(), 'Node should be regular.')
Ejemplo n.º 17
0
    def test_tree_depth(self):
        # New node should have depth 0
        node = Tree(n_children=2)
        self.assertEqual(node.tree_depth(),0,\
                         'Tree should have depth 0')

        # Split node 10 times
        for i in range(10):
            node.split()
            node = node.get_child(1)

        # All nodes should have the same tree_depth
        self.assertEqual(node.tree_depth(),10,\
                         'Tree should have depth 10.')
        self.assertEqual(node.get_root().tree_depth(),10,\
                         'Tree should have depth 10.')
Ejemplo n.º 18
0
 def test_remove(self):
     #
     # Remove child 2
     #
     node = Tree(n_children=4)
     node.split()
     child = node.get_child(2)
     child.remove()
     self.assertIsNone(node.get_child(2))
Ejemplo n.º 19
0
    def test_find_node(self):
        t0 = Tree(2)
        t1 = Tree(3)
        forest = Forest([t0, t1])

        t0.split()
        t00 = t0.get_child(0)
        t00.split()
        t001 = t00.get_child(1)

        self.assertEqual(forest.find_node([0, 0, 1]), t001)
        self.assertIsNone(forest.find_node([4, 5, 6]))
Ejemplo n.º 20
0
    def test_in_forest(self):
        #
        # Initialize forest with node
        #
        node = Tree(n_children=2)
        Forest([node])
        self.assertTrue(node.in_forest(), 'Node should be in forest')
        #
        # Initialize empty forest
        #
        node = Tree(n_children=2)
        forest = Forest()

        # Node should not be in there
        self.assertFalse(node.in_forest(), 'Node should not be in a forest.')

        # Add node: NOW it's in the forest.
        forest.add_tree(node)
        self.assertTrue(node.in_forest(), 'Node should be in a forest.')

        # Remove node: it should no longer be there.
        forest.remove_tree(node.get_node_position())
        self.assertFalse(node.in_forest(),
                         'Node should no longer be in the forest.')
Ejemplo n.º 21
0
    def test_get_node_type(self):
        # Define new ROOT node
        node = Tree(n_children=2)
        self.assertEqual(node.get_node_type(),'ROOT',\
                         'Output "node_type" should be "ROOT".')

        # Split node and assert that its child is a LEAF
        node.split()
        child = node.get_child(0)
        self.assertEqual(node.get_node_type(),'ROOT',\
                         'Output "node_type" should be "ROOT".')
        self.assertEqual(child.get_node_type(),'LEAF',\
                         'Output "node_type" should be "LEAF".')

        # Split the child and assert that it is now a BRANCH
        child.split()
        self.assertEqual(child.get_node_type(),'BRANCH',\
                         'Output "node_type" should be "BRANCH".')
Ejemplo n.º 22
0
    def test_n_children(self):
        for n in range(10):
            #
            # Define regular node with n children
            #
            node = Tree(n_children=n)

            # check if number of children correct
            self.assertEqual(node.n_children(),n,\
                             'Number of children incorrect')
            # split node and check if children inherit the same number of children
            node.split()
            for child in node.get_children():
                self.assertEqual(child.n_children(),n,\
                                 'Children have incorrect number of children.')
Ejemplo n.º 23
0
    def test_constructor(self):

        t0 = Tree(2)

        # Check whether all entries are Trees
        self.assertRaises(Exception, Forest, **{'trees': [t0, 0]})

        # Check whether Trees are roots
        t0.split()
        t00 = t0.get_child(0)
        self.assertRaises(Exception, Forest, **{'trees': [t0, t00]})

        t1 = Tree(4)
        forest = Forest([t0, t1])
        self.assertEqual(len(forest._trees), 2)
Ejemplo n.º 24
0
    def test_refine(self):
        # =====================================================================
        # Simple Refineement
        # =====================================================================
        # Define a new forest with two binary trees
        forest = Forest([Tree(2), Tree(2)])

        # Refine the forest indiscriminantly
        forest.refine()

        # Check wether the trees have been split
        count = 0
        for dummy in forest.traverse():
            count += 1
        self.assertEqual(count, 6)

        # =====================================================================
        # Refinement Label
        # =====================================================================
        # Mark second tree and refine only by its label
        forest.get_child(1).mark(1)
        forest.refine(refinement_flag=1)

        # Nothing should have happened (because child is not a leaf)
        count = 0
        for dummy in forest.traverse():
            count += 1
        self.assertEqual(count, 6)

        forest.get_child(1).get_child(1).mark(1)
        forest.refine(refinement_flag=1)

        count = 0
        for dummy in forest.traverse():
            count += 1
        self.assertEqual(count, 8)

        # =====================================================================
        # Refinement of subforest
        # =====================================================================
        forest = Forest([Tree(2), Tree(2)])
        forest.refine()

        # Define subforest
        forest.get_child(0).get_child(0).mark(1)
        forest.root_subtrees(1)

        # Check node count
        count = 0
        for dummy in forest.traverse(1):
            count += 1
        self.assertEqual(count, 4)

        forest.refine(subforest_flag=1)

        # Check node count
        count = 0
        for dummy in forest.traverse(1):
            count += 1
        self.assertEqual(count, 10)

        forest.coarsen(subforest_flag=1)

        count = 0
        for dummy in forest.traverse(1):
            count += 1
        self.assertEqual(count, 4)

        # Now try with a refinement flag
        forest.get_child(1).mark(2)

        forest.refine(subforest_flag=1, refinement_flag=2)

        # Check node count
        count = 0
        for dummy in forest.traverse(1):
            count += 1
        self.assertEqual(count, 6)

        # =====================================================================
        # Refine with new label
        # =====================================================================
        forest = Forest([Tree(2), Tree(2)])
        forest.refine()

        # Define subforest
        forest.get_child(0).get_child(0).mark(1)
        forest.root_subtrees(1)

        # Check node count
        count = 0
        for dummy in forest.traverse(1):
            count += 1
        self.assertEqual(count, 4)

        forest.refine(subforest_flag=1, new_label=4)

        # Node count of new label
        count = 0
        for dummy in forest.traverse(4):
            count += 1
        self.assertEqual(count, 10)

        # Node count for original submesh
        count = 0
        for dummy in forest.traverse(1):
            count += 1
        self.assertEqual(count, 4)

        # Refinement marker
        forest.get_child(1).mark(3)

        # Refine
        forest.refine(subforest_flag=1, refinement_flag=3, new_label=5)

        # Check node count
        count = 0
        for dummy in forest.traverse(5):
            count += 1
        self.assertEqual(count, 6)
Ejemplo n.º 25
0
    def test_traverse(self):
        #
        # Binary Tree
        #
        # Standard
        node = Tree(2)
        forest = Forest([node])

        node.split()
        node.get_child(0).split()
        node.get_child(0).get_child(1).remove()
        addresses = {
            'breadth-first': [[0], [0, 0], [0, 1], [0, 0, 0]],
            'depth-first': [[0], [0, 0], [0, 0, 0], [0, 1]]
        }

        for mode in ['depth-first', 'breadth-first']:
            count = 0
            for leaf in forest.traverse(mode=mode):
                self.assertEqual(leaf.get_node_address(),
                                 addresses[mode][count]),

                count += 1
        #
        # QuadTree
        #
        node = Tree(4)
        forest = Forest([node])
        node.split()
        node.get_child(1).split()
        node.get_child(1).get_child(2).remove()
        addresses = [[0], [0, 0], [0, 1], [0, 2], [0, 3], [0, 1, 0], [0, 1, 1],
                     [0, 1, 3]]
        count = 0
        for n in node.traverse(mode='breadth-first'):
            self.assertEqual(n.get_node_address(), addresses[count],\
                             'Incorrect address.')
            count += 1

        #
        # Forest with one quadtree and one bitree
        #
        bi = Tree(2)
        quad = Tree(4)
        forest = Forest([bi, quad])
        bi.split()
        bi.get_child(0).split()
        quad.split()
        addresses = {
            'breadth-first': [[0], [1], [0, 0], [0, 1], [1, 0], [1, 1], [1, 2],
                              [1, 3], [0, 0, 0], [0, 0, 1]],
            'depth-first': [[0], [0, 0], [0, 0, 0], [0, 0, 1], [0, 1], [1],
                            [1, 0], [1, 1], [1, 2], [1, 3]]
        }

        for mode in ['depth-first', 'breadth-first']:
            count = 0
            for leaf in forest.traverse(mode=mode):
                self.assertEqual(leaf.get_node_address(),
                                 addresses[mode][count])

                count += 1
Ejemplo n.º 26
0
    def test_coarsen(self):
        # =====================================================================
        # Simple Coarsening
        # =====================================================================
        # Define a new forest with two quadtrees
        forest = Forest([Tree(4), Tree(4)])

        # Coarsen, nothing should happen
        forest.coarsen()

        # Check that forest is as it was
        count = 0
        for dummy in forest.traverse():
            count += 1
        self.assertEqual(count, 2)

        # Refine and coarsen again
        forest.refine()
        forest.coarsen()

        # Check that forest is as it was
        count = 0
        for dummy in forest.traverse():
            count += 1
        self.assertEqual(count, 2)

        # =====================================================================
        # Coarsening Flag
        # =====================================================================
        # Refine and mark one grandchild
        forest.refine()

        # Check that forest now has 10 Tree
        count = 0
        for dummy in forest.traverse():
            count += 1
        self.assertEqual(count, 10)

        forest.get_child(0).mark(1)
        forest.coarsen(coarsening_flag=1)

        # Tree 0 should not have children, while Tree 1 should have 4
        count = 0
        for dummy in forest.traverse():
            count += 1
        self.assertEqual(count, 6)

        self.assertFalse(forest.get_child(0).has_children())
        self.assertTrue(forest.get_child(1).has_children())

        # Nothing is marked 1
        self.assertFalse(any(
            child.is_marked(1) for child in forest.traverse()))

        # =====================================================================
        # Coarsening with subforests
        # =====================================================================
        forest = Forest([Tree(2), Tree(2)])

        forest.refine()

        # Make a subforest 0, 00, 01, 1
        forest.get_child(0).get_child(0).mark(2)
        forest.root_subtrees(2)

        count = 0
        for node in forest.traverse(flag=2):
            count += 1
        self.assertEqual(count, 4)

        # Coarsen the subforest
        forest.coarsen(subforest_flag=2)

        # Subforest now contains [0], [1]
        count = 0
        for dummy in forest.traverse(flag=2):
            count += 1
        self.assertEqual(count, 2)

        # Forest still contains 4 nodes (0,1,00,01,10,11)
        count = 0
        for dummy in forest.traverse():
            count += 1
        self.assertEqual(count, 6)

        # New subforest: 0, 00, 01, 1
        forest.get_child(0).get_child(0).mark(2)
        forest.root_subtrees(2)

        # Count subforest nodes
        count = 0
        for dummy in forest.traverse(2):
            count += 1
        self.assertEqual(count, 4)

        # Coarsening flag at a node not in the subforest
        forest.get_child(1).get_child(1).mark(1)

        forest.coarsen(subforest_flag=2, coarsening_flag=1)

        # Count subforest nodes
        count = 0
        for dummy in forest.traverse(2):
            count += 1
        self.assertEqual(count, 4)

        # Apply coarsening flag to a node in the subforest
        forest.get_child(0).mark(1)

        # Coarsen
        forest.coarsen(subforest_flag=2, coarsening_flag=1)

        # Now there should be 2 subnodes
        count = 0
        for dummy in forest.traverse(2):
            count += 1
        self.assertEqual(count, 2)

        # Make sure the coarsening flag is deleted.
        self.assertFalse(forest.get_child(0).is_marked(1))

        # =====================================================================
        # Coarsening with new_label
        # =====================================================================
        # TODO: TEST HERE.

        # Subforest is: 0, 00, 01, 1
        forest.get_child(0).get_child(0).mark(2)
        forest.root_subtrees(2)

        # Mark [0,0] with coarsening flag
        forest.get_child(0).mark(1)

        # Coarsen subforest and label with new_label
        forest.coarsen(subforest_flag=2,
                       coarsening_flag=1,
                       new_label=3,
                       debug=True)

        # Check that subforest still has the same nodes
        count = 0
        for dummy in forest.traverse(2):
            count += 1
        self.assertEqual(count, 4)

        # Check that the new submesh has fewer
        count = 0
        for dummy in forest.traverse(3):
            count += 1
        self.assertEqual(count, 2)

        #
        # Now with no submesh
        #
        # Check that forest still has 6 nodes
        count = 0
        for dummy in forest.traverse():
            count += 1
        self.assertEqual(count, 6)

        # Mark with coarsening flag
        forest.get_child(0).mark(1)

        # Coarsen
        forest.coarsen(coarsening_flag=1, new_label=4)

        # Check that forest still has 6 nodes
        count = 0
        for dummy in forest.traverse():
            count += 1
        self.assertEqual(count, 6)

        # Check that subforest has 5 nodes
        count = 0
        for dummy in forest.traverse(4):
            count += 1
        self.assertEqual(count, 4)
Ejemplo n.º 27
0
from mesh import QuadMesh, Vertex, HalfEdge, QuadCell, Mesh1D, Interval, Tree
from fem import DofHandler, QuadFE, GaussRule, Function
from mesh import convert_to_array
from plot import Plot
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import numpy as np

mtags = Tree(regular=False)
mesh = Mesh1D(resolution=(1, ))

flag = tuple(mtags.get_node_address())
mesh.cells.record(flag)

mtags.add_child()
new_flag = tuple(mtags.get_child(0).get_node_address())
mesh.cells.refine(subforest_flag=flag, \
                  new_label=new_flag)

for leaf in mesh.cells.get_leaves(subforest_flag=flag):
    leaf.info()

print('==' * 20)

for leaf in mesh.cells.get_leaves(subforest_flag=new_flag):
    leaf.info()

element = QuadFE(1, 'Q2')
dofhandler = DofHandler(mesh, element)
dofhandler.distribute_dofs()
dofhandler.set_dof_vertices()
Ejemplo n.º 28
0
 def test_child(self):
     forest = Forest()
     self.assertRaises(Exception, forest.get_child, *(0, ))
     forest = Forest([Tree(1), Tree(4)])
     self.assertEqual(forest.get_child(1).n_children(), 4)
Ejemplo n.º 29
0
    def test_constructor(self):
        # =====================================================================
        # Test 1D 
        # ===================================================================== 
        
        #
        # Kernel consists of a single explicit Function: 
        # 
        f1 = lambda x: x+2
        f = Explicit(f1, dim=1)
        k = Kernel(f)
        x = np.linspace(0,1,100)
        n_points = len(x)
        
        # Check that it evaluates correctly.
        self.assertTrue(np.allclose(f1(x), k.eval(x).ravel()))
        
        # Check shape of kernel
        self.assertEqual(k.eval(x).shape, (n_points,1))
        
        #
        # Kernel consists of a combination of two explicit functions
        # 
        f1 = Explicit(lambda x: x+2, dim=1)
        f2 = Explicit(lambda x: x**2 + 1, dim=1)
        F = lambda f1, f2: f1**2 + f2
        f_t = lambda x: (x+2)**2 + x**2 + 1
        k = Kernel([f1,f2], F=F)
        
        # Check evaluation
        self.assertTrue(np.allclose(f_t(x), k.eval(x).ravel()))
        
        # Check shape 
        self.assertEqual(k.eval(x).shape, (n_points,1))
        

        #
        # Same thing as above, but with nodal functions
        # 
        mesh = Mesh1D(resolution=(1,))
        Q1 = QuadFE(1,'Q1')
        Q2 = QuadFE(1,'Q2')
        
        dQ1 = DofHandler(mesh,Q1)
        dQ2 = DofHandler(mesh,Q2)
        
        # Distribute dofs
        [dQ.distribute_dofs() for dQ in [dQ1,dQ2]]
        
        # Basis functions
        phi1 = Basis(dQ1,'u')
        phi2 = Basis(dQ2,'u')
        
        f1 = Nodal(lambda x: x+2, basis=phi1)
        f2 = Nodal(lambda x: x**2 + 1, basis=phi2)
        k = Kernel([f1,f2], F=F)
        
        # Check evaluation
        self.assertTrue(np.allclose(f_t(x), k.eval(x).ravel()))
        
        #
        # Replace f2 above with its derivative
        # 
        k = Kernel([f1,f2], derivatives=['f', 'fx'], F=F)
        f_t = lambda x: (x+2)**2 + 2*x
                
        # Check derivative evaluation F = F(f1, df2_dx)
        self.assertTrue(np.allclose(f_t(x), k.eval(x).ravel()))
        
        
        # 
        # Sampling 
        #
        one = Constant(1)
        f1 = Explicit(lambda x: x**2 + 1, dim=1)
        
        # Sampled function
        a = np.linspace(0,1,11)
        n_samples = len(a)
        
        # Define Dofhandler
        dh = DofHandler(mesh, Q2)
        dh.distribute_dofs()
        dh.set_dof_vertices()
        xv = dh.get_dof_vertices()
        n_dofs = dh.n_dofs()
        
        phi = Basis(dh, 'u')
        
        # Evaluate parameterized function at mesh dof vertices
        f2_m  = np.empty((n_dofs, n_samples))
        for i in range(n_samples):
            f2_m[:,i] = xv.ravel() + a[i]*xv.ravel()**2
        f2 = Nodal(data=f2_m, basis=phi)
        
        # Define kernel
        F = lambda f1, f2, one: f1 + f2 + one
        k = Kernel([f1,f2,one], F=F)
        
        # Evaluate on a fine mesh
        x = np.linspace(0,1,100)
        n_points = len(x)
        self.assertEqual(k.eval(x).shape, (n_points, n_samples))    
        for i in range(n_samples):
            # Check evaluation            
            self.assertTrue(np.allclose(k.eval(x)[:,i], f1.eval(x)[:,i] + x + a[i]*x**2+ 1))
            
        
        #
        # Sample multiple constant functions
        # 
        f1 = Constant(data=a)
        f2 = Explicit(lambda x: 1 + x**2, dim=1)
        f3 = Nodal(data=f2_m[:,-1], basis=phi)
        
        F = lambda f1, f2, f3: f1 + f2 + f3
        k = Kernel([f1,f2,f3], F=F)
        
        x = np.linspace(0,1,100)
        for i in range(n_samples):
            self.assertTrue(np.allclose(k.eval(x)[:,i], \
                                        a[i] + f2.eval(x)[:,i] + f3.eval(x)[:,i]))
        
        #
        # Submeshes
        # 
        mesh = Mesh1D(resolution=(1,))
        mesh_labels = Tree(regular=False)
        
        mesh = Mesh1D(resolution=(1,))
        Q1 = QuadFE(1,'Q1')
        Q2 = QuadFE(1,'Q2')
        
        dQ1 = DofHandler(mesh,Q1)
        dQ2 = DofHandler(mesh,Q2)
        
        # Distribute dofs
        [dQ.distribute_dofs() for dQ in [dQ1,dQ2]]
        
        # Basis
        p1 = Basis(dQ1)
        p2 = Basis(dQ2) 
        
        
        f1 = Nodal(lambda x: x, basis=p1)
        f2 = Nodal(lambda x: -2+2*x**2, basis=p2)
        one = Constant(np.array([1,2]))
    
        F = lambda f1, f2, one: 2*f1**2 + f2 + one
        
        I = mesh.cells.get_child(0)
        
        kernel = Kernel([f1,f2, one], F=F)
        
        rule1D = GaussRule(5,shape='interval')
        x = I.reference_map(rule1D.nodes())
Ejemplo n.º 30
0
    def test_get_leaves(self):
        #
        # 1D
        #
        node = Tree(2)
        forest = Forest([node])
        leaves = forest.get_leaves()

        # Only a ROOT node, it should be the only LEAF
        self.assertEqual(leaves, [node], 'Cell should be its own leaf.')

        #
        # Split cell and L child - find leaves
        #
        node.split()
        l_child = node.get_child(0)
        l_child.split()
        leaves = forest.get_leaves()
        self.assertEqual(len(leaves), 3, 'Cell should have 3 leaves.')

        #
        # Depth first order
        #
        addresses_depth_first = [[0, 0, 0], [0, 0, 1], [0, 1]]
        leaves = forest.get_leaves(mode='depth-first')
        for i in range(len(leaves)):
            leaf = leaves[i]
            self.assertEqual(leaf.get_node_address(), addresses_depth_first[i],
                             'Incorrect order, depth first search.')
        #
        # Breadth first order
        #
        addresses_breadth_first = [[0, 1], [0, 0, 0], [0, 0, 1]]
        leaves = node.get_leaves(mode='breadth-first')
        for i in range(len(leaves)):
            leaf = leaves[i]
            self.assertEqual(leaf.get_node_address(),
                             addresses_breadth_first[i],
                             'Incorrect order, breadth first search.')

        node.get_child(0).get_child(0).mark('1')
        node.get_child(1).mark('1')
        node.make_rooted_subtree('1')
        leaves = node.get_leaves(subtree_flag='1')
        self.assertEqual(len(leaves),2, \
                         'There should only be 2 flagged leaves')

        #
        # 2D
        #
        node = Tree(4)
        forest = Forest([node])

        #
        # Split cell and SW child - find leaves
        #
        node.split()
        sw_child = node.get_child(0)
        sw_child.split()
        leaves = node.get_leaves()
        self.assertEqual(len(leaves), 7, 'Node should have 7 leaves.')

        #
        # Nested traversal
        #
        leaves = node.get_leaves()
        self.assertEqual(leaves[0].get_node_address(),[0,1], \
            'The first leaf in the nested enumeration should have address [1]')

        leaves = node.get_leaves(mode='depth-first')
        self.assertEqual(leaves[0].get_node_address(), [0,0,0], \
                         'First leaf in un-nested enumeration should be [0,0].')

        #
        # Merge SW child - find leaves
        #
        sw_child.delete_children()

        leaves = node.get_leaves()
        self.assertEqual(len(leaves), 4, 'Node should have 4 leaves.')

        #
        # Marked Leaves
        #
        node = Tree(4)
        node.mark(1)
        forest = Forest([node])
        self.assertTrue(node in forest.get_leaves(flag=1), \
                        'Node should be a marked leaf node.')
        self.assertTrue(node in forest.get_leaves(), \
                        'Node should be a marked leaf node.')

        node.split()
        sw_child = node.get_child(0)
        sw_child.split()
        sw_child.mark(1)
        self.assertEqual(node.get_leaves(subtree_flag=1), \
                         [sw_child], 'SW child should be only marked leaf')

        sw_child.remove()
        self.assertEqual(forest.get_leaves(subforest_flag=1), \
                         [node], 'node should be only marked leaf')

        #
        # Nested traversal
        #
        node = Tree(4)
        node.split()
        forest = Forest([node])
        for child in node.get_children():
            child.split()

        node.get_child(1).mark(1, recursive=True)
        node.get_child(3).mark(1)
        forest.root_subtrees(1)
        leaves = forest.get_leaves(subforest_flag=1)
        self.assertEqual(len(leaves), 7, 'This tree has 7 flagged LEAF nodes.')
        self.assertEqual(leaves[0], node.get_child(0),
                         'The first leaf should be the NE child.')
        self.assertEqual(leaves[3],
                         node.get_child(1).get_child(0),
                         '4th flagged leaf should be SE-NW grandchild.')