예제 #1
0
def _test_nonzero_derivatives_of_noncompounds_produce_the_right_types_and_shapes(self, collection):
    debug = 0

    u = collection.shared_objects.u
    v = collection.shared_objects.v
    w = collection.shared_objects.w

    # for t in chain(collection.noncompounds, collection.compounds):
    for t in collection.noncompounds:
        for var in (u, v, w):
            # Include d/dx [z ? y: x] but not d/dx [x ? f: z]
            if isinstance(t, Conditional) and (var in unique_post_traversal(t.ufl_operands[0])):
                if debug:
                    print(("Depends on %s :: %s" % (str(var), str(t))))
                continue

            if debug:
                print(('\n', '...:   ', t.ufl_shape, var.ufl_shape, '\n'))
            before = derivative(t, var)
            if debug:
                print(('\n', 'before:   ', str(before), '\n'))
            after = ad_algorithm(before)
            if debug:
                print(('\n', 'after:    ', str(after), '\n'))
            expected_shape = 0*t
            if debug:
                print(('\n', 'expected_shape: ', str(expected_shape), '\n'))
            # print '\n', str(expected_shape), '\n', str(after), '\n', str(before), '\n'

            if var in unique_post_traversal(t):
                self.assertEqualTotalShape(after, expected_shape)
                self.assertNotEqual(after, expected_shape)
            else:
                assert after == expected_shape
예제 #2
0
def _test_nonzero_diffs_of_noncompounds_produce_the_right_types_and_shapes(
        self, collection):
    debug = 0
    u = collection.shared_objects.u
    v = collection.shared_objects.v
    w = collection.shared_objects.w

    vu = variable(u)
    vv = variable(v)
    vw = variable(w)

    # for t in chain(collection.noncompounds, collection.compounds):
    for t in collection.noncompounds:
        t = replace(t, {u: vu, v: vv, w: vw})
        for var in (vu, vv, vw):
            # Include d/dx [z ? y: x] but not d/dx [x ? f: z]
            if isinstance(t, Conditional) and (var in unique_post_traversal(
                    t.ufl_operands[0])):
                if debug:
                    print(("Depends on %s :: %s" % (str(var), str(t))))
                continue

            before = diff(t, var)
            if debug:
                print(('\n', 'before:   ', str(before), '\n'))
            after = ad_algorithm(before)
            if debug:
                print(('\n', 'after:    ', str(after), '\n'))
            expected_shape = 0 * outer(
                t, var)  # expected shape, not necessarily value
            if debug:
                print(('\n', 'expected_shape: ', str(expected_shape), '\n'))
            # print '\n', str(expected_shape), '\n', str(after), '\n', str(before), '\n'

            if var in unique_post_traversal(t):
                self.assertEqualTotalShape(after, expected_shape)
                self.assertNotEqual(after, expected_shape)
            else:
                assert after == expected_shape
예제 #3
0
def _extract_variables(a):
    """Build a list of all Variable objects in a,
    which can be a Form, Integral or Expr.
    The ordering in the list obeys dependency order."""
    handled = set()
    variables = []
    for e in iter_expressions(a):
        for o in unique_post_traversal(e):
            if isinstance(o, Variable):
                expr, label = o.ufl_operands
                if label not in handled:
                    variables.append(o)
                    handled.add(label)
    return variables
예제 #4
0
def _test_nonzero_diffs_of_noncompounds_produce_the_right_types_and_shapes(self, collection):
    debug = 0
    u = collection.shared_objects.u
    v = collection.shared_objects.v
    w = collection.shared_objects.w

    vu = variable(u)
    vv = variable(v)
    vw = variable(w)

    # for t in chain(collection.noncompounds, collection.compounds):
    for t in collection.noncompounds:
        t = replace(t, {u:vu, v:vv, w:vw})
        for var in (vu, vv, vw):
            # Include d/dx [z ? y: x] but not d/dx [x ? f: z]
            if isinstance(t, Conditional) and (var in unique_post_traversal(t.ufl_operands[0])):
                if debug:
                    print(("Depends on %s :: %s" % (str(var), str(t))))
                continue

            before = diff(t, var)
            if debug:
                print(('\n', 'before:   ', str(before), '\n'))
            after = ad_algorithm(before)
            if debug:
                print(('\n', 'after:    ', str(after), '\n'))
            expected_shape = 0*outer(t, var) # expected shape, not necessarily value
            if debug:
                print(('\n', 'expected_shape: ', str(expected_shape), '\n'))
            # print '\n', str(expected_shape), '\n', str(after), '\n', str(before), '\n'

            if var in unique_post_traversal(t):
                self.assertEqualTotalShape(after, expected_shape)
                self.assertNotEqual(after, expected_shape)
            else:
                assert after == expected_shape
