Ejemplo n.º 1
0
    def test_function_replace(self):
        par = Parameter(3, name='par')
        selector = Function(
            lambda *args: args[1],
            name='selector',
            parameters=(
                Parameter('bla', name='first'),
                par,
                Parameter('blup', name='third'),
            ),
        )
        context = Parameter(2, name='pre_factor') * selector

        # store parents and children for comparison
        _children_before = selector.get_children()
        _parameters_before = selector.parameters
        _parents_before = selector.get_parents()

        # replace node
        selector_new = Function(lambda *args: args[2], name='selector_new')
        selector.replace(selector_new, other_children=False)

        # retrieve new parents and children for comparison
        _children_after = selector_new.get_children()
        _parents_after = selector_new.get_parents()
        _parameters_after = selector_new.parameters

        self.assertEqual(_children_before, _children_after)
        self.assertEqual(_parents_before, _parents_after)
        self.assertEqual(_parameters_before, _parameters_after)

        self.assertEqual(selector_new.value, 'blup')
Ejemplo n.º 2
0
    def test_function_value_lambda(self):
        par_a = Parameter(3)
        par_b = Parameter(7)

        func_a_b = Function(lambda a, b: a * 10 + b,
                            name='_lambda',
                            parameters=(par_a, par_b))

        self.assertEqual(
            func_a_b.value,
            func_a_b.func(par_a.value, par_b.value),
        )
Ejemplo n.º 3
0
    def setUp(self):
        self.par_a = Parameter(3)
        self.par_b = Parameter(7)

        self.func_sum_a_b = Function(TestNodes.sum_function,
                                     parameters=(self.par_a, self.par_b))

        self.empty_par = Empty()

        self.counter = 0
        self.sum = 0

        self.array_content = [0.1, 2.3, 4.5, 6.7, 8.9]
Ejemplo n.º 4
0
    def test_function_auto_parameters(self):
        def func(a, b):
            return a * 10 + b

        func_a_b = Function(func)

        par_a = Parameter(3)
        par_b = Parameter(7)

        func_a_b.parameters = [par_a, par_b]

        self.assertEqual(
            func_a_b.parameters,
            [par_a, par_b],
        )
Ejemplo n.º 5
0
    def test_parameter_replace(self):
        par = Parameter(3, name='par')
        selector = Function(
            lambda *args: args[1],
            name='selector',
            parameters=(
                Parameter('bla', name='first'),
                par,
                Parameter('blup', name='third'),
            ),
        )
        context = Parameter(2, name='pre_factor') * selector

        # store parents and children for comparison
        _children_before = par.get_children()
        _parents_before = par.get_parents()

        # replace node
        par_new = Parameter(7, name='par_new')
        par.replace(par_new)

        par_new.value = 6

        self.assertEqual(context.value, 12)

        # retrieve new parents and children for comparison
        _children_after = par_new.get_children()
        _parents_after = par_new.get_parents()

        self.assertEqual(_children_before, _children_after)
        self.assertEqual(_parents_before, _parents_after)
Ejemplo n.º 6
0
    def test_function_auto_parameters_add_one_at_a_time(self):
        def func(a, b):
            return a * 10 + b

        func_a_b = Function(func)

        par_a = Parameter(3)
        par_b = Parameter(7)

        func_a_b.add_parameter(par_a)
        func_a_b.add_parameter(par_b)

        self.assertEqual(
            func_a_b.parameters,
            [par_a, par_b],
        )
        self.assertEqual(func_a_b.value, 37)
Ejemplo n.º 7
0
    def test_function_parameters(self):
        par_a = Parameter(3)
        par_b = Parameter(7)

        def func(a, b):
            return a * 10 + b

        func_a_b = Function(func, parameters=(par_a, par_b))

        self.assertEqual(
            func_a_b.get_children(),
            [par_a, par_b],
        )
        self.assertEqual(
            func_a_b.parameters,
            [par_a, par_b],
        )
