예제 #1
0
class TestNodeVisitors(unittest.TestCase):
    def setUp(self):
        self.root = RootNode()
        self.empty = Empty(name="empty")
        self.par_a = Parameter(2, name="par_a")
        self.par_b = Parameter(3, name="par_b")
        self.par_c = Parameter(4, name="par_c")
        self.func_1 = self.par_a + self.par_b
        self.func_2 = self.par_a * self.par_c
        self.func_3 = Function(lambda a, b: a + b)
        self.alias = Alias(self.par_a, name="alias")
        self.par_a.set_children([self.par_b])
        self.root.set_children(
            [self.empty, self.func_1, self.func_2, self.func_3, self.alias])

    # -- NodeChildrenPrinter

    def test_node_children_printer(self):
        self.root.print_descendants()

    # -- NodeCycleChecker

    def test_node_cycle_checker(self):

        NodeCycleChecker(self.par_a).run()

        self.par_b.set_children([self.par_a])

        with self.assertRaises(ValueError):
            NodeCycleChecker(self.par_a).run()
예제 #2
0
    def test_node_add_parent(self):
        par = Parameter(3, name='par')
        par_2 = Parameter(2, name='par_2')

        # Manually adding parents is not allowed:
        with self.assertRaises(NodeException):
            par.add_parent(par_2)
예제 #3
0
    def test_function_replace(self):
        par = Parameter(3, name='par')
        selector = Function(
            lambda *args: args[1],
            name='selector',
            parameters=(
                Parameter('bla', name='first'),
                par,
                Parameter('blup', name='third'),
            ),
        )
        context = Parameter(2, name='pre_factor') * selector

        # store parents and children for comparison
        _children_before = selector.get_children()
        _parameters_before = selector.parameters
        _parents_before = selector.get_parents()

        # replace node
        selector_new = Function(lambda *args: args[2], name='selector_new')
        selector.replace(selector_new, other_children=False)

        # retrieve new parents and children for comparison
        _children_after = selector_new.get_children()
        _parents_after = selector_new.get_parents()
        _parameters_after = selector_new.parameters

        self.assertEqual(_children_before, _children_after)
        self.assertEqual(_parents_before, _parents_after)
        self.assertEqual(_parameters_before, _parameters_after)

        self.assertEqual(selector_new.value, 'blup')
예제 #4
0
    def test_fallback(self):
        par_a = Parameter(14)
        par_b = Parameter(2)

        div_a_b = Fallback((par_a / par_b, 'DIV0'),
                           exception_type=ZeroDivisionError)

        # no exception yet
        self.assertEqual(
            div_a_b.value,
            par_a.value / par_b.value,
        )

        # make denominator zero
        par_b.value = 0

        # exception matches type -> fallback
        self.assertEqual(
            div_a_b.value,
            'DIV0',
        )

        # exception does not match type -> raise
        with self.assertRaises(ZeroDivisionError):
            Fallback((par_a / par_b, 'DIV0'), exception_type=TypeError).value
예제 #5
0
    def test_function_value_update_frozen(self):
        par_a = Parameter(3)
        par_b = Parameter(7)

        def func(a, b):
            return a * 10 + b

        func_a_b = Function(func, parameters=(par_a, par_b))

        self.assertEqual(
            func_a_b.value,
            func(par_a.value, par_b.value),
        )

        # test value update
        func_a_b.freeze()
        par_a.value = -3

        self.assertEqual(
            func_a_b.value,
            func(3, 7),
        )

        func_a_b.unfreeze()
        self.assertEqual(
            func_a_b.value,
            func(par_a.value, par_b.value),
        )
예제 #6
0
    def setUp(self):
        self.par_a = Parameter(3)
        self.par_b = Parameter(7)

        self.func_sum_a_b = Function(TestNodes.sum_function,
                                     parameters=(self.par_a, self.par_b))

        self.empty_par = Empty()
예제 #7
0
    def test_node_get_children_get_parents(self):
        par_1 = Parameter(3, name='par_1')
        par_2 = Parameter(2, name='par_2')
        par_lc_1 = par_1 + par_2
        par_lc_2 = par_1 - par_2

        for _par in par_1, par_2:
            self.assertEqual(_par.get_parents(), [par_lc_1, par_lc_2])
        for _par in par_lc_1, par_lc_2:
            self.assertEqual(_par.get_children(), [par_1, par_2])
예제 #8
0
    def test_add_existing_ignore(self):
        par_orig = Parameter('my_original_value', name='par')
        self._nexus.add(par_orig)

        par_new = Parameter('my_new_value', name='par')
        self._nexus.add(par_new, existing_behavior='ignore')

        self.assertIs(
            self._nexus.get('par'),
            par_orig,
        )
예제 #9
0
    def test_add_expression_compare_value(self):
        a = Parameter(3)
        b = Parameter(5)

        expr = (a + b)
        expr.name = 'sum'

        self._nexus.add(a)
        self._nexus.add(b)
        self._nexus.add(expr)

        self.assertEqual(self._nexus.get('sum').value, 8)