예제 #5
0
def _test_nonzero_derivatives_of_noncompounds_produce_the_right_types_and_shapes(
        self, collection):
    debug = 0

    u = collection.shared_objects.u
    v = collection.shared_objects.v
    w = collection.shared_objects.w

    # for t in chain(collection.noncompounds, collection.compounds):
    for t in collection.noncompounds:
        for var in (u, v, w):
            # Include d/dx [z ? y: x] but not d/dx [x ? f: z]
            if isinstance(t, Conditional) and (var in unique_post_traversal(
                    t.ufl_operands[0])):
                if debug:
                    print(("Depends on %s :: %s" % (str(var), str(t))))
                continue

            if debug:
                print(('\n', '...:   ', t.ufl_shape, var.ufl_shape, '\n'))
            before = derivative(t, var)
            if debug:
                print(('\n', 'before:   ', str(before), '\n'))
            after = ad_algorithm(before)
            if debug:
                print(('\n', 'after:    ', str(after), '\n'))
            expected_shape = 0 * t
            if debug:
                print(('\n', 'expected_shape: ', str(expected_shape), '\n'))
            # print '\n', str(expected_shape), '\n', str(after), '\n', str(before), '\n'

            if var in unique_post_traversal(t):
                self.assertEqualTotalShape(after, expected_shape)
                self.assertNotEqual(after, expected_shape)
            else:
                assert after == expected_shape
예제 #6
0
def test_pre_and_post_traversal():
    element = FiniteElement("CG", "triangle", 1)
    v = TestFunction(element)
    f = Coefficient(element)
    g = Coefficient(element)
    p1 = f * v
    p2 = g * v
    s = p1 + p2

    # NB! These traversal algorithms are intended to guarantee only
    # parent before child and vice versa, not this particular
    # ordering:
    assert list(pre_traversal(s)) == [s, p2, g, v, p1, f, v]
    assert list(post_traversal(s)) == [g, v, p2, f, v, p1, s]
    assert list(unique_pre_traversal(s)) == [s, p2, g, v, p1, f]
    assert list(unique_post_traversal(s)) == [v, f, p1, g, p2, s]
예제 #7
0
def compute_expression_hashdata(expression, terminal_hashdata) -> bytes:
    cache = {}

    for expr in unique_post_traversal(expression):
        # Uniquely traverse tree and hash each node
        # E.g. (a + b*c) is hashed as hash([+, hash(a), hash([*, hash(b), hash(c)])])
        # Traversal uses post pattern, so children hashes are cached
        if expr._ufl_is_terminal_:
            data = [terminal_hashdata[expr]]
        else:
            data = [expr._ufl_typecode_]

            for op in expr.ufl_operands:
                data += [cache[op]]
        cache[expr] = hashlib.sha512(str(data).encode("utf-8")).digest()
    return cache[expression]
예제 #8
0
def test_pre_and_post_traversal():
    element = FiniteElement("CG", "triangle", 1)
    v = TestFunction(element)
    f = Coefficient(element)
    g = Coefficient(element)
    p1 = f * v
    p2 = g * v
    s = p1 + p2

    # NB! These traversal algorithms are intended to guarantee only
    # parent before child and vice versa, not this particular
    # ordering:
    assert list(pre_traversal(s)) == [s, p2, g, v, p1, f, v]
    assert list(post_traversal(s)) == [g, v, p2, f, v, p1, s]
    assert list(unique_pre_traversal(s)) == [s, p2, g, v, p1, f]
    assert list(unique_post_traversal(s)) == [v, f, p1, g, p2, s]
예제 #9
0
파일: map_dag.py 프로젝트: mrambausek/ufl
 def traversal(expression):
     return unique_post_traversal(expression, visited)
예제 #10
0
파일: map_dag.py 프로젝트: FEniCS/ufl
 def traversal(expression):
     return unique_post_traversal(expression, visited)