Ejemplo n.º 8
0
    def setUp(self):
        self.par_a = Parameter(3)
        self.par_b = Parameter(7)

        self.func_sum_a_b = Function(TestNodes.sum_function,
                                     parameters=(self.par_a, self.par_b))

        self.empty_par = Empty()
Ejemplo n.º 9
0
    def test_function_value_update_frozen(self):
        par_a = Parameter(3)
        par_b = Parameter(7)

        def func(a, b):
            return a * 10 + b

        func_a_b = Function(func, parameters=(par_a, par_b))

        self.assertEqual(
            func_a_b.value,
            func(par_a.value, par_b.value),
        )

        # test value update
        func_a_b.freeze()
        par_a.value = -3

        self.assertEqual(
            func_a_b.value,
            func(3, 7),
        )

        func_a_b.unfreeze()
        self.assertEqual(
            func_a_b.value,
            func(par_a.value, par_b.value),
        )
Ejemplo n.º 10
0
 def setUp(self):
     self.root = RootNode()
     self.empty = Empty(name="empty")
     self.par_a = Parameter(2, name="par_a")
     self.par_b = Parameter(3, name="par_b")
     self.par_c = Parameter(4, name="par_c")
     self.func_1 = self.par_a + self.par_b
     self.func_2 = self.par_a * self.par_c
     self.func_3 = Function(lambda a, b: a + b)
     self.alias = Alias(self.par_a, name="alias")
     self.par_a.set_children([self.par_b])
     self.root.set_children(
         [self.empty, self.func_1, self.func_2, self.func_3, self.alias])
Ejemplo n.º 11
0
    def test_nodes_parametric_constructor(self):
        p = Parameter(None)
        str(p)

        a = Alias(Parameter(None))
        str(a)

        f = Function(lambda x: x, name='_lambda')
        str(f)

        fb = Fallback((f, ))
        str(fb)

        t = Tuple([])
        str(t)
Ejemplo n.º 12
0
    def test_function_value(self):
        par_a = Parameter(3)
        par_b = Parameter(7)

        def func(a, b):
            return a * 10 + b

        func_a_b = Function(func, parameters=(par_a, par_b))

        self.assertEqual(
            func_a_b.value,
            func(par_a.value, par_b.value),
        )

        # test value update
        par_a.value = -3

        self.assertEqual(
            func_a_b.value,
            func(par_a.value, par_b.value),
        )

        with self.assertRaises(NodeException):
            func_a_b.value = 5
Ejemplo n.º 13
0
    def test_parameter_replace_literal(self):
        par = Parameter(3, name='par')
        selector = Function(
            lambda *args: args[1],
            name='selector',
            parameters=(
                Parameter('bla', name='first'),
                par,
                Parameter('blup', name='third'),
            ),
        )
        context = Parameter(2, name='pre_factor') * selector

        # replace node
        par.replace(7)

        self.assertEqual(context.value, 14)