예제 #10
0
    def test_add_function_custom_par_names(self):
        x = self._nexus.add(Parameter(7, name='x'))
        y = self._nexus.add(Parameter(8, name='y'))

        def my_func(a, b):
            return 2 * a + b

        func_node = self._nexus.add_function(func=my_func,
                                             par_names=['x', 'y'])

        # check function value
        self.assertEqual(func_node.value, my_func(x.value, y.value))
예제 #11
0
    def test_function_value_lambda(self):
        par_a = Parameter(3)
        par_b = Parameter(7)

        func_a_b = Function(lambda a, b: a * 10 + b,
                            name='_lambda',
                            parameters=(par_a, par_b))

        self.assertEqual(
            func_a_b.value,
            func_a_b.func(par_a.value, par_b.value),
        )
예제 #12
0
    def test_add_function_varargs(self):
        x = self._nexus.add(Parameter(7, name='x'))
        y = self._nexus.add(Parameter(8, name='y'))

        def my_func(*args):
            return sum(args)

        func_node = self._nexus.add_function(func=my_func,
                                             par_names=['x', 'y'])

        # check function value
        self.assertEqual(func_node.value, my_func(x.value, y.value))
예제 #13
0
    def test_node_iter_children_iter_parents(self):
        par_1 = Parameter(3, name='par_1')
        par_2 = Parameter(2, name='par_2')
        par_lc_1 = par_1 + par_2
        par_lc_2 = par_1 - par_2

        for _par in par_1, par_2:
            self.assertEqual([_parent for _parent in _par.iter_parents()],
                             [par_lc_1, par_lc_2])
        for _par in par_lc_1, par_lc_2:
            self.assertEqual([_child for _child in _par.iter_children()],
                             [par_1, par_2])
예제 #14
0
    def test_node_remove_child(self):
        par = Parameter(3, name='par')
        child_1 = Parameter(1, name='child_1')
        child_2 = Parameter(2, name='child_2')

        par.set_children([child_1, child_2])
        par.remove_child(child_1)

        self.assertEqual(par.get_children(), [child_2])
예제 #15
0
    def test_node_remove_parent(self):
        par = Parameter(3, name='par')
        parent_1 = Parameter(1, name='parent_1')
        parent_2 = Parameter(2, name='parent_2')

        par.set_parents([parent_1, parent_2])
        par.remove_parent(parent_1)

        self.assertEqual(par.get_parents(), [parent_2])
예제 #16
0
    def setUp(self):
        self.par_a = Parameter(3)
        self.par_b = Parameter(7)

        self.func_sum_a_b = Function(TestNodes.sum_function,
                                     parameters=(self.par_a, self.par_b))

        self.empty_par = Empty()

        self.counter = 0
        self.sum = 0

        self.array_content = [0.1, 2.3, 4.5, 6.7, 8.9]
예제 #17
0
 def setUp(self):
     self.root = RootNode()
     self.empty = Empty(name="empty")
     self.par_a = Parameter(2, name="par_a")
     self.par_b = Parameter(3, name="par_b")
     self.par_c = Parameter(4, name="par_c")
     self.func_1 = self.par_a + self.par_b
     self.func_2 = self.par_a * self.par_c
     self.func_3 = Function(lambda a, b: a + b)
     self.alias = Alias(self.par_a, name="alias")
     self.par_a.set_children([self.par_b])
     self.root.set_children(
         [self.empty, self.func_1, self.func_2, self.func_3, self.alias])
예제 #18
0
    def test_add_expression_test_dependents_added(self):
        a = Parameter(3)
        b = Parameter(5)

        expr = (a - b) / (a + b)
        expr.name = 'asymm'

        expr_value = (a.value - b.value) / (a.value + b.value)

        self._nexus.add(expr)

        self.assertIs(self._nexus.get(a.name), a)
        self.assertIs(self._nexus.get(b.name), b)
예제 #19
0
    def test_function_auto_parameters(self):
        def func(a, b):
            return a * 10 + b

        func_a_b = Function(func)

        par_a = Parameter(3)
        par_b = Parameter(7)

        func_a_b.parameters = [par_a, par_b]

        self.assertEqual(
            func_a_b.parameters,
            [par_a, par_b],
        )
예제 #20
0
    def test_nodes_parametric_constructor(self):
        p = Parameter(None)
        str(p)

        a = Alias(Parameter(None))
        str(a)

        f = Function(lambda x: x, name='_lambda')
        str(f)

        fb = Fallback((f, ))
        str(fb)

        t = Tuple([])
        str(t)