Ejemplo n.º 14
0
class TestNodes(unittest.TestCase):

    DEFAULT_CONSTRUCTIBLE_NODE_TYPES = (NodeBase, Empty, RootNode)
    PARAMETRIC_NODE_TYPES = (Parameter, Alias, Function, Fallback, Tuple)

    @staticmethod
    def sum_function(a, b):
        return a + b

    def setUp(self):
        self.par_a = Parameter(3)
        self.par_b = Parameter(7)

        self.func_sum_a_b = Function(TestNodes.sum_function,
                                     parameters=(self.par_a, self.par_b))

        self.empty_par = Empty()

        self.counter = 0
        self.sum = 0

        self.array_content = [0.1, 2.3, 4.5, 6.7, 8.9]

    # -- NodeBase

    def test_nodes_default_constructor(self):
        for _node_type in self.DEFAULT_CONSTRUCTIBLE_NODE_TYPES:
            # NodeBase constructor cannot be called because NodeBase has an abstract method.
            if _node_type is not NodeBase:
                _node_type()
                str(_node_type())

    def test_nodes_parametric_constructor(self):
        p = Parameter(None)
        str(p)

        a = Alias(Parameter(None))
        str(a)

        f = Function(lambda x: x, name='_lambda')
        str(f)

        fb = Fallback((f, ))
        str(fb)

        t = Tuple([])
        str(t)

    def test_nodes_equality(self):
        p_1 = Parameter(None)
        p_2 = Parameter(None)
        self.assertEqual(p_1, p_1)
        self.assertNotEqual(p_1, p_2)

    def test_node_get_children_get_parents(self):
        par_1 = Parameter(3, name='par_1')
        par_2 = Parameter(2, name='par_2')
        par_lc_1 = par_1 + par_2
        par_lc_2 = par_1 - par_2

        for _par in par_1, par_2:
            self.assertEqual(_par.get_parents(), [par_lc_1, par_lc_2])
        for _par in par_lc_1, par_lc_2:
            self.assertEqual(_par.get_children(), [par_1, par_2])

    def test_node_iter_children_iter_parents(self):
        par_1 = Parameter(3, name='par_1')
        par_2 = Parameter(2, name='par_2')
        par_lc_1 = par_1 + par_2
        par_lc_2 = par_1 - par_2

        for _par in par_1, par_2:
            self.assertEqual([_parent for _parent in _par.iter_parents()],
                             [par_lc_1, par_lc_2])
        for _par in par_lc_1, par_lc_2:
            self.assertEqual([_child for _child in _par.iter_children()],
                             [par_1, par_2])

    def test_node_add_child(self):
        par = Parameter(3, name='par')
        par.add_child(1)
        par.add_child(2)

        self.assertEqual([p.value for p in par.iter_children()], [1, 2])

    def test_node_add_parent(self):
        par = Parameter(3, name='par')
        par_2 = Parameter(2, name='par_2')

        with self.assertRaises(TypeError):
            par.add_parent("notanode")

        # Manually adding parents is not allowed:
        with self.assertRaises(NodeException):
            par.add_parent(par_2)

        par_2.add_child(par)
        self.assertEqual(par.get_parents(), [par_2])

    def test_node_set_children(self):
        par = Parameter(3, name='par')
        child_1 = Parameter(1, name='child_1')
        child_2 = Parameter(2, name='child_2')

        par.set_children([child_1, child_2])

        self.assertEqual(par.get_children(), [child_1, child_2])

    def test_node_remove_child(self):
        par = Parameter(3, name='par')
        child_1 = Parameter(1, name='child_1')
        child_2 = Parameter(2, name='child_2')

        par.set_children([child_1, child_2])
        par.remove_child(child_1)

        self.assertEqual(par.get_children(), [child_2])

        with self.assertRaises(TypeError):
            par.remove_child("notanode")

    def test_node_remove_parent(self):
        par = Parameter(3, name='par')
        parent_1 = Parameter(1, name='parent_1')
        parent_2 = Parameter(2, name='parent_2')

        parent_1.add_child(par)
        parent_2.add_child(par)
        with self.assertRaises(NodeException):
            par.remove_parent(parent_1)
        self.assertEqual(par.get_parents(), [parent_1, parent_2])
        parent_1.remove_child(par)
        self.assertEqual(par.get_parents(), [parent_2])

        with self.assertRaises(TypeError):
            par.remove_parent("notanode")

    def test_parameter_replace(self):
        par = Parameter(3, name='par')
        selector = Function(
            lambda *args: args[1],
            name='selector',
            parameters=(
                Parameter('bla', name='first'),
                par,
                Parameter('blup', name='third'),
            ),
        )
        context = Parameter(2, name='pre_factor') * selector

        # store parents and children for comparison
        _children_before = par.get_children()
        _parents_before = par.get_parents()

        # replace node
        par_new = Parameter(7, name='par_new')
        par.replace(par_new)

        par_new.value = 6

        self.assertEqual(context.value, 12)

        # retrieve new parents and children for comparison
        _children_after = par_new.get_children()
        _parents_after = par_new.get_parents()

        self.assertEqual(_children_before, _children_after)
        self.assertEqual(_parents_before, _parents_after)

    def test_parameter_replace_literal(self):
        par = Parameter(3, name='par')
        selector = Function(
            lambda *args: args[1],
            name='selector',
            parameters=(
                Parameter('bla', name='first'),
                par,
                Parameter('blup', name='third'),
            ),
        )
        context = Parameter(2, name='pre_factor') * selector

        # replace node
        par.replace(7)

        self.assertEqual(context.value, 14)

    def test_function_replace(self):
        par = Parameter(3, name='par')
        selector = Function(
            lambda *args: args[1],
            name='selector',
            parameters=(
                Parameter('bla', name='first'),
                par,
                Parameter('blup', name='third'),
            ),
        )
        context = Parameter(2, name='pre_factor') * selector

        # store parents and children for comparison
        _children_before = selector.get_children()
        _parameters_before = selector.parameters
        _parents_before = selector.get_parents()

        # replace node
        selector_new = Function(lambda *args: args[2], name='selector_new')
        selector.replace(selector_new, other_children=False)

        # retrieve new parents and children for comparison
        _children_after = selector_new.get_children()
        _parents_after = selector_new.get_parents()
        _parameters_after = selector_new.parameters

        self.assertEqual(_children_before, _children_after)
        self.assertEqual(_parents_before, _parents_after)
        self.assertEqual(_parameters_before, _parameters_after)

        self.assertEqual(selector_new.value, 'blup')

    def test_callback(self):
        def increment_counter():
            self.counter += 1

        def update_sum(summand):
            self.sum += summand

        self.par_a.register_callback(lambda: 42)
        self.par_a.register_callback(increment_counter)
        self.func_sum_a_b.register_callback(update_sum, args=[10])
        self.par_a.value = 5
        self.assertEqual(self.counter, 1)
        self.assertEqual(self.sum, 10)
        self.par_a.value = 6
        self.assertEqual(self.counter, 2)
        self.assertEqual(self.sum, 20)

    def test_replace_child(self):
        par_a = Parameter(4)
        par_b = Parameter(5)
        par_c = Parameter(6)
        par_a.add_child(par_b)
        self.assertEqual(par_a.get_children(), [par_b])
        with self.assertRaises(NodeException):
            par_a.replace_child(current_child=par_c, new_child=par_b)
        par_a.replace_child(current_child=par_b, new_child=par_c)
        with self.assertRaises(NodeException):
            par_a.replace_child(current_child=par_b, new_child=par_c)
        self.assertEqual(par_a.get_children(), [par_c])
        par_a.replace_child(current_child=par_c, new_child=7)
        self.assertNotEqual(par_a.get_children(), [par_b])
        self.assertNotEqual(par_a.get_children(), [par_c])
        self.assertIs(type(par_a.get_children()[0]), Parameter)

    # -- RootNode

    def test_root_node_replace(self):
        r = RootNode()
        with self.assertRaises(TypeError):
            r.replace(7)

    def test_root_node_parents(self):
        r = RootNode()
        with self.assertRaises(TypeError):
            r.add_parent(2)
        self.assertEqual([p for p in r.iter_parents()], [])
        self.assertEqual(r.get_parents(), [])

    # -- Empty

    def test_empty_value(self):
        e = Empty()
        with self.assertRaises(NodeException):
            _ = e.value
        with self.assertRaises(NodeException):
            e.value = 33

    # -- Parameter

    def test_parameter_get_value(self):
        par = Parameter(3)
        self.assertEqual(par.value, 3)

    def test_parameter_set_value(self):
        par = Parameter(3)
        par.value = 7
        self.assertEqual(par.value, 7)

    def test_parameter_name(self):
        par = Parameter(3, name='bla')
        self.assertEqual(par.name, 'bla')

    def test_parameter_invalid_name_raise(self):
        with self.assertRaises(NodeException):
            par = Parameter(3, name='2bla')

    def test_parameter_reserved_name_raise(self):
        for _rn in NodeBase.RESERVED_PARAMETER_NAMES:
            with self.assertRaises(NodeException):
                par = Parameter(3, name=_rn)

    def test_tuple_iter_raise(self):
        par = Parameter(3)
        with self.assertRaises(TypeError):
            iter(par)

    def test_parameter_binary_operation(self):
        par_a = Parameter(3)
        par_b = Parameter(7)
        for _op_name, (_op, _) in _OPERATORS.items():
            par_expr = _op(par_a, par_b)
            self.assertEqual(par_expr.value, _op(par_a.value, par_b.value))

    def test_parameter_binary_operation_with_literal(self):
        par_a = Parameter(3)
        for _op_name, (_op, _) in _OPERATORS.items():
            par_expr = _op(par_a, 7)
            self.assertEqual(par_expr.value, _op(par_a.value, 7))
            # reversed order
            par_expr = _op(7, par_a)
            self.assertEqual(par_expr.value, _op(7, par_a.value))

    def test_parameter_unary_operation(self):
        par = Parameter(3)

        self.assertEqual((+par).value, par.value)
        self.assertEqual((-par).value, -(par.value))
        self.assertEqual((~par).value, ~(par.value))

    # -- Alias

    def test_alias_value(self):
        par = Parameter(3)
        alias = Alias(par, name='alias')
        self.assertEqual(alias.value, par.value)
        with self.assertRaises(NodeException):
            alias.value = 5

    def test_multiple_alias_value(self):
        par = Parameter(3)
        alias = Alias(Alias(Alias(par)))
        self.assertEqual(alias.value, par.value)

    def test_alias_ref(self):
        par = Parameter(3)
        alias = Alias(par, name='alias')
        self.assertEqual(alias.ref, par)

    # -- Function

    def test_function_value(self):
        par_a = Parameter(3)
        par_b = Parameter(7)

        def func(a, b):
            return a * 10 + b

        func_a_b = Function(func, parameters=(par_a, par_b))

        self.assertEqual(
            func_a_b.value,
            func(par_a.value, par_b.value),
        )

        # test value update
        par_a.value = -3

        self.assertEqual(
            func_a_b.value,
            func(par_a.value, par_b.value),
        )

        with self.assertRaises(NodeException):
            func_a_b.value = 5

    def test_function_value_update_frozen(self):
        par_a = Parameter(3)
        par_b = Parameter(7)

        def func(a, b):
            return a * 10 + b

        func_a_b = Function(func, parameters=(par_a, par_b))

        self.assertEqual(
            func_a_b.value,
            func(par_a.value, par_b.value),
        )

        # test value update
        func_a_b.freeze()
        par_a.value = -3

        self.assertEqual(
            func_a_b.value,
            func(3, 7),
        )

        func_a_b.unfreeze()
        self.assertEqual(
            func_a_b.value,
            func(par_a.value, par_b.value),
        )

    def test_function_value_lambda(self):
        par_a = Parameter(3)
        par_b = Parameter(7)

        func_a_b = Function(lambda a, b: a * 10 + b,
                            name='_lambda',
                            parameters=(par_a, par_b))

        self.assertEqual(
            func_a_b.value,
            func_a_b.func(par_a.value, par_b.value),
        )

    def test_function_parameters(self):
        par_a = Parameter(3)
        par_b = Parameter(7)

        def func(a, b):
            return a * 10 + b

        func_a_b = Function(func, parameters=(par_a, par_b))

        self.assertEqual(
            func_a_b.get_children(),
            [par_a, par_b],
        )
        self.assertEqual(
            func_a_b.parameters,
            [par_a, par_b],
        )

    def test_function_auto_parameters(self):
        def func(a, b):
            return a * 10 + b

        func_a_b = Function(func)

        par_a = Parameter(3)
        par_b = Parameter(7)

        func_a_b.parameters = [par_a, par_b]

        self.assertEqual(
            func_a_b.parameters,
            [par_a, par_b],
        )
        self.assertEqual(func_a_b.value, 37)

    def test_function_auto_parameters_add_one_at_a_time(self):
        def func(a, b):
            return a * 10 + b

        func_a_b = Function(func)

        par_a = Parameter(3)
        par_b = Parameter(7)

        func_a_b.add_parameter(par_a)
        func_a_b.add_parameter(par_b)

        self.assertEqual(
            func_a_b.parameters,
            [par_a, par_b],
        )
        self.assertEqual(func_a_b.value, 37)

    # -- Fallback

    def test_fallback(self):
        par_a = Parameter(14)
        par_b = Parameter(2)

        div_a_b = Fallback((par_a / par_b, 'DIV0'),
                           exception_type=ZeroDivisionError)

        # no exception yet
        self.assertEqual(
            div_a_b.value,
            par_a.value / par_b.value,
        )

        # make denominator zero
        par_b.value = 0

        # exception matches type -> fallback
        self.assertEqual(
            div_a_b.value,
            'DIV0',
        )

        # exception does not match type -> raise
        with self.assertRaises(ZeroDivisionError):
            Fallback((par_a / par_b, 'DIV0'), exception_type=TypeError).value

    def test_fallback_no_good_alternative(self):
        par_a = Parameter(14)

        div_a_b = Fallback((par_a / 0, 2 * par_a / 0),
                           exception_type=ZeroDivisionError)

        # no exception yet
        with self.assertRaises(FallbackError):
            div_a_b.value,

    # -- Tuple

    def test_tuple_value(self):
        tuple_a = Tuple((14, 2))
        self.assertIs(type(tuple_a.value), tuple)
        self.assertEqual(tuple_a.value, (14, 2))

    def test_tuple_len(self):
        tuple_a = Tuple((14, 2))
        self.assertEqual(len(tuple_a), 2)

    def test_tuple_getitem(self):
        tuple_a = Tuple((14, 2))
        self.assertEqual(tuple_a.value, (14, 2))
        self.assertEqual(tuple_a[0].value, 14)
        self.assertEqual(tuple_a[1].value, 2)

    def test_tuple_setitem(self):
        tuple_a = Tuple((14, 2))
        tuple_a[0] = Parameter(15)
        tuple_a[1] = 1
        self.assertEqual(tuple_a.value, (15, 1))
        self.assertEqual(tuple_a[0].value, 15)
        self.assertEqual(tuple_a[1].value, 1)

    def test_tuple_nodes(self):
        tuple_a = Tuple((14, 2))
        tuple_a.nodes = (9, 4)
        self.assertEqual(tuple_a.value, (9, 4))

    def test_tuple_iter(self):
        tuple_a = Tuple((14, 2))
        self.assertEqual(tuple_a.nodes, [node for node in tuple_a])

    def test_tuple_iter_values(self):
        tuple_a = Tuple((14, 2))
        self.assertEqual(tuple_a.value,
                         tuple([value for value in tuple_a.iter_values()]))

    # -- Array

    def test_array_value(self):
        array_a = Array(self.array_content)
        self.assertIs(type(array_a.value), np.ndarray)
        self.assertTrue(np.all(array_a.value == self.array_content))

    def test_array_getitem(self):
        array_a = Array(self.array_content)
        for i, expected_value_i in enumerate(self.array_content):
            self.assertEqual(array_a[i].value, expected_value_i)

    def test_array_setitem(self):
        array_a = Array(self.array_content)
        array_a[3] = 0
        self.assertIs(type(array_a[3]), Parameter)
        self.assertTrue(np.all(array_a.value == [0.1, 2.3, 4.5, 0.0, 8.9]))

    def test_array_iter_values(self):
        array_a = Array(self.array_content)
        self.assertTrue(
            np.all(array_a.value == np.array(
                [value for value in array_a.iter_values()])))
Ejemplo n.º 15
0
 def test_nodes_parametric_constructor(self):
     Parameter(None)
     Alias(Parameter(None))
     f = Function(lambda x: x, name='_lambda')
     Fallback((f, ))
     Tuple([])