예제 #21
0
    def test_add_dependency(self):
        def test_func(x=3):
            return 2 * x

        func_node = self._nexus.add_function(func=test_func, par_names=['x'])

        y = self._nexus.add(Parameter(4, name='y'))

        self.assertEqual(func_node.value, 6)
        self.assertEqual(func_node.stale, False)

        # here 'test_func' does not go stale on 'y' update
        y.value = 8
        self.assertEqual(func_node.stale, False)

        self._nexus.add_dependency('test_func', depends_on='y')

        # now 'test_func' depends on 'y' and should go stale
        y.value = 23
        self.assertEqual(func_node.stale, True)
        self.assertEqual(func_node.value, 6)

        with self.assertRaises(ValueError):
            self._nexus.add_dependency('DEADBEEF', depends_on='y')
        with self.assertRaises(ValueError):
            self._nexus.add_dependency('y', depends_on=['a', 'b'])
예제 #22
0
 def test_tuple_setitem(self):
     tuple_a = Tuple((14, 2))
     tuple_a[0] = Parameter(15)
     tuple_a[1] = 1
     self.assertEqual(tuple_a.value, (15, 1))
     self.assertEqual(tuple_a[0].value, 15)
     self.assertEqual(tuple_a[1].value, 1)
예제 #23
0
 def test_parameter_binary_operation_with_literal(self):
     par_a = Parameter(3)
     for _op_name, (_op, _) in _OPERATORS.items():
         par_expr = _op(par_a, 7)
         self.assertEqual(par_expr.value, _op(par_a.value, 7))
         # reversed order
         par_expr = _op(7, par_a)
         self.assertEqual(par_expr.value, _op(7, par_a.value))
예제 #24
0
    def test_function_parameters(self):
        par_a = Parameter(3)
        par_b = Parameter(7)

        def func(a, b):
            return a * 10 + b

        func_a_b = Function(func, parameters=(par_a, par_b))

        self.assertEqual(
            func_a_b.get_children(),
            [par_a, par_b],
        )
        self.assertEqual(
            func_a_b.parameters,
            [par_a, par_b],
        )
예제 #25
0
    def test_add_function_existing_parameter_with_default(self):

        a = self._nexus.add(Parameter(4, name='a'))
        b = self._nexus.add(Parameter(9, name='b'))

        def my_func(a=4, b=200):
            return 2 * a + b

        with self.assertWarns(UserWarning) as w:
            func_node = self._nexus.add_function(func=my_func)

        # check if parameter values correspond to existing parameters
        self.assertEqual([p.value for p in func_node.parameters],
                         [a.value, b.value])

        self.assertIn('Ignoring default value', w.warning.args[0])
        self.assertIn('conflicting value', w.warning.args[0])
예제 #26
0
    def test_function_auto_parameters_add_one_at_a_time(self):
        def func(a, b):
            return a * 10 + b

        func_a_b = Function(func)

        par_a = Parameter(3)
        par_b = Parameter(7)

        func_a_b.add_parameter(par_a)
        func_a_b.add_parameter(par_b)

        self.assertEqual(
            func_a_b.parameters,
            [par_a, par_b],
        )
        self.assertEqual(func_a_b.value, 37)
예제 #27
0
    def test_add_get_explicit_name(self):
        par = Parameter('my_value', name='par')
        self._nexus.add(par)

        self.assertIs(
            self._nexus.get('par'),
            par,
        )
예제 #28
0
    def test_add_get(self):
        par = Parameter('my_value')
        self._nexus.add(par)

        self.assertIs(
            self._nexus.get(par.name),
            par,
        )
예제 #29
0
 def test_get_value_dict(self):
     self._nexus.add(Parameter(1, name="a"))
     self._nexus.add(Parameter(2, name="b"))
     self._nexus.add(Parameter(3, name="c"))
     self._nexus.add_function(lambda a, b, c: a + b * c, func_name="func")
     value_dict = self._nexus.get_value_dict()
     self.assertEqual(len(value_dict), 4)
     self.assertEqual(value_dict["a"], 1)
     self.assertEqual(value_dict["b"], 2)
     self.assertEqual(value_dict["c"], 3)
     self.assertEqual(value_dict["func"], 7)
     value_dict = self._nexus.get_value_dict(node_names=["b", "a", "func"])
     self.assertEqual(len(value_dict), 3)
     self.assertEqual(value_dict["a"], 1)
     self.assertEqual(value_dict["b"], 2)
     self.assertEqual(value_dict["func"], 7)
     with self.assertRaises(ValueError):
         _ = self._nexus.get_value_dict(error_behavior="bogus")
예제 #30
0
    def test_add_function_existing_parameters(self):

        a = self._nexus.add(Parameter(4, name='a'))
        b = self._nexus.add(Parameter(5, name='b'))

        def my_func(a, b):
            return 2 * a + b

        func_node = self._nexus.add_function(func=my_func)

        # check if name is as in signature
        self.assertEqual(func_node.name, 'my_func')

        # check if parameter values correspond to existing parameters
        self.assertEqual([p.value for p in func_node.parameters],
                         [a.value, b.value])
        # check function value
        self.assertEqual(func_node.value, my_func(a.value, b.value))